In [1]:
import torch
import torch.nn as nn
import time
import math
import gc
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from dataclasses import dataclass
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import GPT2Config, GPT2LMHeadModel, AutoTokenizer
from datasets import load_dataset
# from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
# from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

# --- Configuration ---
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)

CONF = {
    "vocab_size": 50257,
    "context_length": 1024,
    "batch_size": 1,           # Adjust based on VRAM
    "grad_accum": 4,           # Effective batch size = 32
    "lr": 3e-4,
    "train_steps": 1000,       # Short run for demo (increase to >10k for real results)
    "eval_every": 500,
    "dataset_name": "DKYoon/SlimPajama-6B",
    "dataset_subset": "default", 
}

# --- Model Configs for ~370M Parameters ---
# GPT-2 Medium-like
gpt2_config = GPT2Config(
    vocab_size=CONF["vocab_size"],
    n_positions=CONF["context_length"],
    n_embd=1024,
    n_layer=24,
    n_head=16,
    bos_token_id=50256,
    eos_token_id=50256,
)

# Mamba-2 (approx 370M)
# d_model=1024, n_layer=48 is standard for ~370M in Mamba papers
mamba2_config_dict = {
    "d_model": 1024,
    "n_layer": 48,
    "vocab_size": CONF["vocab_size"],
    "ssm_cfg": {"layer": "Mamba2"}, 
    "rms_norm": True,
    "residual_in_fp32": True,
    "fused_add_norm": True,
    "pad_vocab_size_multiple": 8
}

In [2]:
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def get_dataloader(split="train"):
    """Streams SlimPajama, tokenizes, and returns a DataLoader."""
    # Streaming=True avoids downloading the whole 20GB+ dataset
    dataset = load_dataset(CONF["dataset_name"], split=split, streaming=True)
    
    def tokenize_function(examples):
        return tokenizer(
            examples["text"], 
            truncation=True, 
            max_length=CONF["context_length"], 
            padding="max_length"
        )

    # Only keep input_ids to save memory
    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text", "meta"])
    tokenized_dataset = tokenized_dataset.with_format("torch")
    
    return DataLoader(tokenized_dataset, batch_size=CONF["batch_size"])

# Initialize Loaders
train_loader = get_dataloader("train")
lambada_dataset = load_dataset("cimec/lambada", split="validation")
print("Datasets initialized (Streaming Mode).")

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/48 [00:00<?, ?it/s]

Datasets initialized (Streaming Mode).


In [3]:
class GPT2Trainer:
    def __init__(self, config):
        self.model = GPT2LMHeadModel(config).to(device)
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=CONF["lr"])
        # Store metrics to save later
        self.metrics = {
            "model_name": "GPT-2 (350M)",
            "loss": [], 
            "memory_gb": [], 
            "throughput_tok_sec": [], 
            "lambada_acc": 0.0,
            "inference_speed": 0.0
        }

    def count_parameters(self):
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

    def train(self, steps):
        self.model.train()
        data_iter = iter(train_loader)
        progress_bar = tqdm(range(steps), desc="Training GPT-2")
        
        for i in progress_bar:
            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(train_loader)
                batch = next(data_iter)
            
            input_ids = batch["input_ids"].to(device)
            labels = input_ids.clone()
            
            # Tracking
            torch.cuda.reset_peak_memory_stats()
            t0 = time.time()
            
            # Forward & Backward
            outputs = self.model(input_ids, labels=labels)
            loss = outputs.loss / CONF["grad_accum"]
            loss.backward()
            
            if (i + 1) % CONF["grad_accum"] == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                
                # Metrics
                dt = time.time() - t0
                tokens_processed = input_ids.numel()
                mem_gb = torch.cuda.max_memory_allocated() / 1e9
                
                # Log actual loss (not scaled)
                current_loss = loss.item() * CONF["grad_accum"]
                self.metrics["loss"].append(current_loss)
                self.metrics["memory_gb"].append(mem_gb)
                self.metrics["throughput_tok_sec"].append(tokens_processed / dt)
                
                progress_bar.set_postfix({
                    "Loss": f"{current_loss:.4f}", 
                    "Mem": f"{mem_gb:.2f}GB"
                })

    def evaluate_lambada(self, n_samples=200):
        """Zero-shot evaluation on LAMBADA (Last Word Prediction)."""
        print("Running LAMBADA Evaluation...")
        self.model.eval()
        correct = 0
        total = 0
        
        for i, row in enumerate(lambada_dataset):
            if i >= n_samples: break
            
            full_text = row['text']
            split_text = full_text.split()
            context = " ".join(split_text[:-1])
            target_word = " " + split_text[-1]
            
            inputs = tokenizer(context, return_tensors="pt").to(device)
            target_id = tokenizer(target_word)["input_ids"][0]
            
            with torch.no_grad():
                outputs = self.model(inputs["input_ids"])
                pred_id = torch.argmax(outputs.logits[0, -1, :]).item()
            
            if pred_id == target_id:
                correct += 1
            total += 1
            
        acc = correct / total
        self.metrics["lambada_acc"] = acc
        print(f"LAMBADA Accuracy: {acc:.2%}")

    def measure_inference_speed(self, n_tokens=50):
        """Benchmarks generation speed."""
        print("Benchmarking Inference Speed...")
        self.model.eval()
        input_ids = torch.randint(0, CONF["vocab_size"], (1, 128)).to(device)
        
        # Warmup
        _ = self.model(input_ids)
        
        start = time.time()
        curr_ids = input_ids
        with torch.no_grad():
            for _ in range(n_tokens):
                outputs = self.model(curr_ids)
                next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(1)
                curr_ids = torch.cat([curr_ids, next_token], dim=1)
                
        duration = time.time() - start
        speed = n_tokens / duration
        self.metrics["inference_speed"] = speed
        print(f"Inference Speed: {speed:.2f} tokens/sec")

    def save_metrics(self):
        """Saves metrics to JSON for future comparison with Mamba."""
        with open(CONF["metrics_filename"], "w") as f:
            json.dump(self.metrics, f)
        print(f"Metrics saved to {CONF['metrics_filename']}")

In [4]:
# Initialize
trainer = GPT2Trainer(gpt2_config)
print(f"Model Parameters: {trainer.count_parameters():,}")

# 1. Train
trainer.train(CONF["train_steps"])

# 2. Evaluate
trainer.evaluate_lambada()
trainer.measure_inference_speed()

# 3. Save Data for later Mamba comparison
trainer.save_metrics()

# 4. Simple Visualization for GPT-2 only
def smooth(scalars, weight=0.9):
    last = scalars[0]
    smoothed = []
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point
        smoothed.append(smoothed_val)
        last = smoothed_val
    return smoothed

plt.figure(figsize=(10, 5))
plt.plot(smooth(trainer.metrics["loss"]), label="GPT-2 Loss")
plt.title("GPT-2 Training Loss (Smoothed)")
plt.xlabel("Steps")
plt.ylabel("Cross Entropy Loss")
plt.legend()
plt.show()

Model Parameters: 354,823,168


Training GPT-2:   0%|          | 0/1000 [00:00<?, ?it/s]

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


OutOfMemoryError: CUDA out of memory. Tried to allocate 394.00 MiB. GPU 0 has a total capacity of 11.64 GiB of which 70.88 MiB is free. Process 1634211 has 11.56 GiB memory in use. Of the allocated memory 10.90 GiB is allocated by PyTorch, and 544.09 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)