# Introduction (WIP)
- This notebook is based on the "Let's build GPT: from scratch, in code, spelled out" tutorial by Andrej Karpathy. You can find the tutorial here --> https://www.youtube.com/watch?v=kCc8FmEb1nY&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&index=7
- There are several different approaches in this notebook that do not strictly follow the original video. Some implementations are my own.
- I am using the same shakespeare text as in the video.

# Import needed libraries

In [1]:
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, random_split

# Device agnostic code

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
generator = torch.Generator(device=device)
print(f"default device set to {device}")

default device set to cpu


# Prepare the data

In [3]:
with open("/kaggle/input/shakespeare/input.txt", "r", encoding="utf-8") as f:
    text = f.read()

vocab = sorted(set(text))
vocab_size = len(vocab)

print(vocab)
print(vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


### Tokenizer

In [4]:
stoi = {c: v for v, c in enumerate(vocab)}
itos = {v: c for c, v in stoi.items()}
print(stoi["h"])
print(itos[46])

46
h


In [5]:
encode = lambda d: [stoi[idx] for idx in d]
decode = lambda e: "".join([itos[idx] for idx in e])

encoded = encode("hello, how are you?!")
decoded = decode(encoded)
print(encoded)
print(decoded)

[46, 43, 50, 50, 53, 6, 1, 46, 53, 61, 1, 39, 56, 43, 1, 63, 53, 59, 12, 2]
hello, how are you?!


# Prepare the dataset

In [6]:
context_size = 8
n_embd = 5
vocab_size = len(vocab)

In [7]:
def make_dataset(text, context_size):
    data = torch.tensor(encode(text), dtype=torch.long)

    #random_idx = torch.randint(0, len(data)-context_size, (int(len(data)/context_size),))
    random_idx = torch.randperm(len(data)-context_size)
    inputs = torch.stack([data[idx:idx+context_size] for idx in random_idx])
    labels = torch.stack([data[idx+1:idx+context_size+1] for idx in random_idx])

    return TensorDataset(inputs, labels)


In [8]:
# sicne randint might give the same random_idx, randperm is going to be preffered
print(torch.randint(0, 10, (10,)))
print(torch.randperm(10))

tensor([3, 7, 5, 7, 9, 3, 1, 3, 4, 1])
tensor([5, 2, 4, 1, 7, 3, 6, 8, 9, 0])


In [9]:
dataset = make_dataset(text=text[:100000], context_size=8)

In [10]:
train_split = int(len(dataset)*0.8)
test_split = int(len(dataset)-train_split)

train_dataset, test_dataset = random_split(dataset=dataset, lengths=[train_split, test_split], generator=generator)

In [11]:
batch_size = 32
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, generator=generator)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, generator=generator)

In [12]:
def sample_from_data(dataloader):
    for batch, (X, y) in enumerate(dataloader):
        #print(f"batch {batch}, input {X}, label {y}")
        #print(batch)
        pass

In [13]:
sample_from_data(dataloader=train_dataloader)

# Base model (MLP)

In [14]:
class MLP(nn.Module):
    def __init__(self, context_size, n_embd, vocab_size):
        super().__init__()


        self.context_size = context_size
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # B x T x C; B --> batches, T --> time (context_size), C --> n_embd
        self.pos_embedding_table = nn.Embedding(context_size, n_embd) # T x C; this is from the posisitional encoding part of the video
        
        self.linear1 = nn.Linear(in_features=context_size*n_embd, out_features=8*8) # B x T*C @ T*C x H; H --> number of hidden_units
        self.linear2 = nn.Linear(in_features=8*8, out_features=8*8*8)
        self.linear3 = nn.Linear(in_features=8*8*8, out_features=vocab_size)
        self.act_fn = nn.Tanh()

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        B, T = idx.shape
        C = self.n_embd
        positions = torch.arange(start=0, end=T, step=1)
        x = self.token_embedding_table(idx) + self.pos_embedding_table(positions)
        x = x.view(B, T*C)

        x = self.act_fn(self.linear1(x))
        x = self.act_fn(self.linear2(x))
        x = self.linear3(x)

        return x

    def generate(self, idx: torch.Tensor, randomize: bool, max_length: int, num_samples: int) -> torch.Tensor:
        outputs = []
        for sample in range(num_samples):
            full_text = "" 
            for i in range(max_length):
                logits = self(idx)
                percents = torch.softmax(logits, dim=1)

                if randomize:
                    pred = torch.multinomial(percents, num_samples=1)
                    full_text += decode(pred.tolist()[0])
                    idx = torch.cat([idx[:, 1:], pred], dim=1) # update the context, remove the first element of the tensor and add the new prediction made by the model
                else:
                    pred = torch.argmax(percents)
                    full_text += decode([pred.item()])
                    idx = torch.cat([idx[:, 1:], pred.view(1, 1)], dim=1) # update the context, remove the first element of the tensor and add the new prediction made by the model
                    # in the argmax the output is a single element, pred.view(1, 1) turns it into a batch of dim 1, so it can be concatenated to the previous context

            outputs.append(full_text)

        return outputs

            


# Define the base model, optimizer and loss function

In [15]:
mlp = MLP(context_size=context_size, n_embd=n_embd, vocab_size=vocab_size)
optimizer = torch.optim.Adam(params=mlp.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# Take samples from the base model

In [16]:
@torch.no_grad
def model_sampler(model, context, randomize, max_length, num_samples):
    #print(len(context))
    #print(context_size)
    test = torch.tensor([[20, 53, 61,  1, 39, 56, 58, 39]])
    #print(test[:, 1:])
    result = torch.cat((test[:, 1:], torch.tensor([[99]])), dim=1)
    #print(result)
    #print("\n")

    mlp.eval()
    idx = torch.tensor(encode(context), dtype=torch.long).view(1, len(encode(context))) # inputs must be batched
    outputs = mlp.generate(idx=idx, randomize=randomize, max_length=max_length, num_samples=num_samples)
    for output in outputs:
        print(f"{output} \n\n")

model_sampler(model=mlp, context="How are ", randomize=True, max_length=10, num_samples=5)

XOV,L
ERH' 


ubyl$Fwizf 


 z:z:ITQ:N 


r'qxW:I$'I 


LPVmD$yx,K 




# Training loop

In [17]:
def train_model(model, dataloader, loss_fn, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        for batch, (X, y) in tqdm(enumerate(dataloader)):
            logits = model(X)
            loss = loss_fn(logits, y[:, -1])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch % 200 == 0:
                print(f"loss for batch {batch} --> {loss} at epoch {epoch}")

    print(f"loss for the very last batch --> {loss}")

In [18]:
mlp.train()
train_model(model=mlp, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer, epochs=1)

0it [00:00, ?it/s]

loss for batch 0 --> 4.203449249267578 at epoch 0


254it [00:00, 336.15it/s]

loss for batch 200 --> 3.1416616439819336 at epoch 0


459it [00:01, 336.09it/s]

loss for batch 400 --> 2.828216075897217 at epoch 0


669it [00:02, 345.57it/s]

loss for batch 600 --> 2.7205810546875 at epoch 0


844it [00:02, 328.87it/s]

loss for batch 800 --> 2.6857826709747314 at epoch 0


1055it [00:03, 342.28it/s]

loss for batch 1000 --> 2.4109437465667725 at epoch 0


1257it [00:03, 304.58it/s]

loss for batch 1200 --> 2.7299134731292725 at epoch 0


1462it [00:04, 335.34it/s]

loss for batch 1400 --> 2.5381977558135986 at epoch 0


1668it [00:05, 325.84it/s]

loss for batch 1600 --> 2.264328956604004 at epoch 0


1845it [00:05, 345.83it/s]

loss for batch 1800 --> 2.2220375537872314 at epoch 0


2057it [00:06, 343.03it/s]

loss for batch 2000 --> 2.772373676300049 at epoch 0


2268it [00:06, 344.83it/s]

loss for batch 2200 --> 2.3911831378936768 at epoch 0


2443it [00:07, 329.09it/s]

loss for batch 2400 --> 2.3883790969848633 at epoch 0


2500it [00:07, 331.95it/s]

loss for the very last batch --> 2.1770222187042236





# Base model inference

In [19]:
@torch.no_grad
def model_inference(model, dataloader):
    mlp.eval()
    X, y = next(iter(dataloader))
    logits = model(X)
    percents = torch.softmax(logits, dim=1) # dim=1 since the input was batched
    preds = torch.argmax(percents, dim=1) # dim=1 since the input was batched
    print(f"for {X} \n model predicted {preds}")
    print(f"expected --> {y[:, -1]}")
    print(y)

In [20]:
model_inference(model=mlp, dataloader=train_dataloader)

for tensor([[ 1, 58, 46, 43, 47, 56,  1, 54],
        [ 1, 39, 50, 50,  1, 58, 53,  1],
        [63, 53, 59,  1, 46, 43, 52, 41],
        [56, 52, 47, 52, 45,  1, 49, 47],
        [49,  1, 51, 43,  8,  1, 21,  6],
        [ 0, 14, 30, 33, 32, 33, 31, 10],
        [39, 57,  1, 53, 44, 58, 43, 52],
        [46, 53, 56, 53, 59, 45, 46, 50],
        [59, 50, 42,  1, 58, 46, 43, 63],
        [54, 50, 59, 41, 49,  1, 53, 59],
        [56, 43, 40, 59, 49, 43,  1, 63],
        [58,  1, 44, 53, 56, 58, 46,  0],
        [ 0, 25, 13, 30, 15, 21, 33, 31],
        [ 0, 18, 53, 56,  1, 61, 46, 39],
        [59, 57,  6,  1, 58, 53,  1, 39],
        [46, 63,  1, 44, 56, 47, 43, 52],
        [44,  1, 39, 45, 39, 47, 52,  6],
        [43, 51,  0, 32, 47, 51, 43,  7],
        [47, 59, 57,  1, 15, 53, 56, 47],
        [39, 47, 52,  1, 59, 57,  1, 51],
        [47, 56, 52, 43, 57, 57,  1, 53],
        [55, 59, 43, 57, 58,  1, 63, 53],
        [50, 50, 63,  1, 57, 51, 47, 50],
        [43,  1, 46, 53, 51, 4

In [21]:
model_sampler(model=mlp, context="How are ", randomize=True, max_length=500, num_samples=1)

trd, int on thald co s, cut you m'd ink ioL

VOalys ant matr th: youck it the mrle.

ThabIL

hithe t wheil theit meinld, pthas te of thplols: vill sil
 ingd t se hi to enr ursns, pilla,s thel oulbog

Sld ou himCse Micou, ruoos thout pin te re'cut inq to cer ounirro: th m;
Aaditn, brav.

CORINIA:
Lhaintt frowt iseger, uf r, yucimec:

Fid tarv, boty then ous dese t oll yo vosila meat, bale ie'tarnevened,ere: sheod me as anlt tar heroms ined, rhak. I ce siverl:
Nofa Ted. ehelpt aar of tr itus. toou 




# Self attention math

In [22]:
sample_batch = next(iter(test_dataloader))[0]
B, T = sample_batch.shape # batch of B by T
print(B, T)
example_emb = nn.Embedding(vocab_size, 4)
embedded = example_emb(sample_batch)
B, T, C = embedded.shape # embedded is Batches by Time (context_sie) by Channels (num of values per token)
print(embedded.shape)

32 8
torch.Size([32, 8, 4])


In [23]:
bag_of_words = torch.zeros(size=(B, T, C)) # each of the values has a unique value

for batch_idx in range(B):
    for context_idx in range(T):
        xprev = embedded[batch_idx, :context_idx+1]
        bag_of_words[batch_idx, context_idx] = torch.mean(xprev, dim=0)
print(bag_of_words)

tensor([[[-1.7578e-01, -2.7511e-01,  3.8973e-01, -3.1879e-02],
         [-3.7397e-01,  4.2724e-01,  1.3090e-01, -1.0879e-01],
         [-2.7785e-01,  5.9413e-01, -3.4005e-01,  5.0945e-01],
         ...,
         [-3.5089e-01,  4.3939e-01, -5.0236e-01,  2.2753e-01],
         [-4.8447e-01,  2.7865e-01, -4.5576e-01,  2.7088e-01],
         [-4.4588e-01,  2.0943e-01, -3.5007e-01,  2.3303e-01]],

        [[-5.7216e-01,  1.1296e+00, -1.2793e-01, -1.8570e-01],
         [-2.5776e-01,  6.2494e-01, -5.0406e-01, -1.9584e-01],
         [-4.6103e-01,  6.3299e-01, -2.6764e-01, -1.3422e-01],
         ...,
         [-3.3368e-01,  3.5445e-01, -6.2654e-02,  9.3825e-03],
         [-3.6775e-01,  4.6519e-01, -7.1979e-02, -1.8486e-02],
         [-4.0152e-01,  3.7491e-01, -3.1470e-01, -1.2347e-01]],

        [[ 2.4236e+00,  1.7779e+00, -6.6982e-01,  1.9079e+00],
         [ 7.4922e-01,  3.9269e-01, -9.7558e-02,  1.6171e+00],
         [ 3.0909e-01,  1.0370e-01, -2.3918e-01,  1.2321e+00],
         ...,
         

In [24]:
ones = torch.ones(size=(3, 3))
tril = torch.tril(ones) # lower triangular part of a matrix
print(tril)
a = torch.randint(0, 10, (3, 2), dtype=torch.float32)
b = torch.randint(0, 10, (2, 3), dtype=torch.float32)
matmul_output = a @ b
matmul_tril_output = torch.tril(a) @ b
print(matmul_output)
print(matmul_tril_output)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[ 0.,  0.,  0.],
        [16., 12., 30.],
        [64., 34., 99.]])
tensor([[ 0.,  0.,  0.],
        [16., 12., 30.],
        [64., 34., 99.]])


In [25]:
# do the same as bag of words but with matrix multiplication (dot product)
a = torch.ones(size=(3, 3), dtype=torch.float32)
b = torch.randint(0, 10, (3, 2), dtype=torch.float32)

a = torch.tril(a)
"""
b = torch.tensor(
    [
        [2, 7],
        [6, 4],
        [6, 5]
    ], dtype=torch.float32
)
"""
print(a)
a = a/a.sum(dim=1, keepdim=True)
print(a)

output = a @ b
print(output)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[1.0000, 4.0000],
        [4.0000, 3.0000],
        [2.6667, 4.6667]])


In [26]:
sample_batch = next(iter(test_dataloader))[0]
B, T = sample_batch.shape # batch of B by T
example_emb = nn.Embedding(vocab_size, 4)
embedded = example_emb(sample_batch)
B, T, C = embedded.shape # embedded is Batches by Time (context_sie) by Channels (num of values per token)
#print(embedded.shape)

wei = torch.tril(torch.ones(size=(T, T)))
wei = wei / wei.sum(dim=1, keepdim=True)
print(embedded.shape) # B x T x C
print(wei.shape) # T x T
#  1xTxT @ BxTxC
bag_of_words = wei @ embedded
print(bag_of_words)

torch.Size([32, 8, 4])
torch.Size([8, 8])
tensor([[[ 1.9280,  0.6144,  0.5746,  0.1969],
         [ 1.2407,  0.8683,  0.2741, -0.2585],
         [ 1.0031,  0.7238, -0.0335,  0.1450],
         ...,
         [ 0.5927,  0.6425,  0.1193, -0.2535],
         [ 0.6924,  0.7620,  0.2443, -0.4628],
         [ 0.8469,  0.7436,  0.2855, -0.3804]],

        [[ 0.5534,  1.1223, -0.0263, -0.7139],
         [ 0.6984, -0.4539,  0.1814,  0.0345],
         [ 0.6050, -0.2214, -0.3025,  0.6756],
         ...,
         [ 0.4217,  0.3487,  0.1173, -0.1781],
         [ 0.4406,  0.4592,  0.0967, -0.2546],
         [ 0.3808,  0.5652,  0.2071, -0.4068]],

        [[-0.7959,  2.3944, -0.3589,  0.4237],
         [-0.9841,  0.6067,  0.9223,  1.2838],
         [-0.3471,  0.1231,  0.4180,  0.9470],
         ...,
         [-0.2035,  0.3260,  0.3895,  0.4830],
         [-0.0953,  0.4397,  0.3301,  0.3120],
         [ 0.0384,  0.4003,  0.4160,  0.2190]],

        ...,

        [[ 1.5730,  1.6455,  1.7858, -0.4104],
   

# Bag of words type aggregation with a mask

In [27]:
tril = torch.tril(torch.ones(size=(T, T)))
wei = torch.zeros(size=(T, T)) # zeros just so there's a plaaceholder for masked_fill
wei = wei.masked_fill(tril==0, float("-inf")) # whenever the value in tril is 0, it will get replaced with -inf; this allows softmax to come into place, since -inf will get a percent of 0
wei = torch.softmax(wei, dim=1)
print(wei)
bag_of_words = wei @ embedded
print(bag_of_words)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[[ 1.9280,  0.6144,  0.5746,  0.1969],
         [ 1.2407,  0.8683,  0.2741, -0.2585],
         [ 1.0031,  0.7238, -0.0335,  0.1450],
         ...,
         [ 0.5927,  0.6425,  0.1193, -0.2535],
         [ 0.6924,  0.7620,  0.2443, -0.4628],
         [ 0.8469,  0.7436,  0.2855, -0.3804]],

        [[ 0.5534,  1.1223, -0.0263, -0.7139],
         [ 0.6984, -0.4539,  0.1814,  0.0345],
         [ 0.60

# MLP model with agreggation
- a problem that needs to be addressed with the previous model is that it needs to always receive a input of B x T (batch_size by context_size), whereas it would be best if the model could adapt to inputs of different context_size

In [28]:
# Code to allow comunication between past tokens
C = n_embd
T = context_size
B = batch_size
wei = torch.zeros(size=(T, T))
tril = torch.tril(torch.ones(size=(T, T)))
wei = wei.masked_fill(tril==0, float('-inf'))
xbow = wei.softmax(dim=1)

test_tensor = torch.randn(size=(B, T, C))
print(xbow @ test_tensor)

tensor([[[ 6.0203e-01,  7.7065e-01,  5.7534e-01, -4.4121e-01,  9.4248e-01],
         [ 5.0862e-02,  2.8102e-01, -5.0223e-01,  9.0502e-01,  3.9423e-01],
         [ 5.9646e-01, -3.6989e-01, -4.1166e-01,  6.3679e-01,  1.3495e-02],
         ...,
         [ 3.3532e-01, -4.9190e-01, -4.1953e-01,  2.7829e-01, -3.8374e-02],
         [ 2.0919e-01, -2.6207e-01, -4.4235e-02,  2.3070e-01,  3.0028e-02],
         [ 2.9527e-01, -1.9783e-01, -1.1896e-01,  2.9992e-01,  9.6746e-02]],

        [[-6.8395e-01,  4.1657e-02,  7.7012e-01,  5.0897e-01,  1.1499e-02],
         [-7.3110e-01,  5.9835e-01,  5.5834e-01, -1.5275e-01, -2.6736e-01],
         [-9.1908e-01,  6.3536e-01,  4.8470e-01, -3.6135e-01,  4.0921e-02],
         ...,
         [-3.0068e-01,  8.3545e-01,  3.7348e-01, -7.0396e-01, -1.3963e-01],
         [-2.7897e-01,  7.4900e-01,  4.8425e-01, -6.7691e-01, -7.3942e-02],
         [-1.1215e-01,  5.6384e-01,  4.4792e-01, -7.1354e-01, -6.7875e-02]],

        [[ 1.9688e+00, -1.3080e-01, -1.7727e-01,  4.9248

In [29]:
# a problem that needs to be addressed with the previous model is that it needs to always receive a 
batch_sample_inputs, batch_sample_labels = next(iter(train_dataloader))
print(batch_sample_inputs.shape)
print(batch_sample_labels.shape)

torch.Size([32, 8])
torch.Size([32, 8])


In [30]:
class MLPv2(nn.Module):
    def __init__(self, vocab_size, n_embd, context_size):
        super().__init__()

        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.context_size = context_size

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.positional_embedding_table = nn.Embedding(context_size, n_embd) # the position of each token has a separate table of values; this helps the model keep track of the order of characters

        #self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size) # this time in_features=n_embd so it's not context_size dependant
        self.linear1 = nn.Linear(in_features=n_embd, out_features=8*8)
        self.linear2 = nn.Linear(in_features=8*8, out_features=8*8*8)
        self.linear3 = nn.Linear(in_features=8*8*8, out_features=vocab_size)

        self.act_fn = nn.Tanh()

    def info(self):
        info_dict = {
            "vocab_size": self.vocab_size,
            "n_embd": self.n_embd,
            "context_size": self.context_size
        }

        return info_dict

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T = x.shape
        C = self.n_embd

        #print(B, T, C)

        positions = torch.arange(start=0, end=T, step=1)
        token_emb = self.token_embedding_table(x) # batch_size x context_size x n_embd --> BxTxC
        pos_emb = self.positional_embedding_table(positions) # T x C (each context has a position)

        x = token_emb + pos_emb
        x = x.view(B*T, C)

        #logits = self.lm_head(token_emb.view(B*T, C)) # output of shape B*T x hidden_units1; this might be a problem for the labels since they are of shape B x T, thus they need to be reshaped aswell
        x = self.act_fn(self.linear1(token_emb.view(B*T, C))) # hidden_units1 x hidden_units2
        x = self.act_fn(self.linear2(x)) # hidden_units2 x hidden_units3
        x = self.linear3(x) # hidden_units3 x vocab_size
        return x

    def generate(self, starting_idx: torch.Tensor, max_length) -> torch.Tensor:

        full_text = itos[starting_idx.item()]
        for i in range(max_length):
            logits = self(starting_idx)
            percents = torch.softmax(logits, dim=1)
            pred = torch.multinomial(percents, num_samples=1)
            starting_idx = pred
            full_text += decode([pred.item()])
        return full_text


In [31]:
mlpv2 = MLPv2(vocab_size=vocab_size, n_embd=32, context_size=context_size)

In [32]:
mlpv2.info()

{'vocab_size': 65, 'n_embd': 32, 'context_size': 8}

In [33]:
mlpv2_loss_fn = nn.CrossEntropyLoss()
batch_sample_inputs, batch_sample_labels = next(iter(train_dataloader))
sample_input = batch_sample_inputs[0]
sample_label = batch_sample_labels[0]
print(sample_input) # 1 x T (B x T)
print(sample_label.view(1, -1)) # 1 x T (B x T)
mlpv2.eval()
with torch.inference_mode():
    logits = mlpv2(sample_input.view(1, -1))
    labels = sample_label.view(-1) # from B x T to B*T to match the shape of the logits
    print(logits.shape) # B*T x vocab_size
    print(labels.shape)

    loss = mlpv2_loss_fn(logits, labels)
    print(loss)

tensor([ 1, 58, 46, 43, 47, 56,  1, 54])
tensor([[58, 46, 43, 47, 56,  1, 54, 53]])
torch.Size([8, 65])
torch.Size([8])
tensor(4.1342)


In [34]:
@torch.no_grad
def generate_from_model(model, num_outputs, starting_char, max_length):
    mlpv2.eval()
    outputs = []
    starting_idx = torch.tensor([stoi[starting_char]], dtype=torch.long).view(1, -1)
    for i in range(num_outputs):
        output = mlpv2.generate(starting_idx=starting_idx, max_length=max_length) # must be batched
        outputs.append(output)

    return outputs

In [35]:
test_output = generate_from_model(model=mlpv2, num_outputs=1, starting_char="a", max_length=5)
print(test_output[0])

aO-,yF


In [36]:
def train_model(model, dataloader, loss_fn, optimizer, epochs):
    model.train()

    for epoch in range(epochs):
        for batch, (X, y) in enumerate(dataloader):
            logits = model(X)
            loss = loss_fn(logits, y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            if batch % 1200 == 0:
                print(f"loss for batch {batch} --> {loss} at epoch {epoch}")

    print(f"loss for the very last batch --> {loss}")

In [37]:
mlpv2_optimizer = torch.optim.Adam(params=mlpv2.parameters(), lr=1e-3)
mlpv2_loss_fn = nn.CrossEntropyLoss()

In [38]:
train_model(model=mlpv2, dataloader=train_dataloader, loss_fn=mlpv2_loss_fn, optimizer=mlpv2_optimizer, epochs=4)

loss for batch 0 --> 4.193885326385498 at epoch 0
loss for batch 1200 --> 2.415966033935547 at epoch 0
loss for batch 2400 --> 2.2945008277893066 at epoch 0
loss for batch 0 --> 2.498746633529663 at epoch 1
loss for batch 1200 --> 2.4111926555633545 at epoch 1
loss for batch 2400 --> 2.2905168533325195 at epoch 1
loss for batch 0 --> 2.4965763092041016 at epoch 2
loss for batch 1200 --> 2.410950183868408 at epoch 2
loss for batch 2400 --> 2.2882120609283447 at epoch 2
loss for batch 0 --> 2.495180130004883 at epoch 3
loss for batch 1200 --> 2.4093832969665527 at epoch 3
loss for batch 2400 --> 2.2853567600250244 at epoch 3
loss for the very last batch --> 2.2023563385009766


In [39]:
test_outputs = generate_from_model(model=mlpv2, max_length=100, num_outputs=2, starting_char="b")
for output in test_outputs:
    print(f"{output}\n\n")

bes o yo pesthe cave grth o veat yosean ce ad towome nant und cteoth inavenkexthone ous sshesthes tlo


bret m se d qul ter
CO ulu OLAse whe mon y.
IUS:
COLA:


TIUS:
Wheranin d cbe th iak' be.
Tom orise a




In [40]:
torch.arange(start=0, end=T, step=1) # from 0 to T-1

tensor([0, 1, 2, 3, 4, 5, 6, 7])

# Self attention
- with xbow you can add information about the tokens, but the model itself does not attribute any weight to them. This is what self attention solves by using Keys, Queries and Values
- every token will have a specific Query (Q) and Key (K) attatched to it
    - Query --> what the model is looking for
    - Key --> the weight the model is giving to this certain token
    - Value --> Q @ K to get the affinities between themselves

In [41]:
# Code to allow comunication between past tokens

B, T, C = 2, 4, 8
wei = torch.zeros(size=(T, T))
tril = torch.tril(torch.ones(size=(T, T)))
wei = wei.masked_fill(tril==0, float('-inf'))
xbow = wei.softmax(dim=1)

test_tensor = torch.randn(size=(B, T, C))
output = xbow @ test_tensor
print(output.shape)

torch.Size([2, 4, 8])


In [42]:
head_size = 32
key = nn.Linear(in_features=C, out_features=head_size, bias=False) # bias = False so that it's just a multiplication
query = nn.Linear(in_features=C, out_features=head_size, bias=False) 

#print(test_tensor.shape)
k = key(test_tensor) # BxTxC @ BxTxhead_size --> BxTxhead_size; each batch has a context and each context character has a key value
q = query(test_tensor) # BxTxC @ BxTxhead_size --> BxTxhead_size; each batch has a context and each context character has a query value
print(k.shape)
print(q.shape)

torch.Size([2, 4, 32])
torch.Size([2, 4, 32])


In [43]:
print(k.transpose(-2, -1).shape) # same as k.permute(0, 2, 1)
print(k.permute(0, 2, 1).shape)

wei = q @ k.transpose(-2, -1) # BxTxhead_size @ Bxhead_sizexT --> BxTxT
print(wei.shape) #  B x T x T

tril = torch.tril(torch.ones(size=(T, T)))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = torch.softmax(wei, dim=-1) # dim=-1 in this case, since wei is of shape

output = wei @ test_tensor

print(output.shape) # B x T x C

torch.Size([2, 32, 4])
torch.Size([2, 32, 4])
torch.Size([2, 4, 4])
torch.Size([2, 4, 8])


# Update the model with a self attention head

In [44]:
class Head(nn.Module):
    def __init__(self, head_size, n_embd):
        super(Head, self).__init__()

        self.Q = nn.Linear(in_features=n_embd, out_features=head_size)
        self.K = nn.Linear(in_features=n_embd, out_features=head_size)
        self.V = nn.Linear(in_features=n_embd, out_features=head_size)


    def forward(self, x: torch.Tensor, mask) -> torch.Tensor:
        B, T, C = x.shape
        #mask = torch.tril(torch.ones(size=(T, T)))

        q = self.Q(x) # B x T x head_size
        k = self.K(x) # B x T x head_size
        v = self.V(x) # B x T x head_size

        k = k.transpose(-2, -1) # B x head_size x T

        wei = q @ k # BxTxhead_size @ Bxhead_sizexT --> BxTxT

        wei = wei.masked_fill(mask==0, float('-inf'))
        wei = torch.softmax(wei, dim=-1)

        #x = wei @ v # BxTxT @ BxTxhead_size --> BxTxhead_size
        x = wei @ x

        return x



In [45]:
class MLPv3(nn.Module):
    def __init__(self, vocab_size, n_embd, context_size):
        super(MLPv3, self).__init__()

        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.context_size = context_size

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.positional_embedding_table = nn.Embedding(context_size, n_embd) # the position of each token has a separate table of values; this helps the model keep track of the order of characters

        self.head = Head(head_size=32, n_embd=n_embd)

        #self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size) # this time in_features=n_embd so it's not context_size dependant
        self.linear1 = nn.Linear(in_features=n_embd, out_features=8*8)
        self.linear2 = nn.Linear(in_features=8*8, out_features=8*8*8)
        self.linear3 = nn.Linear(in_features=8*8*8, out_features=vocab_size)

        self.act_fn = nn.Tanh()

        self.mask = torch.tril(torch.ones(size=(context_size, context_size)))

    def info(self):
        info_dict = {
            "vocab_size": self.vocab_size,
            "n_embd": self.n_embd,
            "context_size": self.context_size
        }

        return info_dict

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T = x.shape
        C = self.n_embd

        #print(B, T, C)

        positions = torch.arange(start=0, end=T, step=1)
        token_emb = self.token_embedding_table(x) # batch_size x context_size x n_embd --> BxTxC
        pos_emb = self.positional_embedding_table(positions) # T x C (each context has a position)


        x = token_emb + pos_emb
        x = self.head(x, self.mask)
        x = x.view(B*T, C)

        return x

    def generate(self, starting_idx: torch.Tensor, max_length) -> torch.Tensor:

        full_text = itos[starting_idx.item()]
        for i in range(max_length):
            logits = self(starting_idx)
            percents = torch.softmax(logits, dim=1)
            pred = torch.multinomial(percents, num_samples=1)
            starting_idx = pred
            full_text += decode([pred.item()])
        return full_text


In [46]:
n_embd = 32
mlpv3 = MLPv3(vocab_size=vocab_size, n_embd=n_embd, context_size=context_size)
mlpv3.optimizer = torch.optim.Adam(params=mlpv3.parameters(), lr=1e-2)
mlpv3_loss_fn = nn.CrossEntropyLoss()
mlpv3.info()

{'vocab_size': 65, 'n_embd': 32, 'context_size': 8}

In [47]:
outputs = generate_from_model(model=mlpv3, max_length=200, num_outputs=2, starting_char="a")
for output in outputs:
    print(f"{output} \n\n")

anstil?

Fofg band d cowithor,
ENIUTore: t
Whityo thuchy me, the tin, tl' hesoobl pre gon ioeshald, y?
Ashe

MENENININUS:

Mar t te datre ce ik, bulbe h!
Fr wira
Thees! e isolo vel t goue ty, atitiumy  


anerdy ngrck, y Vode nowig st,
Tort glopimionire cith ous t tlle u' he ple.
OLA s wntwotonom f prsserd andengin io she tthe.
Mar ou f d tt mshindot r f pe sher:

Th tn beetre y kne and, pppe d tris twn 




In [48]:
#train_model(model=mlpv3, dataloader=train_dataloader, loss_fn=mlpv3_loss_fn, optimizer=mlpv2_optimizer, epochs=1)

In [49]:
outputs = generate_from_model(model=mlpv3, max_length=200, num_outputs=2, starting_char="a")
for output in outputs:
    print(f"{output} \n\n")

at s m of n waneye m, t uththot y poish hen
ENICOLAwhe t Lof t yosisethee at ere t, elst fru pethioto weshe.
An oofacoomim thej d, ere wily mefer? mathe,
rs as be s s shie s wido't harto Cin ce t, tll  


atwhive--alaasark yo ns n m o y f f nst m Pu streart thevaveptyo an s, cot ho ay coulun, merowethananowire rilowe m, ho bl wazech.
Whe we ich stod t tod dr wowe vie tod po yovether priple umy wou ac'
T 




# Self-attetion x cross-attention
- in self-attention the values for queries (Q), keys (K) and values (V) all come from x itself, thus self-attention
- in cross-attention those values can come from somewhere else