<a href="https://colab.research.google.com/github/aditya161205/Data-DaVinci/blob/main/Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.optim as optim
import math
import copy

coding multihead attention

In [3]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_mod,num_heads):
    super(MultiHeadAttention, self).__init__()
    self.d_mod=d_mod #this is embeding size
    self.num_heads=num_heads
    self.d_k=d_mod//num_heads

    self.W_k=nn.linear(self.d_mod,self.d_mod)
    self.W_v=nn.linear(self.d_mod,self.d_mod)
    self.W_q=nn.linear(self.d_mod,self.d_mod)
    self.W_o=nn.linear(self.d_mod,self.d_mod)

  def ScaledDotProduct(self,Q,K,V,mask=False):
    attn=torch.matmul(Q,K.transpose(-2,-1))/ math.sqrt(self.d_k)
    if mask is not None:
        attn = attn.masked_fill(mask == 0, float('-inf'))
        #if the value at mask is 0 mean that we have to ignore it so we assign -inf which is squeezed to 0 by softmax
    attn_prob=torch.softmax(attn,dim=-1)
    attn_val=torch.matmul(attn_prob,V)
    return attn_val

  def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

  def combine_heads(self, x):
      batch_size, _, seq_length, d_k = x.size()
      return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

  def forward(self,Q,K,V,mask):
    Q = self.split_heads(self.W_q(Q))
    K = self.split_heads(self.W_k(K))
    V = self.split_heads(self.W_v(V))

    attn_val=self.ScaledDotProduct(Q,K,V,mask)
    output=self.W_o(self.combine_heads(attn_val))

    return output


In [4]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_mod, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.l1=nn.linear(d_mod,d_ff)
        self.l2=nn.linear(d_ff,d_mod)
        self.activation=nn.ReLU()
    def forward(self,x):
      x=self.l1(x)
      x=self.activation(x)
      x=self.l2(x)
      return x

taking the code for positional embeding from medium...sorry

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_mod, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_mod)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_mod, 2).float() * -(math.log(10000.0) / d_mod))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [7]:
class EncoderLayer(nn.Module):
  def __init__(self, d_mod, num_heads, d_ff, dropout):
    super(EncoderLayer, self).__init__()
    self.attention=MultiHeadAttention(d_mod,num_heads)
    self.feed_forward=PositionWiseFeedForward(d_mod,d_ff)
    self.norm1=nn.LayerNorm(d_mod)
    self.dropout1=nn.Dropout(dropout)
    self.norm2=nn.LayerNorm(d_mod)
    self.dropout2=nn.Dropout(dropout)

  def forward(self, x, mask):
    attn=self.attention(x,x,x,mask)
    x=self.norm1(x+self.dropout1(attn))
    ff_out=self.feed_forward(x)
    x = self.norm2(x + self.dropout2(ff_out))
    return x


In [8]:
class DecoderLayer(nn.Module):
  def __init__(self,d_mod,num_heads,d_ff,dropout):
    super(DecoderLayer,self).__init__()
    self.self_attn=MultiHeadAttention(d_mod,num_heads)
    self.cross_attn=MultiHeadAttention(d_mod,num_heads)
    self.feen_forward=PositionWiseFeedForward(d_mod,d_ff)
    self.norm1=nn.LayerNorm(d_mod)
    self.dropout1=nn.Dropout(dropout)
    self.norm2=nn.LayerNorm(d_mod)
    self.dropout2=nn.Dropout(dropout)
    self.norm3=nn.LayerNorm(d_mod)
    self.dropout3=nn.Dropout(dropout)


  def forward(self,x,enc_output,src_mask,trgt_mask):
    self_attn_score=self.self_attn(x,x,x,trgt_mask)
    x=self.norm1(x+self.dropout1(self_attn_score))
    cross_attn_score=self.cross_attn(enc_output,enc_output,x,trgt_mask)
    x=self.norm2(x+self.dropout2(cross_attn_score))
    feed_output=self.feen_forward(x)
    x=self.norm3(x+self.dropout3(feed_output))
    return x