In [None]:
import os
import re
import json
from tqdm import tqdm
import torch
import logging
import argparse
import numpy as np
from overrides import overrides
from torch.nn import CrossEntropyLoss
from transformers import RobertaTokenizer, RobertaForMaskedLM
import csv
from tqdm import tqdm
import xlrd
import random
from utils import *

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
lm = "roberta-large"
device="cuda:0"
# lm = '../Model/roberta_cskg/'
tokenizer = RobertaTokenizer.from_pretrained(lm)
model = RobertaForMaskedLM.from_pretrained(lm)
model.to(device)
model.eval()

In [None]:
sentence_data_path = ""
wordplay_data_list = ""
sentence_data_list = list(np.load(sentence_data_path,allow_pickle=True))
wordplay_data_list = list(np.load(wordplay_data_list,allow_pickle=True))

In [None]:
test_data_list = sentence_data_list + wordplay_data_list

In [None]:
MAX_SEQUENCE_PER_TIME= 80
gold = []
predictions = []
results = []
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

for sample_data in tqdm(test_data_list):
    question = sample_data['question']
    label = sample_data['label']
    choices = [sub_answer[0].lower()+sub_answer[1:] for sub_answer in sample_data['choice_list'] ]
    gold.append(label)
    prediction = score_task(question, choices, tokenizer, device, model)
    sample_data['predict'] = prediction


In [None]:
predict_to_choine = {0:'A',1:'B',2:'C',3:'D'}
for item in test_data_list:
    item['predict'] = predict_to_choine[item['predict']]

In [None]:
word_play,sentence_play = getResultdata(test_data_list)
final_result = getSeperateResult(word_play,sentence_play)

### utils

In [None]:
###############Roberta#######################
def score_task(question, choices, tokenizer, device, model):
    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    question_ids = tokenizer.encode(question)
    choice_ids = [tokenizer.encode(choice, add_prefix_space=True)[1:-1] for choice in choices]
    sequences = [question_ids[:-1] + choice_ids[i] +[tokenizer.sep_token_id] for i in range(len(choice_ids))]
    label_ids = [[-100]+text[1:-1]+[-100] for text in sequences]
    sequences, label_ids, attention_mask = prepare_input(sequences, label_ids, pad_token_id)
    prediction = token_wise_scoring(sequences, label_ids, attention_mask, tokenizer, device, model)
    return prediction


def prepare_input(sequences, label_ids, pad_token_id):
    max_length = max([len(text) for text in sequences])
    attention_mask = np.zeros((len(sequences), max_length))
    for i in range(len(sequences)):
        attention_mask[i][:len(sequences[i])] = 1
    sequences = [text + [pad_token_id] * (max_length - len(text)) for text in sequences]
    label_ids = [text + [-100] * (max_length - len(text)) for text in label_ids]
    return sequences, label_ids, attention_mask


def token_wise_scoring(sequences, label_ids, attention_mask, tokenizer, device, model):
    choice_loss = [0 for i in range(len(sequences))]
    for i in range(len(sequences)):
        tmp_seq_list = []
        tmp_label_list = []
        tmp_attention_mask = []
        curr_label_ids = label_ids[i]
        for j, t in enumerate(curr_label_ids):
            if t == -100:
                continue
            tmp_seq = torch.tensor(sequences[i][:j]+[tokenizer.mask_token_id]+sequences[i][j+1:]).long().to(device)
            tmp_label = torch.tensor([-100]*j+sequences[i][j:j+1]+[-100]*(len(sequences[i])-j-1)).long().to(device)
            tmp_seq_list.append(tmp_seq)
            tmp_label_list.append(tmp_label)
            tmp_attention_mask.append(torch.tensor(attention_mask[i]).long().to(device))
        tmp_seq_list = torch.stack(tmp_seq_list)
        tmp_label_list = torch.stack(tmp_label_list)
        tmp_attention_mask = torch.stack(tmp_attention_mask)
        if len(tmp_seq_list) < MAX_SEQUENCE_PER_TIME:
            loss = get_lm_score(model, tmp_seq_list, tmp_label_list, tmp_attention_mask)
        else:
            loss = []
            for chunk in range(0, len(tmp_seq_list), MAX_SEQUENCE_PER_TIME):
                loss.append(get_lm_score(model, tmp_seq_list[chunk:chunk+MAX_SEQUENCE_PER_TIME], tmp_label_list[chunk:chunk+MAX_SEQUENCE_PER_TIME], tmp_attention_mask[chunk:chunk+MAX_SEQUENCE_PER_TIME]))
            loss = np.concatenate(loss)
        choice_loss[i] = sum(loss)/len(loss) 
    prediction = choice_loss.index(min(choice_loss))
    return prediction


def get_lm_score(model, batch, label_ids, attention_mask):
    """
    Get the cross entropy loss of the texts in batch using the langage model
    """
    # Batch: [num_choices, max_length]
    with torch.no_grad():
        num_choices, max_length = batch.shape
        label_ids = label_ids.view(-1)
        lm_logits = model(batch, attention_mask=attention_mask)[0]
        lm_logits = lm_logits.view(-1, lm_logits.size(-1))
        loss_fct = CrossEntropyLoss(reduction="none")
        loss = loss_fct(lm_logits, label_ids)
        loss = loss.view(num_choices, -1).sum(1).cpu().numpy()
    return loss