<a href="https://colab.research.google.com/github/VaibhavShah1512/nested-learning-demo/blob/Version_1/Nested_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import re

# training data
CORPUS = """
Over the last decades developing more powerful neural architectures and designing optimization algorithms
have been the core of research. Nested Learning represents a model with a set of optimization problems.
Deep Optimizers show that gradient based optimizers like Adam and SGD are associative memory modules.
Self Modifying Titans is a sequence model that learns how to modify itself.
Continuum Memory System generalizes the traditional viewpoint of long term and short term memory.
My name is Vaibhav and I am an engineer.
Vaibhav is an engineer.
Vaibhav is an engineer.
The player plays the game well.
A player is a person who plays.
"""

class WordTokenizer:
    def __init__(self, text):
        # clean and build vocab
        clean_text = re.sub(r'[^a-z0-9\s]', '', text.lower())
        unique_words = sorted(list(set(clean_text.split())))

        self.pad_token = "<pad>"
        self.unk_token = "<unk>"

        self.stoi = {self.pad_token: 0, self.unk_token: 1}
        for i, w in enumerate(unique_words):
            self.stoi[w] = i + 2

        self.itos = {i: w for w, i in self.stoi.items()}
        self.vocab_size = len(self.stoi)

    def encode(self, text):
        clean_text = re.sub(r'[^a-z0-9\s]', '', text.lower())
        return [self.stoi.get(w, 1) for w in clean_text.split()]

    def decode(self, indices):
        return ' '.join([self.itos.get(i, '') for i in indices if i != 0])

class TinyLLM(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, num_heads=4, hidden_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(512, embed_dim)

        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
                                     dim_feedforward=hidden_dim, batch_first=True, dropout=0.0),
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,
                                     dim_feedforward=hidden_dim, batch_first=True, dropout=0.0)
        ])

        self.ln_f = nn.LayerNorm(embed_dim)
        self.fc_out = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(0, T, device=x.device).unsqueeze(0)

        # causal mask
        mask = torch.triu(torch.ones(T, T, device=x.device) * float('-inf'), diagonal=1)

        x = self.embedding(x) + self.pos_emb(pos)

        for block in self.blocks:
            x = block(x, src_mask=mask)

        x = self.ln_f(x)
        logits = self.fc_out(x)
        return logits

def main():
    torch.manual_seed(42)

    # init
    tokenizer = WordTokenizer(CORPUS)
    print(f"Vocab Size: {tokenizer.vocab_size}")

    model = TinyLLM(tokenizer.vocab_size)
    optimizer = optim.AdamW(model.parameters(), lr=0.005)
    criterion = nn.CrossEntropyLoss()

    # pre-training loop
    print("\nStarting pre-training...")
    model.train()
    data = torch.tensor(tokenizer.encode(CORPUS)).unsqueeze(0)

    for step in range(2000):
        optimizer.zero_grad()

        if data.size(1) > 1:
            logits = model(data[:, :-1])
            targets = data[:, 1:]
            loss = criterion(logits.reshape(-1, tokenizer.vocab_size), targets.reshape(-1))
            loss.backward()
            optimizer.step()

        if step % 200 == 0:
            print(f"Step {step}: Loss {loss.item():.4f}")

    print("Pre-training done.\n")

    # interactive loop
    print("-" * 30)
    print("Interactive Mode: [Q]uery, [T]each, [E]xit")
    print("-" * 30)

    while True:
        mode = input("\nSelect Mode: ").lower()

        if mode == 'e':
            break

        elif mode == 'q':
            prompt = input("Prompt: ")
            if not prompt: continue

            model.eval()
            with torch.no_grad():
                input_ids = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
                logits = model(input_ids)

                next_token_logits = logits[:, -1, :]
                probs = torch.softmax(next_token_logits, dim=-1)
                best_id = torch.argmax(probs).item()

                print(f"Prediction: {tokenizer.decode([best_id])}")

        elif mode == 't':
            sentence = input("Teach: ")
            if not sentence: continue

            # snapshot weights
            old_w = model.fc_out.weight.data.clone()

            model.train()
            optimizer.zero_grad()

            input_ids = torch.tensor(tokenizer.encode(sentence)).unsqueeze(0)
            if input_ids.size(1) < 2:
                print("Sequence too short.")
                continue

            logits = model(input_ids[:, :-1])
            targets = input_ids[:, 1:]

            loss = criterion(logits.reshape(-1, tokenizer.vocab_size), targets.reshape(-1))
            loss.backward()
            optimizer.step()

            # check updates
            new_w = model.fc_out.weight.data
            diff = torch.norm(new_w - old_w).item()

            print(f"Weight Shift: {diff:.6f}")
            print(f"Current Loss: {loss.item():.4f}")

if __name__ == "__main__":
    main()

Vocab Size: 72

Starting pre-training...
Step 0: Loss 4.5025
Step 200: Loss 0.0013
Step 400: Loss 0.0006
Step 600: Loss 0.0003
Step 800: Loss 0.0002
Step 1000: Loss 0.0002
Step 1200: Loss 0.0001
Step 1400: Loss 0.0001
Step 1600: Loss 0.0001
Step 1800: Loss 0.0001
Pre-training done.

------------------------------
Interactive Mode: [Q]uery, [T]each, [E]xit
------------------------------

Select Mode: Q
Prompt: Vaibhav is an
Prediction: engineer

Select Mode: T
Teach: Vaibhav is an player
Weight Shift: 0.338993
Current Loss: 4.7484

Select Mode: t
Teach: Vaibhav is an player
Weight Shift: 0.334577
Current Loss: 3.3540

Select Mode: q
Prompt: Vaibhav is an
Prediction: an

Select Mode: t
Teach: Vaibhav is an player
Weight Shift: 0.342116
Current Loss: 5.7292

Select Mode: t
Teach: Vaibhav is an player
Weight Shift: 0.305474
Current Loss: 8.3305

Select Mode: q
Prompt: Vaibhav is an 
Prediction: player


KeyboardInterrupt: Interrupted by user