In [None]:
import torch
from transformers import RobertaTokenizer, RobertaForMultipleChoice, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score
import json
import json_lines
import os
import numpy as np
from tqdm import tqdm
from utils import *

In [None]:
sentence_data_path = ""
wordplay_data_list = ""
sentence_data_list = list(np.load(sentence_data_path,allow_pickle=True))
wordplay_data_list = list(np.load(wordplay_data_list,allow_pickle=True))

In [None]:
test_data_list = sentence_data_list + wordplay_data_list

In [None]:
model_path = ''
model = RobertaForMultipleChoice.from_pretrained(model_path)
tokenizer = RobertaTokenizer.from_pretrained('roberta-large')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
class MultipleChoiceDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_ids, attention_mask, label = self.data[idx]
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': label
        }

In [None]:
test_data_list[0]

In [None]:
answer_map = {'A':0,'B':1,'C':2,'D':3}
def load_data(test_data_list):

    processed_data = []
    for item in test_data_list:
        question = item['question']
        options = [_ for _ in item['choice_list']]

        examples = []
        for option in options:
            text = question + " " + option
            encoded = tokenizer.encode_plus(
                text,
                truncation=True,
                max_length=512,
                padding='max_length',
                return_attention_mask=True,
                return_tensors='pt'
            )
            examples.append(encoded)

        input_ids = torch.stack([example['input_ids'] for example in examples]).squeeze()
        attention_mask = torch.stack([example['attention_mask'] for example in examples]).squeeze()

        label = torch.tensor(item['label'])

        processed_data.append((input_ids, attention_mask, label))

    return processed_data

In [None]:
valid_data = load_data(test_data_list)
valid_dataset = MultipleChoiceDataset(valid_data)
batch_size=4
valid_loader = DataLoader(valid_dataset, shuffle=False, batch_size=batch_size)

In [None]:
model.eval()
preds = []
for batch in tqdm(valid_loader):
    inputs = {key: val.to(device) for key, val in batch.items() if key != "labels"}
    labels = batch["labels"]
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    preds.extend(torch.argmax(logits, dim=1).detach().cpu().numpy())
    break

In [None]:
pred_map = ['A','B','C','D']
for pred,item in zip(preds,test_data_list):
    item['predict'] = pred_map[int[pred]]

In [None]:
word_play,sentence_play = getResultdata(test_data_list)
final_result = getSeperateResult(word_play,sentence_play)