<a href="https://colab.research.google.com/github/Krithika25/A-study-on-the-repercussions-of-the-Covid-19-pandemic-in-the-mental-health/blob/main/Transformer_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

class SelfAttention(nn.Module):

  def __init__(self, temp, dropout=0.1):
    super().__init__()
    self.temp=temp
    self.dropout=nn.Dropout(dropout)
  
  def forward(self, query, key, value, mask):

    # Normalize it with (key_dim)^(1/2)
    q=query/self.temp

    # dot_p of shape query_size x query_size
    dot_p=torch.matmul(q, key.transpose(1,2))

    if mask is not None:
        
      dot_p=dot_p.masked_fill(mask==0,float(-1e9))
    
    attention=self.dropout(F.softmax(dot_p))

    # final_output is of same shape as value
    final_output=torch.matmul(attention, value)
  
    return final_output, attention


# Multihead Attention
class MultiheadAttention(nn.Module):

  def __init__(self, heads, model_dim, key_dim, value_dim, dropout=0.1):
    super().__init__()
    self.heads=heads
    self.model_dim=model_dim
    self.batch_size=1
    #self.head_dim=model_dim // heads
    self.key_dim=key_dim
    self.value_dim=value_dim
    self.dropout=nn.Dropout(dropout)
    self.attention=SelfAttention(temp=key_dim**(1/2))
    self.q_s=nn.Linear(model_dim, heads*key_dim, bias=False)
    self.k_s=nn.Linear(model_dim, heads*key_dim, bias=False)
    self.v_s=nn.Linear(model_dim, heads*value_dim, bias=False)
    self.fc=nn.Linear(heads*self.key_dim, model_dim, bias=False)

    self.norm=nn.LayerNorm(model_dim, eps=1e-6)

  def forward(self, query, key, value, mask=None, batch_size=10):

    #Divide the input queries, keys and values into heads to perform Attention
    
    q = self.q_s(query).view(batch_size*self.heads, query.size(1), self.key_dim) # q is of shape [heads, query_size, model_dims]
    k = self.k_s(key).view(batch_size*self.heads, key.size(1),  self.key_dim) # k is of shape [heads, key_size, model_dims]
    v = self.v_s(value).view(batch_size*self.heads, value.size(1), self.value_dim) # v is of shape [heads, query_size, model_dims]

    que, attent=self.attention(q,k,v,mask)
    que=que.transpose(1, 2).contiguous().view(query.size(0), query.size(1), self.heads*self.key_dim) # que is of shape [query_size, model_dims]

    #Apply Dropout
    que = self.dropout(self.fc(que))
    
    # Skip connection
    que += query

    #Perform Layer Normalization
    que = self.norm(que)

    return que, attent

#Feed Forward Neural Network
class FeedForward(nn.Module):

  def __init__(self, input, hidden, dropout=0.1):
    super().__init__()
    #Feed Forward Neural network with given input and hidden layers
    self.feed_forward=nn.Sequential(
        nn.Linear(input, hidden),
        nn.ReLU(),
        nn.Linear(hidden, input)
    )
    self.dropout=nn.Dropout(dropout)
    self.norm=nn.LayerNorm(input, eps=1e-6)
  
  def forward(self, x):
    #Implement Feed Forward network layer
    res=self.feed_forward(x)
    res=self.dropout(res)

    #Including Skip connections
    res+=x

    #Apply Layer Normalization
    res=self.norm(res)
    return res

class DecoderBlock(nn.Module):

  def __init__(self, model_dim, hidden_dim, head, value_dim, key_dim, dropout=0.1):
    super(DecoderBlock, self).__init__()
    self.key_dim=key_dim
    self.model_dim=model_dim
    self.attention=MultiheadAttention(head, model_dim, key_dim, value_dim, dropout)
    self.global_attention=MultiheadAttention(head, model_dim, key_dim, value_dim, dropout)
    self.cross_attention=MultiheadAttention(head, model_dim, key_dim, value_dim, dropout)
    self.ffn=FeedForward(model_dim, hidden_dim,dropout)
    self.output_ffn=nn.Linear(10, model_dim*model_dim, bias=False)
    self.dropout=nn.Dropout(dropout)

  def forward(self, dec_input, output, batch_size, attn_mask=None):

    #The Multihead Self Attention block implemented on the Decoder part
    dec_output, dec_slf_attn = self.attention(dec_input, dec_input, dec_input, mask=attn_mask)

    #Cross_Attention
    enc_output=self.output_ffn(output.float())
    enc_output=enc_output.view(batch_size, self.model_dim, self.model_dim)
    dec_output,dec_crs_attn=self.cross_attention(dec_output, enc_output, enc_output)

    #Apply Feed Forward Layer
    dec_output = self.ffn(dec_output)
    return dec_output, dec_slf_attn


class Decoder(nn.Module):
  def __init__(self, n_layers, heads, key_dim, value_dim, model_dim, hidden_dim, dropout=0.1):
    super().__init__()
    self.dropout=nn.Dropout(dropout)
    self.decoder_layer=nn.ModuleList([DecoderBlock(model_dim, hidden_dim, heads, value_dim, key_dim, dropout) for i in range(n_layers)])
    self.norm=nn.LayerNorm(model_dim, eps=1e-16)
    self.model_dim=model_dim
  
  def forward(self, tar_seq, trg_mask, output, batch_size, return_attns=False):
    dec_input = self.dropout(tar_seq)
    dec_input = self.norm(dec_input)
    for dec_layer in self.decoder_layer:
            #Call the decoder block
            dec_output, dec_slf_attn = dec_layer(dec_input, output, batch_size, attn_mask=trg_mask)
    return dec_output

class Transformer(nn.Module):
  
  def __init__(self, model_dim=30, hidden_dim=2048, n_layers=6, heads=6, key_dim=5, value_dim=5, dropout=0.1):
    super().__init__()

    self.model_dim=model_dim

    #Decoder block
    self.decoder = Decoder(n_layers, heads, key_dim, value_dim, model_dim, hidden_dim, dropout)

    #Project the decoder output
    self.trg_img=nn.Linear(1, 255,bias=False)

    #Calculate probabilities
    self.soft=nn.Softmax(dim=2)

  def forward(self, trg_seq, one_hot, batch_size=10):

    #Get the mask
    trg_mask = get_trg_mask(trg_seq)

    #Decoder Block
    dec_output = self.decoder(trg_seq.float(), trg_mask, one_hot, batch_size)

    #Project the decoder output to linear
    img=self.trg_img(dec_output.view(batch_size, self.model_dim*self.model_dim, 1))
    img=img.view(batch_size, self.model_dim*self.model_dim, 255)

    #Predict the value 0-255 for each pixel
    pred=self.soft(img)

    #Sample the probabilities using top-k
    ind=torch.topk(pred, k=50, dim=2)

    #Assign the pixel value that has maximum probability for each pixel in each image
    img=ind[1][:,:,0].view(batch_size, self.model_dim, self.model_dim)

    return img.double(), ind

# Calculate Loss using Categorical Cross_entropy
def calc_loss(expected, pred):
  loss = F.cross_entropy(pred, expected, reduction='sum')
  return loss

class CustomizeOptim():
    def __init__(self, optimizer, model_dims=30, n_warmup_steps=400):
        self._optimizer = optimizer
        self.d_model = model_dims
        self.n_warmup_steps = n_warmup_steps
        self.n_steps = 0

    def step_and_update_lr(self):
        self.update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        d_model = self.d_model
        n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
        return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))

    def update_learning_rate(self):
        self.n_steps += 1
        lr = self._get_lr_scale()
        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

# Mask the future values
def get_trg_mask(seq):
    _, sz_b, len_s = seq.size()
    subsequent_mask = (1 - torch.triu(torch.ones((1, sz_b, sz_b), device=seq.device), diagonal=1)).bool()
    #Get 3 random attention rows
    tokens=random.sample(range(0,len(subsequent_mask[0])), 3)
    for i in range(len(tokens)):
        subsequent_mask[0][tokens[i]]=True # Set them to True to remove mask
    return subsequent_mask


def train(Res, y, epochs, optimizer):

  #Create one-hot vector
  one_hot=torch.zeros(len(y), 10)
  for i in range(len(y)):
    one_hot[i][y[i]]=1.0
  
  one_hot.requires_grad=True
  
  # Training
  for i in range(epochs):
    optimizer.zero_grad()
    A, B= t(Res, one_hot)

    # Calculate loss
    loss=calc_loss(A, Res)
    print("The loss at epoch ", i, " is ", loss.item())
    loss.backward()
    optimizer.step_and_update_lr()
  return A, B


A=0
B=0
t=Transformer()

if __name__=='__main__':
  # Load the dataset
  epochs=1000
  X=pd.read_csv('test_image.csv')
  y=pd.read_csv('test_label.csv')
  X['Label']=y
  data=pd.DataFrame(columns=X.columns)
  
  #Take a sample from each class
  for i in range(0,10):
    label=X.loc[X['Label']==i].sample(n=1, random_state=1)
    data=pd.concat([data, label])
    x=data.drop('Label',1)

  #Store it as a seperate file
  x.to_csv('data.csv', index=False)
  x=pd.read_csv('data.csv')
  A=np.array(x)
  A=A/1.0
  # Reshape
  A=A.reshape(len(A),28,28)
  
  
  # Pad it to 30x30 to make it suitable for multi-head attention
  Res=[np.pad(x, [(1, 1), (1, 1)], mode='constant') for x in A]
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
  Res=torch.tensor(Res, requires_grad=True).to(device)
  """
  Sequence length=30
  Model dimensions=30
  Heads=6
  """

  # Initializing Optimizer for training
  optimizer = CustomizeOptim(torch.optim.Adam(t.parameters(), betas=(0.9, 0.98), eps=1e-09) , 30)
  optimizer.zero_grad()
  y=[0,1,2,3,4,5,6,7,8,9]
  A,B=train(Res, y, epochs, optimizer)



The loss at epoch  0  is  96761360.34737034
The loss at epoch  1  is  96530500.96668693
The loss at epoch  2  is  96222000.7820535
The loss at epoch  3  is  96753530.41295753
The loss at epoch  4  is  96593526.45813347
The loss at epoch  5  is  97091484.75579935
The loss at epoch  6  is  97275523.23395145
The loss at epoch  7  is  97459437.16871452
The loss at epoch  8  is  96434764.4278722
The loss at epoch  9  is  96790430.27695534
The loss at epoch  10  is  97331074.43365873
The loss at epoch  11  is  97387572.62988098
The loss at epoch  12  is  96333714.53015077
The loss at epoch  13  is  96446389.99339753
The loss at epoch  14  is  96619412.67159665
The loss at epoch  15  is  95905262.06590624
The loss at epoch  16  is  97500411.46064006
The loss at epoch  17  is  96495309.15276065
The loss at epoch  18  is  97051391.16215828
The loss at epoch  19  is  97501431.20611137
The loss at epoch  20  is  97355982.68532906
The loss at epoch  21  is  97831037.7469835
The loss at epoch  22  