In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import math
import numpy as np

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# ============================================================================
# DATA PREPARATION
# ============================================================================

# Download tiny shakespeare
!wget -q https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('input.txt', 'r') as f:
    text = f.read()

# Character-level tokenization
chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

encode = lambda s: [char_to_idx[c] for c in s]
decode = lambda l: ''.join([idx_to_char[i] for i in l])

# Encode dataset
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

print(f"Vocab size: {vocab_size}")
print(f"Dataset size: {len(data):,} characters")
print(f"Train: {len(train_data):,}, Val: {len(val_data):,}")

In [None]:
# ============================================================================
# DATASET CLASS
# ============================================================================

class CharDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.block_size]
        y = self.data[idx + 1:idx + self.block_size + 1]
        return x, y

In [None]:
# ============================================================================
# MODEL COMPONENTS
# ============================================================================

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape

        # Project and split into Q, K, V
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # Apply causal mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmax and dropout
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention to values
        out = torch.matmul(attn_weights, v)

        # Concatenate heads
        out = out.transpose(1, 2).contiguous()
        out = out.reshape(batch_size, seq_len, d_model)

        # Output projection
        out = self.out_proj(out)

        return out


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(F.gelu(self.linear1(x))))


class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Self-attention with residual connection
        attn_out = self.attn(self.norm1(x), mask)
        x = x + self.dropout(attn_out)

        # Feed-forward with residual connection
        ff_out = self.ff(self.norm2(x))
        x = x + self.dropout(ff_out)

        return x


class GPT(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len

        # Token embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

        # Output layer
        self.norm = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size)

        self.dropout = nn.Dropout(dropout)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, targets=None):
        batch_size, seq_len = x.shape

        # Create causal mask
        mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0).to(x.device)

        # Embed tokens and add positional encoding
        x = self.token_embedding(x) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)

        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x, mask)

        # Output
        x = self.norm(x)
        logits = self.fc_out(x)

        # Calculate loss if targets provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """Generate text from a prompt"""
        for _ in range(max_new_tokens):
            # Crop context if it exceeds max_len
            idx_cond = idx if idx.size(1) <= self.max_len else idx[:, -self.max_len:]

            # Forward pass
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            # Optional top-k sampling
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

            # Sample from distribution
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)

            # Append to sequence
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [None]:
# ============================================================================
# HYPERPARAMETERS
# ============================================================================

# Model hyperparameters
block_size = 128      # Context length
d_model = 256         # Embedding dimension
num_heads = 8         # Number of attention heads
num_layers = 6        # Number of transformer blocks
d_ff = 1024          # Feed-forward dimension
dropout = 0.2

# Training hyperparameters
batch_size = 64
learning_rate = 3e-4
max_iters = 5000
eval_interval = 500
eval_iters = 200

# ============================================================================
# TRAINING SETUP
# ============================================================================

# Create datasets
train_dataset = CharDataset(train_data, block_size)
val_dataset = CharDataset(val_data, block_size)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize model
model = GPT(
    vocab_size=vocab_size,
    d_model=d_model,
    num_heads=num_heads,
    num_layers=num_layers,
    d_ff=d_ff,
    max_len=block_size,
    dropout=dropout
).to(device)

print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================

@torch.no_grad()
def estimate_loss():
    model.eval()
    losses = {'train': [], 'val': []}

    for split, loader in [('train', train_loader), ('val', val_loader)]:
        for i, (x, y) in enumerate(loader):
            if i >= eval_iters:
                break
            x, y = x.to(device), y.to(device)
            _, loss = model(x, y)
            losses[split].append(loss.item())

    model.train()
    return {k: np.mean(v) for k, v in losses.items()}


def train():
    model.train()

    for iter_num in range(max_iters):
        # Get batch
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            # Forward pass
            logits, loss = model(x, y)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            break  # One batch per iteration

        # Evaluate
        if iter_num % eval_interval == 0 or iter_num == max_iters - 1:
            losses = estimate_loss()
            print(f"Step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    print("\nTraining complete!")

In [None]:
# ============================================================================
# GENERATION FUNCTION
# ============================================================================

def generate_text(prompt, max_new_tokens=500, temperature=0.8, top_k=40):
    """Generate text from a prompt"""
    model.eval()

    # Encode prompt
    if prompt:
        context = torch.tensor([encode(prompt)], dtype=torch.long).to(device)
    else:
        context = torch.zeros((1, 1), dtype=torch.long).to(device)

    # Generate
    generated = model.generate(context, max_new_tokens, temperature, top_k)
    output = decode(generated[0].tolist())

    return output

In [None]:
# ============================================================================
# RUN TRAINING
# ============================================================================

print("\n" + "="*50)
print("Starting Training")
print("="*50 + "\n")

train()

In [None]:
# ============================================================================
# TEST GENERATION
# ============================================================================

print("\n" + "="*50)
print("Testing Generation")
print("="*50 + "\n")

prompts = [
    "ROMEO:",
    "To be or not to be,",
    "What light through yonder window"
]

for prompt in prompts:
    print(f"\nPrompt: '{prompt}'")
    print("-" * 50)
    output = generate_text(prompt, max_new_tokens=200, temperature=0.8, top_k=40)
    print(output)
    print("\n")

In [None]:
# ============================================================================
# INTERACTIVE CHAT FUNCTION
# ============================================================================

def chat():
    """Interactive chat with the model"""
    print("\n" + "="*50)
    print("GPT Chat Mode (type 'quit' to exit)")
    print("="*50 + "\n")

    while True:
        prompt = input("You: ")
        if prompt.lower() in ['quit', 'exit', 'q']:
            print("Goodbye!")
            break

        response = generate_text(prompt, max_new_tokens=300, temperature=0.8, top_k=40)
        print(f"\nGPT: {response}\n")

# Uncomment to start chatting:
chat()