# BiGram Model Implementation

This Jupyter Notebook will follow Andrej Karpathy's video: [Let's Build GPT: from scratch](https://youtu.be/kCc8FmEb1nY). I've added comments where I feel important concepts are mentioned and other areas.

In [42]:
# read in the text line by line
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# establish the characters and the vocabulary size
chars = sorted(list(set(text)))
vocab_size = len(chars)

Create the encoder and decoder functions.

In [43]:
# initialise lookuptables
char_ind = {}
ind_char = {}
for i, ch in enumerate(chars):
    char_ind[ch] = i
    ind_char[i] = ch

In [44]:
# encoder and decoder functions
def encode(input_str:str) -> list:
    return [char_ind[s] for s in input_str]

def decode(input_list:list) -> str:
    return ''.join([ind_char[l] for l in input_list])

Andrej Karpathy spoke about how there is a trade off between vocab size and encoded sequence size. If you have a large vocab size that means each individual sequence of text can be described in a smaller sequence of numbers, and vice versa.

#### Convert the data into tensors with PyTorch

In this step we convert our data into tensors. We will be performing our operations fully on tensors since this is the datatype that machine learning algorithms can actually understand.

In [45]:
import torch

data = torch.tensor(encode(text), dtype=torch.long)

From the tests below it seems that the `0` key is a new-line indicator and `1` is a space indicator.

In [46]:
print(data.shape, data.dtype, data[:100], text[:100], sep = '\n\n')

torch.Size([1115394])

torch.int64

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


#### Split up the data into training and evaluation

Here we've chosen a 90/10 split. 90% of our data will be used to train the model and the other 10% will be used to evaluate and test the model.

In [47]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

The code below was taken from Andrej Karpathy directly. Here we are creating the batches of data we wil train our model on.

In [48]:
torch.manual_seed(1337)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 4
block_size = 8 

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [49]:
xb, yb = get_batch('train')

### Bigram Language Model Architecture

The Model architecture is generally standard and there are templates available. We can find such templates in PyTorch documentation. Here's a link to their `NGramLanguageModel` architecture:

https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html

In [50]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__ (self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)

        if targets == None:
            loss = None

        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1 , :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [51]:
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)

In [52]:
idx = torch.zeros((1,1), dtype=torch.long)                # kickstart generation with 0 tensor.
generated_tokens = m.generate(idx, max_new_tokens=100)[0] # generate 100 new tokens based on 0 tensor.
encoded_list = generated_tokens.tolist()                  # convert tensor into list. 
print(decode(encoded_list))                               # decode the list to give the actual output.
print(loss)


SKIcLT;AcELMoTbvZv C?nq-QE33:CJqkOKH-q;:la!oiywkHjgChzbQ?u!3bLIgwevmyFJGUGp
wnYWmnxKWWev-tDqXErVKLgJ
tensor(4.8786, grad_fn=<NllLossBackward0>)


The result above is completely garbage since our model hasn't been trained and its just spitting out random values based of probabilities.

In [53]:
optimizer = torch.optim.AdamW(m.parameters(), lr=0.001)

batch_size = 32
for steps in range(1000):
    xb, yb = get_batch('train')

    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

3.704137086868286


In [54]:
print(decode((m.generate(idx = torch.zeros((1,1), dtype=torch.long), max_new_tokens=200)[0].tolist())))


Wh;;Sq.f ustNzknc
kwgOj$dhPWr,SV?hsusiKpgXXUh;Apmem d?hESXI.i;TrJgkiF-oKbXCAA -botrngFCHAUQkn$

pn$w-gHoi?wtd!
LLULIfSK'bAw :M.ZtOptXEQcL?hfaofqbPd?OnonQQJMap$aypupIBYGUsZaI'ottllo..k$W$Akp?yl?ajKlzY!


Note that now there is visible structure in the output and contains some actual words instead of gibberish.
This is the simplest possible model because, "the tokens aren't talking to each other" as Andrej Karpathy puts it in his video. Next we implement the Transformer model with multiheaded attention.

The loss evaluation is very noisy at the moment because it ouptuts the loss on the last training step. By averaging out the loss over iterations we get a better idea of our parameter.

In [55]:
# generalise some of our previously established code
model = BigramLanguageModel(vocab_size = len(chars))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
m = model.to(device)
eval_iters = 200

In [56]:
# create the new loss estimation function
def estimate_loss():
    out = {}
    model.eval() # put the model into evaluation mode

    # evaluate the model on the training data and the evaluation data
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X,Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [57]:
max_iters = 10000
eval_interval = 1000

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.6001, val loss 4.6191
step 1000: train loss 3.6539, val loss 3.6814
step 2000: train loss 3.0974, val loss 3.1083
step 3000: train loss 2.7838, val loss 2.8092
step 4000: train loss 2.6390, val loss 2.6522
step 5000: train loss 2.5513, val loss 2.5750
step 6000: train loss 2.5237, val loss 2.5498
step 7000: train loss 2.4866, val loss 2.5079
step 8000: train loss 2.4819, val loss 2.5129
step 9000: train loss 2.4698, val loss 2.4920


In [58]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=200)[0].tolist()))


ke t HAduriamy vel ive-fcome s te col MEOfel.
O:
ANG IUCKERO:
MPO:
LAS:
T:
QUGIs.
K:
If hasevep, lenghen mbru-get;xas
KIThar, kis Whe s to s owor whaty al wigalakee I myfoutong llik coulond:
MBus wor,


# Attention Modules.

Before we start, here is a "mathematical trick" in the self-attention module.

In [59]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2            # size constraints -- Batch | Time | Channel.
x = torch.randn(B, T, C)     # initialise random vector with required size.
x.shape

torch.Size([4, 8, 2])

There are 8 tokens in the tensor above. We want the tokens to start talking to each other but the contstraint is that the $x_i$ token should talk to all the previous tokens but NOT the $x_{i+1}$ token. This is because the goal of our model is to predict the next token. If we gave it information about the next token then our evaluation functions would not work.

The way we transport information is by taking an average of all of the previous tokens up until the $x_{i}$ token for every token. There are 3 different methods we can employ to perform our averaging.

## Version 1: `FOR` Loop Iteration.

In [60]:
x_bow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        x_prev = x[b,:t + 1]
        x_bow[b,t] = torch.mean(x_prev, 0)

print(x[0], x_bow[0], sep='\n\n')

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


## Version 2: Matrix Multiplication.

We can make this much faster using matrix mutliplication instead of a `for` loop. Here is the general demonstration below.

In [61]:
torch.manual_seed(69)
a = torch.tril(torch.ones(3, 3)) # 3x3 matrix
a = a / torch.sum(a, 1, keepdim=True)  # makes it so that each row sum up to 1.
b = torch.randint(0,10,(3,2)).float() # 3x2 matrix

c = a @ b  # c is now the successive average term matrix of b with size 3x2

print(f'a=\n{a}', f'b=\n{b}', f'c=\n{c}', sep='\n')


a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b=
tensor([[0., 5.],
        [9., 3.],
        [2., 5.]])
c=
tensor([[0.0000, 5.0000],
        [4.5000, 4.0000],
        [3.6667, 4.3333]])


Let's apply this to our situation.

In [62]:
# initialise the weight matrix.
weights = torch.tril(torch.ones(T, T))
weights = weights/ weights.sum(1, keepdim=True)

weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

This is 

In [63]:
x_bow_mm = weights @ x # the '@' operator figures out necessary matrix multiplication dimensions.
x_bow_mm[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [64]:
torch.allclose(x_bow_mm, x_bow) # test if the two methods are the same for all tensors.

True

## Version 3: Softmax.

This will be our prefered method and the explanation will follow after the couple of lines.

In [65]:
tril = torch.tril(torch.ones(T, T))  # starts of as a lower triangular matrix with all entries being 1.
weights_sm = torch.zeros((T, T))     # starts of as 0 matrix.
# for all elements where tril == 0, make that entry into -infty
weights_sm = weights_sm.masked_fill(tril == 0, float('-inf')) # "the future can't communicate with the past."

weights_sm

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

This method id beneficial because we can think of the $-\infty$ values as masks, essentially telling the model that those future tokens don't exist.

In [66]:
from torch.nn import functional as F

weights_sm = F.softmax(weights_sm, dim=-1)
weights_sm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [67]:
x_bow_sm = weights_sm @ x

torch.allclose(x_bow_mm, x_bow_sm)

True

# Re-Defining Model Architecture.

In [68]:
n_embeddings = 32
vocab_size = len(chars)

In [69]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__ (self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embeddings)
        self.position_embedding_table = nn.Embedding(block_size, n_embeddings)
        self.lm_head = nn.Linear(n_embeddings, vocab_size)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_embeddings = self.token_embedding_table(idx)
        pos_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        x = tok_embeddings + pos_embeddings # x now holds the token identities and the positions where they occur.

        logits = self.lm_head(x)

        if targets == None:
            loss = None

        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:, -1 , :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

## Version 4: Self Attention.

Below we implement one head of attention.

In [70]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

In [71]:
k = key(x)
q = query(x)
weights = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

Now instead of `weights` being the lower trianlular matrix with all entries being 1. We have `weights` as the matrix above.

In [72]:
tril = torch.tril(torch.ones((T, T)))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
out = weights @ x

We'll tweek our self-attention head slightly by introducing `value`

In [73]:
torch.manual_seed(1337)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16

# every new instance of nn.Linear() initialises weights randomly -- despite having a manual seed set.
key = nn.Linear(C, head_size, bias=False)    
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x)
q = query(x)
weights = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones((T, T)))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)
out = weights @ value(x)

In [74]:
out.shape  # since head_size = 16 and we dot product with value(x) (B, T, 16)

torch.Size([4, 8, 16])

If we wanted to create an `encoder` block then we woudl get rid of our masking layer for `weights` allowing all tokens to communicate with each other. Right now this is a decoder block because we have masking right now.

# Applying Self Attention with Transformers.

In [75]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key   = nn.Linear(n_embeddings, head_size, bias=False)
        self.query = nn.Linear(n_embeddings, head_size, bias=False)
        self.value = nn.Linear(n_embeddings, head_size, bias=False)

        # 'tril' is not a paramter of the module, so in PyTorch naming convention. We call it a buffer.
        # we assign it using the register buffer which creates the tril variable.
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        weights = q @ k.transpose(-2, -1) * (C)**(-0.5)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=-1)

        out = weights @ v

        return out

### Re-Defining BiGram Model.

In [76]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__ (self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embeddings)
        self.position_embedding_table = nn.Embedding(block_size, n_embeddings)
        self.lm_head = nn.Linear(n_embeddings, vocab_size)
        self.sa_head = Head(n_embeddings)
        
    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_embeddings = self.token_embedding_table(idx)
        pos_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        x = tok_embeddings + pos_embeddings # x now holds the token identities and the positions where they occur.
        x = self.sa_head(x) # NEW LINE!!!

        logits = self.lm_head(x)

        if targets == None:
            loss = None

        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]  # NEW LINE!!!
            logits, loss = self(idx_cond)
            logits = logits[:, -1 , :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

### Training New Model W/ Attention.

In [77]:
# generalise some of our previously established code
model = BigramLanguageModel()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
m = model.to(device)
eval_iters = 200

max_iters = 10000
eval_interval = 1000

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2468, val loss 4.2452
step 1000: train loss 2.5396, val loss 2.5553
step 2000: train loss 2.4580, val loss 2.4627
step 3000: train loss 2.4234, val loss 2.4371
step 4000: train loss 2.4024, val loss 2.4225
step 5000: train loss 2.3987, val loss 2.4041
step 6000: train loss 2.3813, val loss 2.3972
step 7000: train loss 2.3718, val loss 2.4105
step 8000: train loss 2.3611, val loss 2.3811
step 9000: train loss 2.3698, val loss 2.3763


In [78]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=200)[0].tolist()))


xquorouseee mons'd:
Anet bet ckeeitle, thar lan forowas yed theel:
THAer, fo yumy fols yow fesenem sn bel Thy mous adetrt mond gepr
Pany.

SCO asth ofr meeriwinf thay fothans sn wes: mowoutrly;
Ho d'd


While we have gotten our loss down to $\approx 2.3698$ we can improve this further by implementing multi-headed attention. Multiheaded attention is simply an attention block but we do it many more times and concatenate our results at the end to give our final vector.

In [79]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])
        
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return out

### Re-defining BiGram Model W/ Multi-Headed Attention.

In [80]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__ (self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embeddings)
        self.position_embedding_table = nn.Embedding(block_size, n_embeddings)
        self.lm_head = nn.Linear(n_embeddings, vocab_size)

        # instead of having 1 head of 32 (single-headed attention)
        # we have 4 heads of 8 (multi-headed attention)
        self.sa_heads = MultiHeadAttention(4, n_embeddings // 4) 
        
    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_embeddings = self.token_embedding_table(idx)
        pos_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        x = tok_embeddings + pos_embeddings
        x = self.sa_heads(x) # NEW LINE!!!

        logits = self.lm_head(x)

        if targets == None:
            loss = None

        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]  # NEW LINE!!!
            logits, loss = self(idx_cond)
            logits = logits[:, -1 , :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

### Training the Model w/ Multi-Headed Attention.

In [81]:
model = BigramLanguageModel()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
m = model.to(device)
eval_iters = 200

max_iters = 10000
eval_interval = 1000

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.2418, val loss 4.2400
step 1000: train loss 2.4670, val loss 2.4791
step 2000: train loss 2.3648, val loss 2.3585
step 3000: train loss 2.2957, val loss 2.3113
step 4000: train loss 2.2579, val loss 2.2828
step 5000: train loss 2.2464, val loss 2.2614
step 6000: train loss 2.2192, val loss 2.2546
step 7000: train loss 2.1982, val loss 2.2572
step 8000: train loss 2.1851, val loss 2.2284
step 9000: train loss 2.1951, val loss 2.2346


In [82]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=200)[0].tolist()))



Rhorouse erathe'd:
And, bet coreiell, thar laive.

Nos yeints
SA:
To en, fory,
Ands mave thesencmown be, Thild viry curt mond geladight.

SCOLUMME:
Br me, in nfning hawth,
Youns
This
Thou ous;
Ho dem


Note that now we have reduced our loss down to $\approx 2.2361$ this is a big improvement from our previous single-headed attention run.

# Adding a N.N.

We were able to get the tokens talking to each other. However, when we were calculating our probabilities(`logits`) we went too fast and didn't let the tokens 'think' about what they actually found. In this next step we will implement `FeedForward()` which is a function that will act as a Neural Net to compute `logits`.

In [84]:
class FeedForward(nn.Module):
    def __init__(self, n_embeddings):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embeddings, n_embeddings),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.net(x)
        

In [85]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__ (self):
        super().__init__()
        self.token_embedding_table    = nn.Embedding(vocab_size, n_embeddings)
        self.position_embedding_table = nn.Embedding(block_size, n_embeddings)

        self.lm_head = nn.Linear(n_embeddings, vocab_size)
        self.ffwd    = FeedForward(n_embeddings) # NEW LINE!!

        # single-headed attention -- 1 head of 32
        # multi-headed attention  -- 4 heads of 8
        self.sa_heads = MultiHeadAttention(4, n_embeddings // 4) 
        
    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_embeddings = self.token_embedding_table(idx)
        pos_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        x = tok_embeddings + pos_embeddings
        x = self.sa_heads(x)
        x = self.ffwd(x)       # NEW LINE!!!
        
        logits = self.lm_head(x)

        if targets == None:
            loss = None

        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1 , :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [86]:
model = BigramLanguageModel()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
m = model.to(device)
eval_iters = 200

max_iters = 10000
eval_interval = 1000

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for iter in range(max_iters):
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.1975, val loss 4.1968
step 1000: train loss 2.4594, val loss 2.4683
step 2000: train loss 2.3524, val loss 2.3566
step 3000: train loss 2.2875, val loss 2.3015
step 4000: train loss 2.2516, val loss 2.2712
step 5000: train loss 2.2339, val loss 2.2399
step 6000: train loss 2.2005, val loss 2.2327
step 7000: train loss 2.1802, val loss 2.2404
step 8000: train loss 2.1635, val loss 2.1934
step 9000: train loss 2.1728, val loss 2.1986


In [88]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=200)[0].tolist()))


CAUMESTESS:
Mery and''s sut, your but lear tallomese sire And cus or your.

KINTEDY Jlivight manke up hat her arpofell ow eves ings:
Whe tint
Hhos be hust und blome is sicht inots vingbobe
Kollys seth
