In [1]:
# Transformer block replication from "Attention is All You Need"
# Simplified for training affordability (1 block only)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embedding size must be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        N, seq_len, embed_size = x.shape
        x = x.view(N, seq_len, self.heads, self.head_dim)
        values = self.values(x)
        keys = self.keys(x)
        queries = self.queries(x)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        attention = F.softmax(energy / math.sqrt(self.head_dim), dim=-1)
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, seq_len, embed_size)
        return self.fc_out(out)

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super().__init__()
        self.attention = MultiHeadSelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn = self.attention(x)
        x = self.dropout(self.norm1(attn + x))
        forward = self.feed_forward(x)
        return self.dropout(self.norm2(forward + x))

# Sample input for testing
if __name__ == "__main__":
    sample_input = torch.rand(32, 10, 512)  # (batch_size, seq_len, embed_size)
    block = TransformerBlock(embed_size=512, heads=8, dropout=0.1, forward_expansion=4)
    out = block(sample_input)
    print(out.shape)  # Expected: (32, 10, 512)


torch.Size([32, 10, 512])
