In [1]:
import numpy as np
import sqlite3
import torch
from tqdm import tqdm
import unicodedata

from collections import defaultdict, OrderedDict, Counter
from dataclasses import dataclass
import datetime as dt
from itertools import chain
import os
import pathlib
from pathlib import Path
import string
import pandas as pd
import unicodedata as ud
from time import time
from typing import Dict, Type, Callable, List, Union
import sys
import ujson

from aic_nlp_utils.json import read_jsonl, read_json, write_json, write_jsonl
from aic_nlp_utils.encoding import nfc
from aic_nlp_utils.fever import fever_detokenize
from sentence_transformers import CrossEncoder, util
import textwrap

sys.path.insert(0, '/home/drchajan/devel/python/FC/ColBERTv2') # ignore other ColBERT installations

from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Indexer, Searcher
from colbert.data import Queries, Collection
from colbert import Trainer
from colbert.utilities.prepare_data import import_qacg_split, import_qacg_split_subsample, generate_original_id2pid_mapping, export_as_anserini_collection, anserini_retrieve_claims, sbert_CE_rerank, generate_triples_by_retrieval, generate_triples_by_retrieval_nway

%load_ext autoreload
%autoreload 2

  from tqdm.autonotebook import tqdm
No CUDA runtime is found, using CUDA_HOME='/mnt/appl/software/CUDA/11.7.0'
  warn("The installed version of bitsandbytes was compiled without GPU support. "


/home/drchajan/devel/python/FC/fc_env_hflarge/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32


In [16]:
def import_sqlite(in_path, out_jsonl):
    recs = []
    with sqlite3.connect(in_path, detect_types=sqlite3.PARSE_DECLTYPES) as connection:

        cursor = connection.cursor()
        cursor.execute("SELECT id, did, bid, date, keywords, text FROM documents")

        row = 0
        nexcluded = 0
        for id_, did, bid, date, keywords, text in cursor.fetchall():
            rec = {"id": id_, "did": did, "bid": bid, "date": date, "keywords": keywords, "text": text}
            recs.append(rec)
        write_jsonl(out_jsonl, recs)
    

data = import_sqlite("/mnt/data/ctknews/factcheck/par6/interim/ctk_filtered.db", "/mnt/data/ctknews/factcheck/par6/interim/jsonl/ctk_filtered_all_cols.jsonl")

In [17]:
# APPROACH = "full" # all generated data
# APPROACH = "balanced" # balanced classes
APPROACH = "balanced_shuf" # balanced classes, shuffled
# APPROACH = "fever_size" # subsampled to have exact fever distribution

LANG, NER_DIR, ANSERINI_LANG = "cs", "PAV-ner-CNEC", "cs"

# DATA_ROOT = f"/mnt/data/cro/factcheck/v1"
# DATA_CORPUS = Path(DATA_ROOT, "interim", "cro_paragraphs_filtered.jsonl")
# TRN_SIZE, DEV_SIZE, TST_SIZE = 20000, 2000, 2000

DATA_ROOT = f"/mnt/data/ctknews/factcheck/par6"
DATA_CORPUS = Path(DATA_ROOT, "interim", "jsonl", "ctk_filtered_all_cols.jsonl")
TRN_SIZE, DEV_SIZE, TST_SIZE = 40000, 4000, 4000

# DATA_ROOT = f"/mnt/data/factcheck/denikn/v1"
# DATA_CORPUS = Path(DATA_ROOT, "interim", "denikn_paragraphs.jsonl")
# TRN_SIZE, DEV_SIZE, TST_SIZE = 20000, 2000, 2000

# DATA_ROOT = f"/mnt/data/newton/parlamentni_listy/factcheck/v1"
# DATA_CORPUS = Path(DATA_ROOT, "interim", "plisty_paragraphs.jsonl")
# TRN_SIZE, DEV_SIZE, TST_SIZE = 20000, 2000, 2000

QACG_ROOT = Path(DATA_ROOT, "qacg")
# WIKI_PREDICTIONS = f"{WIKI_ROOT}/predictions"

QG_DIR = "mt5-large_all-cp126k"
QACG_DIR = "mt5-large_all-cp156k"

SPLIT_DIR = Path("splits", NER_DIR, QG_DIR, QACG_DIR)
SPLIT_ROOT = Path(QACG_ROOT, SPLIT_DIR)

CLAIM_DIR = Path("claim", NER_DIR, QG_DIR, QACG_DIR)
CLAIM_ROOT = Path(QACG_ROOT, CLAIM_DIR)


TRAIN_FILES = {
    "s": Path(CLAIM_ROOT, "train_support.json"), 
    "r": Path(CLAIM_ROOT, "train_refute.json"),
    "n": Path(CLAIM_ROOT, "train_nei.json") # JUST to generate splits for NLI, nothing else concerning ColBERT!
    }

DEV_FILES = {
    "s": Path(CLAIM_ROOT, "dev_support.json"), 
    "r": Path(CLAIM_ROOT, "dev_refute.json"),
    "n": Path(CLAIM_ROOT, "dev_nei.json")
    }

TEST_FILES = {
    "s": Path(CLAIM_ROOT, "test_support.json"), 
    "r": Path(CLAIM_ROOT, "test_refute.json"),
    "n": Path(CLAIM_ROOT, "test_nei.json")
    }

COLBERT_ROOT = Path(DATA_ROOT, "colbertv2/qacg")
LINENO2ID = Path(COLBERT_ROOT, "paragraphs_lineno2id.json")


QUERIES_ROOT = Path(COLBERT_ROOT, "queries", NER_DIR, QG_DIR, QACG_DIR)
TRIPLES_ROOT = Path(COLBERT_ROOT, "triples", NER_DIR, QG_DIR, QACG_DIR)

ANSERINI_ROOT = Path(DATA_ROOT, "anserini")
ANSERINI_COLLECTION = str(Path(ANSERINI_ROOT, "collection"))
ANSERINI_INDEX = str(Path(ANSERINI_ROOT, "index"))
ANSERINI_RETRIEVED = Path(ANSERINI_ROOT, "retrieved", NER_DIR, QG_DIR, QACG_DIR)
SPLIT_ROOT

PosixPath('/mnt/data/ctknews/factcheck/par6/qacg/splits/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k')

In [18]:
def import_corpus(corpus_file):
    # it already has correct format
    raw = read_jsonl(corpus_file, show_progress=True)
    for e in raw:
        e["id"] = nfc(e["id"])
        if "did" not in e:
            did, bid = e["id"].split("_")
            e["bid"] = bid
            e["did"] = did
        e["did"] = nfc(str(e["did"]))
        e["text"] = nfc(e["text"])
    return raw

In [19]:
corpus = import_corpus(DATA_CORPUS)
corpus[1]

0.00it [00:00, ?it/s]

{'id': '20000818F01557_1',
 'did': '20000818F01557',
 'bid': 1,
 'date': '2000-08-18T14:24:00Z',
 'keywords': 'EU;Kolumbie;automoto;Volkswagen',
 'text': 'BRUSEL 18. srpna (ČTK) - Evropská komise (EK) jako protimonopolní orgán Evropské unie se začala zabývat stížností německé automobilky Volkswagen, která není spokojena s daňovým režimem v jihoamerické Kolumbii. Automobilka je přesvědčena, že Kolumbie svou daňovou politikou diskriminuje zahraniční dovozce malých automobilů.'}

In [20]:
def print_stats(corpus):
    did_set = set([r["did"] for r in corpus])
    ndoc = len(did_set)
    npar = len(corpus)
    print(f"documents: {ndoc} paragraphs: {npar}, paragraphs per document: {npar/ndoc}")
    plens = [len(r["text"]) for r in corpus]
    print(f"paragraph len: min:{np.min(plens)}, max:{np.max(plens)}, mean:{np.mean(plens)}, median:{np.median(plens)}")

print_stats(corpus)

documents: 2151114 paragraphs: 11754759, paragraphs per document: 5.464498394785213
paragraph len: min:1, max:341451, mean:300.6763692900892, median:303.0


In [21]:
original_id2pid = generate_original_id2pid_mapping(corpus)
lineno2id = {i: r["id"] for i, r in enumerate(corpus)}

In [22]:
# RUN this just for the first time
write_json(Path(COLBERT_ROOT, "original_id2pid.json"), original_id2pid, mkdir=True)
write_jsonl(Path(COLBERT_ROOT, "collection.jsonl"), corpus, mkdir=True)
write_json(LINENO2ID, lineno2id, mkdir=True)

In [32]:
LINENO2ID

PosixPath('/mnt/data/newton/parlamentni_listy/factcheck/v1/colbertv2/qacg/paragraphs_lineno2id.json')

In [127]:
if APPROACH == "full":
    trn_data = import_qacg_split(TRAIN_FILES)
    dev_data = import_qacg_split(DEV_FILES)
    tst_data = import_qacg_split(TEST_FILES)
elif APPROACH in ["balanced", "balanced_shuf"]:
    trn_data = import_qacg_split_subsample(TRAIN_FILES, subsample=TRN_SIZE, seed=1234)
    dev_data = import_qacg_split_subsample(DEV_FILES, subsample=DEV_SIZE, seed=1234)
    tst_data = import_qacg_split_subsample(TEST_FILES, subsample=TST_SIZE, seed=1234)
    if APPROACH == "balanced_shuf":
        rng = np.random.RandomState(1234)
        rng.shuffle(trn_data)
        rng.shuffle(dev_data)
        rng.shuffle(tst_data)
elif APPROACH == "fever_size":
    print("Created in prepare_data_fever.ipynb")
    trn_data = read_jsonl(Path(SPLIT_ROOT, "train_fever_size.jsonl"))
    dev_data = read_jsonl(Path(SPLIT_ROOT, "dev_fever_size.jsonl"))
    tst_data = read_jsonl(Path(SPLIT_ROOT, "test_fever_size.jsonl"))
else:
    assert False, APPROACH

reading: /mnt/data/newton/parlamentni_listy/factcheck/v1/qacg/claim/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/train_support.json
reading: /mnt/data/newton/parlamentni_listy/factcheck/v1/qacg/claim/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/train_refute.json
reading: /mnt/data/newton/parlamentni_listy/factcheck/v1/qacg/claim/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/train_nei.json
reading: /mnt/data/newton/parlamentni_listy/factcheck/v1/qacg/claim/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_support.json
reading: /mnt/data/newton/parlamentni_listy/factcheck/v1/qacg/claim/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_refute.json
reading: /mnt/data/newton/parlamentni_listy/factcheck/v1/qacg/claim/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_nei.json
reading: /mnt/data/newton/parlamentni_listy/factcheck/v1/qacg/claim/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k/test_support.json
reading: /mnt/data/newt

In [128]:
def unique_claims(data):
    claims = set()
    new_data = []
    for e in data:
        if e["claim"] not in claims:
            new_data.append(e)
            claims.add(e["claim"])
    print(f"original claims: {len(data)}, unique: {len(new_data)}")
    return new_data

trn_data_all = unique_claims(trn_data)
dev_data_all = unique_claims(dev_data)
tst_data_all = unique_claims(tst_data)

trn_data = unique_claims([e for e in trn_data if e["label"] != 'n'])
dev_data = unique_claims([e for e in dev_data if e["label"] != 'n'])
tst_data = unique_claims([e for e in tst_data if e["label"] != 'n'])

original claims: 60000, unique: 57951
original claims: 6000, unique: 5904
original claims: 6000, unique: 5872
original claims: 40000, unique: 38639
original claims: 4000, unique: 3916
original claims: 4000, unique: 3889


In [129]:
write_jsonl(Path(SPLIT_ROOT, f"train_{APPROACH}.jsonl"), trn_data_all, mkdir=True)
write_jsonl(Path(SPLIT_ROOT, f"dev_{APPROACH}.jsonl"), dev_data_all, mkdir=True)
write_jsonl(Path(SPLIT_ROOT, f"test_{APPROACH}.jsonl"), tst_data_all, mkdir=True)
write_jsonl(Path(SPLIT_ROOT, f"train_{APPROACH}_no_nei.jsonl"), trn_data, mkdir=True)
write_jsonl(Path(SPLIT_ROOT, f"dev_{APPROACH}_no_nei.jsonl"), dev_data, mkdir=True)
write_jsonl(Path(SPLIT_ROOT, f"test_{APPROACH}_no_nei.jsonl"), tst_data, mkdir=True)
SPLIT_ROOT

PosixPath('/mnt/data/newton/parlamentni_listy/factcheck/v1/qacg/splits/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k')

In [130]:
trn_data[0], trn_data[1], trn_data[5]

({'claim': 'Pirátská strana se snaží dostat seniory na palubu.',
  'label': 'r',
  'evidence': ['H59A21CY0077_15']},
 {'claim': 'Cílem útoku je Ukrajina.',
  'label': 'r',
  'evidence': ['H59A22930039_2']},
 {'claim': 'Česká Lípa byla úklidem města.',
  'label': 's',
  'evidence': ['H59A22AH0102_2']})

In [131]:
def export_queries(data, out_file):
    queries = []
    for r in tqdm(data):
        queries.append({"query": r["claim"]})
    write_jsonl(out_file, queries, mkdir=True)

export_queries(trn_data, Path(QUERIES_ROOT, f"train_qacg_queries_{APPROACH}.jsonl"))
export_queries(dev_data, Path(QUERIES_ROOT, f"dev_qacg_queries_{APPROACH}.jsonl"))
export_queries(tst_data, Path(QUERIES_ROOT, f"test_qacg_queries_{APPROACH}.jsonl"))
QUERIES_ROOT

100%|██████████| 38639/38639 [00:00<00:00, 1230533.42it/s]
100%|██████████| 3916/3916 [00:00<00:00, 1488166.57it/s]
100%|██████████| 3889/3889 [00:00<00:00, 1669223.11it/s]


PosixPath('/mnt/data/newton/parlamentni_listy/factcheck/v1/colbertv2/qacg/queries/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k')

## Anserini Hard Negatives

We will use Anserini in the first stage to get hard negatives. 

In [73]:
export_as_anserini_collection(corpus, ANSERINI_COLLECTION)

In [74]:
!python -m pyserini.index.lucene \
    -collection JsonCollection \
    -generator DefaultLuceneDocumentGenerator \
    -threads 4 \
    -input {ANSERINI_COLLECTION} \
    -language {ANSERINI_LANG} \
    -index {ANSERINI_INDEX} \
    -storePositions -storeDocvectors -storeRaw

2023-11-26 15:36:21,795 INFO  [main] index.IndexCollection (IndexCollection.java:380) - Setting log level to INFO
2023-11-26 15:36:21,797 INFO  [main] index.IndexCollection (IndexCollection.java:383) - Starting indexer...
2023-11-26 15:36:21,797 INFO  [main] index.IndexCollection (IndexCollection.java:385) - DocumentCollection path: /mnt/data/ctknews/factcheck/par6/anserini/collection
2023-11-26 15:36:21,798 INFO  [main] index.IndexCollection (IndexCollection.java:386) - CollectionClass: JsonCollection
2023-11-26 15:36:21,798 INFO  [main] index.IndexCollection (IndexCollection.java:387) - Generator: DefaultLuceneDocumentGenerator
2023-11-26 15:36:21,798 INFO  [main] index.IndexCollection (IndexCollection.java:388) - Threads: 4
2023-11-26 15:36:21,798 INFO  [main] index.IndexCollection (IndexCollection.java:389) - Language: cs
2023-11-26 15:36:21,799 INFO  [main] index.IndexCollection (IndexCollection.java:390) - Stemmer: porter
2023-11-26 15:36:21,799 INFO  [main] index.IndexCollection

In [132]:
ANSERINI_RETRIEVED

PosixPath('/mnt/data/newton/parlamentni_listy/factcheck/v1/anserini/retrieved/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k')

In [133]:
anserini_retrieve_claims(ANSERINI_INDEX, dev_data, 128)
write_jsonl(Path(ANSERINI_RETRIEVED, f"dev_{APPROACH}.jsonl"), dev_data, mkdir=True)

anserini_retrieve_claims(ANSERINI_INDEX, tst_data, 128)
write_jsonl(Path(ANSERINI_RETRIEVED, f"test_{APPROACH}.jsonl"), tst_data, mkdir=True)

anserini_retrieve_claims(ANSERINI_INDEX, trn_data, 128)
write_jsonl(Path(ANSERINI_RETRIEVED, f"train_{APPROACH}.jsonl"), trn_data, mkdir=True)

100%|██████████| 3916/3916 [00:37<00:00, 105.19it/s]
100%|██████████| 3889/3889 [00:36<00:00, 105.96it/s]
100%|██████████| 38639/38639 [06:01<00:00, 107.01it/s]


In [134]:
trn_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"train_{APPROACH}.jsonl"))
dev_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"dev_{APPROACH}.jsonl"))
tst_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"test_{APPROACH}.jsonl"))

In [135]:
ANSERINI_RETRIEVED

PosixPath('/mnt/data/newton/parlamentni_listy/factcheck/v1/anserini/retrieved/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k')

## Triplet Generation

In [136]:
trn_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"train_{APPROACH}.jsonl"))
dev_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"dev_{APPROACH}.jsonl"))
tst_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"test_{APPROACH}.jsonl"))
# trn_data = read_jsonl(Path(ANSERINI_RETRIEVED, "train_anserini+minilm.jsonl"))
# dev_data = read_jsonl(Path(ANSERINI_RETRIEVED, "dev_anserini+minilm.jsonl"))

In [137]:
def show_retrieval(sample, corpus, search, k=3):
    claim = sample["claim"]
    bid = sample["evidence"][0]
    rec = corpus[original_id2pid[bid]]
    evidence = rec["text"]
    # retrieved = sample["retrieved"]
    print(f"CLAIM: {claim}")
    print()

    if search:
        found = f"FOUND '{search}'" if search in evidence else f"NOT FOUND '{search}'"
    else:
        found = ""

    print(f"EVIDENCE ({bid}) {found}:\n" + textwrap.fill(evidence))

    for i in range(k):
        bid = sample["retrieved"][i]
        ret = corpus[original_id2pid[bid]]["text"]
        if search:
            found = f"FOUND '{search}'" if search in ret else f"NOT FOUND '{search}'"
        else:
            found = ""
        print(f"\nRETRIEVED {i+1} ({bid}) {found}:\n" + textwrap.fill(ret))



show_retrieval(tst_data[11], corpus, search="Niles", k=3)

CLAIM: Hnutí ANO bylo historicky prvním hnutí, které vyhrálo.

EVIDENCE (H59A21DS0046_4) NOT FOUND 'Niles':
Ještě v dubnu mohli gratulace za historicky první pokoření hnutí ANO
přijímat Piráti v koalici se Starosty a Nezávislými (STAN). Ale jak se
ukázalo i zde, první vyhrání leckdy z kapsy vyhání. PirSTAN vyhnalo z
prvního místa, křivka úspěšnosti této koalice se vydala na trajektorii
postupného, leč vytrvalého pádu, v jehož průběhu ji dokázalo přeskočit
konkurenční koaliční uskupení SPOLU (ODS+TOP 09+KDU-ČSL) a staronový
hegemon žebříčku hnutí ANO.

RETRIEVED 1 (H59A21GE0032_11) NOT FOUND 'Niles':
To je sice pravda, ale když se na to podívám celkově, tak vyhrálo
hnutí ANO, vždyť kolik hlasů dostaly ty strany a straničky v
jednotlivých koalicích? Kolik dohromady dostaly dvě historicky
největší a nejúspěšnější strany ODS a ČSSD, které dřív měly přece
tolik hlasů. Ony hrály, řečeno hokejovou hantýrkou, NHL a pak najednou
padaly a padaly o soutěže dolů.

RETRIEVED 2 (H59A22F50063_9) NOT 

### Triples by Retrieval

In [138]:
nway = 128

trn_triples = generate_triples_by_retrieval_nway(trn_data, corpus, original_id2pid, nway=nway, use_evidence=True)
write_jsonl(Path(TRIPLES_ROOT, f"trn_triples_nway{nway}_evidence+anserini_{APPROACH}.jsonl"), trn_triples, mkdir=True)

dev_triples = generate_triples_by_retrieval_nway(dev_data, corpus, original_id2pid, nway=nway, use_evidence=True)
write_jsonl(Path(TRIPLES_ROOT, f"dev_triples_nway{nway}_evidence+anserini_{APPROACH}.jsonl"), dev_triples, mkdir=True)

tst_triples = generate_triples_by_retrieval_nway(tst_data, corpus, original_id2pid, nway=nway, use_evidence=True)
write_jsonl(Path(TRIPLES_ROOT, f"tst_triples_nway{nway}_evidence+anserini_{APPROACH}.jsonl"), tst_triples, mkdir=True)

  5%|▍         | 1822/38639 [00:00<00:11, 3293.73it/s]



  9%|▉         | 3581/38639 [00:01<00:09, 3512.89it/s]



 13%|█▎        | 5198/38639 [00:01<00:09, 3657.27it/s]



100%|██████████| 38639/38639 [00:08<00:00, 4309.05it/s]


generated 38639 triples with 0 failures and 24 random fixes


 65%|██████▍   | 2539/3916 [00:00<00:00, 4609.43it/s]



100%|██████████| 3916/3916 [00:00<00:00, 4660.01it/s]

generated 3916 triples with 0 failures and 2 random fixes



  0%|          | 6/3889 [00:00<02:11, 29.60it/s]



 32%|███▏      | 1253/3889 [00:00<00:01, 1869.04it/s]



100%|██████████| 3889/3889 [00:01<00:00, 2088.67it/s]

generated 3889 triples with 0 failures and 7 random fixes





In [139]:
TRIPLES_ROOT

PosixPath('/mnt/data/newton/parlamentni_listy/factcheck/v1/colbertv2/qacg/triples/PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k')

# Combine Sources

In [157]:
def combine_queries_triples_collection(
        split_files, # already created combined split
        triple_files, # target files for triples (keep same order of splits!)
        query_files, # target files for queries (keep same order of splits!)
        collection_dir, # combined collection dir
        source2triple_files, # triple files for each particular source language
        source2collection_files
        ):
    
    #import splits and generate queries (just claims) 
    splits = []
    for split_file, query_file in zip(split_files, query_files):
        data = read_jsonl(split_file)
        splits.append(data)
        recs = [{"query": r["claim"]} for r in data]
        print(f'writing query file: "{query_file}"')
        write_jsonl(query_file, recs, mkdir=True)

    print("loading collections")
    source2collection = {source: read_jsonl(collection_file, show_progress=True) for source, collection_file in source2collection_files.items()}

    print("loading triples")
    source2triples = {source: [read_jsonl(Path(triple_file_lst[0], triple_file_name)) for triple_file_name in triple_file_lst[1]] for source, triple_file_lst in source2triple_files.items()}

    new_collection = []
    used_collection_pids = defaultdict(dict) # language to original pid to new pid 
    new_triples = [[] for _ in range(len(splits))] # dict of list (splits) of lists (triples)

    for split_idx, split in enumerate(splits):
        # run through the matching split and select all samples with matching `lang`
        for s in tqdm(split, desc=f"split idx: {split_idx}"):
            # print(f"DEBUG: sample={s}")
            source = s["source"]
            # index from the original language split
            orig_idx = int(s["orig_idx"])
            # print(f"DEBUG: source={source} split_idx={split_idx} orig_idx={orig_idx}")
            trp = source2triples[source][split_idx][orig_idx]
            # print(f"DEBUG: trp={trp}")
            new_trp_id = len(new_triples[split_idx]) # generate ever increasing triple ids
            new_trp = [new_trp_id]
            # now translate original pids to new pids
            for orig_pid in trp[1:]:
                if orig_pid not in used_collection_pids[source]:
                    new_pid = len(new_collection) # ever increasing pid given by position in `the new_collection`
                    used_collection_pids[source][orig_pid] = new_pid
                    col_item = source2collection[source][orig_pid]
                    new_collection.append(col_item)
                else:
                    new_pid = used_collection_pids[source][orig_pid]
                new_trp.append(new_pid)
            # print(f"DEBUG: new_trp={new_trp}")
                    
            new_triples[split_idx].append(new_trp)
    
    assert len(new_triples) == len(triple_files)
    for nt, triple_file in zip(new_triples, triple_files):
        print(f'writing triple file: "{triple_file}"')
        write_jsonl(triple_file, nt, mkdir=True)

    collection_file = Path(collection_dir, "collection.jsonl")
    print(f'writing collection file: "{collection_file}"')
    write_jsonl(collection_file, new_collection, mkdir=True, show_progress=True)

    # Most likely makes no sense as the bids from different Wikipedia may overlap 
    # original_id2pid = {r["id"]: i for i, r in enumerate(new_collection)}
    # original_id2pid_file = Path(collection_dir, "original_id2pid.json")
    # print(f'writing original_id2pid file: "{original_id2pid_file}"')
    # write_json(original_id2pid_file, original_id2pid, mkdir=True)

In [160]:
APPROACH="balanced_shuf"
QACG_ROOT=f"/mnt/data/factcheck/qacg/news_sum/qacg"
COLBERT_ROOT=f"/mnt/data/factcheck/qacg/news_sum/colbertv2/qacg"
COLBERT_ROOT_CRO=f"/mnt/data/cro/factcheck/v1/colbertv2/qacg"
COLBERT_ROOT_CTK=f"/mnt/data/ctknews/factcheck/par6/colbertv2/qacg"
COLBERT_ROOT_DENIKN=f"/mnt/data/factcheck/denikn/v1/colbertv2/qacg"
COLBERT_ROOT_PLISTY=f"/mnt/data/newton/parlamentni_listy/factcheck/v1/colbertv2/qacg"
MODELS_CS=f"PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k"

TRIPLE_SPLITS = [
    f"dev_triples_nway128_evidence+anserini_{APPROACH}.jsonl",
    f"tst_triples_nway128_evidence+anserini_{APPROACH}.jsonl",
    f"trn_triples_nway128_evidence+anserini_{APPROACH}.jsonl",
    ]

st = combine_queries_triples_collection(
    split_files=[
        f"{QACG_ROOT}/splits/dev_{APPROACH}_no_nei.jsonl",
        f"{QACG_ROOT}/splits/test_{APPROACH}_no_nei.jsonl",
        f"{QACG_ROOT}/splits/train_{APPROACH}_no_nei.jsonl",
    ],
    query_files=[
        f"{COLBERT_ROOT}/queries/dev_qacg_queries_{APPROACH}.jsonl",
        f"{COLBERT_ROOT}/queries/test_qacg_queries_{APPROACH}.jsonl",
        f"{COLBERT_ROOT}/queries/train_qacg_queries_{APPROACH}.jsonl",
    ],
    triple_files=[f"{COLBERT_ROOT}/triples/{ts}" for ts in TRIPLE_SPLITS],
    collection_dir=COLBERT_ROOT,
    source2triple_files = {
        "cro": (f"{COLBERT_ROOT_CRO}/triples/{MODELS_CS}", TRIPLE_SPLITS),
        "ctk": (f"{COLBERT_ROOT_CTK}/triples/{MODELS_CS}", TRIPLE_SPLITS),
        "denikn": (f"{COLBERT_ROOT_DENIKN}/triples/{MODELS_CS}", TRIPLE_SPLITS),
        "plisty": (f"{COLBERT_ROOT_PLISTY}/triples/{MODELS_CS}", TRIPLE_SPLITS),
    },
    source2collection_files = {
        "cro": f"{COLBERT_ROOT_CRO}/collection.jsonl",
        "ctk": f"{COLBERT_ROOT_CTK}/collection.jsonl",
        "denikn": f"{COLBERT_ROOT_DENIKN}/collection.jsonl",
        "plisty": f"{COLBERT_ROOT_PLISTY}/collection.jsonl",
    }
)

writing query file: "/mnt/data/factcheck/qacg/news_sum/colbertv2/qacg/queries/dev_qacg_queries_balanced_shuf.jsonl"
writing query file: "/mnt/data/factcheck/qacg/news_sum/colbertv2/qacg/queries/test_qacg_queries_balanced_shuf.jsonl"
writing query file: "/mnt/data/factcheck/qacg/news_sum/colbertv2/qacg/queries/train_qacg_queries_balanced_shuf.jsonl"
loading collections


0.00it [00:00, ?it/s]

0.00it [00:00, ?it/s]

0.00it [00:00, ?it/s]

0.00it [00:00, ?it/s]

loading triples


split idx: 0: 100%|██████████| 19644/19644 [00:01<00:00, 11312.34it/s]
split idx: 1: 100%|██████████| 19657/19657 [00:01<00:00, 11925.49it/s]
split idx: 2: 100%|██████████| 194972/194972 [00:14<00:00, 13077.11it/s]


writing triple file: "/mnt/data/factcheck/qacg/news_sum/colbertv2/qacg/triples/dev_triples_nway128_evidence+anserini_balanced_shuf.jsonl"
writing triple file: "/mnt/data/factcheck/qacg/news_sum/colbertv2/qacg/triples/tst_triples_nway128_evidence+anserini_balanced_shuf.jsonl"
writing triple file: "/mnt/data/factcheck/qacg/news_sum/colbertv2/qacg/triples/trn_triples_nway128_evidence+anserini_balanced_shuf.jsonl"
writing collection file: "/mnt/data/factcheck/qacg/news_sum/colbertv2/qacg/collection.jsonl"


  0%|          | 0/7677707 [00:00<?, ?it/s]

In [161]:
QACG_ROOT

'/mnt/data/factcheck/qacg/news_sum/qacg'