In [None]:
import pyterrier as pt
import numpy as np
import fastrank
import lightgbm as lgb


pt.init()

In [None]:
#creating the index
files = pt.io.find_files("./corpus")

indexer =  pt.TRECCollectionIndexer("./index2", verbose=True, blocks=False)
indexref = indexer.index(files)

In [None]:
#index statistics
index = pt.IndexFactory.of(indexref)
print(index.getCollectionStatistics().toString())

In [None]:
topics_path = "./topics.msmarco-doc.dev.txt"
qrels_path = "./qrels.msmarco-doc.dev.txt"

qrels = pt.io.read_qrels(qrels_path)
topics = pt.io.read_topics(topics_path,format="singleline")

train_topics, valid_topics, test_topics = np.split(topics, [int(.6*len(topics)), int(.8*len(topics))])

In [None]:
BM25 = pt.BatchRetrieve(indexref, controls = {"wmodel": "BM25"})

TF_IDF =  pt.BatchRetrieve(indexref, controls = {"wmodel": "TF_IDF"})
PL2 =  pt.BatchRetrieve(indexref, controls = {"wmodel": "PL2"})

In [None]:
fbr = pt.FeaturesBatchRetrieve(indexref, controls = {"wmodel": "BM25"}, features=["SAMPLE", "WMODEL:TF_IDF", "WMODEL:PL2"]) 
pipe = (BM25%100) >> (pt.transformer.IdentityTransformer() ** TF_IDF ** PL2)

In [None]:
train_request = fastrank.TrainRequest.coordinate_ascent()

params = train_request.params
params.init_random = True
params.normalize = True
params.seed = 1234567

ca_pipe = pipe >> pt.ltr.apply_learned_model(train_request, form='fastrank')

ca_pipe.fit(train_topics, qrels)

In [None]:
lmart_l = lgb.LGBMRanker(
    task="train",
    silent=False,
    min_data_in_leaf=1,
    min_sum_hessian_in_leaf=1,
    max_bin=255,
    num_leaves=31,
    objective="lambdarank",
    metric="ndcg",
    ndcg_eval_at=[10],
    ndcg_at=[10],
    eval_at=[10],
    learning_rate= .1,
    importance_type="gain",
    num_iterations=100,
    early_stopping_rounds=5
)

lmart_x_pipe = pipe >> pt.ltr.apply_learned_model(lmart_l, form="ltr", fit_kwargs={'eval_at':[10]})

lmart_x_pipe.fit(train_topics, qrels, valid_topics, qrels)

In [None]:
pt.Experiment(
    [BM25 % 100, ca_pipe, lmart_x_pipe],
    test_topics,
    qrels, 
    names=["BM25",  "BM25 + CA", "BM25 + LMart"],
    baseline=0,
    eval_metrics=["map", "ndcg", "ndcg_cut_10", "mrt", "recip_rank"])