In [None]:
import torch , math
from torch.utils.data import Dataset
from tqdm import tqdm
import os , re , json
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
from collections import  Counter
import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
def train_epoch(model, dataloader, optimizer, criterion, device, epoch):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch} - Training")

    for batch in progress_bar:
        batch = batch.to(device)
        inputs = batch[:, :-1]
        targets = batch[:, 1:]

        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / len(dataloader)
    return avg_loss

def calculate_perplexity(loss):
    return math.exp(min(loss, 20))


In [None]:
def training_generate(model, prompt, vocab, max_length=50, temperature=0.8, top_k=40, device='cuda'):
    model.eval()
    tokens = [vocab.word2idx[vocab.SOS_TOKEN]] + vocab.encode(prompt)
    tokens = torch.tensor(tokens).unsqueeze(0).to(device)

    with torch.no_grad():
        for _ in range(max_length):
            if tokens.size(1) >= model.max_seq_len:
                break

            logits = model(tokens)
            logits = logits[:, -1, :] / temperature

            if top_k > 0:
                top_k_logits, top_k_indices = torch.topk(logits, top_k)
                logits = torch.full_like(logits, float('-inf'))
                logits.scatter_(1, top_k_indices, top_k_logits)

            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)

            if next_token.item() == vocab.word2idx[vocab.EOS_TOKEN]:
                break

            tokens = torch.cat([tokens, next_token], dim=1)

    generated_tokens = tokens.squeeze(0).tolist()
    return vocab.decode(generated_tokens)



In [None]:
def visualize_attention(model, text, vocab, device, save_path='attention_viz'):
    """Visualize attention patterns for a given text"""
    os.makedirs(save_path, exist_ok=True)

    model.eval()
    tokens = [vocab.word2idx[vocab.SOS_TOKEN]] + vocab.encode(text)
    token_words = [vocab.idx2word[t] for t in tokens[:20]]  # Limit to 20 tokens
    tokens = torch.tensor(tokens[:20]).unsqueeze(0).to(device)

    with torch.no_grad():
        _, attention_weights = model(tokens, return_attention=True)

    # Visualize each layer
    for layer_idx, layer_attn in enumerate(attention_weights):
        # layer_attn shape: (batch, num_heads, seq_len, seq_len)
        layer_attn = layer_attn[0].cpu().numpy()  # Remove batch dimension

        num_heads = layer_attn.shape[0]
        fig, axes = plt.subplots(2, num_heads // 2, figsize=(20, 8))
        axes = axes.flatten()

        for head_idx in range(num_heads):
            attn = layer_attn[head_idx]

            ax = axes[head_idx]
            sns.heatmap(attn, ax=ax, cmap='viridis',
                       xticklabels=token_words,
                       yticklabels=token_words,
                       cbar=True, square=True)
            ax.set_title(f'Head {head_idx + 1}')
            ax.set_xlabel('Key')
            ax.set_ylabel('Query')

        plt.suptitle(f'Layer {layer_idx + 1} - Attention Patterns')
        plt.tight_layout()
        plt.savefig(f'{save_path}/layer_{layer_idx + 1}_attention.png', dpi=150)
        plt.close()

    print(f"Attention visualizations saved to {save_path}/")

def plot_training_curves(train_losses, val_losses, train_perplexities, val_perplexities, save_path='plots'):
    """Plot training curves"""
    os.makedirs(save_path, exist_ok=True)

    # Loss curves
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss', marker='o')
    plt.plot(val_losses, label='Val Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)

    # Perplexity curves
    plt.subplot(1, 2, 2)
    plt.plot(train_perplexities, label='Train Perplexity', marker='o')
    plt.plot(val_perplexities, label='Val Perplexity', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Perplexity')
    plt.title('Training and Validation Perplexity')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'results/training/training_curves.png', dpi=150)
    plt.close()
    print(f"Training curves saved to {save_path}/training_curves.png")


In [None]:
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            batch = batch.to(device)
            inputs = batch[:, :-1]
            targets = batch[:, 1:]

            logits = model(inputs)
            loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss


In [None]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1), :]


In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        # assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        output = self.W_o(context)
        return output, attn_weights



In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = LayerNorm(d_model)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output, attn_weights = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x, attn_weights


In [None]:
class DecoderTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads,
                 d_ff, max_seq_len, dropout=0.1, pretrained_embeddings=None):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        # Store the original embedding dimension (300 for FastText)
        self.embedding_dim = pretrained_embeddings.shape[1] if pretrained_embeddings is not None else d_model

        # Create embedding layer with original FastText dimension
        self.embedding = nn.Embedding(vocab_size, self.embedding_dim)

        # Load pretrained embeddings if provided
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
            # Add projection layer to convert from FastText dim to d_model
            self.embedding_proj = nn.Linear(self.embedding_dim, d_model)
        else:
            self.embedding_proj = nn.Identity()  # No projection needed if no pretrained embeddings

        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def create_causal_mask(self, seq_len, device):
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        mask = mask.unsqueeze(0).unsqueeze(0)
        return mask

    def forward(self, x, return_attention=False):
        batch_size, seq_len = x.shape
        mask = self.create_causal_mask(seq_len, x.device)

        # Get embeddings in original dimension (300)
        x = self.embedding(x) * math.sqrt(self.embedding_dim)

        # Project to d_model (which is divisible by num_heads)
        x = self.embedding_proj(x)

        x = self.pos_encoding(x)
        x = self.dropout(x)
        attention_weights = []
        for layer in self.layers:
            x, attn_weights = layer(x, mask)
            if return_attention:
                attention_weights.append(attn_weights)
        x = self.norm(x)
        logits = self.output_projection(x)
        if return_attention:
            return logits, attention_weights
        return logits



In [None]:
class Vocabulary:
    def __init__(self, fasttext_model=None):
        self.word2idx = {}
        self.idx2word = {}
        self.word_counts = Counter()
        self.PAD_TOKEN = '<pad>'
        self.SOS_TOKEN = '<sos>'
        self.EOS_TOKEN = '<eos>'
        self.UNK_TOKEN = '<unk>'
        self.add_word(self.PAD_TOKEN)
        self.add_word(self.SOS_TOKEN)
        self.add_word(self.EOS_TOKEN)
        self.add_word(self.UNK_TOKEN)
        self.fasttext_model = fasttext_model

    def add_word(self, word):
        if word not in self.word2idx:
            idx = len(self.word2idx)
            self.word2idx[word] = idx
            self.idx2word[idx] = word
        self.word_counts[word] += 1

    def __len__(self):
        return len(self.word2idx)

    def encode(self, text):
        tokens = self.tokenize(text)
        return [self.word2idx.get(token, self.word2idx[self.UNK_TOKEN])
                for token in tokens]

    def decode(self, indices):
        words = []
        for idx in indices:
            if idx in [self.word2idx[self.PAD_TOKEN], self.word2idx[self.SOS_TOKEN]]:
                continue
            if idx == self.word2idx[self.EOS_TOKEN]:
                break
            words.append(self.idx2word.get(idx, self.UNK_TOKEN))
        return ' '.join(words)

    def tokenize(self, text):
        text = text.lower()
        tokens = re.findall(r'\b\w+\b|[.,!?;]', text)
        return tokens

    def create_embedding_matrix(self):
        embedding_matrix = torch.randn(len(self.word2idx), 300) * 0.01
        if self.fasttext_model is not None:
            found = 0
            for word, idx in self.word2idx.items():
                if word in self.fasttext_model:
                    embedding_matrix[idx] = torch.tensor(self.fasttext_model[word])
                    found += 1
            print(f"Found {found}/{len(self.word2idx)} words in FastText")
        return embedding_matrix

    def save(self, path):
        with open(path, 'w') as f:
            json.dump({
                'word2idx': self.word2idx,
                'idx2word': {int(k): v for k, v in self.idx2word.items()},
                'word_counts': dict(self.word_counts)
            }, f)

    @classmethod
    def load(cls, path, fasttext_model=None):
        vocab = cls(fasttext_model)
        with open(path, 'r') as f:
            data = json.load(f)
        vocab.word2idx = data['word2idx']
        vocab.idx2word = {int(k): v for k, v in data['idx2word'].items()}
        vocab.word_counts = Counter(data['word_counts'])
        return vocab


In [None]:
class TinyStoriesDataset(Dataset):
    def __init__(self, texts, vocab, context_length, max_samples=None):
        self.vocab = vocab
        self.context_length = context_length
        self.sequences = []

        print("Preparing dataset...")
        for idx, text in enumerate(tqdm(texts)):
            if max_samples and idx >= max_samples:
                break

            tokens = [vocab.word2idx[vocab.SOS_TOKEN]] + vocab.encode(text) + [vocab.word2idx[vocab.EOS_TOKEN]]

            for i in range(len(tokens) - 1):
                end_idx = min(i + context_length + 1, len(tokens))
                seq = tokens[i:end_idx]

                if len(seq) < context_length + 1:
                    seq = seq + [vocab.word2idx[vocab.PAD_TOKEN]] * (context_length + 1 - len(seq))

                self.sequences.append(seq)

        print(f"Created {len(self.sequences)} sequences")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx], dtype=torch.long)



In [None]:
CONFIG = {
    'name': 'baseline',
    'description': 'Standard baseline configuration from assignment',
    'context_length': 64,
    'num_layers': 3,
    'num_heads': 8,
    'd_model': 296,
    'd_ff': 1184,
    'dropout': 0.1,
    'batch_size': 32,
    'learning_rate': 3e-4,
    'num_epochs': 5,
    'max_train_samples': 50000,
    'max_val_samples': 15000,
    'save_dir': 'checkpoints/baseline',
    'plot_dir': 'plots/baseline'
}

In [None]:
import gensim.downloader as api
from gensim.models import KeyedVectors
import os

def load_fasttext_model():
    model_path = 'fasttext/fasttext_model.bin'
    
    os.makedirs('fasttext', exist_ok=True)
    
    if not os.path.exists(model_path):
        print("Model not found. Downloading FastText model...")
        
        model = api.load('fasttext-wiki-news-subwords-300')
        
        model.save(model_path)
        print("Model downloaded and saved successfully!")
    else:
        print("Loading FastText model from cache...")
        model = KeyedVectors.load(model_path)
        print("Model loaded successfully!")
    
    return model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"\nConfiguration:")
for k, v in CONFIG.items():
     print(f"  {k}: {v}")

    # Load FastText
print("\n" + "="*50)
print("Loading FastText embeddings...")
print("="*50)
fasttext_model = load_fasttext_model()

    # Load Dataset
print("\n" + "="*50)
print("Loading TinyStories dataset...")
print("="*50)
dataset = load_dataset("roneneldan/TinyStories")

print("\n" + "="*50)
print("Building vocabulary...")
print("="*50)
vocab_path = f"{CONFIG['save_dir']}/vocab.json"

if os.path.exists(vocab_path):
        print("Loading existing vocabulary...")
        vocab = Vocabulary.load(vocab_path, fasttext_model)
else:
    vocab = Vocabulary(fasttext_model)
    # Build vocabulary from training data
    num_samples = min(CONFIG['max_train_samples'], len(dataset['train']))
    for i in tqdm(range(num_samples), desc="Building vocabulary"):
        text = dataset['train'][i]['text']
        for word in vocab.tokenize(text):
            vocab.add_word(word)
    vocab.save(vocab_path)

print(f"Vocabulary size: {len(vocab)}")

# Create Datasets
print("\n" + "="*50)
print("Creating datasets...")
print("="*50)

# Prepare train texts
train_texts = [dataset['train'][i]['text'] for i in range(min(CONFIG['max_train_samples'], len(dataset['train'])))]
val_texts = [dataset['validation'][i]['text'] for i in range(min(CONFIG['max_val_samples'], len(dataset['validation'])))]

train_dataset = TinyStoriesDataset(
    train_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_train_samples']
)

val_dataset = TinyStoriesDataset(
    val_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_val_samples']
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                         shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                       shuffle=False, num_workers=0)

# Initialize Model
print("\n" + "="*50)
print("Initializing model...")
print("="*50)
embedding_matrix = vocab.create_embedding_matrix()



In [None]:
model = DecoderTransformer(
    vocab_size=len(vocab),
    d_model=CONFIG['d_model'],
    num_layers=CONFIG['num_layers'],
    num_heads=CONFIG['num_heads'],
    d_ff=CONFIG['d_ff'],
    max_seq_len=CONFIG['context_length'],
    dropout=CONFIG['dropout'],
    pretrained_embeddings=embedding_matrix
).to(device)


In [None]:
checkpoint = torch.load('best_model.pt', map_location=device)
checkpoint_last = torch.load('last_model.pt', map_location=device)
model.load_state_dict(checkpoint_last['model_state_dict'])
model.eval()

In [None]:
# Print all keys stored in the checkpoint
print("Checkpoint keys:", checkpoint.keys())

# Print details
print("\nEpoch:", checkpoint['epoch'])
print("Train Loss:", checkpoint['train_loss'])
print("Validation Loss:", checkpoint['val_loss'])
print("\nConfig:")
print(checkpoint['config'])

In [None]:
# Print all keys stored in the checkpoint
print("Checkpoint keys:", checkpoint_last.keys())

# Print details
print("\nEpoch:", checkpoint_last['epoch'])
print("Train Loss:", checkpoint_last['train_loss'])
print("Validation Loss:", checkpoint_last['val_loss'])
print("\nConfig:")
print(checkpoint_last['config'])

In [None]:
# Initialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
optimizer.load_state_dict(checkpoint_last['optimizer_state_dict'])
criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx[vocab.PAD_TOKEN])


In [None]:
print("\n" + "="*50)
print("Starting training...")
print("="*50)

train_losses = []
val_losses = []
train_perplexities = []
val_perplexities = []
best_val_loss = checkpoint['val_loss']
initial_epoch = checkpoint_last['epoch']

for epoch in range(initial_epoch + 1, initial_epoch + CONFIG['num_epochs'] + 1):
    print(f"\n{'='*50}")
    print(f"Epoch {epoch}/{initial_epoch + CONFIG['num_epochs'] + 1}")
    print(f"{'='*50}")

    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device, epoch)
    train_ppl = calculate_perplexity(train_loss)

    # Validate
    val_loss = validate(model, val_loader, criterion, device)
    val_ppl = calculate_perplexity(val_loss)

    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_perplexities.append(train_ppl)
    val_perplexities.append(val_ppl)

    print(f"\nEpoch {epoch} Results:")
    print(f"  Train Loss: {train_loss:.4f} | Train PPL: {train_ppl:.2f}")
    print(f"  Val Loss:   {val_loss:.4f} | Val PPL:   {val_ppl:.2f}")
    import json

    # after training loop of each epoch
    results = {
            "epoch": epoch,
            "train_loss": round(train_loss, 4),
            "train_ppl": round(train_ppl, 2),
            "val_loss": round(val_loss, 4),
            "val_ppl": round(val_ppl, 2)
    }
    
    file_path = r"results/training/training_results.json"
    os.makedirs(os.path.dirname(file_path), exist_ok=True)

    # Step 1: Read existing content or initialize empty list
    if os.path.exists(file_path):
        with open(file_path, "r") as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError:
                data = []  # If file is empty or invalid
    else:
        data = []

    # Step 2: Append new results
    data.append(results)

    # Step 3: Write back to file
    with open(file_path, "w") as f:
        json.dump(data, f, indent=4)  

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': CONFIG
        }, "best_model.pt")
        print(f"  ✓ Saved best model (val_loss: {val_loss:.4f})")

    # Generate samples
    print(f"\n  Sample Generations:")
    for prompt in ["Once upon a time", "The little girl", "In the forest"]:
        generated = training_generate(model, prompt, vocab, max_length=30, device=device)
        print(f"    '{prompt}' → {generated}")

In [None]:
# Plot training curves
print("\n" + "="*50)
print("Plotting training curves...")
print("="*50)
plot_training_curves(train_losses, val_losses, train_perplexities,
                    val_perplexities, CONFIG['plot_dir'])


# Visualize Attention
print("\n" + "="*50)
print("Generating attention visualizations...")
print("="*50)

sample_texts = [
        "Once upon a time there was a little girl",
        "The cat sat on the mat and looked around",
        "A boy went to the park to play"
]

for i, text in enumerate(sample_texts):
    visualize_attention(model, text, vocab, device,
                          save_path=f"results/training/attention_sample_{i+1}")

print("\n" + "="*50)
print("Training complete!")
print("="*50)
print(f"Best model saved to: {CONFIG['save_dir']}/best_model.pt")
print(f"Plots saved to: {CONFIG['plot_dir']}/")

In [None]:
from inference.inference import evaluate_model

results = evaluate_model(
        model=model,
        val_dataset=val_dataset,
        vocab=vocab,
        num_samples=50,
        prompt_length=5,
        max_generation_length=50,
        device='cuda'
    )
    
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)
print(f"Number of samples evaluated: {results['num_samples_evaluated']}")
print(f"Average Perplexity: {results['avg_perplexity']:.4f}")
print(f"BLEU Score: {results['bleu_score']:.6f}")
print("="*50)

In [None]:
torch.save({
            'epoch': initial_epoch + CONFIG['num_epochs'],
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': CONFIG
        }, "last_model.pt")