<a href="https://colab.research.google.com/github/Akhan1502/MBZUAI-Internship-Tasks/blob/main/Encoder_Decoder_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Encoder-decoder architecture with self-attention and multi-head attention

We use PyTorch for neural network building.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

This is where multi-head self-attention happens:

We project the input into Queries (Q), Keys (K), and Values (V)

We split into multiple heads to capture different attention patterns

We use scaled dot-product attention to compute how much each token should pay attention to others.

In [13]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        assert self.head_dim * heads == embed_size, "Embed size must be divisible by heads"

        # Linear layers for Q, K, V
        self.values = nn.Linear(embed_size, embed_size)
        self.keys = nn.Linear(embed_size, embed_size)
        self.queries = nn.Linear(embed_size, embed_size)

        # Final linear layer after attention
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, values, keys, query):
        N = query.shape[0]  # batch size
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]


        # Project Q, K, V
        values = self.values(values).view(N, value_len, self.heads, self.head_dim)
        keys   = self.keys(keys).view(N, key_len, self.heads, self.head_dim)
        queries= self.queries(query).view(N, query_len, self.heads, self.head_dim)

        # Scaled dot-product attention
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        attention = torch.softmax(energy / (self.embed_size ** 0.5), dim=3)

        # Weighted sum of values
        out = torch.einsum("nhqk,nkhd->nqhd", [attention, values])
        out = out.reshape(N, query_len, self.embed_size) # Reshape to match query length

        return self.fc_out(out)

The encoder:

Applies self-attention to its input sequence

Adds residual connections + Layer Normalization

Passes through a Feed-Forward Network

In [8]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion):
        super(EncoderLayer, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

    def forward(self, x):
        attention = self.attention(x, x, x)
        x = self.norm1(attention + x)  # Residual connection
        forward = self.feed_forward(x)
        return self.norm2(forward + x)


The decoder has two attention steps:

Self-Attention on the target sequence (so it learns relationships in the output sentence)

Encoder-Decoder Attention (cross-attention) to use encoder’s output

Feed-forward network

In [15]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion):
        super(DecoderLayer, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.encoder_attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.norm3 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )

    def forward(self, x, enc_out):
        attention = self.attention(x, x, x) # Self-attention on target sequence
        x = self.norm1(attention + x)

        # Cross-attention: query from target (x), keys and values from encoder output (enc_out)
        enc_attention = self.encoder_attention(enc_out, enc_out, x)
        x = self.norm2(enc_attention + x)

        forward = self.feed_forward(x)
        return self.norm3(forward + x)

We combine the encoder and decoder, plus embedding layers for words and a final linear layer for predictions.



In [10]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, embed_size=16, heads=2, forward_expansion=4):
        super(SimpleTransformer, self).__init__()
        self.src_embedding = nn.Embedding(vocab_size, embed_size)
        self.tgt_embedding = nn.Embedding(vocab_size, embed_size)

        self.encoder = EncoderLayer(embed_size, heads, forward_expansion)
        self.decoder = DecoderLayer(embed_size, heads, forward_expansion)

        self.fc_out = nn.Linear(embed_size, vocab_size)

    def forward(self, src, tgt):
        src_embed = self.src_embedding(src)
        tgt_embed = self.tgt_embedding(tgt)

        enc_out = self.encoder(src_embed)
        dec_out = self.decoder(tgt_embed, enc_out)

        out = self.fc_out(dec_out)
        return out, enc_out, dec_out


We make a tiny vocabulary for demonstration.


In [11]:
# Toy vocabulary
vocab = {"hello": 0, "world": 1, "i": 2, "am": 3, "chatgpt": 4, "<pad>": 5}
inv_vocab = {v: k for k, v in vocab.items()}

def tokenize(sentence):
    return [vocab.get(word, vocab["<pad>"]) for word in sentence.lower().split()]

def detokenize(tokens):
    return " ".join(inv_vocab.get(t, "<unk>") for t in tokens)


We give it a source and target sentence, run through the model, and print everything.


In [12]:
# Initialize model
model = SimpleTransformer(vocab_size=len(vocab), embed_size=8, heads=2)

# Example sentences
src_sentence = "hello world"
tgt_sentence = "i am chatgpt"

# Tokenize
src_tokens = torch.tensor([tokenize(src_sentence)])  # shape: (1, seq_len)
tgt_tokens = torch.tensor([tokenize(tgt_sentence)])

# Forward pass
output, enc_out, dec_out = model(src_tokens, tgt_tokens)

# Display results
print("\nSource sentence:", src_sentence)
print("Source tokens:", src_tokens.tolist())
print("Encoder output shape:", enc_out.shape)

print("\nTarget sentence:", tgt_sentence)
print("Target tokens:", tgt_tokens.tolist())
print("Decoder output shape:", dec_out.shape)

predicted_tokens = torch.argmax(output, dim=2).squeeze().tolist()
print("\nPredicted tokens:", predicted_tokens)
print("Predicted sentence:", detokenize(predicted_tokens))


RuntimeError: einsum(): subscript k has size 3 for operand 1 which does not broadcast with previously seen size 2

In [14]:
# Initialize model
model = SimpleTransformer(vocab_size=len(vocab), embed_size=8, heads=2)

# Example sentences
src_sentence = "hello world"
tgt_sentence = "i am chatgpt"

# Tokenize
src_tokens = torch.tensor([tokenize(src_sentence)])  # shape: (1, seq_len)
tgt_tokens = torch.tensor([tokenize(tgt_sentence)])

# Forward pass
output, enc_out, dec_out = model(src_tokens, tgt_tokens)

# Display results
print("\nSource sentence:", src_sentence)
print("Source tokens:", src_tokens.tolist())
print("Encoder output shape:", enc_out.shape)

print("\nTarget sentence:", tgt_sentence)
print("Target tokens:", tgt_tokens.tolist())
print("Decoder output shape:", dec_out.shape)

predicted_tokens = torch.argmax(output, dim=2).squeeze().tolist()
print("\nPredicted tokens:", predicted_tokens)
print("Predicted sentence:", detokenize(predicted_tokens))

RuntimeError: einsum(): subscript k has size 3 for operand 1 which does not broadcast with previously seen size 2

In [16]:
# Initialize model
model = SimpleTransformer(vocab_size=len(vocab), embed_size=8, heads=2)

# Example sentences
src_sentence = "hello world"
tgt_sentence = "i am chatgpt"

# Tokenize
src_tokens = torch.tensor([tokenize(src_sentence)])  # shape: (1, seq_len)
tgt_tokens = torch.tensor([tokenize(tgt_sentence)])

# Forward pass
output, enc_out, dec_out = model(src_tokens, tgt_tokens)

# Display results
print("\nSource sentence:", src_sentence)
print("Source tokens:", src_tokens.tolist())
print("Encoder output shape:", enc_out.shape)

print("\nTarget sentence:", tgt_sentence)
print("Target tokens:", tgt_tokens.tolist())
print("Decoder output shape:", dec_out.shape)

predicted_tokens = torch.argmax(output, dim=2).squeeze().tolist()
print("\nPredicted tokens:", predicted_tokens)
print("Predicted sentence:", detokenize(predicted_tokens))


Source sentence: hello world
Source tokens: [[0, 1]]
Encoder output shape: torch.Size([1, 2, 8])

Target sentence: i am chatgpt
Target tokens: [[2, 3, 4]]
Decoder output shape: torch.Size([1, 3, 8])

Predicted tokens: [1, 5, 1]
Predicted sentence: world <pad> world
