<a href="https://www.kaggle.com/code/evelynartoria/decoder-transformer-model-from-scratch-pytorch?scriptVersionId=187728602" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!mkdir ./models

In [2]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from tqdm import tqdm

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device=device)
torch.set_default_device(device)
print(f"default device set to {device}")

default device set to cuda


In [4]:
with open("/kaggle/input/shakespeare/input.txt", 'r', encoding="utf-8") as f:
    text = f.read()

In [5]:
vocab = sorted(set(text))
vocab_size = len(vocab)

In [6]:
stoi = {c: v for v, c in enumerate(vocab)}
itos = {v: c for c, v in stoi.items()}

print(stoi["h"])
print(itos[46])

46
h


In [7]:
encode = lambda e: [stoi[ch] for ch in e]
decode = lambda d: "".join([itos[idx] for idx in d])

encoded = encode("hello how are you?")
decoded = decode(encoded)

print(encoded, decoded)

[46, 43, 50, 50, 53, 1, 46, 53, 61, 1, 39, 56, 43, 1, 63, 53, 59, 12] hello how are you?


In [8]:
context_size = 128
random_idx_tensor = torch.randperm(10000//context_size) * context_size
print(random_idx_tensor)

tensor([ 256, 7168, 2560, 6272, 6784, 1152, 7296, 1536, 4864, 4480, 9856, 6144,
         128, 9728, 2176, 9344, 8576, 5632, 8960, 9216, 6656, 4736, 3072, 5120,
         640, 1408, 5248, 6400, 8448, 3584,  384, 4992,  896, 6528, 2432, 7936,
        6016, 2048, 3456, 7808, 6912, 4608, 9600, 3840, 4224, 1920, 2944, 5888,
        2688, 1280,    0, 2816, 3968, 8192, 1792, 5504, 8704, 7040, 8320, 5376,
        4096, 3200, 5760, 9088, 2304, 7424, 1664, 3712, 9472, 7552, 3328, 8064,
        7680,  512,  768, 8832, 4352, 1024], device='cuda:0')


In [9]:
def make_dataset(data):
    random_idx_tensor = torch.randperm((len(data)-context_size)//context_size) * context_size
    inputs = torch.stack([data[idx:idx+context_size] for idx in random_idx_tensor])
    labels = torch.stack([data[idx+1:idx+context_size+1] for idx in random_idx_tensor])
    
    return TensorDataset(inputs.to(torch.long), labels.to(torch.long))

In [10]:
data = torch.tensor(encode(text))
dataset = make_dataset(data=data)
sample_input = dataset[0][0]
sample_label = dataset[0][1]

print(sample_input)
print(sample_label)

print(f"dataset length --> {len(dataset)} ({len(dataset) * context_size} characters), that is, about the length of text {len(text)} - context_size --> {len(text)-context_size}")

tensor([56, 42, 57,  6,  1, 61, 46, 43, 52,  1, 63, 53, 59,  1, 57, 46, 39, 50,
        50,  1, 49, 52, 53, 61,  7,  7, 39, 57,  1, 47, 52,  1, 58, 46, 47, 57,
         1, 56, 39, 45, 43,  6,  0, 28, 56, 53, 60, 53, 49, 43, 42,  1, 40, 63,
         1, 46, 47, 51,  6,  1, 63, 53, 59,  1, 41, 39, 52, 52, 53, 58,  7,  7,
        58, 46, 43,  1, 45, 56, 43, 39, 58,  1, 42, 39, 52, 45, 43, 56,  0, 35,
        46, 47, 41, 46,  1, 58, 46, 47, 57,  1, 51, 39, 52,  5, 57,  1, 50, 47,
        44, 43,  1, 42, 47, 42,  1, 53, 61, 43,  1, 63, 53, 59,  6,  1, 63, 53,
        59,  5], device='cuda:0')
tensor([42, 57,  6,  1, 61, 46, 43, 52,  1, 63, 53, 59,  1, 57, 46, 39, 50, 50,
         1, 49, 52, 53, 61,  7,  7, 39, 57,  1, 47, 52,  1, 58, 46, 47, 57,  1,
        56, 39, 45, 43,  6,  0, 28, 56, 53, 60, 53, 49, 43, 42,  1, 40, 63,  1,
        46, 47, 51,  6,  1, 63, 53, 59,  1, 41, 39, 52, 52, 53, 58,  7,  7, 58,
        46, 43,  1, 45, 56, 43, 39, 58,  1, 42, 39, 52, 45, 43, 56,  0, 35, 46,
      

In [11]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)-train_split)

train_dataset, test_dataset = random_split(dataset=dataset, lengths=[train_split, test_split], generator=generator)

In [12]:
batch_size = 32
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, generator=generator)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, generator=generator)

In [13]:
class Head(nn.Module):
    def __init__(self, n_embd, head_size, context_size):
        super(Head, self).__init__()
        
        self.Q = nn.Linear(in_features=n_embd, out_features=head_size) # takes in BxTxC and return BxTxHead_size
        self.K = nn.Linear(in_features=n_embd, out_features=head_size)
        self.V = nn.Linear(in_features=n_embd, out_features=head_size)
        
        self.register_buffer("tril", torch.tril(torch.ones(size=(context_size, context_size))))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape # batch_size by context_size by n_embd
        q = self.Q(x) # BxTxHead_size
        k = self.K(x) # BxTxHead_size
        
        wei = q @ k.transpose(-2, -1) * (C ** -0.5) # BxTxHead_size @ BxHead_sizexT --> BxTxT then divided by the square root of n_embd
        
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf')) # :T and :T is needed in case context is smaller than context_size
        wei = torch.softmax(wei, dim=-1)
        v = self.V(x) # BxTxHead_size
        output = wei @ v # BxTxT @ BxTxHead_size --> BxTxHead_size
        
        return output

In [14]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, n_embd, context_size, n_heads, head_size):
        super(MultiHeadedAttention, self).__init__()
        self.heads = nn.ModuleList([Head(n_embd=n_embd, head_size=head_size, context_size=context_size) for _ in range(n_heads)]) # BxTx (n_heads * head_size)
        self.projection = nn.Linear(in_features=n_heads*head_size, out_features=n_embd) # ensures the output is going to be o shape BxTxn_embd (BxTxC) so that is can go through multiple attention block
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.cat([head(x) for head in self.heads], dim=-1) # cat in the Channels dimension; output shape is BxTx (n_heads * head_size)
        x = self.projection(x)
        return  x

In [15]:
class FeedForward(nn.Module):
    def __init__(self, in_features):
        super(FeedForward, self).__init__()
        self.ffwrd_layer = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=in_features * 4), # scale by 4, according to the attention is all you need paper
            nn.ReLU(),
            nn.Linear(in_features=in_features * 4, out_features=in_features) # another projection layer
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.ffwrd_layer(x)

In [16]:
class Block(nn.Module):
    def __init__(self, n_heads, head_size, n_embd, context_size):
        super(Block, self).__init__()
        self.multiheaded_self_attetion = MultiHeadedAttention(n_embd=n_embd, context_size=context_size, n_heads=n_heads, head_size=head_size) # create a multiheaded attention block; returns shape BxTx (num_heads*head_size)
        self.ffwrd = FeedForward(in_features=n_embd)
        self.layer_norm1 = nn.LayerNorm(n_embd)
        self.layer_norm2 = nn.LayerNorm(n_embd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.multiheaded_self_attetion(self.layer_norm1(x))
        x = x + self.ffwrd(self.layer_norm2(x))

        return x

In [17]:
class Decoder(nn.Module):
    def __init__(self, n_embd, context_size, vocab_size, num_sa_heads, sa_head_size):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.context_size = context_size
        self.n_embd = n_embd
        
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # each character from the vocab has n_embd values associated to it
        self.positional_embedding_table = nn.Embedding(context_size, n_embd) # each character position in the context has n_embd values associated to it
        
        self.attention_blocks = nn.Sequential(
            Block(n_heads=num_sa_heads, head_size=sa_head_size, context_size=self.context_size, n_embd=self.n_embd), # takes in BxTxC, calculate logits of BxTx (num_heads * head_size), then project it as BxTxC
            Block(n_heads=num_sa_heads, head_size=sa_head_size, context_size=self.context_size, n_embd=self.n_embd),
            Block(n_heads=num_sa_heads, head_size=sa_head_size, context_size=self.context_size, n_embd=self.n_embd),
            nn.LayerNorm(n_embd) # normalize the layers
        )
        

        self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size) # (B, T, vocab_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T = x.shape # batch_size and context_size
        positions = torch.arange(start=0, end=T, step=1)
        
        pos_emb = self.positional_embedding_table(positions) # T x C --> in broadcasting, pytorch adds a batch dim=1
        token_emb = self.token_embedding_table(x) # B x T x C
        
        x = token_emb + pos_emb # BxTxC
        x = self.attention_blocks(x) # returns logits of shape BxTx (self.sa_head_size * self.num_sa_heads) projected to BxTxC
        x = self.lm_head(x) # BxTxvocab_size --> BxTxHead_size @ BxTxVocab_size return BxTxVocab_size

        return x.view(B*T, self.vocab_size) # easier shape to work with the labels
    
    def generate(self, starting_idx: torch.Tensor, max_length: int, debug: bool) -> torch.Tensor:
        full_text = decode([starting_idx.item()])
        context = starting_idx
        
        for _ in range(max_length):
            context = context[:, -self.context_size:] # make sure the context is of size context_size
            
            if debug:
                print(f"predicting on context: {decode(context[0].tolist())}")
            
            logits = self(context) # B*T x vocab_size --> 1*2 x vocab_size
            logits = logits[-1, :].view(1, self.vocab_size) # only take the prediction for the last character
            percents = torch.softmax(logits, dim=1) # 1*2xvocab_size
            pred = torch.multinomial(percents, num_samples=1) 
            full_text += decode(pred.tolist()[0])
            context = torch.cat([context, pred], dim=1) # add to the context dimension instead of the batch dim
            
        return full_text


In [18]:
class model_generator:
    def __init__(self, model: object, max_length: int, num_samples: int, vocab_size: int):
        self.model = model
        self.max_length = max_length
        self.num_samples = num_samples
        self.vocab_size = vocab_size
        
        self.last_output = ""
        
        self.params_dict = {
            "model": model,
            "max_length": max_length,
            "num_samples": num_samples,
            "previous_outputs": []
        }
    
    @torch.no_grad
    def generate(self, starting_char: str = None, clear_outputs: bool = True, debug: bool = False):
        self.model.eval()
        
        if clear_outputs:
            self.clear_ouptuts()
            
        if starting_char is None:
            starting_char = decode([torch.randint(0, vocab_size, (1,)).item()])
            
        for _ in range(self.num_samples):
            starting_idx = torch.tensor(encode(starting_char), dtype=torch.long).view(1, 1)
            output = self.model.generate(starting_idx=starting_idx, max_length=self.max_length, debug=debug)
            self.params_dict["previous_outputs"].append(output)
            self.last_output = output
    
    def update_params(self, model: object = None, max_length: int = None, num_samples: int = None, clear_outputs: bool = None):
        if clear_outputs:
            self.clear_outputs()
            
        updated_dict = {
            "model": model,
            "max_length": max_length,
            "num_samples": num_samples
        }
        
        for attribute, value in updated_dict.items():
            if value is not None:
                self.params_dict[attribute] = value
                setattr(self, attribute, value)
    
    def clear_ouptuts(self):
        self.params_dict["previous_outputs"] = []
        self.last_output = ""
        
    def print_outputs(self, last: bool = None):
        if last:
            print(self.last_output)
        else:
            for output in self.params_dict["previous_outputs"]:
                print(f"{output}\n\n")

In [19]:
n_embd = 784
vocab_size = len(vocab)
context_size = 128 # same as previously set
num_sa_heads = 16
sa_head_size = 64

decoder = Decoder(vocab_size=vocab_size, n_embd=n_embd, context_size=context_size, num_sa_heads=num_sa_heads, sa_head_size=sa_head_size)
optimizer = torch.optim.Adam(params=decoder.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [20]:
decoder_generator = model_generator(model=decoder, max_length=32, num_samples=1, vocab_size=vocab_size)

In [21]:
#decoder_generator.update_params(max_length=100, num_samples=1)
decoder_generator.generate()
decoder_generator.print_outputs()

RcJq!rTUw:RyhyBp$L ucAUg;K!qFB,N&




In [22]:
decoder.eval()
test_loss_fn = nn.CrossEntropyLoss()
batch_sample_inputs, batch_sample_labels = next(iter(train_dataloader))
with torch.inference_mode():
    #print(batch_sample_inputs)
    #print(batch_sample_labels)
    logits = decoder(batch_sample_inputs)
    labels = batch_sample_labels.view(-1) # turns the label shape into a B*T
    #print(logits) # 4 batches of 8 characters each, the model is trying to predict the next sequence
    #print(labels)
    loss = test_loss_fn(logits, batch_sample_labels.view(-1))
    print(loss)

tensor(4.3892, device='cuda:0')


In [23]:
def train_model(model, dataloader, loss_fn, optimizer, epochs):
    model.train()
    
    for epoch in range(epochs):
        for batch, (X, y) in tqdm(enumerate(dataloader)):
            logits = model(X) # shape of B*T x vocab_size
            labels = y.view(-1) # shape of B*T --> each character has it's own prediction
            loss = loss_fn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch % 100 == 0:
                print(f"loss for batch {batch} --> {loss} at epoch {epoch}")

In [24]:
train_model(model=decoder, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer, epochs=20)

1it [00:00,  2.52it/s]

loss for batch 0 --> 4.394429683685303 at epoch 0


101it [00:23,  4.19it/s]

loss for batch 100 --> 2.4578640460968018 at epoch 0


201it [00:48,  4.08it/s]

loss for batch 200 --> 2.235759735107422 at epoch 0


205it [00:48,  4.19it/s]
1it [00:00,  4.07it/s]

loss for batch 0 --> 2.2955939769744873 at epoch 1


101it [00:25,  3.91it/s]

loss for batch 100 --> 2.000166177749634 at epoch 1


201it [00:51,  3.90it/s]

loss for batch 200 --> 1.7390151023864746 at epoch 1


205it [00:51,  3.95it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 1.771932601928711 at epoch 2


101it [00:25,  3.95it/s]

loss for batch 100 --> 1.700164556503296 at epoch 2


201it [00:50,  3.95it/s]

loss for batch 200 --> 1.6728771924972534 at epoch 2


205it [00:51,  3.97it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 1.5364112854003906 at epoch 3


101it [00:25,  3.91it/s]

loss for batch 100 --> 1.5434497594833374 at epoch 3


201it [00:51,  3.92it/s]

loss for batch 200 --> 1.586573600769043 at epoch 3


205it [00:51,  3.94it/s]
1it [00:00,  3.93it/s]

loss for batch 0 --> 1.4715495109558105 at epoch 4


101it [00:25,  3.94it/s]

loss for batch 100 --> 1.5315405130386353 at epoch 4


201it [00:51,  3.91it/s]

loss for batch 200 --> 1.4816144704818726 at epoch 4


205it [00:51,  3.95it/s]
1it [00:00,  3.95it/s]

loss for batch 0 --> 1.3621184825897217 at epoch 5


101it [00:25,  3.93it/s]

loss for batch 100 --> 1.4543262720108032 at epoch 5


201it [00:51,  3.95it/s]

loss for batch 200 --> 1.4408073425292969 at epoch 5


205it [00:51,  3.95it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 1.3393572568893433 at epoch 6


101it [00:25,  3.95it/s]

loss for batch 100 --> 1.3958181142807007 at epoch 6


201it [00:50,  3.93it/s]

loss for batch 200 --> 1.3905450105667114 at epoch 6


205it [00:51,  3.96it/s]
1it [00:00,  4.05it/s]

loss for batch 0 --> 1.2626875638961792 at epoch 7


101it [00:25,  3.94it/s]

loss for batch 100 --> 1.3121176958084106 at epoch 7


201it [00:50,  3.94it/s]

loss for batch 200 --> 1.3210190534591675 at epoch 7


205it [00:51,  3.96it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 1.2767424583435059 at epoch 8


101it [00:25,  3.94it/s]

loss for batch 100 --> 1.2819160223007202 at epoch 8


201it [00:50,  3.95it/s]

loss for batch 200 --> 1.2620248794555664 at epoch 8


205it [00:51,  3.96it/s]
1it [00:00,  3.95it/s]

loss for batch 0 --> 1.1620242595672607 at epoch 9


101it [00:25,  3.94it/s]

loss for batch 100 --> 1.1902016401290894 at epoch 9


201it [00:50,  3.92it/s]

loss for batch 200 --> 1.2189648151397705 at epoch 9


205it [00:51,  3.96it/s]
1it [00:00,  4.00it/s]

loss for batch 0 --> 1.0956398248672485 at epoch 10


101it [00:25,  3.92it/s]

loss for batch 100 --> 1.1308989524841309 at epoch 10


201it [00:50,  3.93it/s]

loss for batch 200 --> 1.2420809268951416 at epoch 10


205it [00:51,  3.96it/s]
1it [00:00,  3.97it/s]

loss for batch 0 --> 1.0267654657363892 at epoch 11


101it [00:25,  3.93it/s]

loss for batch 100 --> 1.097220778465271 at epoch 11


201it [00:50,  3.94it/s]

loss for batch 200 --> 1.1269652843475342 at epoch 11


205it [00:51,  3.96it/s]
1it [00:00,  3.94it/s]

loss for batch 0 --> 0.9361605048179626 at epoch 12


101it [00:25,  3.93it/s]

loss for batch 100 --> 1.0088902711868286 at epoch 12


201it [00:50,  3.95it/s]

loss for batch 200 --> 1.1004530191421509 at epoch 12


205it [00:51,  3.96it/s]
1it [00:00,  3.98it/s]

loss for batch 0 --> 0.8747413158416748 at epoch 13


101it [00:25,  3.96it/s]

loss for batch 100 --> 0.9314310550689697 at epoch 13


201it [00:50,  3.94it/s]

loss for batch 200 --> 0.9656411409378052 at epoch 13


205it [00:51,  3.97it/s]
1it [00:00,  3.97it/s]

loss for batch 0 --> 0.7812402248382568 at epoch 14


101it [00:25,  3.96it/s]

loss for batch 100 --> 0.8872048854827881 at epoch 14


201it [00:50,  3.95it/s]

loss for batch 200 --> 0.8955422043800354 at epoch 14


205it [00:51,  3.97it/s]
1it [00:00,  3.89it/s]

loss for batch 0 --> 0.7058507204055786 at epoch 15


101it [00:25,  3.96it/s]

loss for batch 100 --> 0.7878478169441223 at epoch 15


201it [00:50,  3.94it/s]

loss for batch 200 --> 0.8347071409225464 at epoch 15


205it [00:51,  3.97it/s]
1it [00:00,  3.90it/s]

loss for batch 0 --> 0.5803107619285583 at epoch 16


101it [00:25,  3.94it/s]

loss for batch 100 --> 0.6996538639068604 at epoch 16


201it [00:50,  3.94it/s]

loss for batch 200 --> 0.7376101613044739 at epoch 16


205it [00:51,  3.96it/s]
1it [00:00,  4.00it/s]

loss for batch 0 --> 0.4922633767127991 at epoch 17


101it [00:25,  3.95it/s]

loss for batch 100 --> 0.5820157527923584 at epoch 17


201it [00:50,  3.94it/s]

loss for batch 200 --> 0.6707831025123596 at epoch 17


205it [00:51,  3.98it/s]
1it [00:00,  3.98it/s]

loss for batch 0 --> 0.4281761348247528 at epoch 18


101it [00:25,  3.97it/s]

loss for batch 100 --> 0.5230690240859985 at epoch 18


201it [00:50,  3.94it/s]

loss for batch 200 --> 0.5732739567756653 at epoch 18


205it [00:51,  3.97it/s]
1it [00:00,  3.99it/s]

loss for batch 0 --> 0.3629069924354553 at epoch 19


101it [00:25,  3.94it/s]

loss for batch 100 --> 0.4555146396160126 at epoch 19


201it [00:50,  3.96it/s]

loss for batch 200 --> 0.49032214283943176 at epoch 19


205it [00:51,  3.97it/s]


In [25]:
print("generating 5 sample of 500 characters each")

decoder_generator.update_params(max_length=500, num_samples=5)
decoder_generator.generate()
decoder_generator.print_outputs()

generating 5 sample of 500 characters each
y to this of rance,
Take the rose of it of our leade. Now, weak, it ever
For I have to me, that continue it now
He thus smay's to be to contine; and he, more play, sir,
He was thee to Rome, here strongly this roam;
I will come to this come to her:
Nay, or be there.

MERCUTITUS:
I hought or well tell me.

CORIONAS:
A noble lamb thought something was to me
doth us; but the fow of that 'tis so dazen.

SICINIWSCe:
Let'l not me:
Nay, be the oath! what with absence: I wannot do't,
What nay, hads, been 


y brother, forth thee,
What defenderstands sulment to swear John.
Three, come, know, by Saint Planta! Whither blest!
But plucky, gently, with a fatal holy--
Stray of Capulet hath
And be tour that an enemy!

GREMIO:
So this temper.

GLOUCESTER:

BUCH:
Nay, with you not what?

ISABELLA:
What's a methort?

ARINCELO:
Sweet or wear no!

ANGELO:
He hath and all it?

ISABELLA:
Grandam, hand not to seen the eaus is worth
an excellen of that traitor.

ISABEL

# Save and load the model

In [26]:
# save the model
torch.save(decoder.state_dict(), "./models/shakespeare_like_text_generator.pt")

In [27]:
n_embd = 784
vocab_size = len(vocab)
context_size = 128 # same as previously set
num_sa_heads = 16
sa_head_size = 64

test_model = Decoder(vocab_size=vocab_size, n_embd=n_embd, context_size=context_size, num_sa_heads=num_sa_heads, sa_head_size=sa_head_size)

test_model.load_state_dict(torch.load("./models/shakespeare_like_text_generator.pt"))
test_model_generator = model_generator(model=test_model, max_length=320, num_samples=1, vocab_size=vocab_size) 
test_model_generator.generate()
test_model_generator.print_outputs()


& never the grost:
A man age that the bid you cannot joy
To greature of the face that God's all that guess
Which am untimely in all together:
Come, concluded to the meats,
If would be tumble-buried, with angled boy.

CLIFFORD:
O, the gold!
I told Barest thou liest whereof that makest
Addest thou understance?

PETRUCHIO:


