In [None]:
import torch
from transformers import RobertaTokenizer, RobertaForMultipleChoice, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score
import json
import json_lines
import os
from tqdm import tqdm

In [None]:
class MultipleChoiceDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_ids, attention_mask, label = self.data[idx]
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': label
        }

In [None]:
answer_map = {'A':0,'B':1,'C':2,'D':3,'E':4}
def load_data(file_path):
    data = []
    with open(file_path, 'rb') as f: 
        for item in json_lines.reader(f):
            data.append(item)

    processed_data = []
    for item in data:
        question = item['question']['stem']
        options = [_['text'] for _ in item['question']['choices']]
        examples = []
        for option in options:
            text = question + " " + option
            encoded = tokenizer.encode_plus(
                text,
                truncation=True,
                max_length=512,
                padding='max_length',
                return_attention_mask=True,
                return_tensors='pt'
            )
            examples.append(encoded)

        input_ids = torch.stack([example['input_ids'] for example in examples]).squeeze()
        attention_mask = torch.stack([example['attention_mask'] for example in examples]).squeeze()

        label = torch.tensor(answer_map[item['answerKey']])

        processed_data.append((input_ids, attention_mask, label))

    return processed_data

In [None]:
model = RobertaForMultipleChoice.from_pretrained('roberta-large')
tokenizer = RobertaTokenizer.from_pretrained('roberta-large')

In [None]:
train_data = load_data("../data/rs_train.jsonl")
valid_data = load_data("../data/rs_dev.jsonl")

In [None]:
train_dataset = MultipleChoiceDataset(train_data)
valid_dataset = MultipleChoiceDataset(valid_data)

batch_size = 4

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
valid_loader = DataLoader(valid_dataset, shuffle=False, batch_size=batch_size)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)
epochs = 3
total_steps = len(train_loader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

In [None]:
for epoch in range(epochs):
    model.save_pretrained('model_{}_directory'.format(epoch))
    model.train()
    total_loss = 0
    average_loss = 0
    for index, batch in tqdm(enumerate(train_loader)):
        inputs = {key: val.to(device) for key, val in batch.items() if key != "labels"}
        labels = batch["labels"].to(device)
        outputs = model(**inputs, labels=labels)
        if index ==0:
            print(outputs)
        loss = outputs.loss
        total_loss += loss.item()
        average_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        if index%100==0:
            print("###{}####: Average loss: {}".format(index,average_loss / 100))
            average_loss = 0

    avg_train_loss = total_loss / len(train_loader)

    model.eval()
    preds = []
    true_labels = []
    for batch in tqdm(valid_loader):
        inputs = {key: val.to(device) for key, val in batch.items() if key != "labels"}
        labels = batch["labels"]
        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits
        preds.extend(torch.argmax(logits, dim=1).detach().cpu().numpy())
        true_labels.extend(labels.numpy())

    acc = accuracy_score(true_labels, preds)
    print(f'Epoch: {epoch+1}, Train loss: {avg_train_loss}, Validation accuracy: {acc}')
    model.save_pretrained('model_{}_directory'.format(epoch))