### Cell 1 — Setup (imports, seed, device)

In [None]:
import math, random, time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cpu


### Cell 2 — Create a toy “translation” task + vocabulary + encode/decode

In [None]:
COLORS  = ["red","blue","green","yellow","black","white","orange","purple"]
OBJECTS = ["car","bike","house","shirt","phone","book","cup","bag"]
SEP = "and"

SPECIALS = ["<pad>", "<bos>", "<eos>", "<unk>"]
PAD, BOS, EOS, UNK = SPECIALS

def make_example(max_pairs=3):
    n = random.randint(1, max_pairs)
    pairs = [(random.choice(COLORS), random.choice(OBJECTS)) for _ in range(n)]

    src_tokens, tgt_tokens = [], []
    for i, (c, o) in enumerate(pairs):
        if i > 0:
            src_tokens.append(SEP)
            tgt_tokens.append(SEP)
        src_tokens += [c, o]
        tgt_tokens += [o, c]
    return src_tokens, tgt_tokens

# Build vocab
vocab_tokens = SPECIALS + sorted(set(COLORS + OBJECTS + [SEP]))
stoi = {tok: i for i, tok in enumerate(vocab_tokens)}
itos = {i: tok for tok, i in stoi.items()}

pad_id = stoi[PAD]
bos_id = stoi[BOS]
eos_id = stoi[EOS]
unk_id = stoi[UNK]

def encode(tokens):
    return [bos_id] + [stoi.get(t, unk_id) for t in tokens] + [eos_id]

def decode(ids):
    out = []
    for i in ids:
        tok = itos[int(i)]
        if tok in (BOS, PAD):
            continue
        if tok == EOS:
            break
        out.append(tok)
    return out

# Quick sanity check
src, tgt = make_example()
print("SRC:", " ".join(src))
print("TGT:", " ".join(tgt))
print("Encoded SRC:", encode(src))


SRC: blue car and black shirt and yellow house
TGT: car blue and shirt black and house yellow
Encoded SRC: [1, 8, 10, 4, 7, 18, 4, 20, 13, 2]


Cell 3 — Dataset + padding collate + DataLoaders

In [None]:
class ToyTranslationDataset(Dataset):
    def __init__(self, n_samples):
        self.data = [make_example(max_pairs=3) for _ in range(n_samples)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        src_tokens, tgt_tokens = self.data[idx]
        return (torch.tensor(encode(src_tokens), dtype=torch.long),
                torch.tensor(encode(tgt_tokens), dtype=torch.long))

def collate_batch(batch):
    src_batch, tgt_batch = zip(*batch)

    src_len = max(x.size(0) for x in src_batch)
    tgt_len = max(x.size(0) for x in tgt_batch)

    src_padded = torch.full((len(batch), src_len), pad_id, dtype=torch.long)
    tgt_padded = torch.full((len(batch), tgt_len), pad_id, dtype=torch.long)

    for i, (src, tgt) in enumerate(zip(src_batch, tgt_batch)):
        src_padded[i, :src.size(0)] = src
        tgt_padded[i, :tgt.size(0)] = tgt

    return src_padded, tgt_padded

train_ds = ToyTranslationDataset(n_samples=8000)
val_ds   = ToyTranslationDataset(n_samples=1000)

train_dl = DataLoader(train_ds, batch_size=128, shuffle=True,  collate_fn=collate_batch)
val_dl   = DataLoader(val_ds,   batch_size=128, shuffle=False, collate_fn=collate_batch)

print("Vocab size:", len(stoi))


Vocab size: 21


In [None]:
train_ds[0]

(tensor([ 1,  8,  6,  4, 14, 10,  4, 17,  6,  2]),
 tensor([ 1,  6,  8,  4, 10, 14,  4,  6, 17,  2]))

### Cell 4 — Define the Transformer Encoder–Decoder model

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

def pad_mask(x):
    # True where padding is present
    return x.eq(pad_id)

def subsequent_mask(sz, device):
    # True means "mask out" future tokens (causal)
    return torch.triu(torch.ones((sz, sz), device=device, dtype=torch.bool), diagonal=1)

class Seq2SeqTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, enc_layers=2, dec_layers=2, ff_dim=256, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        self.src_emb = nn.Embedding(vocab_size, d_model)
        self.tgt_emb = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, dropout)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=enc_layers,
            num_decoder_layers=dec_layers,
            dim_feedforward=ff_dim,
            dropout=dropout,
            batch_first=True
        )

        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt_in):
        src_key_padding_mask = pad_mask(src)
        tgt_key_padding_mask = pad_mask(tgt_in)
        tgt_mask = subsequent_mask(tgt_in.size(1), src.device)

        src_e = self.pos_enc(self.src_emb(src) * math.sqrt(self.d_model))
        tgt_e = self.pos_enc(self.tgt_emb(tgt_in) * math.sqrt(self.d_model))

        h = self.transformer(
            src_e, tgt_e,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_key_padding_mask,
        )
        return self.out(h)  # (batch, tgt_len, vocab_size)

model = Seq2SeqTransformer(vocab_size=len(stoi)).to(device)
print(model)


Seq2SeqTransformer(
  (src_emb): Embedding(21, 128)
  (tgt_emb): Embedding(21, 128)
  (pos_enc): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
          )
          (linear1): Linear(in_features=128, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=256, out_features=128, bias=True)
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )

### Cell 5 — Train (teacher forcing + loss + accuracy)

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

def run_epoch(dl, training=True):
    model.train(training)
    total_loss, total_tokens, correct_tokens = 0.0, 0, 0

    for src, tgt in dl:
        src = src.to(device)
        tgt = tgt.to(device)

        tgt_in  = tgt[:, :-1]   # input to decoder
        tgt_out = tgt[:, 1:]    # what we want to predict

        logits = model(src, tgt_in)  # (B, T, V)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))

        if training:
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        total_loss += loss.item()

        preds = logits.argmax(dim=-1)
        mask = tgt_out.ne(pad_id)
        correct_tokens += (preds.eq(tgt_out) & mask).sum().item()
        total_tokens += mask.sum().item()

    return total_loss / len(dl), correct_tokens / max(1, total_tokens)

EPOCHS = 25
for epoch in range(1, EPOCHS + 1):
    t0 = time.time()
    train_loss, train_acc = run_epoch(train_dl, training=True)
    val_loss, val_acc     = run_epoch(val_dl,   training=False)
    dt = time.time() - t0

    print(f"Epoch {epoch:02d} | "
          f"train loss {train_loss:.3f}, acc {train_acc:.3f} | "
          f"val loss {val_loss:.3f}, acc {val_acc:.3f} | {dt:.1f}s")


Epoch 01 | train loss 0.296, acc 0.870 | val loss 0.159, acc 0.938 | 16.7s
Epoch 02 | train loss 0.262, acc 0.886 | val loss 0.148, acc 0.941 | 16.0s
Epoch 03 | train loss 0.249, acc 0.890 | val loss 0.142, acc 0.944 | 16.7s
Epoch 04 | train loss 0.240, acc 0.896 | val loss 0.124, acc 0.956 | 15.6s
Epoch 05 | train loss 0.241, acc 0.895 | val loss 0.125, acc 0.954 | 16.6s
Epoch 06 | train loss 0.229, acc 0.900 | val loss 0.116, acc 0.959 | 16.7s
Epoch 07 | train loss 0.219, acc 0.907 | val loss 0.090, acc 0.979 | 15.8s
Epoch 08 | train loss 0.207, acc 0.913 | val loss 0.080, acc 0.976 | 15.6s
Epoch 09 | train loss 0.198, acc 0.917 | val loss 0.077, acc 0.973 | 15.8s
Epoch 10 | train loss 0.188, acc 0.923 | val loss 0.057, acc 0.988 | 16.3s
Epoch 11 | train loss 0.181, acc 0.926 | val loss 0.050, acc 0.986 | 15.6s
Epoch 12 | train loss 0.166, acc 0.934 | val loss 0.064, acc 0.983 | 15.8s
Epoch 13 | train loss 0.162, acc 0.936 | val loss 0.036, acc 0.993 | 16.1s
Epoch 14 | train loss 0.1

### Inference (greedy decoding) + test examples

In [None]:
@torch.no_grad()
def greedy_decode(src_tokens, max_len=30):
    model.eval()

    src_ids = torch.tensor(encode(src_tokens), dtype=torch.long, device=device).unsqueeze(0)
    src_pad = pad_mask(src_ids)

    # Encoder -> memory
    src_e = model.pos_enc(model.src_emb(src_ids) * math.sqrt(model.d_model))
    memory = model.transformer.encoder(src_e, src_key_padding_mask=src_pad)

    ys = torch.tensor([[bos_id]], dtype=torch.long, device=device)
    for _ in range(max_len):
        tgt_e = model.pos_enc(model.tgt_emb(ys) * math.sqrt(model.d_model))
        tgt_m = subsequent_mask(ys.size(1), device=device)

        out = model.transformer.decoder(
            tgt_e, memory,
            tgt_mask=tgt_m,
            tgt_key_padding_mask=pad_mask(ys),
            memory_key_padding_mask=src_pad
        )
        next_logits = model.out(out[:, -1, :])
        next_id = int(next_logits.argmax(dim=-1).item())

        ys = torch.cat([ys, torch.tensor([[next_id]], device=device)], dim=1)
        if next_id == eos_id:
            break

    return decode(ys.squeeze(0))

tests = [
    ["red","car"],
    ["blue","bike"],
    ["yellow","book","and","black","phone"],
    ["white","cup","and","orange","shirt","and","green","phone"]
]

for s in tests:
    pred = greedy_decode(s)
    print("SRC:", " ".join(s))
    print("PRD:", " ".join(pred))
    print("-"*50)


SRC: red car
PRD: car red
--------------------------------------------------
SRC: blue bike
PRD: bike blue
--------------------------------------------------
SRC: yellow book and black phone
PRD: book yellow and phone black
--------------------------------------------------
SRC: white cup and orange shirt and green phone
PRD: cup white and phone orange and shirt green
--------------------------------------------------
