# Transformer from Scratch
### Based on "Attention Is All You Need" (Vaswani et al., 2017)

This notebook implements the core components of the Transformer model,
including scaled dot-product attention, multi-head attention, positional
encoding, and encoder–decoder blocks, with explanations mapping directly
to the original research paper.


## Mapping to Research Paper

This implementation directly follows Section 3 of the paper
"Attention Is All You Need" (Vaswani et al., 2017):

- Section 3.1: Encoder–Decoder architecture
- Section 3.2.1: Scaled dot-product attention
- Section 3.2.2: Multi-head attention
- Section 3.2.3: Masked self-attention in decoder
- Section 3.5: Positional encoding


In [2]:
import torch
import torch.nn as nn
import math


In [3]:
# Model configuration (small on purpose for clarity)
d_model = 128
num_heads = 8
d_k = d_model // num_heads


## 1. Scaled Dot-Product Attention

Scaled dot-product attention computes the similarity between queries and keys
using a dot product, scales the result by √d_k to stabilize gradients, applies
softmax to obtain attention weights, and uses these weights to compute a
weighted sum of values.


In [4]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.scale = math.sqrt(d_k)

    def forward(self, Q, K, V, mask=None):
        # Q, K, V: (batch_size, seq_len, d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        # scores: (batch_size, seq_len, seq_len)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, V)

        return output, attention_weights


In [5]:
B, S, d_k = 2, 5, 16
Q = torch.rand(B, S, d_k)
K = torch.rand(B, S, d_k)
V = torch.rand(B, S, d_k)

attention = ScaledDotProductAttention(d_k)
output, attn_weights = attention(Q, K, V)

print("Output shape:", output.shape)
print("Attention weights shape:", attn_weights.shape)


Output shape: torch.Size([2, 5, 16])
Attention weights shape: torch.Size([2, 5, 5])


## 2. Multi-Head Attention

Multi-head attention allows the model to jointly attend to information from
different representation subspaces at different positions. Each head performs
scaled dot-product attention independently, and the results are concatenated
and linearly transformed.


In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention(self.d_k)

    def forward(self, Q, K, V, mask=None):
        B, S, _ = Q.size()

        Q = self.W_q(Q)
        K = self.W_k(K)
        V = self.W_v(V)

        Q = Q.view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(B, S, self.num_heads, self.d_k).transpose(1, 2)

        out, _ = self.attention(Q, K, V, mask)

        out = out.transpose(1, 2).contiguous()
        out = out.view(B, S, self.num_heads * self.d_k)

        return self.W_o(out)


In [7]:
B, S, d_model = 2, 6, 128
x = torch.rand(B, S, d_model)

mha = MultiHeadAttention(d_model, num_heads=8)
out = mha(x, x, x)

print("Output shape:", out.shape)


Output shape: torch.Size([2, 6, 128])


## 3. Positional Encoding

Since the Transformer does not use recurrence or convolution, positional
encoding is added to the input embeddings to provide information about
the relative or absolute position of tokens in the sequence. The original
paper uses fixed sinusoidal functions.


In [8]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer("pe", pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


In [9]:
B, S, d_model = 2, 10, 128
x = torch.zeros(B, S, d_model)

pe = PositionalEncoding(d_model)
out = pe(x)

print("Output shape:", out.shape)


Output shape: torch.Size([2, 10, 128])


## 4. Encoder Block

Each encoder layer consists of a multi-head self-attention mechanism
followed by a position-wise feed-forward network. Residual connections
and layer normalization are applied after each sub-layer.


In [10]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.net(x)


In [11]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        attn_out = self.mha(x, x, x, mask)
        x = self.norm1(x + attn_out)

        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)

        return x


In [12]:
B, S, d_model = 2, 8, 128
x = torch.rand(B, S, d_model)

encoder_block = EncoderBlock(d_model, num_heads=8)
out = encoder_block(x)

print("Encoder output shape:", out.shape)


Encoder output shape: torch.Size([2, 8, 128])


## 5. Decoder Block

Each decoder layer contains masked multi-head self-attention, followed by
encoder–decoder attention and a position-wise feed-forward network.
Masking prevents the model from attending to future positions during training.


In [13]:
def generate_subsequent_mask(size):
    mask = torch.tril(torch.ones(size, size))
    return mask


In [14]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.enc_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
        attn1 = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + attn1)

        attn2 = self.enc_attn(x, enc_out, enc_out, src_mask)
        x = self.norm2(x + attn2)

        ffn_out = self.ffn(x)
        x = self.norm3(x + ffn_out)

        return x


In [15]:
B, S, d_model = 2, 6, 128
x = torch.rand(B, S, d_model)
enc_out = torch.rand(B, S, d_model)

tgt_mask = generate_subsequent_mask(S)

decoder = DecoderBlock(d_model, num_heads=8)
out = decoder(x, enc_out, tgt_mask=tgt_mask)

print("Decoder output shape:", out.shape)


Decoder output shape: torch.Size([2, 6, 128])


## 6. Full Transformer Model

The full Transformer model consists of stacked encoder and decoder blocks.
Token embeddings are combined with positional encodings before being passed
through the encoder. The decoder generates outputs using masked self-attention
and encoder–decoder attention.


In [16]:
class Transformer(nn.Module):
    def __init__(self, d_model, num_heads, num_layers=2, vocab_size=1000):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)

        self.encoders = nn.ModuleList(
            [EncoderBlock(d_model, num_heads) for _ in range(num_layers)]
        )

        self.decoders = nn.ModuleList(
            [DecoderBlock(d_model, num_heads) for _ in range(num_layers)]
        )

        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        src = self.embedding(src)
        src = self.positional_encoding(src)

        for encoder in self.encoders:
            src = encoder(src, src_mask)

        tgt = self.embedding(tgt)
        tgt = self.positional_encoding(tgt)

        for decoder in self.decoders:
            tgt = decoder(tgt, src, src_mask, tgt_mask)

        return self.fc_out(tgt)


In [17]:
B, S = 2, 5
vocab_size = 1000

src = torch.randint(0, vocab_size, (B, S))
tgt = torch.randint(0, vocab_size, (B, S))

tgt_mask = generate_subsequent_mask(S)

model = Transformer(d_model=128, num_heads=8, num_layers=2, vocab_size=vocab_size)
out = model(src, tgt, tgt_mask=tgt_mask)

print("Transformer output shape:", out.shape)


Transformer output shape: torch.Size([2, 5, 1000])


In [18]:
criterion = nn.CrossEntropyLoss()

# Shift target for teacher forcing
loss = criterion(out.view(-1, vocab_size), tgt.view(-1))
print("Sample loss:", loss.item())


Sample loss: 7.0142412185668945


In [19]:
# Minimal training demonstration (sanity check)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(3):
    optimizer.zero_grad()

    src = torch.randint(0, vocab_size, (B, S))
    tgt = torch.randint(0, vocab_size, (B, S))
    tgt_mask = generate_subsequent_mask(S)

    output = model(src, tgt, tgt_mask=tgt_mask)
    loss = criterion(output.view(-1, vocab_size), tgt.view(-1))

    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")


Epoch 1, Loss: 7.1733
Epoch 2, Loss: 6.9667
Epoch 3, Loss: 7.1999
