In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

<torch._C.Generator at 0x246dfd25790>

In [2]:
# the dataset we use is called 'tiny shakespeare', containing all works of Shakespeare in a 1mb txt file

# read the file
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print("length of dataset in characters:", len(text))

length of dataset in characters: 1115394


In [4]:
# the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
# all the unique characters that occur in this text

chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [6]:
# create mappings of characters to and from integers

stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
print(itos)

encode = lambda s: [stoi[c] for c in s]  # encoder: takes a string, returns a list of integers
decode = lambda l: ''.join([itos[i] for i in l])  # decoder: takes a list of integers, returns a string

print(encode('hello there'))
print(decode([46, 43, 50, 50, 53, 1, 58, 46, 43, 56, 43]))

{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i', 48: 'j', 49: 'k', 50: 'l', 51: 'm', 52: 'n', 53: 'o', 54: 'p', 55: 'q', 56: 'r', 57: 's', 58: 't', 59: 'u', 60: 'v', 61: 'w', 62: 'x', 63: 'y', 64: 'z'}
[46, 43, 50, 50, 53, 1, 58, 46, 43, 56, 43]
hello there


In [7]:
# let's now encode the entire dataset and store it in a tensor
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])  # this is what GPT sees

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [8]:
# let's split this data into train and validation sets

n = int(0.9*len(data))  # 90% train, 10% test
train_data = data[:n]
val_data = data[n:]

In [9]:
block_size = 8  # context length
train_data[:block_size+1]  # we do plus one to get 9 characters, because in a chunk of 9 characters, there are 8 possible examples

# for example:
# given 18 ---> 47 comes next
# given 18, 47 ---> 56 comes next
# ...and so on

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [10]:
x = train_data[:block_size]  # inputs
y = train_data[1:block_size+1]  # target outputs
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f'when input is {context}, output is {target}')

# this is done to make the transformer used to seeing contexts all the way from a context length of 1 to a context length of 8
# so it can start generating even when it is given just one character as an input
# and it can generate even when it is given upto 8 characters as input

# This is like our makemore version, but the difference is, in makemore, we had a fixed input size. 
# All names were 8 characters long. If not, we used to pad the . character to make them 8 long 
# here, instead we append the characters gradually so that the model could learn the combinations

when input is tensor([18]), output is 47
when input is tensor([18, 47]), output is 56
when input is tensor([18, 47, 56]), output is 57
when input is tensor([18, 47, 56, 57]), output is 58
when input is tensor([18, 47, 56, 57, 58]), output is 1
when input is tensor([18, 47, 56, 57, 58,  1]), output is 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]), output is 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), output is 58


In [11]:
batch_size = 4  # how many independent sequences will we process in parallel
block_size = 8  # maximum context length for predictions

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))  # randomly selects batch_size number of starting indices, that's why we do len(data) - block_size, so we don't choose the last few indices to begin with
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print('outputs:')
print(yb.shape)
print(yb)

print('-----')

for b in range(batch_size):  # batch dimension
    for t in range(block_size):  # block dimension
        context = xb[b, :t+1]
        target = yb[b, t]  # imp to not get confused: we index into yb here, not xb
        print(f'when input is {context.tolist()}, the target is: {target}')
        print(f'{decode(context.tolist())}, ---> {decode([target.tolist()])}')  # just for clarity
        print()
    print()

inputs:
torch.Size([4, 8])
outputs:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
-----
when input is [24], the target is: 43
L, ---> e

when input is [24, 43], the target is: 58
Le, ---> t

when input is [24, 43, 58], the target is: 5
Let, ---> '

when input is [24, 43, 58, 5], the target is: 57
Let', ---> s

when input is [24, 43, 58, 5, 57], the target is: 1
Let's, --->  

when input is [24, 43, 58, 5, 57, 1], the target is: 46
Let's , ---> h

when input is [24, 43, 58, 5, 57, 1, 46], the target is: 43
Let's h, ---> e

when input is [24, 43, 58, 5, 57, 1, 46, 43], the target is: 39
Let's he, ---> a


when input is [44], the target is: 53
f, ---> o

when input is [44, 53], the target is: 56
fo, ---> r

when input is [44, 53, 56], the target is: 1
for, --->  

when input is [44, 53, 56, 1], the target is: 58
for , ---> t

when input is [44, 53, 

In [28]:
# bigram language model

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)  # embedding table of containing vocab_size dimension vectors

    def forward(self, idx, targets=None):
        
        # idx and targets are both (B, T) tensors of integers
        logits = self.token_embedding_table(idx)  # (B, T, C)  the third dimension is the embeddings

        if targets == None:
            loss = None
        else:
            # cross_entropy expects arguments to be of shape (N, C, d1, d2, ...) where N is batch size, C is number of classes
            B, T, C = logits.shape
            logits = logits.view(B*T, C)  # 4*8=32 samples with 65 classes each
            targets = targets.view(-1)  # cross_entropy expects target of shape (N, d1, d2, ...), i.e. 4*8 = 32
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context

        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :]  # (B, C)  # pluck out the last element in the time dimension across all batches
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1)  # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape, loss)  # loss at initialization given a prob of 1/65 should be ~4.17

# idx = torch.zeros((1, 1), dtype=torch.long)  # batch size: 1, time size: 1 containing a zero
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

torch.Size([32, 65]) tensor(4.7341, grad_fn=<NllLossBackward0>)

Pzb.nAIijJooLkmUc fAe?fyKmo:IOuc.rkGKqRYtKJcyJSldKsH?;Fi?X&aQEbtQpUkw'u b:HBVVTVoEy3X.:sgdfg;nOJk:.c


In [29]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [184]:
batch_size = 64

for steps in range(10000):

    # sample a batch of data 
    xb, yb = get_batch('train')

    # evaluate loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.4595069885253906


In [199]:
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


I mee hon he ithauleen's?-
a gheo ing, pisourthest t ameer:
myen have,

Malset ubedy acepy ng t INCERBUCEr ick'l;
e'e IZAsth t we ndofrendere the thofakeon wiche alik n cem, buran aw hice me r gh:
Angak beene, p; mind merin s,
Toth bll minoutalo de IULLOPSore ERe y p tirthove,
CII I'de this ed;
Ay tin:

APowall o aporere w ts
BUKI ff o k,
PEEENAs:
KI l thed de gbr hion: upsth med pr'dou m ourere,
NERI al't wear n, anched, iset shit RAndas towilat t ithlise fad,
Wr-d ye ftibeouty;
PUThe?
Whegenth


### The mathematical trick in self-attention

In [238]:
# consider the following toy example:

torch.manual_seed(1337)
B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [239]:
# We want x[b, t] = mean_{i<=t} x[b, i]
# i.e. each 't'th token will be the average of itself and the tokens behind it. 
# Note that it is strictly the ones behind it, not the ones in front. We want to predict the future tokens, not already know them

# version 1: for loops
xbow = torch.zeros((B, T, C))  # bow is short for bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]  # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

# but this method with for loops is very inefficient

In [228]:
# version 2: matrix multiply
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)

xbow2 = wei @ x  
# (T, T) @ (B, T, C) ---> pytorch will see these are not shapes are not as we require them, so it will do
# (B, T, T) @ (B, T, C) = (B, T, C)  ---> for each batch element, there will be a (T, T) @ (T, C), giving out a (T, C)

torch.allclose(xbow, xbow2, atol=1e-6, rtol=1e-4)  # Adjusted tolerance for floating point differences

True

In [245]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))  # will be used as a mask to ensure that each position only attends to previous positions (including itself)
wei = torch.zeros((T, T))  
wei = wei.masked_fill(tril == 0, float('-inf'))  # fills the upper triangular part of wei with negative infinity, to ensure that each position cannot attend to future positions
wei = F.softmax(wei, dim=-1)  # converts the weights to probabilities, -inf values become zero
xbow3 = wei @ x
torch.allclose(xbow, xbow3, atol=1e-6, rtol=1e-4)

# NOTE: for the softmax function here, this is what it does:
# For the negative infinity values (upper triangle), e^(-inf) = 0
# For the zero values (lower triangle), e^0 = 1
# then it is transformed into a probability distribution where each row sums to 1

RuntimeError: The size of tensor a (2) must match the size of tensor b (32) at non-singleton dimension 2

In [246]:
# version 4: self-attention!
torch.manual_seed(1337)
B, T, C = 4, 8, 32  # batch, time, channels
x = torch.randn(B, T, C)

# let's see a single head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)  # (B, T, head_size) 
q = query(x)  # (B, T, head_size)
# all tokens in all batches in parallel and independently, produce a key and a query, no communication has happened YET

# so what this head does is, each token (ie character in this case), now has two vectors, a key and a query
# these will help tokens "ask questions" (query) and "provide information" (key)
# before we used to initialize the weights with all zeros
# now we initialize weights in a data dependent manner

# so each char will look at what has come before, and which ones are the most important for it
# this is done by matrix multiplying it's query with the keys of all tokens that came before and its own key as well
# this mat mul will 'align' (by dot product) the weights of that token to the chars that came before
# each token's query is dot-producted with every token's key (and it's own key as well)
# a high dot product means the query token finds that key token relevant

wei = q @ k.transpose(-2, -1) * head_size**-0.5 # we have to be careful with the transpose here for mat mul, we transpose the 2nd last and last dim here
# (B, T, 16) @ (B, 16, T) = (B, T, T)
# we multiply by head_size**-0.5 to make wei unit variance if the incoming q and k are unit variance too. To prevent saturation of softmax

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))  # this could be said to be the 'affinity'. we don't want this to be all uniform ie 0, because different tokens will find different other tokens more or less interesting
wei = wei.masked_fill(tril == 0, float('-inf')) 
wei = F.softmax(wei, dim=-1)
v = value(x)  # now the x is private to each token. the v is the thing that gets aggretated for the purpose of this single head
out = wei @ v
#out = wei @ x

out.shape

# Query (q): determines what each token is looking for
# Key (k): determines what each token offers
# Value (v): represents the actual information that's being passed

# The attention weights (wei) determine how important each token is to each other token
# we don't necessarily want to pass the raw input (x) directly. Instead, we pass a transformed version (v)
# By having separate transformations for key, query, and value, 
# the model can learn to represent the information in different ways for different purposes

# The q @ k.transpose(-2, -1) operation computes a similarity or relevance score
# But we might want to pass different information than what we used to compute that score
# v allows us to separate "what determines attention" from "what information is passed"

torch.Size([4, 8, 16])

In [244]:
wei[0]
# here in the 8th row, look at the 8th token 0.2391 for example
# the 8th token now's what content it has (through character embedding table), and it knows what position it's at (position embedding table)
# and now it can send out a 'query', to look through the 'keys' of the previous characters, to know how important each of them is to it
# so if a certain character is important to it, those both will have a high affinity
# as an example here, the 8th token 0.2391 seems to have a high affinity for the 4th token 0.2297, since they're close in values

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

In [None]:
# ------------------------- notes and deriving the steps along the way ------------------------- #

In [205]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [206]:
xbow[0]
# the first element remains the same, as there is nothing behind that to average over
# the next one is an average of itself and the one behind it
# and so on

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [210]:
torch.manual_seed(42)

a = torch.ones(3, 3)
b = torch.randint(0, 10, (3,2)).float()  # (3,2) specifies the shape of the tensor
c = a @ b
print('a=')
print(a)
print('---')
print(b)
print('---')
print('c=')
print(c)

a=
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
---
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


In [208]:
torch.tril(torch.ones(3, 3))  # gives a lower trianguar matrix i.e. 0 above in the upper half right triangle

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [211]:
torch.manual_seed(42)

a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3,2)).float()  # (3,2) specifies the shape of the tensor
c = a @ b
print('a=')
print(a)
print('---')
print(b)
print('---')
print('c=')
print(c)

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
---
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


In [None]:
# so we end up just plucking out the 1st row of b to be the 1st row of c
# for the 2nd row of c, we end up taking the sums of the 1st and 2nd rows of b
# for the 3rd row of c, we take the sums of the 1st, 2nd, and 3rd rows of b

# so we have a pretty good way to take the sums we want now (step for mean calculation remains)
# and more efficiently instead of having to use for loops

In [212]:
torch.manual_seed(42)

a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)  # to normalize the values in it, so that we directly get the mean  # all rows will sum to 1 now
b = torch.randint(0, 10, (3,2)).float()  # (3,2) specifies the shape of the tensor
c = a @ b
print('a=')
print(a)
print('---')
print(b)
print('---')
print('c=')
print(c)

# so the c we get here is directly the averaged out tensor we wanted all along
# matrix multiplication FTW!

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
---
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
---
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [230]:
xbow[0], xbow2[0]

(tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]),
 tensor([[ 0.1808, -0.0700],
         [-0.0894, -0.4926],
         [ 0.1490, -0.3199],
         [ 0.3504, -0.2238],
         [ 0.3525,  0.0545],
         [ 0.0688, -0.0396],
         [ 0.0927, -0.0682],
         [-0.0341,  0.1332]]))