Training a transformer based LLM on a small dataset for fun

In [None]:
# Dowload the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-01-14 19:15:05--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2024-01-14 19:15:06 (27.3 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [None]:
# read data
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# inspect data
print("length of dataset in characters: ", len(text))

# let's look at the first 100 characters
char_ = 100
print(text[:char_])
print("end of first", char_, "characters")

length of dataset in characters:  1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You
end of first 100 characters


In [None]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


next, tokenize input text. Meaning converting raw text as string to some sequence of integers according to vocab. Here, we create from scratch so simple and not using a plug and play tokenizer.

In [None]:
# TOKENIZER
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) } # iterate over all the characters (chars) and create a look up table from character to the integer and vice versa
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

#example: translate to integers and back
print(encode("hii there")) # list of integers that represents the input string
print(decode(encode("hii there"))) #decodes list of integers (which were encoded first) to get back the exact same string


[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [None]:
# Encode the entire text dataset and store it into a torch.Tensor
import torch # we use PyTorch: https://pytorch.org
data = torch.tensor(encode(text), dtype=torch.long) #takes all of the text and encodes it and wraps it into a torch.tensor to get the data tensor
print(data.shape, data.dtype) #this shows what the data tensor looks like
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this

torch.Size([1115394]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

In [None]:
# Get validation data set so that we can make sure it's an LLM that does not just memorizes but creates text.
# So,  split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

Next, start using parts of the text to train the transformer so that it can learn the patterns. For this, we train it on chucks of the datasets instead of feeding all of the text all at once because that would be computationally expensive.

The random chuncks of the dataset have a maximum length that we can set, referred to as blocksize, or context length, or such.

In [None]:
block_size = 8
train_data[:block_size+1]
# these are predicitions:
# in the context of 18, 47 comes next
# in the context of 18 and 47, 56  comes next
# see code below for illustration

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [None]:
x = train_data[:block_size] # input in transformer: the first blocksize characters
y = train_data[1:block_size+1] # the next blocksize charaters, so the blocksize characters, offset by 1
for t in range(block_size):
    context = x[:t+1]
    target = y[t] # target being next prediction given earlier integers
    print(f"when input is {context} the target: {target}")

when input is tensor([18]) the target: 47
when input is tensor([18, 47]) the target: 56
when input is tensor([18, 47, 56]) the target: 57
when input is tensor([18, 47, 56, 57]) the target: 58
when input is tensor([18, 47, 56, 57, 58]) the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]) the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]) the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the target: 58


In [None]:
torch.manual_seed(1337) # seed in random number generator for reproducibility
batch_size = 4 # the amount of independent sequences that are processed in parallel?
block_size = 8 # the maximum context length to make predictions

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data # gives data array for training OR validation data
    ix = torch.randint(len(data) - block_size, (batch_size,)) # generate 'batchsize' number of random offsets into the training set
    #torch.stack, stacks all the tensors into rows of batch_size by block_size
    x = torch.stack([data[i:i+block_size] for i in ix]) # first block size characters, starting at ix
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # characters of x, offset by 1.
    return x, y

# sample a batch xb and yb
xb, yb = get_batch('train')
print('inputs:')
print(xb.shape) # this integer tensor is going to feed into the transformer
print(xb)
print('targets:')
print(yb.shape) # transformer processes the examples in x and looks up the integers to predict in the tensor y
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
targets:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])
----
when input is [24] the target: 43
when input is [24, 43] the target: 58
when input is [24, 43, 58] the target: 5
when input is [24, 43, 58, 5] the target: 57
when input is [24, 43, 58, 5, 57] the target: 1
when input is [24, 43, 58, 5, 57, 1] the target: 46
when input is [24, 43, 58, 5, 57, 1, 46] the target: 43
when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39
when input is [44] the target: 53
when input is [44, 53] the target: 56
when input is [44, 53, 56] the target: 1
when input is [44, 53, 56, 1] the target: 58
when input is [44, 53, 56, 1, 58] the target: 46
when input is [44, 53

In [None]:
print(xb) # our input to the transformer

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


Now we have the batch of input that we want to feed into the transformer. This can be fed into the neural netrwork. (Bigram LLM)

In [None]:
# import pytorch NN module
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):
      #takes input x (renamed idx) and passes it into token embedding table

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C) arranges the embedding into a batch by time by channel tensor (e.g. 4,8, 65)
                                                 # The B, T, C is interpreted as the logits: the scores for the next characters in the sequences based on the character itself
        # Targets have to be optional to prevent error in case of no target and no loss (to just get logits without loss), so if targets are none, there is also no loss
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape #unpack so it can be used in proper dimensions for pytorch (pure coding thing)
            logits = logits.view(B*T, C) # turn three dimensional array (B by T by C) into 2 dimensional one (B*T by C)
            targets = targets.view(B*T) # turn two dimensional array (B by T) into a one dimensional array of (B*T)
            loss = F.cross_entropy(logits, targets) # evaluate quality of prediction of the next character by determining the loss: the negative log likelihood.
                                                    # loss is the crossentropy on the logits and the targets.

        return logits, loss


    def generate(self, idx, max_new_tokens):
        # idx is a (B by T) array of indices in the current context in some batch
        # it continues the generation in the time (T) dimension for each B
        # so takes a B by T and makes it a B by T + 1 in for a loop of max_new_tokens
        for _ in range(max_new_tokens):
            # get the predictions (logits) for the current idx. the loss is ignored here
            logits, loss = self(idx)
            # focus only on the last time (T) step
            logits = logits[:, -1, :] # instead of B by T by C, get (B by C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B by C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # idx_next becomes B by 1 (B, 1)
            # append / concatenate sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # creates (B by T+1)
        return idx

m = BigramLanguageModel(vocab_size) # construct Bigram LLM, which is a subclass of the NN module
logits, loss = m(xb, yb) # call bigram model, pass it the inputs and the targets
print(logits.shape)
print(loss)


# idx = torch.zeros((1, 1), dtype=torch.long) # one by one tensor of datatype integer, kicks of generation with zero as first char
# , max_new_tokens=100) # gets 100 tokens
# (m.generate # continues this generation above with '[0].tolist()))' indexing the zero-th row and turns into a python list to feed into the decode function
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


torch.Size([32, 65])
tensor(4.8786, grad_fn=<NllLossBackward0>)

Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3


The predicition above makes no sense because it has not been trained properly. Now the model above has to be trained to get a non-random prediction.

In [None]:
# create a PyTorch optimizer object which takes the gradients and update the parameters using the gradients.
# the optimizer Adam is used here.
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
batch_size = 32
for steps in range(10000): # increase number of steps for good results...

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True) #bring out the gradients from the previous step
    loss.backward() # get the gradients for all the parameters
    optimizer.step() # use the gradients to update the parameters

print(loss.item())


2.5727508068084717


In [None]:
# see what trained model now does
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))

# you can play around with the tokens.



Iyoteng h hasbe pave pirance
Rie hicomyonthar's
Plinseard ith henoure wounonthioneir thondy, y helti


The model above  is simple as only the previous character is considered to predict the next token. The transformer comes in to enable the tokens to "talk" to one another for the prediction to figure out what is in the context.

See example below to consider the maths that creates self attention.

The mathematical trick in self-attention

In [None]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [None]:
# consider the following toy example:

# for every t'th token, calculate the average of the preceding tokens [:t]
# this creates some sort of a feature vector of the i'th token in the context of it's history. - allowing for self-attention.
# creating the average is a weak form of attention, it can be improved (later)
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [None]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) # bag of words
for b in range(B): # iterate over all the batch dimensions
    for t in range(T): # iterate over time
        xprev = x[b,:t+1] # (t,C) # the T elements in the past and C: two dimensional info from tokens
        xbow[b,t] = torch.mean(xprev, 0) # get the average of the T time dimension which you store in the x bag of words


In [None]:
# version 2: using matrix multiply for a weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C) = Batch matrix multiplication
torch.allclose(xbow, xbow2)

False

In [None]:
# version 3: use Softmax
# used in our model for self-attention because the infinities can become context dependant ()
# this allows weighted aggregations of previous elements by using matrix multiplation in lower triangular elements.

tril = torch.tril(torch.ones(T, T)) # starts out with 50% ones (lower triangle) and 50% zero's (upper triangle)
wei = torch.zeros((T,T)) # starts out full of zero's
wei = wei.masked_fill(tril == 0, float('-inf')) # creates 50% zero's and 50% negative infinite - excluding future
wei = F.softmax(wei, dim=-1) # softmax every row to normalize and get the same matrix as above (version 1 and 2)
xbow3 = wei @ x
torch.allclose(xbow, xbow3)


False

In [None]:
# version 4: self-attention!

# get tokens from the past to emit two vectors : a query and a key
# query = what it's looking for (e.g. a vowel will look for a consonant)
# key is what it contains itself (being a vowel)
# do a dot product with the query and the keys, which becomes wei.
# the more query and the key align, the weight (wei) will become very high which creates more attention to it.
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C) # this is info on the specific token, which will be used for self-attention

# let's see a single Head perform self-attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)   # (B, T, head_size =16)
q = query(x) # (B, T, head_size = 16)
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T) # -2 and -1 transposes in the last two dimensions

# mask wei matrix and initialize
tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # mask the upper triangular - meaning the tokens don't know what the next characters are, and only predict based on previous characters.
wei = F.softmax(wei, dim=-1) #exponentiate and normalize to avoid negative numbers and get a nice distribution

v = value(x) # v is that which gets activated when this specific token is queried. (e.g. I am a vowel, I pay attention to consonants)
out = wei @ v
#out = wei @ x

out.shape

torch.Size([4, 8, 16])

In [None]:
# weights are now context dependant
wei


tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
         [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
         [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
         [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1687, 0.8313, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2477, 0.0514, 0.7008, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4410, 0.0957, 0.3747, 0.0887, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0069, 0.0456, 0.0300, 0.7748, 0.1427, 0.0000, 0.0000, 0.0000],
         [0.0660, 0.089

In [None]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `wei` by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below


Softmax creates a large peak when values in the tensor are far from each other. so we try to prevent the values from being very far removed from each other, by scaling using "head_size"

In [None]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5

In [None]:
k.var()

tensor(1.0449)

In [None]:
q.var()

tensor(1.0700)

In [None]:
wei.var()

tensor(1.0918)

In [None]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1) # not very peaky ranging from 0.14 to 0.28 - meaning numbers aroung / close to each other / zero

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])

In [None]:
torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky (range from 0.03 - 0.8) because the *8 creates large differences between values,
                                                                  # converges to one-hot

tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])

next, we apply multi head attention instead of single head attention

Multi head attention applies multiple attentions in paralell and concatenates the result.

In [None]:
""" class MultiHeadAttention(nn.Module):"""
    """multiple heads of self-attention in parallel """

    """def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # concatenate over the channel (C) dimension
        out = self.dropout(self.proj(out))
        return out"""



Full finished code

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]


In [None]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

torch.manual_seed(1337)

<torch._C.Generator at 0x78cab66b9870>

In [None]:

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx



In [None]:
model = BigramLanguageModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


0.209729 M parameters
step 0: train loss 4.4116, val loss 4.4022
step 100: train loss 2.6568, val loss 2.6670
step 200: train loss 2.5091, val loss 2.5060
step 300: train loss 2.4199, val loss 2.4337
step 400: train loss 2.3500, val loss 2.3563
step 500: train loss 2.2961, val loss 2.3126
step 600: train loss 2.2408, val loss 2.2501
step 700: train loss 2.2053, val loss 2.2187
step 800: train loss 2.1636, val loss 2.1870
step 900: train loss 2.1226, val loss 2.1483
step 1000: train loss 2.1017, val loss 2.1283
step 1100: train loss 2.0683, val loss 2.1174
step 1200: train loss 2.0376, val loss 2.0798
step 1300: train loss 2.0256, val loss 2.0645
step 1400: train loss 1.9919, val loss 2.0362
step 1500: train loss 1.9696, val loss 2.0304
step 1600: train loss 1.9625, val loss 2.0470
step 1700: train loss 1.9402, val loss 2.0119
step 1800: train loss 1.9085, val loss 1.9957
step 1900: train loss 1.9080, val loss 1.9869
step 2000: train loss 1.8834, val loss 1.9941
step 2100: train loss 1.