In [None]:
! pip install datasets transformers

In [None]:
from transformers import RobertaTokenizer, RobertaForMultipleChoice, AutoTokenizer, AdamW, TrainingArguments, Trainer
from torch.utils.data import DataLoader, Dataset
import torch
import json
import pandas as pd
import numpy as np
from datasets import load_dataset

In [None]:
# Check if a GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model and tokenizer
model = RobertaForMultipleChoice.from_pretrained('roberta-base').to(device)
tokenizer = AutoTokenizer.from_pretrained('roberta-base')

In [None]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [None]:
def read_jsonl(path: str):
    with open(path) as fh:
        return [json.loads(line) for line in fh.readlines() if line]

In [None]:
dataset = load_dataset("derek-thomas/ScienceQA")

In [None]:
dataset["train"][0]

In [None]:
def select_features(example):
    # Replace 'context', 'question', 'answers' with the features you want to keep
    selected_features = {key: example[key] for key in ['question', 'answer', 'choices']}
    return selected_features

data = dataset["train"].map(lambda x: select_features(x), remove_columns = ["hint", "task", "grade", "subject", "topic", "category", "skill", "lecture", "solution"])

Map:   0%|          | 0/12726 [00:00<?, ? examples/s]

In [None]:
max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.

In [None]:
tokenized_example = tokenizer(
    [question["question"] for question in data],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    stride=doc_stride
)

In [None]:
for x in tokenized_example["input_ids"][:2]:
    print(tokenizer.decode(x))

<s>Which of these states is farthest north?</s>
<s>Identify the question that Tom and Justin's experiment can best answer.</s>


In [None]:
class MathWordProblemDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        question = self.data[idx]['question']
        answer = self.data[idx]['choices'][self.data[idx]['answer']]
        
        # tokenize inputs
        inputs = self.tokenizer(question, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        
        # tokenize outputs
        outputs = self.tokenizer(answer, padding='max_length', truncation=True, max_length=128, return_tensors='pt')

        start_positions = torch.tensor([len(inputs)]).to(device)
        end_positions = torch.tensor([len(inputs) + len(outputs) - 1]).to(device)
        outputs = {'start_positions': start_positions, 'end_positions': end_positions}
        
        return {'input_ids': inputs['input_ids'][0], 
                'attention_mask': inputs['attention_mask'][0], 
                'start_positions': outputs['start_positions'], 
                'end_positions': outputs['end_positions']}
        

def collate_fn(batch):
    # sort batch by length of input sequence
    sorted_batch = sorted(batch, key=lambda x: x['input_ids'].shape[0], reverse=True)
    
    # pad inputs and outputs to have equal length
    input_ids = torch.nn.utils.rnn.pad_sequence([x['input_ids'] for x in sorted_batch], batch_first=True)
    attention_mask = torch.nn.utils.rnn.pad_sequence([x['attention_mask'] for x in sorted_batch], batch_first=True)
    start_positions = torch.tensor([x['start_positions'] for x in sorted_batch])
    end_positions = torch.tensor([x['end_positions'] for x in sorted_batch])
    
    return {'input_ids': input_ids.to(device), 
            'attention_mask': attention_mask.to(device), 
            'start_positions': start_positions.to(device), 
            'end_positions': end_positions.to(device)}

from tqdm import tqdm
import matplotlib.pyplot as plt

# create dataset and dataloader
batch_size=32
dataset = MathWordProblemDataset(data, tokenizer)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
losses = []

# load model and set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# finetune the model
for epoch in range(3):
    model.train(True)
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        
        # forward pass
        outputs = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
                        # start_positions=batch['start_positions'], end_positions=batch['end_positions'])
        
        # calculate loss and backpropagate
        loss = outputs.loss
        losses.append(loss)
        loss.backward()
        optimizer.step()
        
    print(f"Epoch {epoch+1} loss: {loss.item()}")

# save the finetuned model
model.save_pretrained('finetuned_math_word_problem_model')

In [None]:
losses = [x.detach().numpy() for x in losses]

# Plot the training and validation losses
plt.plot(losses, label='Train Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# example math word problem
question = "John has 5 apples. They gave 2 apples to Mary. How many apples does John have now?"
actual_reason = "John had 5 apples before giving 2 apples to Mary. Now he has only 3 apples left."

# tokenize inputs
inputs = tokenizer(question, padding='max_length', truncation=True, max_length=128, return_tensors='pt')

# make prediction
output = model(input_ids=inputs['input_ids'].to(device), attention_mask=inputs['attention_mask'].to(device))

start_scores, end_scores = output.start_logits, output.end_logits
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores) + 1
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end]))

In [None]:
answer

In [None]:
# Pass through the model and get the answer and reasoning
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
start_logits, end_logits = outputs.start_logits, outputs.end_logits
start_index = torch.argmax(start_logits)
end_index = torch.argmax(end_logits)
answer_tokens = input_tokens[start_index:end_index+1]
answer = tokenizer.decode(answer_tokens)

print(f"Answer: {answer}")