<a href="https://www.kaggle.com/code/evelynartoria/decoder-transformer-model-from-scratch-pytorch?scriptVersionId=187731215" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
!mkdir ./models

In [2]:
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader, random_split
from tqdm import tqdm

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator(device=device)
torch.set_default_device(device)
print(f"default device set to {device}")

default device set to cuda


In [4]:
with open("/kaggle/input/shakespeare/input.txt", 'r', encoding="utf-8") as f:
    text = f.read()

In [5]:
vocab = sorted(set(text))
vocab_size = len(vocab)

In [6]:
stoi = {c: v for v, c in enumerate(vocab)}
itos = {v: c for c, v in stoi.items()}

print(stoi["h"])
print(itos[46])

46
h


In [7]:
encode = lambda e: [stoi[ch] for ch in e]
decode = lambda d: "".join([itos[idx] for idx in d])

encoded = encode("hello how are you?")
decoded = decode(encoded)

print(encoded, decoded)

[46, 43, 50, 50, 53, 1, 46, 53, 61, 1, 39, 56, 43, 1, 63, 53, 59, 12] hello how are you?


In [8]:
context_size = 128
random_idx_tensor = torch.randperm(10000//context_size) * context_size
print(random_idx_tensor)

tensor([9344, 2176, 1024, 1280, 9088,    0, 7296, 5888, 2048, 6528, 2688, 4224,
        9728, 1152, 4096, 5504, 3712, 5248,  640, 6016, 4992, 1920, 6272, 2816,
        9600, 2304, 3328, 2944, 3968,  256, 1664, 6656, 7424, 8192, 6400, 7808,
        7680, 3200,  128,  768, 4352, 1408, 8448, 4480, 1792, 8704, 7552, 8064,
        7040,  896, 1536, 9856, 6144, 4608,  384, 7168, 3072, 5120, 8960, 3840,
        3456, 8576, 4736, 5376, 5632, 4864, 2560, 7936, 8832, 3584, 8320, 2432,
        9472, 9216, 5760, 6784, 6912,  512], device='cuda:0')


In [9]:
def make_dataset(data):
    random_idx_tensor = torch.randperm((len(data)-context_size)//context_size) * context_size
    inputs = torch.stack([data[idx:idx+context_size] for idx in random_idx_tensor])
    labels = torch.stack([data[idx+1:idx+context_size+1] for idx in random_idx_tensor])
    
    return TensorDataset(inputs.to(torch.long), labels.to(torch.long))

In [10]:
data = torch.tensor(encode(text))
dataset = make_dataset(data=data)
sample_input = dataset[0][0]
sample_label = dataset[0][1]

print(sample_input)
print(sample_label)

print(f"dataset length --> {len(dataset)} ({len(dataset) * context_size} characters), that is, about the length of text {len(text)} - context_size --> {len(text)-context_size}")

tensor([ 1, 21,  1, 58, 53, 50, 42,  1, 63, 53, 59,  1, 53, 44, 10,  0, 21,  1,
        54, 56, 39, 63,  1, 63, 53, 59,  1, 57, 58, 39, 52, 42,  1, 45, 53, 53,
        42,  1, 44, 39, 58, 46, 43, 56,  1, 58, 53,  1, 51, 43,  1, 52, 53, 61,
         6,  0, 19, 47, 60, 43,  1, 51, 43,  1, 14, 47, 39, 52, 41, 39,  1, 44,
        53, 56,  1, 51, 63,  1, 54, 39, 58, 56, 47, 51, 53, 52, 63,  8,  0,  0,
        28, 43, 42, 39, 52, 58, 10,  0, 31, 53, 44, 58,  1, 57, 53, 52,  2,  0,
        31, 47, 56,  6,  1, 40, 63,  1, 63, 53, 59, 56,  1, 50, 43, 39, 60, 43,
        10,  1], device='cuda:0')
tensor([21,  1, 58, 53, 50, 42,  1, 63, 53, 59,  1, 53, 44, 10,  0, 21,  1, 54,
        56, 39, 63,  1, 63, 53, 59,  1, 57, 58, 39, 52, 42,  1, 45, 53, 53, 42,
         1, 44, 39, 58, 46, 43, 56,  1, 58, 53,  1, 51, 43,  1, 52, 53, 61,  6,
         0, 19, 47, 60, 43,  1, 51, 43,  1, 14, 47, 39, 52, 41, 39,  1, 44, 53,
        56,  1, 51, 63,  1, 54, 39, 58, 56, 47, 51, 53, 52, 63,  8,  0,  0, 28,
      

In [11]:
train_split = int(len(dataset)*0.75)
test_split = int(len(dataset)-train_split)

train_dataset, test_dataset = random_split(dataset=dataset, lengths=[train_split, test_split], generator=generator)

In [12]:
batch_size = 32
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, generator=generator)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True, generator=generator)

In [13]:
class Head(nn.Module):
    def __init__(self, n_embd, head_size, context_size):
        super(Head, self).__init__()
        
        self.Q = nn.Linear(in_features=n_embd, out_features=head_size) # takes in BxTxC and return BxTxHead_size
        self.K = nn.Linear(in_features=n_embd, out_features=head_size)
        self.V = nn.Linear(in_features=n_embd, out_features=head_size)
        
        self.register_buffer("tril", torch.tril(torch.ones(size=(context_size, context_size))))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape # batch_size by context_size by n_embd
        q = self.Q(x) # BxTxHead_size
        k = self.K(x) # BxTxHead_size
        
        wei = q @ k.transpose(-2, -1) * (C ** -0.5) # BxTxHead_size @ BxHead_sizexT --> BxTxT then divided by the square root of n_embd
        
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf')) # :T and :T is needed in case context is smaller than context_size
        wei = torch.softmax(wei, dim=-1)
        v = self.V(x) # BxTxHead_size
        output = wei @ v # BxTxT @ BxTxHead_size --> BxTxHead_size
        
        return output

In [14]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, n_embd, context_size, n_heads, head_size):
        super(MultiHeadedAttention, self).__init__()
        self.heads = nn.ModuleList([Head(n_embd=n_embd, head_size=head_size, context_size=context_size) for _ in range(n_heads)]) # BxTx (n_heads * head_size)
        self.projection = nn.Linear(in_features=n_heads*head_size, out_features=n_embd) # ensures the output is going to be o shape BxTxn_embd (BxTxC) so that is can go through multiple attention block
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.cat([head(x) for head in self.heads], dim=-1) # cat in the Channels dimension; output shape is BxTx (n_heads * head_size)
        x = self.projection(x)
        return  x

In [15]:
class FeedForward(nn.Module):
    def __init__(self, in_features):
        super(FeedForward, self).__init__()
        self.ffwrd_layer = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=in_features * 4), # scale by 4, according to the attention is all you need paper
            nn.ReLU(),
            nn.Linear(in_features=in_features * 4, out_features=in_features) # another projection layer
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.ffwrd_layer(x)

In [16]:
class Block(nn.Module):
    def __init__(self, n_heads, head_size, n_embd, context_size):
        super(Block, self).__init__()
        self.multiheaded_self_attetion = MultiHeadedAttention(n_embd=n_embd, context_size=context_size, n_heads=n_heads, head_size=head_size) # create a multiheaded attention block; returns shape BxTx (num_heads*head_size)
        self.ffwrd = FeedForward(in_features=n_embd)
        self.layer_norm1 = nn.LayerNorm(n_embd)
        self.layer_norm2 = nn.LayerNorm(n_embd)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.multiheaded_self_attetion(self.layer_norm1(x))
        x = x + self.ffwrd(self.layer_norm2(x))

        return x

In [17]:
class Decoder(nn.Module):
    def __init__(self, n_embd, context_size, vocab_size, num_sa_heads, sa_head_size):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.context_size = context_size
        self.n_embd = n_embd
        
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd) # each character from the vocab has n_embd values associated to it
        self.positional_embedding_table = nn.Embedding(context_size, n_embd) # each character position in the context has n_embd values associated to it
        
        self.attention_blocks = nn.Sequential(
            Block(n_heads=num_sa_heads, head_size=sa_head_size, context_size=self.context_size, n_embd=self.n_embd), # takes in BxTxC, calculate logits of BxTx (num_heads * head_size), then project it as BxTxC
            Block(n_heads=num_sa_heads, head_size=sa_head_size, context_size=self.context_size, n_embd=self.n_embd),
            Block(n_heads=num_sa_heads, head_size=sa_head_size, context_size=self.context_size, n_embd=self.n_embd),
            nn.LayerNorm(n_embd) # normalize the layers
        )
        

        self.lm_head = nn.Linear(in_features=n_embd, out_features=vocab_size) # (B, T, vocab_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T = x.shape # batch_size and context_size
        positions = torch.arange(start=0, end=T, step=1)
        
        pos_emb = self.positional_embedding_table(positions) # T x C --> in broadcasting, pytorch adds a batch dim=1
        token_emb = self.token_embedding_table(x) # B x T x C
        
        x = token_emb + pos_emb # BxTxC
        x = self.attention_blocks(x) # returns logits of shape BxTx (self.sa_head_size * self.num_sa_heads) projected to BxTxC
        x = self.lm_head(x) # BxTxvocab_size --> BxTxHead_size @ BxTxVocab_size return BxTxVocab_size

        return x.view(B*T, self.vocab_size) # easier shape to work with the labels
    
    def generate(self, starting_idx: torch.Tensor, max_length: int, debug: bool) -> torch.Tensor:
        full_text = decode([starting_idx.item()])
        context = starting_idx
        
        for _ in range(max_length):
            context = context[:, -self.context_size:] # make sure the context is of size context_size
            
            if debug:
                print(f"predicting on context: {decode(context[0].tolist())}")
            
            logits = self(context) # B*T x vocab_size --> 1*2 x vocab_size
            logits = logits[-1, :].view(1, self.vocab_size) # only take the prediction for the last character
            percents = torch.softmax(logits, dim=1) # 1*2xvocab_size
            pred = torch.multinomial(percents, num_samples=1) 
            full_text += decode(pred.tolist()[0])
            context = torch.cat([context, pred], dim=1) # add to the context dimension instead of the batch dim
            
        return full_text


In [18]:
class model_generator:
    def __init__(self, model: object, max_length: int, num_samples: int, vocab_size: int):
        self.model = model
        self.max_length = max_length
        self.num_samples = num_samples
        self.vocab_size = vocab_size
        
        self.last_output = ""
        
        self.params_dict = {
            "model": model,
            "max_length": max_length,
            "num_samples": num_samples,
            "previous_outputs": []
        }
    
    @torch.no_grad
    def generate(self, starting_char: str = None, clear_outputs: bool = True, debug: bool = False):
        self.model.eval()
        
        if clear_outputs:
            self.clear_ouptuts()
            
        if starting_char is None:
            starting_char = decode([torch.randint(0, vocab_size, (1,)).item()])
            
        for _ in range(self.num_samples):
            starting_idx = torch.tensor(encode(starting_char), dtype=torch.long).view(1, 1)
            output = self.model.generate(starting_idx=starting_idx, max_length=self.max_length, debug=debug)
            self.params_dict["previous_outputs"].append(output)
            self.last_output = output
    
    def update_params(self, model: object = None, max_length: int = None, num_samples: int = None, clear_outputs: bool = None):
        if clear_outputs:
            self.clear_outputs()
            
        updated_dict = {
            "model": model,
            "max_length": max_length,
            "num_samples": num_samples
        }
        
        for attribute, value in updated_dict.items():
            if value is not None:
                self.params_dict[attribute] = value
                setattr(self, attribute, value)
    
    def clear_ouptuts(self):
        self.params_dict["previous_outputs"] = []
        self.last_output = ""
        
    def print_outputs(self, last: bool = None):
        if last:
            print(self.last_output)
        else:
            for output in self.params_dict["previous_outputs"]:
                print(f"{output}\n\n")

In [19]:
n_embd = 784
vocab_size = len(vocab)
context_size = 128 # same as previously set
num_sa_heads = 16
sa_head_size = 64

decoder = Decoder(vocab_size=vocab_size, n_embd=n_embd, context_size=context_size, num_sa_heads=num_sa_heads, sa_head_size=sa_head_size)
optimizer = torch.optim.Adam(params=decoder.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()

In [20]:
decoder_generator = model_generator(model=decoder, max_length=32, num_samples=1, vocab_size=vocab_size)

In [21]:
#decoder_generator.update_params(max_length=100, num_samples=1)
decoder_generator.generate()
decoder_generator.print_outputs()

rnKPkvRm$AUSplMJ?IrCuBLpXBK
gZO?q




In [22]:
decoder.eval()
test_loss_fn = nn.CrossEntropyLoss()
batch_sample_inputs, batch_sample_labels = next(iter(train_dataloader))
with torch.inference_mode():
    #print(batch_sample_inputs)
    #print(batch_sample_labels)
    logits = decoder(batch_sample_inputs)
    labels = batch_sample_labels.view(-1) # turns the label shape into a B*T
    #print(logits) # 4 batches of 8 characters each, the model is trying to predict the next sequence
    #print(labels)
    loss = test_loss_fn(logits, batch_sample_labels.view(-1))
    print(loss)

tensor(4.3425, device='cuda:0')


In [23]:
def train_model(model, dataloader, loss_fn, optimizer, epochs):
    model.train()
    
    for epoch in range(epochs):
        for batch, (X, y) in tqdm(enumerate(dataloader)):
            logits = model(X) # shape of B*T x vocab_size
            labels = y.view(-1) # shape of B*T --> each character has it's own prediction
            loss = loss_fn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch % 100 == 0:
                print(f"loss for batch {batch} --> {loss} at epoch {epoch}")

In [24]:
train_model(model=decoder, dataloader=train_dataloader, loss_fn=loss_fn, optimizer=optimizer, epochs=20)

1it [00:00,  2.59it/s]

loss for batch 0 --> 4.361568927764893 at epoch 0


101it [00:23,  4.33it/s]

loss for batch 100 --> 2.4810991287231445 at epoch 0


201it [00:46,  4.14it/s]

loss for batch 200 --> 2.20521879196167 at epoch 0


205it [00:47,  4.33it/s]
1it [00:00,  4.14it/s]

loss for batch 0 --> 2.21132493019104 at epoch 1


101it [00:24,  4.03it/s]

loss for batch 100 --> 1.9597002267837524 at epoch 1


201it [00:50,  3.80it/s]

loss for batch 200 --> 1.8027087450027466 at epoch 1


205it [00:51,  4.01it/s]
1it [00:00,  3.87it/s]

loss for batch 0 --> 1.7473392486572266 at epoch 2


101it [00:27,  3.69it/s]

loss for batch 100 --> 1.735757827758789 at epoch 2


201it [00:53,  3.80it/s]

loss for batch 200 --> 1.6752307415008545 at epoch 2


205it [00:54,  3.75it/s]
1it [00:00,  3.87it/s]

loss for batch 0 --> 1.557845115661621 at epoch 3


101it [00:26,  3.79it/s]

loss for batch 100 --> 1.5470411777496338 at epoch 3


201it [00:52,  3.75it/s]

loss for batch 200 --> 1.6084872484207153 at epoch 3


205it [00:53,  3.82it/s]
1it [00:00,  3.75it/s]

loss for batch 0 --> 1.5185537338256836 at epoch 4


101it [00:26,  3.79it/s]

loss for batch 100 --> 1.4281761646270752 at epoch 4


201it [00:53,  3.77it/s]

loss for batch 200 --> 1.4587103128433228 at epoch 4


205it [00:53,  3.80it/s]
1it [00:00,  3.81it/s]

loss for batch 0 --> 1.4311699867248535 at epoch 5


101it [00:26,  3.79it/s]

loss for batch 100 --> 1.4260568618774414 at epoch 5


201it [00:52,  3.77it/s]

loss for batch 200 --> 1.415411114692688 at epoch 5


205it [00:53,  3.81it/s]
1it [00:00,  3.83it/s]

loss for batch 0 --> 1.3092536926269531 at epoch 6


101it [00:26,  3.76it/s]

loss for batch 100 --> 1.4066685438156128 at epoch 6


201it [00:53,  3.76it/s]

loss for batch 200 --> 1.2929439544677734 at epoch 6


205it [00:54,  3.79it/s]
1it [00:00,  3.84it/s]

loss for batch 0 --> 1.2929799556732178 at epoch 7


101it [00:26,  3.78it/s]

loss for batch 100 --> 1.2788715362548828 at epoch 7


201it [00:53,  3.78it/s]

loss for batch 200 --> 1.2940754890441895 at epoch 7


205it [00:53,  3.80it/s]
1it [00:00,  3.85it/s]

loss for batch 0 --> 1.1887211799621582 at epoch 8


101it [00:26,  3.78it/s]

loss for batch 100 --> 1.2219927310943604 at epoch 8


201it [00:52,  3.78it/s]

loss for batch 200 --> 1.2758433818817139 at epoch 8


205it [00:53,  3.81it/s]
1it [00:00,  3.91it/s]

loss for batch 0 --> 1.119668960571289 at epoch 9


101it [00:26,  3.77it/s]

loss for batch 100 --> 1.1692415475845337 at epoch 9


201it [00:52,  3.76it/s]

loss for batch 200 --> 1.2558178901672363 at epoch 9


205it [00:53,  3.81it/s]
1it [00:00,  3.84it/s]

loss for batch 0 --> 1.061602234840393 at epoch 10


101it [00:26,  3.78it/s]

loss for batch 100 --> 1.112863302230835 at epoch 10


201it [00:52,  3.77it/s]

loss for batch 200 --> 1.1433919668197632 at epoch 10


205it [00:53,  3.81it/s]
1it [00:00,  3.81it/s]

loss for batch 0 --> 1.029674768447876 at epoch 11


101it [00:26,  3.80it/s]

loss for batch 100 --> 1.0813753604888916 at epoch 11


201it [00:52,  3.78it/s]

loss for batch 200 --> 1.1094677448272705 at epoch 11


205it [00:53,  3.81it/s]
1it [00:00,  3.94it/s]

loss for batch 0 --> 0.920134425163269 at epoch 12


101it [00:26,  3.77it/s]

loss for batch 100 --> 1.0397216081619263 at epoch 12


201it [00:52,  3.79it/s]

loss for batch 200 --> 1.0370795726776123 at epoch 12


205it [00:53,  3.81it/s]
1it [00:00,  3.80it/s]

loss for batch 0 --> 0.8501040935516357 at epoch 13


101it [00:26,  3.77it/s]

loss for batch 100 --> 0.9073486924171448 at epoch 13


201it [00:52,  3.78it/s]

loss for batch 200 --> 0.984785258769989 at epoch 13


205it [00:53,  3.81it/s]
1it [00:00,  3.87it/s]

loss for batch 0 --> 0.7806637287139893 at epoch 14


101it [00:26,  3.79it/s]

loss for batch 100 --> 0.8624745011329651 at epoch 14


201it [00:52,  3.78it/s]

loss for batch 200 --> 0.855772852897644 at epoch 14


205it [00:53,  3.81it/s]
1it [00:00,  3.81it/s]

loss for batch 0 --> 0.6865339875221252 at epoch 15


101it [00:26,  3.78it/s]

loss for batch 100 --> 0.7341650128364563 at epoch 15


201it [00:52,  3.79it/s]

loss for batch 200 --> 0.8220940232276917 at epoch 15


205it [00:53,  3.81it/s]
1it [00:00,  3.90it/s]

loss for batch 0 --> 0.5495465397834778 at epoch 16


101it [00:26,  3.79it/s]

loss for batch 100 --> 0.6301834583282471 at epoch 16


201it [00:52,  3.78it/s]

loss for batch 200 --> 0.7209587097167969 at epoch 16


205it [00:53,  3.81it/s]
1it [00:00,  3.89it/s]

loss for batch 0 --> 0.45450475811958313 at epoch 17


101it [00:26,  3.78it/s]

loss for batch 100 --> 0.556039571762085 at epoch 17


201it [00:52,  3.79it/s]

loss for batch 200 --> 0.6120920181274414 at epoch 17


205it [00:53,  3.81it/s]
1it [00:00,  3.91it/s]

loss for batch 0 --> 0.3820245563983917 at epoch 18


101it [00:26,  3.77it/s]

loss for batch 100 --> 0.48824411630630493 at epoch 18


201it [00:52,  3.78it/s]

loss for batch 200 --> 0.5217015147209167 at epoch 18


205it [00:53,  3.81it/s]
1it [00:00,  3.81it/s]

loss for batch 0 --> 0.35087838768959045 at epoch 19


101it [00:26,  3.79it/s]

loss for batch 100 --> 0.3872120976448059 at epoch 19


201it [00:52,  3.79it/s]

loss for batch 200 --> 0.459830105304718 at epoch 19


205it [00:53,  3.81it/s]


In [25]:
print("generating 5 sample of 500 characters each")

decoder_generator.update_params(max_length=500, num_samples=5)
decoder_generator.generate()
decoder_generator.print_outputs()

generating 5 sample of 500 characters each
And I wash atting wart confest; thou sway one
It is canno arms of hath been such a jevy
Too slew that pay your mants of happy soul.

CLARENCE:
That breat not une is when I can by him ere
To come the buried to see him I tall of York
Whostly accountryman: eyest, I have hearts,
I'll my vice and a heart's name is land all.

ARIVLLET:
There is to see it.

KING RICHARD III:
Then lame the mirth all my largarden's lady
And take a pinghambiar of moress' tond.
I, in that is I than seizard from me, and then


AR-ICLARD Io:
Why comes this?

MARCIUS:
Are you done for me;
And, by therefore come, my lord, which not make
To make your shield mellast upon is of me.

CLARENCE:

QUEEN MARGARET:
Then I, I can I darly in that die, of heavenment
That I charge and be my eyes together.

GLOUCESTER:
Thy widow as I instrument.

LADY ANNE:
I thought meet then shall have heart most go.

GLOUCESTER:
I do not seek bite merift.

LADY GLO:
Then flow see there mean there remem

# Save and load the model

In [26]:
# save the model
torch.save(decoder.state_dict(), "./models/shakespeare_like_text_generator.pt")

In [27]:
n_embd = 784
vocab_size = len(vocab)
context_size = 128 # same as previously set
num_sa_heads = 16
sa_head_size = 64

test_model = Decoder(vocab_size=vocab_size, n_embd=n_embd, context_size=context_size, num_sa_heads=num_sa_heads, sa_head_size=sa_head_size)

test_model.load_state_dict(torch.load("./models/shakespeare_like_text_generator.pt"))
test_model_generator = model_generator(model=test_model, max_length=320, num_samples=1, vocab_size=vocab_size) 
test_model_generator.generate()
test_model_generator.print_outputs()


CESTER:
I go; and if you plant deserved
The Katharina, did I know not him princ.

KING EDWARD IV:
This off that he be offence, that approached their
Thou shalt the tertainmed traitor flatter.

RICHARD:
Marriumph, help in earth, heaving thee, my lord
Death is heaven; for what my would languance
And Baptista'e maint blood


