<a href="https://colab.research.google.com/github/Akor-Michael/Implementation-of-Transformer-Architecture/blob/main/Attention_is_all_you_need_implementation_by_Akor_Michael.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

 **Implementation of Transformer Model from the LandMark paper "Attention is All you Need" by Akor Michael. The last two cells provide training for the Transformer on your data.**


In [None]:
# Importing the necessary Libraries
import torch
import torch.nn as nn
import math
from torch.utils.data import DataLoader, Dataset


In [None]:
# Checking if a GPU is available, otherwise using the CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the PositionalEncoding class
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, embed_size):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(0.1)

        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2).float() * (-math.log(10000.0) / embed_size))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)



In [None]:
# Implementing Self Attention
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (
            self.head_dim * heads == embed_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query, mask):
        # Get number of training examples
        N = query.shape[0]

        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(query)

        # Split the embedding into self.heads different pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)

        # Calculate energy values
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        # Apply masking if provided
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Calculate attention scores
        attention = torch.nn.functional.softmax(
            energy / (self.embed_size ** (1 / 2)), dim=3
        )

        # Compute the output using attention scores
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        # Apply final linear transformation
        out = self.fc_out(out)
        return out


In [None]:
#  Implementing Multi Head Attention
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.heads = heads
        self.head_dim = embed_size // heads

        self.values = nn.ModuleList([nn.Linear(embed_size, self.head_dim) for _ in range(heads)])
        self.keys = nn.ModuleList([nn.Linear(embed_size, self.head_dim) for _ in range(heads)])
        self.queries = nn.ModuleList([nn.Linear(embed_size, self.head_dim) for _ in range(heads)])

        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        attention = []
        for i in range(self.heads):
            values_proj = self.values[i](values)
            keys_proj = self.keys[i](keys)
            queries_proj = self.queries[i](query)
            attention.append(
                SelfAttention(self.head_dim, 1)(values_proj, keys_proj, queries_proj, mask)
            )

        # Concatenate the attention scores from all heads
        attention = torch.cat(attention, dim=2)

        # Apply linear transformation to get the final output
        out = self.fc_out(attention)
        return out


In [None]:
# Feed Foward Network with ReLU for Non Linearity
class FeedForward(nn.Module):
    def __init__(self, embed_size, feedforward_dim, dropout):
        super(FeedForward, self).__init__()
        self.fc1 = nn.Linear(embed_size, feedforward_dim)
        self.fc2 = nn.Linear(feedforward_dim, embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [None]:
# Implementation of Actual transformer Block with  multi-head self-attention, layer normalization, feed forward network, and dropout layers
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.feed_forward = FeedForward(embed_size, forward_expansion * embed_size, dropout)
        self.norm2 = nn.LayerNorm(embed_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention_output = self.attention(value, key, query, mask)
        x = self.dropout(self.norm1(attention_output + query))
        feed_forward_output = self.feed_forward(x)
        out = self.dropout(self.norm2(feed_forward_output + x))
        return out


In [None]:
# Encoder comprising of multiple Transformer Blocks stacked together
class Encoder(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length,
    ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = PositionalEncoding(max_length, embed_size)

        # Stack of Transformer Block layers
        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion,
                )
                for _ in range(num_layers)
            ]
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(
            (self.word_embedding(x) + self.position_embedding(positions))
        )

        # Passing the input through the stacked Transformer Block layers
        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out


In [None]:
# Decoder Block comprising of masked multi-head self-attention and a subsequent Transformer block
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout, device):
        super(DecoderBlock, self).__init__()
        self.norm = nn.LayerNorm(embed_size)
        self.attention = MultiHeadAttention(embed_size, heads=heads)
        self.transformer_block = TransformerBlock(
            embed_size, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, value, key, src_mask, trg_mask):
        # Multi-head self-attention on the decoder input
        attention = self.attention(x, x, x, trg_mask)
        query = self.dropout(self.norm(attention + x))

        # Pass the query through the transformer block with encoder output
        out = self.transformer_block(value, key, query, src_mask)
        return out


In [None]:
# Decoder which consists of stacks of multiple Decoder Blocks
class Decoder(nn.Module):
    def __init__(
        self,
        trg_vocab_size,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout,
        device,
        max_length,
    ):
        super(Decoder, self).__init__()
        self.device = device
        self.word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.position_embedding = PositionalEncoding(max_length, embed_size)

        # Create a stack of DecoderBlock layers
        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout, device)
                for _ in range(num_layers)
            ]
        )

        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, src_mask, trg_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout((self.word_embedding(x) + self.position_embedding(positions)))

        # Pass the input through the stacked Decoder Block layers
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, trg_mask)

        # Final linear transformation to obtain output
        out = self.fc_out(x)

        return out


In [None]:
# Integrating the encoder and decoder to create the complete Transformer model
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=512,
        num_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0,
        device=device,
        max_length=100,
    ):

        super(Transformer, self).__init__()

        # Create Encoder and Decoder
        self.encoder = Encoder(
            src_vocab_size,
            embed_size,
            num_layers,
            heads,
            device,
            forward_expansion,
            dropout,
            max_length,
        )

        self.decoder = Decoder(
            trg_vocab_size,
            embed_size,
            num_layers,
            heads,
            forward_expansion,
            dropout,
            device,
            max_length,
        )

        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, 1, trg_len, trg_len
        )
        return trg_mask.to(self.device)

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        out = self.decoder(trg, enc_src, src_mask, trg_mask)
        return out


In [None]:
# Creating an instance of the Transformer Model and passing sample input data through the model it to obtain the output.
# Sample input data
src_input = torch.tensor([[1, 5, 6, 4, 3, 9, 5, 2, 0], [1, 8, 7, 3, 4, 5, 6, 7, 2]]).to(device)
trg_input = torch.tensor([[1, 7, 4, 3, 5, 9, 2, 0], [1, 5, 6, 2, 4, 7, 6, 2]]).to(device)

# Vocabulary sizes and padding indices
src_vocab_size = 10
trg_vocab_size = 10
src_pad_idx = 0
trg_pad_idx = 0

# Create the Transformer model
transformer_model = Transformer(
    src_vocab_size,
    trg_vocab_size,
    src_pad_idx,
    trg_pad_idx,
    device=device
).to(device)

# Pass data through the model
output = transformer_model(src_input, trg_input[:, :-1])

# The output shape will vary based on the target sequence length and vocabulary size
print("Output shape:", output.shape)


In [None]:
# Training Loop, Define your dataset(Required), dataloader(Required), loss function, and optimizer. See below cell for example

train_dataloader =   # Training dataloader
loss_fn = nn.CrossEntropyLoss() # Using Cross Entropy Loss by Default
optimizer = torch.optim.Adam(transformer_model.parameters(), lr=0.001) # Using ADAM Optimizer by default

# Training loop
num_epochs = 100 # Epoch Count default
for epoch in range(num_epochs):
    transformer_model.train()
    total_loss = 0

    for batch in train_dataloader:
        src_input = batch["src_input"].to(device)
        trg_input = batch["trg_input"].to(device)
        trg_target = batch["trg_target"].to(device)

        optimizer.zero_grad()
        output = transformer_model(src_input, trg_input)

        output_dim = output.shape[-1]
        output = output.contiguous().view(-1, output_dim)
        trg_target = trg_target.contiguous().view(-1)

        loss = loss_fn(output, trg_target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")


In [None]:
# Example to load dataset, Replace sample data to train on your own embeddings

# Sample dataset class
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

# Sample input data
sample_data = [
    {"src_input": [1, 5, 6, 4, 3, 9, 5, 2], "trg_input": [1, 7, 4, 3, 5, 9, 2], "trg_target": [7, 4, 3, 5, 9, 2, 0]},
    {"src_input": [1, 8, 7, 3, 4, 5, 6, 7], "trg_input": [1, 5, 6, 2, 4, 7, 6], "trg_target": [5, 6, 2, 4, 7, 6, 2]}
]

# Preprocessing function
def preprocess_data(data, src_pad_idx, trg_pad_idx):
    preprocessed_data = []
    for example in data:
        src_input = example["src_input"]
        trg_input = example["trg_input"]
        trg_target = example["trg_target"]

        src_input = src_input + [src_pad_idx] * (10 - len(src_input))  # Padding
        trg_input = trg_input + [trg_pad_idx] * (10 - len(trg_input))  # Padding
        trg_target = trg_target + [trg_pad_idx] * (10 - len(trg_target))  # Padding

        preprocessed_data.append({
            "src_input": src_input,
            "trg_input": trg_input,
            "trg_target": trg_target
        })
    return CustomDataset(preprocessed_data)

# Define vocabulary sizes and padding indices
src_vocab_size = 11
trg_vocab_size = 11
src_pad_idx = 0
trg_pad_idx = 0

# Preprocess the data
preprocessed_dataset = preprocess_data(sample_data, src_pad_idx, trg_pad_idx)
train_dataloader = DataLoader(preprocessed_dataset, batch_size=2, shuffle=True)

# Print the first batch from the dataloader
for batch in train_dataloader:
    print(batch)
    break
