In [1]:
%%writefile requirements.txt
torch==2.3.0
torchtext==0.18
pandas
sentencepiece
tqdm
wandb
sacrebleu==2.3.1

Writing requirements.txt


In [2]:
!pip install -r requirements.txt

Collecting torch==2.3.0 (from -r requirements.txt (line 1))
  Downloading torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl.metadata (26 kB)
Collecting torchtext==0.18 (from -r requirements.txt (line 2))
  Downloading torchtext-0.18.0-cp310-cp310-manylinux1_x86_64.whl.metadata (7.9 kB)
Collecting sacrebleu==2.3.1 (from -r requirements.txt (line 7))
  Downloading sacrebleu-2.3.1-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.0/57.0 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.3.0->-r requirements.txt (line 1))
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.3.0->-r requirements.txt (line 1))
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.3.0->-r requirements.txt (li

In [3]:
import os
import shutil

src_dir = "/kaggle/input/de2en-translation/data"
dst_dir = "data"

os.makedirs(dst_dir, exist_ok=True)

for file_name in os.listdir(src_dir):
    src_file = os.path.join(src_dir, file_name)
    dst_file = os.path.join(dst_dir, file_name)
    
    shutil.copy(src_file, dst_file)

In [4]:
os.makedirs("src", exist_ok=True)
os.makedirs("data/train_inference", exist_ok=True)

In [5]:
import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("wandb")
wandb.login(key=wandb_api_key)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mkdzhr[0m ([33mkdzhr-hse-university[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

Using code from https://github.com/pytorch/examples/tree/main/language_translation

Using code from https://github.com/pytorch/examples/tree/main/language_translation

In [None]:
%%writefile src/data.py
import torch
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator


class DataIter:
    def __init__(self, src_path, tgt_path):
        with open(src_path, "r") as src_file:
            self.src_list = [line for line in src_file.readlines()]

        with open(tgt_path, "r") as tgt_file:
            self.tgt_list = [line for line in tgt_file.readlines()]

        assert len(self.src_list) == len(self.tgt_list)

    def __len__(self):
        return len(self.src_list)
    
    def __getitem__(self, index):
        return (self.src_list[index], self.tgt_list[index])


class SimpleTokenizer:
    def __call__(self, line):
        return line.split()


def _yield_tokens(iterable_data, tokenizer, index):
    for data in iterable_data:
        yield tokenizer(data[index])


def get_all_lines(path):
    with open(path, "r") as f:
        return [line[:-1] for line in f.readlines()]


def get_data(config):
    special_symbols = {
        "<unk>":0,
        "<pad>":1,
        "<bos>":2,
        "<eos>":3
    }

    src_train_lines = get_all_lines(config.locs.src_train_path)
    tgt_train_lines = get_all_lines(config.locs.tgt_train_path)
    src_val_lines = get_all_lines(config.locs.src_val_path)
    tgt_val_lines = get_all_lines(config.locs.tgt_val_path)

    src_lines = src_train_lines + src_val_lines
    tgt_lines = tgt_train_lines + tgt_val_lines
    
    np.random.seed(52)
    indices = np.arange(len(src_lines))
    np.random.shuffle(indices)
    
    lines = [(src_lines[ind], tgt_lines[ind]) for ind in indices]

    print(lines[:5])
    
    val_size = int(len(src_lines) * config.train.val_size)
    valid_iterator, train_iterator = lines[:val_size], lines[val_size:]
    
    test_iterator = [(line, line) for line in get_all_lines(config.locs.src_test_path)]

    src_tokenizer, tgt_tokenizer = SimpleTokenizer(), SimpleTokenizer()

    src_vocab = build_vocab_from_iterator(
        _yield_tokens(train_iterator, src_tokenizer, 0),
        min_freq=config.train.src_min_freq,
        specials=list(special_symbols.keys()),
        special_first=True
    )

    tgt_vocab = build_vocab_from_iterator(
        _yield_tokens(train_iterator, tgt_tokenizer, 1),
        min_freq=config.train.tgt_min_freq,
        specials=list(special_symbols.keys()),
        special_first=True
    )

    src_vocab.set_default_index(special_symbols["<unk>"])
    tgt_vocab.set_default_index(special_symbols["<unk>"])

    def _seq_transform(*transforms):
        def func(txt_input):
            for transform in transforms:
                txt_input = transform(txt_input)
            return txt_input
        return func

    def _tensor_transform(token_ids):
        return torch.cat(
            (torch.tensor([special_symbols["<bos>"]]),
             torch.tensor(token_ids),
             torch.tensor([special_symbols["<eos>"]]))
        )

    src_lang_transform = _seq_transform(src_tokenizer, src_vocab, _tensor_transform)
    tgt_lang_transform = _seq_transform(tgt_tokenizer, tgt_vocab, _tensor_transform)
    
    def _collate_fn(batch):
        src_batch, tgt_batch = [], []
        for src_sample, tgt_sample in batch:
            src_batch.append(src_lang_transform(src_sample.rstrip("\n")))
            tgt_batch.append(tgt_lang_transform(tgt_sample.rstrip("\n")))

        src_batch = pad_sequence(src_batch, padding_value=special_symbols["<pad>"], batch_first=True)
        tgt_batch = pad_sequence(tgt_batch, padding_value=special_symbols["<pad>"], batch_first=True)
        return src_batch, tgt_batch

    train_dataloader = DataLoader(train_iterator, batch_size=config.train.batch_size, collate_fn=_collate_fn)
    valid_dataloader = DataLoader(valid_iterator, batch_size=config.train.batch_size, collate_fn=_collate_fn)
    test_dataloader = DataLoader(test_iterator, batch_size=config.train.batch_size, collate_fn=_collate_fn)

    return train_dataloader, valid_dataloader, test_dataloader, src_vocab, tgt_vocab, src_lang_transform, tgt_lang_transform, special_symbols, train_iterator, valid_iterator, test_iterator


def generate_square_subsequent_mask(size, device):
    mask = (torch.triu(torch.ones((size, size), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt, pad_idx, device):
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    
    src_padding_mask = (src == pad_idx)
    tgt_padding_mask = (tgt == pad_idx)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [None]:
%%writefile src/model.py
import math

import torch
from torch.nn import functional as F
from torch import nn

class PositionalEncoding(nn.Module):
    def __init__(
        self,
        emb_size,
        dropout,
        maxlen=5000
    ):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(0)
                     
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:, :token_embedding.size(1), :])


def noisy_emb(emb, noise):
    return emb + torch.randn_like(emb) * noise


class Translator(nn.Module):
    def __init__(
            self,
            num_encoder_layers,
            num_decoder_layers,
            embed_size,
            num_heads,
            src_vocab_size,
            tgt_vocab_size,
            dim_feedforward,
            dropout,
            emb_noise,
        ):
        super(Translator, self).__init__()

        self.emb_noise = emb_noise

        self.src_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, embed_size)

        self.pos_enc = PositionalEncoding(embed_size, dropout)

        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )

        self.ff = nn.Linear(embed_size, tgt_vocab_size)

        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, trg, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):

        src_emb = self.pos_enc(self.src_embedding(src))
        tgt_emb = self.pos_enc(self.tgt_embedding(trg))

        if self.training and self.emb_noise != 0:
            src_emb = noisy_emb(src_emb, self.emb_noise)
            tgt_emb = noisy_emb(tgt_emb, self.emb_noise)
        
        outs = self.transformer(
            src_emb,
            tgt_emb,
            src_mask[0] if len(src_mask.shape) == 3 else src_mask,
            tgt_mask[0] if len(tgt_mask.shape) == 3 else tgt_mask,
            None,
            src_padding_mask,
            tgt_padding_mask,
            memory_key_padding_mask
        )

        return self.ff(outs)

    def encode(self, src, src_mask):
        embed = self.src_embedding(src)
        pos_enc = self.pos_enc(embed)
        return self.transformer.encoder(pos_enc, src_mask)

    def decode(self, tgt, memory, tgt_mask):
        embed = self.tgt_embedding(tgt)
        pos_enc = self.pos_enc(embed)
        return self.transformer.decoder(pos_enc, memory, tgt_mask)


class ParallelTranslator(nn.Module):
    def __init__(
            self,
            num_encoder_layers,
            num_decoder_layers,
            embed_size,
            num_heads,
            src_vocab_size,
            tgt_vocab_size,
            dim_feedforward,
            dropout,
            emb_noise,
        ):
        super().__init__()
        
        self.model = Translator(
            num_encoder_layers,
            num_decoder_layers,
            embed_size,
            num_heads,
            src_vocab_size,
            tgt_vocab_size,
            dim_feedforward,
            dropout,
            emb_noise,
        )
        self.par_model = nn.DataParallel(self.model)
    
    def forward(self, src, trg, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
        return self.par_model(src, trg, src_mask.unsqueeze(0).repeat(src.shape[0], 1, 1), tgt_mask.unsqueeze(0).repeat(src.shape[0], 1, 1), src_padding_mask, tgt_padding_mask, memory_key_padding_mask)

    @torch.no_grad()
    def encode(self, src, src_mask):
        self.model.eval()
        return self.model.encode(src, src_mask)

    @torch.no_grad()
    def decode(self, tgt, memory, tgt_mask):
        self.model.eval()
        return self.model.decode(tgt, memory, tgt_mask)

    @torch.no_grad()
    def ff(self, x):
        self.model.eval()
        return self.model.ff(x)

In [None]:
%%writefile main.py
# Using code from https://github.com/pytorch/examples/tree/main/language_translation


import os
import logging
import random
import json
from time import time

import torch
import numpy as np
import sacrebleu
import wandb
from torch.nn import functional as F
from tqdm import tqdm

from src.model import Translator, ParallelTranslator
from src.data import get_data, create_mask, generate_square_subsequent_mask
from argparse import ArgumentParser


BOS_STRING = "<bos>"
EOS_STRING = "<eos>"
PAD_STRING = "<pad>"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# We'll use this common hack to aesthetically access config elements
class DotDict(dict):
    def __getattr__(self, key):
        assert key in self
        value = self.get(key)
        if isinstance(value, dict):
            return DotDict(value)
        return value


class CosineScheduler:
    def __init__(self, optimizer, init_lr=1e-6, warmup_epochs=20, decay_epochs=180, lr=4e-3, decay_lr=1e-6):
        self.optimizer = optimizer
        self.init_lr = init_lr
        self.warmup_epochs = warmup_epochs
        self.decay_epochs = decay_epochs
        self.lr = lr
        self.decay_lr = decay_lr
        self.last_lr = None

    def get_last_lr(self):
        return [self.last_lr]
    
    def update_optimizer(self, lr):
        self.last_lr = lr
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr
    
    def step(self, epoch):
        if epoch < self.warmup_epochs:
            self.update_optimizer(self.init_lr + epoch / self.warmup_epochs * (self.lr - self.init_lr))
        else:
            epoch -= self.warmup_epochs
            self.update_optimizer(self.lr * (1 + np.cos(np.pi * epoch / self.decay_epochs)) / 2)


def greedy_decode_multi(model, src_batch, src_mask, special_symbols, max_length):
    model.eval()
    batch_size = src_batch.shape[0]
    memory = model.encode(src_batch, src_mask)

    tgt_batch = torch.ones((batch_size, 1), dtype=torch.long).fill_(special_symbols[BOS_STRING]).to(DEVICE)

    finished_lines = torch.zeros(batch_size, dtype=torch.bool).to(DEVICE)

    for _ in range(max_length-1):
        tgt_mask = generate_square_subsequent_mask(tgt_batch.size(1), DEVICE)
        output = model.decode(tgt_batch, memory, tgt_mask)
        prob = model.ff(output[:, -1])
        
        next_tokens = prob.argmax(dim=1)
        next_tokens[finished_lines] = special_symbols[PAD_STRING]
        finished_lines |= (next_tokens == special_symbols[EOS_STRING])
        tgt_batch = torch.cat([tgt_batch, next_tokens.unsqueeze(1)], dim=1)
        
        if finished_lines.all():
            break
            
    return tgt_batch


# it's really slow, but smh gets the best score
def beam_decode_naive(model, src, src_mask, special_symbols, max_length, beam_topk, all_topk):
    model.eval()
    memory = model.encode(src, src_mask)
    
    start_token = torch.ones(1, 1).fill_(special_symbols[BOS_STRING]).type(torch.long).to(DEVICE)
    nodes = [(start_token, 0)]    
    for _ in range(max_length - 2):
        nodes_new = list()
        all_ended = True
        for seq, score in nodes:
            if seq[0, -1] == special_symbols[EOS_STRING]:
                nodes_new.append((seq, score))
            else:
                all_ended = False
                tgt_mask = generate_square_subsequent_mask(seq.size(1), DEVICE)
                output = model.decode(seq, memory, tgt_mask)
                logits = model.ff(output[:, -1])
                
                probs = F.log_softmax(logits, dim=1)[0]
                topk_scores, topk_tokens = torch.topk(probs, beam_topk)

                for cur_score, token_num in zip(topk_scores, topk_tokens):
                    token = torch.ones(1, 1).fill_(token_num).type(torch.long).to(DEVICE)
                    nodes_new.append((torch.cat((seq, token), dim=1), score + cur_score))
        if all_ended:
            break
        nodes_new.sort(key=lambda x: x[1], reverse=True)
        nodes = nodes_new[:all_topk]
    return nodes[0][0]


def beam_decode(model, src, src_mask, special_symbols, max_length, beam_topk, all_topk):
    model.eval()
    memory = model.encode(src.repeat(all_topk, 1), src_mask)
    
    tgt_batch = torch.ones(all_topk, 1).fill_(special_symbols[BOS_STRING]).type(torch.long).to(DEVICE)
    scores = torch.zeros(1, 1).type(torch.float64).to(DEVICE)

    best_end_score = None
    end_seq = None

    end_scores = list()
    
    for len_i in range(max_length - 2):
        cur_cnt = scores.shape[0]
        nodes_new = list()
        all_ended = True

        tgt_mask = generate_square_subsequent_mask(tgt_batch.size(1), DEVICE)
        output = model.decode(tgt_batch, memory, tgt_mask)
        logits = model.ff(output[:, -1])
        
        topk_scores, topk_tokens = torch.topk(F.log_softmax(logits[:cur_cnt], dim=1), beam_topk, dim=1)
        topk_scores += scores

        next_size = min(all_topk, beam_topk * cur_cnt)
        new_scores, raw_top_ind = torch.topk(topk_scores.flatten(), next_size)
        top_ind_i, top_ind_j = raw_top_ind // beam_topk, raw_top_ind % beam_topk

        new_scores = new_scores.unsqueeze(1)

        cur_tokens = topk_tokens[top_ind_i, top_ind_j].unsqueeze(1)
        end_mask = cur_tokens == special_symbols[EOS_STRING]
        
        new_tokens = torch.ones(all_topk, 1).fill_(special_symbols[BOS_STRING]).type(torch.long).to(DEVICE)
        new_tokens[:next_size] = topk_tokens[top_ind_i, top_ind_j].unsqueeze(1)
        
        new_batch = torch.ones(all_topk, len_i + 1).fill_(special_symbols[BOS_STRING]).type(torch.long).to(DEVICE)
        new_batch[:next_size] = tgt_batch[top_ind_i]
        
        tgt_batch = torch.cat((new_batch, new_tokens), dim=1)
        scores = new_scores

        end_mask = new_tokens[:next_size] == special_symbols[EOS_STRING]
        if end_mask.any():
            indices = np.arange(scores.shape[0])[end_mask.flatten().cpu()]
            ind_max = indices[scores[end_mask].flatten().argmax().item()]
            assert end_mask[ind_max]
            end_scores += list(scores[end_mask].flatten())
            end_scores = sorted(end_scores, reverse=True)[:all_topk]
            if best_end_score is None or scores[ind_max] > best_end_score:
                best_end_score = scores[ind_max]
                end_seq = tgt_batch[ind_max].unsqueeze(0)
            scores[end_mask] -= 1e30

        comb_scores = sorted(list(scores.flatten()) + end_scores, reverse=True)
        if len(comb_scores) > all_topk:
            threshold = comb_scores[all_topk]
            scores[scores <= threshold] -= 1e30
        
        if best_end_score is not None and best_end_score > scores.max():
            break
            
    ind_scores = scores.argmax()
    res = end_seq if best_end_score is not None and best_end_score > scores[ind_scores] else tgt_batch[ind_scores].unsqueeze(0)
    if res[0, -1] != special_symbols[EOS_STRING]:
        end_token = torch.ones(1, 1).fill_(special_symbols[EOS_STRING]).type(torch.long).to(DEVICE)
        res = torch.cat((res, end_token), dim=1)
    return res


def beam_decode_multi(model, src_batch, src_mask, special_symbols, max_lengths, beam_topk, all_topk):
    max_length = max_lengths.max().item()
    tgt_batch = torch.ones(src_batch.shape[0], max_length).fill_(special_symbols[PAD_STRING]).type(torch.long).to(DEVICE)
    for i in range(src_batch.shape[0]):
        res = beam_decode(model, src_batch[i].unsqueeze(0), src_mask, special_symbols, max_lengths[i], beam_topk, all_topk)
        tgt_batch[i:i+1, :res.shape[1]] = res
    return tgt_batch


@torch.no_grad()
def generate_translations(model, dl, special_symbols, tgt_vocab, gen_method, gen_config):
    translations = list()
    original_sentences = list()
    for src_batch, _ in tqdm(dl):
        src_mask = torch.zeros((src_batch.shape[1], src_batch.shape[1]), dtype=torch.bool)
        
        src_batch = src_batch.to(DEVICE)
        src_mask = src_mask.to(DEVICE)

        max_length = src_batch.shape[1] + 5

        max_lengths = torch.zeros(src_batch.shape[0], dtype=torch.int32)
        for i in range(src_batch.shape[0]):
            cur_max_length = (src_batch[i] == special_symbols[EOS_STRING]).nonzero()
            if cur_max_length.numel() == 0:
                cur_max_length = max_length
            else:
                cur_max_length = cur_max_length[0].item() + 5
            max_lengths[i] = cur_max_length
        
        if gen_method == "greedy":
            tgt_batch = greedy_decode_multi(model, src_batch, src_mask, special_symbols, max_length)
        elif gen_method == "beam":
            tgt_batch = beam_decode_multi(
                model,
                src_batch,
                src_mask,
                special_symbols,
                max_lengths,
                beam_topk=gen_config.beam.beam_topk,
                all_topk=gen_config.beam.all_topk,
            )

        for i, tgt_tokens in enumerate(tgt_batch):
            cur_max_length = max_lengths[i]
            output_as_list = list(tgt_tokens.cpu().numpy())[:cur_max_length]
            output_list_words = filter(lambda elem: elem not in special_symbols, tgt_vocab.lookup_tokens(output_as_list))
            translation = " ".join(output_list_words)
            print(translation)
            translations.append(translation)

    return translations


@torch.no_grad()
def calc_metrics(model, val_dl, tgt_vocab, special_symbols, valid_iterator, gen_method, gen_config):
    model.eval()
    translations = generate_translations(model, val_dl, special_symbols, tgt_vocab, gen_method, gen_config)
    original_sentences = [tgt for src, tgt in valid_iterator]
    for ind in random.choices(range(len(original_sentences)), k=5):
        print(f"original_sentences[{ind}]: {original_sentences[ind]}")
        print(f"translations[{ind}]: {translations[ind]}")

    bleu_scores = [
        sacrebleu.sentence_bleu(translation, [reference]).score
        for translation, reference in zip(translations, original_sentences)
    ]
    
    metric_values = {
        # "BLEU": sacrebleu.corpus_bleu(translations, [[elem] for elem in original_sentences]),
        "BLEU": sum(bleu_scores) / len(bleu_scores),
    }
    return metric_values


@torch.no_grad()
def save_translations(model, dl, tgt_vocab, special_symbols, file_path, gen_method, gen_config):
    model.eval()
    translations = generate_translations(model, dl, special_symbols, tgt_vocab, gen_method, gen_config)
    with open(file_path, "w") as f:
        f.write("\n".join(translations) + "\n")


def inference(config, model_path):
    _, val_dl, test_dl, src_vocab, tgt_vocab, src_transform, _, special_symbols, _, valid_iterator, _ = get_data(config)

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    model = (ParallelTranslator if config.train.parallel else Translator)(
        num_encoder_layers=config.model.enc_layers,
        num_decoder_layers=config.model.dec_layers,
        embed_size=config.model.embed_size,
        num_heads=config.model.attn_heads,
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        dim_feedforward=config.model.dim_feedforward,
        dropout=config.model.dropout,
        emb_noise=config.model.emb_noise,
    ).to(DEVICE)

    model.load_state_dict(torch.load(model_path))
    model.eval()

    # print("Metrics on val set:")
    # print(calc_metrics(model, val_dl, tgt_vocab, special_symbols, valid_iterator, config.gen.inference, config.gen))
    
    output_lines = generate_translations(model, test_dl, special_symbols, tgt_vocab, config.gen.inference, config.gen)
        
    with open(config.locs.tgt_test_path, "w") as file:
        file.write("\n".join(output_lines) + "\n")
    

def train(model, train_dl, loss_fn, optim, scheduler, epoch, special_symbols, config):    
    # Object for accumulating losses
    losses = 0
    iters = len(train_dl)

    model.train()
    for i, (src, tgt) in tqdm(enumerate(train_dl), ascii=True, total=len(train_dl)):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        
        tgt_input = tgt[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols["<pad>"], DEVICE)
        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        optim.zero_grad()
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        
        loss.backward()
        optim.step()
        if scheduler is not None:
            scheduler.step(epoch + i / iters)
        losses += loss.item()

    return losses / len(list(train_dl))


def validate(model, valid_dl, loss_fn, special_symbols):    
    losses = 0
    model.eval()
    for src, tgt in tqdm(valid_dl):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:, :-1]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols["<pad>"], DEVICE)
        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)
        
        tgt_out = tgt[:, 1:]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(valid_dl))


def main(config, load_weights):
    os.makedirs(config.locs.logging_dir, exist_ok=True)
    logger = logging.getLogger(__name__)
    logging.basicConfig(filename=config.locs.logging_dir + "log.txt", level=logging.INFO)

    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger().addHandler(console)

    logging.info(f"Translation task: {config.lang.src} -> {config.lang.tgt}")
    logging.info(f"Using device: {DEVICE}")

    train_dl, valid_dl, test_dl, src_vocab, tgt_vocab, _, _, special_symbols, _, valid_iterator, _ = get_data(config)

    logging.info("Loaded data")

    src_vocab_size = len(src_vocab)
    tgt_vocab_size = len(tgt_vocab)

    logging.info(f"{config.lang.src} vocab size: {src_vocab_size}")
    logging.info(f"{config.lang.tgt} vocab size: {tgt_vocab_size}")

    model = (ParallelTranslator if config.train.parallel else Translator)(
        num_encoder_layers=config.model.enc_layers,
        num_decoder_layers=config.model.dec_layers,
        embed_size=config.model.embed_size,
        num_heads=config.model.attn_heads,
        src_vocab_size=src_vocab_size,
        tgt_vocab_size=tgt_vocab_size,
        dim_feedforward=config.model.dim_feedforward,
        dropout=config.model.dropout,
        emb_noise=config.model.emb_noise,
    ).to(DEVICE)

    if load_weights:
        model.load_state_dict(torch.load(config.locs.model_path))
    
    logging.info("Model created... starting training!")

    loss_fn = torch.nn.CrossEntropyLoss(
        ignore_index=special_symbols["<pad>"],
        label_smoothing=config.train.label_smoothing,
    )

    if config.train.optim.name == "Adam":
        opt_params = config.train.optim.adam
        optim = torch.optim.Adam(
            model.parameters(),
            lr=opt_params.lr,
            betas=(opt_params.beta1,opt_params.beta2),
            eps=opt_params.eps,
            weight_decay=opt_params.weight_decay,
        )
    
    if config.train.scheduler.name == "Cosine":
        sch_params = config.train.scheduler.cosine
        scheduler = CosineScheduler(
            optim,
            init_lr=sch_params.init_lr,
            warmup_epochs=sch_params.warmup_epochs,
            decay_epochs=sch_params.decay_epochs,
            lr=sch_params.lr,
            decay_lr=sch_params.decay_lr,
        )
    else:
        scheduler = None

    best_val_loss = 1e6
    
    for idx, epoch in enumerate(range(1, config.train.epochs+1)):

        start_time = time()
        train_loss = train(model, train_dl, loss_fn, optim, scheduler, epoch, special_symbols, config)
        epoch_time = time() - start_time
        val_loss   = validate(model, valid_dl, loss_fn, special_symbols)
        metrics    = calc_metrics(model, valid_dl, tgt_vocab, special_symbols, valid_iterator, config.gen.train, config.gen)
        
        bleu_score = metrics["BLEU"]

        mod_str = f"e{epoch}_val{val_loss:.3f}_BLEU{bleu_score:.3f}"
        torch.save(model.state_dict(), config.locs.logging_dir + mod_str + ".pt")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            logging.info("New best model, saving...")
            # torch.save(model.state_dict(), config.locs.logging_dir + "best.pt")

        save_path = config.locs.tgt_test_dir + mod_str + ".en"
        save_translations(model, test_dl, tgt_vocab, special_symbols, save_path, config.gen.train, config.gen)
        wandb.save(save_path)

        # torch.save(model.state_dict(), config.locs.logging_dir + "last.pt")

        logger.info(f"Epoch: {epoch}\n\tTrain loss: {train_loss:.3f}\n\tVal loss: {val_loss:.3f}\n\tEpoch time = {epoch_time:.1f} seconds\n\tETA = {epoch_time*(config.train.epochs-idx-1):.1f} seconds\n\tMetrics:{metrics}")

        log_dict = {
            "Epoch": epoch,
            "Train loss": train_loss,
            "Val loss": val_loss,
            "Epoch time": epoch_time,
            **metrics,
        }

        if scheduler is not None:
            log_dict["Learning rate"] = scheduler.get_last_lr()[0]
        
        wandb.log(log_dict)

    torch.save(model.state_dict(), config.locs.logging_dir + "last.pt")


if __name__ == "__main__":
    wandb.finish()

    random.seed(52)
    torch.manual_seed(52)
    
    parser = ArgumentParser(
        prog="Machine Translator training and inference",
    )

    parser.add_argument("--config", type=str, default="config.json")
    parser.add_argument("--exp_name", type=str, default="???")
    parser.add_argument("--inference", action="store_true")
    parser.add_argument("--model_path", type=str, default="NONE")
    parser.add_argument("--load", action="store_true")
    
    args = parser.parse_args()

    with open(args.config) as f:
        config = json.load(f)

    if args.inference:
        inference(DotDict(config), args.model_path)
    else:    
        wandb.init(project='DL_bhw2', config=config, name=args.exp_name)
        config = DotDict(wandb.config)
        DEVICE = torch.device("cuda" if config.train.backend == "gpu" and torch.cuda.is_available() else "cpu")    
        main(config, args.load)


In [None]:
%%writefile config.json
{
    "lang": {
        "src": "de",
        "tgt": "en"
    },
    "locs": {
        "src_train_path": "data/train.de-en.de",
        "src_val_path": "data/val.de-en.de",
        "tgt_train_path": "data/train.de-en.en",
        "tgt_val_path": "data/val.de-en.en",
        "logging_dir": "saves/",
        
        "src_test_path": "data/test1.de-en.de",
        "tgt_test_path": "data/test1.de-en.en",
        "tgt_test_dir": "data/train_inference/"
    },
    "train": {
        "epochs": 50,
        "val_size": 0.1,
        "optim": {
            "name": "Adam",
            "adam": {
                "lr": 3e-4,
                "beta1": 0.9,
                "beta2": 0.98,
                "eps": 1e-9,
                "weight_decay": 0
            },
        },
        "scheduler": {
            "name": "Cosine",
            "cosine": {
                "init_lr": 1e-6,
                "warmup_epochs": 5,
                "decay_epochs": 45,
                "lr": 2e-3,
                "decay_lr": 1e-5
            }
        },
        "batch_size": 128,
        "src_min_freq": 7,
        "tgt_min_freq": 7,
        "backend": "gpu",
        "parallel": true,
        "label_smoothing": 0.1
    },
    "model": {
        "attn_heads": 4,
        "enc_layers": 4,
        "dec_layers": 4,
        "embed_size": 128,
        "dim_feedforward": 256,
        "dropout": 0.1,
        "emb_noise": 0.2
    },
    "gen": {
    "train": "greedy",
    "inference": "beam",
    "greedy": {
        "meow": "meow"
    },
    "beam": {
        "beam_topk": 3,
        "all_topk": 8
    }
    }
}

Writing config.json


In [None]:
!python3 main.py --exp_name ed4_mf7_en02

[34m[1mwandb[0m: Currently logged in as: [33mkdzhr[0m ([33mkdzhr-hse-university[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.19.1
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20250313_074739-rgys72xq[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33med4_mf7_en02[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/kdzhr-hse-university/DL_bhw2[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/kdzhr-hse-university/DL_bhw2/runs/rgys72xq[0m
Translation task: de -> en
Using device: cuda
[('also wenn diese vorhersagen korrekt sind , wird diese lücke sich nicht schließen .', 'so if these predictions are accurate , that gap is not going to close .'), ('wir nehmen das gerüst und bepflanzen es mit zellen ; und hier sehen sie , wie sich die klappensegel öffnen und schließen .', 'we take the scaff

In [None]:
!python3 main.py --inference --model_path saves/last.pt

### Final submission is now located in data/test1.de-en.en

In [None]:
# !python3 main.py --exp_name ed4_mf7_en02_fine --load

New best model, saving...
100%|███████████████████████████████████████████| 24/24 [00:15<00:00,  1.60it/s]
Epoch: 8
	Train loss: 2.992
	Val loss: 3.113
	Epoch time = 239.4 seconds
	ETA = 10054.1 seconds
	Metrics:{'BLEU': 26.432411396040642}
100%|#######################################| 1385/1385 [03:46<00:00,  6.12it/s]
100%|█████████████████████████████████████████| 154/154 [00:13<00:00, 11.73it/s]
100%|█████████████████████████████████████████| 154/154 [01:44<00:00,  1.47it/s]
original_sentences[560]: and i put away the newspaper -- and i was getting on a plane -- and i sat there , and i did something i hadn 't done for a long time -- which is i did nothing .
translations[560]: and i -- i put the newspaper aside -- i just got into an airplane -- and then i sat there and i did something i didn 't do for a long time -- which is nothing .
original_sentences[84]: well , the supreme court considered this 100-years tradition and said , in an opinion written by justice douglas , that the ca

In [None]:
# import shutil

# files = ['main.py', 'src/model.py', 'src/data.py', '2025-03-03/best.pt', 'requirements.txt', '']
# arc = 'submission.zip'

# shutil.make_archive(archive_name.replace('.zip', ''), 'zip', root_dir='.', base_dir=None, files=files_to_archive)

# print(f"Archive created: {archive_name}")