In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TinyTransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)

        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embed_dim)
        )
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Self-attention
        attn_output, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_output)  # Add & Norm

        # Feedforward
        ff_output = self.ff(x)
        x = self.norm2(x + ff_output)  # Add & Norm
        return x


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GPTMini(nn.Module):
    def __init__(self, vocab_size, block_size, embed_dim, num_heads, ff_hidden_dim, num_layers):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.position_embedding = nn.Embedding(block_size, embed_dim)

        self.transformer_blocks = nn.ModuleList([
            TinyTransformerBlock(embed_dim, num_heads, ff_hidden_dim)
            for _ in range(num_layers)
        ])

        self.ln_f = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

        self.block_size = block_size

    def forward(self, idx):
        B, T = idx.size()
        assert T <= self.block_size, "Input sequence too long"

        # Token + positional embeddings
        tok_emb = self.token_embedding(idx)  # (B, T, embed_dim)
        pos_emb = self.position_embedding(torch.arange(T, device=idx.device))  # (T, embed_dim)
        x = tok_emb + pos_emb  # (B, T, embed_dim)

        for block in self.transformer_blocks:
            x = block(x)

        x = self.ln_f(x)
        logits = self.head(x)  # (B, T, vocab_size)
        return logits

In [3]:
# Toy dataset
text = "hello world"
chars = sorted(list(set(text)))
vocab_size = len(chars)

# Mapping from char to index and vice versa
stoi = { ch: i for i, ch in enumerate(chars) }
itos = { i: ch for ch, i in stoi.items() }

# Encode/decode helpers
def encode(s): return [stoi[c] for c in s]
def decode(l): return ''.join([itos[i] for i in l])

# Prepare dataset
block_size = 8  # context length
data = torch.tensor(encode(text), dtype=torch.long)

def get_batch():
    i = torch.randint(len(data) - block_size, (1,))
    x = data[i:i+block_size].unsqueeze(0)      # shape: (1, block_size)
    y = data[i+1:i+block_size+1].unsqueeze(0)  # next tokens
    return x, y

In [4]:
model = GPTMini(
    vocab_size=vocab_size,
    block_size=block_size,
    embed_dim=32,
    num_heads=4,
    ff_hidden_dim=64,
    num_layers=2
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for step in range(1000):
    x_batch, y_batch = get_batch()
    logits = model(x_batch)
    loss = F.cross_entropy(logits.view(-1, vocab_size), y_batch.view(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"Step {step}: loss = {loss.item():.4f}")

Step 0: loss = 2.2375
Step 100: loss = 0.0881
Step 200: loss = 0.0323
Step 300: loss = 0.0170
Step 400: loss = 0.0111
Step 500: loss = 0.0077
Step 600: loss = 0.0056
Step 700: loss = 0.0031
Step 800: loss = 0.0033
Step 900: loss = 0.0027


In [5]:
def sample(model, start_text, length):
    model.eval()
    idx = torch.tensor([encode(start_text)], dtype=torch.long)
    for _ in range(length):
        idx_crop = idx[:, -block_size:]  # crop context
        logits = model(idx_crop)
        logits = logits[:, -1, :]  # last token logits
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_token], dim=1)
    return decode(idx[0].tolist())

In [23]:
print(sample(model, start_text="h", length=5))

hello 
