In [1]:
import json
import torch
from transformers import BertTokenizer, BertForMultipleChoice
import json
from transformers import RobertaTokenizer, RobertaForMultipleChoice
from word2number import w2n

In [2]:
def load_data(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

In [3]:
def digit_based_reframing(text):
    #input reframing by converting textual numbers into their digit based representation
    words = text.split()
    for i, word in enumerate(words):
        try:
            words[i] = str(w2n.word_to_num(word))
        except ValueError:  # Ignore non numeric words
            pass
    return ' '.join(words)


In [4]:
# Preprocessing Function, using reframing type as a parameter
def preprocess_data(data, reframing_type="original"):
    questions, choices, labels = [], [], []
    
    for item in data:
        if reframing_type == "original":
            question = item['question']
        elif reframing_type == "scientific":
            question = item['question_sci_10E']
        elif reframing_type == "digit_based":
            question = digit_based_reframing(item['question'])
        else:
            raise ValueError("Invalid reframing type!")

        questions.append(question)
        choices.append([item['Option1'], item['Option2']])
        labels.append(0 if item['answer'] == 'Option 1' else 1)
    
    return questions, choices, labels

In [5]:
# Load datasets
# train_data = load_data("QQA_train.json")
dev_data = load_data("QQA_dev.json")
# test_data = load_data("QQA_test.json")

In [6]:
# Preprocess datasets without any reframing
# train_qs_orig, train_choices_orig, train_labels_orig = preprocess_data(train_data, "original")
dev_qs_orig, dev_choices_orig, dev_labels_orig = preprocess_data(dev_data, "original")

In [7]:
# Preprocess datasets based on scientific notation reframing
# train_qs_sci, train_choices_sci, train_labels_sci = preprocess_data(train_data, "scientific")
dev_qs_sci, dev_choices_sci, dev_labels_sci = preprocess_data(dev_data, "scientific")

In [8]:
# Preprocess datasets based on digit based input reframing
# train_qs_digit, train_choices_digit, train_labels_digit = preprocess_data(train_data, "digit_based")
dev_qs_digit, dev_choices_digit, dev_labels_digit = preprocess_data(dev_data, "digit_based")

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [10]:
# Load BERT and RoBERTa tokenizer and pretrained model
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertForMultipleChoice.from_pretrained("bert-base-uncased").to(device)
roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
roberta_model = RobertaForMultipleChoice.from_pretrained("roberta-base").to(device)

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaForMultipleChoice were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
# Tokenization of questions and their choices
def encode_example(tokenizer, question, choices):
    encodings = tokenizer([question] * len(choices), choices, padding=True, truncation=True, return_tensors="pt")
    return {key: val.unsqueeze(0).to(device) for key, val in encodings.items()}

In [12]:
# Evaluation Function which takes model, tokenizer and processed data
def evaluate(model, tokenizer, dev_qs, dev_choices, dev_labels):
    correct = 0
    model.eval()
    with torch.no_grad():
        for q, c, label in zip(dev_qs, dev_choices, dev_labels):
            inputs = encode_example(tokenizer, q, c)
            logits = model(**inputs).logits
            pred = torch.argmax(logits, dim=1).item()
            correct += (pred == label)
    return correct / len(dev_labels)

### Evaluation of BERT Model

In [13]:
print("Evaluating BERT withot any input reframing")
accuracy = evaluate(bert_model, bert_tokenizer, dev_qs_orig, dev_choices_orig, dev_labels_orig)
print(f"BERT Accuracy (Original): {accuracy:.4f}")

Evaluating BERT withot any input reframing
BERT Accuracy (Original): 0.5185


In [14]:
# Run Evaluation
print("Evaluating BERT with Scientific Notation Based Reframing")
accuracy = evaluate(bert_model, bert_tokenizer, dev_qs_sci, dev_choices_sci, dev_labels_sci)
print(f"BERT Accuracy (Original): {accuracy:.4f}")

Evaluating BERT with Scientific Notation Based Reframing
BERT Accuracy (Original): 0.4815


In [15]:
# Run Evaluation
print("Evaluating BERT with Digit Based Reframing")
accuracy = evaluate(bert_model, bert_tokenizer, dev_qs_digit, dev_choices_digit, dev_labels_digit)
print(f"BERT Accuracy (Original): {accuracy:.4f}")

Evaluating BERT with Digit Based Reframing
BERT Accuracy (Original): 0.5309


### Evaluating Roberta Model

In [16]:
# Run Evaluation
print("Evaluating Roberta withot any input reframing")
accuracy = evaluate(roberta_model, roberta_tokenizer, dev_qs_orig, dev_choices_orig, dev_labels_orig)
print(f"BERT Accuracy (Original): {accuracy:.4f}")

Evaluating Roberta withot any input reframing
BERT Accuracy (Original): 0.5185


In [17]:
# Run Evaluation
print("Evaluating Roberta with Scientific Notaion Based Reframing")
accuracy = evaluate(roberta_model, roberta_tokenizer, dev_qs_sci, dev_choices_sci, dev_labels_sci)
print(f"BERT Accuracy (Original): {accuracy:.4f}")

Evaluating Roberta with Scientific Notaion Based Reframing
BERT Accuracy (Original): 0.5062


In [18]:
# Run Evaluation
print("Evaluating Roberta with Digit Based Reframing ")
accuracy = evaluate(roberta_model, roberta_tokenizer, dev_qs_digit, dev_choices_digit, dev_labels_digit)
print(f"BERT Accuracy (Original): {accuracy:.4f}")

Evaluating Roberta with Digit Based Reframing 
BERT Accuracy (Original): 0.5926
