In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd '/content/gdrive/MyDrive/VLSP2021-MRC-BLANC'

In [None]:
!pip install transformers trankit

In [None]:
import string
import json
from trankit import Pipeline
from tqdm import tqdm

punctuation = string.punctuation + '…'

p = Pipeline('vietnamese', embedding='xlm-roberta-large')

In [None]:
def algin_segmented_answer(tokenized_context, answer):
    answer_start = answer['answer_start']
    answer_text = ' '.join(answer['text'].split())
    answer_end = answer_start + len(answer_text)

    left_spans, right_spans = [], []
    for sentence in tokenized_context['sentences']:
        for token in sentence['tokens']:
            left_spans.append(token['dspan'][0])
            right_spans.append(token['dspan'][1])
    
    if answer_start is not None and answer_start not in left_spans:
        corrected_answer_start = left_spans[0]
        corrected_answer_start_idx = 0
        for idx, left_span in enumerate(left_spans):
            if left_span > answer_start: break
            corrected_answer_start = left_span
            corrected_answer_start_idx = idx
        
        if corrected_answer_start_idx + 1 < len(left_spans) and ((corrected_answer_start + left_spans[corrected_answer_start_idx + 1])/2 < answer_start):
            answer_start = left_spans[corrected_answer_start_idx + 1]
        else:
            answer_start = corrected_answer_start

    if answer_start is not None and answer_end not in right_spans:
        corrected_answer_end = right_spans[-1]
        corrected_answer_end_idx = 0
        for idx, right_span in enumerate(right_spans[::-1]):
            if right_span < answer_end: break
            corrected_answer_end = right_span
            corrected_answer_end_idx = idx

        if corrected_answer_end_idx + 1 < len(right_spans) and ((corrected_answer_end + right_spans[::-1][corrected_answer_end_idx + 1])/2 < answer_end):
            answer_end = right_spans[::-1][corrected_answer_end_idx + 1]
        else:
            answer_end = corrected_answer_end

    if answer_start is not None:
        while answer_start < len(tokenized_context['text']):
            if (tokenized_context['text'][answer_start] in punctuation or tokenized_context['text'][answer_start] == ' '): answer_start += 1
            else: break
        while answer_start not in left_spans: answer_start -= 1

        while answer_end - 1 > 0:
            if tokenized_context['text'][answer_end - 1] in punctuation or tokenized_context['text'][answer_end - 1] == ' ': answer_end -= 1
            else: break
        while answer_end not in right_spans: answer_end += 1

    cur_offset, new_answer_start, new_answer_end, new_answer_text = 0, None, None, []
    found_answer = False
    idx_sentence_begin = None
    idx_sentence_end = None
    for idx, sentence in enumerate(tokenized_context['sentences']):
        if not found_answer: cur_offset = 0
        for token in sentence['tokens']:
            if token['dspan'][0] == answer_start:
                new_answer_start = cur_offset
                found_answer = True
                idx_sentence_begin = idx

            if new_answer_start is not None and new_answer_end is None:
                new_answer_text.append(token['text'].replace(' ', '_'))

            if token['dspan'][1] == answer_end:
                new_answer_end = cur_offset
                idx_sentence_end = idx

            cur_offset = cur_offset + token['dspan'][1] - token['dspan'][0] + 1

    new_context_text = []
    for idx, sentence in enumerate(tokenized_context['sentences']):
        if idx_sentence_begin <= idx and idx <= idx_sentence_end:
            for token in sentence['tokens']:
                new_context_text.append(token['text'].replace(' ', '_'))
    
    new_context_text = ' '.join(new_context_text)
    new_answer_text = ' '.join(new_answer_text)

    return new_context_text, {'answer_start': new_answer_start, 'text': new_answer_text}

In [None]:
original_data = json.load(open('./data/train_data/original_data/train.json', 'r'))

new_data = []

for article in tqdm(original_data['data']):
    for paragraph in article['paragraphs']:
        context = paragraph['context']
        context = ' '.join(context.split())
        tokenized_context = p.tokenize(context)
        for qa in paragraph['qas']:
            question = qa['question']
            question = ' '.join(question.split())
            tokenized_question = p.tokenize(question, is_sent=True)

            if 'plausible_answers' in qa:
                for answer in qa['plausible_answers']:
                    new_context_text, new_answer = algin_segmented_answer(tokenized_context, answer)
                    if new_context_text == '':
                        print(qa['id'])
                        continue
                    new_data.append({
                        'title': article['title'],
                        'paragraphs': [{
                            'qas': [
                                {
                                    'question': ' '.join([token['text'].replace(' ', '_') for token in tokenized_question['tokens']]),
                                    'answers': [],
                                    'plausible_answers': [new_answer],
                                    'id': qa['id'],
                                    'is_impossible': qa['is_impossible'],
                                }
                            ],
                            'context': new_context_text
                        }]
                    })
            else:
                for answer in qa['answers']:
                    new_context_text, new_answer = algin_segmented_answer(tokenized_context, answer)
                    if new_context_text == '':
                        print(qa['id'])
                        continue
                    new_data.append({
                        'title': article['title'],
                        'paragraphs': [{
                            'qas': [
                                {
                                    'question': ' '.join([token['text'].replace(' ', '_') for token in tokenized_question['tokens']]),
                                    'answers': [new_answer],
                                    'id': qa['id'],
                                    'is_impossible': qa['is_impossible'],
                                }
                            ],
                            'context': new_context_text              
                        }]
                    })

json.dump(
    {
        'version': 'viquad2_top1_training_set',
        'data': new_data
    },
    open(f'./data/train_data/tokenized_data/segmented_top_1_data.json', 'w', encoding='utf-8'), ensure_ascii=False
)