<a href="https://colab.research.google.com/github/Kushagra481/Attention_in_Transformers/blob/main/Encoder_and_Multihead.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Mulit Head

In [1]:
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask=None):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = self.value(values)
        keys = self.key(keys)
        query = self.query(query)

        values = values.view(N, value_len, self.heads, self.head_dim).transpose(1, 2)
        keys = keys.view(N, key_len, self.heads, self.head_dim).transpose(1, 2)
        query = query.view(N, query_len, self.heads, self.head_dim).transpose(1, 2)

        energy = torch.matmul(query, keys.transpose(-1, -2)) / math.sqrt(self.head_dim)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy, dim=-1)
        out = torch.matmul(attention, values)
        out = out.transpose(1, 2).contiguous().view(N, query_len, self.embed_size)
        out = self.fc_out(out)
        return out


## Encoder Decoder ✈

In [2]:
import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, ff_hidden),
            nn.ReLU(),
            nn.Linear(ff_hidden, embed_size)
        )

    def forward(self, x, mask):
        attn = self.attention(x, x, x, mask)
        x = self.norm1(attn + x)
        forward = self.feed_forward(x)
        x = self.norm2(forward + x)
        return x

class Encoder(nn.Module):
    def __init__(self, input_dim, embed_size, heads, depth, ff_hidden):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(input_dim, embed_size)
        self.layers = nn.ModuleList([
            TransformerBlock(embed_size, heads, ff_hidden) for _ in range(depth)
        ])

    def forward(self, x, mask):
        x = self.embed(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, ff_hidden):
        super(DecoderBlock, self).__init__()
        self.attn1 = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.attn2 = MultiHeadAttention(embed_size, heads)
        self.norm2 = nn.LayerNorm(embed_size)
        self.ff = nn.Sequential(
            nn.Linear(embed_size, ff_hidden),
            nn.ReLU(),
            nn.Linear(ff_hidden, embed_size)
        )
        self.norm3 = nn.LayerNorm(embed_size)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        attn1 = self.attn1(x, x, x, tgt_mask)
        x = self.norm1(attn1 + x)
        attn2 = self.attn2(enc_out, enc_out, x, src_mask)
        x = self.norm2(attn2 + x)
        ff = self.ff(x)
        x = self.norm3(ff + x)
        return x

class Decoder(nn.Module):
    def __init__(self, output_dim, embed_size, heads, depth, ff_hidden):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(output_dim, embed_size)
        self.layers = nn.ModuleList([
            DecoderBlock(embed_size, heads, ff_hidden) for _ in range(depth)
        ])
        self.fc_out = nn.Linear(embed_size, output_dim)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        x = self.embed(x)
        for layer in self.layers:
            x = layer(x, enc_out, src_mask, tgt_mask)
        return self.fc_out(x)

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, embed_size=256, heads=8, depth=6, ff_hidden=512):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, embed_size, heads, depth, ff_hidden)
        self.decoder = Decoder(tgt_vocab_size, embed_size, heads, depth, ff_hidden)

    def forward(self, src, tgt, src_mask, tgt_mask):
        enc_out = self.encoder(src, src_mask)
        out = self.decoder(tgt, enc_out, src_mask, tgt_mask)
        return out
