In [1]:
# Import the necessary libraries
from datasets import load_dataset, Dataset
from transformers import T5Tokenizer, DataCollatorForSeq2Seq
from transformers import T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
from normalizer import normalize
import torch

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# check for cuda
print(torch.cuda.is_available())

In [None]:
dataset_coqa = load_dataset("arbitropy/bcoqa")

In [None]:
prefix = "Continue conversation:\n"

def create_history_format(example, history, turnCount):
    """
    Creates the formatted history for the conversation prompt.

    Args:
        example (dict): The example containing the context, questions, and answers.
        history (list): The list of previous questions and answers in the conversation.
        turnCount (int): The number of previous turns to include in the history. If 0, include all turns.

    Returns:
        str: The formatted conversation prompt.

    """
    if len(history) == 0:
        prompt = prefix + """Context:\n%s\n"""%(example['context'])
    elif len(history) < turnCount or turnCount == 0:
        prompt = prefix +"""Context:\n%s\nHistory:\n"""%(example['context'])
        for i in range(len(history)): # add all history
            prompt += """+%s\n-%s\n"""%(history[i]['question'], history[i]['answer'])
    else:
        prompt = prefix +"""Context:\n%s\nHistory:\n"""%(example['context'])
        for i in range(len(history)-turnCount,len(history)): # only the last needed turn count
            prompt += """+%s\n-%s\n"""%(history[i]['question'], history[i]['answer'])
    prompt += """+%s\n-"""%(example['question']) # add the query question
    return prompt

def denest_element(example):
    """
    Denests an conversation by creating multiple items with different questions and answers.

    Args:
        example (dict): A whole conversation item.

    Returns:
        list: The list of denested items.

    """
    dict_list = []
    for i in range(len(example['questions'])):
        raw_denest_dict = {'id': example['id'], 'context': example['story'], 'question': example['questions'][i]['question'], 'answer': example['answers'][i]['answer']}
        # the story is modified with final prompt
        context = create_history_format(raw_denest_dict, dict_list, 0) # turn count is zero for entire history
        # add the modified one to the train_dataset, context holds the prompt
        dict_list.append({'id': example['id'], 'context': context, 'question': example['questions'][i]['question'], 'answer': example['answers'][i]['answer']})
    return dict_list

def denest_dataset_with_context(dataset):
    """
    Denests a dataset by denesting each conversation in the dataset into items with single question and answer.

    Args:
        dataset (dataset): Whole dataset.

    Returns:
        dataset: The denested dataset.

    """
    denested_list = []
    for i, input in enumerate(dataset):
        response = denest_element(input)
        for item in response:
            denested_list.append(item)
    return Dataset.from_list(denested_list)

In [None]:
train_dataset_coqa = denest_dataset_with_context(dataset_coqa['train'])
valid_dataset_coqa = denest_dataset_with_context(dataset_coqa['validation'])


In [None]:
model_checkpoint = "csebuetnlp/banglat5"
tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [None]:
def tokenize_label(examples):
    """
    Tokenizes the input examples and prepares the model inputs for training.

    Args:
        examples (dict): A dictionary containing the input examples.

    Returns:
        dict: A dictionary containing the tokenized model inputs with labels.

    """
    inputs = examples['context']
    model_inputs = tokenizer(normalize(inputs), max_length=1024, truncation=True) # normalize library is used before tokeinzation for best result
    labels = tokenizer(text_target=normalize(examples["answer"]), max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
# Map the preprocessing function across our dataset
train_tokenized_coqa = train_dataset_coqa.map(tokenize_label)
valid_tokenized_coqa = valid_dataset_coqa.map(tokenize_label)

In [None]:
batch_size = 8
training_args = Seq2SeqTrainingArguments(
    output_dir="bcoqa-bt5",
    evaluation_strategy="steps",
    eval_steps = 10000,
    save_strategy = 'steps',
    save_steps = 10000,
    optim="adafactor",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=2,
    num_train_epochs=2,
    predict_with_generate=True,
    load_best_model_at_end=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized_coqa,
    eval_dataset=valid_tokenized_coqa,
    tokenizer=tokenizer,
    data_collator=data_collator,
)



In [None]:
# Train the model
trainer.train()

In [None]:
trainer.save_model('bcoqa-bt5')