In [1]:
import numpy as np
import sqlite3
import torch
from tqdm import tqdm
import unicodedata

from collections import defaultdict, OrderedDict, Counter
from dataclasses import dataclass
import datetime as dt
from itertools import chain
import os
import pathlib
from pathlib import Path
import string
import pandas as pd
import unicodedata as ud
from time import time
from typing import Dict, Type, Callable, List, Union
import sys
import ujson

from aic_nlp_utils.json import read_jsonl, read_json, write_json, write_jsonl
from aic_nlp_utils.encoding import nfc
from aic_nlp_utils.fever import fever_detokenize
from sentence_transformers import CrossEncoder, util
import textwrap

sys.path.insert(0, '/home/drchajan/devel/python/FC/ColBERTv2') # ignore other ColBERT installations

from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Indexer, Searcher
from colbert.data import Queries, Collection
from colbert import Trainer
from colbert.utilities.prepare_data import import_qacg_split, import_qacg_split_subsample, generate_original_id2pid_mapping, export_as_anserini_collection, anserini_retrieve_claims, sbert_CE_rerank, generate_triples_by_retrieval, generate_triples_by_retrieval_nway

%load_ext autoreload
%autoreload 2

  from tqdm.autonotebook import tqdm


In [2]:
# APPROACH = "full" # all generated data
# APPROACH = "balanced" # balanced classes
APPROACH = "balanced_shuf" # balanced classes, shuffled
# APPROACH = "fever_size" # subsampled to have exact fever distribution

# Anserini fails for Polish so use the default English config
# LANG, NER_DIR, ANSERINI_LANG = "cs", "PAV-ner-CNEC", "cs"
LANG, NER_DIR, ANSERINI_LANG = "en", "stanza", "en"
# LANG, NER_DIR, ANSERINI_LANG = "sk", "crabz_slovakbert-ner", "sk"
# LANG, NER_DIR, ANSERINI_LANG = "pl", "stanza", "en"

# DATE = "20230220"
# DATE = "20230801"
DATE = "20240201" # CEDMO

WIKI_ROOT = f"/mnt/data/factcheck/wiki/{LANG}/{DATE}"
WIKI_CORPUS = f"{WIKI_ROOT}/paragraphs/{LANG}wiki-{DATE}-paragraphs.jsonl"
WIKI_LINENO2ID = f"{WIKI_ROOT}/paragraphs/{LANG}wiki-{DATE}-paragraphs_lineno2id.json"
# WIKI_PREDICTIONS = f"{WIKI_ROOT}/predictions"

QACG_ROOT = f"{WIKI_ROOT}/qacg"

QG_DIR = "mt5-large_all-cp126k"
QACG_DIR = "mt5-large_all-cp156k"

SPLIT_DIR = Path("splits", NER_DIR, QG_DIR, QACG_DIR)
SPLIT_ROOT = Path(QACG_ROOT, SPLIT_DIR)

CLAIM_DIR = Path("claim", NER_DIR, QG_DIR, QACG_DIR)
CLAIM_ROOT = Path(QACG_ROOT, CLAIM_DIR)


TRAIN_FILES = {
    "s": Path(CLAIM_ROOT, "train_support.json"), 
    "r": Path(CLAIM_ROOT, "train_refute.json"),
    # "n": Path(CLAIM_ROOT, "train_nei.json")
    }

DEV_FILES = {
    "s": Path(CLAIM_ROOT, "dev_support.json"), 
    "r": Path(CLAIM_ROOT, "dev_refute.json"),
    # "n": Path(CLAIM_ROOT, "dev_nei.json")
    }

TEST_FILES = {
    "s": Path(CLAIM_ROOT, "test_support.json"), 
    "r": Path(CLAIM_ROOT, "test_refute.json"),
    # "n": Path(CLAIM_ROOT, "test_nei.json")
    }

COLBERT_ROOT = f"/mnt/data/factcheck/wiki/{LANG}/{DATE}/colbertv2/qacg"

QUERIES_ROOT = Path(COLBERT_ROOT, "queries", NER_DIR, QG_DIR, QACG_DIR)
TRIPLES_ROOT = Path(COLBERT_ROOT, "triples", NER_DIR, QG_DIR, QACG_DIR)

ANSERINI_ROOT = Path(WIKI_ROOT, "anserini")
ANSERINI_COLLECTION = str(Path(ANSERINI_ROOT, "collection"))
ANSERINI_INDEX = str(Path(ANSERINI_ROOT, "index"))
ANSERINI_RETRIEVED = Path(ANSERINI_ROOT, "retrieved", NER_DIR, QG_DIR, QACG_DIR)
SPLIT_ROOT

PosixPath('/mnt/data/factcheck/wiki/en/20240201/qacg/splits/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k')

In [3]:
def import_corpus(corpus_file):
    # it already has correct format
    raw = read_jsonl(corpus_file, show_progress=True)
    for e in raw:
        e["id"] = nfc(e["id"])
        e["did"] = nfc(e["did"])
        e["text"] = nfc(e["text"])
    return raw

In [4]:
corpus = import_corpus(WIKI_CORPUS)
corpus[1]

0.00it [00:00, ?it/s]

{'id': 'Anarchism_2',
 'did': 'Anarchism',
 'bid': 2,
 'text': 'Anarchism\n\nAnarchists employ diverse approaches, which may be generally divided into revolutionary and evolutionary strategies; there is significant overlap between the two. Evolutionary methods try to simulate what an anarchist society might be like, but revolutionary tactics, which have historically taken a violent turn, aim to overthrow authority and the state. Many facets of human civilization have been influenced by anarchist theory, critique, and praxis.\n\nAnarchism\n\nEtymology, terminology, and definition.\n\nThe etymological origin of "anarchism" is from the Ancient Greek "anarkhia", meaning "without a ruler", composed of the prefix "an-" ("without") and the word "arkhos" ("leader" or "ruler"). The suffix "-ism" denotes the ideological current that favours anarchy. "Anarchism" appears in English from 1642 as "anarchisme" and "anarchy" from 1539; early English usages emphasised a sense of disorder. Various facti

In [5]:
def print_stats(corpus):
    did_set = set([r["did"] for r in corpus])
    ndoc = len(did_set)
    npar = len(corpus)
    print(f"documents: {ndoc} paragraphs: {npar}, paragraphs per document: {npar/ndoc}")
    plens = [len(r["text"]) for r in corpus]
    print(f"paragraph len: min:{np.min(plens)}, max:{np.max(plens)}, mean:{np.mean(plens)}, median:{np.median(plens)}")

print_stats(corpus)

documents: 6206719 paragraphs: 15689008, paragraphs per document: 2.527745818684558
paragraph len: min:100, max:116665, mean:1084.785491982667, median:1142.0


In [6]:
original_id2pid = generate_original_id2pid_mapping(corpus)
lineno2id = {i: r["id"] for i, r in enumerate(corpus)}

In [7]:
# RUN this just for the first time
write_json(Path(COLBERT_ROOT, "original_id2pid.json"), original_id2pid, mkdir=True)
write_jsonl(Path(COLBERT_ROOT, "collection.jsonl"), corpus, mkdir=True)
write_json(WIKI_LINENO2ID, lineno2id, mkdir=True)

In [7]:
if APPROACH == "full":
    trn_data = import_qacg_split(TRAIN_FILES)
    dev_data = import_qacg_split(DEV_FILES)
    tst_data = import_qacg_split(TEST_FILES)
elif APPROACH in ["balanced", "balanced_shuf"]:
    trn_data = import_qacg_split_subsample(TRAIN_FILES, subsample=98403, seed=1234)
    dev_data = import_qacg_split_subsample(DEV_FILES, subsample=10029, seed=1234)
    tst_data = import_qacg_split_subsample(TEST_FILES, subsample=9480, seed=1234)
    if APPROACH == "balanced_shuf":
        rng = np.random.RandomState(1234)
        rng.shuffle(trn_data)
        rng.shuffle(dev_data)
        rng.shuffle(tst_data)
elif APPROACH == "fever_size":
    print("Created in prepare_data_fever.ipynb")
    trn_data = read_jsonl(Path(SPLIT_ROOT, "train_fever_size.jsonl"))
    dev_data = read_jsonl(Path(SPLIT_ROOT, "dev_fever_size.jsonl"))
    tst_data = read_jsonl(Path(SPLIT_ROOT, "test_fever_size.jsonl"))
else:
    assert False, APPROACH

reading: /mnt/data/factcheck/wiki/pl/20230801/qacg/claim/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k/train_support.json
WARN>> claim not NFC, fixing...
WARN>> claim not NFC, fixing...
WARN>> claim not NFC, fixing...
WARN>> claim not NFC, fixing...
reading: /mnt/data/factcheck/wiki/pl/20230801/qacg/claim/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k/train_refute.json
reading: /mnt/data/factcheck/wiki/pl/20230801/qacg/claim/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_support.json
reading: /mnt/data/factcheck/wiki/pl/20230801/qacg/claim/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_refute.json
reading: /mnt/data/factcheck/wiki/pl/20230801/qacg/claim/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k/test_support.json
reading: /mnt/data/factcheck/wiki/pl/20230801/qacg/claim/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k/test_refute.json


In [8]:
def unique_claims(data):
    claims = set()
    new_data = []
    for e in data:
        if e["claim"] not in claims:
            new_data.append(e)
            claims.add(e["claim"])
    print(f"original claims: {len(data)}, unique: {len(new_data)}")
    return new_data

trn_data = unique_claims(trn_data)
dev_data = unique_claims(dev_data)
tst_data = unique_claims(tst_data)

original claims: 196806, unique: 183204
original claims: 20058, unique: 18685
original claims: 18960, unique: 17731


In [9]:
write_jsonl(Path(SPLIT_ROOT, f"train_{APPROACH}.jsonl"), trn_data, mkdir=True)
write_jsonl(Path(SPLIT_ROOT, f"dev_{APPROACH}.jsonl"), dev_data, mkdir=True)
write_jsonl(Path(SPLIT_ROOT, f"test_{APPROACH}.jsonl"), tst_data, mkdir=True)
SPLIT_ROOT

PosixPath('/mnt/data/factcheck/wiki/pl/20230801/qacg/splits/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k')

In [10]:
len(trn_data)

183204

In [11]:
trn_data[0], trn_data[1]

({'claim': 'Iwan Własow został zamordowany w obozie dla jeńców wojennych.',
  'label': 's',
  'evidence': ['Iwan_Własow_(dyplomata)_1']},
 {'claim': 'W drugiej rundzie US Open wygrała z Eugenie Bouchard.',
  'label': 's',
  'evidence': ['Angelique_Kerber_5']})

In [12]:
def export_queries(data, out_file):
    queries = []
    for r in tqdm(data):
        queries.append({"query": r["claim"]})
    write_jsonl(out_file, queries, mkdir=True)

export_queries(trn_data, Path(QUERIES_ROOT, f"train_qacg_queries_{APPROACH}.jsonl"))
export_queries(dev_data, Path(QUERIES_ROOT, f"dev_qacg_queries_{APPROACH}.jsonl"))
export_queries(tst_data, Path(QUERIES_ROOT, f"test_qacg_queries_{APPROACH}.jsonl"))
QUERIES_ROOT

100%|██████████| 183204/183204 [00:00<00:00, 1511458.23it/s]
100%|██████████| 18685/18685 [00:00<00:00, 1626451.60it/s]
100%|██████████| 17731/17731 [00:00<00:00, 1766993.07it/s]


PosixPath('/mnt/data/factcheck/wiki/pl/20230801/colbertv2/qacg/queries/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k')

## Anserini Hard Negatives

We will use Anserini in the first stage to get hard negatives. 

In [83]:
export_as_anserini_collection(corpus, ANSERINI_COLLECTION)

In [84]:
!python -m pyserini.index.lucene \
    -collection JsonCollection \
    -generator DefaultLuceneDocumentGenerator \
    -threads 4 \
    -input {ANSERINI_COLLECTION} \
    -language {ANSERINI_LANG} \
    -index {ANSERINI_INDEX} \
    -storePositions -storeDocvectors -storeRaw

2023-09-13 19:16:56,787 INFO  [main] index.IndexCollection (IndexCollection.java:380) - Setting log level to INFO
2023-09-13 19:16:56,788 INFO  [main] index.IndexCollection (IndexCollection.java:383) - Starting indexer...
2023-09-13 19:16:56,788 INFO  [main] index.IndexCollection (IndexCollection.java:385) - DocumentCollection path: /mnt/data/factcheck/wiki/pl/20230801/anserini/collection
2023-09-13 19:16:56,789 INFO  [main] index.IndexCollection (IndexCollection.java:386) - CollectionClass: JsonCollection
2023-09-13 19:16:56,789 INFO  [main] index.IndexCollection (IndexCollection.java:387) - Generator: DefaultLuceneDocumentGenerator
2023-09-13 19:16:56,789 INFO  [main] index.IndexCollection (IndexCollection.java:388) - Threads: 4
2023-09-13 19:16:56,789 INFO  [main] index.IndexCollection (IndexCollection.java:389) - Language: en
2023-09-13 19:16:56,789 INFO  [main] index.IndexCollection (IndexCollection.java:390) - Stemmer: porter
2023-09-13 19:16:56,789 INFO  [main] index.IndexCollec

In [13]:
ANSERINI_RETRIEVED

PosixPath('/mnt/data/factcheck/wiki/pl/20230801/anserini/retrieved/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k')

In [14]:
anserini_retrieve_claims(ANSERINI_INDEX, dev_data, 128)
write_jsonl(Path(ANSERINI_RETRIEVED, f"dev_{APPROACH}.jsonl"), dev_data, mkdir=True)

anserini_retrieve_claims(ANSERINI_INDEX, tst_data, 128)
write_jsonl(Path(ANSERINI_RETRIEVED, f"test_{APPROACH}.jsonl"), tst_data, mkdir=True)

anserini_retrieve_claims(ANSERINI_INDEX, trn_data, 128)
write_jsonl(Path(ANSERINI_RETRIEVED, f"train_{APPROACH}.jsonl"), trn_data, mkdir=True)

100%|██████████| 28440/28440 [05:58<00:00, 79.38it/s]
100%|██████████| 295209/295209 [1:01:02<00:00, 80.60it/s]


In [None]:
trn_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"train_{APPROACH}.jsonl"))
dev_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"dev_{APPROACH}.jsonl"))
tst_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"test_{APPROACH}.jsonl"))

In [13]:
def match_retrieval(data, other_jsonl):
    odata = read_jsonl(other_jsonl)
    claim2retrieved = {e["claim"]: e["retrieved"] for e in odata}
    print(f"other jsonl claims: {len(odata)} unique claims: {len(claim2retrieved)}")
    for e in data:
        e["retrieved"] = claim2retrieved[e["claim"]]

match_retrieval(dev_data, Path(ANSERINI_RETRIEVED, f"dev_balanced.jsonl"))
match_retrieval(tst_data, Path(ANSERINI_RETRIEVED, f"test_balanced.jsonl"))
match_retrieval(trn_data, Path(ANSERINI_RETRIEVED, f"train_balanced.jsonl"))

write_jsonl(Path(ANSERINI_RETRIEVED, f"dev_{APPROACH}.jsonl"), dev_data, mkdir=True)
write_jsonl(Path(ANSERINI_RETRIEVED, f"test_{APPROACH}.jsonl"), tst_data, mkdir=True)
write_jsonl(Path(ANSERINI_RETRIEVED, f"train_{APPROACH}.jsonl"), trn_data, mkdir=True)

other jsonl claims: 30087 unique claims: 28476
other jsonl claims: 28440 unique claims: 26991
other jsonl claims: 295209 unique claims: 278746


In [15]:
ANSERINI_RETRIEVED

PosixPath('/mnt/data/factcheck/wiki/pl/20230801/anserini/retrieved/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k')

## SBERT Reranked Anserini

Did not work well enough. Skip for now...

In [None]:
# import splits with hard negatives retrieved by Anserini
trn_data_anserini = read_jsonl(Path(ANSERINI_RETRIEVED, "train.jsonl"), show_progress=True)
dev_data_anserini = read_jsonl(Path(ANSERINI_RETRIEVED, "dev.jsonl"), show_progress=True)

231kit [00:09, 25.6kit/s] 
27.9kit [00:00, 36.8kit/s]


In [37]:
sbert_CE_rerank(dev_data_anserini, corpus)
write_jsonl(Path(ANSERINI_RETRIEVED, "dev_anserini+minilm.jsonl"), dev_data_anserini)

100%|██████████| 27893/27893 [1:04:37<00:00,  7.19it/s]


In [38]:
sbert_CE_rerank(trn_data_anserini, corpus)
write_jsonl(Path(ANSERINI_RETRIEVED, "train_anserini+minilm.jsonl"), trn_data_anserini)

100%|██████████| 231296/231296 [9:09:20<00:00,  7.02it/s]  


## Triplet Generation

In [14]:
trn_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"train_{APPROACH}.jsonl"))
dev_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"dev_{APPROACH}.jsonl"))
tst_data = read_jsonl(Path(ANSERINI_RETRIEVED, f"test_{APPROACH}.jsonl"))
# trn_data = read_jsonl(Path(ANSERINI_RETRIEVED, "train_anserini+minilm.jsonl"))
# dev_data = read_jsonl(Path(ANSERINI_RETRIEVED, "dev_anserini+minilm.jsonl"))

In [16]:
def show_retrieval(sample, corpus, search, k=3):
    claim = sample["claim"]
    bid = sample["evidence"][0]
    rec = corpus[original_id2pid[bid]]
    evidence = rec["text"]
    # retrieved = sample["retrieved"]
    print(f"CLAIM: {claim}")
    print()

    if search:
        found = f"FOUND '{search}'" if search in evidence else f"NOT FOUND '{search}'"
    else:
        found = ""

    print(f"EVIDENCE ({bid}) {found}:\n" + textwrap.fill(evidence))

    for i in range(k):
        bid = sample["retrieved"][i]
        ret = corpus[original_id2pid[bid]]["text"]
        if search:
            found = f"FOUND '{search}'" if search in ret else f"NOT FOUND '{search}'"
        else:
            found = ""
        print(f"\nRETRIEVED {i+1} ({bid}) {found}:\n" + textwrap.fill(ret))



show_retrieval(tst_data[11], corpus, search="Niles", k=3)

CLAIM: Leonardo Emilio Bonzi urodził się w Mediolanie.

EVIDENCE (Leonardo_Bonzi_1) NOT FOUND 'Niles':
Leonardo Bonzi  Leonardo Emilio Bonzi (ur. 22 grudnia 1902 w
Mediolanie, zm. 30 grudnia 1977 w Ripalta Cremasca) – włoski
bobsleista, olimpijczyk, a także reżyser i producent filmów
dokumentalnych.  Był również tenisistą. W 1019 dotarł do 3. rundy gry
pojedynczej French Championships i Wimbledonu.  Leonardo Bonzi

RETRIEVED 1 (Leonardo_Bonzi_1) NOT FOUND 'Niles':
Leonardo Bonzi  Leonardo Emilio Bonzi (ur. 22 grudnia 1902 w
Mediolanie, zm. 30 grudnia 1977 w Ripalta Cremasca) – włoski
bobsleista, olimpijczyk, a także reżyser i producent filmów
dokumentalnych.  Był również tenisistą. W 1019 dotarł do 3. rundy gry
pojedynczej French Championships i Wimbledonu.  Leonardo Bonzi

RETRIEVED 2 (San_Leonardo_(stacja_metra)_1) NOT FOUND 'Niles':
San Leonardo (stacja metra)  San Leonardo – stacja metra w Mediolanie,
na linii M1. Znajduje się na via Fichera, w dzielnicy San Leonardo, w
Mediolanie 

### Triples by Retrieval

In [27]:
n_preretrieve = 64
offset = 0
# trn_triples = generate_triples_by_retrieval(trn_data, corpus, original_id2pid, n_preretrieve, offset=offset)
# dev_triples = generate_triples_by_retrieval(dev_data, corpus, original_id2pid, n_preretrieve, offset=offset)
# write_jsonl(Path(COLBERT_ROOT, f"train_triples{n_preretrieve}_o{offset}.jsonl"), trn_triples)
# write_jsonl(Path(COLBERT_ROOT, f"dev_triples{n_preretrieve}_o{offset}.jsonl"), dev_triples)
# write_jsonl(Path(COLBERT_ROOT, f"train_triples{n_preretrieve}_o{offset}_anserini+minilm.jsonl"), trn_triples)
# write_jsonl(Path(COLBERT_ROOT, f"dev_triples{n_preretrieve}_o{offset}_anserini+minilm.jsonl"), dev_triples)

In [26]:
nway = 128

trn_triples = generate_triples_by_retrieval_nway(trn_data, corpus, original_id2pid, nway=nway)
write_jsonl(Path(TRIPLES_ROOT, f"trn_triples_nway{nway}_anserini_{APPROACH}.jsonl"), trn_triples, mkdir=True)

dev_triples = generate_triples_by_retrieval_nway(dev_data, corpus, original_id2pid, nway=nway)
write_jsonl(Path(TRIPLES_ROOT, f"dev_triples_nway{nway}_anserini_{APPROACH}.jsonl"), dev_triples, mkdir=True)

tst_triples = generate_triples_by_retrieval_nway(tst_data, corpus, original_id2pid, nway=nway)
write_jsonl(Path(TRIPLES_ROOT, f"tst_triples_nway{nway}_anserini_{APPROACH}.jsonl"), tst_triples, mkdir=True)

# trn_triples = generate_triples_by_retrieval_nway(trn_data, corpus, original_id2pid, nway=nway)
# write_jsonl(Path(COLBERT_ROOT, f"trn_triples_nway{nway}_anserini+minilm.jsonl"), trn_triples)

# dev_triples = generate_triples_by_retrieval_nway(dev_data, corpus, original_id2pid, nway=nway)
# write_jsonl(Path(COLBERT_ROOT, f"dev_triples_nway{nway}_anserini+minilm.jsonl"), dev_triples)

  2%|▏         | 3784/183485 [00:00<00:20, 8949.08it/s] 



  8%|▊         | 14785/183485 [00:01<00:15, 10581.07it/s]



 10%|█         | 18948/183485 [00:01<00:15, 10584.34it/s]



100%|██████████| 183485/183485 [00:21<00:00, 8683.94it/s] 


generated 183485 triples with 0 failures and 78 random fixes


 27%|██▋       | 5006/18750 [00:00<00:01, 9232.34it/s] 



 39%|███▉      | 7406/18750 [00:00<00:01, 7735.18it/s]



 60%|█████▉    | 11191/18750 [00:01<00:00, 8530.76it/s] 



100%|██████████| 18750/18750 [00:02<00:00, 8336.40it/s]


generated 18750 triples with 0 failures and 8 random fixes


  8%|▊         | 1346/17727 [00:00<00:02, 6549.37it/s]



 46%|████▌     | 8162/17727 [00:00<00:00, 10577.28it/s]



 70%|██████▉   | 12341/17727 [00:01<00:00, 9850.58it/s] 



100%|██████████| 17727/17727 [00:01<00:00, 9865.66it/s] 


generated 17727 triples with 0 failures and 5 random fixes


In [17]:
nway = 128

trn_triples = generate_triples_by_retrieval_nway(trn_data, corpus, original_id2pid, nway=nway, use_evidence=True)
write_jsonl(Path(TRIPLES_ROOT, f"trn_triples_nway{nway}_evidence+anserini_{APPROACH}.jsonl"), trn_triples, mkdir=True)

dev_triples = generate_triples_by_retrieval_nway(dev_data, corpus, original_id2pid, nway=nway, use_evidence=True)
write_jsonl(Path(TRIPLES_ROOT, f"dev_triples_nway{nway}_evidence+anserini_{APPROACH}.jsonl"), dev_triples, mkdir=True)

tst_triples = generate_triples_by_retrieval_nway(tst_data, corpus, original_id2pid, nway=nway, use_evidence=True)
write_jsonl(Path(TRIPLES_ROOT, f"tst_triples_nway{nway}_evidence+anserini_{APPROACH}.jsonl"), tst_triples, mkdir=True)

  1%|          | 2020/183204 [00:00<01:15, 2401.78it/s]



  6%|▌         | 11114/183204 [00:02<01:02, 2764.91it/s]



  9%|▉         | 16090/183204 [00:03<00:52, 3189.15it/s]



 12%|█▏        | 21206/183204 [00:04<00:22, 7099.99it/s]



100%|██████████| 183204/183204 [00:38<00:00, 4805.47it/s]


generated 183204 triples with 0 failures and 24 random fixes


 41%|████▏     | 7744/18685 [00:01<00:03, 3206.99it/s]



 86%|████████▌ | 16047/18685 [00:03<00:01, 2296.24it/s]



100%|██████████| 18685/18685 [00:04<00:00, 4099.11it/s]


generated 18685 triples with 0 failures and 3 random fixes


  4%|▎         | 641/17731 [00:00<00:16, 1043.30it/s]



100%|██████████| 17731/17731 [00:02<00:00, 6202.84it/s]


generated 17731 triples with 0 failures and 1 random fixes


In [18]:
TRIPLES_ROOT

PosixPath('/mnt/data/factcheck/wiki/pl/20230801/colbertv2/qacg/triples/stanza/mt5-large_all-cp126k/mt5-large_all-cp156k')

In [21]:
da = read_jsonl("/mnt/data/factcheck/wiki/sk/20230801/colbertv2/qacg/triples/crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_triples_nway128_anserini_balanced.jsonl")
db = read_jsonl("/mnt/data/factcheck/wiki/sk/20230801/colbertv2/qacg/triples/crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_triples_nway128_evidence+anserini_balanced.jsonl")

cnt = 0
for a, b in zip(da, db):
    a, b = np.array(a), np.array(b)
    # print(len(a), len(b))
    if not np.all(a == b) and len(a) == len(b):
    # if len(a) != len(b):
        # break
        cnt += 1
cnt, len(da)

(14084, 30087)

In [22]:
a

array([ 30086, 152437, 155842, 155683, 155291, 188084, 293071, 164405,
       177648, 194284, 156128, 153482, 155843, 177643, 151892, 159393,
       293070, 152403, 162089, 160233, 177622, 157550, 152538, 194292,
       157642, 155415, 154850, 154418, 194285, 155562, 153558, 194351,
       193373, 194293, 155685, 155292, 156775, 151721, 177653, 193369,
       156019, 293073, 193355, 194289, 193371, 194340, 256442, 194309,
       156018, 155416, 167926, 155296, 201761, 194287, 193366, 151580,
       155290, 193356, 194359, 264176, 194335, 250302, 194303, 194348,
       194345, 157452, 156741, 160653, 177659, 151891, 253088, 194306,
       155845, 194308, 253087,  98021, 194355, 154051, 281091, 156979,
       202027, 201966, 293078, 291248, 177610, 154120, 201970, 161965,
       250311, 152439, 193341, 160295, 154556, 160764, 201963, 152458,
       159394, 154363, 155662, 201948, 201965, 155684, 201954, 293132,
       201950, 154568, 194326, 152504, 167925, 344184, 202020,  98020,
      

In [23]:
b

array([ 30086, 155845, 152437, 155842, 155683, 155291, 188084, 293071,
       164405, 177648, 194284, 156128, 153482, 155843, 177643, 151892,
       159393, 293070, 152403, 162089, 160233, 177622, 157550, 152538,
       194292, 157642, 155415, 154850, 154418, 194285, 155562, 153558,
       194351, 193373, 194293, 155685, 155292, 156775, 151721, 177653,
       193369, 156019, 293073, 193355, 194289, 193371, 194340, 256442,
       194309, 156018, 155416, 167926, 155296, 201761, 194287, 193366,
       151580, 155290, 193356, 194359, 264176, 194335, 250302, 194303,
       194348, 194345, 157452, 156741, 160653, 177659, 151891, 253088,
       194306, 194308, 253087,  98021, 194355, 154051, 281091, 156979,
       202027, 201966, 293078, 291248, 177610, 154120, 201970, 161965,
       250311, 152439, 193341, 160295, 154556, 160764, 201963, 152458,
       159394, 154363, 155662, 201948, 201965, 155684, 201954, 293132,
       201950, 154568, 194326, 152504, 167925, 344184, 202020,  98020,
      

In [24]:
corpus[290620]

{'id': 'The_2nd_Law_4',
 'did': 'The_2nd_Law',
 'bid': 4,
 'text': 'The 2nd Law\n\nAlbum vyšiel ako set súborov na stiahnutie, CD, bonusová edícia CD+DVD (so zábermi z výroby, "The Making of The 2nd Law", na DVD), a aj ako vinylová LP platňa. Edícia „Deluxe“ bola boxsetom "The 2nd Law" obsahujúcim CD, DVD vinylový dvojalbum a tri postery.\n\nThe 2nd Law\n\nPrezentácia albumu.\n\nDňa 16. júna 2012 kapela Muse vydala pre plánovaný album "The 2nd Law" trailer, ku ktorému bolo na ich oficiálnej webstránke zverejnené počítadlo, ktorého deň nula (t.j. plánované vydanie) pripadal na 17. september toho istého roku. Trailer obsahujúci aj dubstepové prvky sa stretol s rozpačitými reakciami skalných fanúšikov. Dňa 9. augusta dala kapela Muse pre fanúšikov, ktorí si objednali album k dispozícii nahrávku piesne „The 2nd Law: Unsustainable“. Dňa 10. augusta 2012, kapela zverejnila aj video k tejto piesni na svojej oficiálnej stránke na YouTube. Skupina zároveň urobil súťaž na produkciu hudobného vid

In [25]:
corpus[104930]

{'id': 'Muse_17',
 'did': 'Muse',
 'bid': 17,
 'text': 'Muse\n\nDňa 6. júna 2012 skupina Muse vydala trailer nového albumu "The 2nd Law" s tým, že na svojom webe začali odpočítavať dni do 17. septembra, kedy by mali tento album vydať. Trailer, ktorý skutočne obsahoval aj dubstepové prvky bol od fanúšikov prijatý so zmiešanými reakciami.\n\nMuse\n\nNa 7. jún 2012 bol ohlásený prvý koncert The 2nd Law Tour. Turné sa začalo v Európe. Konalo sa vo Francúzsku, Španielsku, Spojenom kráľovstve a v ďalších krajinách. Prvý singel tohto albumu, „Survival“, bol oficiálnym songom Letných olympijských hier 2012, ktoré sa konali v Londýne. Singel mal spolu s prelúdiom premiéru v show, ktoré v BBC Radio 1 uvádzal DJ Zane Lowe. Taktiež s touto piesňou kapela Muse vystúpila počas záverečnej ceremónie týchto hier.\n\nMuse\n\nZoznam skladieb albumu "The 2nd Law" bol oficiálne zverejnený 13. júla 2012. Druhým jeho oficiálnym singlom bola nahrávka „Madness“. Vyšiel 20. augusta 2012 a hudobné video k tejto 

In [25]:
q = read_jsonl("/mnt/data/factcheck/wiki/sk/20230801/colbertv2/qacg/queries/crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_qacg_queries_balanced.jsonl")
q[22]

{'query': 'Muse vydali album The 2nd Law.'}

### Randomly sampled triples

Not used...

In [29]:
# replace with Anserini, random does not need "retrieved" field
trn_data = import_split(Path(FEVER_ROOT, "train.jsonl"), original_id2pid, fixer=enfever_lrev_id_fixer)
dev_data = import_split(Path(FEVER_ROOT, "paper_dev.jsonl"), original_id2pid, fixer=enfever_lrev_id_fixer)

100%|██████████| 145449/145449 [00:00<00:00, 543252.96it/s]


Not found 0/263822 evidence documents, 0 claims had zero evidence


100%|██████████| 9999/9999 [00:00<00:00, 779793.35it/s]

Not found 0/14475 evidence documents, 0 claims had zero evidence





In [14]:
def generate_triples_random(data, corpus, original_id2pid, k, seed=1234):
    # generate soft negatives by choosing random documents
    idx2id = {i: doc["id"] for i, doc in enumerate(corpus)}
    id2idx = {doc["id"]: i for i, doc in enumerate(corpus)}
    
    rng = np.random.RandomState(seed)

    def random_docs(posidx: int):
        docs = set()
        while len(docs) < k:
            doc = rng.choice(len(id2idx))
            if doc != pos and doc not in docs:
                docs.add(doc)
        return list(docs)

    triples = []
    for qid, r in enumerate(tqdm(data)):
        for pos in r["evidence"]:
            posidx = id2idx[pos]
            for neg in random_docs(posidx):
                neg = idx2id[neg]
                triples.append((qid, original_id2pid[pos], original_id2pid[neg]))
    print(f"generated {len(triples)} triples")
    return triples

In [15]:
k = 32
trn_triples = generate_triples_random(trn_data, corpus, original_id2pid, k, seed=1234)
dev_triples = generate_triples_random(dev_data, corpus, original_id2pid, k, seed=1235)
tst_triples = generate_triples_random(tst_data, corpus, original_id2pid, k, seed=1236)
write_jsonl(Path(COLBERT_ROOT, f"train_triples{k}_random.jsonl"), trn_triples)
write_jsonl(Path(COLBERT_ROOT, f"dev_triples{k}_random.jsonl"), dev_triples)
write_jsonl(Path(COLBERT_ROOT, f"test_triples{k}_random.jsonl"), tst_triples)

100%|██████████| 109810/109810 [00:41<00:00, 2625.70it/s]


generated 4482720 triples


100%|██████████| 6666/6666 [00:02<00:00, 3034.82it/s]


generated 258528 triples


In [20]:
# TODO implement this, but for full Wikipedia corpora, EnFEVER has only single paragraph=document per page
def generate_triples_by_page(data, corpus, original_id2pid, k):
    # documents = paragraphs, I use page as term describing whole original text (e.g., Wikipedia page composed of documents=paragraphs)
    # takes the positive document and adds k-1 negatives from the same page
    id2txt = {doc["id"]: doc["text"] for doc in corpus}
    failures = 0
    triples = []
    # for qid, r in enumerate(data):
    #     # those retrieved but not in the annotated evidence will become hard negatives 
    #     retrieved = set(r["retrieved"][offset:]).difference(r["evidence"])
    #     for pos in r["evidence"]:
    #         if pos not in id2txt:
    #             # may happen for EnFEVER when the snapshot does not exactly match 
    #             failures += 1
    #             continue
    #         for neg in list(retrieved)[:k]:
    #             triples.append((qid, original_id2pid[pos], original_id2pid[neg]))
    # print(f"generated {len(triples)} triples with {failures} failures")
    return triples

# k = 8
# dev_triples = generate_triples_by_page(dev_data, corpus, original_id2pid, k=k)

In [12]:
corpus[1]

{'id': '2018–19_Dhaka_Premier_Division_Twenty20_Cricket_League',
 'text': 'The 2018–19 Dhaka Premier Division Twenty20 Cricket League was the first edition of the Dhaka Premier Division Twenty20 Cricket League, a Twenty20 cricket competition that was held in Bangladesh. It started on 25 February 2019 and concluded on 4 March 2019. The tournament took place directly before the 2018–19 Dhaka Premier Division Cricket League, and features the same twelve teams. The final of the competition was played as a night game at the Sher-e-Bangla National Stadium in Mirpur. The Bangladesh Cricket Board (BCB) instigated the tournament in order to give Bangladeshi players more experience in the 20-over format, in the hope that local players will become more prominent in the Bangladesh Premier League. For this reason the tournament featured local cricketers exclusively, unlike the Dhaka Premier Division Cricket League, in which foreign players take part. Shinepukur Cricket Club were the first team to q

In [22]:
c = Counter(pid2original_id[t[1]] for t in trn_triples)
c.most_common(20)

[('Snoop_Dogg', 6912),
 ('Marlon_Brando', 6880),
 ('United_States', 6112),
 ('Wyatt_Earp', 5888),
 ('Michael_Jackson', 5856),
 ('United_Kingdom', 5760),
 ('Adele', 5152),
 ('Miley_Cyrus', 5024),
 ('Tim_Rice', 4832),
 ('The_Beatles', 4544),
 ('International_relations', 4480),
 ('A_Song_of_Ice_and_Fire', 4480),
 ('Abraham_Lincoln', 4448),
 ('David_Beckham', 4320),
 ('Anne_Hathaway', 4192),
 ('One_Direction', 4192),
 ('Frank_Sinatra', 4160),
 ('Oliver_Reed', 4096),
 ('Deadpool_-LRB-_film_-RRB-', 4064),
 ('Bradley_Cooper', 4064)]

In [15]:
write_jsonl(Path(COLBERT_ROOT, f"train_triples{n_preretrieve}_o{offset}.jsonl"), trn_triples)
write_jsonl(Path(COLBERT_ROOT, f"dev_triples{n_preretrieve}_o{offset}.jsonl"), dev_triples)

In [52]:
dev_triples = read_jsonl(Path(COLBERT_ROOT, f"dev_triples1_random.jsonl"))

In [56]:
import textwrap
for qid, pos, neg in dev_triples[1:]:
    claim = dev_data[qid]['claim']
    # pos = textwrap.fill(corpus[pos]['text'])
    # neg = textwrap.fill(corpus[neg]['text'])
    pos = corpus[pos]["id"]
    neg = corpus[neg]["id"]

    print(f"{claim}")
    print(f"{pos}")
    print()
    print(f"{neg}")
    print("-----------")
    break

Telemundo is a English-language television network.
Hispanic_and_Latino_Americans

Cnaphalocrocis_poeyalis
-----------


# Combine Languages

In [22]:
def combine_queries_triples_collection(
        split_files, # already created combined split
        triple_files, # target files for triples (keep same order of splits!)
        query_files, # target files for queries (keep same order of splits!)
        collection_dir, # combined collection dir
        lang2triple_files, # triple files for each particular source language
        lang2collection_files
        ):
    
    #import splits and generate queries (just claims) 
    splits = []
    for split_file, query_file in zip(split_files, query_files):
        data = read_jsonl(split_file)
        splits.append(data)
        recs = [{"query": r["claim"]} for r in data]
        print(f'writing query file: "{query_file}"')
        write_jsonl(query_file, recs, mkdir=True)

    print("loading collections")
    lang2collection = {lang: read_jsonl(collection_file, show_progress=True) for lang, collection_file in lang2collection_files.items()}

    print("loading triples")
    lang2triples = {lang: [read_jsonl(Path(triple_file_lst[0], triple_file_name)) for triple_file_name in triple_file_lst[1]] for lang, triple_file_lst in lang2triple_files.items()}

    new_collection = []
    used_collection_pids = defaultdict(dict) # language to original pid to new pid 
    new_triples = [[] for _ in range(len(splits))] # dict of list (splits) of lists (triples)

    for split_idx, split in enumerate(splits):
        # run through the matching split and select all samples with matching `lang`
        for s in tqdm(split, desc=f"split idx: {split_idx}"):
            # print(f"DEBUG: sample={s}")
            lang = s["lang"]
            # index from the original language split
            orig_idx = int(s["orig_idx"])
            # print(f"DEBUG: orig_idx={orig_idx}")
            trp = lang2triples[lang][split_idx][orig_idx]
            # print(f"DEBUG: trp={trp}")
            new_trp_id = len(new_triples[split_idx]) # generate ever increasing triple ids
            new_trp = [new_trp_id]
            # now translate original pids to new pids
            for orig_pid in trp[1:]:
                if orig_pid not in used_collection_pids[lang]:
                    new_pid = len(new_collection) # ever increasing pid given by position in `the new_collection`
                    used_collection_pids[lang][orig_pid] = new_pid
                    col_item = lang2collection[lang][orig_pid]
                    new_collection.append(col_item)
                else:
                    new_pid = used_collection_pids[lang][orig_pid]
                new_trp.append(new_pid)
            # print(f"DEBUG: new_trp={new_trp}")
                    
            new_triples[split_idx].append(new_trp)
    
    assert len(new_triples) == len(triple_files)
    for nt, triple_file in zip(new_triples, triple_files):
        print(f'writing triple file: "{triple_file}"')
        write_jsonl(triple_file, nt, mkdir=True)

    collection_file = Path(collection_dir, "collection.jsonl")
    print(f'writing collection file: "{collection_file}"')
    write_jsonl(collection_file, new_collection, mkdir=True, show_progress=True)

    # Most likely makes no sense as the bids from different Wikipedia may overlap 
    # original_id2pid = {r["id"]: i for i, r in enumerate(new_collection)}
    # original_id2pid_file = Path(collection_dir, "original_id2pid.json")
    # print(f'writing original_id2pid file: "{original_id2pid_file}"')
    # write_json(original_id2pid_file, original_id2pid, mkdir=True)

In [24]:
APPROACH="balanced_shuf"
DATE = "20230801"
QACG_ROOT=f"/mnt/data/factcheck/wiki/cs_en_pl_sk/{DATE}/qacg"
COLBERT_ROOT=f"/mnt/data/factcheck/wiki/cs_en_pl_sk/{DATE}/colbertv2/qacg"
COLBERT_ROOT_CS=f"/mnt/data/factcheck/wiki/cs/{DATE}/colbertv2/qacg"
COLBERT_ROOT_EN=f"/mnt/data/factcheck/wiki/en/{DATE}/colbertv2/qacg"
COLBERT_ROOT_PL=f"/mnt/data/factcheck/wiki/pl/{DATE}/colbertv2/qacg"
COLBERT_ROOT_SK=f"/mnt/data/factcheck/wiki/sk/{DATE}/colbertv2/qacg"
MODELS_CS=f"PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k"
MODELS_EN=f"stanza/mt5-large_all-cp126k/mt5-large_all-cp156k"
MODELS_PL=f"stanza/mt5-large_all-cp126k/mt5-large_all-cp156k"
MODELS_SK=f"crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k"

# TRIPLE_SPLITS = [
#     f"dev_triples_nway128_anserini_{APPROACH}.jsonl",
#     f"tst_triples_nway128_anserini_{APPROACH}.jsonl",
#     f"trn_triples_nway128_anserini_{APPROACH}.jsonl",
#     ]

TRIPLE_SPLITS = [
    f"dev_triples_nway128_evidence+anserini_{APPROACH}.jsonl",
    f"tst_triples_nway128_evidence+anserini_{APPROACH}.jsonl",
    f"trn_triples_nway128_evidence+anserini_{APPROACH}.jsonl",
    ]

combine_queries_triples_collection(
    split_files=[
        f"{QACG_ROOT}/splits/dev_{APPROACH}.jsonl",
        f"{QACG_ROOT}/splits/test_{APPROACH}.jsonl",
        f"{QACG_ROOT}/splits/train_{APPROACH}.jsonl",
    ],
    query_files=[
        f"{COLBERT_ROOT}/queries/dev_qacg_queries_{APPROACH}.jsonl",
        f"{COLBERT_ROOT}/queries/test_qacg_queries_{APPROACH}.jsonl",
        f"{COLBERT_ROOT}/queries/train_qacg_queries_{APPROACH}.jsonl",
    ],
    triple_files=[f"{COLBERT_ROOT}/triples/{ts}" for ts in TRIPLE_SPLITS],
    collection_dir=COLBERT_ROOT,
    lang2triple_files = {
        "cs": (f"{COLBERT_ROOT_CS}/triples/{MODELS_CS}", TRIPLE_SPLITS),
        "en": (f"{COLBERT_ROOT_EN}/triples/{MODELS_EN}", TRIPLE_SPLITS),
        "pl": (f"{COLBERT_ROOT_PL}/triples/{MODELS_PL}", TRIPLE_SPLITS),
        "sk": (f"{COLBERT_ROOT_SK}/triples/{MODELS_SK}", TRIPLE_SPLITS),
    },
    lang2collection_files = {
        "cs": f"{COLBERT_ROOT_CS}/collection.jsonl",
        "en": f"{COLBERT_ROOT_EN}/collection.jsonl",
        "pl": f"{COLBERT_ROOT_PL}/collection.jsonl",
        "sk": f"{COLBERT_ROOT_SK}/collection.jsonl",
    }
)

writing query file: "/mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/colbertv2/qacg/queries/dev_qacg_queries_balanced_shuf.jsonl"
writing query file: "/mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/colbertv2/qacg/queries/test_qacg_queries_balanced_shuf.jsonl"
writing query file: "/mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/colbertv2/qacg/queries/train_qacg_queries_balanced_shuf.jsonl"
loading collections


0.00it [00:00, ?it/s]

0.00it [00:00, ?it/s]

0.00it [00:00, ?it/s]

0.00it [00:00, ?it/s]

loading triples


split idx: 0: 100%|██████████| 18750/18750 [00:01<00:00, 12830.77it/s]
split idx: 1: 100%|██████████| 18146/18146 [00:01<00:00, 12137.26it/s]
split idx: 2: 100%|██████████| 186489/186489 [00:14<00:00, 12693.15it/s]


writing triple file: "/mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/colbertv2/qacg/triples/dev_triples_nway128_evidence+anserini_balanced_shuf.jsonl"
writing triple file: "/mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/colbertv2/qacg/triples/tst_triples_nway128_evidence+anserini_balanced_shuf.jsonl"
writing triple file: "/mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/colbertv2/qacg/triples/trn_triples_nway128_evidence+anserini_balanced_shuf.jsonl"
writing collection file: "/mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/colbertv2/qacg/collection.jsonl"


  0%|          | 0/7450195 [00:00<?, ?it/s]

In [25]:
QACG_ROOT

'/mnt/data/factcheck/wiki/cs_en_pl_sk/20230801/qacg'

In [26]:
# the same for the language SUM dataset
APPROACH="balanced_shuf"
DATE = "20230801"
QACG_ROOT=f"/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/{DATE}/qacg"
COLBERT_ROOT=f"/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/{DATE}/colbertv2/qacg"
COLBERT_ROOT_CS=f"/mnt/data/factcheck/wiki/cs/{DATE}/colbertv2/qacg"
COLBERT_ROOT_EN=f"/mnt/data/factcheck/wiki/en/{DATE}/colbertv2/qacg"
COLBERT_ROOT_PL=f"/mnt/data/factcheck/wiki/pl/{DATE}/colbertv2/qacg"
COLBERT_ROOT_SK=f"/mnt/data/factcheck/wiki/sk/{DATE}/colbertv2/qacg"
MODELS_CS=f"PAV-ner-CNEC/mt5-large_all-cp126k/mt5-large_all-cp156k"
MODELS_EN=f"stanza/mt5-large_all-cp126k/mt5-large_all-cp156k"
MODELS_PL=f"stanza/mt5-large_all-cp126k/mt5-large_all-cp156k"
MODELS_SK=f"crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k"
# TRIPLE_SPLITS = [
#     f"dev_triples_nway128_anserini_{APPROACH}.jsonl",
#     f"tst_triples_nway128_anserini_{APPROACH}.jsonl",
#     f"trn_triples_nway128_anserini_{APPROACH}.jsonl",
#     ]

TRIPLE_SPLITS = [
    f"dev_triples_nway128_evidence+anserini_{APPROACH}.jsonl",
    f"tst_triples_nway128_evidence+anserini_{APPROACH}.jsonl",
    f"trn_triples_nway128_evidence+anserini_{APPROACH}.jsonl",
    ]

combine_queries_triples_collection(
    split_files=[
        f"{QACG_ROOT}/splits/dev_{APPROACH}.jsonl",
        f"{QACG_ROOT}/splits/test_{APPROACH}.jsonl",
        f"{QACG_ROOT}/splits/train_{APPROACH}.jsonl",
    ],
    query_files=[
        f"{COLBERT_ROOT}/queries/dev_qacg_queries_{APPROACH}.jsonl",
        f"{COLBERT_ROOT}/queries/test_qacg_queries_{APPROACH}.jsonl",
        f"{COLBERT_ROOT}/queries/train_qacg_queries_{APPROACH}.jsonl",
    ],
    triple_files=[f"{COLBERT_ROOT}/triples/{ts}" for ts in TRIPLE_SPLITS],
    collection_dir=COLBERT_ROOT,
    lang2triple_files = {
        "cs": (f"{COLBERT_ROOT_CS}/triples/{MODELS_CS}", TRIPLE_SPLITS),
        "en": (f"{COLBERT_ROOT_EN}/triples/{MODELS_EN}", TRIPLE_SPLITS),
        "pl": (f"{COLBERT_ROOT_PL}/triples/{MODELS_PL}", TRIPLE_SPLITS),
        "sk": (f"{COLBERT_ROOT_SK}/triples/{MODELS_SK}", TRIPLE_SPLITS),
    },
    lang2collection_files = {
        "cs": f"{COLBERT_ROOT_CS}/collection.jsonl",
        "en": f"{COLBERT_ROOT_EN}/collection.jsonl",
        "pl": f"{COLBERT_ROOT_PL}/collection.jsonl",
        "sk": f"{COLBERT_ROOT_SK}/collection.jsonl",
    }
)

writing query file: "/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/queries/dev_qacg_queries_balanced_shuf.jsonl"
writing query file: "/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/queries/test_qacg_queries_balanced_shuf.jsonl"
writing query file: "/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/queries/train_qacg_queries_balanced_shuf.jsonl"
loading collections


0.00it [00:00, ?it/s]

0.00it [00:00, ?it/s]

0.00it [00:00, ?it/s]

0.00it [00:00, ?it/s]

loading triples


split idx: 0: 100%|██████████| 75573/75573 [00:05<00:00, 13991.37it/s]
split idx: 1: 100%|██████████| 71607/71607 [00:05<00:00, 13119.88it/s]
split idx: 2: 100%|██████████| 741542/741542 [00:56<00:00, 13156.83it/s]


writing triple file: "/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/triples/dev_triples_nway128_evidence+anserini_balanced_shuf.jsonl"
writing triple file: "/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/triples/tst_triples_nway128_evidence+anserini_balanced_shuf.jsonl"
writing triple file: "/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/triples/trn_triples_nway128_evidence+anserini_balanced_shuf.jsonl"
writing collection file: "/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/collection.jsonl"


  0%|          | 0/12243036 [00:00<?, ?it/s]

In [27]:
COLBERT_ROOT

'/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg'

In [6]:
corpus = import_corpus("/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/collection.jsonl")
triples = read_jsonl("/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/triples/dev_triples_nway128_anserini_balanced.jsonl")
queries = read_jsonl("/mnt/data/factcheck/wiki/sum_cs_en_pl_sk/20230801/colbertv2/qacg/queries/dev_qacg_queries_balanced.jsonl")

# corpus = import_corpus("/mnt/data/factcheck/wiki/sk/20230801/colbertv2/qacg/collection.jsonl")
# triples = read_jsonl("/mnt/data/factcheck/wiki/sk/20230801/colbertv2/qacg/triples/crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_triples_nway128_anserini_balanced.jsonl")
# queries = read_jsonl("/mnt/data/factcheck/wiki/sk/20230801/colbertv2/qacg/queries/crabz_slovakbert-ner/mt5-large_all-cp126k/mt5-large_all-cp156k/dev_qacg_queries_balanced.jsonl")

0.00it [00:00, ?it/s]

In [10]:
idx = 2
print(triples[idx])
print(queries[idx])
print(corpus[triples[idx][1]])
print(corpus[triples[idx][2]])
print(corpus[triples[idx][3]])

[2, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383]
{'query': 'Najbližšia železničná stanica v Žerčice je Dobrovice.'}
{'id': 'Žerčice_2', 'did': 'Žerčice', 'bid': 2, 'text': 'Žerčice\n\nvýroba cementového tovaru, holič, 2 hostince, kolár, 2 kováči, mlyn, obchod s obuvou Baťa, 3 obuvníci, pekár, 2 mäsiari, 3 obchody so zmiešaným tovarom, sporiteľní a záložní spolok pre Žerčice, trafika, 3 stolári, veľkost

In [11]:
Counter([r["lang"] for r in corpus])

KeyError: 'lang'

In [33]:
corpus_sk[146171]

{'id': 'Medzinárodné_letisko_Vancouver_4',
 'did': 'Medzinárodné_letisko_Vancouver',
 'bid': 4,
 'text': 'Medzinárodné letisko Vancouver\n\nLetisko získalo v roku 2001 ocenenie "Airport Management Award" od B.C. Aviation Council.\n\nMedzinárodné letisko Vancouver\n\nZdroje.\n\n"Tento článok je čiastočný alebo úplný preklad článku [ Vancouver International Airport] na anglickej Wikipédii."',
 'url': 'https://sk.wikipedia.org/wiki?curid=222084',
 'revid': '6909'}

# Retrieve

In [4]:
class ColBERTv2Retriever:
    def __init__(self, index_name, original_id2pid_file):
            with Run().context(RunConfig(experiment='FEVER predictions')):
                self.searcher = Searcher(index=index_name)
            self.original_id2pid = read_json(original_id2pid_file)
            self.pid2original_id = {v: k for k, v in self.original_id2pid.items()}

    def retrieve(self, query: str, k: int):
        results = self.searcher.search(query, k=k)
        pids, ranks, scores = results
        ids = [self.pid2original_id[pid] for pid in pids]
        return ids, scores
    
# IDX_NAME = "/mnt/data/factcheck/fever/data-en-lrev/colbertv2-sqlite/indices/bert-base-uncased/msmarco/msmarco.2bits"
# IDX_NAME = "/mnt/data/factcheck/fever/data-en-lrev/colbertv2/indices/bert-base-multilingual-cased/enfever_lrev/triples32_random_cp20k.2bits"
# IDX_NAME = "/mnt/data/factcheck/fever/data-en-lrev/colbertv2/indices/bert-base-multilingual-cased/enfever_lrev/triples64_o0_anserini+minilm_cp10k.2bits"
# IDX_NAME = "/mnt/data/factcheck/fever/data-en-lrev/colbertv2/indices/bert-base-multilingual-cased/enfever_lrev/triples64_o0_anserini+minilm_cp30k.2bits"
IDX_NAME = "/mnt/data/factcheck/fever/data-en-lrev/colbertv2/indices/xlm-roberta-large-squad2/enfever_lrev/nway32_anserini+minilm.2bits"
# IDX_NAME = "/mnt/data/factcheck/fever/data-en-lrev/colbertv2/indices/bert-base-multilingual-cased/enfever_lrev/nway96_anserini+minilm.2bits"

retriever = ColBERTv2Retriever(IDX_NAME, Path(COLBERT_ROOT, "original_id2pid.json"))

[Apr 03, 08:19:36] #> Loading collection from JSONL...
0M 1M 2M 3M 4M 5M 
ColBERT: self.colbert_config=ColBERTConfig(query_token_id='[unused0]', doc_token_id='[unused1]', query_token='[Q]', doc_token='[D]', ncells=None, centroid_score_threshold=None, ndocs=None, index_path=None, nbits=2, kmeans_niters=4, resume=False, similarity='cosine', bsize=64, accumsteps=8, lr=3e-06, maxsteps=500000, save_every=None, warmup=None, warmup_bert=None, relu=False, nway=32, use_ib_negatives=False, reranker=False, distillation_alpha=1.0, ignore_scores=False, model_name='deepset/xlm-roberta-large-squad2', query_maxlen=32, attend_to_mask_tokens=False, interaction='colbert', dim=128, doc_maxlen=300, mask_punctuation=True, checkpoint='/mnt/data/factcheck/fever/data-en-lrev/colbertv2/checkpoints/xlm-roberta-large-squad2/enfever_lrev/train_nway32_anserini+minilm/colbert', triples='/mnt/data/factcheck/fever/data-en-lrev/colbertv2/trn_triples_nway128_anserini+minilm.jsonl', collection=<colbert.data.collection.Co

100%|██████████| 216/216 [00:00<00:00, 784.84it/s]


[Apr 03, 08:20:19] #> Loading codes and residuals...


100%|██████████| 216/216 [00:09<00:00, 23.28it/s]


In [5]:
retriever.retrieve("Obama was a president", k=3)


#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==
#> Input: . Obama was a president, 		 True, 		 None
#> Output IDs: torch.Size([32]), tensor([     0,      3,      5,  16042,    509,     10,  13918,      2, 250001,
        250001, 250001, 250001, 250001, 250001, 250001, 250001, 250001, 250001,
        250001, 250001, 250001, 250001, 250001, 250001, 250001, 250001, 250001,
        250001, 250001, 250001, 250001, 250001])
#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])



(['Barack_Obama',
  'Electoral_history_of_Barack_Obama',
  'Barack_-LRB-disambiguation-RRB-'],
 [24.9375, 24.8125, 23.703125])

In [6]:
def generate_fever_predictions(claims_jsonl, predictions_jsonl, retriever: ColBERTv2Retriever, k:int=500):
    test_data = read_jsonl(claims_jsonl)
    for r in tqdm(test_data[:]):
        ids, _ = retriever.retrieve(r["claim"], k=k)
        r["predicted_pages"] = ids
    write_jsonl(predictions_jsonl, test_data, mkdir=True)

In [7]:
SPLIT = "paper_test"

# generate_fever_predictions(Path(FEVER_ROOT, f"{SPLIT}.jsonl"), 
#                            Path(FEVER_PREDICTIONS, f"{SPLIT}/colbertv2/bert-base-uncased/ms-marco+enfever_lrev/triples64_o0_anserini_negatives_cp20k.jsonl"), retriever, k=500)

# generate_fever_predictions(Path(FEVER_ROOT, f"{SPLIT}.jsonl"), 
#                            Path(FEVER_PREDICTIONS, f"{SPLIT}/colbertv2/bert-base-uncased/msmarco/k500.jsonl"), retriever, k=500)

# generate_fever_predictions(Path(FEVER_ROOT, f"{SPLIT}.jsonl"), 
#                            Path(FEVER_PREDICTIONS, f"{SPLIT}/colbertv2/bert-base-multilingual-cased/enfever_lrev/triples64_o0_anserini+minilm_cp30k.2bits.jsonl"), retriever, k=500)

# generate_fever_predictions(Path(FEVER_ROOT, f"{SPLIT}.jsonl"), 
#                            Path(FEVER_PREDICTIONS, f"{SPLIT}/colbertv2/bert-base-multilingual-cased/enfever_lrev/nway96_anserini+minilm.2bits.jsonl"), retriever, k=500)

# generate_fever_predictions(Path(FEVER_ROOT, f"{SPLIT}.jsonl"), 
#                            Path(FEVER_PREDICTIONS, f"{SPLIT}/colbertv2/bert-base-uncased/enfever_lrev/nway32_anserini+minilm.2bits.jsonl"), retriever, k=500)

generate_fever_predictions(Path(FEVER_ROOT, f"{SPLIT}.jsonl"), 
                           Path(FEVER_PREDICTIONS, f"{SPLIT}/colbertv2/xlm-roberta-large-squad2/enfever_lrev/nway32_anserini+minilm.2bits.jsonl"), retriever, k=500)

100%|██████████| 9999/9999 [05:46<00:00, 28.89it/s]


In [13]:
pid2original_id = {v: k for k, v in original_id2pid.items()}

In [18]:
c = Counter(pid2original_id[t[1]] for t in trn_triples)
c.most_common(200)