# Introduction (WIP)
- This notebook is based on the "Let's build GPT: from scratch, in code, spelled out" tutorial by Andrej Karpathy. You can find the tutorial here --> https://www.youtube.com/watch?v=kCc8FmEb1nY&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&index=7
- There are several different approaches in this notebook that do not strictly follow the original video. Some implementations are my own.
- I am using the same shakespeare text as in the video.

# Import needed libraries

In [1]:
from tqdm import tqdm

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, random_split

# Device agnostic code

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
generator = torch.Generator(device=device)
print(f"default device set to {device}")

default device set to cpu


# Prepare the data

In [3]:
with open("/kaggle/input/shakespeare/input.txt", "r", encoding="utf-8") as f:
    text = f.read()

vocab = sorted(set(text))
vocab_size = len(vocab)

print(vocab)
print(vocab_size)

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
65


### Tokenizer

In [4]:
stoi = {c: v for v, c in enumerate(vocab)}
itos = {v: c for c, v in stoi.items()}
print(stoi["h"])
print(itos[46])

46
h


In [5]:
encode = lambda d: [stoi[idx] for idx in d]
decode = lambda e: "".join([itos[idx] for idx in e])

encoded = encode("hello, how are you?!")
decoded = decode(encoded)
print(encoded)
print(decoded)

[46, 43, 50, 50, 53, 6, 1, 46, 53, 61, 1, 39, 56, 43, 1, 63, 53, 59, 12, 2]
hello, how are you?!


# Prepare the dataset

In [6]:
context_size = 8
n_embd = 5
vocab_size = len(vocab)

In [7]:
def make_dataset(text, context_size):
    data = torch.tensor(encode(text), dtype=torch.long)

    #random_idx = torch.randint(0, len(data)-context_size, (int(len(data)/context_size),))
    random_idx = torch.randperm(len(data)-context_size)
    inputs = torch.stack([data[idx:idx+context_size] for idx in random_idx])
    labels = torch.stack([data[idx+1:idx+context_size+1] for idx in random_idx])

    return TensorDataset(inputs, labels)


In [8]:
# sicne randint might give the same random_idx, randperm is going to be preffered
print(torch.randint(0, 10, (10,)))
print(torch.randperm(10))

tensor([2, 9, 0, 3, 4, 9, 2, 8, 3, 9])
tensor([5, 0, 8, 3, 6, 9, 1, 4, 2, 7])


In [9]:
dataset = make_dataset(text=text[:100000], context_size=8)

In [10]:
train_split = int(len(dataset)*0.8)
test_split = int(len(dataset)-train_split)

train_dataset, test_dataset = random_split(dataset=dataset, lengths=[train_split, test_split], generator=generator)

In [11]:
batch_size = 32
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, generator=generator)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, generator=generator)

In [12]:
def sample_from_data(dataloader):
    for batch, (X, y) in enumerate(dataloader):
        #print(f"batch {batch}, input {X}, label {y}")
        #print(batch)
        pass

In [13]:
sample_from_data(dataloader=train_dataloader)

# Base model (MLP)

In [14]:
class MLP(nn.Module):
    def __init__(self, context_size, n_embd, vocab_size):
        super().__init__()


        self.context_size = context_size
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # B x T x C; B --> batches, T --> time (context_size), C --> n_embd
        self.pos_embedding_table = nn.Embedding(context_size, n_embd) # T x C; this is from the posisitional encoding part of the video
        
        self.linear1 = nn.Linear(in_features=context_size*n_embd, out_features=8*8) # B x T*C @ T*C x H; H --> number of hidden_units
        self.linear2 = nn.Linear(in_features=8*8, out_features=8*8*8)
        self.linear3 = nn.Linear(in_features=8*8*8, out_features=vocab_size)
        self.act_fn = nn.Tanh()

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        B, T = idx.shape
        C = self.n_embd
        positions = torch.arange(start=0, end=T, step=1)
        x = self.token_embedding_table(idx) + self.pos_embedding_table(positions)
        x = x.view(B, T*C)

        x = self.act_fn(self.linear1(x))
        x = self.act_fn(self.linear2(x))
        x = self.linear3(x)

        return x

    def generate(self, idx: torch.Tensor, randomize: bool, max_length: int, num_samples: int) -> torch.Tensor:
        outputs = []
        for sample in range(num_samples):
            full_text = "" 
            for i in range(max_length):
                logits = self(idx)
                percents = torch.softmax(logits, dim=1)

                if randomize:
                    pred = torch.multinomial(percents, num_samples=1)
                    full_text += decode(pred.tolist()[0])
                    idx = torch.cat([idx[:, 1:], pred], dim=1) # update the context, remove the first element of the tensor and add the new prediction made by the model
                else:
                    pred = torch.argmax(percents)
                    full_text += decode([pred.item()])
                    idx = torch.cat([idx[:, 1:], pred.view(1, 1)], dim=1) # update the context, remove the first element of the tensor and add the new prediction made by the model
                    # in the argmax the output is a single element, pred.view(1, 1) turns it into a batch of dim 1, so it can be concatenated to the previous context

            outputs.append(full_text)

        return outputs

            


# Define the base model, optimizer and loss function

In [15]:
mlp = MLP(context_size=context_size, n_embd=n_embd, vocab_size=vocab_size)
optimizer = torch.optim.Adam(params=mlp.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

# Take samples from the base model

In [16]:
@torch.no_grad
def model_sampler(model, context, randomize, max_length, num_samples):
    #print(len(context))
    #print(context_size)
    test = torch.tensor([[20, 53, 61,  1, 39, 56, 58, 39]])
    #print(test[:, 1:])
    result = torch.cat((test[:, 1:], torch.tensor([[99]])), dim=1)
    #print(result)
    #print("\n")

    mlp.eval()
    idx = torch.tensor(encode(context), dtype=torch.long).view(1, len(encode(context))) # inputs must be batched
    outputs = mlp.generate(idx=idx, randomize=randomize, max_length=max_length, num_samples=num_samples)
    for output in outputs:
        print(f"{output} \n\n")

model_sampler(model=mlp, context="How are ", randomize=True, max_length=10, num_samples=5)

C:v'h.KN&K 


OU$Ad;k?'A 


cK$lvGfjb: 


-xI3bDZKL  


:Y!:uoCJEx 




# Training loop

In [17]:
def train_model(model, dataloader, loss_fn, optimizer, epochs):
    model.train()
    for epoch in range(epochs):
        for batch, (X, y) in tqdm(enumerate(dataloader)):
            logits = model(X)
            loss = loss_fn(logits, y[:, -1])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch % 200 == 0:
                print(f"loss for batch {batch} --> {loss} at epoch {epoch}")

    print(f"loss for the very last batch --> {loss}")

In [18]:
mlp.train()
train_model(model=mlp, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer, epochs=1)

0it [00:00, ?it/s]

loss for batch 0 --> 4.146660804748535 at epoch 0


255it [00:01, 268.05it/s]

loss for batch 200 --> 3.1106534004211426 at epoch 0


447it [00:01, 268.07it/s]

loss for batch 400 --> 2.642275810241699 at epoch 0


646it [00:02, 272.13it/s]

loss for batch 600 --> 2.6673550605773926 at epoch 0


847it [00:03, 283.74it/s]

loss for batch 800 --> 2.695472002029419 at epoch 0


1050it [00:03, 280.73it/s]

loss for batch 1000 --> 2.1004154682159424 at epoch 0


1255it [00:04, 282.31it/s]

loss for batch 1200 --> 2.489316463470459 at epoch 0


1457it [00:05, 280.87it/s]

loss for batch 1400 --> 1.9982508420944214 at epoch 0


1631it [00:06, 283.35it/s]

loss for batch 1600 --> 2.2123279571533203 at epoch 0


1833it [00:06, 274.56it/s]

loss for batch 1800 --> 2.3320584297180176 at epoch 0


2031it [00:07, 278.24it/s]

loss for batch 2000 --> 2.4104225635528564 at epoch 0


2234it [00:08, 284.43it/s]

loss for batch 2200 --> 2.572524309158325 at epoch 0


2436it [00:08, 279.03it/s]

loss for batch 2400 --> 2.1622660160064697 at epoch 0


2500it [00:09, 273.75it/s]

loss for the very last batch --> 1.9016098976135254





# Base model inference

In [19]:
@torch.no_grad
def model_inference(model, dataloader):
    mlp.eval()
    X, y = next(iter(dataloader))
    logits = model(X)
    percents = torch.softmax(logits, dim=1) # dim=1 since the input was batched
    preds = torch.argmax(percents, dim=1) # dim=1 since the input was batched
    print(f"for {X} \n model predicted {preds}")
    print(f"expected --> {y[:, -1]}")
    print(y)

In [20]:
model_inference(model=mlp, dataloader=train_dataloader)

for tensor([[58, 46, 39, 52,  1, 47, 42, 50],
        [57, 39, 60, 43, 42,  0, 37, 53],
        [47, 41, 46,  1, 46, 43,  0, 42],
        [ 0, 13, 57,  1, 39, 52, 63,  1],
        [43, 52,  1, 54, 53, 59, 52, 42],
        [51, 47, 50, 50, 47, 53, 52, 57],
        [43,  5, 57,  1, 57, 58, 39, 58],
        [ 1, 63, 53, 59, 56,  1, 41, 53],
        [53, 50, 42,  7, 47, 52,  1, 58],
        [51, 63,  1, 42, 43, 57, 47, 56],
        [47, 58, 58, 50, 43,  1, 55, 59],
        [41, 43, 57, 57, 53, 56, 57,  1],
        [ 1, 44, 47, 52, 43,  1, 57, 54],
        [61, 53, 59, 50, 42,  1, 44, 53],
        [10,  0, 37, 53, 59,  1, 39, 56],
        [57,  6,  0, 13, 57,  1,  5, 58],
        [ 0, 40, 43, 52, 41, 46, 43, 56],
        [13, 10,  0, 31, 61, 43, 43, 58],
        [52, 42, 47, 52, 45,  1, 46, 43],
        [33, 31, 10,  0, 35, 43, 50, 50],
        [63, 53, 59,  1, 58, 46, 47, 57],
        [41, 43,  1, 58, 53,  1, 57, 43],
        [63, 43, 58, 12,  0, 27,  1, 51],
        [45, 46, 58, 47, 52, 4

In [21]:
model_sampler(model=mlp, context="How are ", randomize=True, max_length=500, num_samples=1)

faks. ive ancen he, he sat yacagou;d, fou, bes honl:
Bf
MS.RAeI,ar fones
AAn bust be coucou',
Hat st lo,
 have
iouled, voup stselt, boce iut,
Aars, tounw, pseceetoud? trsintt, nt itotds-and tombland
Yuld;oot

Sen, srdedls, ofped
Mulf at torim.

MTUENENWpps pine dolledet!'e nements,, pe kithit s bo sha inud lou fo't bt furaes
Loat
Aad oms ppouedst. amkt ham;;d.

MANENIUIChon, oru poik cou heu.

BlUIM:NE'rUestep; po ton. wamkml.
WhMLO,:
LEOICNIUS:
Hertill, hes sa'te tous.,ar nrhard in sou.

Forl:, 




# Self attention math

In [22]:
sample_batch = next(iter(test_dataloader))[0]
B, T = sample_batch.shape # batch of B by T
print(B, T)
example_emb = nn.Embedding(vocab_size, 4)
embedded = example_emb(sample_batch)
B, T, C = embedded.shape # embedded is Batches by Time (context_sie) by Channels (num of values per token)
print(embedded.shape)

32 8
torch.Size([32, 8, 4])


In [23]:
bag_of_words = torch.zeros(size=(B, T, C)) # each of the values has a unique value

for batch_idx in range(B):
    for context_idx in range(T):
        xprev = embedded[batch_idx, :context_idx+1]
        bag_of_words[batch_idx, context_idx] = torch.mean(xprev, dim=0)
print(bag_of_words)

tensor([[[-1.7152, -1.6628,  1.9220, -0.5699],
         [-1.8801, -1.7981, -0.3845, -0.4875],
         [-0.8771, -1.0641, -0.4788,  0.0515],
         ...,
         [-0.1167, -0.4897, -0.4951, -0.1277],
         [-0.0074, -0.4484, -0.5282, -0.2701],
         [ 0.0087, -0.3526, -0.4009, -0.1639]],

        [[ 1.4021, -0.3635, -1.8786,  0.0854],
         [ 1.0147, -0.3344, -0.3362, -0.1383],
         [ 1.1161, -0.1369, -0.3775, -0.2427],
         ...,
         [ 0.4374, -0.1645, -0.0404,  0.1385],
         [ 0.0827, -0.4172, -0.4191,  0.0608],
         [ 0.0466, -0.4464, -0.6699,  0.2086]],

        [[ 1.1288,  0.4040, -0.6674,  1.1294],
         [ 0.4700, -0.6070, -0.4009,  0.8013],
         [ 0.5224, -0.5065,  0.1348,  0.4135],
         ...,
         [ 0.7478, -0.3227, -0.3659,  0.7559],
         [ 0.5560, -0.1962, -0.1366,  0.6243],
         [ 0.3498, -0.2706,  0.0725,  0.5135]],

        ...,

        [[ 0.3887, -0.4573, -0.0542,  2.0803],
         [-0.8473, -0.7564, -0.1124,  1.1075]

In [24]:
ones = torch.ones(size=(3, 3))
tril = torch.tril(ones) # lower triangular part of a matrix
print(tril)
a = torch.randint(0, 10, (3, 2), dtype=torch.float32)
b = torch.randint(0, 10, (2, 3), dtype=torch.float32)
matmul_output = a @ b
matmul_tril_output = torch.tril(a) @ b
print(matmul_output)
print(matmul_tril_output)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[ 56.,  75.,  73.],
        [  7.,   6.,   8.],
        [ 42., 117.,  75.]])
tensor([[ 56.,  48.,  64.],
        [  7.,   6.,   8.],
        [ 42., 117.,  75.]])


In [25]:
# do the same as bag of words but with matrix multiplication (dot product)
a = torch.ones(size=(3, 3), dtype=torch.float32)
b = torch.randint(0, 10, (3, 2), dtype=torch.float32)

a = torch.tril(a)
"""
b = torch.tensor(
    [
        [2, 7],
        [6, 4],
        [6, 5]
    ], dtype=torch.float32
)
"""
print(a)
a = a/a.sum(dim=1, keepdim=True)
print(a)

output = a @ b
print(output)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[6.0000, 8.0000],
        [3.0000, 4.0000],
        [4.6667, 3.6667]])


In [26]:
sample_batch = next(iter(test_dataloader))[0]
B, T = sample_batch.shape # batch of B by T
example_emb = nn.Embedding(vocab_size, 4)
embedded = example_emb(sample_batch)
B, T, C = embedded.shape # embedded is Batches by Time (context_sie) by Channels (num of values per token)
#print(embedded.shape)

wei = torch.tril(torch.ones(size=(T, T)))
wei = wei / wei.sum(dim=1, keepdim=True)
print(embedded.shape) # B x T x C
print(wei.shape) # T x T
#  1xTxT @ BxTxC
bag_of_words = wei @ embedded
print(bag_of_words)

torch.Size([32, 8, 4])
torch.Size([8, 8])
tensor([[[-0.3020,  0.8613,  0.8509,  1.0694],
         [-0.1474,  0.8005, -0.0386,  0.7087],
         [ 0.2306,  0.8734,  0.1007,  0.5407],
         ...,
         [-0.0237,  0.9753,  0.2149,  0.6148],
         [-0.1876,  1.1710,  0.3889,  0.6553],
         [ 0.1801,  1.2350,  0.3246,  0.6314]],

        [[-2.0734,  0.3787, -0.4855, -1.4053],
         [-0.6292,  0.5835, -0.7629, -0.4855],
         [ 0.0906,  0.8567, -1.0150, -0.4248],
         ...,
         [-0.2257, -0.0535, -0.1501, -0.4011],
         [-0.1924,  0.0598, -0.2613, -0.2940],
         [-0.1832,  0.0983, -0.1115, -0.1161]],

        [[ 0.9866,  1.0192,  0.3793,  0.2046],
         [ 0.0192,  0.2940, -0.0295,  0.1856],
         [ 0.2845,  0.4588, -0.3664,  0.2685],
         ...,
         [ 0.1857,  0.4227, -0.4971,  0.1486],
         [ 0.2495,  0.3161, -0.3195, -0.0099],
         [ 0.1306,  0.3211, -0.3383, -0.0613]],

        ...,

        [[ 1.3477, -0.2381, -1.7774,  1.2868],
   

# Bag of words type aggregation with a mask

In [27]:
tril = torch.tril(torch.ones(size=(T, T)))
wei = torch.zeros(size=(T, T)) # zeros just so there's a plaaceholder for masked_fill
wei = wei.masked_fill(tril==0, float("-inf")) # whenever the value in tril is 0, it will get replaced with -inf; this allows softmax to come into place, since -inf will get a percent of 0
wei = torch.softmax(wei, dim=1)
print(wei)
bag_of_words = wei @ embedded
print(bag_of_words)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])
tensor([[[-0.3020,  0.8613,  0.8509,  1.0694],
         [-0.1474,  0.8005, -0.0386,  0.7087],
         [ 0.2306,  0.8734,  0.1007,  0.5407],
         ...,
         [-0.0237,  0.9753,  0.2149,  0.6148],
         [-0.1876,  1.1710,  0.3889,  0.6553],
         [ 0.1801,  1.2350,  0.3246,  0.6314]],

        [[-2.0734,  0.3787, -0.4855, -1.4053],
         [-0.6292,  0.5835, -0.7629, -0.4855],
         [ 0.09

# MLP model with agreggation
- a problem that needs to be addressed with the previous model is that it needs to always receive a input of B x T (batch_size by context_size), whereas it would be best if the model could adapt to inputs of different context_size

In [28]:
# Code to allow comunication between past tokens
C = n_embd
T = context_size
B = batch_size
wei = torch.zeros(size=(T, T))
tril = torch.tril(torch.ones(size=(T, T)))
wei = wei.masked_fill(tril==0, float('-inf'))
xbow = wei.softmax(dim=1)

test_tensor = torch.randn(size=(B, T, C))
print(xbow @ test_tensor)

tensor([[[ 0.1772, -0.8563, -0.8250,  0.7784,  0.3994],
         [-1.4278, -0.2129, -0.9921,  0.2713,  0.2778],
         [-0.6630, -0.1838, -0.6789,  0.2007,  0.2597],
         ...,
         [-0.1470, -0.4072, -0.3895, -0.1202, -0.3026],
         [-0.0819, -0.2754, -0.4424,  0.0961, -0.4651],
         [-0.1635, -0.0382, -0.5058,  0.0578, -0.3659]],

        [[-0.1855,  0.0972,  0.8307, -1.3274, -0.5651],
         [-0.4229, -0.4281,  0.3734, -0.5946, -0.3629],
         [-0.1432, -0.3768, -0.1316, -0.6000, -0.5317],
         ...,
         [ 0.0350, -0.6950, -0.6717,  0.3662, -0.1211],
         [-0.0757, -0.5785, -0.4903, -0.1183, -0.2821],
         [-0.1859, -0.5367, -0.4898, -0.1930, -0.2658]],

        [[ 1.2801, -0.5447,  0.9786, -0.2297, -0.3808],
         [ 0.5822, -0.5246,  0.6279, -0.3887,  0.4966],
         [ 0.4285,  0.0777,  0.2231, -0.3176,  0.0764],
         ...,
         [ 0.4156,  0.0136, -0.2311, -0.1844, -0.2442],
         [ 0.6149, -0.0090, -0.1830, -0.0991,  0.0663],
  

In [29]:
# a problem that needs to be addressed with the previous model is that it needs to always receive a 
batch_sample_inputs, batch_sample_labels = next(iter(train_dataloader))
print(batch_sample_inputs.shape)
print(batch_sample_labels.shape)

torch.Size([32, 8])
torch.Size([32, 8])


In [30]:
class MLPv2(nn.Module):
    def __init__(self, vocab_size, n_embd, context_size):
        super().__init__()

        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.context_size = context_size

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.positional_embedding_table = nn.Embedding(context_size, n_embd) # the position of each token has a separate table of values; this helps the model keep track of the order of characters

        #self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size) # this time in_features=n_embd so it's not context_size dependant
        self.linear1 = nn.Linear(in_features=n_embd, out_features=8*8)
        self.linear2 = nn.Linear(in_features=8*8, out_features=8*8*8)
        self.linear3 = nn.Linear(in_features=8*8*8, out_features=vocab_size)

        self.act_fn = nn.Tanh()

    def info(self):
        info_dict = {
            "vocab_size": self.vocab_size,
            "n_embd": self.n_embd,
            "context_size": self.context_size
        }

        return info_dict

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T = x.shape
        C = self.n_embd

        #print(B, T, C)

        positions = torch.arange(start=0, end=T, step=1)
        token_emb = self.token_embedding_table(x) # batch_size x context_size x n_embd --> BxTxC
        pos_emb = self.positional_embedding_table(positions) # T x C (each context has a position)

        x = token_emb + pos_emb
        x = x.view(B*T, C)

        #logits = self.lm_head(token_emb.view(B*T, C)) # output of shape B*T x hidden_units1; this might be a problem for the labels since they are of shape B x T, thus they need to be reshaped aswell
        x = self.act_fn(self.linear1(token_emb.view(B*T, C))) # hidden_units1 x hidden_units2
        x = self.act_fn(self.linear2(x)) # hidden_units2 x hidden_units3
        x = self.linear3(x) # hidden_units3 x vocab_size
        return x

    def generate(self, starting_idx: torch.Tensor, max_length) -> torch.Tensor:

        full_text = itos[starting_idx.item()]
        for i in range(max_length):
            logits = self(starting_idx)
            percents = torch.softmax(logits, dim=1)
            pred = torch.multinomial(percents, num_samples=1)
            starting_idx = pred
            full_text += decode([pred.item()])
        return full_text


In [31]:
mlpv2 = MLPv2(vocab_size=vocab_size, n_embd=32, context_size=context_size)

In [32]:
mlpv2.info()

{'vocab_size': 65, 'n_embd': 32, 'context_size': 8}

In [33]:
mlpv2_loss_fn = nn.CrossEntropyLoss()
batch_sample_inputs, batch_sample_labels = next(iter(train_dataloader))
sample_input = batch_sample_inputs[0]
sample_label = batch_sample_labels[0]
print(sample_input) # 1 x T (B x T)
print(sample_label.view(1, -1)) # 1 x T (B x T)
mlpv2.eval()
with torch.inference_mode():
    logits = mlpv2(sample_input.view(1, -1))
    labels = sample_label.view(-1) # from B x T to B*T to match the shape of the logits
    print(logits.shape) # B*T x vocab_size
    print(labels.shape)

    loss = mlpv2_loss_fn(logits, labels)
    print(loss)

tensor([58, 46, 39, 52,  1, 47, 42, 50])
tensor([[46, 39, 52,  1, 47, 42, 50, 63]])
torch.Size([8, 65])
torch.Size([8])
tensor(4.1575)


In [34]:
@torch.no_grad
def generate_from_model(model, num_outputs, starting_char, max_length):
    mlpv2.eval()
    outputs = []
    starting_idx = torch.tensor([stoi[starting_char]], dtype=torch.long).view(1, -1)
    for i in range(num_outputs):
        output = mlpv2.generate(starting_idx=starting_idx, max_length=max_length) # must be batched
        outputs.append(output)

    return outputs

In [35]:
test_output = generate_from_model(model=mlpv2, num_outputs=1, starting_char="a", max_length=5)
print(test_output[0])

ah $
S


In [36]:
def train_model(model, dataloader, loss_fn, optimizer, epochs):
    model.train()

    for epoch in range(epochs):
        for batch, (X, y) in enumerate(dataloader):
            logits = model(X)
            loss = loss_fn(logits, y.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            if batch % 1200 == 0:
                print(f"loss for batch {batch} --> {loss} at epoch {epoch}")

    print(f"loss for the very last batch --> {loss}")

In [37]:
mlpv2_optimizer = torch.optim.Adam(params=mlpv2.parameters(), lr=1e-3)
mlpv2_loss_fn = nn.CrossEntropyLoss()

In [38]:
train_model(model=mlpv2, dataloader=train_dataloader, loss_fn=mlpv2_loss_fn, optimizer=mlpv2_optimizer, epochs=4)

loss for batch 0 --> 4.192663669586182 at epoch 0
loss for batch 1200 --> 2.4094653129577637 at epoch 0
loss for batch 2400 --> 2.5396831035614014 at epoch 0
loss for batch 0 --> 2.3995587825775146 at epoch 1
loss for batch 1200 --> 2.358095645904541 at epoch 1
loss for batch 2400 --> 2.5309691429138184 at epoch 1
loss for batch 0 --> 2.393231153488159 at epoch 2
loss for batch 1200 --> 2.344420909881592 at epoch 2
loss for batch 2400 --> 2.526885986328125 at epoch 2
loss for batch 0 --> 2.3893609046936035 at epoch 3
loss for batch 1200 --> 2.3379037380218506 at epoch 3
loss for batch 2400 --> 2.523819923400879 at epoch 3
loss for the very last batch --> 2.4213552474975586


In [39]:
test_outputs = generate_from_model(model=mlpv2, max_length=100, num_outputs=2, starting_char="b")
for output in test_outputs:
    print(f"{output}\n\n")

bus gom hesse'lo tof pintr howowan tigo pupr yor
pl,
Fonce, nt on nd neeit,
Tot nile gem hiticrm
S:
S


bart d
Arel; yond
The yor w o ateathe g, tthealle;
OLe pl hengiredof RCitr;
fo INUSIAncondis
Yot,
Who




In [40]:
torch.arange(start=0, end=T, step=1) # from 0 to T-1

tensor([0, 1, 2, 3, 4, 5, 6, 7])

# Self attention
- with xbow you can add information about the tokens, but the model itself does not attribute any weight to them. This is what self attention solves by using Keys, Queries and Values
- every token will have a specific Query (Q) and Key (K) attatched to it
    - Query --> what the model is looking for
    - Key --> the weight the model is giving to this certain token
    - Value --> Q @ K to get the affinities between themselves

In [41]:
# Code to allow comunication between past tokens

B, T, C = 2, 4, 8
wei = torch.zeros(size=(T, T))
tril = torch.tril(torch.ones(size=(T, T)))
wei = wei.masked_fill(tril==0, float('-inf'))
xbow = wei.softmax(dim=1)

test_tensor = torch.randn(size=(B, T, C))
output = xbow @ test_tensor
print(output.shape)

torch.Size([2, 4, 8])


In [42]:
head_size = 32
key = nn.Linear(in_features=C, out_features=head_size, bias=False) # bias = False so that it's just a multiplication
query = nn.Linear(in_features=C, out_features=head_size, bias=False) 

#print(test_tensor.shape)
k = key(test_tensor) # BxTxC @ BxTxhead_size --> BxTxhead_size; each batch has a context and each context character has a key value
q = query(test_tensor) # BxTxC @ BxTxhead_size --> BxTxhead_size; each batch has a context and each context character has a query value
print(k.shape)
print(q.shape)

torch.Size([2, 4, 32])
torch.Size([2, 4, 32])


In [43]:
print(k.transpose(-2, -1).shape) # same as k.permute(0, 2, 1)
print(k.permute(0, 2, 1).shape)

wei = q @ k.transpose(-2, -1) # BxTxhead_size @ Bxhead_sizexT --> BxTxT
print(wei.shape) #  B x T x T

tril = torch.tril(torch.ones(size=(T, T)))
wei = wei.masked_fill(tril==0, float('-inf'))
wei = torch.softmax(wei, dim=-1) # dim=-1 in this case, since wei is of shape

output = wei @ test_tensor

print(output.shape) # B x T x C

torch.Size([2, 32, 4])
torch.Size([2, 32, 4])
torch.Size([2, 4, 4])
torch.Size([2, 4, 8])


# Update the model with a self attention head

In [44]:
class Head(nn.Module):
    def __init__(self, head_size, n_embd):
        super(Head, self).__init__()

        self.Q = nn.Linear(in_features=n_embd, out_features=head_size)
        self.K = nn.Linear(in_features=n_embd, out_features=head_size)
        self.V = nn.Linear(in_features=n_embd, out_features=head_size)


    def forward(self, x: torch.Tensor, mask) -> torch.Tensor:
        B, T, C = x.shape
        #mask = torch.tril(torch.ones(size=(T, T)))

        q = self.Q(x) # B x T x head_size
        k = self.K(x) # B x T x head_size
        v = self.V(x) # B x T x head_size

        k = k.transpose(-2, -1) # B x head_size x T

        wei = q @ k # BxTxhead_size @ Bxhead_sizexT --> BxTxT

        wei = wei.masked_fill(mask==0, float('-inf'))
        wei = torch.softmax(wei, dim=-1)

        #x = wei @ v # BxTxT @ BxTxhead_size --> BxTxhead_size
        x = wei @ x

        return x



In [45]:
class MLPv3(nn.Module):
    def __init__(self, vocab_size, n_embd, context_size):
        super(MLPv3, self).__init__()

        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.context_size = context_size

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.positional_embedding_table = nn.Embedding(context_size, n_embd) # the position of each token has a separate table of values; this helps the model keep track of the order of characters

        self.head = Head(head_size=32, n_embd=n_embd)

        #self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size) # this time in_features=n_embd so it's not context_size dependant
        self.linear1 = nn.Linear(in_features=n_embd, out_features=8*8)
        self.linear2 = nn.Linear(in_features=8*8, out_features=8*8*8)
        self.linear3 = nn.Linear(in_features=8*8*8, out_features=vocab_size)

        self.act_fn = nn.Tanh()

        self.mask = torch.tril(torch.ones(size=(context_size, context_size)))

    def info(self):
        info_dict = {
            "vocab_size": self.vocab_size,
            "n_embd": self.n_embd,
            "context_size": self.context_size
        }

        return info_dict

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T = x.shape
        C = self.n_embd

        #print(B, T, C)

        positions = torch.arange(start=0, end=T, step=1)
        token_emb = self.token_embedding_table(x) # batch_size x context_size x n_embd --> BxTxC
        pos_emb = self.positional_embedding_table(positions) # T x C (each context has a position)


        x = token_emb + pos_emb
        x = self.head(x, self.mask)
        x = x.view(B*T, C)

        return x

    def generate(self, starting_idx: torch.Tensor, max_length) -> torch.Tensor:

        full_text = itos[starting_idx.item()]
        for i in range(max_length):
            logits = self(starting_idx)
            percents = torch.softmax(logits, dim=1)
            pred = torch.multinomial(percents, num_samples=1)
            starting_idx = pred
            full_text += decode([pred.item()])
        return full_text


In [46]:
n_embd = 32
mlpv3 = MLPv3(vocab_size=vocab_size, n_embd=n_embd, context_size=context_size)
mlpv3.optimizer = torch.optim.Adam(params=mlpv3.parameters(), lr=1e-2)
mlpv3_loss_fn = nn.CrossEntropyLoss()
mlpv3.info()

{'vocab_size': 65, 'n_embd': 32, 'context_size': 8}

In [47]:
outputs = generate_from_model(model=mlpv3, max_length=200, num_outputs=2, starting_char="a")
for output in outputs:
    print(f"{output} \n\n")

a w'dinforaso s?
MENINat m navo cive tasu,
Thed beey gug.

Ans bases wesulotingre tyoowoulorde
Wir igorieethhonong
Fimuy holy antin,, mo f o, t d teiothaipoulebowasllealid
aleoraminls trallert n s me u 


an he y terasere ise,
Or s; sum ukenaik:
Frat s.
VOn ano thowiuth aldrs
ORI coofein.
Bu d brony tuollearilithue tithen;
Troncendse w
Thirs ccotee d powavot upe t spr.
S:
Thod,
MIUS:

And ck-bendowathat 




In [48]:
#train_model(model=mlpv3, dataloader=train_dataloader, loss_fn=mlpv3_loss_fn, optimizer=mlpv2_optimizer, epochs=1)

In [49]:
outputs = generate_from_model(model=mlpv3, max_length=200, num_outputs=2, starting_char="a")
for output in outputs:
    print(f"{output} \n\n")

anves yoofaty hous,

Foucet mer us ss br,
TUMENERUSil,
Soilldy iouptstcorein s she hear sond y y ntesy bloron n t athis celo cr h,
Vondy sppr
n ine te f tt s miroreerushat plllon:
shear
Gou st;

At f t 


af t sce;
Frid as
ARUS:
Totous thandrdecraber t s ind beleativersteal ched r uvilendizend hendst d m toizel, brd,
Buede feres d hakldencean:
Coun.
Th bu aitor bur. st.
NI tinshoneespas l grothyord as.
 




# Self-attetion x cross-attention
- in self-attention the values for queries (Q), keys (K) and values (V) all come from x itself, thus self-attention
- in cross-attention those values can come from somewhere else