In [1]:
from collections import Counter, OrderedDict
import ujson
import pathlib
from pathlib import Path
import sqlite3
import sys
from tqdm import tqdm
from typing import Dict, List, Set
import unicodedata
import uuid


from aic_nlp_utils.json import read_jsonl, read_json, write_jsonl, write_jsonl
from aic_nlp_utils.fever import fever_detokenize
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs
import stanza
# stanza.download("en")

sys.path.append('Claim_Generation')
from T5_QG import pipeline
from distractor_generation import Distractor_Generation

%load_ext autoreload
%autoreload 2

- This is full version aimed at generating SUPPORTED and REFUTED claims needed for evidence retrieval, keeping NEIs for later.
- It is an extended version of the `claim_generation.ipynb`.
- Aimed to generate data for post LREV EnFEVER models (e.g., ColBERT v2).
- The source data are based on **full** Wikipedia dump split to paragraphs. See `drchajan/notebooks/download_wiki.ipynb`
- Only a sample of the full corpus is matched to pages appearing in the LREV EnFEVER so the generated claims (and models trained on them) are somewhat comparable to original EnFEVER dataset.
- NEI context are retrieved from other paragraphs of the same Wikipedia page.
- Fixed input and output formats for those we use in AIC.

**Notes**
- Currently ignoring multi-hops - single evidence documents are used only.

In [2]:
# DATA_DIR="/mnt/data/factcheck/claim_extraction/csfeversum/en/0.0.2"
LANG = "en"
DATE = "20230220"
WIKI_ROOT = f"/mnt/data/factcheck/wiki/{LANG}/20230220"
QACG_ROOT = f"{WIKI_ROOT}/qacg"
WIKI_CORPUS = f"{WIKI_ROOT}/paragraphs/{LANG}wiki-20230220-paragraphs.jsonl"
FEVER_ROOT = "/mnt/data/factcheck/fever/data-en-lrev/fever-data"

In [3]:
corpus = read_jsonl(WIKI_CORPUS)
corpus_id2idx = {r["id"]: i for i, r in enumerate(corpus)}
corpus_pages = set(r["did"] for r in corpus)
len(corpus_pages)

6204729

In [4]:
corpus[0]

{'id': 'Anarchism_0',
 'did': 'Anarchism',
 'bid': 0,
 'text': 'Anarchism',
 'url': 'https://en.wikipedia.org/wiki?curid=12',
 'revid': '6068332'}

In [5]:
def extract_fever_evidence_pages(split_jsonls: List):
    pages = set()
    for jsonl in split_jsonls:
        print(jsonl)
        split = read_jsonl(jsonl)
        for rec in split:
            if rec["verifiable"] == "VERIFIABLE":
                for eset in rec["evidence"]:
                    for ev in eset:
                        pages.add(ev[2])
    return pages



fever_pages_trn = extract_fever_evidence_pages([Path(FEVER_ROOT, "train.jsonl")])
fever_pages_dev = extract_fever_evidence_pages([Path(FEVER_ROOT, "paper_dev.jsonl")])
fever_pages_tst = extract_fever_evidence_pages([Path(FEVER_ROOT, "paper_test.jsonl")])
len(fever_pages_trn), len(fever_pages_dev), len(fever_pages_tst)

/mnt/data/factcheck/fever/data-en-lrev/fever-data/train.jsonl
/mnt/data/factcheck/fever/data-en-lrev/fever-data/paper_dev.jsonl
/mnt/data/factcheck/fever/data-en-lrev/fever-data/paper_test.jsonl


(12549, 1460, 1499)

In [7]:
# trying to match fever pages to corpus pages so we can generate claims based on topics comparable to EnFEVER
# corpus is based on newer dump so the match can't be perfect
def match_fever_to_corpus_pages(fever_pages: Set[str], corpus_pages: Set[str]):
    fever_pages = set(fever_detokenize(p) for p in fever_pages)
    matched = fever_pages.intersection(corpus_pages)
    print(f"matched {len(matched)}/{len(fever_pages)} pages")
    return matched

sel_corpus_pages_trn = match_fever_to_corpus_pages(fever_pages_trn, corpus_pages)
sel_corpus_pages_dev = match_fever_to_corpus_pages(fever_pages_dev, corpus_pages)
sel_corpus_pages_tst = match_fever_to_corpus_pages(fever_pages_tst, corpus_pages)

matched 11247/12549 pages
matched 1339/1460 pages
matched 1358/1499 pages


In [8]:
def extract_corpus_pages(corpus, corpus_id2idx, sel_corpus_pages: Set[str]):
    recs = []
    for p in sel_corpus_pages:
        corpus_rec = corpus[corpus_id2idx[p + "_1"]] # take the first paragraph - should roughly mimic the leadning parts from EnFEVER
        recs.append(corpus_rec)
    return recs
    
corpus_recs_trn = extract_corpus_pages(corpus, corpus_id2idx, sel_corpus_pages_trn)
corpus_recs_dev = extract_corpus_pages(corpus, corpus_id2idx, sel_corpus_pages_dev)
corpus_recs_tst = extract_corpus_pages(corpus, corpus_id2idx, sel_corpus_pages_tst)

In [9]:
def extract_ners(corpus_recs, ner_json):
    # for each text gives a triplet (ner, ner_type, ner-ner_type count in text)
    # the triplets are sorted by decreasing count
    stanza_nlp = stanza.Pipeline('en', use_gpu = True, processors="tokenize,ner")
    entity_dict = OrderedDict()
    for l in tqdm(corpus_recs):
        text = l["text"]
        pass_doc = stanza_nlp(text)
        ner_pairs = [(ent.text, ent.type) for ent in pass_doc.ents] # text-type pairs
        ner_cnts = Counter(ner_pairs) # their 
        ners_unique_with_counts =  [(p[0], p[1], ner_cnts[(p[0], p[1])]) for p in set(ner_pairs)]
        ners_unique_with_counts = sorted(ners_unique_with_counts, key=lambda n: -n[2])
        entity_dict[l["id"]] = ners_unique_with_counts
    Path(ner_json).parent.mkdir(parents=True, exist_ok=True)    
    write_json(ner_json, entity_dict)

extract_ners(corpus_recs_dev, Path(QACG_ROOT, "dev_ners.json"))
extract_ners(corpus_recs_tst, Path(QACG_ROOT, "test_ners.json"))
extract_ners(corpus_recs_trn, Path(QACG_ROOT, "train_ners.json"))

2023-04-03 12:35:46 INFO: Loading these models for language: en (English):
| Processor | Package   |
-------------------------
| tokenize  | ewt       |
| ner       | ontonotes |

2023-04-03 12:35:46 INFO: Use device: gpu
2023-04-03 12:35:46 INFO: Loading: tokenize
2023-04-03 12:35:48 INFO: Loading: ner
2023-04-03 12:35:49 INFO: Done loading processors!
100%|██████████| 1339/1339 [00:53<00:00, 24.98it/s]
2023-04-03 12:36:42 INFO: Loading these models for language: en (English):
| Processor | Package   |
-------------------------
| tokenize  | ewt       |
| ner       | ontonotes |

2023-04-03 12:36:42 INFO: Use device: gpu
2023-04-03 12:36:42 INFO: Loading: tokenize
2023-04-03 12:36:42 INFO: Loading: ner
2023-04-03 12:36:43 INFO: Done loading processors!
100%|██████████| 1358/1358 [00:57<00:00, 23.58it/s]
2023-04-03 12:37:40 INFO: Loading these models for language: en (English):
| Processor | Package   |
-------------------------
| tokenize  | ewt       |
| ner       | ontonotes |

2023

In [11]:
def generate_qas(corpus_recs, ner_json, qas_json):
    # QG NLP object
    gpu_index = 0

    print('Loading QG module >>>>>>>>')
    qg_nlp = pipeline("question-generation", model='valhalla/t5-base-qg-hl', qg_format="highlight", gpu_index = gpu_index)
    print('QG module loaded.')

    ners = read_json(ner_json)

    qas = OrderedDict()
    invalid_sample = 0
    for l in tqdm(corpus_recs):
        entities = ners[str(l['id'])]

        # create a batch
        sources, answers = [], []
        for ent_text, ent_type, ent_cnt in entities:
            sources.append(l['text'])
            answers.append(ent_text)
            
        # question generation
        if len(sources) > 0 and len(sources) == len(answers):
            results = []
            try:
                results = qg_nlp.batch_qg_with_answer(sources, answers)
            except:
                invalid_sample += 1

            if len(results) == 0:
                continue
            
            # save results
            result_for_sample = {}
            for ind, QA in enumerate(results):
                ent_text, ent_type, _ = entities[ind]
                question = QA['question']
                answer = QA['answer']
                result_for_sample[f'{ent_text}:::{ent_type}'] = [question, answer]

            qas[str(l['id'])] = result_for_sample
        else:
            invalid_sample += 1

    print(f'#invalid samples: {invalid_sample}')
    Path(qas_json).parent.mkdir(parents=True, exist_ok=True)
    write_json(qas_json, qas)


generate_qas(corpus_recs_dev, Path(QACG_ROOT, "dev_ners.json"), Path(QACG_ROOT, "dev_qas.json"))
generate_qas(corpus_recs_tst, Path(QACG_ROOT, "test_ners.json"), Path(QACG_ROOT, "test_qas.json"))
generate_qas(corpus_recs_trn, Path(QACG_ROOT, "train_ners.json"), Path(QACG_ROOT, "train_qas.json"))

Loading QG module >>>>>>>>
QG module loaded.


100%|██████████| 1358/1358 [14:21<00:00,  1.58it/s]


#invalid samples: 51
Loading QG module >>>>>>>>
QG module loaded.


100%|██████████| 11247/11247 [1:59:33<00:00,  1.57it/s] 


#invalid samples: 373


In [9]:
class ClaimGenerator:
    def __init__(self, corpus_recs, ner_json, qas_json, QA2D_model_path, sense_to_vec_path, gpu_index=0):
        # QA2D model object
        print('Loading QA2D module >>>>>>>>')
        model_args = Seq2SeqArgs()
        model_args.max_length = 64
        model_args.silent = True

        self.QA2D_model = Seq2SeqModel(
            encoder_decoder_type="bart", 
            encoder_decoder_name=QA2D_model_path,
            cuda_device=gpu_index,
            args=model_args
        )

        print('Loading Replacement Generator module >>>>>>>>')
        self.replacement_generator = Distractor_Generation(sense2vec_path = sense_to_vec_path, T = 0.7)

        self.corpus_recs = corpus_recs
        self.ners = read_json(ner_json)
        self.qas = read_json(qas_json)


    def _load_passage_entities(self, id_):
        passage_entities = []
        for ent_text, ent_type, _ in self.ners[id_]:
            passage_entities.append(f'{ent_text}:::{ent_type}') # group by entity name and type as in the QAS file
        return passage_entities
    
    def _load_precomputed_qas_for_entities(self, id_, passage_entities):
        if id_ not in self.qas:
            print(f"missing id: {id_}")
            return None
        QA_for_sample = self.qas[id_]
        QA_pairs = []
        for entity in passage_entities:
            if entity in QA_for_sample:
                ent_text, ent_type = entity.split(':::')
                question, answer = QA_for_sample[entity]
                QA_pairs.append({'question': question, 'answer': answer, 'answer_type': ent_type})
            else:
                print(f"missing entity: {entity} for id: {id_}")
                return None
        if len(QA_pairs) == 0:
            print(f"zero length pairs for id: {id_}")
            return None
        return QA_pairs 
        

    def generate_supported_claims(self, sample):
        texts, id_ = sample['text'], str(sample['id'])

        # Step 1: load entities in text
        passage_entities = self._load_passage_entities(id_)
        if len(passage_entities) == 0: # no NERs
            return None 

        # Step 2: load precomputed QAs for entities
        QA_pairs = self._load_precomputed_qas_for_entities(id_, passage_entities)
        if QA_pairs is None:
            return None

        # Step 3: QA2D
        to_predict = [qa['question'] + ' [SEP] ' + qa['answer'] for qa in QA_pairs]
        results = []
        # try:
        results = self.QA2D_model.predict(to_predict)
        # except:
            # return None
        if len(results) == 0:
            print(f"zero length results for id: {id_}")
            return None

        assert len(results) == len(QA_pairs)

        claims_for_sample = OrderedDict()
        for ent, claim in zip(passage_entities, results):
            claims_for_sample[ent] = claim
        return claims_for_sample

    def generate_refute_global_claims(self, sample):
        texts, id_ = sample['text'], str(sample['id'])

        # Step 1: load entities in text
        passage_entities = self._load_passage_entities(id_)
        if len(passage_entities) == 0: # no NERs
            return None 
        
        # Step 2: get entity replacement
        entity_replacement_dict = {} # get replacement beforehand to save time

        valid_entities = set()
        for ent in passage_entities:
            ent_text, _ = ent.split(':::')
            replacement = self.replacement_generator.get_options(ent_text)
            if replacement is not None:
                entity_replacement_dict[ent_text] = replacement
                valid_entities.add(ent)

        # Step 3: load precomputed QAs for entities
        QA_pairs = self._load_precomputed_qas_for_entities(id_, passage_entities)
        if QA_pairs is None:
            return None

        # Step 4: Answer Replacement
        to_predict = []
        replace_type = []
        replace_keys = []
        for qa in QA_pairs:
            ans_ent_text = qa['answer']
            ans_ent_type = qa['answer_type']
            if ans_ent_text == "" or ans_ent_type == "":
                continue
            replacement = entity_replacement_dict.get(ans_ent_text)
            if replacement is not None:
                to_predict.append(qa['question'] + ' [SEP] ' + replacement[0])
                replace_keys.append(f"{ans_ent_text}:::{ans_ent_type}")
                replace_type.append(ans_ent_type)
        
        # Step 5: QA2D
        if len(to_predict) == 0:
            return None
        results = []
        try:
            results = self.QA2D_model.predict(to_predict)
        except:
            return None
        if len(results) == 0:
            return None
        
        claims_for_sample = OrderedDict()
        for ent, claim in zip(replace_keys, results):
            claims_for_sample[ent] = claim
        return claims_for_sample


    def generate(self, claims_json, claim_type: str):
        claim_type = claim_type.lower()
        assert claim_type in ["supported", "refuted"]
        generated_claims = OrderedDict()
        for sample in tqdm(self.corpus_recs):
            id_ = str(sample['id'])
            if claim_type == "supported":
                claims = self.generate_supported_claims(sample)
            elif claim_type == "refuted":
                claims = self.generate_refute_global_claims(sample)
            if claims is None:
                claims = {}
            generated_claims[id_] = claims

        write_json(claims_json, generated_claims)

In [10]:
confs = [
    (corpus_recs_dev, "dev"),
    (corpus_recs_tst, "test"),
    (corpus_recs_trn, "train"),
]

for corpus_recs, name in confs:
    claim_generator = ClaimGenerator(corpus_recs, 
                                 ner_json=Path(QACG_ROOT, f"{name}_ners.json"), 
                                 qas_json=Path(QACG_ROOT, f"{name}_qas.json"), 
                                 QA2D_model_path="dependencies/QA2D_model", 
                                 sense_to_vec_path="dependencies/s2v_old")

    claim_generator.generate(Path(QACG_ROOT, f"{name}_sup_claims.json"), "supported")
    claim_generator.generate(Path(QACG_ROOT, f"{name}_ref_claims.json"), "refuted")

Loading QA2D module >>>>>>>>
Loading Replacement Generator module >>>>>>>>


100%|██████████| 1339/1339 [1:50:44<00:00,  4.96s/it] 


Loading QA2D module >>>>>>>>
Loading Replacement Generator module >>>>>>>>


100%|██████████| 1358/1358 [1:50:21<00:00,  4.88s/it] 


Loading QA2D module >>>>>>>>
Loading Replacement Generator module >>>>>>>>


100%|██████████| 11247/11247 [15:36:13<00:00,  4.99s/it]  


In [19]:
replacement_generator = Distractor_Generation(sense2vec_path="dependencies/s2v_old/", T=0.7)

In [36]:
replacement_generator.get_options("King George")

('American colonists', 'NOUN')