In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [3]:
# load dataset. 
# this file contains some works by shakespere as plain text

with open(r'../text_input/input.txt') as file:
    text = file.read()

len(text), text[:100]

(1115394,
 'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou')

In [67]:
# preparatory work. 
# alphabet could also be called a token dictionary.
# encode and decode are functions to quickly tokenize and detokenize a text.
# lastly, data is split in train and validation set

alphabet = ''.join(sorted(list(set(text))))
alphabet, len(alphabet)

stoi = {s:i for i,s in enumerate(alphabet)}
itos = {i:s for i,s in enumerate(alphabet)}

encode = lambda x: [stoi[i] for i in x]
decode = lambda x: ''.join([itos[i] for i in x])

data = torch.tensor(encode(text), dtype=torch.long)
split_id = int(0.9*len(data))
train_data = data[:split_id]
val_data = data[split_id:]

train_data.size(), val_data.size(), alphabet, len(alphabet)

(torch.Size([1003854]),
 torch.Size([111540]),
 "\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz",
 65)

In [5]:
# batch: num of parallel processed tokens
# time: previous and posterior tokens
# channels: embedding dim of each token
B,T,C = 4,8,10                                  # B: batch, T: time, C: channels
nH, sH = 5, 2                                   # nH: num of heads, sH: size of heads (in multi-head attention -> nH * sH = C)

tril = torch.tril(torch.ones((T,T)))            # lower triangular matrix

# keys, queries and values are model parameters that will be trained
key = torch.randn((1,nH,C,sH))                  # how strong to give to another time
query = torch.randn((1,nH,C,sH))                # how strong to receive from another time
value = torch.randn((1,nH,C,sH))                # what value to transmit 

# void data & dimension magic
x1 = torch.randn((B,T,C))
x = x1.view((B,1,T,C))

# compute key and query values 
k = x @ key                                     # (B,1,T,C) x (1,nH,C,sH) -> (B,nH,T,sH)
q = x @ query                                   # (B,1,T,C) x (1,nH,C,sH) -> (B,nH,T,sH)

# combine keys and queries to weights, i.e. the strength of connection between two tokens in time dimension
wei = q @ k.transpose(-1,-2)                    # (B,nH,T,sH) x (B,nH,sH,T) -> (B,nH,T,T)
wei = wei.masked_fill(tril == 0, float('-inf')) # ignore anticausal time
wei = F.softmax(wei, dim=-1)                    # rescale to [0,1]

# create values - the actual information to give - for each token and 
v = x @ value                                   # (B,1,T,C) x (1,nH,C,sH) -> (B,nH,T,sH)
out = wei @ v                                   # (B,nH,T,T) x (B,nH,T,sH) -> (B,nH,T,sH)
out = out.permute(0,2,1,3).reshape((B,T,-1))    # (B,nH,T,sH) -> (B,T,nH,sH) -> (B,T,nH*sH) = (B,T,C)
x1.shape, k.shape, x.shape, v.shape, wei.shape, out.shape

### RECAP
# input has batches (the number of parallel computations), time (the number of adjacent tokens that are taken into consideration), and channels (num of embedding dim)
# attention creates connections with varying strengths between tokens in the time domain. strength is determined by dot product of keys and queries.
# value is the actual information that is transmitted over the built connections.
# because we are interested in predicting the next character, we do not allow connections from future tokens to past tokens (tril matrix)
# all matrix multiplications could also be implemented as linear layers of a neural network
# in multi-head attention, we split the channel dimension in multiple seperate parts. each part will be processed by a different head (potentially different keys, queries, values, etc.)

(torch.Size([4, 8, 10]),
 torch.Size([4, 5, 8, 2]),
 torch.Size([4, 1, 8, 10]),
 torch.Size([4, 5, 8, 2]),
 torch.Size([4, 5, 8, 8]),
 torch.Size([4, 8, 10]))

In [6]:
# fct that prepares data for training. 
# x is of shape (batchsize, blocksize)
# y is of shape (batchsize, blocksize)
# for every entry in x[i,j], 
# y[i,j] is the prediction given context x[:i,j]
# this structure makes sense given the attention blocks that are 
# introduced later!

blocksize = 8
batchsize = 4
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(train_data)-blocksize, (batchsize,))
    x = torch.stack([train_data[i:i+blocksize] for i in ix])
    y = torch.stack([train_data[i+1:i+blocksize+1] for i in ix])
    return x,y

Xb,Yb = get_batch('train')
print(Xb,Yb,sep='\n')
Xb.shape,Yb.shape

tensor([[53,  1, 39,  1, 51, 39, 56, 40],
        [45, 53, 53, 42,  1, 50, 53, 56],
        [ 1, 40, 56, 47, 41, 49,  6,  0],
        [16, 59, 49, 43,  1, 53, 44,  1]])
tensor([[ 1, 39,  1, 51, 39, 56, 40, 50],
        [53, 53, 42,  1, 50, 53, 56, 42],
        [40, 56, 47, 41, 49,  6,  0, 35],
        [59, 49, 43,  1, 53, 44,  1, 26]])


(torch.Size([4, 8]), torch.Size([4, 8]))

In [7]:
# the most simple prediction model. 
# takes in a single character and predicts the next.
# the embedding provides a prob. distr. for every character in alphabet.
# watch out - torch usually requires (B,C,d1,d2,...), 
# we mostly use (B,d1,d2,...,C), so occasionally some modification are necesarry. 
# forward() specifies how the network works. 
# generate() applies the network to generate text based on its learning.
# notices that in this network the batch AND the time dimension act as
# if they are batch dimension. there is no special meaning to the
# time dimension as of now. 
# that is the reason why we can concatenate these dimensions in forward().


class BigramModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(len(alphabet), len(alphabet))

    def forward(self, x, y=None):
        logits = self.emb(x)

        if y is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            y = y.view(B*T)
            loss = F.cross_entropy(logits, y)

        return logits, loss
    
    def generate(self, x, max_tokens):
        for _ in range(max_tokens):
            logits, loss = self(x)
            logits = logits[:,-1,:]
            probs = F.softmax(logits, dim=-1)
            pred = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, pred), dim=1)
        return x


In [71]:
m = BigramModel()
log, los = m(Xb, Yb)

print(decode(m.generate(x=torch.zeros((1,1), dtype=torch.long), max_tokens=50)[0].tolist()))


i'iWXutSQyMy&;aYVs&GgAFk jW-dpHYr
;kQm!PE'zB,OIcHy


In [22]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)


In [57]:
for _ in range(1000):
    xb,yb = get_batch('train')
    logits, loss = m(xb,yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

2.729191541671753


In [66]:
print(decode(m.generate(x=torch.zeros((1,1), dtype=torch.long), max_tokens=50)[0].tolist()))


Aqber'scou have usure w?
Wininorry'the
MI t pry at


In [301]:
# version 1
tril = torch.tril(torch.ones((blocksize,blocksize))) 
atmask = tril / tril.sum(axis=1, keepdim=True)


In [304]:
# version 3
tril = torch.tril(torch.ones((blocksize,blocksize)))
wei = torch.zeros((blocksize,blocksize))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
(wei @ xtest).shape

torch.Size([4, 8, 2])

In [320]:
# version 4: self-attention
B,T,C = 4,8,2
head_size = 6
xtest = torch.randn((B,T,C))    # (B,T,C)
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(xtest)      # (B,T,H)
q = query(xtest)    # (B,T,H)
wei = k @ q.transpose(-1,-2) * head_size**-0.5   # (B,T,H) @ (B,H,T) --> (B,T,T)

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(xtest)
out = wei @ v

wei.shape, v.shape, out.shape

(torch.Size([4, 8, 8]), torch.Size([4, 8, 6]), torch.Size([4, 8, 6]))