In [1]:
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert import Trainer

In [2]:
config = ColBERTConfig(
        bsize=64,
        root=r"../retrain_colbert",
    )

In [None]:
with Run().context(RunConfig(nranks=1, experiment="msmarco")):
    trainer = Trainer(
        triples= r"../data/triples.train.small.tsv",
        queries= r"../data/queries.train.tsv",
        collection= r"../data/collection.tsv",
        config=config,
    )

    checkpoint_path = trainer.train()

    print(f"Saved checkpoint to {checkpoint_path}...")

#> Starting...


In [1]:
from colbert.evaluation.loaders import *

In [5]:
from tqdm.auto import tqdm

In [2]:
param = {
    'triples': '../data/triples.train.small.tsv',
    'queries': '../data/queries.train.tsv',
    'collection': '../data/collection.tsv'
}

In [171]:
def load_queries(queries_path):
    queries = OrderedDict()

    print_message("#> Loading the queries from", queries_path, "...")

    with open(queries_path, encoding='utf-8') as f:
        for line in f:
            qid, query, *_ = line.replace("\xa0", " ").strip().split('\t')
            qid = int(qid)

            assert (qid not in queries), ("Query QID", qid, "is repeated!")
            queries[re.sub('[^ 0-9a-zA-Z_-]', '', query.strip(" "))] = qid

    print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")

    return queries

In [9]:
def load_collection(collection_path):
    print_message("#> Loading collection...")

    collection = {}

    with open(collection_path, encoding="utf-8") as f:
        for line_idx, line in enumerate(f):
            if line_idx % (1000*1000) == 0:
                print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)

            pid, passage, *rest = line.strip('\n\r ').split('\t')
            assert pid == 'id' or int(pid) == line_idx

            if len(rest) >= 1:
                title = rest[0]
                passage = title + ' | ' + passage

            collection[passage] = line_idx

    print()

    return collection

In [172]:
queries = load_queries(param['queries'])

[Sep 26, 16:31:04] #> Loading the queries from ../data/queries.train.tsv ...
[Sep 26, 16:31:05] #> Got 806349 queries. All QIDs are unique.



In [None]:
collection = load_collection(param['collection'])

[Sep 26, 15:43:32] #> Loading collection...
0M 1M 2M 3M 4M 5M 6M 7M 8M 


In [209]:
def get_id(text, data):
    _id = data.get(exceptions.get(text, text), None)
    text = 
    if _id is None: 
        _id = data.get(exceptions.get(re.sub('[^ 0-9a-zA-Z_-]', '', text), re.sub('[^ 0-9a-zA-Z_-]', '', text)), None)
    if _id is None: _id = data.get(text.strip(' '), None)
    
    if _id is None:
        print(text)
        raise Exception(text)
    return _id

In [227]:
import thefuzz

In [176]:
import re

In [105]:
exceptions = {
    "divorce et sÃ©paration": 'divorce et séparation',
    "what is intelÂ® vpro technology": 'what is intel® vpro technology',
    "what is aÂ\xa0shock wave":'what is a shock wave',
    'Germanyâ\x80\x99s perspective, the Treaty of Versailles was a fair settlement for its national interests': 'Germany’s perspective, the Treaty of Versailles was a fair settlement for its national interests',
}

In [234]:
examples = []
with open(param['triples'], encoding='utf-8') as f:
    for line in tqdm(f):
        q_str, p_str_p, p_str_n = line.strip('\n').split('\t')
        qid = get_id(q_str, queries)
        pid_p = get_id(p_str_p, collection)
        pid_n = get_id(p_str_n, collection)
        example = [qid, pid_p, pid_n]
        examples.append(example)

0it [00:00, ?it/s]

which of the followingÂ is an ascending tract of the spinal cord?


Exception: which of the followingÂ is an ascending tract of the spinal cord?

In [235]:
from difflib import SequenceMatcher

In [228]:
from thefuzz import process

In [243]:
%%timeit -n 1 -r 5
sorted( [[q, SequenceMatcher(None, q_str, q).quick_ratio()] for q in data.keys()], key=lambda x: x[1], reverse=True)

13.9 s ± 57 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


In [244]:
%%timeit -n 1 -r 5
hit, max_ratio = '', 0
for q in data.keys():
    score = SequenceMatcher(None, q_str, q).quick_ratio()
    if score >= max_ratio:
        hit, max_ratio = q, score

13.6 s ± 72.2 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)


In [246]:
q_str

'which of the followingÂ\xa0is an ascending tract of the spinal cord?'

In [248]:
q_str.replace("\xa0", ' ')

'which of the followingÂ is an ascending tract of the spinal cord?'

In [247]:
re.sub('[^ 0-9a-zA-Z_-]', '', q_str)

'which of the followingis an ascending tract of the spinal cord'

In [232]:
# data = collection
data = queries
qid = [k for k in data.keys() if "he lithosphere consis" in k]
qid

['the lithosphere consists of  ____________']

In [None]:
list(exceptions.keys())[-1]

In [233]:
exceptions = {**exceptions, **{q_str: qid[0]}}
exceptions

{'divorce et sÃ©paration': 'divorce et séparation',
 'what is intelÂ® vpro technology': 'what is intel® vpro technology',
 'what is aÂ\xa0shock wave': 'what is a shock wave',
 'Germanyâ\x80\x99s perspective, the Treaty of Versailles was a fair settlement for its national interests': 'Germany’s perspective, the Treaty of Versailles was a fair settlement for its national interests',
 'yesÃ¼n temÃ¼r khan emperor taiding of yuan': 'yesün temür khan emperor taiding of yuan',
 ' The vitamin that prevents beriberi is ': ' The vitamin that prevents beriberi is',
 ' phosphates as food ingredients ': ' phosphates as food ingredients',
 ' who invented the periodic table ': ' who invented the periodic table',
 'what does bokmÃ¥l mean': 'what does bokmål mean',
 'which action should youÂ\xa0never take when selecting quotations': 'which action should you never take when selecting quotations',
 'dermatitis, anemia, convulsions, depressions, and confusion are all signs of a vitamin _________Â\xa0defic