Transformer Architecture

In [206]:
import torch
from torch import nn, optim
import re
from collections import Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

eng_sentences = []
jpn_sentences = []
seen = set()

with open("jpn.txt", "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split("\t")

        eng = parts[0].strip().lower()
        jpn = parts[1].strip().lower()

        # Clean English: keep lowercase letters, numbers, and spaces
        eng = re.sub(r"[^a-z0-9\s]", "", eng)

        # Clean Japanese: keep hiragana, katakana, kanji, and Japanese punctuation
        jpn = re.sub(r"[^\u3040-\u30ff\u4e00-\u9fff。、！？\s]", "", jpn)

        if eng not in seen:
            eng_sentences.append(eng)
            jpn_sentences.append(jpn)
            seen.add(eng)

print(f"English Sentences (sample): {eng_sentences[:3]}")
print(f"English Sentences Length: {len(eng_sentences)}")
print(f"Japanese Sentences (sample): {jpn_sentences[:3]}")
print(f"Japanese Sentences Length: {len(jpn_sentences)}")

# English word-level vocab
eng_counter = Counter()
for sent in eng_sentences:
    eng_counter.update(sent.strip().split())

eng_to_ind = {'<pad>': 0, '<unk>': 1}
ind = 2
for word in eng_counter:
    if eng_counter[word] > 2:
        eng_to_ind[word] = ind
        ind += 1

# Japanese char-level vocab
jpn_counter = Counter()
for sent in jpn_sentences:
    jpn_counter.update(list(sent))

jpn_to_ind = {'<pad>': 0, '<unk>': 1, '<bos>': 2, '<eos>': 3}
ind_to_jpn = {0: '<pad>', 1: '<unk>', 2: '<bos>', 3: '<eos>'}
ind = 4
for ch in jpn_counter:
    if jpn_counter[ch] > 2:
        jpn_to_ind[ch] = ind
        ind_to_jpn[ind] = ch
        ind += 1

print("English Vocabulary Size:", len(eng_to_ind))
print("Japanese Vocabulary Size:", len(jpn_to_ind))

eng_encoded = []
for sentence in eng_sentences:
    s = []
    for word in sentence.split():
        if word in eng_to_ind:
            s.append(eng_to_ind[word])
        else:
            s.append(1)
    eng_encoded.append(s)

jpn_encoded = []
for sentence in jpn_sentences:
    s = [2] # 2 is <bos>
    for ch in sentence:
        if ch in jpn_to_ind:
            s.append(jpn_to_ind[ch])
        else:
            s.append(1)   # 1 is <unk>
    s.append(3)           # 3 is <eos>
    jpn_encoded.append(s)

print("Encoded English sentence:", eng_encoded[0:5])
print("Encoded Japanese sentence:", jpn_encoded[0:5])


Using device: cpu
English Sentences (sample): ['go', 'hi', 'run']
English Sentences Length: 94468
Japanese Sentences (sample): ['行け。', 'こんにちは。', '走れ。']
Japanese Sentences Length: 94468
English Vocabulary Size: 6410
Japanese Vocabulary Size: 2020
Encoded English sentence: [[2], [3], [4], [5], [6]]
Encoded Japanese sentence: [[2, 4, 5, 6, 3], [2, 7, 8, 9, 10, 11, 6, 3], [2, 12, 13, 6, 3], [2, 14, 15, 3], [2, 16, 17, 18, 19, 3]]


![Architecture](figures/Encoder_Decoder.png)

In [207]:
class PositionEncoding(nn.Module):
    def __init__(self, d_model, length=200):
        super().__init__()
        PE = torch.zeros(length, d_model)
        position = torch.arange(0, length, dtype=torch.float).unsqueeze(1)
        denominator = torch.pow(10000, 2*torch.arange(0, d_model, 2) / d_model)

        PE[:, 0::2] = torch.sin(position * denominator)
        PE[:, 1::2] = torch.cos(position * denominator)

        PE = PE.unsqueeze(0) # (1, length, d_model)
        self.register_buffer("PE", PE) # This is something that I didn't know was necessary. 
        # It lets PE move to the GPU and includes it in the state dict.

    
    def forward(self, x):
        x = x + self.PE[:, :x.size(1)]
        return x

Embed seperately.

In [208]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, nheads, dim_feedforward, dropout=0.2):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, pad_mask):
        attn_output, _ = self.self_attn(x, x, x, key_padding_mask=pad_mask)
        x = self.norm1(x + attn_output)
        
        ff = self.linear2(self.dropout(self.relu(self.linear1(x))))
        x = self.norm2(x + ff)
        return x

In [209]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, nheads, dim_feedforward, dropout=0.2):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm3 = nn.LayerNorm(d_model)
    
    def forward(self, x, enc_out, y_pad_mask=None, cross_pad_mask=None, causal_mask=None):
        attn_output, _ = self.self_attn(x, x, x, attn_mask=causal_mask, 
                                        key_padding_mask=y_pad_mask)
        x = self.norm1(x + attn_output)

        cross_attn_output, _ = self.cross_attn(x, enc_out, enc_out, 
                                               key_padding_mask=cross_pad_mask)
        x = self.norm2(x + cross_attn_output)

        ff = self.linear2(self.dropout(self.relu(self.linear1(x))))
        x = self.norm3(x + ff)
        return x

In [210]:
def padding_mask(seq):
    # Padding Mask to prevent attention from looking at padding.
    return (seq == 0)

def causal_mask(seq_len, device):
    # Creates upper triangular matrix filled with -inf above the diagonal
    return torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool()

In [211]:
class Transformer(nn.Module):
    def __init__(self, d_model, dim_feedforward, heads, num_encoder_lay, 
                 num_decoder_lay, max_len, x_vocab_size, y_vocab_size):
        super().__init__()
        self.encoder_embedding = nn.Embedding(x_vocab_size, d_model, padding_idx=0)
        self.decoder_embedding = nn.Embedding(y_vocab_size, d_model, padding_idx=0)
        
        self.pos_encoder = PositionEncoding(d_model, max_len)
        self.pos_decoder = PositionEncoding(d_model, max_len)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, heads, dim_feedforward) for _ in range(num_encoder_lay)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, heads, dim_feedforward) for _ in range(num_decoder_lay)
        ])
        self.output_layer = nn.Linear(d_model, y_vocab_size)
    
    def forward(self, x, y):
        x_pad_mask = padding_mask(x)
        y_pad_mask = padding_mask(y)
        causal = causal_mask(y.size(1), y.device)

        x = self.encoder_embedding(x)
        x = self.pos_encoder(x)
        y = self.decoder_embedding(y)
        y = self.pos_decoder(y)

        for layer in self.encoder_layers:
            x = layer(x, x_pad_mask)
        
        for layer in self.decoder_layers:
            y = layer(y, x, y_pad_mask, x_pad_mask, causal)

        return self.output_layer(y)

In [212]:
from torch.utils.data import Dataset

class TransformerDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


In [213]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    x_batch, y_batch = zip(*batch)

    x_batch = pad_sequence(x_batch, batch_first=True, padding_value=0)
    y_batch = pad_sequence(y_batch, batch_first=True, padding_value=0)

    return x_batch, y_batch  # no masks returned


In [214]:
from torch.utils.data import DataLoader

eng_tensors = [torch.tensor(seq, dtype=torch.long) for seq in eng_encoded]
jpn_tensors = [torch.tensor(seq, dtype=torch.long) for seq in jpn_encoded]
dataset = TransformerDataset(eng_tensors, jpn_tensors)

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=collate_fn
)

Custom LR scheduler mirrors original Transformer paper

In [215]:
class TransformerLRScheduler:
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0

    def step(self):
        self.step_num += 1
        lr = self._compute_lr()

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def _compute_lr(self):
        return (self.d_model ** -0.5) * min(
            self.step_num ** -0.5,
            self.step_num * (self.warmup_steps ** -1.5)
        )


In [216]:
epochs = 10
d_model = 180
model = Transformer(d_model=d_model,
                    dim_feedforward=720,
                    heads=6,
                    num_decoder_lay=4,
                    num_encoder_lay=4,
                    max_len=200,
                    x_vocab_size=len(eng_to_ind),
                    y_vocab_size=len(jpn_to_ind)
                    )
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)

scheduler = TransformerLRScheduler(optimizer, d_model)

for epoch in range(epochs):
    total_loss = 0
    for xb, yb in dataloader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        output = model(xb, yb[:, :-1])
        target = yb[:, 1:]

        output = output.flatten(0, 1)
        target = target.flatten()
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        break
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

torch.save(model.state_dict(), "torch_params.pt")


Epoch 1, Loss: 7.8286
Epoch 2, Loss: 7.6616
Epoch 3, Loss: 7.6526
Epoch 4, Loss: 7.6496
Epoch 5, Loss: 7.6503
Epoch 6, Loss: 7.6302
Epoch 7, Loss: 7.6611
Epoch 8, Loss: 7.6603
Epoch 9, Loss: 7.6585
Epoch 10, Loss: 7.6264


In [217]:
import torch.nn.functional as F
from nltk.translate.bleu_score import sentence_bleu

def translate_beam(model, src_seq, eng_to_ind, jpn_to_ind, ind_to_jpn, max_len=40, beam_width=3):
    model.eval()
    bos_id = jpn_to_ind["<bos>"]
    eos_id = jpn_to_ind["<eos>"]

    # Convert input sentence to tensor
    src_indices = [eng_to_ind.get(tok, 1) for tok in src_seq]  # 1 = <unk>
    src_tensor = torch.tensor([src_indices], dtype=torch.long).to(device)

    # Initial decoder input
    beam = [(0.0, [bos_id])]  # (score, token_id_sequence)
    completed = []

    with torch.no_grad():
        for _ in range(max_len):
            new_beam = []

            for score, seq in beam:
                if seq[-1] == eos_id:
                    completed.append((score, seq))
                    continue

                tgt_tensor = torch.tensor([seq], dtype=torch.long).to(device)  # shape: [1, seq_len]
                out = model(src_tensor, tgt_tensor)  # shape: [1, seq_len, vocab]
                logits = out[:, -1, :]  # take last token output
                log_probs = F.log_softmax(logits, dim=-1)

                topk_log_probs, topk_ids = torch.topk(log_probs, beam_width, dim=-1)

                for i in range(beam_width):
                    word_id = topk_ids[0, i].item()
                    new_score = score + topk_log_probs[0, i].item()
                    new_seq = seq + [word_id]

                    if word_id == eos_id:
                        completed.append((new_score, new_seq))
                    else:
                        new_beam.append((new_score, new_seq))

            if not new_beam:
                break

            # Keep top-k
            beam = sorted(new_beam, key=lambda x: x[0], reverse=True)[:beam_width]

        if not completed:
            completed = beam

        best_seq = sorted(completed, key=lambda x: x[0], reverse=True)[0][1]

    # Convert to Japanese tokens (skip <bos> and <eos>)
    return [ind_to_jpn.get(idx, '<unk>') for idx in best_seq[1:-1]]

In [218]:
eng_sentence = "Lets eat"
tokens = eng_sentence.strip().lower().split()
output_chars = translate_beam(model, tokens, eng_to_ind, jpn_to_ind, ind_to_jpn, beam_width=5)
print("".join(output_chars))


適依跳脅依跳依り脅依種葡列葬脅依製依製依製依製依り痛贈依毒齢適度六列脅依製依繰
