# Decoder only Transformer

## Importing the libraries

In [1]:
!pip install pytorch-lightning

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=2.1.0->pytorch-lightning)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=2.1.0->pytorch-lightning)
  Dow

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset,DataLoader

import pytorch_lightning as L

## Creating inputs and labels for training

In [3]:
token_to_id = {"what":0,"is":1,"statquest":2,"awesome":3,"<EOS>":4}

In [4]:
id_to_token = dict(map(reversed,token_to_id.items()))

In [5]:
# we've got 2 different inputs which are :
# 1. what is statquest   2. statquest is what
#  And we want our output will be awesome.

inputs = torch.tensor([[token_to_id["what"],
                        token_to_id["is"],
                        token_to_id["statquest"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"]],

                       [token_to_id["statquest"],
                       token_to_id["is"],
                       token_to_id["what"],
                       token_to_id["<EOS>"],
                       token_to_id["awesome"]]])

In [6]:
labels = torch.tensor([[token_to_id["is"],
                        token_to_id["statquest"],
                        token_to_id["<EOS>"],
                        token_to_id["awesome"],
                        token_to_id["<EOS>"]],

                       [token_to_id["is"],
                       token_to_id["what"],
                       token_to_id["<EOS>"],
                       token_to_id["awesome"],
                       token_to_id["<EOS>"]]])

# we want our output will be "is" when input is "what"
# kinda even start predicting when taking an input.
# and input="awesome" ==> output = "<EOS>" etc.

In [7]:
dataset = TensorDataset(inputs,labels)
dataloader = DataLoader(dataset)

## Position Encoding

In [8]:
class PositionEncoding(nn.Module):
    def __init__(self,d_model=2,max_len=6):  # we only have 2 dim of word embedding, and max_len is just bullshit.
        super().__init__()

        pe = torch.zeros(max_len,d_model)

        position = torch.arange(start=0,end=max_len,step=1).float().unsqueeze(1)
        embedding_index = torch.arange(start=0,end=d_model,step=2).float()

        div_term = 1/torch.tensor(10000.0) ** (embedding_index / d_model)

        pe[:,0::2] = torch.sin(position * div_term)  # starting from 0, until the end, increments that shit by 2.(every other column after that)
        pe[:,1::2] = torch.cos(position * div_term)

        self.register_buffer('pe',pe)

    def forward(self,word_embeddings):
        return word_embeddings + self.pe[:word_embeddings.size(0),:]

## Attention

In [9]:
class Attention(nn.Module):
    def __init__(self,d_model=2):
        super().__init__()

        self.W_q = nn.Linear(in_features=d_model,out_features=d_model,bias=False)
        self.W_k = nn.Linear(in_features=d_model,out_features=d_model,bias=False)
        self.W_v = nn.Linear(in_features=d_model,out_features=d_model,bias=False)


        self.row_dim = 0
        self.col_dim = 1

    def forward(self,encodings_for_q, encodings_for_k,encodings_for_v,mask=None):

        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q,k.transpose(dim0=self.row_dim , dim1=self.col_dim))  # calculating the similarities between Queries and the Keys..

        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)  # nuguu neg lalariin tomyo sanaj l baigaa bizde. yazguur door, dim-d huvaagaad shaadag.

        if mask is not None :
            scaled_sims = scaled_sims.masked_fill(mask=mask,value=-1e9)

        attention_percents = F.softmax(scaled_sims,dim=self.col_dim)

        attention_scores = torch.matmul(attention_percents,v)

        return attention_scores


## Network

In [10]:
class DecoderOnlyTransformer(L.LightningModule):

    def __init__(self,num_tokens=4,d_model=2,max_len=6):
        super().__init__()

        self.we = nn.Embedding(num_embeddings=num_tokens,embedding_dim=d_model)
        self.pe = PositionEncoding(d_model=d_model,max_len=max_len)
        self.self_attention = Attention(d_model=d_model)
        self.fc_layer = nn.Linear(in_features=d_model,out_features=num_tokens)

        self.loss = nn.CrossEntropyLoss()

    def forward(self,token_ids):
        word_embeddings = self.we(token_ids)
        position_encoded = self.pe(word_embeddings)

        mask = torch.tril(torch.ones((token_ids.size(dim=0),token_ids.size(dim=0))))
        mask = mask == 0

        self_attention_values = self.self_attention(position_encoded,position_encoded,position_encoded,mask=mask)

        residual_connection_values = position_encoded + self_attention_values

        fc_layer_output = self.fc_layer(residual_connection_values)

        return fc_layer_output

    def configure_optimizers(self):
        return Adam(self.parameters(),lr=0.1)

    def training_step(self,batch,batch_idx):
        input_tokens,labels = batch
        output = self.forward(input_tokens[0])
        loss = self.loss(output,labels[0])

        return loss

## Training

In [11]:
model = DecoderOnlyTransformer(num_tokens=len(token_to_id),d_model=2,max_len=6)
trainer = L.Trainer(max_epochs=30)
trainer.fit(model,train_dataloaders=dataloader)

INFO:pytorch_lightning.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name           | Type             | Params | Mode 
------------------------------------------------------------
0 | we             | Embedding        | 10     | train
1 | pe             | PositionEncoding | 0      | train
2 | self_attention | Attention        | 12     | train
3 | fc_layer       | Linear           | 15     | train
4 | loss           | CrossEntropyLoss | 0      | train
------------------------------------------------------------
37        Trainable params
0         Non-trainable params
3

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


## Visualising the result

In [12]:
model_input = torch.tensor([token_to_id["what"],
                        token_to_id["is"],
                        token_to_id["statquest"],
                        token_to_id["<EOS>"]])
input_length = model_input.size(dim=0)

predictions = model(model_input)
predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])   # means take only last time step's result. cuz we're not interested in first shits result, since we're only trying to get last token's prediction blah2.
predicted_ids = predicted_id

max_length = 6
for i in range(input_length,max_length):
    if(predicted_id == token_to_id["<EOS>"]):
        break

    model_input = torch.cat((model_input,predicted_id))

    predictions = model(model_input)

    predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])  # uchir ni ene shit ni, input-iin(appended version of it too) toonii row, predict bolomjtoi class-iin toonii col butsaanguut, suuliin row-oos l bid nar argmax shaaj, yamar ug predict hiihee songono shuude.
    predicted_ids = torch.cat((predicted_ids,predicted_id))

print("Predicted Tokens :\n")
for id in predicted_ids :
    print("\t",id_to_token[id.item()])

Predicted Tokens :

	 awesome
	 <EOS>
