In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import torch
torch.cuda.get_device_name()

In [None]:
from transformers import set_seed
set_seed(2)

In [None]:
from datasets import load_dataset

datasets = load_dataset("indic_glue","csqa.bn",split="test")

In [None]:
datasets = datasets.train_test_split(
    train_size=0.9, seed=42
)
datasets

In [None]:
datasets["train"][0]

In [None]:
from transformers import PreTrainedTokenizerFast, AutoModelForSequenceClassification, AutoTokenizer

# tokenizer = PreTrainedTokenizerFast.from_pretrained("../Bengali Pretraining/models/unigram/bert-base-pretrained-bengali")
tokenizer = PreTrainedTokenizerFast.from_pretrained("../Bengali Pretraining/models/unigram/unigram-long-text")

In [None]:
choice_names = ['options']

In [None]:
def preprocess_function(examples):
    premise = [[context.replace("<MASK>","[MASK]")] * 4 for context in examples["question"]]
    cause = [examples['options'][i] for i,_ in enumerate(premise)]

    premise = sum(premise, [])
    cause = sum(cause, [])
    
    # print(premise)
    # print(cause)
    

    tokenized_examples = tokenizer(premise, cause, truncation=True, max_length=128)
#     print(len(tokenized_examples))
    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}
    #return tokenized_examples

In [None]:
temp = preprocess_function(datasets["train"][:1])
# temp

In [None]:
datasets["train"][:1]

In [None]:
for chunk in temp['input_ids'][0]:
    print(tokenizer.decode(chunk))

In [None]:
tokenized_datasets = datasets.map(preprocess_function, batched=True)

In [None]:
def assign_label(example):
    options = example['options']
    example['label'] = options.index(example['answer'])
    return example

In [None]:
tokenized_datasets = tokenized_datasets.map(assign_label)

In [None]:
tokenized_datasets["train"][2]["label"]

In [None]:
datasets["train"][2]

In [None]:
# from transformers import set_seed
# set_seed(30)

In [None]:
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer, AutoModel

# model = AutoModelForMultipleChoice.from_pretrained("../Bengali Pretraining/models/unigram/bert-base-pretrained-bengali")
model = AutoModelForMultipleChoice.from_pretrained("../Bengali Pretraining/models/unigram/unigram-long-text")

In [None]:
model = AutoModelForMultipleChoice.from_pretrained("qa_model")

In [None]:
tokenized_datasets

In [None]:
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch


@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

In [None]:
data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer)

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")

In [None]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [None]:
#  disable weights and biases logging
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
training_args = TrainingArguments(
    output_dir="qa_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    #learning_rate=3e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    #num_train_epochs=3,
    num_train_epochs=4,
    #warmup_ratio=0.1,
    weight_decay=0.01,
    #weight_decay=0.04,
    fp16=True,
    metric_for_best_model = 'accuracy',
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics,
)

In [None]:
trainer.evaluate()

In [None]:
trainer.train()

In [None]:
trainer.save_model()

In [None]:
y_preds, y_true, _ = trainer.predict(tokenized_datasets["test"])

In [None]:
y_preds = np.argmax(y_preds, axis=-1)

In [None]:
from sklearn.metrics import classification_report
target_names = ['choice1', 'choice2','choice3','choice4']

In [None]:
import matplotlib.pyplot as plt
from seaborn import heatmap
from sklearn.metrics import confusion_matrix

#plot heatmap of confusion matrix
mat = confusion_matrix(y_true, y_preds)
heatmap(mat, cmap="Pastel1_r", fmt="d", xticklabels=target_names, yticklabels=target_names, annot=True)

#add overall title to plot
plt.title('Confusion matrix for Clozed QA', fontsize = 12) # title with fontsize 20

In [None]:
misclassified = [i for i in range(len(y_preds)) if ((y_preds[i] != y_true[i]) and (y_true[i]==1) and (y_preds[i]==0))]

In [None]:
misclassified_dataset = tokenized_datasets['test'].select(misclassified)

In [None]:
idx=0
misclassified[idx]

In [None]:
misclassified_dataset[33]
# misclassified_dataset[30:60]['title']