<a href="https://colab.research.google.com/github/ShyamSundhar1411/My-ML-Notebooks/blob/master/Transformers/Attention_Is_All_You_Need.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [100]:
import numpy as np
import pandas as pd

In [101]:
!pip install torch torchvision

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Attention Mechanism

In [102]:
import torch
import torch.nn as nn

In [103]:
class SelfAttention(nn.Module):
  def __init__(self,embed_size,heads):
    super(SelfAttention,self).__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_dim = embed_size//heads

    assert(self.head_dim*heads == embed_size), "Embed Size need to be divisible by Heads"
    self.values = nn.Linear(self.head_dim,self.head_dim,bias = False)
    self.keys = nn.Linear(self.head_dim,self.head_dim,bias = False)
    self.queries = nn.Linear(self.head_dim,self.head_dim,bias = False)
    self.fully_connected_output = nn.Linear(embed_size,embed_size)
  
  def forward(self,values,keys,query,mask):
    n_samples = query.shape[0]
    value_len,key_len,query_len = values.shape[1],keys.shape[1],query.shape[1]

    #Split Embedding into head pieces

    values = values.reshape(n_samples,value_len,self.heads,self.head_dim)
    keys = keys.reshape(n_samples,key_len,self.heads,self.head_dim)
    queries = query.reshape(n_samples,query_len,self.heads,self.head_dim)

    energy = torch.einsum("nqhd,nkhd->nhqk",[queries,keys])

    #Query Shape = sample,query_len,heads,heads_dim
    if mask is not None:
      energy = energy.masked_fill(mask == 0,float("-1e20"))
    attention = torch.softmax(energy/(self.embed_size**(0.5)),dim = 3)
    #Attention Shape: (N,heads,querylen,key_len)
    #Value Shape: (N,value_len,heads,head_size)
    #(N,query_len,heads,head_dim)
    output = torch.einsum("nhql,nlhd->nqhd",[attention,values]).reshape(n_samples,query_len,self.embed_size)
    output = self.fully_connected_output(output)
    return output


## Transformer Block

In [104]:
class TransformerBlock(nn.Module):
  def __init__(self,embed_size,heads,dropout,forward_expansion):
    super(TransformerBlock,self).__init__()
    self.attention = SelfAttention(embed_size,heads)
    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)
    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size,forward_expansion*embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion*embed_size,embed_size)
    )
    self.dropout = nn.Dropout(dropout)
  
  def forward(self,value,key,query,mask):
    attention = self.attention(value,key,query,mask)
    x = self.dropout(self.norm1(attention+query))
    forward = self.feed_forward(x)
    out = self.dropout(self.norm2(forward+x))
    return out

## Encoder Block

In [120]:
class Encoder(nn.Module):
  def __init__(self,src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length):
    super(Encoder,self).__init__()
    self.embed_size = embed_size
    self.device = device
    self.word_embedding = nn.Embedding(src_vocab_size,embed_size)
    self.position_embedding = nn.Embedding(max_length,embed_size)

    self.layers = nn.ModuleList(
        [
            TransformerBlock(embed_size,heads,dropout,forward_expansion) for _ in range(num_layers)
        ]
    )
    self.dropout = nn.Dropout(dropout)

  def forward(self,x,mask):
    n,seq_length = x.shape
    positions = torch.arange(0,seq_length).expand(n,seq_length).to(self.device)

    out = self.dropout(self.word_embedding(x)+self.position_embedding(positions))
    for layer in self.layers:
      out = layer(out,out,out,mask)
    return out


## Decoder Block

In [121]:
class DecoderBlock(nn.Module):
  def __init__(self,embed_size,heads,forward_expansion,dropout,device):
    super(DecoderBlock,self).__init__()
    self.attention = SelfAttention(embed_size,heads)
    self.norm = nn.LayerNorm(embed_size)
    self.transformer_block = TransformerBlock(
        embed_size,heads,dropout,forward_expansion
    )
    self.dropout = nn.Dropout(dropout)
  
  def forward(self,x,value,key,src_mask,target_mask):
    attention = self.attention(x,x,x,target_mask)
    query = self.dropout(self.norm(attention+x))
    out = self.transformer_block(value,key,query,src_mask)
    return out
  

In [133]:
class Decoder(nn.Module):
  def __init__(self,trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length):
    super(Decoder,self).__init__()
    self.device = device
    self.word_embedding = nn.Embedding(trg_vocab_size,embed_size)
    self.position_embedding = nn.Embedding(max_length,embed_size)
    self.layers = nn.ModuleList(
        [
            DecoderBlock(embed_size,heads,forward_expansion,dropout,device) for _ in range(num_layers)
        ]
    )
    self.forward_out = nn.Linear(embed_size,trg_vocab_size)
    self.dropout = nn.Dropout(dropout)

  def forward(self,x,encoder_out,src_mask,trg_mask):
    n,seq_length = x.shape
    positions = torch.arange(0,seq_length).expand(n,seq_length).to(self.device)
    x = self.dropout(self.word_embedding(x)+self.position_embedding(positions))
    for layer in self.layers:
      x = layer(x,encoder_out,encoder_out,src_mask,trg_mask)
    out = self.forward_out(x)
    return out


## Transformer

In [134]:
class Transformer(nn.Module):
  def __init__(self,src_vocab_size,trg_vocab_size,src_pad_index,trg_pad_index,embed_size = 256,num_layers = 6,forward_expansion = 4,heads = 8,dropout = 0,device = "cuda",max_length = 100):
    super(Transformer,self).__init__()
    self.encoder = Encoder(
        src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length

    )
    self.decoder = Decoder(
        trg_vocab_size,embed_size,num_layers,heads,forward_expansion,dropout,device,max_length
    )

    self.src_pad_index = src_pad_index
    self.trg_pad_index = trg_pad_index
    self.device = device

  def make_src_mask(self,src):
    src_mask = (src!=self.src_pad_index).unsqueeze(1).unsqueeze(2)
    return src_mask.to(self.device)
  def make_trg_mask(self,trg):
    n,trg_length = trg.shape
    trg_mask = torch.tril(torch.ones((trg_length,trg_length))).expand(
        n,1,trg_length,trg_length
    )
    return trg_mask.to(self.device)

  def forward(self,src,trg):
    src_mask = self.make_src_mask(src)
    trg_mask = self.make_trg_mask(trg)
    enc_src = self.encoder(src,src_mask)
    out = self.decoder(trg,enc_src,src_mask,trg_mask)
    return out

## Sample

In [135]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

x = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(
    device
)
trg = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10
trg_vocab_size = 10
model = Transformer(src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, device=device).to(device)
out = model(x, trg[:, :-1])
print(out.shape)

cuda
torch.Size([2, 7, 10])
