In [1]:
import torch
import torch.nn as nn
import math

In [None]:
## Defining all the transformer classes



class AttentionHead(nn.Module):
    # enc_dev signifies if its a encoder-decoder attention head
    def __init__(self,input_size,query_size,value_size,self_regress=False,enc_dec=False):
        super().__init__()
        self.wq=nn.Linear(input_size,query_size,bias=False) # W_q matrix
        self.wk=nn.Linear(input_size,query_size,bias=False) # W_k matrix
        self.wv=nn.Linear(input_size,value_size,bias=False) # W_v matrix
        self.ec=enc_dec # indicates whether this attention head is doing encoder-decoder attention
        self.self_regress=self_regress
    

    # computes the final vectors of each token
    # N -> Batch Size
    # L -> Sequence Lengtj
    # Q -> (N,L,eq)
    # K -> (N,L,ek)
    # V -> (N,L,ev)
    # mask -> (N,L,L)
    # out -> (N,L,ev)
    def SelfAttention(self,Q,K,V,mask):
        key_size=K.shape[-1]
        out=torch.matmul(Q,torch.transpose(K,1,2))
        # out=torch.div(out,math.sqrt(key_size))
        sft=nn.Softmax(dim=2)
        attention_weights=sft(torch.div(torch.add(out,mask),math.sqrt(key_size)))
        out=torch.matmul(attention_weights,V)
        return out




    # padding mask given in the form of [0s and 1s] 0-pay attention 1-donot pay attention
    # padding mask -> (N,L)
    # input -> (N,L,input_size)
    # self_regress: Boolean
    def forward(self,input,padding_mask,K_inp=None,V_inp=None):

        if not self.ec:
            K_inp=input
            V_inp=input
        # calculating the Q,K,V matrices
        Q=self.wq(input)
        K=self.wk(K_inp)
        V=self.wv(V_inp)
        
        # making the attention mask
        batch_size=input.shape[0]
        seqlen=input.shape[1]
        mask=torch.unsqueeze(padding_mask,1).repeat(1,input.shape[1],1)*float('-inf') # padding mask
        mask=torch.nan_to_num(mask,nan=0,neginf=float('-inf'))
        if self.self_regress:
            # self-regress mask
            selfRegressMask=torch.triu(torch.ones(batch_size,seqlen, seqlen) * float('-inf'), diagonal=1)
            mask=torch.add(mask,selfRegressMask)

        # computing self attention
        out=self.SelfAttention(Q,K,V,mask)
        return out,Q,K,V

        
        



In [None]:
class Multi_HeadAttention(nn.Module):
    # enc_dev signifies if its a encoder-decoder multi-head attention
    def __init__(self,head_count,input_size,query_size,value_size,self_regress=False,enc_dec=False):
        super().__init__()
        self.finLinear=nn.Linear(head_count*value_size,value_size)
        self.ec=enc_dec
        self.heads=[]
        for h in head_count:
            self.heads.append(AttentionHead(input_size,query_size,value_size,self_regress,enc_dec))
    


    # padding mask given in the form of [0s and 1s] 0-pay attention 1-donot pay attention
    # padding mask -> (N,L)
    # input -> (N,L,input_size)
    # self_regress: Boolean
    # returns ((N,L,ev),list of ks,list of vs)
    def forward(self,input,padding_mask,K_inp=None,V_inp=None):
        out_matrices=[]
        # if return_k_v:
        #     ks=[]
        #     vs=[]
        for head_id,head in enumerate(self.heads):
            headout=head(input,padding_mask,K_inp,V_inp) 
            out_matrices.append(headout[0])
        
        # concatenating and feeding through linear layer
        mh_out=self.finLinear(torch.cat(tuple(out_matrices),dim=2))

        return mh_out

In [None]:
# one encoder block
# take care of passing ks and vs to decoder
class EncoderBlock(nn.Module):
    def __init__(self,input_size,head_count):
        super().__init__()
        self.LN=nn.LayerNorm(input_size)
        self.feedForward=nn.Linear(input_size,input_size)
        self.multiHAttention=Multi_HeadAttention(head_count,input_size,input_size,input_size)

    # inputs -> (N,L,input_size) , these have to be positional encodings
    # padding mask -> (N,L)
    def forward(self,inputs,padding_mask):
        out1=self.multiHAttention(inputs,padding_mask)
        out1=self.LN(torch.add(inputs,out1))
        out=self.feedForward(out1)
        out=self.LN(torch.add(out1,out))
        return out,padding_mask


In [None]:
class DecoderBlock(nn.Module):
    def __init__(self,input_size,head_count):
        super().__init__()
        self.LN=nn.LayerNorm(input_size)
        self.feedForward=nn.Linear(input_size,input_size)
        self.multiHAttention=Multi_HeadAttention(head_count,input_size,input_size,input_size,self_regress=True)
        self.encdecAttention=Multi_HeadAttention(head_count,input_size,input_size,input_size,enc_dec=True)

    # inputs -> (N,L,input_size) , these have to be positional encodings
    # padding_mask_enc -> padding mask of encoder, needed in encoder decoder attention
    # padding mask -> (N,L)
    # K_inp,V_inp -> (N,L,input_size)
    def forward(self,inputs,padding_mask,K_inp,V_inp,padding_mask_enc):
        out1=self.multiHAttention(inputs,padding_mask)
        out1=self.LN(torch.add(inputs,out1))
        out2=self.encdecAttention(out1,padding_mask_enc,K_inp=K_inp,V_inp=V_inp)
        out2=self.LN(torch.add(out1,out2))
        out=self.feedForward(out2)
        out=self.LN(torch.add(out2,out))
        return out,padding_mask,K_inp,V_inp,padding_mask_enc

In [None]:
class EncoderStack(nn.Module):
    def __init__(self,layers,input_size,head_count):
        super().__init__()
        # using sequential
        encoderStack=nn.Sequential()
        for i in range(layers):
            encoderStack.append(EncoderBlock(input_size,head_count))
        self.es=encoderStack
    
    # inputs -> (N,L,input_size) , these have to be positional encodings
    # padding mask -> (N,L)
    def forward(self,inputs,padding_mask):
        out=self.es(inputs,padding_mask)
        return out


In [None]:
class DecoderStack(nn.Module):
    def __init__(self,layers,input_size,head_count):
        super().__init__()
        # using sequential
        decoderStack=nn.Sequential()
        for i in range(layers):
            decoderStack.append(DecoderBlock(input_size,head_count))
        self.ds=decoderStack
    
    # inputs -> (N,L,input_size) , these have to be positional encodings
    # padding mask -> (N,L)
    # padding_mask_enc -> (N,L)
    # enc_outputs -> (N,L,input_size)
    def forward(self,inputs,padding_mask,enc_outputs,padding_mask_enc):
        out=self.ds(inputs,padding_mask,enc_outputs,enc_outputs,padding_mask_enc)
        return out

In [None]:
class Transformer_custom(nn.Module):
    def __init__(self,layers,embedding_size,head_count,inp_vocab_size,out_vocab_size):
        super().__init__()
        # embedding layer for both encoder and decoder
        self.embeddingsEnc=nn.Embedding(inp_vocab_size,embedding_size,0) # pad token is at index 0
        self.embeddingsDec=nn.Embedding(out_vocab_size,embedding_size,0)
        # positional embedding layer
        # encoder layer
        self.encoder=EncoderStack(layers,embedding_size,head_count)
        # decoder layer
        self.decoder=DecoderStack(layers,embedding_size,head_count)
        self.toVocab=nn.Linear(embedding_size,out_vocab_size)
        self.sft=nn.Softmax(dim=2)

    # inputs,outputs -> (N,L,input_size)
    # inp_padding,out_padding -> (N,L)
    # returns out -> (N,L,out_vocab_size)
    def forward(self,inputs,inp_padding,outputs,out_padding):
        enc_embeddings=self.embeddingsEnc(inputs)
        # add positional embedding
        enc_outputs=self.encoder(enc_embeddings,inp_padding)
        
        dec_embeddings=self.embeddingsDec(outputs)
        # add positional embeddings
        out=self.decoder(dec_embeddings,out_padding,enc_outputs[0],enc_outputs[1])
        out=self.toVocab(out)
        out=self.sft(out)
        return out
        


In [None]:
# are input_size==query_size==value_size???
# add activations after linear layers
# add multiple layer norms
# add mask to encoder states to in enc-decoder side?? two masks? self-regress? -- got it
# keys and value matrices how propagated into decoder?? -- got it
# is the encoder decoder attention self regressed -- got it

In [18]:
# testing-------------------------------
# N -> Batch Size
    # L -> Sequence Lengtj
    # Q -> (N,L,eq)
    # K -> (N,L,ek)
    # V -> (N,L,ev)
    # mask -> (N,L,L)
    # out -> (N,L,ev)

# a=torch.tensor([[0,0,0,0,1,1],
#                 [0,0,0,1,1,1]],dtype=torch.float32)
# c=torch.tensor([[0,0,1,1,1,1],
#                 [0,0,0,1,1,1]],dtype=torch.float32)
# a=torch.unsqueeze(a,1)
# c=torch.unsqueeze(c,1)
# print(a.shape)
# b=torch.nan_to_num(a.repeat(1,4,1)*float('-inf'),nan=0,neginf=float('-inf'))
# d=torch.nan_to_num(c.repeat(1,4,1)*float('-inf'),nan=0,neginf=float('-inf'))
# f=torch.add(b,d)
# print(f)
# sft=nn.Softmax(dim=2)
# print(sft(f))
# a=[1,2,3]
# for k,v in enumerate(a):
#     print(k,v)

0 1
1 2
2 3
