In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

In [2]:
#!pip install torch --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m46.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m54.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7e1ae4af99f0>

### Part 1: Load & Tokenize Dataset (Shakespeare)

In [3]:

text = """
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by opposing end them.
"""


In [4]:
# Vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Vocab size: {vocab_size}")

Vocab size: 31


In [5]:
# Mapping from char to int
stoi = {ch: i for i, ch in enumerate(chars)}
itost = {i: ch for ch, i in stoi.items()}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itost[i] for i in l])

In [6]:
# Encode dataset as integers
data = torch.tensor(encode(text), dtype=torch.long)

### Part 2: Dataset Preparation

In [7]:
block_size = 8  # sequence length


def get_batch(split='train'):
    x = []
    y = []
    for i in range(len(data) - block_size):
      x.append(data[i:i+block_size])
      y.append(data[i+1:i+block_size+1])
    return torch.stack(x), torch.stack(y)

X, Y = get_batch()
print(X.shape)
print("Input:", decode(X[0].tolist()))
print("Target:", decode(Y[0].tolist()))


torch.Size([191, 8])
Input: 
To be, 
Target: To be, o


### Part 3: Self-Attention Layer (1-head)

In [30]:
class SelfAttention(nn.Module):
  def __init__(self, emb_size):
    super().__init__() # Calling the constructor of the parent class so that all of its internal mechanisms (like parameter tracking, hooks, buffers, etc.) are properly initialized.
    self.query = nn.Linear(emb_size, emb_size)
    self.key = nn.Linear(emb_size, emb_size)
    self.value = nn.Linear(emb_size, emb_size)
    self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size))) # mask out future tokens in self-attention, enforcing that the model can’t "cheat" by looking ahead.

  def forward(self, x):
    B, T, C = x.shape # B → Batch size: number of sequences in the batch; T → Sequence length (also called block_size in this notebook); C → Embedding dimension: how many features per token
    q = self.query(x)  # (B,T,C)
    k = self.key(x)
    v = self.value(x)

    attn_score = q@k.transpose(-2, -1)/math.sqrt(C) # Transpose(-2,-1) gives B,C,T. #BxCxC
    attn_score = attn_score.masked_fill(self.mask[:T, :T] == 0, float('-inf')) # Masking upper triangle to prevent future look up
    attn_score = F.softmax(attn_score, dim=-1)
    return attn_score @ v



### Part 4: Encoder only transformer

In [31]:
class Transformer_Block_EncoderOnly(nn.Module):
  def __init__(self, emb_size):
    super().__init__()
    self.attn = SelfAttention(emb_size)
    self.ff = nn.Sequential(
        nn.Linear(emb_size, 2*emb_size),
        nn.ReLU(),
        nn.Linear(2*emb_size, emb_size)
    )
    self.ln = nn.LayerNorm(emb_size)

  def forward(self, x):
    h = self.ln(self.attn(x))
    return self.ff(h)


### Part 5: Mini GPT

In [32]:
class Mini_GPT(nn.Module):
  def __init__(self, vocab_size, emb_size, block_size):
    super().__init__()
    self.embed = nn.Embedding(vocab_size, emb_size)
    self.pos_embed = nn.Embedding(block_size, emb_size)
    self.trans = nn.Sequential(
        Transformer_Block_EncoderOnly(emb_size),
        Transformer_Block_EncoderOnly(emb_size)
    )
    self.ln = nn.LayerNorm(emb_size)
    self.logit_output = nn.Linear(emb_size, vocab_size)

  def forward(self, idx):
    B, T = idx.shape
    tok = self.embed(idx)
    pos = self.pos_embed(torch.arange(T, device=idx.device))
    x = tok + pos
    h = self.trans(x)
    return self.logit_output(self.ln(h))

### Part 6: Train Loop

In [33]:
model = Mini_GPT(vocab_size=vocab_size, emb_size=64, block_size=block_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [36]:
for epoch in range(200):
  logits = model(X)
  B, T, C = logits.shape
  loss = loss_fn(logits.view(B*T, C), Y.view(B*T))

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  if epoch % 25 == 0:
      print(f"Epoch {epoch} | Loss: {loss.item():.4f}")

Epoch 0 | Loss: 2.6639
Epoch 25 | Loss: 1.7269
Epoch 50 | Loss: 0.9452
Epoch 75 | Loss: 0.5538
Epoch 100 | Loss: 0.4200
Epoch 125 | Loss: 0.3647
Epoch 150 | Loss: 0.3378
Epoch 175 | Loss: 0.3256


### Part 7: Generate

In [37]:
def generate(model, idx, max_new_tokens):
    model.eval()
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -block_size:]
        logits = model(idx_cond)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, next_token), dim=1)
    return idx

In [38]:
start = torch.tensor([[stoi['T']]])
output = generate(model, start, 100)
print("\nGenerated Text:\n", decode(output[0].tolist()))


Generated Text:
 The r'tis nobler in the mind the mind to suffer
The slings and arrows of outrageous fortune,
Or to ta


## Let's experiment using subword tokens