In [None]:
from IPython.display import clear_output

In [None]:
!git clone https://github.com/google-deepmind/xquad
!apt-get install libgoogle-perftools-dev libsparsehash-dev
!pip install sentencepiece
!pip install accelerate
!pip install stanza
!pip install langdetect
!pip install sudachipy sudachidict_core
!git clone https://github.com/qiyuw/WSPAlign.git
!git clone https://github.com/qiyuw/WSPAlign.InferEval.git
!pip install sentence-transformers==2.2.2
!pip install numba==0.56.4
!pip install sentence-splitter==1.4
!pip install faiss-gpu==1.7.2
!pip install googletrans==4.0.0rc1
!git clone https://github.com/alihejazi97/bertalign.git
!cd /content/bertalign/ && pip install .
clear_output(wait=False)

In [None]:
from transformers import (AutoConfig,AutoModelForQuestionAnswering,AutoTokenizer,pipeline)
import stanza
import torch
from google.colab import drive
from sentence_transformers import SentenceTransformer, util
from stanza.pipeline.core import DownloadMethod
import copy
import json
import os
from tqdm.notebook import tqdm
from bertalign import Bertalign
from difflib import SequenceMatcher
clear_output(wait=False)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe_align = pipeline("question-answering", model="qiyuw/WSPAlign-mbert-base", device=device)
clear_output(wait=False)
drive.mount('/content/drive')
print(f'device = {device}')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
device = cuda


In [None]:
# lang_codes = ['ar','de','el','en','es','fa','fr','hi','ja','ro','ru','sw','te','th','tr','vi','zh']
lang_codes = ['hi', 'en']
# 'bn' bengali is not supported
for lang_code in lang_codes:
    stanza.download(lang_code)
clear_output(wait=False)

In [None]:
def get_stanza_pipeline(lang_code, use_gpu=False):
    return stanza.Pipeline(lang_code, use_gpu=use_gpu, download_method=None, tokenize_no_ssplit=True)

# def contains_unmatched_snts_length(doc1, doc2, data_idx, context_idx):
#     if len(doc1.sentences) != len(doc2.sentences):
#         print(f'data_id = {data_idx} -- context_id = {context_idx}')
#         print(f"number of sentences don't match. src length = {len(doc1.sentences)} -- target length = {len(doc2.sentences)}")
#         return True
#     return False

def contains_illegal_phrases(phrase, illegal_phrases):
    if not illegal_phrases:
        return False
    for illegal_phrase in illegal_phrases:
        if phrase['src_text'] in illegal_phrase or illegal_phrase in phrase['src_text']:
            return True
        if SequenceMatcher(None, phrase['src_text'], illegal_phrase).find_longest_match().size > 4:
            return True
    return False

def align_sentences(src, tgt, src_lang, tgt_lang):
    aligner = Bertalign(src, tgt, src_lang, tgt_lang)
    aligner.align_sents()
    src_lines = []
    tgt_lines = []
    for bead in (aligner.result):
        src_lines.append(aligner._get_line(bead[0], aligner.src_sents))
        tgt_lines.append(aligner._get_line(bead[1], aligner.tgt_sents))
    return src_lines, tgt_lines

def trim_phrases(phrases, illegal_phrases=None, threshold=0.4):
    phrases_out = []
    for phrase_idx, phrase in enumerate(phrases):
        if phrase['align_score'] < threshold:
            continue
        if contains_illegal_phrases(phrase, illegal_phrases):
            continue
        if len(phrases_out) == 0:
            phrases_out.append(phrase)
        elif phrase['target_start'] >= phrases_out[-1]['target_end']:
            phrases_out.append(phrase)
    return phrases_out

def get_alignments(src, target, phrases, pipe, context_sep=' \u00b6 '):
    remove_ids = []
    for phrase_idx, phrase in enumerate(phrases):
        alignment_data = target[:phrase['target_start']] + context_sep + target[phrase['target_start']:phrase['target_end']] + context_sep + target[phrase['target_end']:]
        try:
            result = pipe(alignment_data, src)
            phrase['src_start'] = result['start']
            phrase['src_end'] = result['end']
            phrase['src_text'] = result['answer']
            phrase['align_score'] = result['score']
        except:
            remove_ids.append(phrase_idx)

    for i, remove_id in enumerate(remove_ids):
          del[phrases[remove_id-i]]

def replace_phrases(snt_src, phrases, snt_target, pipe_align, illegal_phrases):
    phrases = sorted(phrases, key=(lambda phrase: (phrase['target_start'], -phrase['target_end'])))
    get_alignments(snt_src, snt_target, phrases, pipe_align)
    phrases = trim_phrases(phrases, illegal_phrases)
    if len(phrases) == 0:
        return copy.copy(snt_src)
    snt_out = copy.copy(snt_src[:phrases[0]['src_start']])
    for idx, phrase in enumerate(phrases):
        if idx == len(phrases) - 1:
            snt_out = snt_out + snt_target[phrase['target_start']:phrase['target_end']] + snt_src[phrase['src_end']:]
            continue
        snt_out = snt_out + snt_target[phrase['target_start']:phrase['target_end']] + snt_src[phrase['src_end']:phrases[idx+1]['src_start']]
    return snt_out


def convert(src_lines, tgt_lines, pipe_src, pipe_target, pipe_align, illegal_phrases, data_idx, context_idx):
    p_out = []
    for src_line, tgt_line in zip(src_lines, tgt_lines):
        if len(pipe_src(src_line).sentences) == 0:
            continue
        if len(pipe_target(tgt_line).sentences) == 0:
            p_out.append(src_line)
            continue
        snt_src = pipe_src(src_line).sentences[0]
        snt_target = pipe_target(tgt_line).sentences[0]
        phrases = []
        for entity in snt_target.entities:
            phrases.append({'target_start': entity.start_char, 'target_end':entity.end_char})
        for word in snt_target.words:
            if word.upos == 'ADJ':
                phrases.append({'target_start': word.start_char, 'target_end':word.end_char})
        snt_result = replace_phrases(snt_src.text, phrases, snt_target.text, pipe_align, illegal_phrases)
        p_out.append(snt_result)
    return ' '.join(p_out), True

In [None]:
context_sep=' \u00b6 '
lang = lang_codes[0]
x_nlp = get_stanza_pipeline(lang, use_gpu=True)
en_nlp = get_stanza_pipeline('en', use_gpu=True)
clear_output(wait=False)

In [None]:
with open(f'/content/xquad/xquad.{lang}.json', 'r') as f:
    x_obj = json.load(f)
with open(f'/content/xquad/xquad.en.json', 'r') as f:
    en_obj = json.load(f)

In [None]:
count_forget_it = 0
with tqdm(total=len(x_obj['data'])) as pbar1:
    for data_idx, (x_data, en_data) in enumerate(zip(x_obj['data'], en_obj['data'])):
        with tqdm(total=len(x_data['paragraphs']), leave=False) as pbar2:
            for context_idx, (x_paragraph, en_paragraph) in enumerate(zip(x_data['paragraphs'], en_data['paragraphs'])):
                x_context = x_paragraph['context'].replace('\ufeff','')
                en_context = en_paragraph['context'].replace('\ufeff','')

                illegal_phrases = []
                for x_qas in x_paragraph['qas']:
                    illegal_phrases.append(x_qas['answers'][0]['text'])

                src_lines, tgt_lines = align_sentences(x_context,en_context, lang, 'en')

                cs_context, changed = convert(src_lines, tgt_lines, x_nlp, en_nlp, pipe_align, illegal_phrases, data_idx, context_idx)

                if changed:
                    forget_it = False
                    for x_qas, illegal_phrase in zip(x_paragraph['qas'], illegal_phrases):
                        if cs_context.find(x_qas['answers'][0]['text']) == -1:
                            count_forget_it += 1
                            forget_it = True
                    if forget_it:
                        continue
                    x_paragraph['context'] = cs_context
                    for x_qas, illegal_phrase in zip(x_paragraph['qas'], illegal_phrases):
                        x_start = cs_context.find(x_qas['answers'][0]['text'])
                        x_qas['answers'][0]['answer_start'] = x_start
                        if not x_qas['answers'][0]['text'] == cs_context[x_start:x_start + len(x_qas['answers'][0]['text'])]:
                            print(f'data_id = {data_idx} -- context_id = {context_idx}')
                            print(f"{cs_context[x_start:x_start + len(x_qas['answers'][0]['text'])]}")
                            print(x_qas['answers'][0]['text'])

                pbar2.update()
        pbar1.update()

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

Exception: The language Hindi is not suppored yet.

In [None]:
# x_context = x_obj['data'][31]['paragraphs'][4]['context'].replace('\ufeff','')
# en_context =  en_obj['data'][31]['paragraphs'][4]['context'].replace('\ufeff','')
# qas_s = en_obj['data'][31]['paragraphs'][4]['qas']

# illegal_phrases = []
# for x_qas in qas_s:
#     illegal_phrases.append(x_qas['answers'][0]['text'])

# src_lines, tgt_lines = align_sentences(x_context,en_context, lang, 'en')

# cs_context, changed = convert(src_lines, tgt_lines, x_nlp, en_nlp, pipe_align, illegal_phrases, 31, 4)

In [None]:
path = f'/content/xquad_cs_context_only/xquad.{lang}.json'
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, 'w+', encoding='utf-8') as f:
    json.dump(x_obj, f, ensure_ascii=False)

In [None]:
!cp $path /content/drive/MyDrive/

In [None]:
from spacy.lang.en import Arabic

nlp = Arabic()
nlp.add_pipe("sentencizer")
doc = nlp("This is a sentence. This is another sentence.")
assert len(list(doc.sents)) == 2

In [None]:
doc.sents

TypeError: 'generator' object is not subscriptable