In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import sacrebleu
import warnings
import math
from tqdm import tqdm
from collections import Counter
from torch.utils.data import Dataset, DataLoader

warnings.filterwarnings("ignore", message="Converting mask")

BATCH_SIZE = 64
EMBED_DIM = 512
NUM_HEADS = 8
NUM_LAYERS = 6
FF_DIM = 2048
DROPOUT = 0.1
LEARNING_RATE = 1e-4
N_EPOCHS = 10
CLIP = 1.0
MIN_FREQ = 2
MAX_LEN = 100

UNK_TOKEN = "<unk>"
PAD_TOKEN = "<pad>"
BOS_TOKEN = "<bos>"
EOS_TOKEN = "<eos>"

TRAIN_DE_PATH = "./train.de-en.de"
TRAIN_EN_PATH = "./train.de-en.en"
VAL_DE_PATH   = "./val.de-en.de"
VAL_EN_PATH   = "./val.de-en.en"
TEST_DE_PATH  = "./test1.de-en.de"
OUTPUT_TEST_PATH = "test1.de-en.en"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_data(de_path, en_path=None):
    with open(de_path, 'r', encoding='utf-8') as f_de:
        data_de = [line.strip().split() for line in f_de]
    if en_path is not None:
        with open(en_path, 'r', encoding='utf-8') as f_en:
            data_en = [line.strip().split() for line in f_en]
        return data_de, data_en
    return data_de

train_de, train_en = load_data(TRAIN_DE_PATH, TRAIN_EN_PATH)
val_de, val_en = load_data(VAL_DE_PATH, VAL_EN_PATH)
test_de = load_data(TEST_DE_PATH, None)

def build_vocab(tokenized_texts, min_freq=MIN_FREQ):
    counter = Counter()
    for tokens in tokenized_texts:
        counter.update(tokens)
    specials = [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN]
    filtered_tokens = [tok for tok, count in counter.items() if count >= min_freq and tok not in specials]
    filtered_tokens.sort(key=lambda x: (-counter[x], x))
    idx2token = specials + filtered_tokens
    token2idx = {token: idx for idx, token in enumerate(idx2token)}
    return idx2token, token2idx

idx2de, de2idx = build_vocab(train_de, min_freq=MIN_FREQ)
idx2en, en2idx = build_vocab(train_en, min_freq=MIN_FREQ)

pad_idx = en2idx[PAD_TOKEN]
bos_idx = en2idx[BOS_TOKEN]
eos_idx = en2idx[EOS_TOKEN]
unk_idx = en2idx[UNK_TOKEN]

def numericalize(tokens, token2idx):
    return [token2idx.get(tok, token2idx[UNK_TOKEN]) for tok in tokens]

def prepare_data(src_texts, trg_texts, de2idx, en2idx, bos_idx, eos_idx):
    data_src = []
    data_trg = []
    for s, t in zip(src_texts, trg_texts):
        s_idx = numericalize(s, de2idx)
        t_idx = numericalize(t, en2idx)
        t_idx = [bos_idx] + t_idx + [eos_idx]
        data_src.append(s_idx)
        data_trg.append(t_idx)
    return data_src, data_trg

train_src, train_trg = prepare_data(train_de, train_en, de2idx, en2idx, bos_idx, eos_idx)
val_src, val_trg = prepare_data(val_de, val_en, de2idx, en2idx, bos_idx, eos_idx)

class TranslationDataset(Dataset):
    def __init__(self, src_list, trg_list):
        super().__init__()
        self.src_list = src_list
        self.trg_list = trg_list
    def __len__(self):
        return len(self.src_list)
    def __getitem__(self, idx):
        return self.src_list[idx], self.trg_list[idx]

def collate_fn(batch):
    src_list, trg_list = zip(*batch)
    max_src_len = max(len(s) for s in src_list)
    max_trg_len = max(len(t) for t in trg_list)
    padded_src = []
    padded_trg = []
    for s, t in zip(src_list, trg_list):
        s_padded = s + [de2idx[PAD_TOKEN]] * (max_src_len - len(s))
        t_padded = t + [en2idx[PAD_TOKEN]] * (max_trg_len - len(t))
        padded_src.append(s_padded)
        padded_trg.append(t_padded)
    return torch.tensor(padded_src, dtype=torch.long), torch.tensor(padded_trg, dtype=torch.long)

train_dataset = TranslationDataset(train_src, train_trg)
val_dataset = TranslationDataset(val_src, val_trg)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    def forward(self, x):
        seq_len = x.size(1)
        x = x + self.pe[:, :seq_len, :]
        return self.dropout(x)

class TransformerTranslator(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, embed_dim, nheads, num_encoder_layers, num_decoder_layers, ff_dim, dropout, pad_idx):
        super().__init__()
        self.src_vocab_size = src_vocab_size
        self.trg_vocab_size = trg_vocab_size
        self.embed_dim = embed_dim
        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim, padding_idx=pad_idx)
        self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim, padding_idx=pad_idx)
        self.pos_encoder = PositionalEncoding(embed_dim, dropout)
        self.pos_decoder = PositionalEncoding(embed_dim, dropout)
        self.transformer = nn.Transformer(d_model=embed_dim, nhead=nheads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, dim_feedforward=ff_dim, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(embed_dim, trg_vocab_size)
        self._init_weights()
    def _init_weights(self):
        nn.init.xavier_uniform_(self.src_embedding.weight)
        nn.init.xavier_uniform_(self.trg_embedding.weight)
        nn.init.xavier_uniform_(self.fc_out.weight)
    def make_src_key_padding_mask(self, src):
        return (src == de2idx[PAD_TOKEN])
    def make_trg_key_padding_mask(self, trg):
        return (trg == en2idx[PAD_TOKEN])
    def forward(self, src, trg):
        src_pad_mask = self.make_src_key_padding_mask(src)
        trg_pad_mask = self.make_trg_key_padding_mask(trg)
        seq_len = trg.size(1)
        nopeak_mask = torch.triu(torch.ones(seq_len, seq_len, device=src.device), diagonal=1).bool()
        embedded_src = self.src_embedding(src) * math.sqrt(self.embed_dim)
        embedded_src = self.pos_encoder(embedded_src)
        embedded_trg = self.trg_embedding(trg) * math.sqrt(self.embed_dim)
        embedded_trg = self.pos_decoder(embedded_trg)
        out = self.transformer(src=embedded_src, tgt=embedded_trg, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=trg_pad_mask, memory_key_padding_mask=src_pad_mask, tgt_mask=nopeak_mask)
        logits = self.fc_out(out)
        return logits

@torch.no_grad()
def generate_translation(model, src, max_len, bos_idx, eos_idx):
    model.eval()
    src_pad_mask = (src == de2idx[PAD_TOKEN])
    src_embed = model.src_embedding(src) * math.sqrt(model.embed_dim)
    src_embed = model.pos_encoder(src_embed)
    memory = model.transformer.encoder(src_embed, src_key_padding_mask=src_pad_mask)
    batch_size = src.size(0)
    ys = torch.full((batch_size, 1), bos_idx, dtype=torch.long, device=src.device)
    for _ in range(max_len - 1):
        tgt_embed = model.trg_embedding(ys) * math.sqrt(model.embed_dim)
        tgt_embed = model.pos_decoder(tgt_embed)
        nopeak_mask = torch.triu(torch.ones(tgt_embed.size(1), tgt_embed.size(1), device=src.device), diagonal=1).bool()
        tgt_pad_mask = (ys == pad_idx)
        out = model.transformer.decoder(tgt_embed, memory, tgt_mask=nopeak_mask, memory_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
        prob = model.fc_out(out[:, -1])
        next_word = prob.argmax(dim=-1, keepdim=True)
        ys = torch.cat([ys, next_word], dim=1)
        if (next_word == eos_idx).all():
            break
    return ys

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    all_hyps = []
    all_refs = []
    with torch.no_grad():
        for src_batch, trg_batch in loader:
            src_batch = src_batch.to(device)
            trg_batch = trg_batch.to(device)
            output = model(src_batch, trg_batch[:, :-1])
            loss = criterion(output.reshape(-1, model.trg_vocab_size), trg_batch[:, 1:].reshape(-1))
            total_loss += loss.item()
            gen = generate_translation(model, src_batch, MAX_LEN, bos_idx, eos_idx)
            for pred_seq, ref_seq in zip(gen, trg_batch):
                pred_tokens = []
                for idx in pred_seq.tolist():
                    if idx == eos_idx:
                        break
                    if idx not in [bos_idx, pad_idx, eos_idx]:
                        pred_tokens.append(idx2en[idx])
                all_hyps.append(" ".join(pred_tokens))
                ref_tokens = []
                for idx in ref_seq.tolist():
                    if idx == eos_idx:
                        break
                    if idx not in [bos_idx, pad_idx, eos_idx]:
                        ref_tokens.append(idx2en[idx])
                all_refs.append(" ".join(ref_tokens))
    avg_loss = total_loss / len(loader)
    bleu_score = sacrebleu.corpus_bleu(all_hyps, [all_refs], tokenize='none').score
    return avg_loss, bleu_score

def train_model(model, train_loader, val_loader, epochs):
    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    best_bleu = 0.0
    for epoch in range(1, epochs + 1):
        model.train()
        total_train_loss = 0.0
        for src_batch, trg_batch in tqdm(train_loader, desc=f"[Epoch {epoch}/{epochs}]"):
            src_batch = src_batch.to(device)
            trg_batch = trg_batch.to(device)
            optimizer.zero_grad()
            output = model(src_batch, trg_batch[:, :-1])
            loss = criterion(output.reshape(-1, model.trg_vocab_size), trg_batch[:, 1:].reshape(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), CLIP)
            optimizer.step()
            total_train_loss += loss.item()
        avg_train_loss = total_train_loss / len(train_loader)
        val_loss, val_bleu = evaluate(model, val_loader, criterion)
        print(f"\nEpoch {epoch}/{epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f}")
        print(f"  Val   Loss: {val_loss:.4f}")
        print(f"  Val   BLEU: {val_bleu:.2f}\n")
        if val_bleu > best_bleu:
            best_bleu = val_bleu
    print(f"Training completed. Best Val BLEU: {best_bleu:.2f}")

def main():
    print("Initializing model from scratch...")
    model = TransformerTranslator(src_vocab_size=len(idx2de), trg_vocab_size=len(idx2en), embed_dim=EMBED_DIM, nheads=NUM_HEADS, num_encoder_layers=NUM_LAYERS, num_decoder_layers=NUM_LAYERS, ff_dim=FF_DIM, dropout=DROPOUT, pad_idx=pad_idx).to(device)
    print("Starting training from scratch (no old checkpoints)...")
    train_model(model, train_loader, val_loader, N_EPOCHS)
    print("Generating translations for test data...")
    test_src_list = []
    for sent in test_de:
        idxs = [de2idx.get(tok, unk_idx) for tok in sent]
        test_src_list.append(idxs)
    all_preds = []
    model.eval()
    with torch.no_grad():
        for src_idxs in test_src_list:
            src_tensor = torch.tensor([src_idxs], dtype=torch.long, device=device)
            gen_seq = generate_translation(model, src_tensor, MAX_LEN, bos_idx, eos_idx)[0]
            pred_tokens = []
            for idx in gen_seq.tolist():
                if idx == eos_idx:
                    break
                if idx not in [bos_idx, pad_idx, eos_idx]:
                    pred_tokens.append(idx2en[idx])
            all_preds.append(" ".join(pred_tokens))
    with open(OUTPUT_TEST_PATH, "w", encoding="utf-8") as f:
        for line in all_preds:
            f.write(line + "\n")
    print(f"Done. Translations are in '{OUTPUT_TEST_PATH}'.")

if __name__ == "__main__":
    main()


Initializing model from scratch...
Starting training from scratch (no old checkpoints)...


[Epoch 1/10]: 100%|██████████| 3062/3062 [06:27<00:00,  7.90it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 1/10:
  Train Loss: 4.9388
  Val   Loss: 4.0904
  Val   BLEU: 2.99



[Epoch 2/10]: 100%|██████████| 3062/3062 [06:29<00:00,  7.86it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 2/10:
  Train Loss: 3.6802
  Val   Loss: 2.9951
  Val   BLEU: 15.83



[Epoch 3/10]: 100%|██████████| 3062/3062 [06:29<00:00,  7.87it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 3/10:
  Train Loss: 2.8969
  Val   Loss: 2.4954
  Val   BLEU: 24.53



[Epoch 4/10]: 100%|██████████| 3062/3062 [06:28<00:00,  7.89it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 4/10:
  Train Loss: 2.4556
  Val   Loss: 2.2875
  Val   BLEU: 28.34



[Epoch 5/10]: 100%|██████████| 3062/3062 [06:28<00:00,  7.88it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 5/10:
  Train Loss: 2.1711
  Val   Loss: 2.1647
  Val   BLEU: 30.70



[Epoch 6/10]: 100%|██████████| 3062/3062 [06:27<00:00,  7.89it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 6/10:
  Train Loss: 1.9660
  Val   Loss: 2.0939
  Val   BLEU: 31.32



[Epoch 7/10]: 100%|██████████| 3062/3062 [06:29<00:00,  7.87it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 7/10:
  Train Loss: 1.8062
  Val   Loss: 2.0312
  Val   BLEU: 32.39



[Epoch 8/10]: 100%|██████████| 3062/3062 [06:29<00:00,  7.87it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 8/10:
  Train Loss: 1.6727
  Val   Loss: 2.0274
  Val   BLEU: 33.33



[Epoch 9/10]: 100%|██████████| 3062/3062 [06:28<00:00,  7.89it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 9/10:
  Train Loss: 1.5570
  Val   Loss: 2.0076
  Val   BLEU: 33.44



[Epoch 10/10]: 100%|██████████| 3062/3062 [06:28<00:00,  7.88it/s]
That's 100 lines that end in a tokenized period ('.')
It looks like you forgot to detokenize your test data, which may hurt your score.
If you insist your data is detokenized, or don't care, you can suppress this message with the `force` parameter.



Epoch 10/10:
  Train Loss: 1.4525
  Val   Loss: 2.0378
  Val   BLEU: 33.13

Training completed. Best Val BLEU: 33.44
Generating translations for test data...
Done. Translations are in 'test1.de-en.en'.


In [13]:
import sys
print(sys.executable)


/home/user/myenv/bin/python
