In [None]:
pip install sacrebleu wandb

https://wandb.ai/artempotarusov-/bhw-checkpoint/reports/BHW-2-Final--VmlldzoxMTc1OTY2Ng

In [None]:
import os
import zipfile
import math
import time
import torch
import torchtext
import wandb
import sacrebleu
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchtext.vocab import build_vocab_from_iterator
from collections import Counter
from tqdm import tqdm

In [None]:
def load_data_from_zip(file_path):
    with open(file_path, encoding="utf-8") as f:
        content = f.read()
    return content.strip().split("\n")

def yield_tokens_from_lines(lines):
    for line in lines:
        yield line.split()

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src_padded = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=src_vocab["<pad>"])
    tgt_padded = nn.utils.rnn.pad_sequence(tgt_batch, batch_first=True, padding_value=tgt_vocab["<pad>"])
    return src_padded, tgt_padded

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

In [None]:
class TranslationDataset(Dataset):
    def __init__(self, src_sentences, tgt_sentences=None, src_vocab=None, tgt_vocab=None):
        self.src_sentences = src_sentences
        self.tgt_sentences = tgt_sentences
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    def numericalize(self, sentence, vocab, add_tokens=False):
        tokens = sentence.split()
        if add_tokens:
            tokens = ["<bos>"] + tokens + ["<eos>"]
        return [vocab[tok] for tok in tokens]

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src_indices = self.numericalize(self.src_sentences[idx], self.src_vocab, add_tokens=False)
        if self.tgt_sentences is not None:
            tgt_indices = self.numericalize(self.tgt_sentences[idx], self.tgt_vocab, add_tokens=True)
            return torch.tensor(src_indices), torch.tensor(tgt_indices)
        else:
            return torch.tensor(src_indices)

In [None]:
class TransformerMT(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, src_pad_idx, tgt_pad_idx,
                 d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6,
                 dim_feedforward=2048, dropout=0.1):
        super(TransformerMT, self).__init__()
        self.d_model = d_model
        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        self.pos_decoder = PositionalEncoding(d_model, dropout)
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers,
                                          num_decoder_layers, dim_feedforward, dropout)
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, 0.0)
        return mask

    def encode(self, src, src_mask):
        src_padding_mask = (src == self.src_pad_idx).to(src.device)
        src_emb = self.pos_encoder(self.src_embedding(src) * math.sqrt(self.d_model))
        src_emb = src_emb.transpose(0, 1)
        memory = self.transformer.encoder(src_emb, src_key_padding_mask=src_padding_mask)
        return memory

    def decode(self, tgt, memory, tgt_mask):
        tgt_padding_mask = (tgt == self.tgt_pad_idx).to(tgt.device)
        tgt_emb = self.pos_decoder(self.tgt_embedding(tgt) * math.sqrt(self.d_model))
        tgt_emb = tgt_emb.transpose(0, 1)
        out = self.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask,
                                       tgt_key_padding_mask=tgt_padding_mask)
        return out

    def forward(self, src, tgt):
        src_padding_mask = (src == self.src_pad_idx).to(src.device)
        tgt_padding_mask = (tgt == self.tgt_pad_idx).to(tgt.device)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        src_emb = self.pos_encoder(self.src_embedding(src) * math.sqrt(self.d_model))
        tgt_emb = self.pos_decoder(self.tgt_embedding(tgt) * math.sqrt(self.d_model))
        src_emb = src_emb.transpose(0, 1)
        tgt_emb = tgt_emb.transpose(0, 1)
        out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask,
                               src_key_padding_mask=src_padding_mask,
                               tgt_key_padding_mask=tgt_padding_mask,
                               memory_key_padding_mask=src_padding_mask)
        out = self.fc_out(out)
        return out.transpose(0, 1)

In [None]:
@torch.no_grad()
def greedy_decode_new(model, src, src_mask, max_len, start_symbol, end_symbol):
    if hasattr(model, "module"):
        model = model.module
    device = src.device
    src = src.unsqueeze(0)
    src_mask = src_mask.to(device)
    memory = model.encode(src, src_mask)
    result = torch.full((1, 1), start_symbol, dtype=torch.long).to(device)
    for _ in range(max_len - 1):
        tgt_mask = model.generate_square_subsequent_mask(result.size(1)).to(device)
        out = model.decode(result, memory, tgt_mask)
        out = out.transpose(0, 1)
        logits = model.fc_out(out[:, -1, :])
        _, next_word = torch.max(logits, dim=1)
        next_word = next_word.item()
        result = torch.cat([result, torch.tensor([[next_word]], device=device)], dim=1)
        if next_word == end_symbol:
            break
    return result.squeeze(0)

In [None]:
def train_epoch(model, dataloader, optimizer, criterion, scheduler, clip=1.0):
    model.train()
    total_loss = 0
    for src_batch, tgt_batch in tqdm(dataloader, desc="Training"):
        src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
        tgt_input = tgt_batch[:, :-1]
        tgt_output = tgt_batch[:, 1:]
        optimizer.zero_grad()
        output = model(src_batch, tgt_input)
        loss = criterion(output.reshape(-1, output.shape[-1]), tgt_output.reshape(-1))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        lr = scheduler.step()
        total_loss += loss.item()
        wandb.log({"batch_loss": loss.item()})
    return total_loss / len(dataloader)

In [None]:
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for src_batch, tgt_batch in tqdm(dataloader, desc="Evaluating"):
            src_batch, tgt_batch = src_batch.to(device), tgt_batch.to(device)
            tgt_input = tgt_batch[:, :-1]
            tgt_output = tgt_batch[:, 1:]
            output = model(src_batch, tgt_input)
            loss = criterion(output.reshape(-1, output.shape[-1]), tgt_output.reshape(-1))
            total_loss += loss.item()
    return total_loss / len(dataloader)

In [None]:
class WarmupInverseSquareRootSchedule:
    def __init__(self, optimizer, warmup_steps, d_model):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.d_model = d_model
        self.current_step = 0
        self.scale = d_model ** (-0.5)

    def step(self):
        self.current_step += 1
        arg1 = self.current_step ** (-0.5)
        arg2 = self.current_step * (self.warmup_steps ** (-1.5))
        lr = self.scale * min(arg1, arg2)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

In [None]:
wandb.init(project="bhw-checkpoint", config={
        "d_model": 512,
        "nhead": 8,
        "num_encoder_layers": 4,
        "num_decoder_layers": 4,
        "dim_feedforward": 512,
        "dropout": 0.2,
        "batch_size": 64,
        "epochs": 20,
        "lr": 3e-4,
        "warmup_steps": 10000
    })
config = wandb.config

In [None]:
train_src_lines = load_data_from_zip("/home/jupyter/datasphere/datasets/bhw2/train.de-en.de")
train_tgt_lines = load_data_from_zip("/home/jupyter/datasphere/datasets/bhw2/train.de-en.en")
valid_src_lines = load_data_from_zip("/home/jupyter/datasphere/datasets/bhw2/val.de-en.de")
valid_tgt_lines = load_data_from_zip("/home/jupyter/datasphere/datasets/bhw2/val.de-en.en")
test_src_lines = load_data_from_zip("/home/jupyter/datasphere/datasets/bhw2/test1.de-en.de")

global special_symbols
special_symbols = {
        "<unk>":2,
        "<pad>":3,
        "<bos>":0,
        "<eos>":1
    }
specials = ["<bos>", "<eos>", "<unk>", "<pad>"]
global src_vocab, tgt_vocab, inv_tgt_vocab
src_vocab = build_vocab_from_iterator(yield_tokens_from_lines(train_src_lines),
                                      min_freq=10,
                                      specials=specials,
                                      special_first=True)
tgt_vocab = build_vocab_from_iterator(yield_tokens_from_lines(train_tgt_lines),
                                      min_freq=10,
                                      specials=specials,
                                      special_first=True)

src_vocab.set_default_index(src_vocab["<unk>"])
tgt_vocab.set_default_index(tgt_vocab["<unk>"])

inv_tgt_vocab = {idx: token for idx, token in enumerate(tgt_vocab.get_itos())}

global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
train_dataset = TranslationDataset(train_src_lines, train_tgt_lines, src_vocab, tgt_vocab)
valid_dataset = TranslationDataset(valid_src_lines, valid_tgt_lines, src_vocab, tgt_vocab)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=collate_fn)

In [None]:
model_config = {k: config[k] for k in ["d_model", "nhead", "num_encoder_layers", "num_decoder_layers", "dim_feedforward", "dropout"]}
model = TransformerMT(len(src_vocab), len(tgt_vocab), src_pad_idx=src_vocab["<pad>"], 
                       tgt_pad_idx=tgt_vocab["<pad>"], **model_config).to(device)
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
criterion = nn.CrossEntropyLoss(ignore_index=tgt_vocab["<pad>"], label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=0.01)
scheduler = WarmupInverseSquareRootSchedule(optimizer, warmup_steps=config.warmup_steps, d_model=config.d_model)

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
for epoch in range(1, config.epochs + 1):
    start_time = time.time()
    train_loss = train_epoch(model, train_loader, optimizer, criterion, scheduler)
    valid_loss = evaluate(model, valid_loader, criterion)
    elapsed = time.time() - start_time
    wandb.log({"epoch": epoch, "train_loss": train_loss, "valid_loss": valid_loss, "epoch_time": elapsed})
    print("Epoch", epoch, "Train Loss:", train_loss, "Valid Loss:", valid_loss, "Time:", elapsed)
wandb.finish()

In [None]:
# model.load_state_dict(torch.load("/home/jupyter/datasphere/project/checkpoint.pth")['model_state_dict'], strict=False)

In [None]:
# checkpoint = {
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
# }
# torch.save(checkpoint, 'checkpoint.pth')

In [None]:
@torch.no_grad()
def beam_search_decode(model, src, src_mask, max_len, start_symbol, end_symbol,
                       beam_size=3, length_penalty=0.75):
    if hasattr(model, "module"):
        model = model.module
    device = src.device
    if src.dim() == 1:
        src = src.unsqueeze(0)
    src_mask = src_mask.to(device)
    memory = model.encode(src, src_mask)
    beams = [([start_symbol], 0.0)]
    completed = []

    for curr_len in range(1, max_len + 1):
        new_beams = []
        for tokens, score in beams:
            if tokens[-1] == end_symbol:
                norm_score = score / (curr_len ** length_penalty)
                completed.append((tokens, norm_score))
                continue
            tgt = torch.tensor([tokens], dtype=torch.long, device=device)
            tgt_mask = model.generate_square_subsequent_mask(tgt.size(1)).to(device)
            out = model.decode(tgt, memory, tgt_mask)  
            if out.dim() == 3 and out.shape[0] != 1:
                out = out.transpose(0, 1)
            logits = model.fc_out(out[:, -1, :])
            log_probs = torch.log_softmax(logits, dim=-1).squeeze(0)
            topk_log_probs, topk_indices = log_probs.topk(beam_size)
            for i in range(beam_size):
                new_tokens = tokens + [topk_indices[i].item()]
                new_score = score + topk_log_probs[i].item()
                new_beams.append((new_tokens, new_score))
        if not new_beams:
            break
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_size]
        if len(completed) >= beam_size:
            break

    if completed:
        best_tokens, _ = max(completed, key=lambda x: x[1])
    else:
        best_tokens, _ = max(beams, key=lambda x: x[1])
    if best_tokens and best_tokens[0] == start_symbol:
        best_tokens = best_tokens[1:]
    if best_tokens and best_tokens[-1] == end_symbol:
        best_tokens = best_tokens[:-1]
    return best_tokens

In [None]:
test_predictions = []
with torch.no_grad():
    for src_sentence in tqdm(test_src_lines):
        src_indices = [src_vocab[tok] for tok in src_sentence.split()]
        src_tensor = torch.tensor(src_indices, dtype=torch.long).to(device)
        src_mask = torch.zeros(len(src_indices), len(src_indices), dtype=torch.bool).to(device)
        decoded_indices = beam_search_decode(
            model, src_tensor, src_mask, max_len=100, beam_size=5,
            start_symbol=tgt_vocab["<bos>"], end_symbol=tgt_vocab["<eos>"]
        )
        decoded_words = [inv_tgt_vocab[idx] for idx in decoded_indices if idx not in [tgt_vocab["<bos>"], tgt_vocab["<pad>"]]]
        if "<eos>" in decoded_words:
            decoded_words = decoded_words[:decoded_words.index("<eos>")]
        test_predictions.append(" ".join(decoded_words))
        
with open("test_predictions.en", "w", encoding="utf-8") as f:
    for line in test_predictions:
        f.write(line + "\n")