In [2]:
import torch
import math
from torch import nn
import torch.nn.functional as F


These imports bring in:

torch – the PyTorch library for tensor operations.

math – used for mathematical constants and functions (e.g., square root).

nn – PyTorch’s neural network module, where all layers and models come from.

F – gives access to functions like softmax and activation operations.

In [3]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.size()[-1]
    scaled = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(d_k)
    if mask is not None:
        scaled += mask
    attention = F.softmax(scaled, dim=-1)
    values = torch.matmul(attention, v)
    return values, attention


Purpose: Compute attention scores between queries (Q), keys (K), and values (V).

q, k, v shapes: [batch, heads, seq_len, head_dim].

Scaling: Divide by √d_k to stabilize gradients.

Masking: Adds large negative numbers (-inf) to prevent attending to future tokens.

Softmax: Converts scores into probabilities.

Weighted sum: Applies attention weights to values (V).

Output: Returns the attended values and attention map

In [4]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x


Applies two linear transformations with a ReLU in between.

Expands from 512 → 2048 → 512.

Adds nonlinearity and dropout for better generalization.

Operates independently on each position (hence “position-wise”).

In [5]:
class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))
        self.eps = eps

    def forward(self, inputs):
        mean = inputs.mean(dim=-1, keepdim=True)
        var = ((inputs - mean) ** 2).mean(dim=-1, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (inputs - mean) / std
        out = self.gamma * y + self.beta
        return out


Normalizes each feature across the last dimension (e.g., embedding size).

Gamma and Beta allow rescaling and shifting after normalization.

Prevents internal covariate shift, helping stabilize training.

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.qkv_layer = nn.Linear(d_model, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_length, _ = x.size()
        qkv = self.qkv_layer(x)
        qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask)
        values = values.reshape(batch_size, seq_length, self.num_heads * self.head_dim)
        out = self.linear_layer(values)
        return out


Projects input into queries, keys, and values.

Splits them into multiple heads for parallel attention learning.

Each head focuses on a different subspace of representation.

Concatenates and linearly transforms them back to d_model dimension.

In [7]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_layer = nn.Linear(d_model, 2 * d_model)
        self.q_layer = nn.Linear(d_model, d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, y, mask=None):
        batch_size, seq_length, _ = x.size()
        kv = self.kv_layer(x)
        q = self.q_layer(y)
        kv = kv.reshape(batch_size, seq_length, self.num_heads, 2 * self.head_dim)
        q = q.reshape(batch_size, seq_length, self.num_heads, self.head_dim)
        kv = kv.permute(0, 2, 1, 3)
        q = q.permute(0, 2, 1, 3)
        k, v = kv.chunk(2, dim=-1)
        values, attention = scaled_dot_product(q, k, v, mask)
        values = values.reshape(batch_size, seq_length, self.d_model)
        out = self.linear_layer(values)
        return out


Used in Decoder–Encoder connections.

K and V come from the encoder output (x),
while Q comes from the decoder (y).

Helps the decoder attend to relevant encoder features.

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = LayerNormalization([d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model, num_heads)
        self.norm2 = LayerNormalization([d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)
        self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
        self.norm3 = LayerNormalization([d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, decoder_mask):
        _y = y
        y = self.self_attention(y, mask=decoder_mask)
        y = self.dropout1(y)
        y = self.norm1(y + _y)

        _y = y
        y = self.encoder_decoder_attention(x, y)
        y = self.dropout2(y)
        y = self.norm2(y + _y)

        _y = y
        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.norm3(y + _y)
        return y


**Flow:**

**Masked Self-Attention:**
The decoder looks at previous tokens (future ones masked).

**Cross-Attention:**
Connects to the encoder’s output for context.

**Feedforward Network:**
Adds nonlinearity and deeper representation power.

**Residual + LayerNorm:**
After each step, residual connections and normalization stabilize gradients.

In [9]:
class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, mask = inputs
        for module in self._modules.values():
            y = module(x, y, mask)
        return y


**Stacks multiple DecoderLayers.**

Feeds the output of one layer into the next in sequence.

In [10]:
class Decoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers=1):
        super().__init__()
        self.layers = SequentialDecoder(*[
            DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
            for _ in range(num_layers)
        ])

    def forward(self, x, y, mask):
        y = self.layers(x, y, mask)
        return y

Combines multiple layers into a full decoder stack.

Each layer repeats self-attention → cross-attention → feedforward.

num_layers defines decoder depth (here, 5 layers).

In [11]:
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 200
ffn_hidden = 2048
num_layers = 5

x = torch.randn((batch_size, max_sequence_length, d_model))
y = torch.randn((batch_size, max_sequence_length, d_model))
mask = torch.full([max_sequence_length, max_sequence_length], float('-inf'))
mask = torch.triu(mask, diagonal=1)
decoder = Decoder(d_model, ffn_hidden, num_heads, drop_prob, num_layers)
out = decoder(x, y, mask)

In [12]:
print(out)

tensor([[[ 0.8483, -1.5130, -1.0294,  ..., -1.7971, -0.6627, -0.2493],
         [-0.1286,  0.4190,  0.5515,  ..., -1.0953,  1.8280,  0.2260],
         [-1.0160, -0.7728,  0.8552,  ..., -0.3770,  0.7206, -1.7009],
         ...,
         [-1.4675, -1.3779,  0.4696,  ..., -0.7159, -0.5289,  0.3690],
         [-0.3545,  0.9258,  1.3856,  ...,  0.1064, -0.7601, -1.3119],
         [ 0.5759, -0.3608, -0.6522,  ...,  0.1425, -0.3651,  0.8350]],

        [[-0.8018, -1.6261, -1.2081,  ...,  0.1909,  0.0668,  1.0435],
         [-2.7187,  0.7994,  0.7035,  ..., -0.6052,  0.3100, -0.4639],
         [-0.3874,  2.0777,  1.1576,  ..., -0.2313, -1.9917,  1.4528],
         ...,
         [ 0.0914,  0.1942,  0.9081,  ..., -0.7985, -0.6961, -1.2718],
         [ 0.8409, -0.4229,  1.5350,  ...,  0.2965, -1.4381, -0.6059],
         [-1.8653,  0.6308, -0.3067,  ..., -0.0729, -0.1651,  0.2196]],

        [[-0.0908, -0.2383,  0.1555,  ...,  0.5616,  0.1010,  0.6915],
         [-0.2132, -1.6608, -1.1445,  ..., -0