In [None]:
!pip install pip==23.0.1



Collecting pip==23.0.1
  Downloading pip-23.0.1-py3-none-any.whl.metadata (4.1 kB)
Downloading pip-23.0.1-py3-none-any.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-23.0.1


In [None]:
!pip install sacrebleu fairseq sentencepiece wandb gdown pandas torch numpy matplotlib tqdm



Collecting sacrebleu
  Downloading sacrebleu-2.4.3-py3-none-any.whl (103 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.0/104.0 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fairseq
  Downloading fairseq-0.12.2.tar.gz (9.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.6/9.6 MB[0m [31m50.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting colorama
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Collecting portalocker
  Downloading portalocker-2.10.1-py3-none-any.whl (18 kB)
Collecting omegaconf<2.1
  Downloading omegaconf-2.0.6-py3-none-any.whl (36 kB)
Collecting bitarray
  Downloading bitarray-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (278 kB)
[2K     [90m━━━━━━━━━━

In [None]:
!gdown --id '1Sys5keuiw4C27_cG3LMyYi6v5eHJvu1L' --output DATA.zip
!unzip -o DATA.zip

Downloading...
From (original): https://drive.google.com/uc?id=1Sys5keuiw4C27_cG3LMyYi6v5eHJvu1L
From (redirected): https://drive.google.com/uc?id=1Sys5keuiw4C27_cG3LMyYi6v5eHJvu1L&confirm=t&uuid=7b34948a-854d-4d76-b03e-5830d3330a7a
To: /content/DATA.zip
100% 28.2M/28.2M [00:00<00:00, 33.9MB/s]
Archive:  DATA.zip
   creating: DATA/
  inflating: DATA/.DS_Store          
  inflating: __MACOSX/DATA/._.DS_Store  
   creating: DATA/rawdata/
   creating: DATA/rawdata/ted2020/
  inflating: DATA/rawdata/.DS_Store  
  inflating: __MACOSX/DATA/rawdata/._.DS_Store  
  inflating: DATA/rawdata/ted2020/test.raw.zh  
  inflating: __MACOSX/DATA/rawdata/ted2020/._test.raw.zh  
  inflating: DATA/rawdata/ted2020/test.raw.en  
  inflating: __MACOSX/DATA/rawdata/ted2020/._test.raw.en  
  inflating: DATA/rawdata/ted2020/train_dev.raw.zh  
  inflating: __MACOSX/DATA/rawdata/ted2020/._train_dev.raw.zh  
  inflating: DATA/rawdata/ted2020/train_dev.raw.en  
  inflating: __MACOSX/DATA/rawdata/ted2020/._train_dev

In [None]:
import sys
import pprint
import os
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
import numpy as np
import tqdm.auto as tqdm
from pathlib import Path
from argparse import Namespace
from fairseq import utils
from fairseq.data import iterators

import matplotlib.pyplot as plt
import shutil
import sacrebleu
import pandas as pd
from torch.cuda.amp import GradScaler, autocast
from fairseq.tasks.translation import TranslationConfig, TranslationTask
from fairseq.models import FairseqEncoderDecoderModel
from fairseq.models.transformer import TransformerEncoder, TransformerDecoder

seed = 42
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

config = Namespace(
    src_lang='en',
    tgt_lang='zh',

    data_dir='./DATA/rawdata',
    dataset_name='ted2020',
    prefix=Path('./DATA/rawdata').absolute() / 'ted2020',

    train_prefix='./DATA/rawdata/ted2020/train_dev.raw',
    test_prefix='./DATA/rawdata/ted2020/test.raw',

    vocab_size=8000,
    max_len=1024,

    spm_model_path='./DATA/sentencepiece.model',

    data_bin_path='./DATA/data-bin/ted2020',

    savedir='./checkpoints/transformer',
    num_workers=4,
    max_tokens=2048,
    accum_steps=4,
    lr_factor=1.0,
    lr_warmup=4000,
    clip_norm=1.0,
    max_epoch=5,
    start_epoch=1,
    beam=5,
    max_len_a=1.2,
    max_len_b=10,

    post_process="sentencepiece",

    keep_last_epochs=5,
    resume=None,

    early_stopping=True,
    early_stopping_patience=7,
)


In [None]:
prefix = Path(config.prefix)
train_dev_en = prefix / 'train_dev.raw.en'
train_dev_zh = prefix / 'train_dev.raw.zh'
test_en = prefix / 'test.raw.en'
test_zh = prefix / 'test.raw.zh'

tokenized_data_dir = Path('./DATA/tokenized')
tokenized_data_dir.mkdir(parents=True, exist_ok=True)

spm_model_path = Path(config.spm_model_path)

if not spm_model_path.exists():
    import sentencepiece as spm
    combined_train_valid = tokenized_data_dir / 'combined_train_valid.txt'
    with open(combined_train_valid, 'w', encoding='utf-8') as outfile, \
         open(train_dev_en, 'r', encoding='utf-8') as en_f, \
         open(train_dev_zh, 'r', encoding='utf-8') as zh_f:
        for en_line, zh_line in zip(en_f, zh_f):
            if en_line.strip() and zh_line.strip():
                outfile.write(en_line.strip() + '\n')
                outfile.write(zh_line.strip() + '\n')
    spm.SentencePieceTrainer.train(
        input=str(combined_train_valid),
        model_prefix=str(spm_model_path.with_suffix('')),
        vocab_size=config.vocab_size,
        model_type='bpe',
        character_coverage=1.0,
        unk_id=0,
        pad_id=1,
        bos_id=2,
        eos_id=3,
    )
    print("SentencePiece model trained.")
else:
    print("SentencePiece model already exists. Skipping training.")

spm_model = spm.SentencePieceProcessor()
spm_model.load(str(spm_model_path))

def tokenize_and_truncate(input_file, output_file, spm_model, max_len):
    with open(input_file, 'r', encoding='utf-8') as infile, \
         open(output_file, 'w', encoding='utf-8') as outfile:
        for line in infile:
            line = line.strip()
            if line:
                tokens = spm_model.encode(line, out_type=str)
                if len(tokens) > max_len:
                    tokens = tokens[:max_len]
                outfile.write(' '.join(tokens) + '\n')

train_en_tok = tokenized_data_dir / 'train.en'
train_zh_tok = tokenized_data_dir / 'train.zh'
if not train_en_tok.exists() or not train_zh_tok.exists():
    print("Tokenizing training data...")
    tokenize_and_truncate(train_dev_en, train_en_tok, spm_model, config.max_len)
    tokenize_and_truncate(train_dev_zh, train_zh_tok, spm_model, config.max_len)
    print("Training data tokenization completed.")
else:
    print("Tokenized training data already exists. Skipping tokenization.")

test_en_tok = tokenized_data_dir / 'test.en'
test_zh_tok = tokenized_data_dir / 'test.zh'
if not test_en_tok.exists() or not test_zh_tok.exists():
    print("Tokenizing test data...")
    tokenize_and_truncate(test_en, test_en_tok, spm_model, config.max_len)
    tokenize_and_truncate(test_zh, test_zh_tok, spm_model, config.max_len)
    print("Test data tokenization completed.")
else:
    print("Tokenized test data already exists. Skipping tokenization.")

from sklearn.model_selection import train_test_split

valid_size = 0.05

train_en_final = tokenized_data_dir / 'train.final.en'
train_zh_final = tokenized_data_dir / 'train.final.zh'
valid_en_final = tokenized_data_dir / 'valid.en'
valid_zh_final = tokenized_data_dir / 'valid.zh'

if not train_en_final.exists() or not valid_en_final.exists():
    print("Splitting training data into train and validation sets...")
    with open(train_en_tok, 'r', encoding='utf-8') as en_f, \
         open(train_zh_tok, 'r', encoding='utf-8') as zh_f:
        en_lines = [line.strip() for line in en_f if line.strip()]
        zh_lines = [line.strip() for line in zh_f if line.strip()]

    en_train, en_valid, zh_train, zh_valid = train_test_split(
        en_lines, zh_lines, test_size=valid_size, random_state=seed
    )

    with open(train_en_final, 'w', encoding='utf-8') as en_train_f, \
         open(train_zh_final, 'w', encoding='utf-8') as zh_train_f:
        for en, zh in zip(en_train, zh_train):
            en_train_f.write(en + '\n')
            zh_train_f.write(zh + '\n')

    with open(valid_en_final, 'w', encoding='utf-8') as en_valid_f, \
         open(valid_zh_final, 'w', encoding='utf-8') as zh_valid_f:
        for en, zh in zip(en_valid, zh_valid):
            en_valid_f.write(en + '\n')
            zh_valid_f.write(zh + '\n')
    print("Data splitting completed.")
else:
    print("Train and validation sets already exist. Skipping data splitting.")


In [None]:
if not Path(config.data_bin_path).exists():
    print("Preprocessing data with Fairseq...")
    os.makedirs(config.data_bin_path, exist_ok=True)
    os.system(
        f"fairseq-preprocess "
        f"--source-lang en "
        f"--target-lang zh "
        f"--trainpref {tokenized_data_dir}/train.final "
        f"--validpref {tokenized_data_dir}/valid "
        f"--testpref {tokenized_data_dir}/test "
        f"--destdir {config.data_bin_path} "
        f"--joined-dictionary "
        f"--workers {config.num_workers}"
    )
    print("Fairseq preprocessing completed.")
else:
    print("Fairseq binary data already exists. Skipping preprocessing.")


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

task_cfg = TranslationConfig(
    data=config.data_bin_path,
    source_lang=config.src_lang,
    target_lang=config.tgt_lang,
    train_subset="train",
    required_seq_len_multiple=8,
    dataset_impl="mmap",
    upsample_primary=1,
)
task = TranslationTask.setup_task(task_cfg)
print("Loading datasets...")
task.load_dataset(split="train", epoch=1, combine=True)
task.load_dataset(split="valid", epoch=1)
task.load_dataset(split="test", epoch=1)


In [None]:
def build_model(args, task):
    src_dict, tgt_dict = task.source_dictionary, task.target_dictionary

    encoder_embed_tokens = nn.Embedding(len(src_dict), args.encoder_embed_dim, src_dict.pad())
    decoder_embed_tokens = nn.Embedding(len(tgt_dict), args.decoder_embed_dim, tgt_dict.pad())

    encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
    decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)

    model = FairseqEncoderDecoderModel(encoder, decoder)

    def init_params(module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0, std=0.02)
        elif isinstance(module, nn.MultiheadAttention):
            nn.init.xavier_uniform_(module.out_proj.weight)
            if module.out_proj.bias is not None:
                nn.init.constant_(module.out_proj.bias, 0)

    model.apply(init_params)
    return model

arch_args = Namespace(
    encoder_embed_dim=256,
    encoder_ffn_embed_dim=1024,
    encoder_layers=3,
    encoder_attention_heads=4,
    encoder_normalize_before=True,
    decoder_embed_dim=256,
    decoder_ffn_embed_dim=1024,
    decoder_layers=3,
    decoder_attention_heads=4,
    decoder_normalize_before=True,
    share_decoder_input_output_embed=False,
    dropout=0.1,
    activation_fn="relu",
    max_source_positions=1024,
    max_target_positions=1024,
)

model = build_model(arch_args, task)
print(model)


In [None]:
class LabelSmoothedCrossEntropyCriterion(nn.Module):
    def __init__(self, smoothing, ignore_index=None, reduce=True):
        super().__init__()
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.reduce = reduce

    def forward(self, lprobs, target):
        if target.dim() == lprobs.dim() - 1:
            target = target.unsqueeze(-1)
        nll_loss = -lprobs.gather(dim=-1, index=target)
        smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
        if self.ignore_index is not None:
            pad_mask = target.eq(self.ignore_index)
            nll_loss.masked_fill_(pad_mask, 0.0)
            smooth_loss.masked_fill_(pad_mask, 0.0)
        else:
            nll_loss = nll_loss.squeeze(-1)
            smooth_loss = smooth_loss.squeeze(-1)
        if self.reduce:
            nll_loss = nll_loss.sum()
            smooth_loss = smooth_loss.sum()
        eps_i = self.smoothing / lprobs.size(-1)
        loss = (1.0 - self.smoothing) * nll_loss + eps_i * smooth_loss
        return loss

criterion = LabelSmoothedCrossEntropyCriterion(
    smoothing=0.1,
    ignore_index=task.target_dictionary.pad(),
)
criterion = criterion.to(device)

def get_rate(d_model, step_num, warmup_step):
    if step_num == 0:
        step_num = 1
    return (d_model ** -0.5) * min(step_num ** -0.5, step_num * (warmup_step ** -1.5))

class NoamOpt:
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def multiply_grads(self, c):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    p.grad.data.mul_(c)

    def step(self):
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step=None):
        if step is None:
            step = self._step
        return self.factor * get_rate(self.model_size, step, self.warmup)

optimizer = NoamOpt(
    model_size=arch_args.encoder_embed_dim,
    factor=config.lr_factor,
    warmup=config.lr_warmup,
    optimizer=torch.optim.AdamW(
        model.parameters(),
        lr=0,
        betas=(0.9, 0.98),
        eps=1e-9,
        weight_decay=0.0001
    )
)

steps = np.arange(1, 10001)
lrs = [optimizer.rate(i) for i in steps]
plt.figure(figsize=(10,6))
plt.plot(steps, lrs)
plt.title("Noam Learning Rate Schedule")
plt.xlabel("Step")
plt.ylabel("Learning Rate")
plt.grid(True)
plt.show()


In [None]:
def decode(toks, dictionary):
    s = dictionary.string(
        toks.int().cpu(),
        config.post_process,
    )
    return s if s else "<unk>"

def inference_step(sample, model):
    sequence_generator = task.build_generator([model], config)
    gen_out = sequence_generator.generate([model], sample)
    srcs = []
    hyps = []
    refs = []
    for i in range(len(gen_out)):
        srcs.append(decode(
            utils.strip_pad(sample["net_input"]["src_tokens"][i], task.source_dictionary.pad()),
            task.source_dictionary,
        ))
        hyps.append(decode(
            gen_out[i][0]["tokens"],
            task.target_dictionary,
        ))
        refs.append(decode(
            utils.strip_pad(sample["target"][i], task.target_dictionary.pad()),
            task.target_dictionary,
        ))
    return srcs, hyps, refs

def validate(model, task, criterion, log_to_console=True):
    print('Beginning validation...')
    itr = load_data_iterator(task, "valid", 1, config.max_tokens, config.num_workers).next_epoch_itr(shuffle=False)

    stats = {"loss": [], "bleu": 0, "srcs": [], "hyps": [], "refs": []}
    srcs = []
    hyps = []
    refs = []

    model.eval()
    progress = tqdm.tqdm(itr, desc=f"Validation", leave=False)
    with torch.no_grad():
        for i, sample in enumerate(progress):
            sample = utils.move_to_cuda(sample, device=device)
            net_output = model.forward(**sample["net_input"])

            lprobs = F.log_softmax(net_output[0], -1)
            target = sample["target"]
            sample_size = sample["ntokens"]
            loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1)) / sample_size
            progress.set_postfix(valid_loss=loss.item())
            stats["loss"].append(loss)

            s, h, r = inference_step(sample, model)
            srcs.extend(s)
            hyps.extend(h)
            refs.extend(r)

    tok = 'zh'
    stats["loss"] = torch.stack(stats["loss"]).mean().item()
    stats["bleu"] = sacrebleu.corpus_bleu(hyps, [refs], tokenize=tok)
    stats["srcs"] = srcs
    stats["hyps"] = hyps
    stats["refs"] = refs

    if len(hyps) > 0 and log_to_console:
        showid = np.random.randint(len(hyps))
        print("Example Source: " + srcs[showid])
        print("Example Hypothesis: " + hyps[showid])
        print("Example Reference: " + refs[showid])

    if log_to_console:
        print(f"Validation Loss: {stats['loss']:.4f}")
        print(f"Validation BLEU: {stats['bleu'].score:.2f}")

    return stats

def validate_and_save(model, task, criterion, optimizer, epoch, save=True):
    stats = validate(model, task, criterion)
    bleu = stats['bleu']
    loss = stats['loss']
    if save:
        savedir = Path(config.savedir).absolute()
        savedir.mkdir(parents=True, exist_ok=True)

        check = {
            "model": model.state_dict(),
            "stats": {"bleu": bleu.score, "loss": loss},
            "optim": {"step": optimizer._step}
        }
        torch.save(check, savedir / f"checkpoint{epoch}.pt")
        shutil.copy(savedir / f"checkpoint{epoch}.pt", savedir / f"checkpoint_last.pt")
        print(f"Saved epoch checkpoint: {savedir}/checkpoint{epoch}.pt")

        with open(savedir / f"samples{epoch}.en-zh.txt", "w", encoding='utf-8') as f:
            for s, h in zip(stats["srcs"], stats["hyps"]):
                f.write(f"{s}\t{h}\n")

        if getattr(validate_and_save, "best_bleu", 0) < bleu.score:
            validate_and_save.best_bleu = bleu.score
            torch.save(check, savedir / f"checkpoint_best.pt")
            print(f"New best checkpoint saved: {savedir}/checkpoint_best.pt")

        del_epoch = epoch - config.keep_last_epochs
        if del_epoch >= config.start_epoch:
            del_file = savedir / f"checkpoint{del_epoch}.pt"
            if del_file.exists():
                del_file.unlink()
                print(f"Deleted old checkpoint: {del_file}")

    if config.early_stopping:
        if bleu.score > getattr(validate_and_save, "best_bleu", 0):
            validate_and_save.best_bleu = bleu.score
            validate_and_save.patience = config.early_stopping_patience
        else:
            validate_and_save.patience -= 1
            print(f"Early stopping patience remaining: {validate_and_save.patience}")
            if validate_and_save.patience <= 0:
                print("Early stopping triggered.")
                return False

    return True


In [None]:
model = model.to(device=device)

print("Task: {}".format(task.__class__.__name__))
print("Encoder: {}".format(model.encoder.__class__.__name__))
print("Decoder: {}".format(model.decoder.__class__.__name__))
print("Criterion: {}".format(criterion.__class__.__name__))
print("Optimizer: {}".format(optimizer.__class__.__name__))
print(
    "Number of model parameters: {:,} (trained: {:,})".format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    )
)
print(f"Max tokens per batch = {config.max_tokens}, Accumulate steps = {config.accum_steps}")

def load_data_iterator(task, split, epoch=1, max_tokens=4000, num_workers=0, cached=True):
    batch_iterator = task.get_batch_iterator(
        dataset=task.dataset(split),
        max_tokens=max_tokens,
        max_sentences=None,
        max_positions=utils.resolve_max_positions(
            task.max_positions(),
            max_tokens,
        ),
        ignore_invalid_inputs=True,
        seed=seed,
        num_workers=num_workers,
        epoch=epoch,
        disable_iterator_cache=not cached,
    )
    return batch_iterator

epoch_itr = load_data_iterator(task, "train", config.start_epoch, config.max_tokens, config.num_workers)

def try_load_checkpoint(model, optimizer=None, name=None):
    name = name if name else "checkpoint_last.pt"
    checkpath = Path(config.savedir) / name
    if checkpath.exists():
        check = torch.load(checkpath, map_location=device)
        model.load_state_dict(check["model"])
        stats = check["stats"]
        step = "unknown"
        if optimizer is not None:
            optimizer._step = step = check["optim"]["step"]
        print(f"Loaded checkpoint {checkpath}: step={step} loss={stats['loss']} bleu={stats['bleu']}")
    else:
        print(f"No checkpoints found at {checkpath}!")

try_load_checkpoint(model, optimizer, name=config.resume)

def train_one_epoch(epoch_itr, model, task, criterion, optimizer, accum_steps=1):
    itr = epoch_itr.next_epoch_itr(shuffle=True)
    itr = iterators.GroupedIterator(itr, accum_steps)

    stats = {"loss": []}
    scaler = GradScaler()

    model.train()
    progress = tqdm.tqdm(itr, desc=f"Train Epoch {epoch_itr.epoch}", leave=True)
    for samples in progress:
        model.zero_grad()
        accum_loss = 0
        sample_size = 0

        for i, sample in enumerate(samples):
            sample = utils.move_to_cuda(sample, device=device)
            target = sample["target"]
            sample_size_i = sample["ntokens"]
            sample_size += sample_size_i

            with autocast():
                net_output = model.forward(**sample["net_input"])
                lprobs = F.log_softmax(net_output[0], -1)
                loss = criterion(lprobs.view(-1, lprobs.size(-1)), target.view(-1))

                scaler.scale(loss).backward()

            accum_loss += loss.item()

        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_norm)

        scaler.step(optimizer)
        scaler.update()

        loss_print = accum_loss / sample_size
        stats["loss"].append(loss_print)
        progress.set_postfix(loss=loss_print)
        print(f"Train Epoch {epoch_itr.epoch} | Loss: {loss_print:.4f}")

    loss_mean = np.mean(stats["loss"])
    print(f"Training Loss for Epoch {epoch_itr.epoch}: {loss_mean:.4f}")
    return stats

while epoch_itr.next_epoch_idx <= config.max_epoch:
    train_one_epoch(epoch_itr, model, task, criterion, optimizer, config.accum_steps)
    continue_training = validate_and_save(model, task, criterion, optimizer, epoch=epoch_itr.epoch)
    print(f"End of Epoch {epoch_itr.epoch}")
    if config.early_stopping and not continue_training:
        break
    epoch_itr = load_data_iterator(task, "train", epoch_itr.next_epoch_idx, config.max_tokens, config.num_workers)

checkdir = Path(config.savedir)

avg_checkpoint_path = checkdir / "avg_last_5_checkpoint.pt"
if not avg_checkpoint_path.exists():
    checkpoint_files = sorted(checkdir.glob("checkpoint*.pt"), key=lambda x: x.stem)
    last_five = checkpoint_files[-5:]
    if len(last_five) >= 1:
        checkpoints_str = ' '.join([str(cp) for cp in last_five])
        os.system(f"fairseq-average-checkpoints --inputs {checkpoints_str} --num-epoch-checkpoints {len(last_five)} --output {avg_checkpoint_path}")
        print(f"Averaged last {len(last_five)} checkpoints into {avg_checkpoint_path}")
    else:
        print("Not enough checkpoints to average.")
else:
    print(f"Averaged checkpoint already exists at {avg_checkpoint_path}")

best_checkpoint = checkdir / "checkpoint_best.pt"
if best_checkpoint.exists():
    check = torch.load(best_checkpoint, map_location=device)
    model.load_state_dict(check["model"])
    print("Loaded best checkpoint for final evaluation.")
    validate(model, task, criterion, log_to_console=True)
else:
    print("Best checkpoint not found. Skipping loading best checkpoint.")

task.load_dataset(split="test", epoch=1)
test_itr = load_data_iterator(task, "test", epoch=1, max_tokens=config.max_tokens, num_workers=0).next_epoch_itr(shuffle=False)

idxs = []
hyps = []

model.eval()
progress = tqdm.tqdm(test_itr, desc=f"Prediction", leave=True)
with torch.no_grad():
    for i, sample in enumerate(progress):
        sample = utils.move_to_cuda(sample, device=device)

        s, h, r = inference_step(sample, model)

        hyps.extend(h)
        idxs.extend(list(sample['id']))

sorted_hyps = [x for _, x in sorted(zip(idxs, hyps))]

pred = pd.DataFrame({
    'id': range(len(sorted_hyps)),
    'sentence': sorted_hyps
})

pred.to_csv('prediction.csv', index=False)
print("Saved predictions to prediction.csv")

combined_train_valid = tokenized_data_dir / 'combined_train_valid.txt'
if combined_train_valid.exists():
    combined_train_valid.unlink()
    print(f"Deleted {combined_train_valid}")

