In [10]:
import numpy as np
import torch

from collections import defaultdict, OrderedDict, Counter
from dataclasses import dataclass
import datetime as dt
from itertools import chain
import os
import pathlib
from pathlib import Path
import string
import pandas as pd
import unicodedata as ud
from time import time
from typing import Dict, Type, Callable, List
import sys
import json


sys.path.insert(0, '/home/drchajan/devel/python/FC/ColBERTv2') # ignore other ColBERT installations

%load_ext autoreload
%autoreload 2

from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Indexer, Searcher
from colbert.data import Queries, Collection
from colbert import Trainer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
def read_json(fname, object_pairs_hook=OrderedDict):
    with open(fname, 'r') as json_file:
        data = json.load(json_file, object_pairs_hook=object_pairs_hook)
    return data

def read_jsonl(jsonl):
    with open(jsonl, 'r') as json_file:
        data = []
        for jline in json_file:
            rec = json.loads(jline, object_pairs_hook=OrderedDict)
            data.append(rec)
    return data

def write_jsonl(jsonl, data):
    # data is an iterable (list) of JSON-compatible structures (OrderedDict)
    with open(jsonl, 'w', encoding='utf8') as json_file:
        for r in data:
            json.dump(r, json_file, ensure_ascii=False, default=str)
            json_file.write("\n")

In [12]:
cfg = read_json("cfg/colbertv2_wiki_en_20230220_100k.json")

In [13]:
cfg

OrderedDict([('index_name',
              '/mnt/data/factcheck/wiki/en/20230220/colbertv2/indexes/enwiki-20230220-paragraphs-100k.2bits'),
             ('lineno2id_mapping',
              '/mnt/data/factcheck/wiki/en/20230220/paragraphs/enwiki-20230220-paragraphs-100k_lineno2id.json'),
             ('port', 8050)])

In [6]:
class ColBERTv2Retriever:
    def __init__(self, cfg: dict):
            with Run().context(RunConfig(experiment='REST api')):
                self.searcher = Searcher(index=str(cfg["index_name"]))
            with open(cfg["lineno2id_mapping"], "r") as f:
                 self.lineno2id = ujson.load(f)
                 self.lineno2id = {int(k): v for k, v in self.lineno2id.items()}
                 assert len(self.lineno2id) == len(self.searcher.collection.data), f"not maching collection size: {len(self.lineno2id)} != {len(self.searcher.collection.data)}"

    def retrieve(self, query: str, k: int):
        results = self.searcher.search(query, k=k)
        pids, ranks, scores = results
        ids = [self.lineno2id[pid] for pid in pids]
        return ids, scores
    
retriever = ColBERTv2Retriever(cfg)

[Mar 21, 15:12:09] #> Loading collection from JSONL...
0M 
[Mar 21, 15:12:25] #> Loading codec...
[Mar 21, 15:12:25] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Mar 21, 15:12:26] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
[Mar 21, 15:12:26] #> Loading IVF...
[Mar 21, 15:12:26] #> Loading doclens...


100%|██████████| 4/4 [00:00<00:00, 731.54it/s]

[Mar 21, 15:12:26] #> Loading codes and residuals...



100%|██████████| 4/4 [00:00<00:00, 28.06it/s]


In [14]:
SPLIT = "paper_test"
test_data = read_jsonl(f"/mnt/data/factcheck/fever/data-en-lrev/fever-data/{SPLIT}.jsonl")

In [15]:
test_data[0]

OrderedDict([('id', 113501),
             ('verifiable', 'NOT VERIFIABLE'),
             ('label', 'NOT ENOUGH INFO'),
             ('claim', 'Grease had bad reviews.'),
             ('evidence', [[[133128, None, None, None]]])])