In [None]:
!pip install evaluate gensim

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting gensim
  Downloading gensim-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.1 kB)
Collecting numpy>=1.17 (from evaluate)
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting scipy<1.14.0,>=1.7.0 (from gensim)
  Downloading scipy-1.13.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.6/60.6 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gensim-4.3.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.w

In [None]:
import torch , math
from torch.utils.data import Dataset
from tqdm import tqdm
import os , re , json
import pickle
import hashlib
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset
from gensim.models import KeyedVectors
import gensim.downloader as api
import time
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict , Counter


In [None]:
def load_fasttext_model():
    model_path = './fasttext_model.bin'

    print("Loading FastText model from cache...")
    model = KeyedVectors.load(model_path)
    print("Model loaded successfully!")

    return model

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.gamma * x_norm + self.beta


In [None]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1), :]


In [None]:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        # assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context = torch.matmul(attn_weights, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        output = self.W_o(context)
        return output, attn_weights



In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = LayerNorm(d_model)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        attn_output, attn_weights = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x, attn_weights


In [None]:

class DecoderTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads,
                 d_ff, max_seq_len, dropout=0.1, pretrained_embeddings=None):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        # Store the original embedding dimension (300 for FastText)
        self.embedding_dim = pretrained_embeddings.shape[1] if pretrained_embeddings is not None else d_model

        # Create embedding layer with original FastText dimension
        self.embedding = nn.Embedding(vocab_size, self.embedding_dim)

        # Load pretrained embeddings if provided
        if pretrained_embeddings is not None:
            self.embedding.weight.data.copy_(pretrained_embeddings)
            # Add projection layer to convert from FastText dim to d_model
            self.embedding_proj = nn.Linear(self.embedding_dim, d_model)
        else:
            self.embedding_proj = nn.Identity()  # No projection needed if no pretrained embeddings

        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.norm = LayerNorm(d_model)
        self.output_projection = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def create_causal_mask(self, seq_len, device):
        mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
        mask = mask.unsqueeze(0).unsqueeze(0)
        return mask

    def forward(self, x, return_attention=False):
        batch_size, seq_len = x.shape
        mask = self.create_causal_mask(seq_len, x.device)

        # Get embeddings in original dimension (300)
        x = self.embedding(x) * math.sqrt(self.embedding_dim)

        # Project to d_model (which is divisible by num_heads)
        x = self.embedding_proj(x)

        x = self.pos_encoding(x)
        x = self.dropout(x)
        attention_weights = []
        for layer in self.layers:
            x, attn_weights = layer(x, mask)
            if return_attention:
                attention_weights.append(attn_weights)
        x = self.norm(x)
        logits = self.output_projection(x)
        if return_attention:
            return logits, attention_weights
        return logits



In [None]:
class Vocabulary:
    def __init__(self, fasttext_model=None):
        self.word2idx = {}
        self.idx2word = {}
        self.word_counts = Counter()
        self.PAD_TOKEN = '<pad>'
        self.SOS_TOKEN = '<sos>'
        self.EOS_TOKEN = '<eos>'
        self.UNK_TOKEN = '<unk>'
        self.add_word(self.PAD_TOKEN)
        self.add_word(self.SOS_TOKEN)
        self.add_word(self.EOS_TOKEN)
        self.add_word(self.UNK_TOKEN)
        self.fasttext_model = fasttext_model

    def add_word(self, word):
        if word not in self.word2idx:
            idx = len(self.word2idx)
            self.word2idx[word] = idx
            self.idx2word[idx] = word
        self.word_counts[word] += 1

    def __len__(self):
        return len(self.word2idx)

    def encode(self, text):
        tokens = self.tokenize(text)
        return [self.word2idx.get(token, self.word2idx[self.UNK_TOKEN])
                for token in tokens]

    def decode(self, indices):
        words = []
        for idx in indices:
            if idx in [self.word2idx[self.PAD_TOKEN], self.word2idx[self.SOS_TOKEN]]:
                continue
            if idx == self.word2idx[self.EOS_TOKEN]:
                break
            words.append(self.idx2word.get(idx, self.UNK_TOKEN))
        return ' '.join(words)

    def tokenize(self, text):
        text = text.lower()
        tokens = re.findall(r'\b\w+\b|[.,!?;]', text)
        return tokens

    def create_embedding_matrix(self):
        embedding_matrix = torch.randn(len(self.word2idx), 300) * 0.01
        if self.fasttext_model is not None:
            found = 0
            for word, idx in self.word2idx.items():
                if word in self.fasttext_model:
                    embedding_matrix[idx] = torch.tensor(self.fasttext_model[word])
                    found += 1
            print(f"Found {found}/{len(self.word2idx)} words in FastText")
        return embedding_matrix

    def save(self, path):
        with open(path, 'w') as f:
            json.dump({
                'word2idx': self.word2idx,
                'idx2word': {int(k): v for k, v in self.idx2word.items()},
                'word_counts': dict(self.word_counts)
            }, f)

    @classmethod
    def load(cls, path, fasttext_model=None):
        vocab = cls(fasttext_model)
        with open(path, 'r') as f:
            data = json.load(f)
        vocab.word2idx = data['word2idx']
        vocab.idx2word = {int(k): v for k, v in data['idx2word'].items()}
        vocab.word_counts = Counter(data['word_counts'])
        return vocab

    @classmethod
    def inference_load(cls, path):
        vocab = cls()
        with open(path, 'r') as f:
            data = json.load(f)
        vocab.word2idx = data['word2idx']
        vocab.idx2word = {int(k): v for k, v in data['idx2word'].items()}
        vocab.word_counts = Counter(data['word_counts'])
        return vocab



In [None]:
class TinyStoriesDataset(Dataset):
    def __init__(self, texts, vocab, context_length, max_samples=None, cache_dir="/"):
        self.vocab = vocab
        self.context_length = context_length
        self.sequences = []

        # Create cache directory if it doesn't exist
        os.makedirs(cache_dir, exist_ok=True)

        # Create a unique cache filename based on dataset parameters
        cache_key = self._generate_cache_key(texts, context_length, max_samples)
        cache_file = os.path.join(cache_dir, f"dataset_{cache_key}.pkl")

        # Try to load from cache
        if os.path.exists(cache_file):
            print(f"Loading dataset from cache: {cache_file}")
            with open(cache_file, 'rb') as f:
                self.sequences = pickle.load(f)
            print(f"Loaded {len(self.sequences)} sequences from cache")
        else:
            # Create dataset from scratch
            print("Preparing dataset...")
            for idx, text in enumerate(tqdm(texts)):
                if max_samples and idx >= max_samples:
                    break

                tokens = [vocab.word2idx[vocab.SOS_TOKEN]] + vocab.encode(text) + [vocab.word2idx[vocab.EOS_TOKEN]]

                for i in range(len(tokens) - 1):
                    end_idx = min(i + context_length + 1, len(tokens))
                    seq = tokens[i:end_idx]

                    if len(seq) < context_length + 1:
                        seq = seq + [vocab.word2idx[vocab.PAD_TOKEN]] * (context_length + 1 - len(seq))

                    self.sequences.append(seq)

            print(f"Created {len(self.sequences)} sequences")

            # Save to cache
            print(f"Saving dataset to cache: {cache_file}")
            with open(cache_file, 'wb') as f:
                pickle.dump(self.sequences, f)
            print("Dataset cached successfully")

    def _generate_cache_key(self, texts, context_length, max_samples):
        """Generate a unique hash key for this dataset configuration"""
        # Create a string representation of the key parameters
        key_str = f"{len(texts)}_{context_length}_{max_samples}"

        # Add hash of first and last text (to detect if texts changed)
        if len(texts) > 0:
            first_text_hash = hashlib.md5(texts[0].encode()).hexdigest()[:8]
            last_text_hash = hashlib.md5(texts[-1].encode()).hexdigest()[:8]
            key_str += f"_{first_text_hash}_{last_text_hash}"

        return hashlib.md5(key_str.encode()).hexdigest()[:16]

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx], dtype=torch.long)


In [None]:
CONFIG = {
    'name': 'baseline',
    'description': 'Standard baseline configuration from assignment',
    'context_length': 64,
    'num_layers': 3,
    'num_heads': 8,
    'd_model': 296,
    'd_ff': 1184,
    'dropout': 0.1,
    'batch_size': 32,
    'learning_rate': 3e-4,
    'num_epochs': 4,
    'max_train_samples': 15000,
    'max_val_samples': 5000,
    'save_dir': 'checkpoints/baseline',
    'plot_dir': 'plots/baseline'
}

In [None]:
os.makedirs(CONFIG['save_dir'], exist_ok=True)
os.makedirs(CONFIG['plot_dir'], exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"\nConfiguration:")
for k, v in CONFIG.items():
     print(f"  {k}: {v}")

    # Load FastText
print("\n" + "="*50)
print("Loading FastText embeddings...")
print("="*50)
fasttext_model = load_fasttext_model()

    # Load Dataset
print("\n" + "="*50)
print("Loading TinyStories dataset...")
print("="*50)
dataset = load_dataset("roneneldan/TinyStories")

print("\n" + "="*50)
print("Building vocabulary...")
print("="*50)
vocab_path = f"{CONFIG['save_dir']}/vocab.json"

if os.path.exists(vocab_path):
        print("Loading existing vocabulary...")
        vocab = Vocabulary.load(vocab_path, fasttext_model)
else:
    vocab = Vocabulary(fasttext_model)
    # Build vocabulary from training data
    num_samples = min(CONFIG['max_train_samples'], len(dataset['train']))
    for i in tqdm(range(num_samples), desc="Building vocabulary"):
        text = dataset['train'][i]['text']
        for word in vocab.tokenize(text):
            vocab.add_word(word)
    vocab.save(vocab_path)

print(f"Vocabulary size: {len(vocab)}")

# Create Datasets
print("\n" + "="*50)
print("Creating datasets...")
print("="*50)

# Prepare train texts
train_texts = [dataset['train'][i]['text'] for i in range(min(CONFIG['max_train_samples'], len(dataset['train'])))]
val_texts = [dataset['validation'][i]['text'] for i in range(min(CONFIG['max_val_samples'], len(dataset['validation'])))]

train_dataset = TinyStoriesDataset(
    train_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_train_samples']
)

val_dataset = TinyStoriesDataset(
    val_texts,
    vocab,
    CONFIG['context_length'],
    CONFIG['max_val_samples']
)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'],
                         shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'],
                       shuffle=False, num_workers=0)

# Initialize Model
print("\n" + "="*50)
print("Initializing model...")
print("="*50)
embedding_matrix = vocab.create_embedding_matrix()

model = DecoderTransformer(
    vocab_size=len(vocab),
    d_model=CONFIG['d_model'],
    num_layers=CONFIG['num_layers'],
    num_heads=CONFIG['num_heads'],
    d_ff=CONFIG['d_ff'],
    max_seq_len=CONFIG['context_length'],
    dropout=CONFIG['dropout'],
    pretrained_embeddings=embedding_matrix
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx[vocab.PAD_TOKEN])

Using device: cuda

Configuration:
  name: baseline
  description: Standard baseline configuration from assignment
  context_length: 64
  num_layers: 3
  num_heads: 8
  d_model: 296
  d_ff: 1184
  dropout: 0.1
  batch_size: 32
  learning_rate: 0.0003
  num_epochs: 10
  max_train_samples: 15000
  max_val_samples: 5000
  save_dir: checkpoints/baseline
  plot_dir: plots/baseline

Loading FastText embeddings...

Loading TinyStories dataset...

Building vocabulary...
Loading existing vocabulary...
Vocabulary size: 9551

Creating datasets...
Preparing dataset...


100%|██████████| 15000/15000 [00:11<00:00, 1314.70it/s]


Created 3083375 sequences
Preparing dataset...


100%|██████████| 5000/5000 [00:01<00:00, 2778.66it/s]


Created 925828 sequences

Initializing model...
Found 9037/9551 words in FastText


In [None]:
def train_with_gradient_accumulation(
    model, dataloader, optimizer, criterion, device, accumulation_steps=1, epoch=1
):
    """Train with gradient accumulation"""
    model.train()
    total_loss = 0
    optimizer.zero_grad()

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch} (accum={accumulation_steps})")

    for batch_idx, batch in enumerate(progress_bar):
        batch = batch.to(device)
        inputs = batch[:, :-1]
        targets = batch[:, 1:]

        # Forward pass
        logits = model(inputs)
        loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))

        # Normalize loss by accumulation steps
        loss = loss / accumulation_steps
        loss.backward()

        # Update weights every accumulation_steps
        if (batch_idx + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps
        progress_bar.set_postfix({"loss": f"{loss.item() * accumulation_steps:.4f}"})

    # Handle remaining gradients
    if (batch_idx + 1) % accumulation_steps != 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    perplexity = np.exp(avg_loss)
    return avg_loss, perplexity


def evaluate_model(model, dataloader, criterion, device):
    """Evaluate model and calculate perplexity"""
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            batch = batch.to(device)
            inputs = batch[:, :-1]
            targets = batch[:, 1:]

            logits = model(inputs)
            loss = criterion(logits.reshape(-1, logits.size(-1)), targets.reshape(-1))
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    perplexity = np.exp(avg_loss)
    return avg_loss, perplexity


def experiment_gradient_accumulation(
    model, train_loader, val_loader, optimizer, criterion, device,
    num_epochs=3, save_dir="results"
):
    """Experiment with different gradient accumulation steps"""
    print("\nGradient Accumulation Experiment...")

    os.makedirs(save_dir, exist_ok=True)

    accumulation_configs = [1, 2, 4, 8]

    for accum_steps in accumulation_configs:
        print(f"\n{'='*50}")
        print(f"Accumulation Steps: {accum_steps}")
        print(f"Effective Batch Size: {train_loader.batch_size * accum_steps}")
        print(f"{'='*50}")

        # Reset model and optimizer
        model_state = model.state_dict()
        opt_state = optimizer.state_dict()

        epoch_times = []
        train_losses = []
        train_perplexities = []
        val_losses = []
        val_perplexities = []

        for epoch in range(1, num_epochs + 1):
            start_time = time.time()

            train_loss, train_ppl = train_with_gradient_accumulation(
                model,
                train_loader,
                optimizer,
                criterion,
                device,
                accumulation_steps=accum_steps,
                epoch=epoch,
            )

            # Evaluate on validation set
            val_loss, val_ppl = evaluate_model(model, val_loader, criterion, device)

            epoch_time = time.time() - start_time
            epoch_times.append(epoch_time)
            train_losses.append(train_loss)
            train_perplexities.append(train_ppl)
            val_losses.append(val_loss)
            val_perplexities.append(val_ppl)

            print(f"Epoch {epoch}:")
            print(f"  Train - Loss: {train_loss:.4f}, Perplexity: {train_ppl:.2f}")
            print(f"  Val   - Loss: {val_loss:.4f}, Perplexity: {val_ppl:.2f}")
            print(f"  Time: {epoch_time:.2f}s")

        # Save results for this accumulation config
        result = {
            "num_epochs": num_epochs,
            "train_losses": train_losses,
            "train_perplexities": train_perplexities,
            "val_losses": val_losses,
            "val_perplexities": val_perplexities,
            "best_train_loss": min(train_losses),
            "best_train_ppl": min(train_perplexities),
            "best_val_loss": min(val_losses),
            "best_val_ppl": min(val_perplexities)
        }

        save_path = f"{save_dir}/result_accum_{accum_steps}.json"
        with open(save_path, 'w') as f:
            json.dump(result, f, indent=2)

        print(f"✓ Saved: {save_path}")

        # Restore model state for fair comparison
        model.load_state_dict(model_state)
        optimizer.load_state_dict(opt_state)

    print(f"\n{'='*50}")
    print(f"All results saved to: {save_dir}/")
    print(f"{'='*50}\n")


experiment_gradient_accumulation(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    num_epochs=1,
    save_dir="results"
)


Gradient Accumulation Experiment...

Accumulation Steps: 1
Effective Batch Size: 32


Epoch 1 (accum=1): 100%|██████████| 96356/96356 [1:00:25<00:00, 26.58it/s, loss=2.0688]


Epoch 1: Loss=2.1992, Time=3625.41s
✓ Saved: results/result_accum_1.json

Accumulation Steps: 2
Effective Batch Size: 64


Epoch 1 (accum=2):   2%|▏         | 1738/96356 [01:03<57:04, 27.63it/s, loss=1.9779]

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

# Load data from all accumulation step files
accumulation_steps = [1, 2, 4, 8]
data = {}

for step in accumulation_steps:
    filename = f'results/accumulation/accum_{step}.json'
    try:
        with open(filename, 'r') as f:
            data[step] = json.load(f)
        print(f"Loaded {filename}")
    except FileNotFoundError:
        print(f"Warning: {filename} not found, skipping...")
    except json.JSONDecodeError as e:
        print(f"Error decoding {filename}: {e}")

if not data:
    print("No data files found!")
    exit(1)

# Create figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Plot 1: Training Loss vs Accumulation Steps
steps_list = sorted(data.keys())
final_losses = [data[step]['train_losses'][-1] for step in steps_list]

ax1.plot(steps_list, final_losses, marker='o', linewidth=2.5, 
         color='#3b82f6', markersize=10, markerfacecolor='#60a5fa', 
         markeredgewidth=2, markeredgecolor='#3b82f6')

# Add value labels on each point
for step, loss in zip(steps_list, final_losses):
    ax1.annotate(f'{loss:.4f}', 
                xy=(step, loss), 
                xytext=(0, 10),
                textcoords='offset points',
                ha='center',
                fontsize=10,
                fontweight='bold',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7))

ax1.set_xlabel('Accumulation Steps', fontsize=12, fontweight='bold')
ax1.set_ylabel('Training Loss', fontsize=12, fontweight='bold')
ax1.set_title('Training Loss vs Accumulation Steps', fontsize=14, fontweight='bold')
ax1.set_xticks(steps_list)
ax1.grid(True, alpha=0.3)
ax1.set_xlim(0.5, max(steps_list) + 0.5)

# Plot 2: Epoch Times vs Accumulation Steps
steps_list = sorted(data.keys())
epoch_times = [data[step]['epoch_times'][0] for step in steps_list]

ax2.plot(steps_list, epoch_times, marker='s', linewidth=2.5, 
         color='#10b981', markersize=10, markerfacecolor='#34d399', 
         markeredgewidth=2, markeredgecolor='#10b981')

# Add value labels on each point
for step, time in zip(steps_list, epoch_times):
    ax2.annotate(f'{time:.1f}s', 
                xy=(step, time), 
                xytext=(0, 10),
                textcoords='offset points',
                ha='center',
                fontsize=10,
                fontweight='bold',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='lightblue', alpha=0.7))

ax2.set_xlabel('Accumulation Steps', fontsize=12, fontweight='bold')
ax2.set_ylabel('Epoch Time (seconds)', fontsize=12, fontweight='bold')
ax2.set_title('Training Time vs Accumulation Steps', fontsize=14, fontweight='bold')
ax2.set_xticks(steps_list)
ax2.grid(True, alpha=0.3)
ax2.set_xlim(0.5, max(steps_list) + 0.5)

# Adjust layout and save
plt.tight_layout()
output_filename = 'results/accumulation/comparison_plot.png'
plt.savefig(output_filename, dpi=300, bbox_inches='tight')
print(f"\nPlot saved as: {output_filename}")

# Print summary statistics
print("\n" + "="*60)
print("SUMMARY STATISTICS")
print("="*60)
for step in sorted(data.keys()):
    print(f"\nAccumulation Steps = {step}:")
    print(f"  Effective batch size: {data[step]['effective_batch_size']}")
    print(f"  Final training loss: {data[step]['train_losses'][-1]:.4f}")
    print(f"  Epoch time: {data[step]['epoch_times'][0]:.2f} seconds")
print("="*60)

plt.show()