In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

import os
import urllib

block_size = 64
batch_size = 256
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_heads = 6
n_layers = 6
head_size = n_embd / n_heads
dropout = 0.2

In [13]:
# Fetch shakespear data from internet
shakespeare_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

# Download the file from the URL if it does not exist
shakespeare_path = "shakespeare.txt"
if not os.path.isfile(shakespeare_path):
    urllib.request.urlretrieve(shakespeare_url, shakespeare_path)

with open(shakespeare_path) as f:
    shakespeare_text = f.read()
print(shakespeare_text[:148])


First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?



In [14]:
# Create a list of unique characters in the text
chars = sorted(list(set(shakespeare_text)))
vocab_size = len(chars)
print(f"Number of unique characters: {vocab_size}")

# Create a mapping from characters to indices and vice versa
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

# Convert the text to a tensor
encode = lambda s: [stoi[ch] for ch in s]
decode = lambda i: "".join([itos[ch] for ch in i])

data = encode(shakespeare_text)
data = torch.tensor(data, dtype = torch.long)

train_data = data[:int(0.9 * len(data))]
val_data = data[int(0.9 * len(data)):]

# Test the function
print(decode(data[:148].tolist()))

Number of unique characters: 65
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?



In [15]:
def get_batch(split):
    data = train_data if split == "train" else val_data
    idx = torch.randint(len(data) - block_size, (batch_size,))
    inputs = torch.stack([data[i:i+block_size] for i in idx])
    outputs = torch.stack([data[i+1:i+block_size+1] for i in idx])
    x, y = inputs.to(device), outputs.to(device)
    return inputs, outputs

# Test the function
inputs, outputs = get_batch("train")
print(inputs.shape, outputs.shape)

# print the first example
print(decode(inputs[0].tolist()))
print(decode(outputs[0].tolist()))


torch.Size([256, 64]) torch.Size([256, 64])
 some more mightier member
That sets them on: let me have way, m
some more mightier member
That sets them on: let me have way, my


In [16]:
@torch.no_grad()
def estimate_loss():
    """
    Estimate the loss of the model on the training and validation sets
    """
    out = {}    # store the loss values for the training and validation sets
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out

#### Self-attention:
Each token emits three vectors:
* Query: represents what the token is looking for
* Key: represents what the token contains
* value: represents the values to be averaged over

For each node/token we get a similarity measure between the query and all the other keys using a dot product. This similarity measure determines how much attention is paid to a given token at time t, and we aggregate the values based on the similarity.


### Notes:
* Decoding: only care about the past
* Encoding: include the future as well (no masking)
* attention is just a communication protocol between a set of nodes. Its like learning a probabilistic graph where the probability determines how much attention/communication to derive from another node.
* positional encoding is necessary else the nodes are just randomly placed in space
* Self-attention: because the keys and values come from x (the data itself)
* Cross-attention: if keys and values are supplied externally
* Dividing by sqrt head size to preserve variance at initialization. Else the exponential may be too peaky (favoring few values). We want diffuse weights.

In [17]:
class Head(nn.Module):
    def __init__(self, C, head_size):
        super(Head, self).__init__()
        self.key = nn.Linear(C, head_size, bias=False)
        self.query = nn.Linear(C, head_size, bias=False)
        self.value = nn.Linear(C, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('tril', torch.tril(torch.ones(block_size,block_size))) # not a parameter

    def forward(self, x):
        # x of dim (B,T,C)
        B, T, C= x.shape

        k = self.key(x)     # (B, T, head_size)
        q = self.query(x)
        v = self.value(x)

        wei = q @ k.transpose(-2,-1) / np.sqrt(k.shape[-1])
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        
        out = wei @ v
        return out


In [18]:
class MultiHeadAttention(nn.Module):
    def __init__(self, C, head_size, n_heads):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([Head(C, head_size) for _ in range(n_heads)])
        self.linear = nn.Linear(n_heads * head_size, C)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.linear(out)
        out = self.dropout(out)
        return out

In [19]:
class FeedForwardBlock(nn.Module):
    def __init__(self):
        super(FeedForwardBlock, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [20]:
class Block(nn.Module):
    def __init__(self):
        super(Block, self).__init__()
        head_size = n_embd // n_heads
        self.sa_heads = MultiHeadAttention(n_embd, head_size, n_heads)
        self.ff = FeedForwardBlock()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa_heads(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

In [21]:
# What is a bigram language model
# A bigram language model is a model that predicts the next character given the current character.
# The model is trained on a sequence of characters and learns the probability distribution of the next character given the current character.
# The model can be used to generate text by sampling the next character from the probability distribution.

class BigramLanguageModel(torch.nn.Module):
    def __init__(self):
        super(BigramLanguageModel, self).__init__()
        self.embedding = torch.nn.Embedding(vocab_size, n_embd)
        self.position_embedding = torch.nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential( *[Block() for _ in range(n_layers)] )
        self.ln = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
    
    def forward(self, x, target = None):
        # x is a tensor of shape (batch_size, sequence_length)
        B, T = x.shape

        tok_embds = self.embedding(x)   # (B,T) -> (B,T,C)
        pos_embds = self.position_embedding(torch.arange(T).to(x.device)) # (T,C)
        x = tok_embds + pos_embds # (B,T,C)
        x = self.blocks(x)
        x = self.ln(x)
        logits = self.lm_head(x) # (B,T,vocab_size)

        # target is a tensor of shape (batch_size, sequence_length)
        if target is not None:
            # Calculate the loss
            # logits -> (characters in sequence, logits)
            # targets -> (outputs/classes in sequence)
            # each target value is the idx of the output and is compared to the logit channels
            loss = torch.nn.functional.cross_entropy(logits.view(-1, vocab_size), target.view(-1))
        else:
            loss = None

        return logits, loss
    
    def generate(self, x, length = 1):
        # x is a tensor of shape (batch_size, sequence_length)
        # The embedding layer converts the input tensor to a tensor of shape (batch_size, sequence_length, vocab_size)
        logits, loss = self.forward(x[:, -block_size:])

        # Generate the next character
        next_char = torch.multinomial(torch.nn.functional.softmax(logits[:, -1], dim=-1), 1)

        # Append the next character to the input tensor
        x = torch.cat([x, next_char], dim=1)

        # Repeat the process for the specified length
        for i in range(length - 1):
            logits, loss = self.forward(x[:, -block_size:])
            next_char = torch.multinomial(torch.nn.functional.softmax(logits[:, -1], dim=-1), 1)

            x = torch.cat([x, next_char], dim=1)

        return x
    
# Create a bigram language model
model = BigramLanguageModel().to(device)

# Test the model
inputs, outputs = get_batch("train")  # x: (B,T), y: (B,T)
logits, loss = model(inputs, outputs) # logits: (B,T,C), loss: scalar
print(logits.shape, loss)

# generate text
# inputs, outputs = get_batch("train")
# generated = model.generate(inputs, 100)
# print(decode(generated[0].tolist()))

torch.Size([256, 64, 65]) tensor(4.3140, grad_fn=<NllLossBackward0>)


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    X, Y = get_batch("train")
    optimizer.zero_grad()
    logits, loss = model(X, Y)
    loss.backward()
    optimizer.step()

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"Iter: {iter}, Train loss: {losses['train']}, Val loss: {losses['val']}")

# Generate text
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, 500)[0].tolist()))

### Training Notes:

* Dropout:
    * Added before residual connection
    * Added after projection in multihead
    * Can be added on the weights in the heads before multiplication
* LayerNorm:
    * Normalized on an embedding level. Has parameters so that it can scale back up if necessary (depending on optimization)
* BatchNorm:
    * Normalizes over examples in a batch
* Reduce learning_rate for deeper networks


### Notes on Architecture:

* Decoder only architecture (has tril weights)
* The Encoder in the original paper was for language translation:
    * You encode the original sentence and condition on it
    * full weights
    * Feed keys and values from the encoder to the decoder
* Training Stages: 
    1) Pretraining: decoding only
    2) Finetuning / Alignment: through RLHF (Labelers label or rank outputs and we use RL to maximize reward as per the feedback)

### Intermediate tests and layers

In [None]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

# for a function of dim (B, T, C), create a function that takes the average of the prior characters for each character at time t
def xbow(x):
    """
    Compute the average context for each character in the input tensor
    """
    # x is a tensor of shape (batch_size, sequence_length)
    # Create a tensor of shape (batch_size, sequence_length) to store the average context
    avg = torch.zeros_like(x, dtype=torch.float)

    for b in range(x.size(0)):
        for t in range(x.size(1)):
            # Compute the average context for each character at time t
            avg[b, t] = x[b, :t+1].mean(dim=0)
    return avg

print(xbow(x)[1])

def xbow_efficient(x):
    """
    Compute the average context for each character in the input tensor
    """
    # T = x.shape[1]
    # wei = torch.tril(torch.ones(T,T))
    # wei = wei / wei.sum(1, keepdim=True)
    # avg = wei @ x 
    # return avg
    avg_mat = torch.tril(torch.ones(T,T, dtype=torch.float, device=x.device)) \
                / torch.arange(1, T+1, device=x.device, dtype=torch.float).unsqueeze(1)
    avg = torch.matmul(avg_mat, x.float())
    return avg

def xbow_softmax(x):
    """
    Compute the average context for each character in the input tensor
    """
    T = x.shape[1]
    tril = torch.tril(torch.ones(T,T))
    wei = torch.zeros((T,T))
    wei = wei.masked_fill(tril == 0, float('-inf'))
    wei = torch.nn.functional.softmax(wei, dim=1)
    avg = wei @ x 
    return avg


In [None]:
# self attention
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn((B, T, C))

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)      #(B, T, head_size)
q = query(x)    #(B, T, head_size)
v = value(x)    #(B, T, head_size)

wei = q @ k.transpose(-2, -1) / np.sqrt(head_size) # (B, T, T)

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = torch.nn.functional.softmax(wei, dim=-1)

avg = wei @ v

avg.shape
