In [1]:
import torch
import torch.nn as nn
import math


In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)

        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


In [10]:
class ScaledDotProductAttention(nn.Module):
    def forward(self, Q, K, V, mask=None):
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn = torch.softmax(scores, dim=-1)
        return torch.matmul(attn, V)


In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, heads):
        super().__init__()
        assert d_model % heads == 0

        self.d_k = d_model // heads
        self.heads = heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.fc = nn.Linear(d_model, d_model)
        self.attn = ScaledDotProductAttention()

    def forward(self, Q, K, V, mask=None):
        B = Q.size(0)

        Q = self.W_q(Q).view(B, -1, self.heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(B, -1, self.heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(B, -1, self.heads, self.d_k).transpose(1, 2)

        out = self.attn(Q, K, V, mask)
        out = out.transpose(1, 2).contiguous()
        out = out.view(B, -1, self.heads * self.d_k)

        return self.fc(out)


In [12]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.net(x)


In [13]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        x = self.norm1(x + self.attn(x, x, x, mask))
        x = self.norm2(x + self.ffn(x))
        return x


In [14]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads, d_ff):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, heads)
        self.enc_attn = MultiHeadAttention(d_model, heads)
        self.ffn = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        x = self.norm1(x + self.self_attn(x, x, x, tgt_mask))
        x = self.norm2(x + self.enc_attn(x, enc_out, enc_out, src_mask))
        x = self.norm3(x + self.ffn(x))
        return x


In [15]:
class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model, heads, d_ff, layers):
        super().__init__()
        self.src_embed = nn.Embedding(src_vocab, d_model)
        self.tgt_embed = nn.Embedding(tgt_vocab, d_model)
        self.pos = PositionalEncoding(d_model)

        self.encoder = nn.ModuleList([
            EncoderLayer(d_model, heads, d_ff)
            for _ in range(layers)
        ])

        self.decoder = nn.ModuleList([
            DecoderLayer(d_model, heads, d_ff)
            for _ in range(layers)
        ])

        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def forward(self, src, tgt, src_mask, tgt_mask):
        src = self.pos(self.src_embed(src))
        tgt = self.pos(self.tgt_embed(tgt))

        for layer in self.encoder:
            src = layer(src, src_mask)

        for layer in self.decoder:
            tgt = layer(tgt, src, src_mask, tgt_mask)

        return self.fc_out(tgt)


In [16]:
def create_mask(src, tgt):
    src_mask = (src != 0).unsqueeze(1).unsqueeze(2)

    size = tgt.size(1)
    no_peak = torch.tril(torch.ones(size, size)).bool()
    tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(2) & no_peak

    return src_mask, tgt_mask


In [17]:
model = Transformer(
    src_vocab=1000,
    tgt_vocab=1000,
    d_model=64,
    heads=8,
    d_ff=256,
    layers=2
)

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

src = torch.randint(1, 999, (2, 10))
tgt = torch.randint(1, 999, (2, 10))

src_mask, tgt_mask = create_mask(src, tgt)

output = model(src, tgt[:, :-1], src_mask, tgt_mask[:, :, :-1, :-1])

loss = criterion(
    output.reshape(-1, 1000),
    tgt[:, 1:].reshape(-1)
)

loss.backward()
optimizer.step()

print("Training loss:", loss.item())


Training loss: 6.86077356338501


In [18]:
# Hyperparameters
vocab_size = 1000
d_model = 64
num_heads = 8
d_ff = 256
num_layers = 2
seq_len = 10
batch_size = 2

# Dummy input
x = torch.randint(0, vocab_size, (batch_size, seq_len))

# Model
model = TransformerEncoder(vocab_size, d_model, num_heads, d_ff, num_layers)

output = model(x)
print("Output shape:", output.shape)


Output shape: torch.Size([2, 10, 64])
