In [None]:
from datasets import load_from_disk
from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq, AutoConfig, TrainingArguments, Trainer
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader

In [None]:
CHECKPOINT = "MaRiOrOsSi/t5-base-finetuned-question-answering"
DATASET_PATH = "../Datasets/Visa_QA_V3/"

In [None]:
## Tokenization code
checkpoint_tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT, return_tensors='pt')
def tokenize_and_create_prompt(sample):
    question_string = sample['question']
    tags_meta_data = sample['meta_tags'].split(",")
    prompt = f"""
    You are an expert in dealing with questions of immigration and international travel. Answer the following question and use the keywords to get some hints about the answer and context
    Question: {question_string} Keywords: {", ".join(tags_meta_data)}
    """
    tokenized_output = checkpoint_tokenizer(prompt, add_special_tokens=True)
    return tokenized_output

In [None]:
visa_qa_dataset = load_from_disk(DATASET_PATH)
preprocessed_visa_questions = visa_qa_dataset.map(tokenize_and_create_prompt)

In [None]:
qa_model = AutoModelForSeq2SeqLM(CHECKPOINT,
                                 config=AutoConfig.from_pretrained(CHECKPOINT))


In [None]:
ques_data_collator = DataCollatorForSeq2Seq(checkpoint_tokenizer, model=qa_model)
train_args = TrainingArguments(
    output_dir="../Model_Checkpoints/closed-generative-qa",
    num_train_epochs=1,
    warmup_steps=500,
    weight_decay=0.01,
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=1e6,
    gradient_accumulation_steps=16
)

trainer = Trainer(
    model=qa_model,
    args=train_args,
    data_collator=ques_data_collator,
    train_dataset=preprocessed_visa_questions['train'],
    eval_dataset=preprocessed_visa_questions['validation']
)

trainer.train()