!pip install -q torch torchvision matplotlib

In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import random, math, os, textwrap

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [18]:
!wget -q https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O shakespeare.txt

with open("shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

print("Corpus length (chars):", len(text))
print("Sample:")
print(text[:500])

chars = sorted(list(set(text)))
vocab_size = len(chars)
print("Vocab size:", vocab_size)
print("Chars:", chars)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}

def encode(s: str):
    return [stoi[c] for c in s]

def decode(idx_list):
    return "".join(itos[i] for i in idx_list)

data = torch.tensor(encode(text), dtype=torch.long)
print("Encoded data shape:", data.shape)

Corpus length (chars): 1115394
Sample:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor
Vocab size: 65
Chars: ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Encoded data shape: torch.Size([1115394])


In [19]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

block_size = 128
batch_size = 64

def get_batch(split="train"):
    src = train_data if split == "train" else val_data
    # random sample initialization
    ix = torch.randint(0, len(src) - block_size - 1, (batch_size,))
    x = torch.stack([src[i:i+block_size] for i in ix])
    y = torch.stack([src[i+1:i+block_size+1] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y

# sanity check
xb, yb = get_batch("train")
print("Batch x shape:", xb.shape)  # (B, T)
print("Batch y shape:", yb.shape)
print("Example x (decoded):")
print(decode(xb[0].cpu().tolist()))
print("Example y (decoded):")
print(decode(yb[0].cpu().tolist()))

Batch x shape: torch.Size([64, 128])
Batch y shape: torch.Size([64, 128])
Example x (decoded):
 passage.

Second Watchman:
Ay, wherefore else guard we his royal tent,
But to defend his person from night-foes?

WARWICK:
This
Example y (decoded):
passage.

Second Watchman:
Ay, wherefore else guard we his royal tent,
But to defend his person from night-foes?

WARWICK:
This 


In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1, block_size=128):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        # Q, K, V projection
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

        self.W_o = nn.Linear(d_model, d_model, bias=False)

        self.dropout = nn.Dropout(dropout)

        # causal mask: shape (1, 1, T, T)
        mask = torch.tril(torch.ones(block_size, block_size))
        # register buffer
        self.register_buffer("causal_mask", mask.view(1, 1, block_size, block_size))

    def forward(self, x):
        """
        x: (B, T, d_model), T <= block_size
        return: (B, T, d_model)
        """
        B, T, D = x.shape
        H = self.num_heads
        Hd = self.head_dim

        # Q, K, V: (B, T, D)
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        Q = Q.view(B, T, H, Hd).transpose(1, 2)
        K = K.view(B, T, H, Hd).transpose(1, 2)
        V = V.view(B, T, H, Hd).transpose(1, 2)

        # attentions: (B, H, T, T)
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Hd)

        # causal_mask: (1,1,block_size,block_size) -> next T
        mask = self.causal_mask[:, :, :T, :T]  # (1,1,T,T)
        attn_scores = attn_scores.masked_fill(mask == 0, float('-inf'))


        # softmax
        attn_weights = F.softmax(attn_scores, dim=-1)

        attn_weights = self.dropout(attn_weights)

        # each head (B, H, T, Hd)
        context = torch.matmul(attn_weights, V)

        # combine heads: (B, H, T, Hd) -> (B, T, D)
        context = context.transpose(1, 2).contiguous().view(B, T, D)

        out = self.W_o(context)
        return out

In [21]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    """
    Pre-LN
    """
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1, block_size=128):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.mha = MultiHeadAttention(d_model, num_heads, dropout=dropout, block_size=block_size)
        self.ffn = FeedForward(d_model, d_ff, dropout=dropout)

    def forward(self, x):
        # Self-attention + residual
        x = x + self.mha(self.ln1(x))
        # FFN + residual
        x = x + self.ffn(self.ln2(x))
        return x

In [22]:
class CharTransformerLM(nn.Module):
    def __init__(
        self,
        vocab_size,
        d_model=128,
        num_heads=4,
        num_layers=4,
        d_ff=512,
        block_size=128,
        dropout=0.1,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.block_size = block_size

        # token embedding + position embedding
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(block_size, d_model)

        self.dropout = nn.Dropout(dropout)

        # Multiple TransformerBlock
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout=dropout, block_size=block_size)
            for _ in range(num_layers)
        ])

        # LayerNorm + output to vocab
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        self._init_parameters()

    def _init_parameters(self):
        # some initialization

        for name, p in self.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, idx, targets=None):
        """
        idx: (B, T)
        targets: (B, T)
        """
        B, T = idx.shape
        assert T <= self.block_size, "Sequence length exceeds block_size"

        # token + position embedding
        pos = torch.arange(0, T, device=idx.device)
        pos = pos.unsqueeze(0)  # (1, T)
        tok_emb = self.token_embed(idx)       # (B, T, d_model)
        pos_emb = self.pos_embed(pos)         # (1, T, d_model)
        x = tok_emb + pos_emb                 # (B, T, d_model)
        x = self.dropout(x)

        # TransformerBlocks
        for layer in self.layers:
            x = layer(x)

        # LN + linearhead
        x = self.ln_f(x)
        logits = self.head(x)  # (B, T, vocab_size)

        loss = None
        if targets is not None:
            logits_flat = logits.view(-1, self.vocab_size)
            targets_flat = targets.view(-1)
            loss = F.cross_entropy(logits_flat, targets_flat)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=100, temperature=1.0):
        """
        idx: (B, T)
        return: (B, T+max_new_tokens)
        """
        self.eval()
        B, T = idx.shape

        for _ in range(max_new_tokens):

            idx_cond = idx[:, -self.block_size:]

            logits, _ = self(idx_cond)  # (B, T_cond, vocab_size)
            logits = logits[:, -1, :]   # (B, vocab_size)

            # temperature
            logits = logits / temperature

            probs = F.softmax(logits, dim=-1)  # (B, vocab_size)

            # next token
            next_idx = torch.multinomial(probs, num_samples=1)  # (B, 1)

            # concat
            idx = torch.cat([idx, next_idx], dim=1)  # (B, T+1)

        return idx

In [23]:
d_model = 128
num_heads = 4
num_layers = 4
d_ff = 4 * d_model
dropout = 0.1

model = CharTransformerLM(
    vocab_size=vocab_size,
    d_model=d_model,
    num_heads=num_heads,
    num_layers=num_layers,
    d_ff=d_ff,
    block_size=block_size,
    dropout=dropout,
).to(device)

print("Model parameters:",
      sum(p.numel() for p in model.parameters()) / 1e6, "M")

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

Model parameters: 0.82432 M


In [24]:
max_steps = 500
eval_interval = 200
log_interval = 50

def estimate_loss(num_batches=20):
    model.eval()
    out = {}
    with torch.no_grad():
        for split in ["train", "val"]:
            losses = []
            for _ in range(num_batches):
                x, y = get_batch(split)
                _, loss = model(x, y)
                losses.append(loss.item())
            out[split] = sum(losses) / len(losses)
    model.train()
    return out

step = 0
while step < max_steps:
    if step % eval_interval == 0:
        losses = estimate_loss(num_batches=20)
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    x, y = get_batch("train")
    logits, loss = model(x, y)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    if step % log_interval == 0:
        print(f"step {step}: loss {loss.item():.4f}")

    step += 1

print("Training finished.")

step 0: train loss 4.6236, val loss 4.6244
step 0: loss 4.6351
step 50: loss 2.9761
step 100: loss 2.6370
step 150: loss 2.5872
step 200: train loss 2.5010, val loss 2.5050
step 200: loss 2.5320
step 250: loss 2.5249
step 300: loss 2.4897
step 350: loss 2.4755
step 400: train loss 2.4013, val loss 2.4187
step 400: loss 2.4446
step 450: loss 2.4046
Training finished.


In [25]:
model.eval()

def generate_text(prompt="JULIET:", max_new_tokens=400, temperature=0.8):
    # prompt to tensor
    start_ids = torch.tensor([encode(prompt)], dtype=torch.long, device=device)

    sample_ids = model.generate(
        start_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
    )
    sample_text = decode(sample_ids[0].cpu().tolist())
    return sample_text

generated = generate_text("JULIET:", max_new_tokens=500, temperature=0.8)
print("=== Generated Text ===")
print(generated)



=== Generated Text ===
JULIET: is flerdom; tomrd, noow brous:
Tat'st
Cit o wimilen thimyour fow wie t tanelondthn ss.


TOLLOS:
AZy ist thowlee ce sshe thisthatirivik wime haing ches y me.
Ml mad wawicune th pree incat pat sthevil thireveasoug th mo t thaJust t anrd'sel
And ththathass, nd cothit my re ais hush s sis
Frangulll h ho se athand t I llis, I k imathime ay sberet thanghanghoun:
Ad gar t hithig heoul thaFamy t ible s fano ceriner thay chthaus telris
Th ve chand ing t we taven anoure bly it s daricke y,
Yono ite thth
