# Transformer Encoder-Decoder for English-to-Japanese Translation

Note: This project is less documented compared to my other projects, as it was developed while I was implementing a Swin Transformer, leaving me with limited focus for detailed documentation. If I have time I may add more documentation.

This project implements a Transformer based sequence-to-sequence model from scratch to translate English sentences into Japanese. The entire architecture, including multi-head attention, positional encoding, masking, and training with a learning rate scheduler is manually implemented using PyTorch. This project is meant to give me a deeper understanding of Transformers so I can tackle Swin Transformer.

## Architecture
![Architecture](figures/Encoder_Decoder.png)

Image source: Zhang, Aston and Lipton, Zachary C. and Li, Mu and Smola, Alexander J. - https://github.com/d2l-ai/d2l-en

## Preprocessing
Dataset is English–Japanese sentence pairs from the Tatoeba Project. Preprocessing is copy and pasted from my numpy implementation which used the same dataset. Used the entire dataset instead of a small portion.

In [None]:
import torch
from torch import nn, optim
import re
from collections import Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

eng_sentences = []
jpn_sentences = []
seen = set()

with open('jpn.txt', "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split("\t")

        eng = parts[0].strip().lower()
        jpn = parts[1].strip().lower()

        # Clean English: keep lowercase letters, numbers, and spaces
        eng = re.sub(r"[^a-z0-9\s]", "", eng)

        # Clean Japanese: keep hiragana, katakana, kanji, and Japanese punctuation
        jpn = re.sub(r"[^\u3040-\u30ff\u4e00-\u9fff。、！？\s]", "", jpn)

        if eng not in seen:
            eng_sentences.append(eng)
            jpn_sentences.append(jpn)
            seen.add(eng)

print(f"English Sentences (sample): {eng_sentences[:3]}")
print(f"English Sentences Length: {len(eng_sentences)}")
print(f"Japanese Sentences (sample): {jpn_sentences[:3]}")
print(f"Japanese Sentences Length: {len(jpn_sentences)}")

# English word-level vocab
eng_counter = Counter()
for sent in eng_sentences:
    eng_counter.update(sent.strip().split())

eng_to_ind = {'<pad>': 0, '<unk>': 1}
ind = 2
for word in eng_counter:
    if eng_counter[word] > 2:
        eng_to_ind[word] = ind
        ind += 1

# Japanese char-level vocab
jpn_counter = Counter()
for sent in jpn_sentences:
    jpn_counter.update(list(sent))

jpn_to_ind = {'<pad>': 0, '<unk>': 1, '<bos>': 2, '<eos>': 3}
ind_to_jpn = {0: '<pad>', 1: '<unk>', 2: '<bos>', 3: '<eos>'}
ind = 4
for ch in jpn_counter:
    if jpn_counter[ch] > 2:
        jpn_to_ind[ch] = ind
        ind_to_jpn[ind] = ch
        ind += 1

print("English Vocabulary Size:", len(eng_to_ind))
print("Japanese Vocabulary Size:", len(jpn_to_ind))

eng_encoded = []
for sentence in eng_sentences:
    s = []
    for word in sentence.split():
        if word in eng_to_ind:
            s.append(eng_to_ind[word])
        else:
            s.append(1)
    eng_encoded.append(s)

jpn_encoded = []
for sentence in jpn_sentences:
    s = [2] # 2 is <bos>
    for ch in sentence:
        if ch in jpn_to_ind:
            s.append(jpn_to_ind[ch])
        else:
            s.append(1)   # 1 is <unk>
    s.append(3)           # 3 is <eos>
    jpn_encoded.append(s)

print("Encoded English sentence:", eng_encoded[0:5])
print("Encoded Japanese sentence:", jpn_encoded[0:5])


Using device: cuda
English Sentences (sample): ['go', 'hi', 'run']
English Sentences Length: 94468
Japanese Sentences (sample): ['行け。', 'こんにちは。', '走れ。']
Japanese Sentences Length: 94468
English Vocabulary Size: 6410
Japanese Vocabulary Size: 2020
Encoded English sentence: [[2], [3], [4], [5], [6]]
Encoded Japanese sentence: [[2, 4, 5, 6, 3], [2, 7, 8, 9, 10, 11, 6, 3], [2, 12, 13, 6, 3], [2, 14, 15, 3], [2, 16, 17, 18, 19, 3]]


## Positional Encoding

Transformers have no notion of sequence order. To inject information about token position, sinusoidal positional encodings is added to the input embeddings.

PE(i,2k)​=sin(i/10000^(2k/d)), PE(i,2k+1)​=cos(i/10000^(2k/d)) 

where i is the token position, k is the dimension index, and d is the embedding size


In [None]:
class PositionEncoding(nn.Module):
    def __init__(self, d_model, length=200):
        super().__init__()

        PE = torch.zeros(length, d_model) # (length, d_model)

        position = torch.arange(0, length, dtype=torch.float).unsqueeze(1) # (length, 1)
        # Denominator term of the sinusoidal formula
        denominator = torch.pow(10000, 2*torch.arange(0, d_model, 2) / d_model)

        # Apply sin and cos
        PE[:, 0::2] = torch.sin(position * denominator)
        PE[:, 1::2] = torch.cos(position * denominator)

        PE = PE.unsqueeze(0) # (1, length, d_model)
        self.register_buffer("PE", PE) # This is something new that I learned
        # It lets PE move to the GPU and includes it in the state dict.
    
    def forward(self, x):
        x = x + self.PE[:, :x.size(1)]
        return x

## Transformer Encoder & Decoder Layers

Each Transformer layer is built using:
- Multi-Head Attention: Allows the model to attend to information at different positions.
- Feedforward Network: A two-layer MLP.
- Residual Connections and Layer Normalization.


In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, nheads, dim_feedforward, dropout=0.2):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        
        # Feedforward network: two linear layers with ReLU and dropout
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x, pad_mask):
        # Self-attention with padding mask to ignore padding tokens
        attn_output, _ = self.self_attn(x, x, x, key_padding_mask=pad_mask)
        x = self.norm1(x + attn_output)
        
        ff = self.linear2(self.dropout(self.relu(self.linear1(x))))
        x = self.norm2(x + ff)
        return x

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, nheads, dim_feedforward, dropout=0.2):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(d_model, nheads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm3 = nn.LayerNorm(d_model)
    
    def forward(self, x, enc_out, y_pad_mask=None, cross_pad_mask=None, causal_mask=None):
        attn_output, _ = self.self_attn(x, x, x, attn_mask=causal_mask, 
                                        key_padding_mask=y_pad_mask)
        x = self.norm1(x + attn_output)

        cross_attn_output, _ = self.cross_attn(x, enc_out, enc_out, 
                                               key_padding_mask=cross_pad_mask)
        x = self.norm2(x + cross_attn_output)

        ff = self.linear2(self.dropout(self.relu(self.linear1(x))))
        x = self.norm3(x + ff)
        return x

In [9]:
def padding_mask(seq):
    # Padding Mask to prevent attention from looking at padding.
    return (seq == 0)

def causal_mask(seq_len, device):
    # Creates upper triangular matrix filled with -inf above the diagonal
    return torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool()

## Transformer Model

In [10]:
class Transformer(nn.Module):
    def __init__(self, d_model, dim_feedforward, heads, num_encoder_lay, 
                 num_decoder_lay, max_len, x_vocab_size, y_vocab_size):
        super().__init__()
        self.encoder_embedding = nn.Embedding(x_vocab_size, d_model, padding_idx=0)
        self.decoder_embedding = nn.Embedding(y_vocab_size, d_model, padding_idx=0)
        
        self.pos_encoder = PositionEncoding(d_model, max_len)
        self.pos_decoder = PositionEncoding(d_model, max_len)

        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, heads, dim_feedforward) for _ in range(num_encoder_lay)
        ])
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, heads, dim_feedforward) for _ in range(num_decoder_lay)
        ])
        self.output_layer = nn.Linear(d_model, y_vocab_size)
    
    def forward(self, x, y):
        x_pad_mask = padding_mask(x)
        y_pad_mask = padding_mask(y)
        causal = causal_mask(y.size(1), y.device)

        x = self.encoder_embedding(x)
        x = self.pos_encoder(x)
        y = self.decoder_embedding(y)
        y = self.pos_decoder(y)

        for layer in self.encoder_layers:
            x = layer(x, x_pad_mask)
        
        for layer in self.decoder_layers:
            y = layer(y, x, y_pad_mask, x_pad_mask, causal)

        return self.output_layer(y)

In [7]:
from torch.utils.data import Dataset

class TransformerDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]


In [8]:
from torch.nn.utils.rnn import pad_sequence

def collate_fn(batch):
    x_batch, y_batch = zip(*batch)

    x_batch = pad_sequence(x_batch, batch_first=True, padding_value=0)
    y_batch = pad_sequence(y_batch, batch_first=True, padding_value=0)

    return x_batch, y_batch  # no masks returned


In [None]:
from torch.utils.data import DataLoader

eng_tensors = [torch.tensor(seq, dtype=torch.long) for seq in eng_encoded]
jpn_tensors = [torch.tensor(seq, dtype=torch.long) for seq in jpn_encoded]
dataset = TransformerDataset(eng_tensors, jpn_tensors)

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    collate_fn=collate_fn
)

Custom LR scheduler mirrors original Transformer paper

In [None]:
class TransformerLRScheduler:
    def __init__(self, optimizer, d_model, warmup_steps=4000):
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0

    def step(self):
        self.step_num += 1
        lr = self._compute_lr()

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def _compute_lr(self):
        return (self.d_model ** -0.5) * min(
            self.step_num ** -0.5,
            self.step_num * (self.warmup_steps ** -1.5)
        )

## Training Loop

In [11]:
epochs = 50
d_model = 180
model = Transformer(d_model=d_model,
                    dim_feedforward=720,
                    heads=6,
                    num_decoder_lay=5,
                    num_encoder_lay=5,
                    max_len=200,
                    x_vocab_size=len(eng_to_ind),
                    y_vocab_size=len(jpn_to_ind)
                    ).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.98), eps=1e-9)

scheduler = TransformerLRScheduler(optimizer, d_model)

past_loss = []
for epoch in range(epochs):
    total_loss = 0
    total_tokens = 0
    for xb, yb in dataloader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        output = model(xb, yb[:, :-1])
        target = yb[:, 1:]

        num_tokens = (target != 0).sum  ().item()
        total_tokens += num_tokens

        output = output.flatten(0, 1)
        target = target.flatten()
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item() * num_tokens
    print(f"Epoch {epoch+1}, Loss: {total_loss / total_tokens:.4f}")
    if len(past_loss) > 5:
        if past_loss[-1] < total_loss and past_loss[-2] < past_loss[-1]:
            break
    past_loss.append(total_loss)

torch.save(model.state_dict(), "torch_params2.pt")


Epoch 1, Loss: 3.7658
Epoch 2, Loss: 2.2293
Epoch 3, Loss: 1.8411
Epoch 4, Loss: 1.6123
Epoch 5, Loss: 1.4417
Epoch 6, Loss: 1.3264
Epoch 7, Loss: 1.2383
Epoch 8, Loss: 1.1697
Epoch 9, Loss: 1.1112
Epoch 10, Loss: 1.0618
Epoch 11, Loss: 1.0193
Epoch 12, Loss: 0.9825
Epoch 13, Loss: 0.9489
Epoch 14, Loss: 0.9190
Epoch 15, Loss: 0.8911
Epoch 16, Loss: 0.8663
Epoch 17, Loss: 0.8441
Epoch 18, Loss: 0.8223
Epoch 19, Loss: 0.8029
Epoch 20, Loss: 0.7836
Epoch 21, Loss: 0.7675
Epoch 22, Loss: 0.7513
Epoch 23, Loss: 0.7356
Epoch 24, Loss: 0.7215
Epoch 25, Loss: 0.7088
Epoch 26, Loss: 0.6946
Epoch 27, Loss: 0.6835
Epoch 28, Loss: 0.6723
Epoch 29, Loss: 0.6616
Epoch 30, Loss: 0.6522
Epoch 31, Loss: 0.6421
Epoch 32, Loss: 0.6317
Epoch 33, Loss: 0.6217
Epoch 34, Loss: 0.6136
Epoch 35, Loss: 0.6056
Epoch 36, Loss: 0.5974
Epoch 37, Loss: 0.5897
Epoch 38, Loss: 0.5825
Epoch 39, Loss: 0.5747
Epoch 40, Loss: 0.5678
Epoch 41, Loss: 0.5607
Epoch 42, Loss: 0.5546
Epoch 43, Loss: 0.5491
Epoch 44, Loss: 0.54

## Beam Search implementation

In [None]:
import torch.nn.functional as F

def translate_beam(model, src_seq, eng_to_ind, jpn_to_ind, ind_to_jpn, max_len=40, beam_width=3):
    model.eval()
    bos_id = jpn_to_ind["<bos>"]
    eos_id = jpn_to_ind["<eos>"]

    # Convert input sentence to tensor
    src_indices = [eng_to_ind.get(tok, 1) for tok in src_seq]  # 1 = <unk>
    src_tensor = torch.tensor([src_indices], dtype=torch.long).to(device)

    # Initial decoder input
    beam = [(0.0, [bos_id])]  # (score, token_id_sequence)
    completed = []

    with torch.no_grad():
        for _ in range(max_len):
            new_beam = []

            for score, seq in beam:
                if seq[-1] == eos_id:
                    completed.append((score, seq))
                    continue

                tgt_tensor = torch.tensor([seq], dtype=torch.long).to(device)  # shape: [1, seq_len]
                out = model(src_tensor, tgt_tensor)  # shape: [1, seq_len, vocab]
                logits = out[:, -1, :]  # take last token output
                log_probs = F.log_softmax(logits, dim=-1)

                topk_log_probs, topk_ids = torch.topk(log_probs, beam_width, dim=-1)

                for i in range(beam_width):
                    word_id = topk_ids[0, i].item()
                    new_score = score + topk_log_probs[0, i].item()
                    new_seq = seq + [word_id]

                    if word_id == eos_id:
                        completed.append((new_score, new_seq))
                    else:
                        new_beam.append((new_score, new_seq))

            if not new_beam:
                break

            # Keep top-k
            beam = sorted(new_beam, key=lambda x: x[0], reverse=True)[:beam_width]

        if not completed:
            completed = beam

        best_seq = sorted(completed, key=lambda x: x[0], reverse=True)[0][1]

    # Convert to Japanese tokens (skip <bos> and <eos>)
    return [ind_to_jpn.get(idx, '<unk>') for idx in best_seq[1:-1]]

## Load trained model

In [16]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Transformer(
    d_model=180,
    dim_feedforward=720,
    heads=6,
    num_decoder_lay=5,
    num_encoder_lay=5,
    max_len=200,
    x_vocab_size=len(eng_to_ind),
    y_vocab_size=len(jpn_to_ind)
).to(device)

state = torch.load("torch_params.pt", map_location=device)

missing, unexpected = model.load_state_dict(state, strict=False)
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

model.eval()


Missing keys: []
Unexpected keys: []


Transformer(
  (encoder_embedding): Embedding(6410, 180, padding_idx=0)
  (decoder_embedding): Embedding(2020, 180, padding_idx=0)
  (pos_encoder): PositionEncoding()
  (pos_decoder): PositionEncoding()
  (encoder_layers): ModuleList(
    (0-4): 5 x EncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=180, out_features=180, bias=True)
      )
      (norm1): LayerNorm((180,), eps=1e-05, elementwise_affine=True)
      (linear1): Linear(in_features=180, out_features=720, bias=True)
      (relu): ReLU()
      (dropout): Dropout(p=0.2, inplace=False)
      (linear2): Linear(in_features=720, out_features=180, bias=True)
      (norm2): LayerNorm((180,), eps=1e-05, elementwise_affine=True)
    )
  )
  (decoder_layers): ModuleList(
    (0-4): 5 x DecoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=180, out_features=180, bias=True)
      )
      (cross_attn): Multihead

In [29]:
eng_sentence = "This worked pretty well"
tokens = eng_sentence.strip().lower().split()
output_chars = translate_beam(model, tokens, eng_to_ind, jpn_to_ind, ind_to_jpn, beam_width=5)
print("".join(output_chars))

これ、かなりうまく動くよ。


![Architecture](figures/Translation.png)

Translation works as it should for small to medium sized sequences. Larger sequences dont work as well but thats to be expected.