In [None]:
import torch , math
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
import matplotlib.pyplot as plt
import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns


In [None]:
from model.DecoderTransformer import DecoderTransformer
from Dataset.Vocabulary import Vocabulary
from Dataset.TinyStories import TinyStoriesDataset
from Dataset.load_fasttext_model import load_fasttext_model

# Train

In [3]:
def train_epoch(model, dataloader, optimizer, criterion, device, epoch):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch} - Training")

    for batch in progress_bar:
        batch = batch.to(device)
        inputs = batch[:, :-1]
        targets = batch[:, 1:]

        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_loss = total_loss / len(dataloader)
    return avg_loss

def calculate_perplexity(loss):
    return math.exp(min(loss, 20))


In [4]:
def training_generate(model, prompt, vocab, max_length=50, temperature=0.8, top_k=40, device='cuda'):
    model.eval()
    tokens = [vocab.word2idx[vocab.SOS_TOKEN]] + vocab.encode(prompt)
    tokens = torch.tensor(tokens).unsqueeze(0).to(device)

    with torch.no_grad():
        for _ in range(max_length):
            if tokens.size(1) >= model.max_seq_len:
                break

            logits = model(tokens)
            logits = logits[:, -1, :] / temperature

            if top_k > 0:
                top_k_logits, top_k_indices = torch.topk(logits, top_k)
                logits = torch.full_like(logits, float('-inf'))
                logits.scatter_(1, top_k_indices, top_k_logits)

            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)

            if next_token.item() == vocab.word2idx[vocab.EOS_TOKEN]:
                break

            tokens = torch.cat([tokens, next_token], dim=1)

    generated_tokens = tokens.squeeze(0).tolist()
    return vocab.decode(generated_tokens)


In [5]:
def visualize_attention(model, text, vocab, device, save_path='attention_viz'):
    """Visualize attention patterns for a given text"""
    os.makedirs(save_path, exist_ok=True)

    model.eval()
    tokens = [vocab.word2idx[vocab.SOS_TOKEN]] + vocab.encode(text)
    token_words = [vocab.idx2word[t] for t in tokens[:20]]  # Limit to 20 tokens
    tokens = torch.tensor(tokens[:20]).unsqueeze(0).to(device)

    with torch.no_grad():
        _, attention_weights = model(tokens, return_attention=True)

    # Visualize each layer
    for layer_idx, layer_attn in enumerate(attention_weights):
        # layer_attn shape: (batch, num_heads, seq_len, seq_len)
        layer_attn = layer_attn[0].cpu().numpy()  # Remove batch dimension

        num_heads = layer_attn.shape[0]
        fig, axes = plt.subplots(2, num_heads // 2, figsize=(20, 8))
        axes = axes.flatten()

        for head_idx in range(num_heads):
            attn = layer_attn[head_idx]

            ax = axes[head_idx]
            sns.heatmap(attn, ax=ax, cmap='viridis',
                       xticklabels=token_words,
                       yticklabels=token_words,
                       cbar=True, square=True)
            ax.set_title(f'Head {head_idx + 1}')
            ax.set_xlabel('Key')
            ax.set_ylabel('Query')

        plt.suptitle(f'Layer {layer_idx + 1} - Attention Patterns')
        plt.tight_layout()
        plt.savefig(f'{save_path}/layer_{layer_idx + 1}_attention.png', dpi=150)
        plt.close()

    print(f"Attention visualizations saved to {save_path}/")

def plot_training_curves(train_losses, val_losses, train_perplexities, val_perplexities, save_path='plots'):
    """Plot training curves"""
    os.makedirs(save_path, exist_ok=True)

    # Loss curves
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss', marker='o')
    plt.plot(val_losses, label='Val Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)

    # Perplexity curves
    plt.subplot(1, 2, 2)
    plt.plot(train_perplexities, label='Train Perplexity', marker='o')
    plt.plot(val_perplexities, label='Val Perplexity', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Perplexity')
    plt.title('Training and Validation Perplexity')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(f'results/training/training_curves.png', dpi=150)
    plt.close()
    print(f"Training curves saved to {save_path}/training_curves.png")


In [6]:
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            batch = batch.to(device)
            inputs = batch[:, :-1]
            targets = batch[:, 1:]

            logits = model(inputs)
            loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss


In [15]:
CONFIG = {
    'name': 'baseline',
    'description': 'Standard baseline configuration from assignment',
    'context_length': 64,
    'num_layers': 3,
    'num_heads': 8,
    'd_model': 296,
    'd_ff': 1184,
    'dropout': 0.1,
    'batch_size': 32,
    'learning_rate': 3e-4,
    'num_epochs': 5,
    'max_train_samples': 50000,
    'max_val_samples': 15000,
    'save_dir': 'checkpoints/baseline',
    'plot_dir': 'plots/baseline'
}

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"\nConfiguration:")
for k, v in CONFIG.items():
     print(f"  {k}: {v}")

    # Load FastText
print("\n" + "="*50)
print("Loading FastText embeddings...")
print("="*50)
fasttext_model = load_fasttext_model()

    # Load Dataset
print("\n" + "="*50)
print("Loading TinyStories dataset...")
print("="*50)
dataset = load_dataset("roneneldan/TinyStories")

print("\n" + "="*50)
print("Building vocabulary...")
print("="*50)
vocab_path = f"{CONFIG['save_dir']}/vocab.json"

if os.path.exists(vocab_path):
        print("Loading existing vocabulary...")
        vocab = Vocabulary.load(vocab_path, fasttext_model)
else:
    vocab = Vocabulary(fasttext_model)
    # Build vocabulary from training data
    num_samples = min(CONFIG['max_train_samples'], len(dataset['train']))
    for i in tqdm(range(num_samples), desc="Building vocabulary"):
        text = dataset['train'][i]['text']
        for word in vocab.tokenize(text):
            vocab.add_word(word)
    vocab.save(vocab_path)

print(f"Vocabulary size: {len(vocab)}")

# Create Datasets
print("\n" + "="*50)
print("Creating datasets...")
print("="*50)

# Prepare train texts
train_texts = [dataset['train'][i]['text'] for i in range(min(CONFIG['max_train_samples'], len(dataset['train'])))]
val_texts = [dataset['validation'][i]['text'] for i in range(min(CONFIG['max_val_samples'], len(dataset['validation'])))]

train_dataset = TinyStoriesDataset(
    train_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_train_samples']
)

val_dataset = TinyStoriesDataset(
    val_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_val_samples']
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                         shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                       shuffle=False, num_workers=0)

# Initialize Model
print("\n" + "="*50)
print("Initializing model...")
print("="*50)
embedding_matrix = vocab.create_embedding_matrix()



Using device: cuda

Configuration:
  name: baseline
  description: Standard baseline configuration from assignment
  context_length: 64
  num_layers: 3
  num_heads: 8
  d_model: 296
  d_ff: 1184
  dropout: 0.1
  batch_size: 32
  learning_rate: 0.0003
  num_epochs: 5
  max_train_samples: 50000
  max_val_samples: 15000
  save_dir: checkpoints/baseline
  plot_dir: plots/baseline

Loading FastText embeddings...
Loading FastText model from cache...
Model loaded successfully!

Loading TinyStories dataset...

Building vocabulary...
Loading existing vocabulary...
Vocabulary size: 10598

Creating datasets...
Preparing dataset...


100%|██████████| 50000/50000 [00:27<00:00, 1828.93it/s]


Created 10251722 sequences
Preparing dataset...


100%|██████████| 15000/15000 [00:07<00:00, 1952.54it/s]


Created 2900659 sequences

Initializing model...
Found 9972/10598 words in FastText


In [18]:
model = DecoderTransformer(
    vocab_size=len(vocab),
    d_model=CONFIG['d_model'],
    num_layers=CONFIG['num_layers'],
    num_heads=CONFIG['num_heads'],
    d_ff=CONFIG['d_ff'],
    max_seq_len=CONFIG['context_length'],
    dropout=CONFIG['dropout'],
    pretrained_embeddings=embedding_matrix
).to(device)


In [19]:
checkpoint = torch.load('best_model.pt', map_location=device)
checkpoint_last = torch.load('last_model.pt', map_location=device)
model.load_state_dict(checkpoint_last['model_state_dict'])
model.eval()

DecoderTransformer(
  (embedding): Embedding(10598, 300)
  (embedding_proj): Linear(in_features=300, out_features=296, bias=True)
  (pos_encoding): PositionalEncoding()
  (layers): ModuleList(
    (0-2): 3 x TransformerBlock(
      (attention): MultiHeadAttention(
        (W_q): Linear(in_features=296, out_features=296, bias=True)
        (W_k): Linear(in_features=296, out_features=296, bias=True)
        (W_v): Linear(in_features=296, out_features=296, bias=True)
        (W_o): Linear(in_features=296, out_features=296, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm1): LayerNorm()
      (feed_forward): FeedForward(
        (linear1): Linear(in_features=296, out_features=1184, bias=True)
        (linear2): Linear(in_features=1184, out_features=296, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (norm): LayerNorm()
  (output_projection): Linear(in_fe

In [20]:
# Print all keys stored in the checkpoint
print("Checkpoint keys:", checkpoint.keys())

# Print details
print("\nEpoch:", checkpoint['epoch'])
print("Train Loss:", checkpoint['train_loss'])
print("Validation Loss:", checkpoint['val_loss'])
print("\nConfig:")
print(checkpoint['config'])

Checkpoint keys: dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'train_loss', 'val_loss', 'config'])

Epoch: 4
Train Loss: 2.131404585684074
Validation Loss: 2.353617798573099

Config:
{'name': 'baseline', 'description': 'Standard baseline configuration from assignment', 'context_length': 64, 'num_layers': 3, 'num_heads': 8, 'd_model': 296, 'd_ff': 1184, 'dropout': 0.1, 'batch_size': 32, 'learning_rate': 0.0003, 'num_epochs': 5, 'max_train_samples': 50000, 'max_val_samples': 15000, 'save_dir': 'checkpoints/baseline', 'plot_dir': 'plots/baseline'}


In [21]:
# Print all keys stored in the checkpoint
print("Checkpoint keys:", checkpoint_last.keys())

# Print details
print("\nEpoch:", checkpoint_last['epoch'])
print("Train Loss:", checkpoint_last['train_loss'])
print("Validation Loss:", checkpoint_last['val_loss'])
print("\nConfig:")
print(checkpoint_last['config'])

Checkpoint keys: dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'train_loss', 'val_loss', 'config'])

Epoch: 7
Train Loss: 2.1328181815887066
Validation Loss: 2.3675910726597627

Config:
{'name': 'baseline', 'description': 'Standard baseline configuration from assignment', 'context_length': 64, 'num_layers': 3, 'num_heads': 8, 'd_model': 296, 'd_ff': 1184, 'dropout': 0.1, 'batch_size': 32, 'learning_rate': 0.0003, 'num_epochs': 5, 'max_train_samples': 50000, 'max_val_samples': 15000, 'save_dir': 'checkpoints/baseline', 'plot_dir': 'plots/baseline'}


In [22]:
# Initialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
optimizer.load_state_dict(checkpoint_last['optimizer_state_dict'])
criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx[vocab.PAD_TOKEN])


In [23]:
print("\n" + "="*50)
print("Starting training...")
print("="*50)

train_losses = []
val_losses = []
train_perplexities = []
val_perplexities = []
best_val_loss = checkpoint['val_loss']
initial_epoch = checkpoint_last['epoch']

for epoch in range(initial_epoch + 1, initial_epoch + CONFIG['num_epochs'] + 1):
    print(f"\n{'='*50}")
    print(f"Epoch {epoch}/{initial_epoch + CONFIG['num_epochs'] + 1}")
    print(f"{'='*50}")

    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device, epoch)
    train_ppl = calculate_perplexity(train_loss)

    # Validate
    val_loss = validate(model, val_loader, criterion, device)
    val_ppl = calculate_perplexity(val_loss)

    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_perplexities.append(train_ppl)
    val_perplexities.append(val_ppl)

    print(f"\nEpoch {epoch} Results:")
    print(f"  Train Loss: {train_loss:.4f} | Train PPL: {train_ppl:.2f}")
    print(f"  Val Loss:   {val_loss:.4f} | Val PPL:   {val_ppl:.2f}")
    import json

    # after training loop of each epoch
    results = {
            "epoch": epoch,
            "train_loss": round(train_loss, 4),
            "train_ppl": round(train_ppl, 2),
            "val_loss": round(val_loss, 4),
            "val_ppl": round(val_ppl, 2)
    }
    
    file_path = r"results/training/training_results.json"
    os.makedirs(os.path.dirname(file_path), exist_ok=True)

    # Step 1: Read existing content or initialize empty list
    if os.path.exists(file_path):
        with open(file_path, "r") as f:
            try:
                data = json.load(f)
            except json.JSONDecodeError:
                data = []  # If file is empty or invalid
    else:
        data = []

    # Step 2: Append new results
    data.append(results)

    # Step 3: Write back to file
    with open(file_path, "w") as f:
        json.dump(data, f, indent=4)  

    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': CONFIG
        }, "best_model.pt")
        print(f"  ✓ Saved best model (val_loss: {val_loss:.4f})")

    # Generate samples
    print(f"\n  Sample Generations:")
    for prompt in ["Once upon a time", "The little girl", "In the forest"]:
        generated = training_generate(model, prompt, vocab, max_length=30, device=device)
        print(f"    '{prompt}' → {generated}")


Starting training...

Epoch 8/13


Epoch 8 - Training:   0%|          | 515/320367 [00:18<3:11:20, 27.86it/s, loss=2.1003]


KeyboardInterrupt: 

In [None]:
# Plot training curves
print("\n" + "="*50)
print("Plotting training curves...")
print("="*50)
plot_training_curves(train_losses, val_losses, train_perplexities,
                    val_perplexities, CONFIG['plot_dir'])


# Visualize Attention
print("\n" + "="*50)
print("Generating attention visualizations...")
print("="*50)

sample_texts = [
        "Once upon a time there was a little girl",
        "The cat sat on the mat and looked around",
        "A boy went to the park to play"
]

for i, text in enumerate(sample_texts):
    visualize_attention(model, text, vocab, device,
                          save_path=f"results/training/attention_sample_{i+1}")

print("\n" + "="*50)
print("Training complete!")
print("="*50)
print(f"Best model saved to: {CONFIG['save_dir']}/best_model.pt")
print(f"Plots saved to: {CONFIG['plot_dir']}/")

# Inference Evaluation

In [None]:
from inference.inference import evaluate_model

results = evaluate_model(
        model=model,
        val_dataset=val_dataset,
        vocab=vocab,
        num_samples=50,
        prompt_length=5,
        max_generation_length=50,
        device='cuda'
    )
    
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)
print(f"Number of samples evaluated: {results['num_samples_evaluated']}")
print(f"Average Perplexity: {results['avg_perplexity']:.4f}")
print(f"BLEU Score: {results['bleu_score']:.6f}")
print("="*50)

In [None]:
torch.save({
            'epoch': initial_epoch + CONFIG['num_epochs'],
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': CONFIG
        }, "last_model.pt")