In [1]:
import torch
import torch.nn as nn
import math

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)  # shape: (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x


In [3]:
def scaled_dot_product_attention(q, k, v, mask=None):
    d_k = q.size(-1)
    scores = q @ k.transpose(-2, -1) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    attn = torch.softmax(scores, dim=-1)
    return attn @ v, attn


In [30]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        B, T, _ = q.size()

        def transform(x, linear):
            B, T, _ = x.size()
            x = linear(x)
            return x.view(B, T, self.num_heads, self.d_k).transpose(1, 2)

        q = transform(q, self.q_linear)
        k = transform(k, self.k_linear)
        v = transform(v, self.v_linear)

        scores, attn = scaled_dot_product_attention(q, k, v, mask)
        scores = scores.transpose(1, 2).contiguous().view(B, T, -1)
        return self.out(scores)


In [31]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(self.dropout(torch.relu(self.linear1(x))))


In [32]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        x = self.norm1(x + self.attn(x, x, x, mask))
        x = self.norm2(x + self.ff(x))
        return x

class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

    def forward(self, x, mask):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, mask)
        return x

In [33]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.enc_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out, src_mask, tgt_mask):
        x = self.norm1(x + self.self_attn(x, x, x, tgt_mask))
        x = self.norm2(x + self.enc_attn(x, enc_out, enc_out, src_mask))
        x = self.norm3(x + self.ff(x))
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

    def forward(self, x, enc_out, src_mask, tgt_mask):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x, enc_out, src_mask, tgt_mask)
        return x

In [34]:
class Transformer(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model, num_layers, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.encoder = Encoder(src_vocab, d_model, num_layers, num_heads, d_ff, dropout)
        self.decoder = Decoder(tgt_vocab, d_model, num_layers, num_heads, d_ff, dropout)
        self.fc_out = nn.Linear(d_model, tgt_vocab)

    def forward(self, src, tgt, src_mask, tgt_mask):
        enc_out = self.encoder(src, src_mask)
        dec_out = self.decoder(tgt, enc_out, src_mask, tgt_mask)
        return self.fc_out(dec_out)


In [35]:
def create_padding_mask(seq, pad_token=0):
    # seq: (batch_size, seq_len)
    return (seq != pad_token).unsqueeze(1).unsqueeze(2)  # shape: (B, 1, 1, T)

def create_look_ahead_mask(size):
    return torch.triu(torch.ones((size, size)), diagonal=1).bool()

def create_combined_mask(tgt_seq, pad_token=0):
    padding_mask = create_padding_mask(tgt_seq, pad_token)  # shape: (B, 1, 1, T)
    look_ahead_mask = create_look_ahead_mask(tgt_seq.size(1)).to(tgt_seq.device)  # shape: (T, T)
    look_ahead_mask = look_ahead_mask.unsqueeze(0).unsqueeze(1)  # shape: (1, 1, T, T)
    return padding_mask & ~look_ahead_mask


In [40]:
PAD_TOKEN_ID = 0
VOCAB_SIZE = 5000
SEQ_LEN = 32         # Max length of input/output token sequences
BATCH_SIZE = 16
NUM_EPOCHS = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Transformer(
    src_vocab=5000,
    tgt_vocab=5000,
    d_model=512,
    num_layers=6,
    num_heads=8,
    d_ff=2048
).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN_ID)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)


In [41]:
from torch.utils.data import Dataset, DataLoader
import torch

class DummyTranslationDataset(Dataset):
    def __init__(self, num_samples=1000, seq_len=SEQ_LEN, vocab_size=VOCAB_SIZE):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.vocab_size = vocab_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        src = torch.randint(1, self.vocab_size, (self.seq_len,))
        tgt = torch.randint(1, self.vocab_size, (self.seq_len,))
        src[torch.randint(0, self.seq_len, (1,))] = PAD_TOKEN_ID  # randomly add a PAD
        tgt[torch.randint(0, self.seq_len, (1,))] = PAD_TOKEN_ID
        return {'src': src, 'tgt': tgt}


In [42]:
dataset = DummyTranslationDataset()
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


In [43]:
for epoch in range(NUM_EPOCHS):
    model.train()
    for batch in dataloader:
        src = batch['src'].to(device)
        tgt = batch['tgt'].to(device)

        tgt_input = tgt[:, :-1]
        tgt_output = tgt[:, 1:]

        src_mask = create_padding_mask(src)
        tgt_mask = create_combined_mask(tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask)
        logits = logits.view(-1, logits.size(-1))
        tgt_output = tgt_output.contiguous().view(-1)

        loss = loss_fn(logits, tgt_output)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch + 1} | Loss: {loss.item():.4f}")


Epoch 1 | Loss: nan
Epoch 2 | Loss: nan
Epoch 3 | Loss: nan
Epoch 4 | Loss: nan
Epoch 5 | Loss: nan
