In [77]:
import torch
from torch import nn, optim
import numpy as np
import re
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Lists for cleaned sentences
eng_sentences = []
jpn_sentences = []

# To avoid duplicates
seen = set()

with open("jpn.txt", "r", encoding="utf-8") as f:
    for line in f:
        parts = line.strip().split("\t")

        # Extract English and Japanese parts
        eng = parts[0].strip().lower()
        jpn = parts[1].strip().lower()

        # Clean English: keep lowercase letters, numbers, and spaces
        eng = re.sub(r"[^a-z0-9\s]", "", eng)

        # Clean Japanese: keep the punctuation
        jpn = re.sub(r"[^\u3040-\u30ff\u4e00-\u9fff。、！？\s]", "", jpn)
        # Filtering out sentences, max 3 words
        if eng not in seen:
            eng_sentences.append(eng)
            jpn_sentences.append(jpn)
            seen.add(eng)

print(f"English Sentences: {eng_sentences[0:10]}")
print(f"English Sentences Length: {len(eng_sentences)}")
print(f"Japanese Sentences: {jpn_sentences[0:10]}")
print(f"Japanese Sentences Length: {len(jpn_sentences)}")

English Sentences: ['go', 'hi', 'run', 'who', 'wow', 'duck', 'fire', 'help', 'hide', 'jump']
English Sentences Length: 94468
Japanese Sentences: ['行け。', 'こんにちは。', '走れ。', '誰？', 'すごい！', '頭を下げろ！', '火事だ！', '助けて！', '隠れろ。', '飛び越えろ！']
Japanese Sentences Length: 94468


In [78]:
from collections import Counter

counter = Counter()

# Count word frequency
for sent in eng_sentences:
    counter.update(sent.split())
    
# English to index mapping
eng_to_ind = {'<pad>': 0, '<unk>':1}
ind = 2

# Include words that appear 3+ times.
for word in counter:
    if counter[word] > 2:
        eng_to_ind[word] = ind
        ind += 1

jcounter = Counter()
# Count character frequency across Japanese sentences
for sent in jpn_sentences:
    jcounter.update(list(sent))

# Japanese to index and index to Japanese mapping
jpn_to_ind = {'<pad>': 0, '<unk>': 1, '<bos>': 2, '<eos>': 3}
ind_to_jpn = {0: '<pad>', 1: '<unk>', 2: '<bos>', 3: '<eos>'}
ind = 4

# Include characters that appear 3+ times
for word in jcounter:
    if jcounter[word] > 2:
        jpn_to_ind[word] = ind
        ind += 1

print("English Vocabulary Size:", len(eng_to_ind))
print("Japanese Vocabulary Size:", len(jpn_to_ind))

English Vocabulary Size: 6410
Japanese Vocabulary Size: 2020


In [79]:
eng_encoded = []

# Encode each English sentence at word level
# Maximum of 18 words per sentence
for sentence in eng_sentences:
    s = []
    for word in sentence.split():
        if word in eng_to_ind:
            s.append(eng_to_ind[word])
        else:
            s.append(1)
        if len(s) > 17:
            break
    while(len(s) < 18):
        s.append(0) # Padding
    eng_encoded.append(s)

jpn_encoded = []

# Encode each Japanese sentence character level
# Maximum of 45 characters per sentence including <bos> and <eos>
for sentence in jpn_sentences:
    s = [2] # 2 is <bos>
    for ch in sentence:
        if len(s) > 43:
            break # We need room to fit <eos> token.
        if ch in jpn_to_ind:
            s.append(jpn_to_ind[ch])
        else:
            s.append(1)   # 1 is <unk>
    s.append(3)           # 3 is <eos>
    while(len(s) < 45):
        s.append(0)       # 0 is <pad>
    jpn_encoded.append(s)

print(f"English Sentences Encoded Length: {len(eng_encoded)}")
print(f"Japanese Sentences Encoded Length: {len(jpn_encoded)}")

English Sentences Encoded Length: 94468
Japanese Sentences Encoded Length: 94468


int(min(600, 1.6 * eng_vocab_size ** 0.56)) = 216 ~= 215
int(min(600, 1.6 * jpn_vocab_size ** 0.56)) = 114 ~= 115

In [80]:
class Translator(nn.Module):
    def __init__(self, eng_vocab_size, jpn_vocab_size, enc_embed_size, dec_embed_size, hidden_size):
        super().__init__()
        self.enc_embed = nn.Embedding(num_embeddings=eng_vocab_size, embedding_dim=enc_embed_size, padding_idx=0)
        self.encoder = nn.LSTM(enc_embed_size, hidden_size, batch_first=True, num_layers=2, dropout=0.3, bidirectional=True)
        
        self.dec_embed = nn.Embedding(num_embeddings=jpn_vocab_size, embedding_dim=dec_embed_size, padding_idx=0)
        self.decoder = nn.LSTM(dec_embed_size, hidden_size, batch_first=True, num_layers=2, dropout=0.3)

        self.out = nn.Linear(hidden_size, jpn_vocab_size)
    
    def forward(self, x, y):
        x = self.enc_embed(x)
        _, (h, c) = self.encoder(x)

        # Reshape from (num_layers*2, batch, hidden) to (num_layers, 2, batch, hidden)
        h = h.view(self.encoder.num_layers, 2, x.size(0), self.encoder.hidden_size)
        c = c.view(self.encoder.num_layers, 2, x.size(0), self.encoder.hidden_size)

        # Sum forward and backward directions: (num_layers, batch, hidden)
        h = h.sum(dim=1)
        c = c.sum(dim=1)


        y = self.dec_embed(y)
        output, _ = self.decoder(y, (h, c))

        logits = self.out(output)
        return logits

In [81]:
from torch.utils.data import DataLoader, TensorDataset

X_tensor = torch.tensor(eng_encoded)
y_tensor = torch.tensor(jpn_encoded)

train_dataset = TensorDataset(X_tensor, y_tensor)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [None]:
model = Translator(len(eng_to_ind), len(jpn_to_ind), 215, 115, 128).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
optim = optim.Adam(model.parameters(), lr=0.001)

epochs = 1
for epoch in range(epochs):
    total_loss = 0

    for xb, yb in train_dataloader:
        xb = xb.to(device)
        yb = yb.to(device)
        optim.zero_grad()

        logits = model(xb, yb[:, :-1])
        targets = yb[:, 1:]

        logits = logits.reshape(-1, logits.size(-1))
        targets = targets.reshape(-1)

        loss = loss_fn(logits, targets)
        loss.backward()
        optim.step()

        total_loss += loss.item() * xb.size(0)
    print(total_loss / len(eng_encoded))

torch.save(model.state_dict(), "torch_params.pt")

0.005156783259691774


In [83]:
import torch.nn.functional as F

def translate_beam(model, src_seq, eng_to_ind, jpn_to_ind, ind_to_jpn, max_len=20, beam_width=3):
    model.eval()

    bos_id = jpn_to_ind["<bos>"]
    eos_id = jpn_to_ind["<eos>"]

    # Convert English words to indices
    src_indices = [eng_to_ind.get(tok, 1) for tok in src_seq]
    src_tensor = torch.tensor([src_indices], dtype=torch.long).to(device)
    print(src_seq)
    print(src_indices)
    print([k for k in src_seq if k not in eng_to_ind])
    with torch.no_grad():
        enc_embed = model.enc_embed(src_tensor)  # (1, T, D)
        _, (h, c) = model.encoder(enc_embed)

        # Combine forward & backward states for bidirectional encoder
        num_layers = model.encoder.num_layers
        batch_size = src_tensor.size(0)
        hidden_size = model.encoder.hidden_size

        # h and c are (num_layers*2, batch, hidden_size)
        h = h.view(num_layers, 2, batch_size, hidden_size).sum(dim=1)  # (num_layers, batch, hidden_size)
        c = c.view(num_layers, 2, batch_size, hidden_size).sum(dim=1)

        # Each beam: (score, sequence, h, c)
        beam = [(0.0, [bos_id], h, c)]
        completed_sequences = []

        for _ in range(max_len):
            new_beam = []

            for score, seq, h_prev, c_prev in beam:
                inputs = torch.tensor([[seq[-1]]], dtype=torch.long).to(device)  # (1, 1)
                dec_embed = model.dec_embed(inputs)  # (1, 1, D)
                out, (h_new, c_new) = model.decoder(dec_embed, (h_prev, c_prev))  # out: (1, 1, H)
                logits = model.out(out[:, -1, :])  # (1, vocab_size)
                log_probs = F.log_softmax(logits, dim=-1)  # (1, vocab_size)

                topk_log_probs, topk_indices = torch.topk(log_probs, beam_width, dim=-1)  # (1, k)

                for i in range(beam_width):
                    word_id = topk_indices[0, i].item()
                    word_score = topk_log_probs[0, i].item()
                    new_seq = seq + [word_id]
                    new_score = score + word_score

                    if word_id == eos_id:
                        completed_sequences.append((new_score, new_seq))
                    else:
                        new_beam.append((new_score, new_seq, h_new, c_new))

            if not new_beam:
                break

            # Keep top-k
            beam = sorted(new_beam, key=lambda x: x[0], reverse=True)[:beam_width]

        # If no sequences finished with <eos>, use current best from beam
        if not completed_sequences:
            completed_sequences = [(score, seq) for score, seq, _, _ in beam]

        completed_sequences.sort(key=lambda x: x[0], reverse=True)
        best_seq = completed_sequences[0][1]

    # Convert to Japanese characters (skip <bos>, <eos>)
    return [ind_to_jpn.get(idx, '<unk>') for idx in best_seq[1:-1]]


In [84]:
eng_sentence = "Tom likes food"
tokens = eng_sentence.strip().lower().split()
output_chars = translate_beam(model, tokens, eng_to_ind, jpn_to_ind, ind_to_jpn, beam_width=5)
print("".join(output_chars))


['tom', 'likes', 'food']
[58, 1102, 878]
[]
<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>
