In [44]:
from mtrfg.utils.data_utils import read_conllu_dataset_allennlp
import string
from pprint import pprint
from collections import OrderedDict

In [33]:
## file path where the results are stored as CoNLLU file
output_file = f'/user/d.bhatt/Multitask-RFG/saved_models/QA_as_FGP/_deepset-bert-base-uncased-squad2_2023-01-13--13:23:24/test_out.conllu'
input_file = f'/data/Multitask_RFG/squad_conllu_dataset/test.conllu'

In [34]:
## reading the output
pred_data = read_conllu_dataset_allennlp(output_file)
gt_data = read_conllu_dataset_allennlp(input_file)

In [45]:
## let's iterate through data and get answers
def get_question_context_answer(datapoint_pred, datapoint_gt):
    def remove_space_punctuations(input_string):
        """
            This function is to remove the space that appears before 
            punctuations due to tokenization, and here we remove it. 
        """
        output_string = input_string
        for punc in string.punctuation:
            output_string = output_string.replace(f' {punc}', f'{punc}')
            output_string = output_string.replace(f'{punc} ', f'{punc}')

        return output_string


    def get_question(datapoint):
        """
            Extract question from datapoint received
        """
        words = [word.text for word in datapoint['words'].tokens]
        sep_index = words.index('[SEP]')
        question = ' '.join(words[:sep_index])
        question = remove_space_punctuations(question)
        return question
    
    def get_context(datapoint):
        """
            Get answer from datapoint received
        """
        words = [word.text for word in datapoint['words'].tokens]
        sep_index = words.index('[SEP]') + 1
        context = ' '.join(words[sep_index:])
        context = remove_space_punctuations(context)
        return context

    def get_answer(datapoint):
        """
            Extract answer based on tags. 
            We iterate through tags, and try to find B-ANS, I-ANS followed by it are part of the 
            answer too
        """
        answers = []
        b_ans_found = False
        for i, tag in enumerate(datapoint['pos_tags'].labels):
            if tag == 'B-ANS':
                b_ans_found = True
                answers.append(datapoint['words'].tokens[i].text)
            elif tag == 'O':
                b_ans_found = False
            elif tag == 'I-ANS' and b_ans_found:
                answers[-1] = answers[-1] + ' ' + datapoint['words'].tokens[i].text

        return [remove_space_punctuations(answer) for answer in answers]
        

    answer_dict = {}
    answer_dict['question'] = get_question(datapoint_gt)
    answer_dict['gt_ans'] = get_answer(datapoint_gt)
    answer_dict['pred_ans'] = get_answer(datapoint_pred)
    answer_dict['context'] = get_context(datapoint_gt)

    return answer_dict

In [46]:
all_answers = {}

for i, (pred_data_point, gt_data_point) in enumerate(zip(pred_data, gt_data)):
    """
        Here, we iterate through gt and prediction data, and build a 
        dictionary with GT and Predicted answers! 
    """
    all_answers[f'{i}'.zfill(5)] = get_question_context_answer(pred_data_point, gt_data_point)



In [None]:
pprint(all_answers)