# Exercise
The n-dimensional tensor mastery challenge: Combine the `Head` and `MultiHeadAttention` into one class that processes all the heads in parallel, treating the heads as another batch dimension (answer is in nanoGPT).

You can even process the calculation for q, k, v in parallel by combining them into a "batch_dimension" of 3 and splitting afterwards. Use nn.Linear instead of torch.randn!!!

In [56]:
import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from tqdm import tqdm

In [57]:
# Create Vocabulary
path = './data/HarryPotterPreprocessed.txt'
with open(path, 'r', encoding='utf-8') as f:
    text = f.read()
chars = sorted(list(set(text)))

# Tokenization
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

In [58]:
# Hyperparameters
VOCAB_SIZE = len(chars)
EMBEDDING_SIZE = 32
CONTEXT_SIZE = 8
BATCH_SIZE = 64
MAX_STEPS = 5000
LEARNING_RATE = 3E-4
BLOCK_COUNT = 2
NUM_HEADS = 4
DROPOUT = 0.2
HEAD_SIZE = EMBEDDING_SIZE // NUM_HEADS # How big Query, Key and Value matrices are
device = 'cuda' if torch.cuda.is_available() else "cpu"
EVAL_INTERVAL = 500
EVAL_LOSS_BATCHES = 200

this_model_name = "model_EX1(2).pth"

In [59]:
# split data into train & validation
data = torch.tensor(encode(text), dtype=torch.long)
n = int(data.shape[0] * 0.9)
train_data, val_data = data[:n], data[n:]

# Loader that returns a batch
def get_batch(split):
    data = train_data if split == "train" else val_data
    ix = torch.randint(0, len(data) - CONTEXT_SIZE, (BATCH_SIZE, ))
    x = torch.stack([data[i:i+CONTEXT_SIZE] for i in ix])
    y = torch.stack([data[i+1:i+CONTEXT_SIZE+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [60]:
# calculate mean loss for {EVAL_LOSS_BATCHES}x batches
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(EVAL_LOSS_BATCHES, device=device)
        for i in tqdm(range(EVAL_LOSS_BATCHES)):
            X, Y = get_batch(split)
            _, loss = model(X, Y)
            losses[i] = loss.item()
        out[split] = losses.mean()  
    model.train()
    return out

In [61]:
""" Multiple Heads of Self-Attention in parallel """
class CausalSelfAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        
        self.num_heads = num_heads
        self.head_size = head_size

        # query, key & value matrix for n heads in parallel (must be split afterwards)
        self.sa_heads = nn.Linear(EMBEDDING_SIZE, 3*num_heads*head_size, bias=False)
        self.dropout = nn.Dropout(DROPOUT)

        # Only For Multi Head
        self.proj = nn.Linear(num_heads*head_size, EMBEDDING_SIZE) # back to original size (see 3b1b Value↑ matrix)
        self.dropout2 = nn.Dropout(DROPOUT)

        # since it's not a parameter of the model => register as buffer
        self.register_buffer('tril', torch.tril(torch.ones(CONTEXT_SIZE, CONTEXT_SIZE)))
    
    def forward(self, x):
        n_batch, n_context, n_emb = x.shape

        x = self.sa_heads(x).view(n_batch, n_context, self.num_heads, 3*self.head_size).transpose(1, 2)
        q, k, v = x.split(self.head_size, -1) # [n_batch, num_heads, n_context, head_size]

        # Attention Score Table
        wei = q @ k.transpose(-2, -1) * q.shape[-1]**-0.5 # [n_batch, num_heads, n_context, n_context]
        # Masked Attention
        wei = wei.masked_fill(self.tril[:n_context, :n_context] == 0, float('-inf'))
        # Aggregation
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        out = wei @ v # [n_batch, num_heads, n_context, head_size]
        out = out.transpose(1, 2).reshape(n_batch, n_context, self.num_heads*self.head_size)

        out = self.dropout2(self.proj(out)) # (n_batch, n_context, EMBEDDING_SIZE)
        return out

In [62]:
class FeedForward(nn.Module):
    def __init__(self, in_feat):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_feat, in_feat * 4),
            nn.ReLU(),
            nn.Linear(4 * in_feat, in_feat),
            nn.Dropout(DROPOUT)
        )
    
    def forward(self, x):
        return self.net(x)

In [63]:
# Transformer Block: Communication (MultiHead Attention) followed by computation (MLP - FeedForward)
class Block(nn.Module):
    def __init__(self, n_heads, head_size):
        super().__init__()
        self.sa_heads = CausalSelfAttention(n_heads, head_size)
        self.ffwd = FeedForward(EMBEDDING_SIZE)

        self.ln1 = nn.LayerNorm(EMBEDDING_SIZE)
        self.ln2 = nn.LayerNorm(EMBEDDING_SIZE)
    
    def forward(self, x):
        # x + because their are residual connections around Masked Multi-Head Attention and Feed Forward (see Transformer Architecture)
        x = x + self.sa_heads(self.ln1(x)) # (BATCH_SIZE, CONTEXT_SIZE, num_heads*head_size)
        x = x + self.ffwd(self.ln2(x)) # (BATCH_SIZE, CONTEXT_SIZE, num_heads*head_size)
        return x


In [64]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

        # add an Embedding Table for Character Embedding
        self.token_embedding_table = nn.Embedding(VOCAB_SIZE, EMBEDDING_SIZE)
        self.position_embedding_table = nn.Embedding(CONTEXT_SIZE, EMBEDDING_SIZE)
        self.blocks = nn.Sequential(*[Block(NUM_HEADS, HEAD_SIZE) for _ in range(BLOCK_COUNT)])
        self.ln_f = nn.LayerNorm(EMBEDDING_SIZE) # final layer norm
        self.lm_head = nn.Linear(EMBEDDING_SIZE, VOCAB_SIZE)

        # better init, not covered in the original GPT video, but important, will cover in followup video
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, x, y=None):
        n_batch, n_context = x.shape

        tok_emb = self.token_embedding_table(x) # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_SIZE)
        pos_emb = self.position_embedding_table(torch.arange(0, n_context, device=device)) # position embedding for each char in CONTEXT (CONTEXT_SIZE, EMBEDDING_SIZE)
        x = tok_emb + pos_emb # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_SIZE)
        x = self.blocks(x)
        x = self.ln_f(x) # (BATCH_SIZE, CONTEXT_SIZE, EMBEDDING_SIZE)
        logits = self.lm_head(x) # (BATCH_SIZE, CONTEXT_SIZE, VOCAB_SIZE)
        
        if y is None:
            loss = None
        else:
            logits = logits.view(n_batch*n_context, VOCAB_SIZE)
            y = y.view(n_batch*CONTEXT_SIZE)
            loss = F.cross_entropy(logits, y)

        return logits, loss
    
    def generate(self, previous_text, max_new_tokens):
        output = previous_text
        for _ in tqdm(range(max_new_tokens)):
            last_tokens = torch.tensor(encode(output[-CONTEXT_SIZE:]), device=device)
            
            # add batch dimension and feed to model
            logits, _ = self(last_tokens.view(1, -1))
            probs = F.softmax(logits, dim=-1)
            probs_next_char = probs[0, -1]
            new_char = itos[torch.multinomial(probs_next_char, num_samples=1).item()]

            output += new_char

        return output

In [65]:
model = Decoder()
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

for step in tqdm(range(MAX_STEPS)):
    # calculate loss every once in a while
    if step % EVAL_INTERVAL == 0:
        losses = estimate_loss()
        print(f"Step {step}/{MAX_STEPS}) train: {losses['train']:.4f}, val: {losses['val']:.4f}")

    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

  0%|          | 0/5000 [00:00<?, ?it/s]

100%|██████████| 200/200 [00:02<00:00, 93.39it/s] 
100%|██████████| 200/200 [00:01<00:00, 104.23it/s]
  0%|          | 1/5000 [00:04<5:49:53,  4.20s/it]

Step 0/5000) train: 4.4371, val: 4.4379


100%|██████████| 200/200 [00:01<00:00, 109.46it/s]
100%|██████████| 200/200 [00:01<00:00, 136.39it/s]
 10%|█         | 504/5000 [00:18<16:02,  4.67it/s]

Step 500/5000) train: 2.6307, val: 2.6180


100%|██████████| 200/200 [00:01<00:00, 109.07it/s]
100%|██████████| 200/200 [00:01<00:00, 108.11it/s]
 20%|██        | 1008/5000 [00:31<11:04,  6.01it/s]

Step 1000/5000) train: 2.3817, val: 2.3719


100%|██████████| 200/200 [00:01<00:00, 115.16it/s]]
100%|██████████| 200/200 [00:01<00:00, 116.75it/s]
 30%|███       | 1508/5000 [00:47<08:34,  6.78it/s]

Step 1500/5000) train: 2.2576, val: 2.2529


100%|██████████| 200/200 [00:01<00:00, 121.21it/s]]
100%|██████████| 200/200 [00:01<00:00, 106.66it/s]
 40%|████      | 2008/5000 [01:02<08:21,  5.97it/s]

Step 2000/5000) train: 2.1812, val: 2.1837


100%|██████████| 200/200 [00:01<00:00, 122.84it/s]]
100%|██████████| 200/200 [00:01<00:00, 110.03it/s]
 50%|█████     | 2506/5000 [01:16<06:10,  6.74it/s]

Step 2500/5000) train: 2.1469, val: 2.1635


100%|██████████| 200/200 [00:01<00:00, 108.16it/s]]
100%|██████████| 200/200 [00:02<00:00, 95.47it/s]
 60%|██████    | 3007/5000 [01:31<05:50,  5.69it/s]

Step 3000/5000) train: 2.1238, val: 2.1259


100%|██████████| 200/200 [00:01<00:00, 125.41it/s]]
100%|██████████| 200/200 [00:01<00:00, 119.95it/s]
 70%|███████   | 3505/5000 [01:45<05:05,  4.89it/s]

Step 3500/5000) train: 2.1055, val: 2.1098


100%|██████████| 200/200 [00:02<00:00, 92.46it/s] ]
100%|██████████| 200/200 [00:01<00:00, 103.38it/s]
 80%|████████  | 4009/5000 [02:00<03:08,  5.25it/s]

Step 4000/5000) train: 2.0823, val: 2.0922


100%|██████████| 200/200 [00:01<00:00, 116.51it/s]]
100%|██████████| 200/200 [00:01<00:00, 129.69it/s]
 90%|█████████ | 4507/5000 [02:13<01:16,  6.46it/s]

Step 4500/5000) train: 2.0724, val: 2.0769


100%|██████████| 5000/5000 [02:24<00:00, 34.71it/s]


In [66]:
# Inference (Generate Harry Potter'ish text)
output = model.generate("\n", 2000)
print(output)

100%|██████████| 2000/2000 [00:09<00:00, 203.99it/s]


of the sive dat untulingene froned o2 froby an —"
Harliss now"
layintan its corthoatif be bed him flootheinne the. If sher thaom eand flioumie wlunes.” mome, and
to grosyou now bad Propestly.f Halshing to
hoks
yod— "
"I sud mone't sto lastes arod das and at and louct betaine he pottirsthing awallly dneck red Ronce. "Chescorred sta mat cead
aad Vottropsed to couldls. 7is
Thim flabp int topichtly bbeuth dow Roff to said. Bangeld grous you -” Whe. Med. Thefthing wlouth Luncan's moppry groftie whe to
a
weamed eell andidling McAintion,"
“Wed ClanPlems to at? Chan'n that a happ af a hrelver - beh pot for Ron, dating, ih ligced on the Pells. Harklef it ssose alock.
EPelhilase. Dobed tlotl quice, at snomser arover tundheriftind," Cut He gughforgudtared iting lotgapetired hen of him scrounaled sapined thoated herl ldostpelf lad deyard were
had and ssome cmiarion
chack?" - seeok are tichoset. “Ang fat—
Ant
Harry mipstorte Snw Pertrorg in her butoroun, tied gropo dubunn’th eexppiatined him tit


