# NanoGpt

Writing and traing transformers on [tiny shakespeare dataset](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) as a character level language model.

Papers used:
1. [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf)

## Read and Explore data

In [1]:
# Get the data 
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-05-09 09:20:48--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-05-09 09:20:49 (5.01 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
with open('input.txt', 'r') as f:
    text = f.read()

In [5]:
# Check out first 1000 chars
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [6]:
# Vocab - what are all the characters that'll be modelled? or possible values?
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Vocab size: {vocab_size}")
print(f"Possible characters: {''.join(chars)}")

Vocab size: 65
Possible characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


We've total of 65 possible characters, space, symbols, upper case and lower case characters. These are the possible characters that the model can see or emit.

## Tracking setup

In [7]:
import mlflow
def log_experiment(run_name, params:dict, metric:dict):
    mlflow.set_experiment("nanogpt")
    mlflow.end_run()
    mlflow.start_run(run_name=run_name)
    for k, v in params.items():
        mlflow.log_param(k, v)
    for k, v in metric.items():
        mlflow.log_metric(k, v)

## Tokenization, train/val split

Tokenization - Conversion string of characters to sequene of integers based on vocabulary(all possible characters)
train split - Model training
val split - To finetune paramters of model

In [8]:
# Create mapping from string to integers and vice versa
itos = {i: s for i, s in enumerate(chars)}
stoi = {v:k for k, v in itos.items()}
encode = lambda s: [stoi[c] for c in s] # encoder: takes a string and output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: takes a list of integers, outputs a string

In [9]:
print(encode("hi there"))
print(decode(encode('hi there')))

[46, 47, 1, 58, 46, 43, 56, 43]
hi there


What we're using here is a very simple tokenizer(encoder) and there are lot's of tokenizers, Listing a few...

1. [SentencePiece](https://github.com/google/sentencepiece#) - Google sub-word tokenizer.
2. [tikitoken](https://github.com/openai/tiktoken) - OpenAI's BPE tokenizer.

We can use long sequence of tokens over a small vocabulary or small sequence of tokens over a large vocabulary.

Typicall sub-word tokenizer are used in practice.

We're using long sequence of tokens over a small vocabulary for this notebook.

In [10]:
# Tokenize entire text
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [11]:
# Split train/val
n = int(len(text) * 0.9)
train_data = data[:n]
val_data = data[n:]
len(train_data), len(val_data)

(1003854, 111540)

## Data Loader: Batched of chunks of data

block_size - Maximum length of chunks to process


In [12]:
block_size = 8
train_data[:block_size + 1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

This chunk has multiple examples packed into it because all these characters follow each other.
In a chunk of 9 characters there are 8 examples packed in... Below code explains this point.

In [13]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"When input is {context} the target: {target}")

When input is tensor([18]) the target: 47
When input is tensor([18, 47]) the target: 56
When input is tensor([18, 47, 56]) the target: 57
When input is tensor([18, 47, 56, 57]) the target: 58
When input is tensor([18, 47, 56, 57, 58]) the target: 1
When input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
When input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
When input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


We train on all the characters from context between 1 to 8 for transformer network used to seeing all possible different block sizes. This will be useful in inference where all possible context with block size can be used.


We've another dimension batch.
We're going to feed mini batche of multiple chunks during training for parallel processing.
All these chunks are trained completley independetly.

In [14]:
# Batching
torch.manual_seed(1337)
batch_size = 4 # sequences to process in parallel?
block_size = 8 # maximum context length for predictions?

def get_batch(split):
    # gernerate a small batch of data of inputs x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i: i+block_size] for i in ix])
    y = torch.stack([data[i+1: i+block_size+1] for i in ix])
    return x, y

In [15]:
xb, yb = get_batch('train')
print(f'inputs shape: {xb.shape}')
print(f'inputs: {xb}')
print(f'tragets shape: {yb.shape}')
print(f'tragets: {yb}')


inputs shape: torch.Size([4, 8])
inputs: tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
tragets shape: torch.Size([4, 8])
tragets: tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [16]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f"When input is {context.tolist()} the target: {target.tolist()}")

When input is [24] the target: 43
When input is [24, 43] the target: 58
When input is [24, 43, 58] the target: 5
When input is [24, 43, 58, 5] the target: 57
When input is [24, 43, 58, 5, 57] the target: 1
When input is [24, 43, 58, 5, 57, 1] the target: 46
When input is [24, 43, 58, 5, 57, 1, 46] the target: 43
When input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
When input is [44] the target: 53
When input is [44, 53] the target: 56
When input is [44, 53, 56] the target: 1
When input is [44, 53, 56, 1] the target: 58
When input is [44, 53, 56, 1, 58] the target: 46
When input is [44, 53, 56, 1, 58, 46] the target: 39
When input is [44, 53, 56, 1, 58, 46, 39] the target: 58
When input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1
When input is [52] the target: 58
When input is [52, 58] the target: 1
When input is [52, 58, 1] the target: 58
When input is [52, 58, 1, 58] the target: 46
When input is [52, 58, 1, 58, 46] the target: 39
When input is [52, 58, 1, 58, 46, 39] the t

Inupt is 32 individual examples put in as 4 chunks together as batch and same for target.

## Baseline: Bigram model

This bigram model is similar to bigram model in makemore series. There we build a lookup table of 27 * 27 vocab_size using staisitics. 

In this we're gonna use PyTorch's Embeddings.

In [17]:
import torch 
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        # For each token lookup nexr character from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets):
        
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx) # (B, T, C)

        return logits

In [18]:
m = BigramLanguageModel(vocab_size)
out = m(xb, yb)
print(out.shape)

torch.Size([4, 8, 65])


In [19]:
out[0].shape, out[0, 0].shape

(torch.Size([8, 65]), torch.Size([65]))

Now we've probabaities of next character or logits for each individual example inside the batch.
Next let's calculate the loss.

In [20]:
import torch 
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        # For each token lookup nexr character from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets):
        
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx) # (B, T, C)

        loss = F.cross_entropy(logits, targets)

        return logits, loss
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)

RuntimeError: Expected target size [4, 65], got [4, 8]

This is not working, becuase Pytorch's CrossEntropy expects logits in the shape of B, C, T instead of B, T, C.

In [21]:
# Let's fix this
# We're going to strech 4, 8 to a single dimension using view to 32 

import torch 
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        # For each token lookup nexr character from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets):
        
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx) # (B, T, C)

        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)

        return logits, loss
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)

torch.Size([32, 65])


In [22]:
print(loss)

tensor(4.8786, grad_fn=<NllLossBackward0>)


Since we've 65 posible characters, our loss should be -ln(1/65)

In [23]:
-torch.log(torch.tensor(1/65))

tensor(4.1744)

The loss we've got is 4.8786 meaning the entropy at initialization is big.

Let's do some generation to test quality of the model.

In [24]:
# Let's fix this
# We're going to strech 4, 8 to a single dimension using view to 32 

import torch 
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        # For each token lookup nexr character from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        
        # idx and targets are both (B, T) tensor of integers
        logits = self.token_embedding_table(idx) # (B, T, C)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in current context
        for _ in range(max_new_tokens):
            # get predictions
            logits, loss = self(idx) # calls forward
            # focus only on last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probs
            probs = F.softmax(logits, dim=-1) # -1 is C
            # Get 1 sample from distirbution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled idx to running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
    
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)

torch.Size([32, 65])


In [25]:
# \n starting point for generate
idx=torch.zeros(1, 1, dtype=torch.long)
idx


tensor([[0]])

In [26]:
# [0].tolist() pick the item
m.generate(idx=torch.zeros(1, 1, dtype=torch.long), max_new_tokens=100).shape

torch.Size([1, 101])

In [27]:
print(decode(m.generate(idx=torch.zeros(1, 1, dtype=torch.long), max_new_tokens=100)[0].tolist()))


pJ:Bpm&yiltNCjeO3:Cx&vvMYW-txjuAd IRFbTpJ$zkZelxZtTlHNzdXXUiQQY:qFINTOBNLI,&oTigq z.c:Cq,SDXzetn3XVj


We get garbage because it's random picking up of characters based on lookup table.

In [28]:
params = {
        "context size": block_size,
        "batch size": batch_size,
        "vocab size": vocab_size,
        "description": "Model not trained and a simple lookup for a (4, 8) sample"
    }    
metric={
        "train loss": loss.item()
    }
log_experiment(run_name="bigram-baseline-0", params=params, metric=metric)

## Trianing bigram model

In [32]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)
batch_size = 32
for steps in range(50000):
    xb, yb = get_batch(split="train")

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

2.3720762729644775


In [33]:
print(decode(m.generate(idx=torch.zeros(1, 1, dtype=torch.long), max_new_tokens=500)[0].tolist()))


OUng t bu s yove tend n, I:
LO:
cthy cotscu
ICHEdee r chidots
Whe methare f mave.
TEYofe o ou, t h Exf it se ventifitou, osingave brd Gilinourthathele MAS: ad g torulor wrail Je t ts osh se thay hofoutreputeeed siveica he qureades't insecoof wheagheateril th anonswa tthe'd d n sm; she CEr yourern ENES: s y wird
Whit; alothers ws t y cthn;
Whasery se h ofe fise pest, oun Mou thar tlul be fearsthee ous,
My well bencis,
IOUK:
Tr igr l har,
INEMomeame irtou o fisiswiouco aperir fulatithowe s ceseroy


In [34]:
params = {
        "context size": block_size,
        "batch size": batch_size,
        "vocab size": vocab_size,
        "learning rate": 1e-3,
        "Optimizer": "AdamW",
        "description": "Model trained for 50000 steps with AdamW optimizer"
    }    
metric={
        "train loss": loss.item()
    }
log_experiment(run_name="bigram-baseline-1", params=params, metric=metric)

## Building the *self-attention*

Right now the tokens are not interacting with each other. We can make them interact with attention.

### version 1: averaging past contexts with for loops, the weakest form of aggreation.

The weakest interation is to average all the tokens until the current timestep. We should avoid going into the future(next time step) which we're going to predict.

In [35]:
# Toy example
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn((B, T, C))
x.shape

torch.Size([4, 8, 2])

In [36]:
# Avearing contexts
# We want x[b, t] = mean_{i<=t} x[b, i]
# bow --> bag of words, since we're just averaging
xbow = torch.zeros(B, T, C)
for b in range(B):
    for t in range(T):
        print(f"Batch: {b}")
        print(f"Timestep: {t}")
        xprev = x[b, :t+1] # (t, C)
        print(f"X: {x[b, t]}")
        print(f"Xprevious: {xprev}")
        xbow[b, t] = torch.mean(xprev, 0)
        print(f"Xaverage: {xbow[b, t]}\n")


Batch: 0
Timestep: 0
X: tensor([ 0.1808, -0.0700])
Xprevious: tensor([[ 0.1808, -0.0700]])
Xaverage: tensor([ 0.1808, -0.0700])

Batch: 0
Timestep: 1
X: tensor([-0.3596, -0.9152])
Xprevious: tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152]])
Xaverage: tensor([-0.0894, -0.4926])

Batch: 0
Timestep: 2
X: tensor([0.6258, 0.0255])
Xprevious: tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255]])
Xaverage: tensor([ 0.1490, -0.3199])

Batch: 0
Timestep: 3
X: tensor([0.9545, 0.0643])
Xprevious: tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643]])
Xaverage: tensor([ 0.3504, -0.2238])

Batch: 0
Timestep: 4
X: tensor([0.3612, 1.1679])
Xprevious: tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679]])
Xaverage: tensor([0.3525, 0.0545])

Batch: 0
Timestep: 5
X: tensor([-1.3499, -0.5102])
Xprevious: tensor([[ 0.1808, -0.

In [37]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [38]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

The avearge is not a good context, because we loose all the spatial information.
And also the above method of aggeration average is ineffecient. We can do this with matrix multiplication.

### The trick in self-attention: Matrix multiply as weighted aggregation

In [39]:
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(low=0, high=10, size=(3, 2), dtype=torch.float)
c = a @ b
print("a=")
print(a)
print("--")
print("b=")
print(b)
print("--")
print("c=")
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


This is matrix multiplication -- first row of a is element wise multiplied with first column of b and summed together gives the (0, 0)th element in c.

In [55]:
# Let's bring in tril
torch.tril(torch.ones(3, 3))

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

See the triangle is one and others are zero, let's use this and matrix multiplication to see what we get.

In [56]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(low=0, high=10, size=(3, 2), dtype=torch.float)
c = a @ b
print("a=")
print(a)
print("--")
print("b=")
print(b)
print("--")
print("c=")
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


We get what we achieved witha for loop in version of self-attention.

In [57]:
# Let's normalize this.
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True) # Normalization among row
b = torch.randint(low=0, high=10, size=(3, 2), dtype=torch.float)
c = a @ b
print("a=")
print(a)
print("--")
print("b=")
print(b)
print("--")
print("c=")
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


### Version 2: matrix multiply

Let's use the above to rewrite for loop aggreagation as weighted matrix multiplication aggregation

In [68]:
# Weights --> Wei
wei = torch.tril((torch.ones(T, T)))
wei = wei / wei.sum(1, keepdim=True)
# wei (T, T) @ x(B, T, C)
# Pytorch create batch dimension for wei
# (B, T, T) @ (B, T, C) --> (B, T, C)
xbow2 = wei @ x

In [69]:
torch.allclose(xbow, xbow2)

True

In [70]:
xbow[0], xbow2[0]

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))

### Version 3: Using softmax

In [72]:
tril = torch.tril(torch.ones(T, T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [73]:
# Zero weights, whil will be updated based on token afinity in self-attention
wei = torch.zeros((T, T))
wei

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [78]:
#Replacing future values with -inf to avoid interaction
wei = wei.masked_fill(tril == 0, float('-inf'))
wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [80]:
# Softamx takes exponentaion and performs normalization by divind each element by sum of exponenets
from torch.nn import functional as F
wei = F.softmax(wei, dim=-1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [81]:
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

In [83]:
# Putting all together
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

Weighted aggreagation of past elements by using matrix multiplication of lower trinagular portion. The elements in lower traingaluar portion informs how much of context is being used.

### Version 4: Self attention - Crux of this notebook

In [13]:
# Current setup averaging of contexts for attention
import torch
from torch.nn import functional as F

torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn((4, 8, 32))

# weights init
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
# Masking
wei = wei.masked_fill(tril == 0, float('-inf'))
# Normalize weights
wei = F.softmax(wei, dim=-1)
out = wei @ x

out.shape

torch.Size([4, 8, 32])

In [15]:
wei.shape

torch.Size([8, 8])

In [16]:
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

The current setup is averaging context for affinity matrix(wei) and it's uniform for each batch of eight tensors.
What we want to do next? --> Instead of averaging contexts we want to pic specific information from past...
This is solved by self-attention.

In self-attention, each token emits an ***query*** and ***key***. 

***query*** what token is looking for
***key*** what the token has

We get the context with respect to current key by dot product of current ***query*** with previous ***key*** at each tokens. If dot product is heigh than the relevancy of that token is high with current token and if not relevancy is low.

Let's implement this now.

In [17]:
# Implementing self-attention with a single head
import torch
from torch import nn
from torch.nn import functional as F

torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn((4, 8, 32))

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)

# dot product query with keys
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

# weights init
tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
# Masking
wei = wei.masked_fill(tril == 0, float('-inf'))
# Normalize weights
wei = F.softmax(wei, dim=-1)
out = wei @ x

out.shape

torch.Size([4, 8, 32])

In [18]:
wei.shape

torch.Size([4, 8, 8])

In [19]:
wei

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
         [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
         [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
         [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1687, 0.8313, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2477, 0.0514, 0.7008, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4410, 0.0957, 0.3747, 0.0887, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0069, 0.0456, 0.0300, 0.7748, 0.1427, 0.0000, 0.0000, 0.0000],
         [0.0660, 0.089

Previously we had a single affinity matrix for all batches. Now we've have an affinity matrix for each batch and inputs.

In [21]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Adding the value to emit when attention is interested in a context.

In [22]:
# Adding value
from torch.nn import functional as F

torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn((4, 8, 32))

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)

# dot product query with keys
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)

# weights init
tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
# Masking
wei = wei.masked_fill(tril == 0, float('-inf'))
# Normalize weights
wei = F.softmax(wei, dim=-1)

# aggreagated input tokens
v = value(x)
out = wei @ v

out.shape

torch.Size([4, 8, 16])

### Notes:

* Attention is a ***communication mechanism***. Can be seen as nodes in a directed graph looking at each other and aggreagating information with a weighted sum from all nodes that point to them, with data-dependent weights.
* There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode vectors.
* Each example across batch dimension is processed completley independently and never *talk* to each other.
* In an *encoder* attention block delete the masking with `tril` line to allow all tokens to communicate with each other(problems like sentiment analysis). The block we have is *decoder* attention block becuase it has triangular masking and is usually used in autoregressive settings, like language modeling.
* *self-attention* means that keys and valyes are produced from same source as queries. In *cross-attention* queries still get produced from x, but keys and values come from some other, external source
* *Scaled* attention additional divides `wei` by `/sqrt(head_size). This makes it so when the input Q, K are unit variance. wei willbe unit variance too and softmax will stay diffused and not saturated too much. Assume a large value of wei in a single example. Then softmax will become one-hot encode vectors(1 at the large value and 0 at all others)

In [25]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1)

In [26]:
k.var(), q.var(), wei.var()

(tensor(1.0700), tensor(0.9006), tensor(18.0429))

In [29]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [30]:
k.var(), q.var(), wei.var()

(tensor(0.9416), tensor(1.0104), tensor(1.0879))