* The objective of this assignments is to build the **Decoder** part of the Transformer architecture.
* We will be using the **PyTorch** framework to implement the following components
  * Decoder Layer that contains
    * Multi-Head Masked Attention (MHMA) Module
    * Multi-Head Cross Attention (MHMA) Module
    * Position-wise Feed Forward Neural Network

  * Implement CLM

* **DO NOT** USE Built-in **TRANSFORMER LAYERS** as it affects the reproducibility.

* You will be given with a configuration file that contains information on various hyperparameters such as embedding dimension, vocabulary size,number heads and so on

* Use ReLU activation function and Stochastic Gradient Descent optimizer
* Here are a list of helpful Pytorch functions (does not mean you have to use all of them) for this subsequent assignments
  * [torch.matmul](https://pytorch.org/docs/stable/generated/torch.matmul.html#torch-matmul)
  * [torch.bmm](https://pytorch.org/docs/stable/generated/torch.bmm.html)
  * torch.swapdims
  * torch.unsqueeze
  * torch.squeeze
  * torch.argmax
  * [torch.Tensor.view](https://pytorch.org/docs/stable/generated/torch.Tensor.view.html)
  * [torch.nn.Embedding](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html)
  * [torch.nn.Parameter](https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html)
  * torch.nn.Linear
  * torch.nn.LayerNorm
  * torch.nn.ModuleList
  * torch.nn.Sequential
  * [torch.nn.CrossEntropyLoss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
  
* Important: Do not set any global seeds.

* Helpful resources to get started with

 * [Andrej Karpathys Nano GPT](https://github.com/karpathy/nanoGPT)
 * [PyTorch Source code of Transformer Layer](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html)



In [None]:
import torch
from torch import Tensor

import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch.nn.functional import one_hot

import torch.optim as optim

from  pprint import pprint
from yaml import safe_load
import copy
import requests
from io import BytesIO

In [None]:
config_url = "https://raw.githubusercontent.com/Arunprakash-A/LLM-from-scratch-PyTorch/main/config_files/dec_config.yml"
response = requests.get(config_url)
config = response.content.decode("utf-8")
config = safe_load(config)
pprint(config)

{'input': {'batch_size': 10, 'embed_dim': 32, 'seq_len': 8, 'vocab_size': 12},
 'model': {'d_ff': 128,
           'd_model': 32,
           'dk': 4,
           'dq': 4,
           'dv': 4,
           'n_heads': 8,
           'n_layers': 6}}


In [None]:
vocab_size = config['input']['vocab_size']
batch_size = config['input']['batch_size']
seq_len = config['input']['seq_len']
embed_dim = config['input']['embed_dim']
dmodel = embed_dim
dq = torch.tensor(config['model']['dq'])
dk = torch.tensor(config['model']['dk'])
dv = torch.tensor(config['model']['dv'])
heads = torch.tensor(config['model']['n_heads'])
d_ff = config['model']['d_ff']

# Input tokens

* Generate a raw_input ids (without any special tokens appended to it)

* Since we will be using this as label after adding the special  \<start\> token, we use the variable name "label_ids"

* Keep the size of the `label_ids=(bs,seq_len-1)` as we insert a special token ids in the next step

In [None]:
data_url = 'https://github.com/Arunprakash-A/LLM-from-scratch-PyTorch/raw/main/config_files/w2_input_tokens'
r = requests.get(data_url)
label_ids = torch.load(BytesIO(r.content))

  label_ids = torch.load(BytesIO(r.content))


In [None]:
label_ids

tensor([[ 7,  8,  7,  7,  9,  2,  6],
        [10,  1, 10,  5,  3,  6,  8],
        [ 3,  4,  8,  2, 10, 10, 10],
        [ 4, 10,  1,  3,  4,  9,  7],
        [ 8,  4,  7,  3,  8, 10,  5],
        [ 9,  1,  8,  5,  9,  9, 10],
        [ 7,  3,  8,  2,  5,  1,  5],
        [ 3,  3,  2,  1,  4,  1,  1],
        [10,  9,  9,  9,  6,  9,  2],
        [ 3,  6,  6,  3,  5,  4,  5]])

* Let the first token_id be be a special `[start]` token (mapped to integer 0)
* If label_ids=$\begin{bmatrix}1&2\\3&4 \end{bmatrix}$, then we modify it as $\begin{bmatrix}0&1&2\\0&3&4 \end{bmatrix}$

In [None]:
begin_token = torch.zeros(label_ids.shape[0], 1, dtype=int)

token_ids = torch.cat((begin_token, label_ids), dim=1) # the first column of token_ids should be zeros and the rest of the columns come from label_ids

In [None]:
token_ids

tensor([[ 0,  7,  8,  7,  7,  9,  2,  6],
        [ 0, 10,  1, 10,  5,  3,  6,  8],
        [ 0,  3,  4,  8,  2, 10, 10, 10],
        [ 0,  4, 10,  1,  3,  4,  9,  7],
        [ 0,  8,  4,  7,  3,  8, 10,  5],
        [ 0,  9,  1,  8,  5,  9,  9, 10],
        [ 0,  7,  3,  8,  2,  5,  1,  5],
        [ 0,  3,  3,  2,  1,  4,  1,  1],
        [ 0, 10,  9,  9,  9,  6,  9,  2],
        [ 0,  3,  6,  6,  3,  5,  4,  5]])

# Implement the following components of a decoder layer

 * Multi-head Masked Attention (MHMA)
 * Multi-head Cross Attention (MHCA)
 * Postion-wise FFN

* Randomly initialize the parameters using normal distribution with the following seed values
  * $W_Q:$(seed=43)
  * $W_K:$(seed=44)
  * $W_V:$(seed=45)
  * $W_O:$(seed=46)

* Remember that, Multi-head cross atention takes two represnetation. One is the encoder output and the other one is the output from masked attetnion sub-layer.

* However, in this assignment, we will fix it to a random matrix.

In [None]:
class MHCA(nn.Module):

  def __init__(self,dmodel,dq,dk,dv,heads):
    super(MHCA,self).__init__()

    self.dmodel = dmodel
    self.dq = dq
    self.dk = dk
    self.dv = dv
    self.heads = heads

    torch.manual_seed(43)
    self.W_Q = nn.Parameter(torch.randn(dq * heads, dmodel))

    torch.manual_seed(44)
    self.W_K = nn.Parameter(torch.randn(dk * heads, dmodel))

    torch.manual_seed(45)
    self.W_V = nn.Parameter(torch.randn(dv * heads, dmodel))

    torch.manual_seed(46)
    self.W_O = nn.Parameter(torch.randn(dmodel, dv * heads))


  def forward(self, dec_query, enc_key, enc_value):

    Q = torch.matmul(dec_query, self.W_Q.T)
    K = torch.matmul(enc_key, self.W_K.T)
    V = torch.matmul(enc_value, self.W_V.T)

    batch_size = Q.shape[0]
    seq_len = Q.shape[1]

    Q = Q.view(batch_size, seq_len, self.heads, self.dq).transpose(1, 2)
    K = K.view(batch_size, seq_len, self.heads, self.dk).transpose(1, 2)
    V = V.view(batch_size, seq_len, self.heads, self.dv).transpose(1, 2)

    interim_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(self.dk)
    interim_scores_sm = F.softmax(interim_scores, dim=-1)

    attn_scores = torch.matmul(interim_scores_sm, V)

    attn_scores_T = attn_scores.transpose(1, 2).contiguous().view(batch_size, seq_len, self.heads * self.dv)

    out = torch.matmul(attn_scores_T, self.W_O.T)

    return out

* By default, `mask=None`. Therefore, create and apply the mask while computing the attention scores


In [None]:
class MHMA(nn.Module):

  def __init__(self,dmodel,dq,dk,dv,heads,mask=None):
    super(MHMA,self).__init__()

    self.dmodel = dmodel
    self.dq = dq
    self.dk = dk
    self.dv = dv
    self.heads = heads
    self.mask = mask

    torch.manual_seed(43)
    self.W_Q = nn.Parameter(torch.randn(dq * heads, dmodel))

    torch.manual_seed(44)
    self.W_K = nn.Parameter(torch.randn(dk * heads, dmodel))

    torch.manual_seed(45)
    self.W_V = nn.Parameter(torch.randn(dv * heads, dmodel))

    torch.manual_seed(46)
    self.W_O = nn.Parameter(torch.randn(dmodel, dv * heads))

  def forward(self, x):

    Q = torch.matmul(x, self.W_Q.T)
    K = torch.matmul(x, self.W_K.T)
    V = torch.matmul(x, self.W_V.T)

    batch_size = Q.shape[0]
    seq_len = Q.shape[1]

    Q = Q.view(batch_size, seq_len, self.heads, self.dq).transpose(1, 2)
    K = K.view(batch_size, seq_len, self.heads, self.dk).transpose(1, 2)
    V = V.view(batch_size, seq_len, self.heads, self.dv).transpose(1, 2)

    interim_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(self.dk)

    if self.mask==None:
      self.mask = torch.triu(torch.ones((seq_len, seq_len)), diagonal=1)
      self.mask = self.mask == 1

    masked_scores = interim_scores.masked_fill(self.mask.unsqueeze(0).unsqueeze(0), float('-inf'))

    masked_scores_sm = F.softmax(masked_scores, dim=-1)

    attn_scores = torch.matmul(masked_scores_sm, V)

    out = attn_scores.transpose(1, 2).contiguous().view(batch_size, seq_len, self.heads * self.dv)

    return out


* Implement the FFN and OutputLayer modules (same as the one you implemented for encoder)

In [None]:
class FFN(nn.Module):
  def __init__(self,dmodel,d_ff):
    super(FFN,self).__init__()

    self.linear1 = nn.Linear(dmodel, d_ff)
    self.linear2 = nn.Linear(d_ff, dmodel)

  def forward(self,x):

    x = self.linear1(x)
    x = F.relu(x)
    out = self.linear2(x)

    return out

In [None]:
class OutputLayer(nn.Module):

  def __init__(self,dmodel,vocab_size):
    super(OutputLayer,self).__init__()

    self.linear = nn.Linear(dmodel, vocab_size)

  def forward(self, x):

    out = self.linear(x)

    return out

* Implement the final decoder layer.

In [None]:
class DecoderLayer(nn.Module):

  def __init__(self,dmodel,dq,dk,dv,d_ff,heads,mask=None):
    super(DecoderLayer,self).__init__()
    self.mhma = MHMA(dmodel,dq,dk,dv,heads,mask=mask)
    self.mhca = MHCA(dmodel,dq,dk,dv,heads)
    self.layer_norm_mhma = torch.nn.LayerNorm(dmodel)
    self.layer_norm_mhca = torch.nn.LayerNorm(dmodel)
    self.layer_norm_ffn = torch.nn.LayerNorm(dmodel)
    self.ffn = FFN(dmodel,d_ff)

  def forward(self, x, enc_output):

    mhma_output = self.mhma(x)
    x = self.layer_norm_mhma(x + mhma_output)

    mhca_output = self.mhca(x, enc_output, enc_output)
    x = self.layer_norm_mhca(x + mhca_output)

    ffn_output = self.ffn(x)
    out = self.layer_norm_ffn(x + ffn_output)

    return out

* Create an embedding layer that takes in token_ids and return embeddings for the token_ids

 * Use seed value: 70

In [None]:
class Embed(nn.Module):

  def __init__(self,vocab_size,embed_dim):
    super(Embed,self).__init__()

    torch.manual_seed(70)
    self.embed= nn.Embedding(vocab_size, embed_dim)

  def forward(self,x):
    out = self.embed(x)
    return out

# Decoder

 * Implement the decoder that has `num_layers` decoder layers

In [None]:
import copy

class Decoder(nn.Module):

  def __init__(self,vocab_size,dmodel,dq,dk,dv,d_ff,heads,mask,num_layers=1):
    super(Decoder,self).__init__()

    self.embed_lookup = Embed(vocab_size, dmodel)

    decoder_layer = DecoderLayer(dmodel, dq, dk, dv, d_ff, heads, mask)
    self.dec_layers = nn.ModuleList([copy.deepcopy(decoder_layer) for i in range(num_layers)])

    self.output_layer = OutputLayer(dmodel, vocab_size)

  def forward(self,enc_rep,tar_token_ids):

    x = self.embed_lookup(tar_token_ids)

    for dec_layer in self.dec_layers:
        x = dec_layer(x, enc_rep)

    out = self.output_layer(x)

    return out

* Representation from encoder

 * Since all the decoder layers require the representation from the encoder to compute cross attention, we are going to feed in the random values (Note, it does not require gradient during training)

In [None]:
enc_rep = torch.randn(size=(batch_size,seq_len,embed_dim),generator=torch.random.manual_seed(10))

# Instantiate the model

In [None]:
model = Decoder(vocab_size,dmodel,dq,dk,dv,d_ff,heads,mask=None)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [None]:
def train(enc_rep,tar_token_ids,label_ids,epochs=1000):
  loss_trace = []
  for epoch in range(epochs):
    out = model(enc_rep,tar_token_ids)
    out = out.view(-1, vocab_size)

    target = tar_token_ids.view(-1)

    loss = criterion(out, target)
    loss_trace.append(loss.item())

    if (epoch+1)%100 == 0:
      print("Epoch :", epoch, "Loss :", loss.item())

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

* Train the model for 1000 epochs

In [None]:
train(enc_rep,token_ids,label_ids,1000)

Epoch : 99 Loss : 1.9622176885604858
Epoch : 199 Loss : 1.5867887735366821
Epoch : 299 Loss : 1.302473783493042
Epoch : 399 Loss : 1.0674269199371338
Epoch : 499 Loss : 0.877951443195343
Epoch : 599 Loss : 0.7048269510269165
Epoch : 699 Loss : 0.5797830820083618
Epoch : 799 Loss : 0.4468874931335449
Epoch : 899 Loss : 0.35477375984191895
Epoch : 999 Loss : 0.28072211146354675


In [None]:
with torch.inference_mode():
  predictions = torch.argmax(model(enc_rep,token_ids),dim=-1)

* The loss will be around 0.17 after 1000 epochs

In [None]:
# number of correct predictions
print(torch.count_nonzero(label_ids==predictions[:,1:]))

tensor(69)


* The number of correct predictions is close to 66