In [1]:
from collections import defaultdict, Counter, OrderedDict
import ujson
import pathlib
from pathlib import Path
import sqlite3
import sys
import textwrap
import torch
from tqdm import tqdm
from typing import Dict, List, Set, Union

import unicodedata

import uuid

import numpy as np

from aic_nlp_utils.json import read_jsonl, read_json, write_json, write_jsonl
from aic_nlp_utils.fever import fever_detokenize, import_fever_corpus_from_sqlite
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs
import stanza
# stanza.download("en")

sys.path.append('Claim_Generation')
from T5_QG import pipeline
from distractor_generation import Distractor_Generation

sys.path.append('Models')
from arguments import ModelArguments, DataTrainingArguments
from load import load_tokenizer_and_model, find_last_checkpoint

%load_ext autoreload
%autoreload 2

- This is full version aimed at generating SUPPORTED and REFUTED claims needed for evidence retrieval, keeping NEIs for later.
- It is a simplified version of the `claim_generation_paragraphs_wiki.ipynb`.
- Aimed to generate data for post LREV EnFEVER models (e.g., ColBERT v2) for CsFEVER corpus.
- Fixed input and output formats for those we use in AIC.

**Notes**
- Currently ignoring multi-hops - single evidence documents are used only.

In [2]:
FEVER_ROOT = Path("/mnt/data/factcheck/fever/data_full_nli-filtered-cs")
FEVER_DATA = Path(FEVER_ROOT, "fever-data/F1_titles_anserininew_threshold")
FEVER_CORPUS_SQLITE = Path(FEVER_ROOT, "fever/cs_wiki_revid_db_sqlite.db")
QACG_ROOT = Path(FEVER_ROOT, "qacg")
QACG_ROOT.mkdir(parents=True, exist_ok=True)

# TODO move to utils
def batch_apply(func, data, batch_size) -> List:
    n = len(data)
    first = 0
    res = []
    while first < n:
        last = min(first + batch_size, n)
        res += func(data[first:last])
        first = last
    assert n == len(res)
    return res

def nfc(s: str):
    return unicodedata.normalize("NFC", s)

In [3]:
import glob
import html
import re
from lxml.html.clean import Cleaner

# TODO: this is the same code as in drchajan:/download_wiki.ipynb; move it to aic_nlp_utils?
# this version is slightly modified (glob string & textcol parameter)

def _filter_and_fix_wiki_extract(root_dir, min_length, textcol="text", remove_re=None):
    # min_length - shorter pages (text) are removed
    # remove_re - matching pages (text) are removed
    records = []
    glob_str = f"{root_dir}/**"
    print(glob_str)
    source_jsonls = sorted(glob.glob(glob_str))
    for source_json in tqdm(source_jsonls):
        for r in read_jsonl(source_json):
            r["original_id"] = int(r["id"])
            r["id"] = r["title"].strip()
            r["text"] = html.unescape(r[textcol].strip())
            if textcol != "text":
                del r[textcol]
            r["id"] = nfc(r["id"])
            r["title"] = nfc(r["title"])
            r["text"] = nfc(r["text"])
            records.append(r)
    print(f"# all: {len(records)}")
    cntr = Counter((r["text"] for r in records))
    records = [r for r in records if cntr[r["text"]] == 1]
    print(f"# without duplicate texts: {len(records)}")
    records = [r for r in records if len(r["text"]) >= min_length]
    print(f"# without short texts: {len(records)}")
    if remove_re:
        pattern = re.compile(remove_re)
        records = [r for r in tqdm(records) if not re.match(pattern, r["text"])]
        print(f"# without text removed based on RE: {len(records)}")
    return records

def _fix_html(records, kill_tags, remove_tags, remove_text):
    cleaner = Cleaner(page_structure=False, links=False, kill_tags=kill_tags, remove_tags=remove_tags)
    TAG_RE = re.compile(r'<[^>]+>')

    n_html = 0
    n_unfixed = 0
    n_err = 0
    for r in tqdm(records):
        for rt in remove_text:
            r["text"] = r["text"].replace(rt, "")
        if "<" in r["text"]: # simple rough tag detection
            n_html += 1
            try:
                r["text"] = cleaner.clean_html(r["text"])
            except:
                n_err += 1
            r["text"] = TAG_RE.sub('', r["text"]).strip()
            if "<" in r["text"]:
                n_unfixed += 1
    print(f"# fixed HTML: {n_html}, remaining: {n_unfixed}, errors: {n_err}")
    return records


def filter_and_fix_wiki_extract_for_lang(root_dir, output_jsonl, lang, textcol="text"):
    assert lang in ["cs", "en", "pl", "sk"]
    remove_text = []
    if lang == "cs":
        min_length = 10
        remove_re = r"\#PŘESMĚRUJ[^\n]*|\#REDIRECT[^\n]*|\#redirect[^\n]*"
        kill_tags = ["id", "minor", "ns", "parentid", "timestamp", "contributor", "comment", "model", "format", "templatestyles"]
        remove_tags = ["revision"]
        remove_text = ["Externí odkazy.", "Odkazy.", "Reference.", "Literatura.", "Související články.", '"V tomto článku byl použit textu z článku ."', ': | | | | |', ': | | | |', ': | | |', ': | |', ': |', '• \xa0 \xa0 • • • •', '"V tomto článku byly použity textů z článků']
    elif lang == "en":
        min_length = 10
        remove_re=r"[^\n]* refers? to:?$|[^\n]* stands? for:?$|[^\n]*\n.* refers? to:?$|[^\n]*refers? to:\n[^\n]*|[^\n]*the following:?$|[^\n]* mean:?$|[^\n]* be:$|[^\n]* is:$|\#REDIRECT[^\n]*"
        kill_tags = ["ref", "onlyinclude"]
        remove_tags = []
    elif lang == "pl":
        min_length = 30
        remove_re=r"[^\n]* to:$|\#PATRZ[^\n]*|[^\n]* może oznaczać:$"
        kill_tags = []
        remove_tags  = ["poem"]
        remove_text = ["\right]</math>"]
    elif lang == "sk":
        min_length = 10
        remove_re=r"[^\n]* je:?$|[^\n]* môže byť:?$"
        kill_tags = ["indicator"]
        remove_tags  = []
        remove_text = ['je zatiaľ „“. Pomôž Wikipédii tým, že ho [ doplníš a rozšíriš.]"']

    records = _filter_and_fix_wiki_extract(root_dir, min_length, remove_re=remove_re, textcol=textcol)
    records = _fix_html(records, kill_tags, remove_tags, remove_text)
    write_jsonl(output_jsonl, records)
    return records

EXTRACTED_ROOT = "/mnt/data/factcheck/fever/data_full_nli-filtered-cs/"
corpus = filter_and_fix_wiki_extract_for_lang(
    Path(EXTRACTED_ROOT, "wiki-pages"),
    Path(EXTRACTED_ROOT, "fever", "wiki_extract_filtered_and_fixed_drchajan.jsonl"), "cs", textcol="contents")
corpus_id2idx = {r["id"]: i for i, r in enumerate(corpus)}

/mnt/data/factcheck/fever/data_full_nli-filtered-cs/wiki-pages/**


100%|██████████| 1476/1476 [00:11<00:00, 124.36it/s]


# all: 825078
# without duplicate texts: 501199
# without short texts: 501158


100%|██████████| 501158/501158 [00:00<00:00, 667662.44it/s]


# without text removed based on RE: 501142


100%|██████████| 501142/501142 [00:05<00:00, 84824.75it/s] 


# fixed HTML: 3992, remaining: 0, errors: 0


In [4]:
corpus[2003]

{'id': 'České středohoří',
 'revid': '503435',
 'url': 'https://cs.wikipedia.org/wiki?curid=4612',
 'title': 'České středohoří',
 'original_id': 4612,
 'text': 'České středohoří () je geomorfologický celek o rozloze 1265 km². Z hlediska horopisného patří do Podkrušnohorské oblasti, která je součástí Krušnohorské subprovincie. Na 84 % území Českého středohoří zaujímá Chráněná krajinná oblast České středohoří (CHKO České středohoří) o výměře 1063,17 km². Nejvyšším vrcholem je Milešovka (837 m). Nejnižším bodem je hladina Labe v Děčíně (121,9 m). Maximální výškový rozdíl tedy činí 715,1 m. Geomorfologické členění. Hluboké údolí Labe rozděluje České středohoří na dva geomorfologické podcelky: Verneřické středohoří (IIIB-5A) na pravém břehu Labe a Milešovské středohoří (IIIB-5B) na levém břehu Labe. Tyto podcelky se dále člení do celkem osmi okrsků: Geologie. Rozlohou 1266 km², délkou přes 70 km a šířkou až 25 km patří České středohoří k menším orografickým celkům. Přesto je však nejmohutně

The following extracts corpus pages used as evidence in annotated CsFEVER data. We can use any page, but this will give us better comparison of what QACG generates when compared to FEVER. 

In [5]:
# note, this differs from EnFEVER!
# Tomas Mlynar uses different format...
def extract_fever_evidence_pages(split_jsonls: List, corpus_id2idx: Dict, corpus):
    fever_pages = set()
    corpus_records = []
    not_found = 0
    for jsonl in split_jsonls:
        print(jsonl)
        split = read_jsonl(jsonl)
        for rec in split:
            if rec["verifiable"] == "VERIFIABLE":
                for ev in rec["evidence_cs"].keys():
                    ev = nfc(ev)
                    if ev in corpus_id2idx:
                        if ev not in fever_pages:
                            corpus_records.append(corpus[corpus_id2idx[ev]])
                        fever_pages.add(ev)
                    else:
                        not_found += 1
    print(f"missing pages: {not_found}/{not_found+len(fever_pages)}")
    return fever_pages, corpus_records



fever_pages_trn, corpus_recs_trn = extract_fever_evidence_pages([Path(FEVER_DATA, "train_fb_cs_nli_split_F1_titles_anserininew.jsonl")], corpus_id2idx, corpus)
fever_pages_dev, corpus_recs_dev = extract_fever_evidence_pages([Path(FEVER_DATA, "paper_dev_fb_cs_nli_split_F1_titles_anserininew.jsonl")], corpus_id2idx, corpus)
fever_pages_tst, corpus_recs_tst = extract_fever_evidence_pages([Path(FEVER_DATA, "paper_test_fb_cs_nli_split_F1_titles_anserininew.jsonl")], corpus_id2idx, corpus)

/mnt/data/factcheck/fever/data_full_nli-filtered-cs/fever-data/F1_titles_anserininew_threshold/train_fb_cs_nli_split_F1_titles_anserininew.jsonl
missing pages: 53/5084
/mnt/data/factcheck/fever/data_full_nli-filtered-cs/fever-data/F1_titles_anserininew_threshold/paper_dev_fb_cs_nli_split_F1_titles_anserininew.jsonl
missing pages: 3/553
/mnt/data/factcheck/fever/data_full_nli-filtered-cs/fever-data/F1_titles_anserininew_threshold/paper_test_fb_cs_nli_split_F1_titles_anserininew.jsonl
missing pages: 0/584


In [6]:
corpus_recs_trn[0]

{'id': 'Jeff Bridges',
 'revid': '546074',
 'url': 'https://cs.wikipedia.org/wiki?curid=481829',
 'title': 'Jeff Bridges',
 'original_id': 481829,
 'text': 'Jeffrey Leon „Jeff“ Bridges (* 4. prosince 1949 v Los Angeles v Kalifornii, USA) je americký herec a hudebník. Mezi jeho nejvýznamnější filmy patři "Poslední představení" (v originále "The Last Picture Show"), "Tron", "Starman", "Báječní Bakerovi hoši (The Fabulous Baker Boys)", "Král rybář (The Fisher King)", "Beze strachu (Fearless)", "Big Lebowski", "Kandidáti (The Contender)" a "Iron Man". Českým divákům je znám svojí rolí Jacka Prescotta v americkém dobrodružném filmu režiséra Johna Guillermina King Kong z roku 1976. Osobní život. Narodil se v Los Angeles v Kalifornii manželům Lloydovi a Dorothy Bridgesovým. Má staršího bratra Beaua a mladší sestru Lucindu. Jeho druhý bratr Garret zemřel v roce 1948 na syndrom náhlého úmrtí kojenců. Se svým bratrem Beauem, který se o něj staral, když byl jeho otec zaneprázdněn prací, měl velic

## NER
Not one of the CZECH NER models extracts ordinals or cardinals, i.e., numbers beyond dates. This might be a problem. We need a better one...

In [14]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, BertTokenizerFast
from transformers import pipeline
from ufal.nametag import Ner, Forms, TokenRanges, NamedEntities


def load_czert_ner_pipeline(model_name):
    model = AutoModelForTokenClassification.from_pretrained(model_name)
    tokenizer = BertTokenizerFast(Path(model_name, "vocab.txt"), strip_accents=False, do_lower_case=False, truncate=True, model_max_length=512)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ner_pipeline = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="first", device=device)
    def ner_pipeline_pairs(text):
        ner_dicts = ner_pipeline(text)
        # ner_pairs = [(e["word"], e["entity_group"]) for e in ner_dicts]
        ner_pairs = [(text[e["start"]:e["end"]], e["entity_group"]) for e in ner_dicts]
        return ner_pairs
    return ner_pipeline_pairs


class UFALNERExtractor:
    def __init__(self, model):
        self.ner = Ner.load(model)
        self.forms = Forms()
        self.tokens = TokenRanges()
        self.entities = NamedEntities()
        self.tokenizer = self.ner.newTokenizer()
        
    def __call__(self, text):
        self.tokenizer.setText(text)
        ners = []
        nertypes = []
        while self.tokenizer.nextSentence(self.forms, self.tokens):
            self.ner.recognize(self.forms, self.entities)
            
            entities = sorted(self.entities, key=lambda entity: (entity.start, -entity.length))
            
            prev_end = -1
            for entity in entities:
                if (entity.start + entity.length) <= prev_end: # take only the highest level entities
                    continue
                ners.append(" ".join(self.forms[entity.start:entity.start+entity.length]))
                nertypes.append(entity.type)
                prev_end = entity.start + entity.length
        ner_pairs = [(ner, nertype) for ner, nertype in zip(ners, nertypes)]
        return ner_pairs


def extract_ners(corpus_recs, ner_json, model_name):
    # for each text gives a triplet (ner, ner_type, ner-ner_type count in text)
    # the triplets are sorted by decreasing count

    if model_name == "ufal.nametag":
        ner_pipeline = UFALNERExtractor("/mnt/data/factcheck/ufal/ner/czech-cnec2.0-140304-no_numbers.ner")
    else:
        ner_pipeline = load_czert_ner_pipeline(model_name)
    entity_dict = OrderedDict()
    for l in tqdm(corpus_recs):
        text = l["text"]
        ner_pairs = ner_pipeline(text)
        ner_cnts = Counter(ner_pairs) # their 
        ners_unique_with_counts =  [(p[0], p[1], ner_cnts[(p[0], p[1])]) for p in set(ner_pairs)]
        ners_unique_with_counts = sorted(ners_unique_with_counts, key=lambda n: -n[2])
        entity_dict[l["id"]] = ners_unique_with_counts
    write_json(ner_json, entity_dict)

In [17]:
model_short="ufal.nametag"
model_name="ufal.nametag"
extract_ners(corpus_recs_dev, Path(QACG_ROOT, "ner", f"dev_ners-{model_short}.json"), model_name)
extract_ners(corpus_recs_tst, Path(QACG_ROOT, "ner", f"test_ners-{model_short}.json"), model_name)
extract_ners(corpus_recs_trn, Path(QACG_ROOT, "ner", f"train_ners-{model_short}.json"), model_name)

100%|██████████| 550/550 [00:07<00:00, 68.96it/s] 
100%|██████████| 584/584 [00:09<00:00, 61.10it/s] 
100%|██████████| 5031/5031 [01:05<00:00, 76.71it/s] 


In [15]:
model_short="CZERT-B-ner-CNEC"
model_name=f"/mnt/data/factcheck/models/czert/{model_short}"
extract_ners(corpus_recs_dev, Path(QACG_ROOT, "ner", f"dev_ners-{model_short}.json"), model_name)
extract_ners(corpus_recs_tst, Path(QACG_ROOT, "ner", f"test_ners-{model_short}.json"), model_name)
extract_ners(corpus_recs_trn, Path(QACG_ROOT, "ner", f"train_ners-{model_short}.json"), model_name)

100%|██████████| 550/550 [00:11<00:00, 46.11it/s]
100%|██████████| 584/584 [00:12<00:00, 46.48it/s]
100%|██████████| 5031/5031 [01:43<00:00, 48.59it/s]


In [16]:
model_short="PAV-ner-CNEC"
model_name=f"/mnt/data/factcheck/models/czert/{model_short}"
extract_ners(corpus_recs_dev, Path(QACG_ROOT, "ner", f"dev_ners-{model_short}.json"), model_name)
extract_ners(corpus_recs_tst, Path(QACG_ROOT, "ner", f"test_ners-{model_short}.json"), model_name)
extract_ners(corpus_recs_trn, Path(QACG_ROOT, "ner", f"train_ners-{model_short}.json"), model_name)

100%|██████████| 550/550 [00:11<00:00, 46.49it/s]
100%|██████████| 584/584 [00:13<00:00, 43.53it/s]
100%|██████████| 5031/5031 [01:49<00:00, 46.10it/s]


I am not merging outputs of multiple NER methods anymore -- keeping for future?

In [18]:
def merge_ners(fins, fout):
    iners = [read_json(fin) for fin in fins]
    page2ners = defaultdict(dict)
    for iner in iners:
        for page, ners in iner.items():
            for ner in ners:
                word, type_, cnt = ner
                if word in page2ners[page]:
                    if page2ners[page][word][1] >= cnt:
                        continue
                page2ners[page][word] = [type_, cnt]
    page2nerlsts = {page: sorted([[word] + params for word, params in ners.items()], key=lambda n: -n[2]) for page, ners in page2ners.items()}
    write_json(fout, page2nerlsts)

merge_ners(
    [Path(QACG_ROOT, "ner", "dev_ners-CZERT-B-ner-CNEC.json"),
     Path(QACG_ROOT, "ner", "dev_ners-PAV-ner-CNEC.json")], 
    Path(QACG_ROOT, "ner", "dev_ners.json"))

merge_ners(
    [Path(QACG_ROOT, "ner", "test_ners-CZERT-B-ner-CNEC.json"),
     Path(QACG_ROOT, "ner", "test_ners-PAV-ner-CNEC.json")], 
    Path(QACG_ROOT, "ner", "test_ners.json"))

merge_ners(
    [Path(QACG_ROOT, "ner", "train_ners-CZERT-B-ner-CNEC.json"),
     Path(QACG_ROOT, "ner", "train_ners-PAV-ner-CNEC.json")], 
    Path(QACG_ROOT, "ner", "train_ners.json"))

## Question Generation (QG)

In [18]:
class BatchQuestionGenerator:
    def __init__(self, tokenizer, model, highlight=False, highlight_tag="<hl>", max_source_length=1024, padding=False, device="cuda", debug=False):
        self.tokenizer = tokenizer
        self.model = model.to(device)
        self.highlight = highlight
        self.highlight_tag = highlight_tag
        self.max_source_length = max_source_length
        self.padding = padding
        self.device = device
        self.debug = debug

    def generate(self, contexts, answers, batch_size=32):
        def highlight_fun(answer, context):
            offset = context.index(answer)
            return f"{context[:offset]}<hl>{answer}<hl>{context[offset + len(answer):]}"

        n = len(contexts)
        assert n == len(answers), (n, len(answers))
        offset = 0
        failures = 0
        predictions = []
        while offset < n:
            last = min(offset + batch_size, n)
            if self.highlight:
                inputs = []
                for context, answer in zip(contexts[offset:last], answers[offset:last]):
                    # if answer in context:
                    inputs.append(highlight_fun(answer, context) )
                    # else:
                        # failures += 1
            else:
                inputs = [answer + "</s>" + context for context, answer in zip(contexts[offset:last], answers[offset:last])]
            model_inputs = self.tokenizer(inputs, max_length=self.max_source_length, padding=self.padding, truncation=True, return_tensors="pt")
            model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()}
            with torch.no_grad():
                Y = self.model.generate(**model_inputs, max_new_tokens=768)
                batch_predictions = self.tokenizer.batch_decode(
                    Y, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
            predictions += batch_predictions
            offset += batch_size

        assert n == len(predictions)
        if self.debug:
            for input, pred in zip(inputs, predictions):
                print(textwrap.fill(input))
                print()
                print(pred)
                print("----------------------------")
        # print(f"#failures: {failures}, #predictions: {len(predictions)}/{n}")
        return predictions

In [19]:
highlight = False
# def batch_evaluator():
# model_args = ModelArguments(model_name_or_path="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/facebook/mbart-large-cc25_cs_CZ/checkpoint-10000")
# model_args = ModelArguments(model_name_or_path="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/mt5-base_cs_CZ/checkpoint-40000")
# model_args = ModelArguments(model_name_or_path="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/mt5-large_cs_CZ/checkpoint-59000")

highlight = True
model_args = ModelArguments(model_name_or_path="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qg/google/mt5-large_cs_CZ_hl/checkpoint-48000")

tokenizer, model, data_collator = load_tokenizer_and_model(model_args, lang="cs_CZ", fp16=True)

batch_question_generator = BatchQuestionGenerator(tokenizer, model, highlight=True, padding=True, debug=False)

In [20]:
def generate_qas(corpus_recs, ner_json, qas_json, generator):
    # QG NLP object

    # print('Loading QG module >>>>>>>>')
    # print('QG module loaded.')

    ners = read_json(ner_json)

    qas = OrderedDict()
    invalid_sample = 0
    for l in tqdm(corpus_recs):
        id_ = str(l['id'])
        if id_ not in ners: # no NERs in this text
            continue
        entities = ners[id_]

        # create a batch
        contexts, answers = [], []
        for ent_text, ent_type, ent_cnt in entities:
            contexts.append(l['text'])
            answers.append(ent_text)


        # question generation
        if len(contexts) > 0 and len(contexts) == len(answers):
            questions = []
            # try:
            questions = generator.generate(contexts, answers)
            # except:
                # invalid_sample += 1

            if len(questions) == 0:
                continue
            
            assert len(questions) == len(contexts)
            # save results
            result_for_sample = {}
            for entity, question, answer, context in zip(entities, questions, answers, contexts):
                ent_text, ent_type, _ = entity
                result_for_sample[f'{ent_text}:::{ent_type}'] = [question, answer]

            qas[str(l['id'])] = result_for_sample
        else:
            invalid_sample += 1

    # print(f'#invalid samples: {invalid_sample}')
    Path(qas_json).parent.mkdir(parents=True, exist_ok=True)
    write_json(qas_json, qas)


# generate_qas(corpus_recs_dev[0:1], Path(QACG_ROOT, "ner", "dev_ners.json"), Path(QACG_ROOT, "qa", "dev_qas.json"), batch_question_generator)
# generate_qas(corpus_recs_dev, Path(QACG_ROOT, "ner", "dev_ners-PAV-ner-CNEC.json"), Path(QACG_ROOT, "qa", "dev_qas-PAV-ner-CNEC_cp10000.json"), batch_question_generator)
# generate_qas(corpus_recs_dev, Path(QACG_ROOT, "ner", "dev_ners-PAV-ner-CNEC.json"), Path(QACG_ROOT, "qa","dev_qas-PAV-ner-CNEC_mt5-base-cp40000.json"), batch_question_generator)
# generate_qas(corpus_recs_tst, Path(QACG_ROOT, "ner", "test_ners.json"), Path(QACG_ROOT, "qa", "test_qas.json"), batch_question_generator)
# generate_qas(corpus_recs_trn, Path(QACG_ROOT, "ner", "train_ners.json"), Path(QACG_ROOT, "qa", "train_qas.json"), batch_question_generator)

In [21]:
confs = [
    (corpus_recs_dev, "dev"),
    (corpus_recs_tst, "test"),
    (corpus_recs_trn, "train"),
]

for corpus_recs, name in confs:
    generate_qas(corpus_recs, Path(QACG_ROOT, "ner", f"{name}_ners-PAV-ner-CNEC.json"), Path(QACG_ROOT, "qa", f"{name}_qas-PAV-ner-CNEC_mt5-large-cp59000.json"), batch_question_generator)

QACG_ROOT

100%|██████████| 550/550 [22:13<00:00,  2.43s/it] 
100%|██████████| 584/584 [24:16<00:00,  2.49s/it]
100%|██████████| 5031/5031 [3:19:49<00:00,  2.38s/it]   


PosixPath('/mnt/data/factcheck/fever/data_full_nli-filtered-cs/qacg')

## Claim Generator (QA2D + Replacement Generator)

In [8]:
# Just for testing
# QA2D_model_path="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/facebook/mbart-large-cc25_cs_CZ/checkpoint-26000"
# sense_to_vec_path="dependencies/s2v_old"

# model_args = ModelArguments(model_name_or_path=QA2D_model_path)
# tokenizer, model, data_collator = load_tokenizer_and_model(model_args, lang="cs_CZ")

In [7]:
class SameDocumentNERReplacementGenerator:
    def __init__(self, seed=1234):
        self.rng = np.random.RandomState(seed)

    def get_options(self, answer, entity, passage_entities, **kwargs):
        ent_name, ent_type = entity.split(":::")
        selected_entity_names = set()
        for passage_entity in passage_entities:
            pent_name, pent_type = passage_entity.split(":::")
            if pent_type == ent_type and pent_name != ent_name:
                selected_entity_names.add(pent_name)
        if len(selected_entity_names) == 0:
            return None
        selected_entity_names = list(selected_entity_names)
        selected_entity_name = self.rng.choice(selected_entity_names)
        selected_entity = (selected_entity_name, ent_type)
        # print(f"{entity} -> {selected_entity}")
        return selected_entity
    
# replacement_generator = Distractor_Generation(sense2vec_path=sense_to_vec_path, T=0.7) # original EN replacement generator
replacement_generator = SameDocumentNERReplacementGenerator()

In [8]:
class ClaimGenerator:
    def __init__(self, replacement_generator, corpus_recs, ner_json, qas_json, QA2D_model_path, device="cuda"):
        # QA2D model object
        print('Loading QA2D module >>>>>>>>')
        model_args = ModelArguments(model_name_or_path=QA2D_model_path)
        self.tokenizer, self.model, data_collator = load_tokenizer_and_model(model_args, lang="cs_CZ")
        print(f'Running on device: {device}')
        # self.model, self.tokenizer = model, tokenizer # TODO REMOVE
        self.device = device
        self.model.to(device)

        self.replacement_generator = replacement_generator

        self.corpus_recs = corpus_recs
        self.ners = read_json(ner_json)
        self.qas = read_json(qas_json)

    def predict(self, inputs, max_source_length=1024, batch_size=16):
        def pred_func(input_texts: List[str]) -> List[str]:
            with torch.no_grad():
                X = self.tokenizer(input_texts, max_length=max_source_length, padding=True, truncation=True, return_tensors="pt")
                X = {k: X[k].to(self.device) for k in X.keys()}
                Y = self.model.generate(**X, max_new_tokens=768)
                output_texts = self.tokenizer.batch_decode(
                    Y, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
            return output_texts
            
        predictions = batch_apply(pred_func, inputs, batch_size=batch_size)
        return predictions

    def _load_passage_entities(self, id_):
        passage_entities = []
        for ent_text, ent_type, _ in self.ners[id_]:
            passage_entities.append(f'{ent_text}:::{ent_type}') # group by entity name and type as in the QAS file
        return passage_entities
    
    def _load_precomputed_qas_for_entities(self, id_, passage_entities):
        if id_ not in self.qas:
            print(f"missing id: {id_}")
            return None
        QA_for_sample = self.qas[id_]
        QA_pairs = []
        for entity in passage_entities:
            if entity in QA_for_sample:
                ent_text, ent_type = entity.split(':::')
                question, answer = QA_for_sample[entity]
                QA_pairs.append({'question': question, 'answer': answer, 'answer_type': ent_type})
            else:
                print(f"missing entity: {entity} for id: {id_}")
                return None
        if len(QA_pairs) == 0:
            print(f"zero length pairs for id: {id_}")
            return None
        return QA_pairs 
        

    def generate_supported_claims(self, sample):
        texts, id_ = sample['text'], str(sample['id'])

        # Step 1: load entities in text
        passage_entities = self._load_passage_entities(id_)
        if len(passage_entities) == 0: # no NERs
            return None 

        # Step 2: load precomputed QAs for entities
        QA_pairs = self._load_precomputed_qas_for_entities(id_, passage_entities)
        if QA_pairs is None:
            return None

        # Step 3: QA2D
        # to_predict = [qa['question'] + ' [SEP] ' + qa['answer'] for qa in QA_pairs] # original model
        to_predict = [qa['answer'] + '</s>' + qa['question'] for qa in QA_pairs]
        results = []
        # try:
        results = self.predict(to_predict)
        # except:
            # return None
        if len(results) == 0:
            print(f"zero length results for id: {id_}")
            return None

        assert len(results) == len(QA_pairs)

        claims_for_sample = OrderedDict()
        for ent, claim in zip(passage_entities, results):
            claims_for_sample[ent] = claim
        return claims_for_sample

    def generate_refute_global_claims(self, sample):
        texts, id_ = sample['text'], str(sample['id'])

        # Step 1: load entities in text
        passage_entities = self._load_passage_entities(id_)
        if len(passage_entities) == 0: # no NERs
            return None 
        
        # Step 2: get entity replacement
        entity_replacement_dict = {} # get replacement beforehand to save time

        valid_entities = set()
        for ent in passage_entities:
            ent_text, _ = ent.split(':::')
            replacement = self.replacement_generator.get_options(ent_text, entity=ent, passage_entities=passage_entities)
            if replacement is not None:
                entity_replacement_dict[ent_text] = replacement
                valid_entities.add(ent)
        # print(f"entity_replacement_dict={entity_replacement_dict}")

        # Step 3: load precomputed QAs for entities
        QA_pairs = self._load_precomputed_qas_for_entities(id_, passage_entities)
        if QA_pairs is None:
            return None

        # Step 4: Answer Replacement
        to_predict = []
        replace_type = []
        replace_keys = []
    
        for qa in QA_pairs:
            ans_ent_text = qa['answer']
            ans_ent_type = qa['answer_type']
            if ans_ent_text == "" or ans_ent_type == "":
                continue
            replacement = entity_replacement_dict.get(ans_ent_text)
            if replacement is not None:
                # print(f'"{ans_ent_text}:::{ans_ent_type}" -> "{replacement}"')
                # predict_input = qa['question'] + ' [SEP] ' + replacement[0] # original model
                predict_input = qa['question'] + '</s>' + replacement[0]
                # print(f">>> {predict_input}")
                to_predict.append(predict_input)
                replace_keys.append(f"{ans_ent_text}:::{ans_ent_type}")
                replace_type.append(ans_ent_type)

        # Step 5: QA2D
        if len(to_predict) == 0:
            return None
        # results = []
        # try:
        results = self.predict(to_predict)
            # print(f"results={results}")
        # except:
            # return None
        if len(results) == 0:
            return None
        
        claims_for_sample = OrderedDict()
        for ent, claim in zip(replace_keys, results):
            claims_for_sample[ent] = claim
        return claims_for_sample


    def generate(self, claims_json, claim_type: str, save_every=0, cont=False):
        claim_type = claim_type.lower()
        assert claim_type in ["supported", "refuted"]
        start = 0
        if Path(claims_json).is_file():
            if cont:
                generated_claims = read_json(claims_json)
                print(f"file exists: {claims_json}, completed: {len(generated_claims)-1}/{len(self.corpus_recs)}")
                start = len(generated_claims)
            else:
                # print("--------------FIX!!!!!!!!!!!-------------------------")
                # generated_claims = read_json(claims_json)
                raise FileExistsError(f"File already exists: {claims_json} !!!")
        else:
            generated_claims = dict() # ordered since P3.7
        cnt = 1
        for sample in tqdm(self.corpus_recs[start:], initial=start, total=len(self.corpus_recs)):
            id_ = str(sample['id'])
            if claim_type == "supported":
                claims = self.generate_supported_claims(sample)
            elif claim_type == "refuted":
                claims = self.generate_refute_global_claims(sample)
            if claims is None:
                claims = {}
            generated_claims[id_] = claims
            cnt += 1
            if save_every > 0 and cnt % save_every == 0:
                write_json(claims_json, generated_claims)

        write_json(claims_json, generated_claims)

In [9]:
confs = [
    (corpus_recs_dev, "dev"),
    (corpus_recs_tst, "test"),
    (corpus_recs_trn, "train"),
]

device = "cuda" if torch.cuda.is_available() else "cpu"

for corpus_recs, name in confs:
    claim_generator = ClaimGenerator(replacement_generator, 
                                 corpus_recs, 
                                 ner_json=Path(QACG_ROOT, "ner", f"{name}_ners-PAV-ner-CNEC.json"), 
                                 qas_json=Path(QACG_ROOT, "qa", f"{name}_qas-PAV-ner-CNEC_mt5-large-cp59000.json"), 
                                 QA2D_model_path="/home/drchajan/devel/python/FC/Zero-shot-Fact-Verification/experiments/qa2d/facebook/mbart-large-cc25_cs_CZ/checkpoint-26000", 
                                 device=device)

    claim_generator.generate(Path(QACG_ROOT, "claim", f"{name}_sup_claims-PAV-ner-CNEC_mt5-large-cp59000.json"), "supported", save_every=100)
    # claim_generator.generate(Path(QACG_ROOT, "claim", f"{name}_ref_claims-PAV-ner-CNEC_mt5-large-cp59000.json"), "refuted", save_every=100)
    # claim_generator.generate(Path(QACG_ROOT, f"{name}_ref_claims.json"), "refuted", save_every=100, cont=True)

Loading QA2D module >>>>>>>>
Running on device: cuda


100%|██████████| 550/550 [09:40<00:00,  1.05s/it] 


Loading QA2D module >>>>>>>>
Running on device: cuda


100%|██████████| 584/584 [10:19<00:00,  1.06s/it] 


Loading QA2D module >>>>>>>>
Running on device: cuda


100%|██████████| 5031/5031 [1:23:50<00:00,  1.00it/s]  


In [36]:
replacement_generator.get_options("King George")

('American colonists', 'NOUN')