# Transformer Architecture

![Transformer Architectural Diagram](https://d2l.ai/_images/transformer.svg)

This notebook implements the right hand side of this diagram. The `DecoderTransformer` is constructed such that it created an embedding and positional encoding of the input, which are combined before it enters into the `Block`s of `MultiheadAttention` networks. The output of the `MultiheadAttention` is passed through a fully connected MLP to produce the final output. 

What makes the implemented transformer a `DecoderTransformer` rather than an encoder is the inclusion of the following line in the attention `Head`: `self.register_buffer('tril', torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))`. This line ensures that the tokens only receive information from tokens that occur before the current token being processed. In other words, it ensures that they only see the past and not into the future. 

## Optimisations
### Layer Norm
- Mention the skip connections and dropout for the improved training of deep networks

### Skip Connections
TODO:

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [22]:
# Hyperparameters
torch.manual_seed(1337)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 64 # How many sequences to process at in parallel
BLOCK_SIZE = 256 # What is the maximum context length for predictions?
MAX_ITERATIONS = 101
LEARNING_RATE = 3e-4
EVAL_EVERY = 100
N_EMBEDDING_DIMENSIONS = 384
N_HEAD = 6
N_LAYER = 6
DROPOUT = 0.2

# Single Attention Head
 - Talk about the upper traingular portion

In [23]:
class Head(nn.Module):
    def __init__(self, head_size) -> None:
        super().__init__()
        self.head_size = head_size
        self.query = nn.Linear(N_EMBEDDING_DIMENSIONS, head_size, bias=False)
        self.key = nn.Linear(N_EMBEDDING_DIMENSIONS, head_size, bias=False)
        self.value = nn.Linear(N_EMBEDDING_DIMENSIONS, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)))
        self.dropout = nn.Dropout(DROPOUT)


    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x) # (B,T,head_size)
        q = self.query(x) # (B,T,head_size)

        wei = q @ k.transpose(-2,-1) * (self.head_size**-0.5)
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf')) # This makes it a decoder block
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        
        out = wei @ v
        
        return out

# Multi-Head Attention
- How does going to "multi head" improve things

In [24]:
class MultiHead(nn.Module):
    def __init__(self, n_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(N_EMBEDDING_DIMENSIONS, N_EMBEDDING_DIMENSIONS) # Added for skip connections
        self.dropout = nn.Dropout(DROPOUT)
    

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.proj(out)
        out = self.dropout(out)
        return out

# Feed Forward Network

In [25]:
class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed), # Multiply by 4 to copy the Transformer paper
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed), # Projection layer added for skip connections
            nn.Dropout(DROPOUT)
        )


    def forward(self, x):
        return self.net(x)

# Multi-Head Attention Block

In [26]:
class Block(nn.Module):
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.self_attention = MultiHead(n_head, head_size)
        self.feed_forward = FeedForward(n_embed)
        self.layer_norm_1 = nn.LayerNorm(n_embed)
        self.layer_norm_2 = nn.LayerNorm(n_embed)
    

    def forward(self, x):
        # NOTE: The x + self.self_attention(x) is the residual connection/skip connection
        x = x + self.self_attention(self.layer_norm_1(x))
        x = x + self.feed_forward(self.layer_norm_2(x))
        return x

# Decoder Transformer

In [27]:
class DecoderTransformer(nn.Module):
    def __init__(self, vocab_size, embedding_size, block_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, embedding_size)
        self.position_embedding_table = nn.Embedding(block_size, embedding_size)
        self.blocks = nn.Sequential(*[Block(embedding_size, N_HEAD) for _ in range(N_LAYER)])
        self.layer_norm = nn.LayerNorm(embedding_size)
        self.lm_head = nn.Linear(embedding_size, vocab_size)
    

    def forward(self, idx, targets=None):
        B,T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.layer_norm(x)

        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets)

        return logits, loss


    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx = idx[:, -BLOCK_SIZE:] # Crop to block size
            logits, loss = self(idx)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx

# Helper Functions

In [28]:
train_data = None
val_data = None

def load_dataset(path: str) -> str:
    with open(path, 'r') as f:
        text = f.read()
    return text


def get_vocab(text: str) -> list:
    return sorted(list(set(text)))


def get_batch(split: str):
    assert split in ['train', 'val']
    data = train_data if split == 'train' else val_data
    
    ix = torch.randint(0, len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([data[i:i+BLOCK_SIZE] for i in ix])
    y = torch.stack([data[i+1:i+BLOCK_SIZE+1] for i in ix])
    return x, y

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()

    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_EVERY)
        for k in range(100):
            xb, yb = get_batch(split)
            logits, loss = model(xb, yb)
            losses[k] = loss.item()
        out[split] = losses.mean()
    
    model.train()
    return out

In [29]:
# Load things
dataset = load_dataset('../datasets/tiny-shakespeare.txt')
vocab = get_vocab(dataset)

stoi = {u:i for i, u in enumerate(vocab)}
itos = {i:u for i, u in enumerate(vocab)}
encode = lambda x: [stoi[c] for c in x]
decode = lambda x: ''.join([itos[c] for c in x])

data = torch.tensor(encode(dataset), dtype=torch.long).to(DEVICE)
train_val_ratio = int(0.9 * len(data))
train_data = data[:train_val_ratio]
val_data = data[train_val_ratio:]

model = DecoderTransformer(vocab_size=len(vocab), 
                           embedding_size=N_EMBEDDING_DIMENSIONS, 
                           block_size=BLOCK_SIZE)
model.to(DEVICE)

optimiser = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

In [30]:
# Training Loop
for iters in range(MAX_ITERATIONS):
    if iters % EVAL_EVERY == 0:
        losses = estimate_loss(model)
        print(f"Step {iters}: train loss {losses['train']:.3f}, val loss {losses['val']:.3f}")

    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimiser.zero_grad(set_to_none=True)
    loss.backward()
    optimiser.step()

Step 0: train loss 4.285, val loss 4.282
Step 100: train loss 2.473, val loss 2.490


In [31]:
input_idx = torch.zeros(1, 1, dtype=torch.long, device=DEVICE) # start with a single token {0: '\n'}
print(decode(model.generate(input_idx, max_new_tokens=1000)[0].tolist()))

athqissthetanchemave mashimen tand, cevont bld yomuthy daime my towindss t byofflir d t's ene.
Thofind bero gh ns ourerrs ve ttow.

I thomeder;
K:
Me ld, ted bout adowhoitheith t, bo, fot g. hice's ithe Herit d n's O:
Courdon ben o s PA imou trospangin; gte
