<a href="https://colab.research.google.com/github/DevanshPatel234/FMML_Project_and_Labs/blob/main/Implementation%20of%20Transformer%20using%20Pytorch3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Implementation of Transformer using Pytorch

Transformers are a type of deep learning model introduced in the paper "Attention is All You Need" by Vaswani et al. in 2017. They have revolutionized natural language processing (NLP) and have been adapted to other domains like computer vision. Here's an overview of the key concepts and components of transformers:

Key Concepts

Attention Mechanism:

Self-Attention: Allows the model to focus on different parts of the input sequence when encoding a particular word, enabling it to capture dependencies regardless of their distance in the sequence.
Scaled Dot-Product Attention: The primary mechanism in transformers, computing the attention scores as the dot product of queries (Q) and keys (K), scaled by the square root of the dimension of the keys. The scores are then passed through a softmax function to obtain weights applied to the values (V).

Positional Encoding:

Since transformers do not have a built-in notion of sequence order (unlike RNNs), positional encodings are added to the input embeddings to retain information about the position of tokens in the sequence.
Transformer Architecture

Encoder-Decoder Structure:

The transformer model is composed of an encoder and a decoder, both built using stacked layers of self-attention and feed-forward neural networks.

Encoder:

Consists of multiple identical layers, each with two main components:
Multi-Head Self-Attention: Enables the model to focus on different positions of the input sequence simultaneously.
Feed-Forward Neural Network: Applies a fully connected feed-forward network to each position independently.
Residual connections and layer normalization are used around each sub-layer to facilitate training.

Decoder:

Similar to the encoder but with an additional layer for multi-head attention over the encoder's output.

The decoder layers include:
Masked Multi-Head Self-Attention: Ensures that the predictions for a position only depend on known outputs up to that position.
Multi-Head Attention: Over the encoder’s output to incorporate information from the input sequence.
Feed-Forward Neural Network: As in the encoder.

Applications

Natural Language Processing:

Machine Translation: Translating text from one language to another.
Text Summarization: Generating concise summaries of long documents.
Question Answering: Finding answers to questions from a given context.
Computer Vision:

Vision Transformers (ViTs): Applying transformer models to image recognition tasks by dividing images into patches and treating them as sequences.

In [1]:
# Library
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

MultiHeadAttention

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))

        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

PositionWiseFeedForward

In [3]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

PositionalEncoding

In [4]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()

        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

EncoderLayer

In [5]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

DecoderLayer

In [6]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

Transformer Model

In [7]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = nn.Embedding(src_vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool()
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output

Preparing Sample Data

In [8]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

Training the model

In [9]:
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [10]:
transformer.train()

Transformer(
  (encoder_embedding): Embedding(5000, 512)
  (decoder_embedding): Embedding(5000, 512)
  (positional_encoding): PositionalEncoding()
  (encoder_layers): ModuleList(
    (0-5): 6 x EncoderLayer(
      (self_attn): MultiHeadAttention(
        (W_q): Linear(in_features=512, out_features=512, bias=True)
        (W_k): Linear(in_features=512, out_features=512, bias=True)
        (W_v): Linear(in_features=512, out_features=512, bias=True)
        (W_o): Linear(in_features=512, out_features=512, bias=True)
      )
      (feed_forward): PositionWiseFeedForward(
        (fc1): Linear(in_features=512, out_features=2048, bias=True)
        (fc2): Linear(in_features=2048, out_features=512, bias=True)
        (relu): ReLU()
      )
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (decoder_layers): ModuleList(
    (0-5): 6 x DecoderLayer(
 

In [11]:
for epoch in range(100):
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1])
    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[:, 1:].contiguous().view(-1))
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

Epoch: 1, Loss: 8.677218437194824
Epoch: 2, Loss: 8.535113334655762
Epoch: 3, Loss: 8.466719627380371
Epoch: 4, Loss: 8.414207458496094
Epoch: 5, Loss: 8.3574800491333
Epoch: 6, Loss: 8.290533065795898
Epoch: 7, Loss: 8.211080551147461
Epoch: 8, Loss: 8.123992919921875
Epoch: 9, Loss: 8.046772956848145
Epoch: 10, Loss: 7.965337753295898
Epoch: 11, Loss: 7.879910945892334
Epoch: 12, Loss: 7.810268402099609
Epoch: 13, Loss: 7.719799995422363
Epoch: 14, Loss: 7.633733749389648
Epoch: 15, Loss: 7.556983470916748
Epoch: 16, Loss: 7.469480991363525
Epoch: 17, Loss: 7.3881659507751465
Epoch: 18, Loss: 7.309781074523926
Epoch: 19, Loss: 7.2236409187316895
Epoch: 20, Loss: 7.147046089172363
Epoch: 21, Loss: 7.063130855560303
Epoch: 22, Loss: 6.987451553344727
Epoch: 23, Loss: 6.903902530670166
Epoch: 24, Loss: 6.831885814666748
Epoch: 25, Loss: 6.750979423522949
Epoch: 26, Loss: 6.677883625030518
Epoch: 27, Loss: 6.604788303375244
Epoch: 28, Loss: 6.538572311401367
Epoch: 29, Loss: 6.4709300994

In [15]:
def evaluate(model, src, tgt):
    model.eval()
    with torch.no_grad():
        output = model(src, tgt[:, :-1])
        _, predicted = torch.max(output, dim=-1)
        correct = (predicted == tgt[:, 1:]).sum().item()
        total = (tgt[:, 1:] != 0).sum().item()
        accuracy = correct / total
    return accuracy

In [16]:
accuracy = evaluate(transformer, src_data, tgt_data)
print(f"Accuracy: {accuracy}")

Accuracy: 0.9919507575757576
