!pip install -q torch torchvision matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import random, math, os, textwrap

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
!wget -q https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O shakespeare.txt

with open("shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

print("Corpus length (chars):", len(text))
print("Sample:")
print(text[:500])

chars = sorted(list(set(text)))
vocab_size = len(chars)
print("Vocab size:", vocab_size)
print("Chars:", chars)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}

def encode(s: str):
    return [stoi[c] for c in s]

def decode(idx_list):
    return "".join(itos[i] for i in idx_list)

data = torch.tensor(encode(text), dtype=torch.long)
print("Encoded data shape:", data.shape)

In [None]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

block_size = 128  
batch_size = 64    

def get_batch(split="train"):
    src = train_data if split == "train" else val_data
    # random sample initialization
    ix = torch.randint(0, len(src) - block_size - 1, (batch_size,))
    x = torch.stack([src[i:i+block_size] for i in ix])
    y = torch.stack([src[i+1:i+block_size+1] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y

# sanity check
xb, yb = get_batch("train")
print("Batch x shape:", xb.shape)  # (B, T)
print("Batch y shape:", yb.shape)
print("Example x (decoded):")
print(decode(xb[0].cpu().tolist()))
print("Example y (decoded):")
print(decode(yb[0].cpu().tolist()))

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1, block_size=128):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Q, K, V projection
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

        # causal mask: shape (1, 1, T, T)
        mask = torch.tril(torch.ones(block_size, block_size))
        # register buffer 
        self.register_buffer("causal_mask", mask.view(1, 1, block_size, block_size))

    def forward(self, x):
        """
        x: (B, T, d_model), T <= block_size
        return: (B, T, d_model)
        """
        B, T, D = x.shape
        H = self.num_heads
        Hd = self.head_dim

        # Q, K, V: (B, T, D)
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        Q = Q.view(B, T, H, Hd).transpose(1, 2)
        K = K.view(B, T, H, Hd).transpose(1, 2)
        V = V.view(B, T, H, Hd).transpose(1, 2)

        # attentions: (B, H, T, T)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Hd)

        # causal_mask: (1,1,block_size,block_size) -> next T
        mask = self.causal_mask[:, :, :T, :T]  # (1,1,T,T)
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))


        # softmax 
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        attn_weights = self.dropout(attn_weights)

        # each head (B, H, T, Hd)
        context = torch.matmul(attn_weights, V)

        # combine heads: (B, H, T, Hd) -> (B, T, D)
        context = context.transpose(1, 2).contiguous().view(B, T, D)

        out = self.W_o(context)
        return out

In [None]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    """
    Pre-LN 
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1, block_size=128):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_model, num_heads, dropout=dropout, block_size=block_size)
        self.ffn = FeedForward(d_model, d_ff, dropout=dropout)

    def forward(self, x):
        # Self-attention + residual
        x = x + self.mha(self.ln1(x))
        # FFN + residual
        x = x + self.ffn(self.ln2(x))
        return x

In [None]:
class CharTransformerLM(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=128,
        num_heads=4,
        num_layers=4,
        d_ff=512,
        block_size=128,
        dropout=0.1,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.block_size = block_size

        # token embedding + position embedding
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(block_size, d_model)

        self.dropout = nn.Dropout(dropout)

        # Multiple TransformerBlock
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout=dropout, block_size=block_size)
            for _ in range(num_layers)
        ])

        # LayerNorm + output to vocab
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        self._init_parameters()

    def _init_parameters(self):
        # some initialization

        for name, p in self.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, idx, targets=None):
        """
        idx: (B, T) 
        targets: (B, T)
        """
        B, T = idx.shape
        assert T <= self.block_size, "Sequence length exceeds block_size"

        # token + position embedding
        pos = torch.arange(0, T, device=idx.device)
        pos = pos.unsqueeze(0)  # (1, T)
        tok_emb = self.token_embed(idx)       # (B, T, d_model)
        pos_emb = self.pos_embed(pos)         # (1, T, d_model)
        x = tok_emb + pos_emb                 # (B, T, d_model)
        x = self.dropout(x)

        # TransformerBlocks
        for layer in self.layers:
            x = layer(x)

        # LN + linearhead
        x = self.ln_f(x)
        logits = self.head(x)  # (B, T, vocab_size)

        loss = None
        if targets is not None:
            logits_flat = logits.view(-1, self.vocab_size)
            targets_flat = targets.view(-1)
            loss = F.cross_entropy(logits_flat, targets_flat)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=100, temperature=1.0):
        """
        idx: (B, T) 
        return: (B, T+max_new_tokens)
        """
        self.eval()
        B, T = idx.shape

        for _ in range(max_new_tokens):

            idx_cond = idx[:, -self.block_size:]

            logits, _ = self(idx_cond)  # (B, T_cond, vocab_size)
            logits = logits[:, -1, :]   # (B, vocab_size)

            # temperature
            logits = logits / temperature

            probs = F.softmax(logits, dim=-1)  # (B, vocab_size)

            # next token
            next_idx = torch.multinomial(probs, num_samples=1)  # (B, 1)

            # concat
            idx = torch.cat([idx, next_idx], dim=1)  # (B, T+1)

        return idx

In [None]:
model.eval()

def generate_text(prompt="JULIET:", max_new_tokens=400, temperature=0.8):
    # prompt to tensor
    start_ids = torch.tensor([encode(prompt)], dtype=torch.long, device=device)

    sample_ids = model.generate(
        start_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
    )
    sample_text = decode(sample_ids[0].cpu().tolist())
    return sample_text

generated = generate_text("JULIET:", max_new_tokens=500, temperature=0.8)
print("=== Generated Text ===")
print(generated)

