In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)

        self.pe = pe.unsqueeze(0)   # shape: (1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)


In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_head = d_model // num_heads
        self.num_heads = num_heads

        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        self.WO = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        B, T, D = q.size()

        # Linear projections
        Q = self.WQ(q)  # (B,T,D)
        K = self.WK(k)
        V = self.WV(v)

        # Split into heads: (B, num_heads, T, d_head)
        Q = Q.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.d_head).transpose(1, 2)

        # Attention scores: (B, num_heads, T, T)
        scores = Q @ K.transpose(-2, -1) / (self.d_head ** 0.5)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        weights = F.softmax(scores, dim=-1)

        out = weights @ V  # (B, num_heads, T, d_head)

        # Merge heads
        out = out.transpose(1, 2).contiguous().view(B, T, D)

        return self.WO(out)


In [4]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=2048):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.net(x)


In [5]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ffn  = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        x = self.norm1(x + self.attn(x, x, x, mask))
        x = self.norm2(x + self.ffn(x))
        return x


In [6]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn  = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn        = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out, self_mask=None, cross_mask=None):
        x = self.norm1(x + self.self_attn(x, x, x, self_mask))
        x = self.norm2(x + self.cross_attn(x, enc_out, enc_out, cross_mask))
        x = self.norm3(x + self.ffn(x))
        return x


In [7]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_heads=4, d_ff=512, num_layers=3, max_len=200):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len)

        self.encoder = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)])

        self.fc_out = nn.Linear(d_model, vocab_size)

    def make_subsequent_mask(self, size):
        # Mask that prevents attending to future positions
        return torch.tril(torch.ones(size, size)).unsqueeze(0).unsqueeze(0)

    def forward(self, src, tgt):
        # ---- ENCODER ----
        src_emb = self.pos(self.embed(src))
        enc_out = src_emb
        for layer in self.encoder:
            enc_out = layer(enc_out)

        # ---- DECODER ----
        tgt_emb = self.pos(self.embed(tgt))
        self_mask = self.make_subsequent_mask(tgt.size(1)).to(tgt.device)

        dec_out = tgt_emb
        for layer in self.decoder:
            dec_out = layer(dec_out, enc_out, self_mask)

        # Output vocab logits
        return self.fc_out(dec_out)


In [8]:
vocab_size = 50
model = Transformer(vocab_size)

src = torch.tensor([[4, 17, 23, 8, 2]])       # shape (batch=1, src_len=5)
tgt = torch.tensor([[1, 7, 7, 9, 10, 2]])     # shape (batch=1, tgt_len=6)

logits = model(src, tgt)
print(logits.shape)


RuntimeError: shape '[1, 6, 4, 64]' is invalid for input of size 1280

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
import numpy as np

# =========================
# Positional Encoding
# =========================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)

        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)


# =========================
# Multi-Head Attention (fixed)
# =========================
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.WQ = nn.Linear(d_model, d_model)
        self.WK = nn.Linear(d_model, d_model)
        self.WV = nn.Linear(d_model, d_model)
        self.WO = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        B, Tq, D = q.shape
        Tk = k.shape[1]
        Tv = v.shape[1]

        Q = self.WQ(q)
        K = self.WK(k)
        V = self.WV(v)

        Q = Q.view(B, Tq, self.num_heads, self.d_head).transpose(1, 2)
        K = K.view(B, Tk, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(B, Tv, self.num_heads, self.d_head).transpose(1, 2)

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_head)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        weights = F.softmax(scores, dim=-1)
        out = weights @ V

        out = out.transpose(1, 2).contiguous().view(B, Tq, D)

        return self.WO(out)


# =========================
# Feed-Forward Network
# =========================
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff=512):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )

    def forward(self, x):
        return self.net(x)


# =========================
# Encoder Layer
# =========================
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=512):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        x = self.norm1(x + self.attn(x, x, x, mask))
        x = self.norm2(x + self.ff(x))
        return x


# =========================
# Decoder Layer
# =========================
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff=512):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.ff = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_out, self_mask=None, cross_mask=None):
        x = self.norm1(x + self.self_attn(x, x, x, self_mask))
        x = self.norm2(x + self.cross_attn(x, enc_out, enc_out, cross_mask))
        x = self.norm3(x + self.ff(x))
        return x


# =========================
# Full Transformer Model
# =========================
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, num_heads=4, num_layers=2, d_ff=512, max_len=200):
        super().__init__()

        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos = PositionalEncoding(d_model, max_len)

        self.encoder = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        self.decoder = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])

        self.fc_out = nn.Linear(d_model, vocab_size)

    def make_sub_mask(self, size):
        mask = torch.tril(torch.ones(size, size))
        return mask.unsqueeze(0).unsqueeze(0)

    def forward(self, src, tgt):
        src_emb = self.pos(self.embed(src))
        enc = src_emb
        for layer in self.encoder:
            enc = layer(enc)

        tgt_emb = self.pos(self.embed(tgt))
        tgt_mask = self.make_sub_mask(tgt.size(1)).to(tgt.device)

        dec = tgt_emb
        for layer in self.decoder:
            dec = layer(dec, enc, tgt_mask)

        return self.fc_out(dec)


# =========================
# Toy dataset
# =========================
train_pairs = [
    ("1 2 3", "one two three"),
    ("4 5 6", "four five six"),
    ("7 8", "seven eight"),
    ("9", "nine")
]

words = set()
for src, tgt in train_pairs:
    words.update(src.split())
    words.update(tgt.split())

words = ["<pad>", "<bos>", "<eos>"] + sorted(words)
stoi = {w: i for i, w in enumerate(words)}
itos = {i: w for w, i in stoi.items()}
vocab_size = len(words)

def encode(text):
    return [stoi["<bos>"]] + [stoi[t] for t in text.split()] + [stoi["<eos>"]]

def pad(seq, max_len=10):
    return seq + [stoi["<pad>"]] * (max_len - len(seq))

dataset = []
for src, tgt in train_pairs:
    s = pad(encode(src))
    t = pad(encode(tgt))
    dataset.append((torch.tensor(s), torch.tensor(t)))


# =========================
# Training loop
# =========================
model = Transformer(vocab_size)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss(ignore_index=stoi["<pad>"])

EPOCHS = 200
print("Training...")

for epoch in range(EPOCHS):
    total_loss = 0
    for src, tgt in dataset:
        src = src.unsqueeze(0)
        tgt_in = tgt[:-1].unsqueeze(0)
        tgt_out = tgt[1:].unsqueeze(0)

        logits = model(src, tgt_in)
        loss = criterion(logits.reshape(-1, vocab_size), tgt_out.reshape(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {total_loss:.3f}")


# =========================
# Inference
# =========================
def translate(model, src_text, max_len=10):
    model.eval()
    src = torch.tensor(pad(encode(src_text))).unsqueeze(0)

    tgt = torch.tensor([[stoi["<bos>"]]])

    for _ in range(max_len):
        logits = model(src, tgt)
        next_token = logits[0, -1].argmax().item()
        tgt = torch.cat([tgt, torch.tensor([[next_token]])], dim=1)
        if next_token == stoi["<eos>"]:
            break

    return " ".join(itos[t.item()] for t in tgt[0])

print("\nTranslations:")
print("1 2 3  -->", translate(model, "1 2 3"))
print("4 5 6  -->", translate(model, "4 5 6"))
print("9      -->", translate(model, "9"))


Training...
Epoch 0, Loss: 11.326
Epoch 20, Loss: 0.050
Epoch 40, Loss: 0.028
Epoch 60, Loss: 0.019
Epoch 80, Loss: 0.014
Epoch 100, Loss: 0.010
Epoch 120, Loss: 0.008
Epoch 140, Loss: 0.007
Epoch 160, Loss: 0.005
Epoch 180, Loss: 0.005

Translations:
1 2 3  --> <bos> one two three <eos>
4 5 6  --> <bos> four five six <eos>
9      --> <bos> nine <eos>
