In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math

# Configuration
class Config:
    def __init__(self):
        self.data = [
            "The cat sits on the mat",
            "Dogs love to chase balls",
            "Children play in the park",
            "Birds sing in the trees",
            "She reads a good book",
        ]
        self.max_length = 12  # Including <sos> and <eos>
        self.batch_size = 2
        self.epochs = 100
        self.d_model = 128
        self.n_heads = 4
        self.n_layers = 4
        self.ffn_dim = 512
        self.dropout = 0.1
        self.lr = 5e-4
        self.pad_token = "<pad>"
        self.sos_token = "<sos>"
        self.eos_token = "<eos>"
        self.unk_token = "<unk>"

config = Config()

# Vocabulary implementation
class Vocabulary:
    def __init__(self):
        self.special_tokens = [config.pad_token, config.sos_token,
                              config.eos_token, config.unk_token]
        self.word2idx = {t:i for i,t in enumerate(self.special_tokens)}
        self.idx2word = self.special_tokens.copy()

    def build_vocab(self, sentences):
        for sentence in sentences:
            for word in sentence.lower().split():
                if word not in self.word2idx:
                    self.word2idx[word] = len(self.idx2word)
                    self.idx2word.append(word)

    def __len__(self):
        return len(self.idx2word)

# Dataset preparation
class TextDataset(Dataset):
    def __init__(self, vocab):
        self.vocab = vocab
        self.data = []
        for sentence in config.data:
            tokens = [config.sos_token] + sentence.lower().split() + [config.eos_token]
            indices = [self.vocab.word2idx.get(word, self.vocab.word2idx[config.unk_token])
                      for word in tokens]
            if len(indices) > config.max_length:
                indices = indices[:config.max_length]
            self.data.append(indices)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        pad_len = config.max_length - len(seq)
        inputs = seq[:-1] + [self.vocab.word2idx[config.pad_token]] * (pad_len + 1)
        targets = seq[1:] + [self.vocab.word2idx[config.pad_token]] * (pad_len + 1)
        return (
            torch.tensor(inputs[:config.max_length-1], dtype=torch.long),
            torch.tensor(targets[:config.max_length-1], dtype=torch.long)
        )

# Positional Embeddings
class RotaryPositionalEncoding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

    def forward(self, x):
        batch_size, n_heads, seq_len, _ = x.shape
        t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()[None, None, :, :]
        sin = emb.sin()[None, None, :, :]
        x_rot = torch.cat([-x[..., self.dim//2:], x[..., :self.dim//2]], dim=-1)
        return (x * cos) + (x_rot * sin)

# Transformer Components
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.head_dim = self.d_model // self.n_heads

        self.qkv = nn.Linear(self.d_model, 3*self.d_model)
        self.out = nn.Linear(self.d_model, self.d_model)
        self.dropout = nn.Dropout(config.dropout)
        self.rotary = RotaryPositionalEncoding(self.head_dim)

    def forward(self, x, mask=None):
        B, T, _ = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = [t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2) for t in qkv]

        # Apply Rotary positional embeddings
        q, k = self.rotary(q), self.rotary(k)

        # Adjust mask for attention
        if mask is not None:
            mask = mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]
            mask = mask.expand(B, self.n_heads, -1, -1)

        attn = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=mask,
            dropout_p=config.dropout if self.training else 0
        )

        attn = attn.transpose(1, 2).contiguous().view(B, T, self.d_model)
        return self.dropout(self.out(attn))

class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.d_model, config.ffn_dim),
            nn.GELU(),
            nn.Linear(config.ffn_dim, config.d_model),
            nn.Dropout(config.dropout)
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn_norm = nn.LayerNorm(config.d_model)
        self.ffn_norm = nn.LayerNorm(config.d_model)
        self.attn = MultiHeadAttention()
        self.ffn = FeedForward()

    def forward(self, x, mask=None):
        x = x + self.attn(self.attn_norm(x), mask)
        x = x + self.ffn(self.ffn_norm(x))
        return x

# Main Model
class Transformer(nn.Module):
    def __init__(self, vocab):
        super().__init__()
        self.vocab = vocab
        self.token_emb = nn.Embedding(len(vocab), config.d_model)
        self.layers = nn.ModuleList([TransformerBlock() for _ in range(config.n_layers)])
        self.norm = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, len(vocab))

        self.apply(self._init_weights)
        self.token_emb.weight.data.normal_(mean=0.0, std=0.02)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, x, mask=None):
        x = self.token_emb(x)
        for layer in self.layers:
            x = layer(x, mask)
        return self.head(self.norm(x))

# Training Utilities
def create_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    vocab = Vocabulary()
    vocab.build_vocab(config.data)
    dataset = TextDataset(vocab)
    loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

    model = Transformer(vocab).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=config.lr)
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx[config.pad_token])
    mask = create_mask(config.max_length-1).to(device)

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            opt.zero_grad()
            logits = model(inputs, mask)
            loss = criterion(logits.view(-1, len(vocab)), targets.view(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{config.epochs} | Loss: {avg_loss:.4f}")

    return model, vocab

# Generation Function
@torch.no_grad()
def generate(model, vocab, prompt, max_len=20):
    device = next(model.parameters()).device
    model.eval()
    tokens = [vocab.word2idx.get(word.lower(), vocab.word2idx[config.unk_token])
             for word in prompt.split()]
    tokens = [vocab.word2idx[config.sos_token]] + tokens

    for _ in range(max_len):
        if len(tokens) >= config.max_length:
            break

        input_seq = torch.tensor(tokens[-config.max_length+1:], dtype=torch.long).unsqueeze(0).to(device)
        mask = create_mask(len(input_seq[0])).to(device)
        logits = model(input_seq, mask)
        next_token = logits[0, -1].argmax().item()
        tokens.append(next_token)
        if next_token == vocab.word2idx[config.eos_token]:
            break

    output = []
    for t in tokens[1:]:  # Skip <sos>
        if t == vocab.word2idx[config.eos_token]:
            break
        if t not in [vocab.word2idx[config.sos_token], vocab.word2idx[config.pad_token]]:
            output.append(vocab.idx2word[t])
    return ' '.join(output).capitalize()

# Main Execution
if __name__ == "__main__":
    print("Starting training...")
    model, vocab = train_model()

    print("\nSample Generations:")
    test_prompts = ["the", "dogs", "children"]
    for prompt in test_prompts:
        generated = generate(model, vocab, prompt)
        print(f"'{prompt}' → {generated}")

Starting training...
Using device: cpu
Epoch 10/100 | Loss: 1.4878
Epoch 20/100 | Loss: 0.8245
Epoch 30/100 | Loss: 0.5104
Epoch 40/100 | Loss: 0.3831
Epoch 50/100 | Loss: 0.3342
Epoch 60/100 | Loss: 0.3084
Epoch 70/100 | Loss: 0.2879
Epoch 80/100 | Loss: 0.2850
Epoch 90/100 | Loss: 0.2756
Epoch 100/100 | Loss: 0.2875

Sample Generations:
'the' → The cat sits on the mat
'dogs' → Dogs love to chase balls
'children' → Children play in the park


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import random
from typing import List, Tuple, Dict

def main():
    print("Initializing...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize vocabulary and data
    vocab = Vocabulary()
    vocab.build_vocab([ex[1] for ex in config.few_shot_examples])

    # Create dataset and dataloader
    train_dataset = TextDataset(vocab, config.train_data)
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

    # Initialize and train model
    model = Transformer(len(vocab)).to(device)
    print(f"\nTraining model with {sum(p.numel() for p in model.parameters())} parameters...")

    model, losses = train_model(model, train_loader, vocab, device)
    print("\nTraining completed!")

    # Evaluate on few-shot examples and generate new samples
    evaluate_few_shot(model, vocab, device, config.few_shot_examples)

# Replace just the Config class with this updated version:
class Config:
    def __init__(self):
        # Training data
        self.train_data = [
            "The cat sits on the mat",
            "Dogs love to chase balls",
            "Children play in the park",
            "Birds sing in the trees",
            "She reads a good book",
            "The sun shines brightly today",
            "Students study in the library",
            "Fish swim in the ocean",
            "The wind blows through leaves",
            "People walk in the garden"
        ]

        # Few-shot examples (not used in training)
        self.few_shot_examples = [
            ("The mouse", "The mouse runs through the house"),
            ("A teacher", "A teacher explains the lesson to students"),
            ("The car", "The car drives down the street quickly"),
            ("In winter", "In winter snow falls softly on the ground"),
            ("The coffee", "The coffee steams in the morning light")
        ]

        # Model configuration
        self.max_length = 16  # Increased for longer sequences
        self.batch_size = 4
        self.epochs = 150
        self.d_model = 256  # Increased model capacity
        self.n_heads = 8
        self.n_layers = 6
        self.ffn_dim = 1024
        self.dropout = 0.1
        self.lr = 3e-4

        # Special tokens
        self.pad_token = "<pad>"
        self.sos_token = "<sos>"
        self.eos_token = "<eos>"
        self.unk_token = "<unk>"

        # Training settings
        self.warmup_steps = 100  # Added missing parameter
        self.gradient_clip = 1.0
        self.eval_every = 10
        self.min_lr = 1e-9  # Added minimum learning rate
        self.weight_decay = 0.01  # Added weight decay
        self.max_grad_norm = 1.0  # Added maximum gradient norm

        # Generation settings
        self.temperature = 0.7
        self.max_gen_length = 30

        # Evaluation settings
        self.num_samples = 3

config = Config()

class Vocabulary:
    def __init__(self):
        self.special_tokens = [config.pad_token, config.sos_token,
                             config.eos_token, config.unk_token]
        self.word2idx = {t:i for i,t in enumerate(self.special_tokens)}
        self.idx2word = self.special_tokens.copy()

    def build_vocab(self, sentences):
        for sentence in sentences:
            for word in sentence.lower().split():
                if word not in self.word2idx:
                    self.word2idx[word] = len(self.idx2word)
                    self.idx2word.append(word)

    def encode(self, text: str) -> List[int]:
        return [self.word2idx.get(word.lower(), self.word2idx[config.unk_token])
                for word in text.split()]

    def decode(self, indices: List[int]) -> str:
        return ' '.join([self.idx2word[idx] for idx in indices
                        if idx not in [self.word2idx[t] for t in [config.pad_token, config.sos_token, config.eos_token]]])

    def __len__(self):
        return len(self.idx2word)

class TextDataset(Dataset):
    def __init__(self, vocab: Vocabulary, data: List[str]):
        self.vocab = vocab
        self.data = []

        for sentence in data:
            tokens = [config.sos_token] + sentence.lower().split() + [config.eos_token]
            indices = [self.vocab.word2idx.get(word, self.vocab.word2idx[config.unk_token])
                      for word in tokens]
            if len(indices) > config.max_length:
                indices = indices[:config.max_length]
            self.data.append(indices)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        pad_len = config.max_length - len(seq)
        inputs = seq[:-1] + [self.vocab.word2idx[config.pad_token]] * (pad_len + 1)
        targets = seq[1:] + [self.vocab.word2idx[config.pad_token]] * (pad_len + 1)
        return (
            torch.tensor(inputs[:config.max_length-1], dtype=torch.long),
            torch.tensor(targets[:config.max_length-1], dtype=torch.long)
        )

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.attention = nn.MultiheadAttention(config.d_model, config.n_heads,
                                             dropout=config.dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(config.d_model)
        self.norm2 = nn.LayerNorm(config.d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(config.d_model, config.ffn_dim),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.ffn_dim, config.d_model)
        )
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, mask=None):
        attention_output, _ = self.attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attention_output))
        feedforward_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(feedforward_output))
        return x

class Transformer(nn.Module):
    def __init__(self, vocab_size: int):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, config.d_model)
        self.position_encoding = PositionalEncoding(config.d_model)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock() for _ in range(config.n_layers)]
        )
        self.norm = nn.LayerNorm(config.d_model)
        self.fc_out = nn.Linear(config.d_model, vocab_size)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, mask=None):
        x = self.dropout(self.position_encoding(self.token_embedding(x)))
        for block in self.transformer_blocks:
            x = block(x, mask)
        return self.fc_out(self.norm(x))

def create_mask(size: int) -> torch.Tensor:
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

class WarmupScheduler:
    def __init__(self, optimizer, d_model, warmup_steps):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.current_step = 0

    def step(self):
        self.current_step += 1
        lr = self.d_model ** (-0.5) * min(self.current_step ** (-0.5),
                                         self.current_step * self.warmup_steps ** (-1.5))
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

def train_model(model: nn.Module, train_loader: DataLoader, vocab: Vocabulary,
                device: torch.device) -> Tuple[nn.Module, List[float]]:
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx[config.pad_token])
    optimizer = torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98))
    scheduler = WarmupScheduler(optimizer, config.d_model, config.warmup_steps)

    losses = []
    mask = create_mask(config.max_length-1).to(device)

    for epoch in range(config.epochs):
        model.train()
        total_loss = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()

            outputs = model(inputs, mask)
            loss = criterion(outputs.view(-1, len(vocab)), targets.view(-1))

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
            optimizer.step()
            scheduler.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        losses.append(avg_loss)

        if (epoch + 1) % config.eval_every == 0:
            print(f"Epoch {epoch+1}/{config.epochs} | Loss: {avg_loss:.4f}")

            # Generate a sample during training
            if (epoch + 1) % (config.eval_every * 2) == 0:
                model.eval()
                sample_prompt = random.choice(config.few_shot_examples)[0]
                generated = generate(model, vocab, sample_prompt, device)
                print(f"Sample generation: '{sample_prompt}' → {generated}\n")

    return model, losses

@torch.no_grad()
def generate(model: nn.Module, vocab: Vocabulary, prompt: str, device: torch.device,
            max_len: int = 30, temperature: float = 0.7) -> str:
    model.eval()
    tokens = [vocab.word2idx[config.sos_token]] + vocab.encode(prompt)

    for _ in range(max_len):
        if len(tokens) >= config.max_length:
            break

        input_seq = torch.tensor(tokens[-config.max_length+1:], dtype=torch.long).unsqueeze(0).to(device)
        mask = create_mask(len(input_seq[0])).to(device)

        logits = model(input_seq, mask)
        logits = logits[0, -1, :] / temperature
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, 1).item()

        tokens.append(next_token)
        if next_token == vocab.word2idx[config.eos_token]:
            break

    return vocab.decode(tokens[1:])  # Skip <sos>

def evaluate_few_shot(model: nn.Module, vocab: Vocabulary, device: torch.device,
                     examples: List[Tuple[str, str]], num_samples: int = 3) -> None:
    model.eval()
    print("\nFew-shot Generation Examples:")

    for prompt, target in examples[:num_samples]:
        generated = generate(model, vocab, prompt, device)
        print(f"\nPrompt: {prompt}")
        print(f"Generated: {generated}")
        print(f"Target: {target}")

    # Generate with new unseen prompts
    print("\nGeneration with Unseen Prompts:")
    new_prompts = [
        "The rainbow",
        "In space",
        "The robot",
        "During summer",
        "The musician"
    ]

    for prompt in new_prompts:
        generated = generate(model, vocab, prompt, device)
        print(f"\nPrompt: {prompt}")
        print(f"Generated: {generated}")

def main():
    print("Initializing...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize vocabulary and data
    vocab = Vocabulary()
    vocab.build_vocab([ex[1] for ex in config.few_shot_examples])

    # Create dataset and dataloader
    train_dataset = TextDataset(vocab, config.train_data)
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

    # Initialize and train model
    model = Transformer(len(vocab)).to(device)
    print(f"\nTraining model with {sum(p.numel() for p in model.parameters())} parameters...")

    model, losses = train_model(model, train_loader, vocab, device)
    print("\nTraining completed!")

    # Evaluate on few-shot examples and generate new samples
    evaluate_few_shot(model, vocab, device, config.few_shot_examples)



if __name__ == "__main__":
    main()

Initializing...
Using device: cpu

Training model with 4754975 parameters...
Epoch 10/150 | Loss: 0.7461
Epoch 20/150 | Loss: 1.3769
Sample generation: 'In winter' → in winter <unk> <unk>

Epoch 30/150 | Loss: 1.4621
Epoch 40/150 | Loss: 1.4615
Sample generation: 'The coffee' → the coffee <unk> <unk> <unk> <unk> the the <unk> <unk> <unk> <unk> through <unk> <unk>

Epoch 50/150 | Loss: 1.4704
Epoch 60/150 | Loss: 1.4679
Sample generation: 'The coffee' → the coffee

Epoch 70/150 | Loss: 1.5305
Epoch 80/150 | Loss: 1.4937
Sample generation: 'The car' → the car <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> the the <unk> <unk> <unk>

Epoch 90/150 | Loss: 1.5428
Epoch 100/150 | Loss: 1.4665
Sample generation: 'A teacher' → a teacher the <unk> <unk> <unk> <unk> in in <unk> <unk>

Epoch 110/150 | Loss: 1.4065
Epoch 120/150 | Loss: 1.4934
Sample generation: 'The coffee' → the coffee <unk> <unk> <unk> <unk> <unk> <unk> students <unk> <unk>

Epoch 130/150 | Loss: 1.4496
Epoch 140/150 | Loss: 1.