In [None]:
import argparse
import json
import logging
import os
import re

import numpy as np
import torch
from overrides import overrides
from torch.nn import CrossEntropyLoss
from transformers import DebertaV2ForMaskedLM
from transformers import DebertaV2Tokenizer
from tqdm import tqdm

In [None]:
%load_ext autoreload
%autoreload 2
device="cuda:0"

In [None]:
model_name = 'microsoft/deberta-v3-large'
cache_dir = '../Model/best_model'
tokenizer = DebertaV2Tokenizer.from_pretrained(model_name, cache_dir=cache_dir)
model = DebertaV2ForMaskedLM.from_pretrained(model_name, cache_dir=cache_dir)
model.to(device)
model.eval()

In [None]:
sentence_data_path = ""
wordplay_data_list = ""
sentence_data_list = list(np.load(sentence_data_path,allow_pickle=True))
wordplay_data_list = list(np.load(wordplay_data_list,allow_pickle=True))

In [None]:
test_data_list = sentence_data_list + wordplay_data_list
test_data_list[1]

In [None]:
gold = []
predictions = []
results = []
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
  

In [None]:
sample = test_data_list[0]
sample

In [None]:
MAX_SEQUENCE_PER_TIME = 80
choice_list= ['A','B','C','D']
for sample in tqdm(test_data_list):
    predict = score_task(sample['question'],sample['choice_list'],tokenizer, device, model)
    sample['predict'] = choice_list[int(predict)]

In [None]:
from utils import *

In [None]:
word_play,reverse_play = getResultdata(test_data_list)
final_result = getSeperateResult(word_play,reverse_play)

# method

In [None]:
def score_task(question, choices, tokenizer, device, model):
    pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
    question_ids = tokenizer.encode(question)
    choice_ids = [tokenizer.encode(choice, add_prefix_space=True)[1:-1] for choice in choices]
    sequences = [question_ids[:-1] + choice_ids[i] + [tokenizer.sep_token_id] for i in range(len(choice_ids))]
    label_ids = [[-100] + text[1:-1] + [-100] for text in sequences]
    sequences, label_ids, attention_mask = prepare_input(sequences, label_ids, pad_token_id)
    prediction = token_wise_scoring(sequences, label_ids, attention_mask, tokenizer, device, model)
    return prediction

In [None]:
def token_wise_scoring(sequences, label_ids, attention_mask, tokenizer, device, model):
    choice_loss = [0 for i in range(len(sequences))]
    for i in range(len(sequences)):
        tmp_seq_list = []
        tmp_label_list = []
        tmp_attention_mask = []
        curr_label_ids = label_ids[i]
        for j, t in enumerate(curr_label_ids):
            if t == -100:
                continue
            tmp_seq = torch.tensor(sequences[i][:j] + [tokenizer.mask_token_id] + sequences[i][j + 1:]).long().to(
                device)
            tmp_label = torch.tensor(
                [-100] * j + sequences[i][j:j + 1] + [-100] * (len(sequences[i]) - j - 1)).long().to(device)
            tmp_seq_list.append(tmp_seq)
            tmp_label_list.append(tmp_label)
            tmp_attention_mask.append(torch.tensor(attention_mask[i]).long().to(device))
        tmp_seq_list = torch.stack(tmp_seq_list)
        tmp_label_list = torch.stack(tmp_label_list)
        tmp_attention_mask = torch.stack(tmp_attention_mask)
        if len(tmp_seq_list) < MAX_SEQUENCE_PER_TIME:
            loss = get_lm_score(model, tmp_seq_list, tmp_label_list, tmp_attention_mask)
        else:
            loss = []
            for chunk in range(0, len(tmp_seq_list), MAX_SEQUENCE_PER_TIME):
                loss.append(get_lm_score(model, tmp_seq_list[chunk:chunk + MAX_SEQUENCE_PER_TIME],
                                         tmp_label_list[chunk:chunk + MAX_SEQUENCE_PER_TIME],
                                         tmp_attention_mask[chunk:chunk + MAX_SEQUENCE_PER_TIME]))
            loss = np.concatenate(loss)
        choice_loss[i] = sum(loss) / len(loss)
    prediction = choice_loss.index(min(choice_loss))
    return prediction


def prepare_input(sequences, label_ids, pad_token_id):
    max_length = max([len(text) for text in sequences])
    attention_mask = np.zeros((len(sequences), max_length))
    for i in range(len(sequences)):
        attention_mask[i][:len(sequences[i])] = 1
    sequences = [text + [pad_token_id] * (max_length - len(text)) for text in sequences]
    label_ids = [text + [-100] * (max_length - len(text)) for text in label_ids]
    return sequences, label_ids, attention_mask

In [None]:
def get_lm_score(model, batch, label_ids, attention_mask):
    """
    Get the cross entropy loss of the texts in batch using the langage model
    """
    # Batch: [num_choices, max_length]
    with torch.no_grad():
        num_choices, max_length = batch.shape
        label_ids = label_ids.view(-1)
        lm_logits = model(batch, attention_mask=attention_mask)[0]
        lm_logits = lm_logits.view(-1, lm_logits.size(-1))
        loss_fct = CrossEntropyLoss(reduction="none")
        loss = loss_fct(lm_logits, label_ids)
        loss = loss.view(num_choices, -1).sum(1).cpu().numpy()
    return loss