In [None]:
!pip install -q pandas scikit-learn matplotlib datasets transformers wandb seaborn captum ipywidgets tqdm

In [None]:
TRACK = False

In [None]:
import wandb
if TRACK:
    wandb.login()

In [None]:
from datasets import load_dataset

train = load_dataset("ai2_arc", 'ARC-Easy', split='train').to_pandas()
test = load_dataset("ai2_arc", 'ARC-Easy', split='test').to_pandas()
dev = load_dataset("ai2_arc", 'ARC-Easy', split='validation').to_pandas()

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForMultipleChoice, BertConfig, BertForMultipleChoice
import random
import os
import numpy as np
import logging
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)

SEED = 2023
def seed_everything(seed=2023):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
seed_everything(SEED)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Model is in device: {device}')

#model_name = 'bert-base-uncased'
model_name= 'dmis-lab/biobert-base-cased-v1.1-squad'
model = AutoModelForMultipleChoice.from_pretrained(model_name).to(device)

tokenizer = AutoTokenizer.from_pretrained(model_name)
#print(model_biobert.config)
print(f'Model {model.config._name_or_path} is being used.')

In [None]:
from multiple_choice_processor import MultipleChoiceProcessor

BATCH_SIZE = 16

processor = MultipleChoiceProcessor(tokenizer, train[:10], dev[:10], test[:10], set='easy')
train_loader, val_loader, test_loader = processor.create_datasets(batch_size=BATCH_SIZE, train_batch_size=BATCH_SIZE)

In [None]:
from train_eval import TrainingPipeline
from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR
from transformers import get_scheduler

EPOCHS = 6
LR = 0.0001

optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.3)
scheduler = CyclicLR(optimizer, base_lr=LR/10, max_lr=LR, cycle_momentum=False, mode='triangular2')

training_pipeline = TrainingPipeline(model, device, train_loader, val_loader, optimizer=optimizer, scheduler=scheduler, 
                                     track=TRACK, num_epochs=EPOCHS, lr=LR, model_checkpoint=False)

train_loss_list, train_accuracy_list, val_loss_list, val_accuracy_list = training_pipeline.train()

In [None]:
training_pipeline.confusion_matrix(report=True)

Salience Maps: Create visualizations that highlight important tokens or regions in the text. These maps can emphasize the most influential parts of the question and answer choices.

In [None]:
from plot import LossAccuracyPlotter

plotter = LossAccuracyPlotter(train_loss_list, val_loss_list, train_accuracy_list, val_accuracy_list, EPOCHS)
plotter.visualize()

In [None]:
test_loss, test_accuracy = training_pipeline.test(test_loader)

In [None]:
from transformers import BertTokenizer, BertForQuestionAnswering
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients
import torch
import gc


def tokenize(tokenizer, question, answer, max_length=45): #43
    encoded_dict = tokenizer.encode_plus(
        question,
        answer,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt',
        return_attention_mask=True,
    )
    input_ids = encoded_dict['input_ids']
    cls_token_id = tokenizer.cls_token_id
    sep_token_id = tokenizer.sep_token_id
    pad_token_id = tokenizer.pad_token_id

    baseline = torch.zeros_like(input_ids)
    baseline[input_ids == cls_token_id] = cls_token_id
    baseline[input_ids == sep_token_id] = sep_token_id
    baseline[(input_ids != cls_token_id) & (input_ids != sep_token_id) & (input_ids != pad_token_id)] = 0

    return {
        'input_ids': encoded_dict['input_ids'],
        'token_type_ids': encoded_dict['token_type_ids'],
        'position_ids': encoded_dict['attention_mask'], 
        'attention_mask': encoded_dict['attention_mask'],
        'baseline': baseline,
        'answer': answer
    }
def predict(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    output = model(inputs, token_type_ids=token_type_ids,
                     position_ids=position_ids, attention_mask=attention_mask)
    return output.start_logits, output.end_logits

def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
    pred = predict(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    pred = pred[position]
    return pred.max(1).values
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

model_path = 'bert-base-uncased'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForQuestionAnswering.from_pretrained(model_path).to(device)
model.eval()
model.zero_grad()

torch.cuda.empty_cache()
gc.collect()

#question = 'A hiker wants to know if air is warmer in a forest than in the nearby farm field. Which activity would best help the hiker find out which area is warmer?'
#answer = [
#    'reading a book about farm fields',
#    'making a weather prediction for the forests',
#    'measuring the wind speed at both locations',
#    'recording the temperatures at both locations'
#]
question = 'A small aluminum cube is dropped into a beaker of water to determine the buoyancy. Which of these is not necessary in determining the buoyant force acting on the cube?'
answer = [
    'density of water',
    'displaced volume',
    'gravitational pull on Earth',
    'density of the aluminum cube'
]

choices = ['A', 'B', 'C', 'D']
print(f'Question: {question}')
print('Answer Choices:')
for i, choice in enumerate(answer):
    print(f'\t[{choices[i]}]: {choice}')

for i, choice_text in enumerate(answer):
    print(f"Visualizations for Answer {choices[i]}")
    sample = tokenize(tokenizer, question, choice_text)

    input_ids = sample['input_ids'].to(device)
    token_type_ids = sample['token_type_ids'].to(device)
    position_ids = sample['position_ids'].to(device)
    attention_mask = sample['attention_mask'].to(device)
    baseline = sample['baseline'].to(device)
    answer_text = sample['answer']

    lig = LayerIntegratedGradients(squad_pos_forward_func, layer=model.bert.embeddings)

    attributions_start, delta_start = lig.attribute(inputs=input_ids,
                                                   baselines=baseline,
                                                   additional_forward_args=(token_type_ids, position_ids, attention_mask, 0),
                                                   return_convergence_delta=True)
    attributions_end, delta_end = lig.attribute(inputs=input_ids,
                                                 baselines=baseline,
                                                 additional_forward_args=(token_type_ids, position_ids, attention_mask, 1),
                                                 return_convergence_delta=True)

    attributions_start_sum = summarize_attributions(attributions_start)
    attributions_end_sum = summarize_attributions(attributions_end)
    start_scores, end_scores = predict(input_ids, token_type_ids=token_type_ids, position_ids=position_ids, attention_mask=attention_mask)
    ground_truth_tokens = tokenizer.encode(answer_text, add_special_tokens=False)
    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    ground_truth_end_ind = indices.index(ground_truth_tokens[-1])
    ground_truth_start_ind = ground_truth_end_ind - len(ground_truth_tokens) + 1

    start_position_vis = viz.VisualizationDataRecord(
                            attributions_start_sum,
                            torch.max(torch.softmax(start_scores[0], dim=0)),
                            torch.argmax(start_scores),
                            torch.argmax(start_scores),
                            str(ground_truth_start_ind),
                            attributions_start_sum.sum(),       
                            all_tokens,
                            delta_start)
    end_position_vis = viz.VisualizationDataRecord(
                            attributions_end_sum,
                            torch.max(torch.softmax(end_scores[0], dim=0)),
                            torch.argmax(end_scores),
                            torch.argmax(end_scores),
                            str(ground_truth_end_ind),
                            attributions_end_sum.sum(),       
                            all_tokens,
                            delta_end)
    print('\033[1m', f'Visualizations For Start Position - Answer {choices[i]}', '\033[0m')
    viz.visualize_text([start_position_vis])

    print('\033[1m', f'Visualizations For End Position - Answer {choices[i]}', '\033[0m')
    viz.visualize_text([end_position_vis])
