In [1]:
from collections import Counter, OrderedDict
import json
import pathlib
from pathlib import Path
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs
import stanza
# stanza.download("en")
import sys
import uuid

sys.path.append('Claim_Generation')

from T5_QG import pipeline


from tqdm import tqdm

This is a simplified version generating only SUPPORTED claims. Also fixed input and output formats for those we use in AIC.

In [2]:
DATA_DIR="/mnt/data/factcheck/claim_extraction/csfeversum/en/0.0.2"

In [3]:
def read_json(fname, object_pairs_hook=OrderedDict):
    with open(fname, 'r') as json_file:
        data = json.load(json_file, object_pairs_hook=object_pairs_hook)
    return data

def write_json(fname, data, indent=3):
    with open(str(fname), 'w', encoding='utf8') as json_file:
        json.dump(data, json_file, ensure_ascii=False, indent=indent, default=str)


def read_jsonl(jsonl):
    with open(jsonl, 'r') as json_file:
        data = []
        for jline in json_file:
            rec = json.loads(jline, object_pairs_hook=OrderedDict)
            data.append(rec)
    return data
    

def write_jsonl(jsonl, data):
    # data is an iterable (list) of JSON-compatible structures (OrderedDict)
    with open(jsonl, 'w', encoding='utf8') as json_file:
        for r in data:
            json.dump(r, json_file, ensure_ascii=False, default=str)
            json_file.write("\n")

In [4]:
def extract_ners(input_jsonl, ner_json):
    # for each text gives a triplet (ner, ner_type, ner-ner_type count in text)
    # the triplets are sorted by decreasing count
    stanza_nlp = stanza.Pipeline('en', use_gpu = True, processors="tokenize,ner")
    entity_dict = OrderedDict()
    src = read_jsonl(input_jsonl)
    for l in tqdm(src):
        text = l["text"]
        pass_doc = stanza_nlp(text)
        ner_pairs = [(ent.text, ent.type) for ent in pass_doc.ents] # text-type pairs
        ner_cnts = Counter(ner_pairs) # their 
        ners_unique_with_counts =  [(p[0], p[1], ner_cnts[(p[0], p[1])]) for p in set(ner_pairs)]
        ners_unique_with_counts = sorted(ners_unique_with_counts, key=lambda n: -n[2])
        entity_dict[l["id"]] = ners_unique_with_counts
    write_json(ner_json, entity_dict)

# extract_ners(Path(DATA_DIR, "test.jsonl"), Path(DATA_DIR, "qacg", "test_ners.json"))
extract_ners(Path(DATA_DIR, "train.jsonl"), Path(DATA_DIR, "qacg", "train_ners.json"))

2023-02-16 11:32:02 INFO: Loading these models for language: en (English):
| Processor | Package   |
-------------------------
| tokenize  | ewt       |
| ner       | ontonotes |

2023-02-16 11:32:02 INFO: Use device: gpu
2023-02-16 11:32:02 INFO: Loading: tokenize
2023-02-16 11:32:08 INFO: Loading: ner
2023-02-16 11:32:10 INFO: Done loading processors!
100%|██████████| 42383/42383 [49:25<00:00, 14.29it/s]  


In [5]:
def generate_qas(input_jsonl, ner_json, qas_json):
    # QG NLP object
    gpu_index = 0

    print('Loading QG module >>>>>>>>')
    qg_nlp = pipeline("question-generation", model='valhalla/t5-base-qg-hl', qg_format="highlight", gpu_index = gpu_index)
    print('QG module loaded.')

    src = read_jsonl(input_jsonl)
    ners = read_json(ner_json)

    qas = OrderedDict()
    invalid_sample = 0
    for l in tqdm(src):
        entities = ners[str(l['id'])]

        # create a batch
        sources, answers = [], []
        for ent_text, ent_type, ent_cnt in entities:
            sources.append(l['text'])
            answers.append(ent_text)
            
        # question generation
        if len(sources) > 0 and len(sources) == len(answers):
            results = []
            try:
                results = qg_nlp.batch_qg_with_answer(sources, answers)
            except:
                invalid_sample += 1

            if len(results) == 0:
                continue
            
            # save results
            result_for_sample = {}
            for ind, QA in enumerate(results):
                ent_text, ent_type, _ = entities[ind]
                question = QA['question']
                answer = QA['answer']
                result_for_sample[f'{ent_text}:::{ent_type}'] = [question, answer]

            qas[str(l['id'])] = result_for_sample
        else:
            invalid_sample += 1

    print(f'#invalid samples: {invalid_sample}')
    write_json(qas_json, qas)


# generate_qas(Path(DATA_DIR, "test.jsonl"), Path(DATA_DIR, "qacg", "test_ners.json"), Path(DATA_DIR, "qacg", "test_qas.json"))
generate_qas(Path(DATA_DIR, "train.jsonl"), Path(DATA_DIR, "qacg", "train_ners.json"), Path(DATA_DIR, "qacg", "train_qas.json"))

Loading QG module >>>>>>>>


Downloading (…)okenizer_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/15.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/1.79k [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/892M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

Downloading (…)ve/main/spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/31.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/892M [00:00<?, ?B/s]

QG module loaded.


  8%|▊         | 3479/42383 [1:29:47<15:16:24,  1.41s/it]

  8%|▊         | 3482/42383 [1:29:51<13:56:17,  1.29s/it]

In [8]:
def generate_claims(input_jsonl, ner_json, qas_json, claims_json, QA2D_model_path, gpu_index=0):
    # QA2D model object
    print('Loading QA2D module >>>>>>>>')
    model_args = Seq2SeqArgs()
    model_args.max_length = 64
    model_args.silent = True

    QA2D_model = Seq2SeqModel(
        encoder_decoder_type="bart", 
        encoder_decoder_name=QA2D_model_path,
        cuda_device=gpu_index,
        args=model_args
    )

    src = read_jsonl(input_jsonl)
    ners = read_json(ner_json)
    qas = read_json(qas_json)

    def claims_for_sample(sample):
        texts, id_ = sample['text'], str(sample['id'])

        # Step 1: load entities in text
        passage_entities = []
        for ent_text, ent_type, _ in ners[id_]:
            passage_entities.append(f'{ent_text}:::{ent_type}')
        if len(passage_entities) == 0:
            # no NERs
            return None 

        # Step 2: load precomputed QAs for entities
        if id_ not in qas:
            print(f"missing id: {id_}")
            return None
        QA_for_sample = qas[id_]
        QA_pairs = []
        for entity in passage_entities:
            if entity in QA_for_sample:
                question, answer = QA_for_sample[entity]
                QA_pairs.append({'question': question, 'answer': answer})
            else:
                print(f"missing entity: {entity} for id: {id_}")
                return None
        if len(QA_pairs) == 0:
            print(f"zero length pairs for id: {id_}")
            return None

        # Step 3: QA2D
        to_predict = [qa['question'] + ' [SEP] ' + qa['answer'] for qa in QA_pairs]
        results = []
        # try:
        results = QA2D_model.predict(to_predict)
        # except:
            # return None
        if len(results) == 0:
            print(f"zero length results for id: {id_}")
            return None

        assert len(results) == len(QA_pairs)

        claims_for_sample = OrderedDict()
        for ent, claim in zip(passage_entities, results):
            claims_for_sample[ent] = claim
        return claims_for_sample

    generated_claims = OrderedDict()
    for sample in tqdm(src[:]):
        id_ = str(sample['id'])
        claims = claims_for_sample(sample)
        if claims is None:
            claims = {}
        generated_claims[id_] = claims

    write_json(claims_json, generated_claims)

In [9]:
# generate_claims(
#     input_jsonl=Path(DATA_DIR, "test.jsonl"),
#     ner_json=Path(DATA_DIR, "qacg", "test_ners.json"),
#     qas_json=Path(DATA_DIR, "qacg", "test_qas.json"),
#     claims_json=Path(DATA_DIR, "qacg", "test_claims.json"),
#     QA2D_model_path="dependencies/QA2D_model",
#     gpu_index=0)

Loading QA2D module >>>>>>>>


 77%|███████▋  | 4058/5288 [1:12:18<13:58,  1.47it/s]

missing id: 33582


 80%|███████▉  | 4209/5288 [1:15:10<29:11,  1.62s/it]

missing id: 11387


 82%|████████▏ | 4328/5288 [1:17:25<14:17,  1.12it/s]

missing id: 10898
missing id: 10956
missing id: 12681
missing id: 12937
missing id: 14703
missing id: 16676
missing id: 18741
missing id: 18985
missing id: 19645
missing id: 23393
missing id: 23443
missing id: 26663
missing id: 28920
missing id: 29516
missing id: 30371
missing id: 30613
missing id: 31112
missing id: 31325
missing id: 34091
missing id: 35831
missing id: 38082
missing id: 40064
missing id: 40149
missing id: 42307
missing id: 45006
missing id: 45622
missing id: 46943
missing id: 49045
missing id: 51562
missing id: 54051
missing id: 59046
missing id: 65215
missing id: 6899


 88%|████████▊ | 4633/5288 [1:21:53<09:13,  1.18it/s]

missing id: 11537
missing id: 12943
missing id: 13271
missing id: 13442
missing id: 13651
missing id: 13980
missing id: 14049
missing id: 15454
missing id: 15749
missing id: 17596
missing id: 18178
missing id: 18589
missing id: 20040
missing id: 21972
missing id: 22560
missing id: 25971
missing id: 26397
missing id: 27224
missing id: 28231
missing id: 28999
missing id: 29643
missing id: 30025
missing id: 30544
missing id: 30693
missing id: 31944
missing id: 32674
missing id: 35182
missing id: 35637
missing id: 38876
missing id: 40427
missing id: 41853
missing id: 43778
missing id: 44256
missing id: 44399
missing id: 47566
missing id: 48167
missing id: 48216
missing id: 49577
missing id: 50056
missing id: 51427
missing id: 51450
missing id: 52280
missing id: 54286
missing id: 54292
missing id: 54921
missing id: 57709
missing id: 58912
missing id: 61076
missing id: 61761
missing id: 62458
missing id: 64294
missing id: 64875
missing id: 65908
missing id: 66854
missing id: 67035
missing id

100%|██████████| 5288/5288 [1:30:31<00:00,  1.03s/it]


In [None]:
generate_claims(
    input_jsonl=Path(DATA_DIR, "train.jsonl"),
    ner_json=Path(DATA_DIR, "qacg", "train_ners.json"),
    qas_json=Path(DATA_DIR, "qacg", "train_qas.json"),
    claims_json=Path(DATA_DIR, "qacg", "train_claims.json"),
    QA2D_model_path="dependencies/QA2D_model",
    gpu_index=0)