In [None]:
!pip install -q transformers==4.34.0 datasets==2.14.5 accelerate==0.23.0 evaluate==0.4.1 peft==0.5.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m258.1/258.1 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.6/85.6 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m28.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m87.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m69.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import get_peft_model, LoraConfig, TaskType

model_name = 'google/flan-t5-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# peft_config = LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05, target_modules=["query", "vavlue"], bias="none")
# model = get_peft_model(model, peft_config)

In [None]:
#modified_MCQA (MedMCQA&FrenchMedMCQA)
!gdown 1e5FHk1tmAy_VmyfgTi42lnzvHUM2Zmo0
!unzip modified_MCQA.zip

In [None]:
from datasets import load_dataset
from datasets import DatasetDict

data_dir = '.'
raw_dataset = {
    'train': load_dataset('json', data_files=f'{data_dir}/MCQA_train.json')['train'],
    'valid': load_dataset('json', data_files=f'{data_dir}/MCQA_dev.json')['train']
}
raw_dataset = DatasetDict(raw_dataset)

In [None]:
# Dataloader
import torch
def id_labeling(num_opts):
  option_dict = {2: ['a', 'b'],
                 3: ['a', 'b', 'c','ab', 'ac', 'bc','abc'],
                 4: ['a', 'b', 'c', 'd', 'ab', 'ac', 'ad', 'bc', 'bd', 'cd','abc', 'abd', 'acd', 'bcd','abcd'],
                 5: ['a', 'b', 'c', 'd', 'e', 'ab', 'ac', 'ad', 'ae', 'bc', 'bd', 'be', 'cd', 'ce', 'de', 'abc', 'abd', 'abe', 'acd',
                     'ace', 'ade', 'bcd', 'bce', 'bde', 'cde', 'abcd', 'abce', 'abde', 'acde', 'bcde','abcde']}
  if num_opts in [2,3,4,5]:
    label_list = option_dict[num_opts]
    id2label = {}
    label2id = {}
    for idx, label in enumerate(label_list):
      id2label[idx] = label
      label2id[label] = idx
    num_labels = len(id2label)
    return id2label, label2id, num_labels

def preprocess_function(examples, max_seq_length, tokenizer):
    # Tokenize the texts
    sentences = []
    labels = []
    for example in zip(examples["question"], examples["context"],
                       examples['answer_a'], examples['answer_b'], examples['answer_c'],
                       examples['answer_d'], examples['answer_e'], examples['label']):
        question = example[0]
        context = example[1]
        answer_a = example[2]
        answer_b = example[3]
        answer_c = example[4]
        answer_d = example[5]
        answer_e = example[6]
        opt_lst = [answer_a, answer_b, answer_c, answer_d, answer_e]
        choices = ''
        for opt in opt_lst:
          if str(opt) != 'nan':
            choices += ". \n " + opt
        prompt = f"Question: {question}. Choice the correct answers from: {choices}. Context: {context}."
        sentences.append(prompt)

        answer = id2label[int(example[7])]
        labels.append(answer)

    model_inputs = tokenizer(sentences,
                             padding="max_length",
                             max_length=max_seq_length,
                             truncation=True)
    labels = tokenizer(labels, padding=True)

    model_inputs["labels"] = [[(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]]
    return model_inputs

In [None]:
from functools import partial
id2label, label2id, num_labels = id_labeling(5)
processed_dataset = raw_dataset.map(partial(preprocess_function,
                                            max_seq_length=256,
                                            tokenizer=tokenizer),
                                    batched=True,
                                    load_from_cache_file=False,
                                    remove_columns=['id', 'question', 'answer_a', 'answer_b', 'answer_c',
                                                    'answer_d', 'answer_e', 'label', 'context', 'bert_text'],
                                    desc="Running tokenizer on dataset",)

In [None]:
#Metric
import numpy as np
import evaluate
from transformers import EvalPrediction

def postprocess_text(predictions, labels):
    predictions = [prediction.strip() for prediction in predictions]
    labels = [label2id[label.strip()] for label in labels]

    for idx in range(len(predictions)):
        if predictions[idx] in label2id:
           predictions[idx] = label2id[predictions[idx]]
        else:
            predictions[idx] = '-100'
    return predictions, labels

def load_metric(metric_name):
    if metric_name == "accuracy":
        return evaluate.load("accuracy")
    elif metric_name == "f1":
        return evaluate.load("f1")

def seq2seq_compute_metrics(tokenizer, metric):
    def compute_metrics(eval_pred: EvalPrediction):
        nonlocal tokenizer, metric
        predictions, labels = eval_pred
        if isinstance(predictions, tuple):
            predictions = predictions[0]

        predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
        # Replace -100 in the labels as we can't decode them.
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        predictions, labels = postprocess_text(predictions, labels)
        result = metric.compute(predictions=predictions, references=labels)
        return result
    return compute_metrics

In [None]:
#Trainer
import transformers
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments

label_pad_token_id = -100
task = "MedMCQA&FrenchMCQA"
metric = load_metric("accuracy")
compute_metrics = seq2seq_compute_metrics(tokenizer, metric)

data_collator = DataCollatorForSeq2Seq(tokenizer,
                                       model=model,
                                       label_pad_token_id=label_pad_token_id,
                                       pad_to_multiple_of=8)

EPOCHS = 3
training_args = Seq2SeqTrainingArguments(f"{model_name}-finetuned-{task}-v{1}",
                                         num_train_epochs=EPOCHS,
                                         per_device_train_batch_size=16,
                                         per_device_eval_batch_size=64,
                                         evaluation_strategy='steps',
                                         save_strategy='steps',
                                         save_steps=2000,
                                         eval_steps=2000,
                                         save_total_limit=EPOCHS,
                                         predict_with_generate=True,
                                         load_best_model_at_end=True,
                                         metric_for_best_model='accuracy')

trainer = Seq2SeqTrainer(model=model,
                         args=training_args,
                         data_collator=data_collator,
                         compute_metrics=compute_metrics,
                         train_dataset=processed_dataset['train'],
                         eval_dataset=processed_dataset['valid'])

In [None]:
trainer.train()