# Decoder only transformer architecture

## Manual computation of the intermediate outputs (Verification)

### Intialising the same transformer model as used in the other notebook

In [25]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy 

torch.manual_seed(6)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.shape[1]].detach()
    
    

class DecoderOnlyTransformer(nn.Module):
    def __init__(self,num_layers, d_model, nhead, dim_feedforward, vocab_size, max_seq_len, dropout = 0):
        super(DecoderOnlyTransformer, self).__init__()

        self.src_embedding = nn.Embedding(vocab_size, d_model)

        self.positional_encoding = PositionalEncoding(d_model = d_model, dropout=0, max_len=max_seq_len)

        self.decoder_layer = nn.TransformerEncoderLayer(d_model = d_model, nhead = nhead, dim_feedforward = dim_feedforward, dropout = dropout)
        self.layers = nn.ModuleList([copy.deepcopy(self.decoder_layer) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, vocab_size)

    
    def generate_square_subsequent_mask(self, tgt):
        seq_length = tgt.size(0)
        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

        return nopeak_mask
    
        # mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        # mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        # return mask

    def forward(self, src, src_mask=None):

        src_embed = self.src_embedding(src)
        pe_src = self.positional_encoding(src)
        
        pe_src = pe_src.transpose(0,1)

        src = src_embed + pe_src

        for layer in self.layers:

            src_mask = self.generate_square_subsequent_mask(src)
            src = layer(src, src_mask)


        op = self.fc(src)
        return op



In [26]:
torch.manual_seed(0)


vocab_size = 20  # Source language vocabulary size
d_model = 6  # Dimension of the model
num_heads = 2
d_ff = 4
num_layers = 2


max_seq_len = 5
model = DecoderOnlyTransformer(d_model=d_model, vocab_size=vocab_size, num_layers=num_layers , nhead=num_heads, max_seq_len = max_seq_len-1, dim_feedforward=d_ff)


batch_size = 10

# Generate random sample data
src_data = torch.randint(1, vocab_size, (max_seq_len , batch_size))  # (seq_length, batch_size)


In [27]:
import copy

state_dict = model.state_dict()

state_dict1 = copy.deepcopy(state_dict)

In [28]:
src_data[:-1, :].shape , src_data.shape


(torch.Size([4, 10]), torch.Size([5, 10]))

### Functions to get the intermediate outputs 

#### Function to fetch word embeddings with help of token indices

In [29]:
def look_up_table(sentence, vocab_embeds, embedding):

    for i in range(sentence.size(0)):
        for j in range(sentence.size(1)):
            
            # Get the index for the current word token index in the sequence
            word_index = sentence[i, j].item()

            if word_index < 0 or word_index >= vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            embedding[i, j, :] = vocab_embeds[word_index, :]

            print(f"Word index: {word_index}, Embedding: {vocab_embeds[word_index, :]}")
    print()

    return embedding

### Embeddings and Positional encoding

In [30]:
def get_embedding_outputs(src_data, max_seq_len, state_dict, d_model):

    src_vocab_embeds = state_dict["src_embedding.weight"]

    src_embedding = torch.zeros(src_data.size(0), src_data.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_data, src_vocab_embeds, src_embedding)
    print(src_embedding.shape)

    pe = PositionalEncoding(d_model = d_model, dropout=0, max_len=max_seq_len)

    print("PE of src data:")
    print(pe(src_data).transpose(0,1))
    print()

    pe_src_embeds = src_embedding + pe(src_data).transpose(0,1)

    print("PE source embeddings : \n")
    print(pe_src_embeds)
    print()

    return pe_src_embeds



## Decoder function to display the intermediate outputs and get the final outputs from the decoder

### Masked self attention 

#### Functions to perform the attention calculation with Q,K and V matrices

In [31]:
def atten_product_needs_wts_false(Q, V, K, bsz, head_dim, src_len, tgt_len, embed_dim, attn_mask):


    # *** For multi-head attention ***
    #  (bsz*num_heads, src_len , head_dim) -> (bsz, num_heads, tgt_len, head_dim)
    Q1 = Q.view(bsz, num_heads, tgt_len, head_dim)
    K1 = K.view(bsz, num_heads, src_len, head_dim)
    V1 = V.view(bsz, num_heads, src_len, head_dim)


    L, S = Q1.size(-2), K1.size(-2)

    scale_factor = 1 / math.sqrt(Q1.size(-1)) 
    # scale_factor = 1
    attn_bias = torch.zeros(L, S, dtype=Q1.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            # attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))

            masked_tensor = attn_mask.float().masked_fill(attn_mask, float('-inf'))
            masked_tensor = masked_tensor.masked_fill(~attn_mask, 0)
            attn_mask = masked_tensor

            print("Attnetion mask infunction = ")
            print(attn_mask)
            print()
            attn_bias = attn_bias.unsqueeze(0).unsqueeze(0)
            attn_bias += attn_mask

        else:
            attn_bias += attn_mask
            attn_bias = attn_bias.unsqueeze(0).unsqueeze(0)


    # print("Attnetion bias = ", attn_bias.shape)
    # print(attn_bias)
    # print()
            

    # (bsz, num_heads, tgt_len, head_dim) @ (bsz, num_heads, head_dim, tgt_len) -> (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = Q1 @ K1.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias


    # print("INTERMEDIATE PDT = ", attn_weight)


    # (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = torch.softmax(attn_weight, dim=-1)

    # print("ATTN PDT = ", attn_weight)

    sum_last_dim = attn_weight.sum(dim=-1)
    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"

    # (bsz, num_heads, tgt_len, tgt_len) @ (bsz, num_heads, tgt_len, head_dim) -> (bsz, num_heads, tgt_len, head_dim) 
    attn_output = attn_weight @ V1

    # print("Dot product attention  = ")
    # print(attn_weight.shape, attn_weight)

    # print(attn_output.shape)
    # print(bsz, tgt_len, embed_dim)
    
    # (bsz*tgt_len, embed_dim)
    attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

    print("Attention output = ")
    print(attn_weight.shape, attn_weight)

    return attn_output



def atten_product_needs_wts_true(Q, K, V, bsz, tgt_len, embed_dim, attn_mask):

    # *** For multi-head attention ***
    #  (bsz*num_heads, src_len , head_dim)
    
    B, Nt, E = Q.shape

    Q_scaled = Q / math.sqrt(E)

    if attn_mask is not None:
        temp_pdt_matrix = torch.baddbmm(attn_mask, Q_scaled, K.transpose(-2, -1))
    else:
        temp_pdt_matrix = torch.bmm(Q_scaled, K.transpose(-2, -1))

    attn_wt_matrix = torch.nn.functional.softmax(temp_pdt_matrix, dim=-1)

    attn_output = torch.bmm(attn_wt_matrix, V)

    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)


    sum_last_dim = attn_wt_matrix.sum(dim=-1)

    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"


    print("Encoder Attention output = ")
    print(attn_output)
    print()

    return attn_output, attn_wt_matrix


#### Function to get the Q,K,V matrices from the model's intialised weights

In [32]:
def get_qkv(query, key, value ,W, b):

    # embed_dim
    E = query.size(-1)

    if key is value:
        if query is key:
            
            # (src_len, bsz, embed_dim) @ (embed_dim*num_heads, embed_dim).T -> (src_len, bsz, embed_dim*num_heads)
            tempop1 = query@W.T

            # reshape to 3, E and not E, 3 is deliberate for better memory coalescing and keeping same order as chunk()
            tempop1 = tempop1.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()

            # (src_len, bsz, embed_dim)
            return tempop1[0], tempop1[1], tempop1[2]
        

        else:

            # (embed_dim*1, embed_dim)
            # (embed_dim*2, embed_dim)
            W_q, W_kv = W.split([E, E * 2])

            if b is None:
                b_q = b_kv = None
            else:
                b_q, b_kv = b.split([E, E * 2])

            # (src_len, bsz, embed_dim) @ (embed_dim*1, embed_dim).T -> (src_len, bsz, embed_dim)
            q_matmul = query@W_q.T

            # (src_len, bsz, embed_dim) @ (embed_dim*2, embed_dim).T -> (src_len, bsz, embed_dim*2)
            kv_matmul = key@W_kv.T

            kv_matmul = kv_matmul.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()

            # (src_len, bsz, embed_dim)
            return q_matmul, kv_matmul[0], kv_matmul[1]

    else:

        W_q, W_k, W_v = W.chunk(3)
        if b is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = b.chunk(3)


        q_matmul = query@W_q.T
        k_matmul = key@W_k.T
        v_matmul = value@W_v.T

        return q_matmul, k_matmul, v_matmul



    

#### Decoder block's self attention output function

In [33]:
def decoder_block_self_attn_output(x, state_dict, layer_num, num_heads, tgt_mask = None,need_weights = False):

    # (tgt_len, bsz, embed_dim)
    query_dec = key_dec = value_dec = x

    tgt_len, bsz, embed_dim = x.shape
 
    # (embed_dim*num_heads, embed_dim)
    W_dec = state_dict["layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_dec = state_dict["layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads

    # (tgt_len, bsz, embed_dim)
    Q_dec,K_dec,V_dec = get_qkv(query_dec, key_dec, value_dec ,W_dec, b_dec)
    
    # Q_dec = Q_dec.unsqueeze(0)
    # K_dec = K_dec.unsqueeze(0)
    # V_dec = V_dec.unsqueeze(0)

    # (1, tgt_len, bsz, embed_dim)
    # print(Q_dec.shape, K_dec.shape , V_dec.shape)
    # print(tgt_len, bsz * num_heads, head_dim)

    # (1, tgt_len, bsz, embed_dim) -> ( bsz*num_heads, tgt_len , head_dim )
    Q_dec = Q_dec.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_dec = K_dec.reshape(K_dec.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_dec = V_dec.reshape(V_dec.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec)
    print()

    print("K_dec_{} = ".format(layer_num))
    print(K_dec)
    print()

    print("V_dec_{} = ".format(layer_num))
    print(V_dec)
    print()

    src_len = K_dec.size(1)


    attn_mask = tgt_mask
    if attn_mask is not None:

        # Ensuring attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

    # attn_mask can be either (L,S) or (N*num_heads, L, S)
    # if attn_mask's shape is (1, L, S) we need to unsqueeze to (1, 1, L, S)
    # in order to match the input for SDPA of (N, num_heads, L, S)
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)

    
    
    if need_weights is False:
        attn_output = atten_product_needs_wts_false(Q = Q_dec, V = V_dec, K = K_dec, bsz = bsz, head_dim=head_dim, src_len=src_len, tgt_len=tgt_len, attn_mask = attn_mask, embed_dim=embed_dim)

        print("Decoder Self Attention = ")
        print(attn_output)
        print()

        op_dec_1 = torch.matmul(attn_output, state_dict["layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_output.size(1))

        return attn_dec_output, None
    
    else:

        attn_dec_output,attn_wt_matrix_dec = atten_product_needs_wts_true(Q=Q_dec, K=K_dec, V=V_dec, bsz=bsz, tgt_len=tgt_len, attn_mask = attn_mask, embed_dim=embed_dim)

        print("Decoder Attention output = ")
        print(attn_wt_matrix_dec)
        print()

        op_dec_1 = torch.matmul(attn_dec_output, state_dict["layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_dec_output.size(1))
    

        return attn_dec_output, attn_wt_matrix_dec


    

### Post self attention in decoder

#### Function to perform the linear layer calculations after deccoder's self attention

In [34]:
def dec_post_self_attn(x, attn_dec_output, state_dict, layer_num, bsz, tgt_len):

    # (bsz*src_len , embed_dim) @ (embed_dim , embed_dim).T -> (bsz*src_len , embed_dim)
    # op_dec_1 = torch.matmul(attn_dec_output, state_dict["layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["layers.{}.self_attn.out_proj.bias".format(layer_num)]

    # print(op_dec_1.shape)

    # # (bsz*src_len , embed_dim) -> (src_len, bsz, embed_dim)

    # print()

    # attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_dec_output.size(1))


    # Here src is the original passed inputs to the 1st transformer encoder layer which 
    # are pe_src_embeds 

    # (src_len, bsz, embed_dim)
    output_dec_1 = attn_dec_output + x

    #  (src_len, bsz, embed_dim) @ (embed_dim) -> (src_len, bsz, embed_dim) 
    linear_result_dec_1 = output_dec_1*state_dict["layers.{}.norm1.weight".format(layer_num)] + state_dict["layers.{}.norm1.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_dec_1 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_1.shape[2:])
    linear_op_dec_1 = layernorm_dec_1(linear_result_dec_1)


    # Manual Layer Normalization
    x = linear_result_dec_1

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_1.weight
    b = layernorm_dec_1.bias

    linear_result_dec_1f = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_1f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_1 = (linear_result_dec_1f - mean) / (std + epsilon) * w + b


    op_dec_1 = torch.matmul(normalized_result_dec_1, state_dict["layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["layers.{}.linear1.bias".format(layer_num)]
    op_dec_1_relu = torch.nn.functional.relu(op_dec_1)
    op_dec_2 = torch.matmul(op_dec_1_relu, state_dict["layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["layers.{}.linear2.bias".format(layer_num)]


    output_dec_2 = op_dec_2 + linear_op_dec_1
    output_dec_2_norm = output_dec_2*state_dict["layers.{}.norm2.weight".format(layer_num)] + state_dict["layers.{}.norm2.bias".format(layer_num)]

    # Layer normalization from Torch's implementation
    layernorm_dec_final = torch.nn.LayerNorm(normalized_shape=output_dec_2_norm.shape[2:])
    output_dec_final = layernorm_dec_final(output_dec_2_norm)


    # Manual Layer Normalization 
    x = output_dec_2_norm

    # Obtained layer norm weights and biases (learnable)
    w = layernorm_dec_final.weight
    b = layernorm_dec_final.bias

    linear_result_dec_2 = w*x + b

    epsilon = 1e-5  
    mean = linear_result_dec_2.mean(dim=-1, keepdim=True)
    std = linear_result_dec_2.std(dim=-1, unbiased=False, keepdim=True)
    output_dec_final = (linear_result_dec_2 - mean) / (std + epsilon) * w + b

    print("Final Encoder {} Output :".format(layer_num))
    print("norm2(norm1(x + self_atten(x)) + feed_fwd_op)\n")
    print(output_dec_final)
    print()

    # (src_len, bsz, embed_dim) 
    return output_dec_final


#### Function to perform the linear layer calculations after transformer blocks

In [35]:
def feef_fwd_transformer(dec_output_final, state_dict):

    # (tgt_len, bsz, embed_dim) @ (vocab_size, embed_dim).T -> (tgt_len, bsz, vocab_size)
    final_op = dec_output_final@state_dict["fc.weight"].T + state_dict["fc.bias"]

    return final_op

In [36]:
def get_all_intermediate_outputs_mask(src_data ,d_model, state_dict , num_decoder_layers, tgt_mask, max_seq_len, d_ff):

    pe_src_embeds = get_embedding_outputs(src_data=src_data,  state_dict=state_dict, max_seq_len=max_seq_len, d_model = d_model)
    print("###"*25)
    print("### Decoder Start ###")
    print()

    x_dec = pe_src_embeds
    for lno in range(num_decoder_layers):
        attn_dec_output, attn_wt_matrix = decoder_block_self_attn_output(x_dec, state_dict, layer_num = lno, num_heads=num_heads, need_weights = False, tgt_mask=tgt_mask)


        if lno == 0:
            tgt_len, bsz, embed_dim = x_dec.shape

        output_dec_final = dec_post_self_attn(x_dec, attn_dec_output, state_dict, layer_num = lno , bsz = bsz, tgt_len = tgt_len)

        x_dec = output_dec_final

    
    print("### Decoder End ###")
    
    final_op = feef_fwd_transformer(x_dec, state_dict)

    return final_op
    

In [37]:
from torch import functional as F

need_weights = False

src_mask = None

tgt_mask = None

memory_mask = None

embed_dim = 6

num_heads = 2

max_seq_len = 5

num_decoder_layers = 2
 

def generate_square_subsequent_mask(self, tgt):
        seq_length = tgt.size(0)
        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

        return nopeak_mask

In [38]:
def generate_mask(src):
    seq_length = src.size(0)
    nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()
    # tgt_mask = tgt_mask & nopeak_mask
    return nopeak_mask

tgt_mask = generate_mask(src_data[:-1, :])

In [39]:
final_op = get_all_intermediate_outputs_mask(src_data[:-1, :], state_dict = state_dict1 , num_decoder_layers = num_decoder_layers, d_model=d_model,  d_ff = d_ff, tgt_mask = tgt_mask, max_seq_len = src_data[:-1, :].shape[0])


Source sentence embedding
Word index: 2, Embedding: tensor([ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959])
Word index: 15, Embedding: tensor([-1.8817, -0.0497, -1.0450, -0.9565,  0.0335,  0.7101])
Word index: 9, Embedding: tensor([ 0.5433, -0.3952, -0.4462,  0.7440,  1.5210,  3.4105])
Word index: 16, Embedding: tensor([ 1.6459, -1.3602,  0.3446,  0.5199, -2.6133, -1.6965])
Word index: 5, Embedding: tensor([ 0.9463, -0.8437, -0.6136,  0.0316, -0.4927,  0.2484])
Word index: 4, Embedding: tensor([ 0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863])
Word index: 12, Embedding: tensor([-1.6293, -0.5497, -0.4798, -0.4997, -1.0670,  1.1149])
Word index: 11, Embedding: tensor([ 1.1108,  1.2899, -1.4782,  2.5672, -0.4731,  0.3356])
Word index: 7, Embedding: tensor([-0.2897,  0.0525,  0.5229,  2.3022, -1.4689, -1.5867])
Word index: 10, Embedding: tensor([-1.5312, -1.2341,  1.8197, -0.5515, -0.5692,  0.9200])
Word index: 12, Embedding: tensor([-1.6293, -0.5497, -0.4798, -0.4997, -1.0670,