In [None]:
import json
from datasets import load_dataset
from transformers import AutoTokenizer

from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch

from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
from transformers.optimization import AdamW

import numpy as np

from datasets import load_metric

In [None]:
FILE_DIR = 'data/'
MODEL_DIR = '.'

In [None]:
def convert_dataset_MC(input_filepath, output_filepath): 
    """
    Convert the given dataset into the output format with more features such as 'ending0', 'startphrase', ... 
    """
    with open(input_filepath, 'r', encoding="utf-8") as ip: 
        for row in ip: 
            data = json.loads(row)
            sent = data['sentence']
            sent_part1, sent_part2 = sent.split("_")
            option1_sent_part2 = data['option1'] + sent_part2
            option2_sent_part2 = data['option2'] + sent_part2
            sentence = {'qID':data['qID'], 'sent1': sent_part1, 'sent2': '', 'startphrase': sent_part1, 'ending0': option1_sent_part2, 'ending1': option2_sent_part2}
            ans = data.get('answer')
            if (ans):
                sentence['label'] = int(data['answer']) - 1
            else: 
                sentence['label'] = int(data['qID'].split("-")[-1]) - 1
            with open(output_filepath, 'a', encoding="utf-8") as op:
                op.write(json.dumps(sentence))
                op.write("\n")

In [None]:
data = ['train', 'test', 'dev']
for d in data: 
    input_file = f"{FILE_DIR}/{d}.jsonl" 
    output_file = f"{FILE_DIR}/MC_converted_{d}.jsonl" 
#     open (output_file, 'w', encoding='utf-8').close() # uncomment if you have an exiting file of the same name.
    convert_dataset_MC(input_file, output_file)

In [None]:
# Variables for the preprocessing function
ending_names = ["ending0", "ending1"]
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [None]:

def preprocess_function(examples):
    first_sentences = [[context] * 2 for context in examples["sent1"]]
    question_headers = examples["sent2"]
    second_sentences = [
        [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
    ]

    # flatten the 2 combined examples 
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
      
    return {k: [v[i : i + 2] for i in range(0, len(v), 2)] for k, v in tokenized_examples.items()}

# Load dataset to get it ready for tokenization
d = load_dataset('json', data_files={'train': f'{FILE_DIR}/MC_converted_train.jsonl', 'validation': f'{FILE_DIR}/MC_converted_dev.jsonl', 'test': f'{FILE_DIR}/MC_converted_test.jsonl'})
# tokenize dataset using map 
tokenized_d = d.map(preprocess_function, batched=True)

In [None]:
@dataclass
class DataCollatorForMultipleChoice:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features): 
        label_name = "label" if "label" in features[0].keys() else "labels"
        
        labels = [feature.pop(label_name) for feature in features] 
        batch_size = len(features) 

        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
         
        batch["labels"] = torch.tensor(labels, dtype=torch.int64) 
      
        return batch

In [None]:
def compute_metrics(eval_predictions):
    predictions, label_ids = eval_predictions
    preds = np.argmax(predictions, axis=1)
    return {"accuracy": (preds == label_ids).astype(np.float32).mean().item()}

In [None]:
# Base pre-trained model
model = AutoModelForMultipleChoice.from_pretrained("distilbert-base-uncased")
# Optimizer selected (AdamW)
optimizer = AdamW(
    model.parameters(),
    lr=1e-3,
    eps=1e-30,
    weight_decay=0.0,
)


In [None]:
# Set model to the device
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

In [None]:
# Prepare Training Arguments
training_args = TrainingArguments(
  output_dir="./results",
  evaluation_strategy="epoch",
  learning_rate=5e-4,
  per_device_train_batch_size=8,
  per_device_eval_batch_size=8,
  num_train_epochs=4,
  weight_decay=0.01,
)

# Create trainer 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_d["train"],
    eval_dataset=tokenized_d["validation"], 
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics = compute_metrics,
    optimizers=(optimizer,None)
)
  

In [None]:
def compute_test(trainer,data):
    """
    Compute accuracy scores for the dataset
    """
    predictions = trainer.predict(data)
    preds = np.argmax(predictions.predictions, axis=-1)# get the raw score with the higher value as the prediction
    metric = load_metric("accuracy" , '')
    return metric.compute(predictions=preds, references=predictions.label_ids)

In [None]:
# Evaluation of pre-fine tuned model
print ("Train Dataset", compute_test(trainer, tokenized_d["train"]))
print ("Evaluation Dataset", compute_test(trainer, tokenized_d["validation"]))
# Train model
trainer.train()

In [None]:
# Evaluation of post-fine tuned model
print ("Train Dataset", compute_test(trainer, tokenized_d["train"]))
print ("Evaluation Dataset", compute_test(trainer, tokenized_d["validation"]))