In [None]:
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())

In [None]:
import pyterrier as pt
import pandas as pd
from itertools import filterfalse
import os
import wget
import zipfile

pt.init()

In [None]:
from pyterrier_colbert.indexing import ColBERTIndexer
from pyterrier_colbert.ranking import ColBERTFactory

In [None]:
from ance.pyterrier_ance import ANCEIndexer, ANCETextScorer

In [None]:
antique = pt.get_dataset("irds:antique/test")

In [None]:
antique_index_src = os.path.abspath("antique-index")
if not os.path.exists(antique_index_src):
    print("Creating a new Antique index for BM25 and RM3...")
    pt.index.IterDictIndexer(antique_index_src, blocks=True, meta={"docno": 20, "text": 131072}).index(antique.get_corpus_iter(), fields=["docno", "text"])

antique_index = pt.IndexFactory.of(antique_index_src)

In [None]:
queries = antique.get_topics()
qrels = antique.get_qrels()

In [None]:
bm25 = pt.terrier.Retriever(antique_index, wmodel="BM25", metadata=["docno", "text"])
rm3 = pt.rewrite.RM3(antique_index)

In [None]:
checkpoint_url = "http://www.dcs.gla.ac.uk/~craigm/colbert.dnn.zip"
extract_dir = "colbert_checkpoint"
checkpoint_path = "colbert_checkpoint.zip"

if not os.path.exists(checkpoint_path):
    print("Downloading checkpoint...")
    wget.download(checkpoint_url, checkpoint_path)
if not os.path.exists(extract_dir):
    with zipfile.ZipFile(checkpoint_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)

colbert_checkpoint_path = os.path.abspath("colbert_checkpoint/colbert.dnn")
index_root = os.path.abspath("antique-index")
index_name = os.path.abspath("antique-colbert-index")

if not os.path.exists(index_name):
    print("Index not found. Creating a new Antique index for ColBERT...")
    colbert_index = ColBERTIndexer(
        checkpoint=colbert_checkpoint_path,
        index_root=index_root,
        index_name=index_name,
        chunksize=64, # Maybe even 128, the allowed maximum --> it regulates the size of PyTorch temp files that are created by the indexer
        gpu=True # if the torch.cuda returned False, comment this
    )
    colbert_index.index(antique.get_corpus_iter())
    print("Index successfully created!")

In [None]:
colbert_reranker = ColBERTFactory(colbert_checkpoint_path, index_root, index_name)

In [None]:
ance_extract_dir = "ance_checkpoint"
ance_checkpoint_path = "ance_checkpoint.zip"

if not os.path.exists(ance_extract_dir):
    with zipfile.ZipFile(ance_checkpoint_path, 'r') as zip_ref:
        zip_ref.extractall(ance_extract_dir)

In [None]:
ance_checkpoint_path = os.path.abspath("ance_checkpoint")
ance_index_name = os.path.abspath("antique-ance-index")

if not os.path.exists(ance_index_name):
    print("Index not found. Creating a new Antique index for ANCE...")
    ance_index = ANCEIndexer(ance_checkpoint_path, ance_index_name, num_docs=403666)
    ance_index.index(antique.get_corpus_iter())
    print("Index successfully created!")

In [None]:
ance_reranker = ANCETextScorer(ance_checkpoint_path)

In [None]:
pipe_dict = {
    "BM25": bm25,
    "BM25_RM3": bm25 >> rm3 >> bm25,
    "BM25_COLBERT": bm25 >> colbert_reranker.text_scorer(),
    "BM25_ANCE": bm25 >> ance_reranker,
    "BM25_RM3_COLBERT": bm25 >> rm3 >> bm25 >> colbert_reranker.text_scorer(),
    "BM25_RM3_ANCE": bm25 >> rm3 >> bm25 >> ance_reranker,
    "BM25_COLBERT_RM3": bm25 >> colbert_reranker.text_scorer() >> rm3 >> bm25,
    "BM25_COLBERT_ANCE": bm25 >> colbert_reranker.text_scorer() >> pt.text.get_text(antique_index) >> ance_reranker,
    "BM25_ANCE_RM3": bm25 >> ance_reranker >> rm3 >> bm25,
    "BM25_ANCE_COLBERT": bm25 >> ance_reranker >> colbert_reranker.text_scorer(),
    "BM25_RM3_COLBERT_ANCE": bm25 >> rm3 >> bm25 >> colbert_reranker.text_scorer() >> pt.text.get_text(antique_index) >> ance_reranker,
    "BM25_RM3_ANCE_COLBERT": bm25 >> rm3 >> bm25 >> ance_reranker >> colbert_reranker.text_scorer(),
    "BM25_COLBERT_RM3_ANCE": bm25 >> colbert_reranker.text_scorer() >> rm3 >> bm25 >> ance_reranker,
    "BM25_COLBERT_ANCE_RM3": bm25 >> colbert_reranker.text_scorer() >> pt.text.get_text(antique_index) >> ance_reranker >> rm3 >> bm25,
    "BM25_ANCE_RM3_COLBERT": bm25 >> ance_reranker >> rm3 >> bm25 >> colbert_reranker.text_scorer(),
    "BM25_ANCE_COLBERT_RM3": bm25 >> ance_reranker >> colbert_reranker.text_scorer() >> rm3 >> bm25
}

In [None]:
if not os.path.exists("antique-twofold/results.csv"):
    twofold_results = pt.Experiment(
        [
            bm25,
            bm25 >> rm3 >> bm25,
            bm25 >> colbert_reranker.text_scorer(),
            bm25 >> ance_reranker
        ],
        queries,
        qrels,
        ["map", "ndcg_cut_10", "recip_rank", "mrt"],
        ["BM25", "BM25_RM3", "BM25_COLBERT", "BM25_ANCE"],
        save_dir="antique-twofold",
        save_mode="reuse",
        baseline=0,
        correction="bonferroni"
    )
    twofold_results.to_csv("antique-twofold/results.csv", sep=',', na_rep="NaN", header=True, index=False)

In [None]:
if not os.path.exists("antique-threefold/results.csv"):
    bm25_rm3 = pt.Transformer.from_df(pt.io.read_results("antique-twofold/BM25_RM3.res.gz"), uniform=False)
    bm25_colbert = pt.Transformer.from_df(pt.io.read_results("antique-twofold/BM25_COLBERT.res.gz"), uniform=False)
    bm25_ance = pt.Transformer.from_df(pt.io.read_results("antique-twofold/BM25_ANCE.res.gz"), uniform=False)

    threefold_results = pt.Experiment(
        [
            bm25,
            bm25_rm3 >> pt.text.get_text(antique_index) >> colbert_reranker.text_scorer(),
            bm25_rm3 >> pt.text.get_text(antique_index) >> ance_reranker,
            bm25_colbert >> rm3 >> bm25,
            bm25_colbert >> pt.text.get_text(antique_index) >> ance_reranker,
            bm25_ance >> rm3 >> bm25,
            bm25_ance >> pt.text.get_text(antique_index) >> colbert_reranker.text_scorer()
        ],
        queries,
        qrels,
        ["map", "ndcg_cut_10", "recip_rank", "mrt"],
        ["BM25", "BM25_RM3_COLBERT", "BM25_RM3_ANCE", "BM25_COLBERT_RM3", "BM25_COLBERT_ANCE", "BM25_ANCE_RM3", "BM25_ANCE_COLBERT"],
        save_dir="antique-threefold",
        save_mode="reuse",
        baseline=0,
        correction="bonferroni"
    )
    threefold_results.to_csv("antique-threefold/results.csv", sep=',', na_rep="NaN", header=True, index=False)

In [None]:
if not os.path.exists("antique-fourfold/results.csv"):
    bm25_rm3_colbert = pt.Transformer.from_df(pt.io.read_results("antique-threefold/BM25_RM3_COLBERT.res.gz"), uniform=False)
    bm25_rm3_ance = pt.Transformer.from_df(pt.io.read_results("antique-threefold/BM25_RM3_ANCE.res.gz"), uniform=False)
    bm25_colbert_rm3 = pt.Transformer.from_df(pt.io.read_results("antique-threefold/BM25_COLBERT_RM3.res.gz"), uniform=False)
    bm25_colbert_ance = pt.Transformer.from_df(pt.io.read_results("antique-threefold/BM25_COLBERT_ANCE.res.gz"), uniform=False)
    bm25_ance_rm3 = pt.Transformer.from_df(pt.io.read_results("antique-threefold/BM25_ANCE_RM3.res.gz"), uniform=False)
    bm25_ance_colbert = pt.Transformer.from_df(pt.io.read_results("antique-threefold/BM25_ANCE_COLBERT.res.gz"), uniform=False)

    fourfold_results = pt.Experiment(
        [
            bm25,
            bm25_rm3_colbert >> pt.text.get_text(antique_index) >> ance_reranker,
            bm25_rm3_ance >> pt.text.get_text(antique_index) >> colbert_reranker.text_scorer(),
            bm25_colbert_rm3 >> pt.text.get_text(antique_index) >> ance_reranker,
            bm25_colbert_ance >> rm3 >> bm25,
            bm25_ance_rm3 >> pt.text.get_text(antique_index) >> colbert_reranker.text_scorer(),
            bm25_ance_colbert >> rm3 >> bm25
        ],
        queries,
        qrels,
        ["map", "ndcg_cut_10", "recip_rank", "mrt"],
        ["BM25", "BM25_RM3_COLBERT_ANCE", "BM25_RM3_ANCE_COLBERT", "BM25_COLBERT_RM3_ANCE", "BM25_COLBERT_ANCE_RM3", "BM25_ANCE_RM3_COLBERT", "BM25_ANCE_COLBERT_RM3"],
        save_dir="antique-fourfold",
        save_mode="reuse",
        baseline=0,
        correction="bonferroni"
    )
    fourfold_results.to_csv("antique-fourfold/results.csv", sep=',', na_rep="NaN", header=True, index=False)