# Nano GPT: From Scratch Implementation

This notebook implements a GPT (Generative Pre-trained Transformer) model from scratch using PyTorch. The model is trained on song lyrics data and can generate text conditioned on song titles.

## Features
- Full GPT architecture implementation (multi-head attention, feed-forward networks, layer normalization)
- SentencePiece tokenization
- Mixed precision training with gradient scaling
- Learning rate scheduling with warmup
- Model checkpointing

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.amp import GradScaler, autocast
import sentencepiece as spm
import random
import time

In [None]:
# Configuration
torch.set_float32_matmul_precision('high')
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
scaler = GradScaler('cuda' if device == 'cuda' else 'cpu')

print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"Number of CUDA devices: {torch.cuda.device_count()}")
    print(f"Current CUDA device: {torch.cuda.current_device()}")

## Model Configuration

In [None]:
# Model hyperparameters
init_from = 'scratch'  # Options: 'scratch', 'resume'
embedding_dim = 640
batch_size = 8
context_length = 128
n_heads = 8
n_layers = 8
dropout = 0.1
weight_decay = 1e-2

# Training hyperparameters
max_iters = 3000
eval_interval = 100
eval_iters = 50
warmup_steps = 200
learning_rate = 2e-5
vocab_size = 50304  # Will be set after loading tokenizer

## Tokenization Setup

In [None]:
# Load the SentencePiece tokenizer
# Note: Update the path to your tokenizer model file
tokenizer_path = 'm.model'  # Update this path as needed

sp = spm.SentencePieceProcessor()
sp.load(tokenizer_path)

# Define encode/decode functions
encode = lambda s: sp.encode_as_ids(s)
decode = lambda l: sp.decode_ids(l)

# Get actual vocab size from tokenizer
vocab_size = sp.GetPieceSize()
print(f"Vocabulary size: {vocab_size}")


## Data Loading and Preprocessing

In [None]:
# Load and preprocess song lyrics data
# Note: Update the path to your data file
data_path = 'album-song-lyrics.json'  # Update this path as needed

full_df = pd.read_json(data_path)
albums = full_df.apply(lambda n: n['Songs'], axis=1)

# Extract songs and titles
songs = []
titles = []
for album in albums:
    for song in album:
        song_text = [text['Text'] for text in song['Lyrics']]
        song_text = '\n'.join(song_text)
        songs.append(song_text)
        titles.append(song['Title'])

# Create DataFrame and encode
lyrics_df = pd.DataFrame({'title': titles, 'lyrics': songs})
lyrics_df['encoded_lyrics'] = lyrics_df['lyrics'].apply(lambda s: encode(s))
lyrics_df['encoded_title'] = lyrics_df['title'].apply(lambda s: encode(s))

# Split into train/val sets
n = int(len(lyrics_df) * 0.9)
train_data = lyrics_df[:n].sample(frac=1, random_state=42).reset_index(drop=True)
val_data = lyrics_df[n:].sample(frac=1, random_state=42).reset_index(drop=True)

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")



In [None]:
# Display statistics about song lengths
print("Song length statistics (in tokens):")
print(train_data.apply(lambda n: len(n['encoded_lyrics']), axis=1).describe())

In [None]:
# Display validation set song titles
print("Validation set songs:")
for _, song in val_data.iterrows():
    print(f"  - {song['title']}")

In [None]:
def get_batch(split):
    """
    Generate a batch of training data.
    
    Each sample is formatted as: [BOS] Title: <title> \nLyrics: <lyrics>
    Random context windows are sampled from each song.
    """
    data = train_data if split == 'train' else val_data
    batch = data.sample(n=batch_size, replace=len(data) < batch_size)
    
    x = []
    y = []
    for _, song in batch.iterrows():
        # Format: BOS token (2) + "Title: " + title + "\nLyrics: " + lyrics
        logits = [2] + encode("Title: ") + song['encoded_title'] + encode("\nLyrics: ") + song['encoded_lyrics']
        
        # Sample a random context window
        max_idx = max(0, len(logits) - context_length - 1)
        idx = random.randint(0, max_idx)
        
        x.append(torch.tensor(logits[idx:idx + context_length], device=device, dtype=torch.long))
        y.append(torch.tensor(logits[idx + 1:idx + context_length + 1], device=device, dtype=torch.long))
    
    return torch.stack(x), torch.stack(y)



In [None]:
# Test batch generation
x, y = get_batch('val')
print("Sample batch (first sequence):")
print(decode(x[0].tolist()))

## Model Architecture

In [None]:
class Head(nn.Module):
    """Single attention head implementing scaled dot-product attention."""
    
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        # Causal mask: lower triangular matrix
        self.register_buffer('tril', torch.tril(torch.ones(context_length, context_length)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        
        # Scaled dot-product attention
        wei = (q @ k.transpose(-2, -1)) * (C ** -0.5)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention mechanism."""
    
    def __init__(self, n_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

In [None]:
class FeedForward(nn.Module):
    """Feed-forward network with ReLU activation."""
    
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class Block(nn.Module):
    """Transformer block with self-attention and feed-forward layers."""
    
    def __init__(self, n_heads):
        super().__init__()
        head_size = embedding_dim // n_heads
        self.sa = MultiHeadAttention(n_heads, head_size)
        self.ffwd = FeedForward()
        self.ln1 = nn.LayerNorm(embedding_dim)
        self.ln2 = nn.LayerNorm(embedding_dim)

    def forward(self, x):
        # Pre-norm architecture with residual connections
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

## Model Initialization

In [None]:
class Gpt(nn.Module):
    """GPT model with token embeddings, positional embeddings, and transformer blocks."""
    
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_dim)
        self.positional_embedding_table = nn.Embedding(context_length, embedding_dim)
        self.blocks = nn.Sequential(*[Block(n_heads) for _ in range(n_layers)])
        self.ln = nn.LayerNorm(embedding_dim)
        self.lm = nn.Linear(embedding_dim, vocab_size)
        self.register_buffer('pos_ids', torch.arange(context_length, device=device))

    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # Token and positional embeddings
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.positional_embedding_table(self.pos_ids[:T])
        x = tok_emb + pos_emb
        
        # Transformer blocks
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.lm(x)

        if targets is None:
            loss = None
        else:
            # Reshape for cross-entropy loss
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, idx, max_new_tokens):
        """Generate text by sampling from the model."""
        for _ in range(max_new_tokens):
            # Crop context to context_length
            idx_cond = idx[:, -context_length:]
            logits, _ = self(idx_cond)
            # Focus on the last time step
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx



In [None]:
# Initialize model
model = Gpt()
m = model.to(device)

# Load pre-trained weights if resuming
if init_from == 'resume':
    checkpoint_path = 'checkpoint.pt'  # Update this path as needed
    m.load_state_dict(torch.load(checkpoint_path, weights_only=True, map_location=device))
    print("Loaded model from checkpoint")
else:
    print("Initializing model from scratch")

# Count parameters
total_params = sum(p.numel() for p in m.parameters())
trainable_params = sum(p.numel() for p in m.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## Training Setup

In [None]:
@torch.no_grad()
def estimate_loss():
    """Estimate loss on train and validation sets."""
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters, device=device)
        for k in range(eval_iters):
            x, y = get_batch(split)
            _, loss = m(x, y)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    m.train()
    return out


## Text Generation

In [None]:
# Setup optimizer and learning rate scheduler
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters - warmup_steps)

# Initialize learning rate to 0 for warmup
optimizer.param_groups[0]['lr'] = 0

# Resume from checkpoint if needed
if init_from == 'resume':
    checkpoint_path = 'checkpoint.pt'  # Update this path as needed
    checkpoint = torch.load(checkpoint_path, map_location=device)
    m.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scaler.load_state_dict(checkpoint['scaler'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    start_iter = checkpoint['iter']
    print(f"Resuming from iteration {start_iter}")
else:
    start_iter = 0

# Training loop
checkpoint_interval = 1000  # Save checkpoint every N iterations
checkpoint_path = 'checkpoint.pt'  # Update this path as needed

for iter in range(start_iter, max_iters):
    # Evaluation
    if iter % eval_interval == 0:
        with torch.autocast(device_type=device, dtype=torch.bfloat16 if dtype == 'bfloat16' else torch.float16):
            losses = estimate_loss()
        current_lr = optimizer.param_groups[0]['lr']
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr: {current_lr:.2e}")

    # Training step
    x, y = get_batch('train')
    
    with torch.autocast(device_type=device, dtype=torch.bfloat16 if dtype == 'bfloat16' else torch.float16):
        logits, loss = m(x, y)
    
    optimizer.zero_grad(set_to_none=True)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    # Learning rate scheduling
    if iter < warmup_steps:
        # Linear warmup
        lr = learning_rate * (iter + 1) / warmup_steps
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        scheduler.step()
    
    # Save checkpoint
    if iter > 0 and iter % checkpoint_interval == 0:
        torch.save({
            'model': m.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scaler': scaler.state_dict(),
            'scheduler': scheduler.state_dict(),
            'iter': iter
        }, checkpoint_path)
        print(f"Checkpoint saved at iteration {iter}")

# Final evaluation
losses = estimate_loss()
print(f"\nFinal - step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")


## Save Model

In [None]:
# Generate text from a prompt
m.eval()
prompt = "Title: A Guitar \nLyrics:"
context = torch.stack([torch.tensor([2] + encode(prompt), device=device, dtype=torch.long)])

with torch.no_grad():
    generated = m.generate(context, max_new_tokens=200)

output = decode(generated[0].tolist())
print("Generated text:")
print(output)
print("\n" + "="*50 + "\n")

In [None]:

# Save the trained model
model_save_path = 'gpt_model.pth'  # Update this path as needed
torch.save(m.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")
