In [64]:
!nvidia-smi

Mon May 22 17:04:08 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.41.03              Driver Version: 530.41.03    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce GTX 1060         Off| 00000000:01:00.0  On |                  N/A |
| N/A   53C    P8                6W /  N/A|     65MiB /  6144MiB |     28%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math



# Plan of action

## Steps
* Download the data
* Tokenizer
* Batch creator
* Create a basic forward pass
* self attention layer
* Create a training process


In [4]:
import urllib.request

# download tiny shakespeare
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'

In [5]:
# download the file directly to a variable
text = urllib.request.urlopen(url).read().decode('utf-8')

## Create a tokenizer at the character level

In [6]:
tokens = list(set(text))
vocab_size = len(tokens)
print(vocab_size)

# Create an encoder decoder for our tokens to turn them into numbers and back
encoder_decoder = {token: i for i, token in enumerate(tokens)}
decoder_encoder = {i: token for i, token in enumerate(tokens)}

encode = lambda x: [encoder_decoder[i] for i in x]
decode = lambda x: "".join([decoder_encoder[i] for i in x])

print(encode("hii there"))
print(decode(encode("hii there")))

65
[51, 64, 64, 55, 36, 51, 3, 6, 3]
hii there


## Creating our dataset
We split the data into training and validation with 90/10 split

In [7]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)


# Split the data into training and validation sets
split_val = int(len(data) * 0.9)
train_data = data[:split_val]
val_data = data[split_val:]


In [8]:
len(train_data), len(val_data)

(1003854, 111540)

### Turning our data into batches

In [9]:
batch_size = 4
block_size = 8



def get_batch(split):
    if split == 'train':
        data = train_data
    else:
        data = val_data
    batch_start_indexes = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in batch_start_indexes])
    y = torch.stack([data[i+1:i+block_size+1] for i in batch_start_indexes])
    return x,y

xb, yb = get_batch('train')

print(f"""
inputs:
{xb.shape}
{xb}

targets:
{yb.shape}
{yb}""")

for b in range(batch_size):
    for t in range(block_size):
        x = xb[b][:t+1]
        y = yb[b][t]
        print(f"Input is {x} and target is {y}")





inputs:
torch.Size([4, 8])
tensor([[51,  3, 11, 55,  1,  6,  3, 55],
        [14, 27, 16, 55, 41, 27, 14, 36],
        [55,  1, 55, 31,  3, 16, 36, 35],
        [ 3, 55, 25,  3,  6, 11, 55, 39]])

targets:
torch.Size([4, 8])
tensor([[ 3, 11, 55,  1,  6,  3, 55,  0],
        [27, 16, 55, 41, 27, 14, 36, 15],
        [ 1, 55, 31,  3, 16, 36, 35,  3],
        [55, 25,  3,  6, 11, 55, 39,  3]])
Input is tensor([51]) and target is 3
Input is tensor([51,  3]) and target is 11
Input is tensor([51,  3, 11]) and target is 55
Input is tensor([51,  3, 11, 55]) and target is 1
Input is tensor([51,  3, 11, 55,  1]) and target is 6
Input is tensor([51,  3, 11, 55,  1,  6]) and target is 3
Input is tensor([51,  3, 11, 55,  1,  6,  3]) and target is 55
Input is tensor([51,  3, 11, 55,  1,  6,  3, 55]) and target is 0
Input is tensor([14]) and target is 27
Input is tensor([14, 27]) and target is 16
Input is tensor([14, 27, 16]) and target is 55
Input is tensor([14, 27, 16, 55]) and target is 41
Input 

### Creating our model
Our goal is to create a simple bigram model using pytorch nn.Module as our basis

In [10]:

torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx)

        if targets == None:
            loss = None
        
        else:
            # Where 
            # B = batch_size = 4
            # T = time = 8
            # C = channel = 65 = vocab_size
            #  We change the shapes of our logits to get them in the shape needed to use pytorch's cross_entropy function

            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
        return logits
    
    def generate(self, x_input, max_new_tokens):

        for _ in range(max_new_tokens):
            logits, loss = self(x_input) # we're not using loss, as we're generating

            next_token = logits[:, -1,:]

            probabilities = F.softmax(next_token, dim=-1)

            top_answer = torch.multinomial(probabilities, num_samples=1)

            x_input = torch.cat((x_input, top_answer), dim=1) # B, T+1. Appending to 1st dimension which is the time dimension

        return x_input
        


model = BigramLanguageModel(vocab_size)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss) # Loss is very high at this point, 4.6 
        


torch.Size([32, 65])
tensor(4.6067, grad_fn=<NllLossBackward0>)


In [11]:
x_input = torch.zeros((1,1),dtype=torch.long )
print(decode(model.generate(x_input, max_new_tokens=100)[0].tolist())) 
# Output is garbage, as we have not begun any training

wgOY fz?l3dtVBcQBnpqoBJiW'$Vn DHAXJcNBBSbdQXmv-jBF'gf:bVP EiXeMXeiUzjryCF3iSQq,ZU,AXxKllPo!in;SRavXe&


### Creating our backward pass
In this step we create an optimizer and demonstrate a basic gradient descent loop. 

So far our model is just an embedding table with the dimensions of vocab_size * vocab_size

In [12]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [13]:
batch_size = 32

for i in range(5000):
    xb, yb = get_batch(batch_size)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss)

tensor(2.4832, grad_fn=<NllLossBackward0>)


In [14]:
x_input = torch.zeros((1,1),dtype=torch.long )
print(decode(model.generate(x_input, max_new_tokens=100)[0].tolist())) 
# Output should look somewhat more sensible, and it does! 
# This is because the tokens have some idea about what should come next just through information encoded in their own embeddings.
# However, we observe a plateau in loss of around 2.3. We'll need to implement new tricks to break through.

wa PERO:
BAQZHAforthe, GTino y mGO:
Whithy IO:uthe GRJcastequ-s.
INe
ore f terr ow t my Js are yof he


### Adding self-attention


In [15]:
torch.manual_seed(1337)
a = torch.tril(torch.ones(3,3))
a = a/torch.sum(a,1,keepdim=True)
b = torch.randint(0,10, (3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])
--
c=
tensor([[5.0000, 7.0000],
        [3.5000, 3.5000],
        [4.0000, 3.3333]])


The purpose of the following example is to demonstrate the simplest implementation of how tokens can communicate with each other.

In this case we just average out all the values of the previous token's channels, which is obviously very lossy, but this is simply illustrative.

We will have a way to add all that back.

In [16]:
B,T,C = 4,8,2
x = torch.randn(B,T,C)
x.shape


# Here we use a bag of words (bow) to illustrate our averaging example
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]
        xbow[b,t] = torch.mean(xprev, 0)

# The purpose of this is to show that the rows of xbow are equal to the average of the values in all previous rows of x

print(xbow[0][1] == torch.mean(x[0][:2],0), xbow[0][2] == torch.mean(x[0][:3],0))


tensor([True, True]) tensor([True, True])


In [17]:
wei = torch.tril(torch.ones(T,T))
print(wei.sum(1, keepdim=True))
wei = wei/wei.sum(1, keepdim=True)
print(wei)
xbow2 = wei@x
torch.allclose(xbow,xbow2)

tensor([[1.],
        [2.],
        [3.],
        [4.],
        [5.],
        [6.],
        [7.],
        [8.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

Our next step is to demonstrate that we can do the above using softmax.


In [41]:
tril = torch.tril(torch.ones(T,T))
wei = torch.zeros(T,T)
print(wei)
wei = wei.masked_fill(tril==0, float('-inf')) 
print(wei)
wei = torch.softmax(wei,dim=1)
print(wei)
xbow3 = wei@x
torch.allclose(xbow3,xbow2)

tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000,

True

In [26]:
torch.tensor(float('-inf'))

tensor(-inf)

To determine the attention of words (more exactly tokens) we use ‘queries’, ‘keys’ and ‘values’.

All of them are presented in vectors. 

Keys activate depending on the strength of closeness with the query vector as determined by dot product.

Keys are an encoded representation for values, in simple cases they can be the same. 




In [63]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B,T,C)

#Attention head
head_size = 16
key = nn.Linear(C,head_size,  bias=False)
query = nn.Linear(C,head_size, bias=False)
value = nn.Linear(C,head_size, bias=False)
k = key(x)      # B,T,16
q = query(x)    # B,T,16

wei = q @ k.transpose(-2,-1) * C**-0.5 

wei = wei.masked_fill(tril==0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
print(out.shape)

torch.Size([4, 8, 16])


### Code explanations of the above:

wei = q @ k.transpose(-2,-1) * C**-0.5

The tranpose is used so we end up with a matrix of B,T,T:

(B,T,16) @ (B,16,T) ---> B, T, T: our desired shape


This lets us do batch matrix multiplication on our tril matrix which is size(16,16)


We apply the normalisation of  C**-0.5 to our wei variable as a normalisation step. We divide by the square route of our head size so that we avoid peaks that are too high in our initial weights.

In [61]:
wei

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4264, 0.5736, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3151, 0.3022, 0.3827, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3007, 0.2272, 0.2467, 0.2253, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1635, 0.2048, 0.1776, 0.1616, 0.2926, 0.0000, 0.0000, 0.0000],
         [0.1403, 0.2272, 0.1454, 0.1244, 0.2678, 0.0949, 0.0000, 0.0000],
         [0.1554, 0.1815, 0.1224, 0.1213, 0.1428, 0.1603, 0.1164, 0.0000],
         [0.0952, 0.1217, 0.1130, 0.1453, 0.1137, 0.1180, 0.1467, 0.1464]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4300, 0.5700, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3379, 0.2559, 0.4061, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2867, 0.2188, 0.2786, 0.2159, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1297, 0.1813, 0.1683, 0.2990, 0.2217, 0.0000, 0.0000, 0.0000],
         [0.1584, 0.167