# Libraries

In [None]:
from datasets import load_dataset,load_metric
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer, set_seed
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch
import numpy as np
from huggingface_hub import notebook_login
from datasets import set_caching_enabled
set_caching_enabled(False)

# Set seed 

In [None]:
set_seed(42)

# Model and tokenizer

In [None]:
# model_checkpoint = "michiyasunaga/BioLinkBERT-base"
model_checkpoint = 'bert-base-uncased'

model = AutoModelForMultipleChoice.from_pretrained(model_checkpoint)
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=512) # truncation of long questions

# Load preprocessed QA dataset

In [None]:
# preprocessed questions from here: https://github.com/michiyasunaga/LinkBERT

datafiles = {}
datafiles['train'] = 'preprocessed_questions/train.json'
datafiles['test'] = 'preprocessed_questions/test.json'
datafiles['validation'] = 'preprocessed_questions/dev.json'

qa_dataset = load_dataset('json', data_files=datafiles)

# Tokenize the dataset

In [None]:
ending_names = ["ending0", "ending1", "ending2", "ending3"]


def preprocess_function(examples):
    first_sentences = [[context] * 4 for context in examples["sent1"]]
    question_headers = examples["sent2"]
    second_sentences = [
        [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
    ]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

tokenized_qa = qa_dataset.map(preprocess_function, batched=True)
tokenized_qa


# Data Collator

In [None]:
@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

# Log into the hub

In [None]:
# need to log into huggingface

notebook_login()

# Compute metrics for accuracy

In [None]:
metric = load_metric('accuracy')

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

# Training arguments and trainer

In [None]:
training_args = TrainingArguments(
    # output_dir='medqa_fine_tuned_linkbert',
    output_dir='medqa_fine_tuned_generic_bert',
    evaluation_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=5,
    weight_decay=0.01,
    push_to_hub=True,
    save_total_limit = 1,
    resume_from_checkpoint=True,
    warmup_steps = 100,
    gradient_accumulation_steps = 8,    
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_qa["train"],
    eval_dataset=tokenized_qa["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics
)

# Training

In [None]:
trainer.train()

# Push to hub

In [None]:
trainer.push_to_hub()