<a href="https://www.kaggle.com/code/evelynartoria/decoder-transformer-model-from-scratch-pytorch?scriptVersionId=187738267" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!mkdir ./models

In [2]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from tqdm import tqdm

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device=device)
torch.set_default_device(device)
print(f"default device set to {device}")

default device set to cuda


In [4]:
with open("/kaggle/input/shakespeare/input.txt", 'r', encoding="utf-8") as f:
    text = f.read()

In [5]:
vocab = sorted(set(text))
vocab_size = len(vocab)

In [6]:
stoi = {c: v for v, c in enumerate(vocab)}
itos = {v: c for c, v in stoi.items()}

print(stoi["h"])
print(itos[46])

46
h


In [7]:
encode = lambda e: [stoi[ch] for ch in e]
decode = lambda d: "".join([itos[idx] for idx in d])

encoded = encode("hello how are you?")
decoded = decode(encoded)

print(encoded, decoded)

[46, 43, 50, 50, 53, 1, 46, 53, 61, 1, 39, 56, 43, 1, 63, 53, 59, 12] hello how are you?


In [8]:
context_size = 128
random_idx_tensor = torch.randperm(10000//context_size) * context_size
print(random_idx_tensor)

tensor([6656, 2304,  896, 5888, 8704, 2816, 2560, 9472, 3712, 9728, 8448, 8064,
        3968, 8576, 4992, 8960, 8192, 6272, 4608, 3200, 3072, 6144, 3840, 9856,
        9088, 1792, 2176, 2688, 3584, 6400, 9216, 6912, 5120, 7424, 4864, 7808,
        7168, 7040, 2432, 4352, 5248, 2048, 4480, 7680,  768, 7552, 4096, 1920,
         256, 7296, 9344, 1664, 5632, 2944, 5504, 4224, 3328, 8320, 4736, 1152,
        9600, 7936,  640,  512, 6784,  128, 1280,    0, 5376, 1024, 3456, 5760,
         384, 1408, 6016, 8832, 6528, 1536], device='cuda:0')


In [9]:
def make_dataset(data):
    random_idx_tensor = torch.randperm((len(data)-context_size)//context_size) * context_size
    inputs = torch.stack([data[idx:idx+context_size] for idx in random_idx_tensor])
    labels = torch.stack([data[idx+1:idx+context_size+1] for idx in random_idx_tensor])
    
    return TensorDataset(inputs.to(torch.long), labels.to(torch.long))

In [10]:
data = torch.tensor(encode(text))
dataset = make_dataset(data=data)
sample_input = dataset[0][0]
sample_label = dataset[0][1]

print(sample_input)
print(sample_label)

print(f"dataset length --> {len(dataset)} ({len(dataset) * context_size} characters), that is, about the length of text {len(text)} - context_size --> {len(text)-context_size}")

tensor([43,  5, 57,  1, 58, 53,  1, 57, 53, 61,  8,  0,  0, 28, 56, 53, 60, 53,
        57, 58, 10,  0, 15, 53, 51, 43,  1, 46, 47, 58, 46, 43, 56,  6,  1, 57,
        47, 56, 56, 39, 46,  8,  1, 15, 39, 52,  1, 63, 53, 59,  1, 41, 59, 58,
         1, 53, 44, 44,  1, 39,  1, 51, 39, 52,  5, 57,  1, 46, 43, 39, 42, 12,
         0,  0, 28, 27, 25, 28, 17, 37, 10,  0, 21, 44,  1, 58, 46, 43,  1, 51,
        39, 52,  1, 40, 43,  1, 39,  1, 40, 39, 41, 46, 43, 50, 53, 56,  6,  1,
        57, 47, 56,  6,  1, 21,  1, 41, 39, 52, 11,  1, 40, 59, 58,  1, 47, 44,
         1, 46], device='cuda:0')
tensor([ 5, 57,  1, 58, 53,  1, 57, 53, 61,  8,  0,  0, 28, 56, 53, 60, 53, 57,
        58, 10,  0, 15, 53, 51, 43,  1, 46, 47, 58, 46, 43, 56,  6,  1, 57, 47,
        56, 56, 39, 46,  8,  1, 15, 39, 52,  1, 63, 53, 59,  1, 41, 59, 58,  1,
        53, 44, 44,  1, 39,  1, 51, 39, 52,  5, 57,  1, 46, 43, 39, 42, 12,  0,
         0, 28, 27, 25, 28, 17, 37, 10,  0, 21, 44,  1, 58, 46, 43,  1, 51, 39,
      

In [11]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)-train_split)

train_dataset, test_dataset = random_split(dataset=dataset, lengths=[train_split, test_split], generator=generator)

In [12]:
batch_size = 32
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, generator=generator)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, generator=generator)

In [13]:
class Head(nn.Module):
    def __init__(self, n_embd, head_size, context_size):
        super(Head, self).__init__()
        
        self.Q = nn.Linear(in_features=n_embd, out_features=head_size) # takes in BxTxC and return BxTxHead_size
        self.K = nn.Linear(in_features=n_embd, out_features=head_size)
        self.V = nn.Linear(in_features=n_embd, out_features=head_size)
        
        self.register_buffer("tril", torch.tril(torch.ones(size=(context_size, context_size))))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape # batch_size by context_size by n_embd
        q = self.Q(x) # BxTxHead_size
        k = self.K(x) # BxTxHead_size
        
        wei = q @ k.transpose(-2, -1) * (C ** -0.5) # BxTxHead_size @ BxHead_sizexT --> BxTxT then divided by the square root of n_embd
        
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf')) # :T and :T is needed in case context is smaller than context_size
        wei = torch.softmax(wei, dim=-1)
        v = self.V(x) # BxTxHead_size
        output = wei @ v # BxTxT @ BxTxHead_size --> BxTxHead_size
        
        return output

In [14]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, n_embd, context_size, n_heads, head_size):
        super(MultiHeadedAttention, self).__init__()
        self.heads = nn.ModuleList([Head(n_embd=n_embd, head_size=head_size, context_size=context_size) for _ in range(n_heads)]) # BxTx (n_heads * head_size)
        self.projection = nn.Linear(in_features=n_heads*head_size, out_features=n_embd) # ensures the output is going to be o shape BxTxn_embd (BxTxC) so that is can go through multiple attention block
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.cat([head(x) for head in self.heads], dim=-1) # cat in the Channels dimension; output shape is BxTx (n_heads * head_size)
        x = self.projection(x)
        return  x

In [15]:
class FeedForward(nn.Module):
    def __init__(self, in_features):
        super(FeedForward, self).__init__()
        self.ffwrd_layer = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=in_features * 4), # scale by 4, according to the attention is all you need paper
            nn.ReLU(),
            nn.Linear(in_features=in_features * 4, out_features=in_features) # another projection layer
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.ffwrd_layer(x)

In [16]:
class Block(nn.Module):
    def __init__(self, n_heads, head_size, n_embd, context_size):
        super(Block, self).__init__()
        self.multiheaded_self_attetion = MultiHeadedAttention(n_embd=n_embd, context_size=context_size, n_heads=n_heads, head_size=head_size) # create a multiheaded attention block; returns shape BxTx (num_heads*head_size)
        self.ffwrd = FeedForward(in_features=n_embd)
        self.layer_norm1 = nn.LayerNorm(n_embd)
        self.layer_norm2 = nn.LayerNorm(n_embd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.multiheaded_self_attetion(self.layer_norm1(x))
        x = x + self.ffwrd(self.layer_norm2(x))

        return x

In [17]:
class Decoder(nn.Module):
    def __init__(self, n_embd, context_size, vocab_size, num_sa_heads, sa_head_size):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.context_size = context_size
        self.n_embd = n_embd
        
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # each character from the vocab has n_embd values associated to it
        self.positional_embedding_table = nn.Embedding(context_size, n_embd) # each character position in the context has n_embd values associated to it
        
        self.attention_blocks = nn.Sequential(
            Block(n_heads=num_sa_heads, head_size=sa_head_size, context_size=self.context_size, n_embd=self.n_embd), # takes in BxTxC, calculate logits of BxTx (num_heads * head_size), then project it as BxTxC
            Block(n_heads=num_sa_heads, head_size=sa_head_size, context_size=self.context_size, n_embd=self.n_embd),
            Block(n_heads=num_sa_heads, head_size=sa_head_size, context_size=self.context_size, n_embd=self.n_embd),
            nn.LayerNorm(n_embd) # normalize the layers
        )
        

        self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size) # (B, T, vocab_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T = x.shape # batch_size and context_size
        positions = torch.arange(start=0, end=T, step=1)
        
        pos_emb = self.positional_embedding_table(positions) # T x C --> in broadcasting, pytorch adds a batch dim=1
        token_emb = self.token_embedding_table(x) # B x T x C
        
        x = token_emb + pos_emb # BxTxC
        x = self.attention_blocks(x) # returns logits of shape BxTx (self.sa_head_size * self.num_sa_heads) projected to BxTxC
        x = self.lm_head(x) # BxTxvocab_size --> BxTxHead_size @ BxTxVocab_size return BxTxVocab_size

        return x.view(B*T, self.vocab_size) # easier shape to work with the labels
    
    def generate(self, starting_idx: torch.Tensor, max_length: int, debug: bool) -> torch.Tensor:
        full_text = decode([starting_idx.item()])
        context = starting_idx
        
        for _ in range(max_length):
            context = context[:, -self.context_size:] # make sure the context is of size context_size
            
            if debug:
                print(f"predicting on context: {decode(context[0].tolist())}")
            
            logits = self(context) # B*T x vocab_size --> 1*2 x vocab_size
            logits = logits[-1, :].view(1, self.vocab_size) # only take the prediction for the last character
            percents = torch.softmax(logits, dim=1) # 1*2xvocab_size
            pred = torch.multinomial(percents, num_samples=1) 
            full_text += decode(pred.tolist()[0])
            context = torch.cat([context, pred], dim=1) # add to the context dimension instead of the batch dim
            
        return full_text


In [18]:
class model_generator:
    def __init__(self, model: object, max_length: int, num_samples: int, vocab_size: int):
        self.model = model
        self.max_length = max_length
        self.num_samples = num_samples
        self.vocab_size = vocab_size
        
        self.last_output = ""
        
        self.params_dict = {
            "model": model,
            "max_length": max_length,
            "num_samples": num_samples,
            "previous_outputs": []
        }
    
    @torch.no_grad
    def generate(self, starting_char: str = None, clear_outputs: bool = True, debug: bool = False):
        self.model.eval()
        
        if clear_outputs:
            self.clear_ouptuts()
            
        if starting_char is None:
            starting_char = decode([torch.randint(0, vocab_size, (1,)).item()])
            
        for _ in range(self.num_samples):
            starting_idx = torch.tensor(encode(starting_char), dtype=torch.long).view(1, 1)
            output = self.model.generate(starting_idx=starting_idx, max_length=self.max_length, debug=debug)
            self.params_dict["previous_outputs"].append(output)
            self.last_output = output
    
    def update_params(self, model: object = None, max_length: int = None, num_samples: int = None, clear_outputs: bool = None):
        if clear_outputs:
            self.clear_outputs()
            
        updated_dict = {
            "model": model,
            "max_length": max_length,
            "num_samples": num_samples
        }
        
        for attribute, value in updated_dict.items():
            if value is not None:
                self.params_dict[attribute] = value
                setattr(self, attribute, value)
    
    def clear_ouptuts(self):
        self.params_dict["previous_outputs"] = []
        self.last_output = ""
        
    def print_outputs(self, last: bool = None):
        if last:
            print(self.last_output)
        else:
            for output in self.params_dict["previous_outputs"]:
                print(f"{output}\n\n")

In [19]:
n_embd = 784
vocab_size = len(vocab)
context_size = 128 # same as previously set
num_sa_heads = 16
sa_head_size = 64

decoder = Decoder(vocab_size=vocab_size, n_embd=n_embd, context_size=context_size, num_sa_heads=num_sa_heads, sa_head_size=sa_head_size)
optimizer = torch.optim.Adam(params=decoder.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [20]:
decoder_generator = model_generator(model=decoder, max_length=32, num_samples=1, vocab_size=vocab_size)

In [21]:
#decoder_generator.update_params(max_length=100, num_samples=1)
decoder_generator.generate()
decoder_generator.print_outputs()

plMoNKZi$KabPMAAEgD;KP
z$ ;X
dBi





In [22]:
decoder.eval()
test_loss_fn = nn.CrossEntropyLoss()
batch_sample_inputs, batch_sample_labels = next(iter(train_dataloader))
with torch.inference_mode():
    #print(batch_sample_inputs)
    #print(batch_sample_labels)
    logits = decoder(batch_sample_inputs)
    labels = batch_sample_labels.view(-1) # turns the label shape into a B*T
    #print(logits) # 4 batches of 8 characters each, the model is trying to predict the next sequence
    #print(labels)
    loss = test_loss_fn(logits, batch_sample_labels.view(-1))
    print(loss)

tensor(4.3012, device='cuda:0')


In [23]:
def train_model(model, dataloader, loss_fn, optimizer, epochs):
    model.train()
    
    for epoch in range(epochs):
        for batch, (X, y) in tqdm(enumerate(dataloader)):
            logits = model(X) # shape of B*T x vocab_size
            labels = y.view(-1) # shape of B*T --> each character has it's own prediction
            loss = loss_fn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch % 100 == 0:
                print(f"loss for batch {batch} --> {loss} at epoch {epoch}")

In [24]:
train_model(model=decoder, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer, epochs=20)

1it [00:00,  2.38it/s]

loss for batch 0 --> 4.3230791091918945 at epoch 0


101it [00:22,  4.43it/s]

loss for batch 100 --> 2.4952898025512695 at epoch 0


201it [00:45,  4.33it/s]

loss for batch 200 --> 2.2436976432800293 at epoch 0


205it [00:46,  4.43it/s]
1it [00:00,  4.32it/s]

loss for batch 0 --> 2.181868314743042 at epoch 1


101it [00:23,  4.21it/s]

loss for batch 100 --> 1.995511770248413 at epoch 1


201it [00:47,  4.09it/s]

loss for batch 200 --> 1.72166109085083 at epoch 1


205it [00:48,  4.24it/s]
1it [00:00,  4.11it/s]

loss for batch 0 --> 1.825104832649231 at epoch 2


101it [00:24,  3.97it/s]

loss for batch 100 --> 1.698589563369751 at epoch 2


201it [00:50,  3.96it/s]

loss for batch 200 --> 1.6682977676391602 at epoch 2


205it [00:51,  4.01it/s]
1it [00:00,  3.92it/s]

loss for batch 0 --> 1.5682876110076904 at epoch 3


101it [00:25,  4.05it/s]

loss for batch 100 --> 1.6100648641586304 at epoch 3


201it [00:49,  3.99it/s]

loss for batch 200 --> 1.6017972230911255 at epoch 3


205it [00:50,  4.05it/s]
1it [00:00,  4.03it/s]

loss for batch 0 --> 1.4753879308700562 at epoch 4


101it [00:25,  3.99it/s]

loss for batch 100 --> 1.415902018547058 at epoch 4


201it [00:50,  4.01it/s]

loss for batch 200 --> 1.4691063165664673 at epoch 4


205it [00:51,  4.01it/s]
1it [00:00,  4.04it/s]

loss for batch 0 --> 1.4025795459747314 at epoch 5


101it [00:25,  4.01it/s]

loss for batch 100 --> 1.4091607332229614 at epoch 5


201it [00:49,  4.01it/s]

loss for batch 200 --> 1.4854360818862915 at epoch 5


205it [00:50,  4.04it/s]
1it [00:00,  4.01it/s]

loss for batch 0 --> 1.3166743516921997 at epoch 6


101it [00:25,  3.98it/s]

loss for batch 100 --> 1.3324304819107056 at epoch 6


201it [00:50,  3.99it/s]

loss for batch 200 --> 1.3443679809570312 at epoch 6


205it [00:51,  4.01it/s]
1it [00:00,  3.98it/s]

loss for batch 0 --> 1.2808433771133423 at epoch 7


101it [00:25,  4.00it/s]

loss for batch 100 --> 1.3003160953521729 at epoch 7


201it [00:50,  4.00it/s]

loss for batch 200 --> 1.3413220643997192 at epoch 7


205it [00:50,  4.02it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 1.242486596107483 at epoch 8


101it [00:25,  4.00it/s]

loss for batch 100 --> 1.2870001792907715 at epoch 8


201it [00:50,  4.00it/s]

loss for batch 200 --> 1.2393996715545654 at epoch 8


205it [00:50,  4.02it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 1.1703169345855713 at epoch 9


101it [00:25,  3.99it/s]

loss for batch 100 --> 1.2035621404647827 at epoch 9


201it [00:50,  4.01it/s]

loss for batch 200 --> 1.2139416933059692 at epoch 9


205it [00:50,  4.03it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 1.0516431331634521 at epoch 10


101it [00:25,  4.01it/s]

loss for batch 100 --> 1.1842637062072754 at epoch 10


201it [00:49,  4.01it/s]

loss for batch 200 --> 1.1409140825271606 at epoch 10


205it [00:50,  4.03it/s]
1it [00:00,  4.03it/s]

loss for batch 0 --> 1.0247409343719482 at epoch 11


101it [00:25,  4.01it/s]

loss for batch 100 --> 1.0253081321716309 at epoch 11


201it [00:49,  4.01it/s]

loss for batch 200 --> 1.1303226947784424 at epoch 11


205it [00:50,  4.04it/s]
1it [00:00,  3.97it/s]

loss for batch 0 --> 0.88817298412323 at epoch 12


101it [00:25,  4.00it/s]

loss for batch 100 --> 0.992388904094696 at epoch 12


201it [00:49,  4.01it/s]

loss for batch 200 --> 1.0694316625595093 at epoch 12


205it [00:50,  4.04it/s]
1it [00:00,  4.02it/s]

loss for batch 0 --> 0.8838292360305786 at epoch 13


101it [00:25,  4.01it/s]

loss for batch 100 --> 0.9680342078208923 at epoch 13


201it [00:49,  4.02it/s]

loss for batch 200 --> 0.9759601354598999 at epoch 13


205it [00:50,  4.04it/s]
1it [00:00,  3.95it/s]

loss for batch 0 --> 0.7714415788650513 at epoch 14


101it [00:25,  4.00it/s]

loss for batch 100 --> 0.8669955730438232 at epoch 14


201it [00:50,  4.01it/s]

loss for batch 200 --> 0.898356556892395 at epoch 14


205it [00:50,  4.03it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 0.7013464570045471 at epoch 15


101it [00:25,  4.01it/s]

loss for batch 100 --> 0.721276044845581 at epoch 15


201it [00:49,  4.01it/s]

loss for batch 200 --> 0.8632165193557739 at epoch 15


205it [00:50,  4.04it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 0.6010428667068481 at epoch 16


101it [00:25,  4.02it/s]

loss for batch 100 --> 0.6803231239318848 at epoch 16


201it [00:49,  4.02it/s]

loss for batch 200 --> 0.7550438642501831 at epoch 16


205it [00:50,  4.03it/s]
1it [00:00,  3.97it/s]

loss for batch 0 --> 0.4882310628890991 at epoch 17


101it [00:25,  4.01it/s]

loss for batch 100 --> 0.6192662715911865 at epoch 17


201it [00:49,  4.02it/s]

loss for batch 200 --> 0.6561490297317505 at epoch 17


205it [00:50,  4.03it/s]
1it [00:00,  4.04it/s]

loss for batch 0 --> 0.42731937766075134 at epoch 18


101it [00:25,  4.01it/s]

loss for batch 100 --> 0.494682639837265 at epoch 18


201it [00:49,  4.00it/s]

loss for batch 200 --> 0.593704104423523 at epoch 18


205it [00:50,  4.04it/s]
1it [00:00,  3.93it/s]

loss for batch 0 --> 0.36597058176994324 at epoch 19


101it [00:25,  4.02it/s]

loss for batch 100 --> 0.4475789964199066 at epoch 19


201it [00:49,  4.02it/s]

loss for batch 200 --> 0.49575209617614746 at epoch 19


205it [00:50,  4.04it/s]


In [25]:
print("generating 5 sample of 500 characters each")

decoder_generator.update_params(max_length=500, num_samples=5)
decoder_generator.generate()
decoder_generator.print_outputs()

generating 5 sample of 500 characters each
u, since you mercy dent you, sir?

of I thank you.

DUCHESS OF YORK:
My lord sob deservice, but not lord; and now,
Be cold Mowbray he for he standing to bewish;
And think ither lets him an a frant, and hour
Let him by nally. What with all men a!--

PETRUCHIO:
O let's stridle! There with all not come,' sweet:
All thing she is breakness to the limb,
And a gone and forcer and ro age.

FRIAR LAURENCE:
I cannot me: better was you.
This my demand you.

ROMMO:
This lord: he is is the man conservest, sir


umber gent for plaints,
Your discontenting, and grace the quality;
Your poor is good dares not and when men
Romansport the sale way; I will go, sir.

ROMEO:
A lean good this is the where is sweet lies: if the will
will Ie let me lappy thee: if an he stay, he's
But lets it noy the such a maken nentreature
Hell I cominius 'dies '' the land. Your move,
A mong and enough even so man alliant home;
And fall he stay it be so walk'd with Lord And
That he an

# Save and load the model

In [26]:
# save the model
torch.save(decoder.state_dict(), "./models/shakespeare_like_text_generator.pt")

In [27]:
n_embd = 784
vocab_size = len(vocab)
context_size = 128 # same as previously set
num_sa_heads = 16
sa_head_size = 64

test_model = Decoder(vocab_size=vocab_size, n_embd=n_embd, context_size=context_size, num_sa_heads=num_sa_heads, sa_head_size=sa_head_size)

test_model.load_state_dict(torch.load("./models/shakespeare_like_text_generator.pt"))
test_model_generator = model_generator(model=test_model, max_length=320, num_samples=1, vocab_size=vocab_size) 
test_model_generator.generate()
test_model_generator.print_outputs()


xoationed your soled that would within
To action a pitcheon, whom your mother silence!

BIANCA.

PETRUCHIO:
The shall be thee end forgot you that said:
You shall with all will in bring in Hereford.

BARNIA:
Comess me thee, gentlement, general,
Alas, make my apping you? 
FRIAR LA:
My grant sheet cut not shut
I not forg t


