In [None]:
import os
import json
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq,
    TrainerCallback
)

In [None]:
EPOCHS = 1
LEARNING_RATE = 1e-4
BATCH_SIZE = 8
OUTPUT_PATH = "./output"
MODEL = "t5-base"
DATA_PATH = '../data/qrels.train.tsv'

In [None]:
# Converting MS-MARCO to torch dataset for fine-tuning


class dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        text = f'Query: {sample[0]} Document: {sample[1]} Relevant:'
        return {
          'text': text,
          'labels': sample[2],
        }

In [None]:
# Using T5-base for fine-tuning
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(101)

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL).to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL)

In [None]:
# Processing MS-MARCO dataset

train_samples = []
with open(DATA_PATH, 'r', encoding='utf-8') as file:
    for num, line in tqdm(enumerate(file)):
        if num == 0: continue
        if num > 6.4e5 * 10:
            break
        query, positive, negative = line.split("\t")
        train_samples.append((query, positive, 'true'))
        train_samples.append((query, negative, 'false'))

In [None]:
dataset_train = dataset(train_samples)

In [None]:
# Setting training arguments and parameters

train_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    remove_unused_columns=False, 
    output_dir=OUTPUT_PATH,
    do_train=True,
    save_strategy='steps',
    save_steps=10000,
    logging_steps=100,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=16,
    learning_rate=LEARNING_RATE,
    weight_decay=5e-5,
    num_train_epochs=EPOCHS,
    warmup_steps=1000,
    adafactor=True,
    seed=1,
    disable_tqdm=False,
    load_best_model_at_end=False,
    dataloader_pin_memory=False,
)

In [None]:
def data_collate(batch):
    texts = [example["text"] for example in batch]
    tokenized = tokenizer(
        texts,
        padding=True,
        truncation="longest_first",
        return_tensors="pt",
        max_length=512,
    )

    tokenized["labels"] = tokenizer(
        [example["labels"] for example in batch], return_tensors="pt"
    )["input_ids"]

    for name in tokenized:
        tokenized[name] = tokenized[name].to(device)

    return tokenized

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=train_args,
    train_dataset=dataset_train,
    tokenizer=tokenizer,
    data_collator=data_collate,
)

trainer.train()

trainer.save_model(OUTPUT_PATH)
trainer.save_state()