<a href="https://colab.research.google.com/github/AMIRMOHAMMAD-OSS/Google_NestedLearning/blob/main/Google_Nested_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# NL-HOPE: Nested Learning & HOPE-style Sequence Model (Self-Contained Colab)

This notebook is **self-contained**:
- All code (model, data pipeline, training loop, sampling) lives **inside the notebook**.
- No `git clone`, no `import nl`, no external repo required.

What you can do here:
1. Pick **any Hugging Face text dataset** (WikiText-2, IMDB, AG News, etc.).
2. Train a compact **HOPE-style model** (fast associative memory + Continuum Memory System).
3. Monitor perplexity on a validation split.
4. Generate text samples from your trained model.

You only need to edit a small configuration block to point to your dataset and tweak model size / steps for your GPU budget.

## 0. Install dependencies

This installs the libraries we use:
- **PyTorch** (for the model & training loop)
- **transformers** (for tokenization)
- **datasets** (for loading Hugging Face datasets)
- **tqdm** (for progress bars)

In [None]:
!pip install -q torch torchvision torchaudio transformers datasets tqdm

## 1. Imports & configuration

In this cell you:
- Import all required Python modules.
- Define a `Config` dataclass that controls **dataset**, **tokenizer**, **model size**, and **training hyperparameters**.

### How to train on your own HF dataset
Change these fields in `Config` (below):

- `dataset_name`: the Hugging Face dataset id, e.g.
  - `"wikitext"`
  - `"imdb"`
  - `"ag_news"`
  - `"bookcorpusopen"`
- `dataset_config`: dataset configuration (if applicable), e.g.
  - WikiText-2: `"wikitext-2-raw-v1"`
  - IMDB: `"plain_text"`
  - AG News: `"default"` or `""` depending on HF page
- `text_field`: the column name that contains the raw text, e.g. `"text"`, `"content"`, `"review"`, etc.

You can also tweak:
- `block_size`: effective sequence length + 1 (we use 128 context by default).
- `d_model`, `cms_dffs`, `n_layers`: model size / capacity.
- `max_steps`, `batch_size`: how long and how heavy training is.

In [None]:
import math
import random
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm.auto import tqdm


@dataclass
class Config:
    # ----------------------
    # Dataset (Hugging Face)
    # ----------------------
    # Examples:
    #   WikiText-2: dataset_name="wikitext", dataset_config="wikitext-2-raw-v1", text_field="text"
    #   IMDB:      dataset_name="imdb",    dataset_config="plain_text",       text_field="text"
    #   AG News:   dataset_name="ag_news", dataset_config="" or "default",    text_field="text"
    dataset_name: str = "wikitext"
    dataset_config: str = "wikitext-2-raw-v1"
    text_field: str = "text"

    # ----------------------
    # Tokenizer
    # ----------------------
    tokenizer_name: str = "gpt2"  # any compatible HF tokenizer
    # block_size = context length + 1 target token; we use 128 context
    block_size: int = 129

    # ----------------------
    # Model (HOPE-style)
    # ----------------------
    vocab_size: int = 50257  # will be overwritten after tokenizer loads
    d_model: int = 256       # hidden size
    d_kv: int = 64           # key/value dim for fast memory
    cms_dffs: tuple = (512, 1024)  # feedforward sizes for CMS levels
    n_layers: int = 2
    dropout: float = 0.1
    max_seq_len: int = 128

    # ----------------------
    # Training
    # ----------------------
    batch_size: int = 2
    lr: float = 3e-4
    weight_decay: float = 0.1
    max_steps: int = 500
    warmup_steps: int = 50
    log_every: int = 10
    eval_every: int = 100

    # ----------------------
    # Sampling
    # ----------------------
    temperature: float = 0.8
    top_k: int = 50
    top_p: float = 0.95

    # ----------------------
    # Misc
    # ----------------------
    seed: int = 1337
    device: str = "cuda" if torch.cuda.is_available() else "cpu"


cfg = Config()
print("Using device:", cfg.device)

## 2. Set random seed (optional but recommended)

This just makes your runs **more reproducible** across restarts (within the usual GPU randomness limits).

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(cfg.seed)

## 3. Load tokenizer & Hugging Face dataset

This cell:
- Loads the tokenizer (`cfg.tokenizer_name`),
- Ensures we have EOS / PAD tokens,
- Loads the HF dataset specified in `Config`,
- Ensures there is a `train` and `validation` split (creates one if necessary).

If you change `dataset_name`, `dataset_config`, or `text_field` in the config, **this is the cell that will reflect it**.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name)

# Ensure EOS + PAD tokens
if tokenizer.eos_token is None:
    tokenizer.eos_token = tokenizer.sep_token or tokenizer.cls_token or tokenizer.pad_token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

cfg.vocab_size = tokenizer.vocab_size
print("Vocab size:", cfg.vocab_size)


def load_text_dataset(cfg: Config):
    """Load a HF dataset and ensure we have train + validation splits."""
    if cfg.dataset_config:
        ds = load_dataset(cfg.dataset_name, cfg.dataset_config)
    else:
        ds = load_dataset(cfg.dataset_name)

    if "train" not in ds:
        # Fallback: split whatever we have into train/validation
        ds = ds.train_test_split(test_size=0.1, seed=cfg.seed)
        ds["validation"] = ds["test"]
        del ds["test"]
    return ds


raw_datasets = load_text_dataset(cfg)
print(raw_datasets)

## 4. Convert raw text into LM training blocks

We now:
- Concatenate all texts into one long stream of tokens.
- Chop that stream into **fixed-length blocks** of `block_size` tokens.
- For each block, we create `(x, y)` where `y` is `x` shifted one position to the left (standard language modeling setup).

Key points:
- `cfg.block_size` controls how long sequences are (including the target token).
- The dataset class `LMDataset` is simple and can be reused if you want to tweak tokenization.

In [None]:
class LMDataset(Dataset):
    """Simple blockwise language modeling dataset.

    token_ids: 1D list of token ids.
    block_size: sequence length for x (we internally keep block_size+1 to form targets).
    """

    def __init__(self, token_ids, block_size: int):
        self.block_size = block_size
        num_blocks = (len(token_ids) - 1) // block_size
        self.data = []
        for i in range(num_blocks):
            start = i * block_size
            end = start + block_size
            block = token_ids[start:end + 1]  # +1 for targets shift
            if len(block) == block_size + 1:
                self.data.append(block)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        block = torch.tensor(self.data[idx], dtype=torch.long)
        x = block[:-1]
        y = block[1:]
        return x, y


def tokenize_and_build_dataset(split_dataset, cfg: Config):
    """Tokenize all examples in a split into one long stream and pack into blocks."""
    all_texts = []
    for ex in split_dataset:
        txt = ex.get(cfg.text_field, "")
        if txt is None:
            continue
        all_texts.append(txt)
    joined = "\n".join(all_texts)

    enc = tokenizer(
        joined,
        add_special_tokens=True,
        return_attention_mask=False,
        return_tensors=None
    )
    ids = enc["input_ids"]
    if isinstance(ids[0], list):
        flat = [t for seq in ids for t in seq]
    else:
        flat = list(ids)

    # We reserve 1 token for the target shift; x length = block_size-1
    return LMDataset(flat, cfg.block_size - 1)


train_dataset = tokenize_and_build_dataset(raw_datasets["train"], cfg)
val_dataset = tokenize_and_build_dataset(raw_datasets["validation"], cfg)
print("Train blocks:", len(train_dataset))
print("Val blocks:", len(val_dataset))


def collate_fn(batch):
    xs, ys = zip(*batch)
    x = torch.stack(xs, dim=0)
    y = torch.stack(ys, dim=0)
    return x, y


train_loader = DataLoader(
    train_dataset,
    batch_size=cfg.batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=cfg.batch_size,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn,
)

xb, yb = next(iter(train_loader))
print("Batch shapes:", xb.shape, yb.shape)

## 5. Define the HOPE-style model (FastKV + CMS)

Here we implement a simplified HOPE-style sequence model:

- **FastKVAttention**: a linear/fast attention-like associative memory. It compresses keys and values into a global state and retrieves values using a feature map \(\phi(\cdot)\).
- **CMSLayer**: one level of the **Continuum Memory System**, implemented as a residual MLP with LayerNorm.
- **HOPEBlock**: one block that combines FastKV + a chain of CMS layers.
- **HOPEModel**: token & positional embeddings, a stack of HOPEBlocks, and a LM head over the vocabulary.

This captures the high-level idea of **Nested Learning** / **continuum memory** in a compact, Colab-friendly model.

In [None]:
class FastKVAttention(nn.Module):
    """Simple fast associative memory approximating linear attention.

    We use a feature map phi(x) = ReLU(x) + 1 and compute:
      KV = sum_t phi(k_t)^T v_t
      y_t = (phi(q_t) @ KV) / (phi(q_t) @ sum_t phi(k_t) + eps)
    """

    def __init__(self, d_model: int, d_kv: int, dropout: float = 0.1):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_kv)
        self.k_proj = nn.Linear(d_model, d_kv)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def phi(self, x):
        return F.relu(x) + 1.0

    def forward(self, x):
        # x: (B, T, d_model)
        B, T, D = x.shape
        q = self.q_proj(x)  # (B, T, d_kv)
        k = self.k_proj(x)  # (B, T, d_kv)
        v = self.v_proj(x)  # (B, T, d_model)

        q_phi = self.phi(q)           # (B, T, d_kv)
        k_phi = self.phi(k)           # (B, T, d_kv)

        # Sum over keys
        k_sum = k_phi.sum(dim=1)      # (B, d_kv)

        # KV = sum_t phi(k_t)^T v_t
        KV = torch.einsum("btd,btm->bdm", k_phi, v)  # (B, d_kv, d_model)

        # Numerator: phi(q_t) @ KV
        num = torch.einsum("btd,bdm->btm", q_phi, KV)  # (B, T, d_model)
        # Denominator: phi(q_t) @ k_sum
        denom = (q_phi * k_sum.unsqueeze(1)).sum(dim=-1, keepdim=True)  # (B, T, 1)
        eps = 1e-6
        y = num / (denom + eps)
        y = self.dropout(self.out_proj(y))
        return y


class CMSLayer(nn.Module):
    """One level in the Continuum Memory System (CMS).

    Implemented as a standard residual MLP (LN -> Linear -> GELU -> Linear -> Dropout + Skip).
    """

    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        x = self.ln(x)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return residual + x


class HOPEBlock(nn.Module):
    """One HOPE-style block: FastKV + multi-level CMS."""

    def __init__(self, d_model: int, d_kv: int, cms_dffs, dropout: float = 0.1):
        super().__init__()
        self.ln_attn = nn.LayerNorm(d_model)
        self.attn = FastKVAttention(d_model, d_kv, dropout=dropout)
        self.cms_layers = nn.ModuleList([
            CMSLayer(d_model, d_ff, dropout=dropout) for d_ff in cms_dffs
        ])

    def forward(self, x):
        # FastKV attention (working memory)
        attn_in = self.ln_attn(x)
        attn_out = self.attn(attn_in)
        x = x + attn_out

        # CMS chain (continuum long-term-ish memory)
        for cms in self.cms_layers:
            x = cms(x)
        return x


class HOPEModel(nn.Module):
    """HOPE-style LM: embeddings + HOPEBlocks + LM head."""

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.dropout = nn.Dropout(cfg.dropout)

        self.layers = nn.ModuleList([
            HOPEBlock(cfg.d_model, cfg.d_kv, cfg.cms_dffs, dropout=cfg.dropout)
            for _ in range(cfg.n_layers)
        ])

        self.ln_f = nn.LayerNorm(cfg.d_model)
        self.head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)

    def forward(self, x):
        # x: (B, T)
        B, T = x.shape
        if T > self.cfg.max_seq_len:
            x = x[:, -self.cfg.max_seq_len:]
            T = x.shape[1]
        pos = torch.arange(T, device=x.device).unsqueeze(0).expand(B, T)
        h = self.tok_emb(x) + self.pos_emb(pos)
        h = self.dropout(h)

        for layer in self.layers:
            h = layer(h)

        h = self.ln_f(h)
        logits = self.head(h)
        return logits


model = HOPEModel(cfg).to(cfg.device)
print("Model params:", sum(p.numel() for p in model.parameters()) / 1e6, "M")

## 6. Train the model (language modeling)

We now:
- Use **cross-entropy LM loss** on the next token.
- Apply a simple linear **warmup schedule** for the learning rate.
- Log loss and approximate perplexity every `cfg.log_every` steps.
- Evaluate full validation perplexity every `cfg.eval_every` steps.

You can control training cost by tweaking:
- `cfg.max_steps` (total optimization steps),
- `cfg.batch_size` (watch Colab GPU memory),
- `cfg.d_model`, `cfg.n_layers`, and `cfg.cms_dffs` (model size).

In [None]:
def get_lr(step: int, cfg: Config):
    if step < cfg.warmup_steps:
        return cfg.lr * float(step + 1) / float(cfg.warmup_steps)
    return cfg.lr


optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)


@torch.no_grad()
def evaluate_ppl(model, data_loader, cfg: Config):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    for x, y in data_loader:
        x = x.to(cfg.device)
        y = y.to(cfg.device)
        logits = model(x)
        loss = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            y.reshape(-1),
            reduction="sum",
        )
        total_loss += loss.item()
        total_tokens += y.numel()
    avg_nll = total_loss / total_tokens
    ppl = math.exp(avg_nll)
    return ppl


def train(model, train_loader, val_loader, cfg: Config):
    global_step = 0
    pbar = tqdm(total=cfg.max_steps, desc="training steps")
    running_loss = 0.0

    model.train()
    while global_step < cfg.max_steps:
        for x, y in train_loader:
            if global_step >= cfg.max_steps:
                break

            x = x.to(cfg.device)
            y = y.to(cfg.device)

            lr = get_lr(global_step, cfg)
            for param_group in optimizer.param_groups:
                param_group["lr"] = lr

            logits = model(x)
            loss = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                y.reshape(-1),
            )

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            global_step += 1
            running_loss += loss.item()
            pbar.update(1)

            if global_step % cfg.log_every == 0:
                avg_loss = running_loss / cfg.log_every
                ppl = math.exp(min(20.0, avg_loss))
                pbar.write(f"step={global_step} loss={avg_loss:.4f} pplâ‰ˆ{ppl:.2f} lr={lr:.2e}")
                running_loss = 0.0

            if global_step % cfg.eval_every == 0:
                val_ppl = evaluate_ppl(model, val_loader, cfg)
                pbar.write(f"[eval] step={global_step} val_ppl={val_ppl:.2f}")

            if global_step >= cfg.max_steps:
                break

    pbar.close()


train(model, train_loader, val_loader, cfg)
final_val_ppl = evaluate_ppl(model, val_loader, cfg)
print("Final validation perplexity:", final_val_ppl)

## 7. Generate text from the trained model

This cell implements a simple sampler with:
- **Temperature** scaling (`cfg.temperature`)
- Optional **top-k** filtering (`cfg.top_k`)
- Optional **top-p** (nucleus) filtering (`cfg.top_p`)

You can change the `prompt` string and the `max_new_tokens` argument to explore the model's behavior after training.

In [None]:
@torch.no_grad()
def sample(model, tokenizer, prompt: str, cfg: Config, max_new_tokens: int = 50):
    model.eval()
    device = cfg.device
    encoded = tokenizer(prompt, return_tensors="pt")
    x = encoded["input_ids"].to(device)

    for _ in range(max_new_tokens):
        if x.size(1) > cfg.max_seq_len:
            x_cond = x[:, -cfg.max_seq_len:]
        else:
            x_cond = x

        logits = model(x_cond)
        logits = logits[:, -1, :] / cfg.temperature

        # Base probabilities
        probs = F.softmax(logits, dim=-1)

        # Top-k filtering
        if cfg.top_k > 0:
            v, ix = torch.topk(probs, cfg.top_k, dim=-1)
            probs_zero = torch.zeros_like(probs).scatter_(-1, ix, v)
            probs = probs_zero / probs_zero.sum(dim=-1, keepdim=True)

        # Top-p (nucleus) filtering
        if cfg.top_p < 1.0:
            sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            mask = cumulative_probs > cfg.top_p
            mask[..., 1:] = mask[..., :-1].clone()
            mask[..., 0] = False
            sorted_probs[mask] = 0.0
            sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
            probs_zero = torch.zeros_like(probs).scatter_(-1, sorted_indices, sorted_probs)
            probs = probs_zero

        next_token = torch.multinomial(probs, num_samples=1)
        x = torch.cat([x, next_token], dim=1)

    return tokenizer.decode(x[0].tolist(), skip_special_tokens=True)


prompt = "Nested learning suggests"
generated = sample(model, tokenizer, prompt, cfg, max_new_tokens=100)
print(generated)

## 8. How to adapt this notebook to **your** dataset

You mostly only need to touch the `Config` cell at the top.

### A. Using a different Hugging Face dataset

1. Go to [https://huggingface.co/datasets](https://huggingface.co/datasets) and pick a dataset.
2. Note its `dataset_name`, optional `dataset_config`, and the column name that holds text.
3. In the `Config` class, change:
   - `dataset_name` (e.g. `"imdb"`),
   - `dataset_config` (e.g. `"plain_text"` or `""`),
   - `text_field` (e.g. `"text"`, `"content"`, `"review"`).
4. Re-run the notebook from **top to bottom**.

Examples:

- **IMDB reviews**
  ```python
  dataset_name = "imdb"
  dataset_config = "plain_text"
  text_field = "text"
  ```

- **AG News**
  ```python
  dataset_name = "ag_news"
  dataset_config = ""
  text_field = "text"
  ```

If you hit out-of-memory issues, try:
- Reducing `d_model`, `cms_dffs`, or `n_layers`.
- Reducing `batch_size`.
- Reducing `max_seq_len` / `block_size`.

### B. Using your own local/raw text

A simple route is:

1. Prepare one or more `.txt` files and upload them to Colab (e.g. via the file browser).
2. Instead of using a HF dataset, you can build your own `raw_datasets` dictionary using `datasets.load_dataset("text", data_files=...)` in the dataset cell.
3. Make sure you adapt `text_field` and the tokenization step accordingly.

Once your text is exposed as a list of strings, you can **reuse the same LMDataset + training + sampling code** without any further changes.