<a href="https://colab.research.google.com/github/ShyamSundhar1411/My-ML-Notebooks/blob/master/Transformers/Attention_Is_All_You_Need.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
import pandas as pd

In [1]:
!pip install torch torchvision

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Attention Mechanism

In [3]:
import torch
import torch.nn as nn

In [6]:
class SelfAttention(nn.Module):
  def __init__(self,embed_size,heads):
    super(SelfAttention,self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_dim = embed_size//heads

    assert(self.head_dim*heads == embed_size), "Embed Size need to be divisible by Heads"
    self.values = nn.Linear(self.head_dim,self.head_dim,bias = False)
    self.keys = nn.Linear(self.head_dim,self.head_dim,bias = False)
    self.queries = nn.Linear(self.head_dim,self.head_dim,bias = False)
    self.fully_connected_output = nn.Linear(embed_size,embed_size)
  
  def forward(self,values,keys,query,mask):
    n_samples = query.shape[0]
    value_len,key_len,query_len = values.shape[1],keys.shape[1],query.shape[1]

    #Split Embedding into head pieces

    values = values.reshape(n_samples,value_len,self.heads,self.head_dim)
    keys = keys.reshape(n_samples,key_len,self.heads,self.head_dim)
    queries = keys.reshape(n_samples,query_len,self.heads,self.head_dim)

    energy = torch.einsum("nqhd,nkhd->nhqk")

    #Query Shape = sample,query_len,heads,heads_dim
    if mask is not None:
      energy = energy.masked_fill(mask==0,float("-1e20"))
    attention = torch.softmax(energy/(self.embed_size**(0.5)),dim = 3)
    #Attention Shape: (N,heads,querylen,key_len)
    #Value Shape: (N,value_len,heads,head_size)
    #(N,query_len,heads,head_dim)
    output = torch.einsum("nhql,nlhd->nqhd",[attention,values]).reshape(n_samples,query_len,self.embed_size)
    output = self.fully_connected_output(output)
    return output


## Transformer Block

In [8]:
class TransformerBlock(nn.Module):
  def __init__(self,embed_size,heads,dropout,forward_expansion):
    super(TransformerBlock,self).__init__()
    self.attention = SelfAttention(embed_size,heads)
    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)
    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size,forward_expansion*embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion*embed_size,embed_size)
    )
    self.dropout = nn.Dropout(dropout)
  
  def forward(self,value,key,query,mask):
    attention = self.attention(value,key,query,mask)
    x = self.dropout(self.norm1(attention+query))
    forward = self.feed_forward(x)
    out = self.dropout(self.norm2(forward+x))
    return out

## Encoder Block

In [None]:
class Encoder(nn.Module):
  def __init__(
      self,src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,
      dropout,max_length):
    super(Encoder).__init__()
    self.embed_size = embed_size
    