In [16]:
import os
os.environ["JAVA_HOME"] = "/home/andrew/Java"
os.environ["JVM_PATH"] = '/home/andrew/Java/jre/lib/server/libjvm.so'
import pyterrier as pt
if not pt.started():
    pt.init(boot_packages=["com.github.terrierteam:terrier-prf:-SNAPSHOT"])

In [17]:
dataset = pt.get_dataset("msmarco_passage")
eval_topics = pt.get_dataset('irds:msmarco-passage/trec-dl-2019/judged').get_topics()

In [18]:
index = pt.IndexFactory.of('/home/andrew/Documents/Data/paced/marcoblocks')

15:48:55.544 [main] WARN org.terrier.structures.BaseCompressingMetaIndex - Structure meta reading data file directly from disk (SLOW) - try index.meta.data-source=fileinmem in the index properties file. 1.9 GiB of memory would be required.


In [19]:
bm25 = pt.BatchRetrieve(index, wmodel="BM25", controls={"bm25.k_1": 0.45, "bm25.b": 0.55, "bm25.k_3": 0.5})
dir_LM = pt.BatchRetrieve(index, wmodel="DirichletLM", controls={"dirichletlm.mu": 200})
DPH = pt.BatchRetrieve(index, wmodel="DPH")

In [20]:
smd = pt.rewrite.SequentialDependence()
bo1 = pt.rewrite.Bo1QueryExpansion(index)
kl = pt.rewrite.KLQueryExpansion(index)
rm3 = pt.rewrite.RM3(index)

In [21]:
expansions = {
    'smd' : lambda x : smd >> x % 1000,
    'bo1' : lambda x : x >> bo1 >> x % 1000,
    'kl' : lambda x : x >> kl >> x % 1000,
    'rm3' : lambda x : x >> rm3 >> x % 1000,
}
models = {
    'bm25' : bm25,
    'dir_LM' : dir_LM,
    'DPH' : DPH,
}

In [22]:
pipes = {}
for model_n, model in models.items():
    pipes[model_n] = model
for exp_n, exp in expansions.items():
    for model_n, model in models.items():
        pipes[f'{exp_n}_{model_n}'] = exp(model)

In [23]:
'''
bm25_rm3 = bm25 >> rm3 >> bm25 % 1000
bm25_bo1 = bm25 >> bo1 >> bm25 % 1000
smd_bm25 = smd >> bm25 % 1000
'''

'\nbm25_rm3 = bm25 >> rm3 >> bm25 % 1000\nbm25_bo1 = bm25 >> bo1 >> bm25 % 1000\nsmd_bm25 = smd >> bm25 % 1000\n'

In [24]:
# name : pipeline dict
'''
pipes = {
    'bm25' : bm25,
    'dph' : dph,
    'bm25_rm3' : bm25_rm3,
    'bm25_bo1' : bm25_bo1,
    'dph_rm3' : dph_rm3,
    'dph_bo1' : dph_bo1,
    'smd_bm25' : smd_bm25,
    'smd_dph' : smd_dph,
}
'''

"\npipes = {\n    'bm25' : bm25,\n    'dph' : dph,\n    'bm25_rm3' : bm25_rm3,\n    'bm25_bo1' : bm25_bo1,\n    'dph_rm3' : dph_rm3,\n    'dph_bo1' : dph_bo1,\n    'smd_bm25' : smd_bm25,\n    'smd_dph' : smd_dph,\n}\n"

In [25]:
scores = {}
for name, pipe in pipes.items():
    scores[name] = pipe.transform(eval_topics)
    print("success for {}".format(name))

success for bm25
success for dir_LM
success for DPH
success for smd_bm25
success for smd_dir_LM
success for smd_DPH
success for bo1_bm25
success for bo1_dir_LM
success for bo1_DPH
success for kl_bm25
success for kl_dir_LM
success for kl_DPH
success for rm3_bm25
success for rm3_dir_LM
success for rm3_DPH


In [26]:
for k, v in scores.items():
    v.to_csv(f'/home/andrew/Documents/Data/paced/members/{k}.csv', index=False)

In [27]:
def convert_to_dict(result):
    result = result.groupby('qid').apply(lambda x: dict(zip(x['docno'], zip(x['score'], x['rank'])))).to_dict()
    return result

In [28]:
score_dicts = {k : convert_to_dict(v) for k, v in scores.items()}

In [29]:
import json
with open('/home/andrew/Documents/Data/paced/scores.json', 'w') as f:
    json.dump(score_dicts, f)

In [30]:
score_dicts['bm25']

{'1037798': {'8760864': (25.7061307957552, 0),
  '8760867': (25.67560318318698, 1),
  '3641634': (25.20785267973889, 2),
  '3620983': (23.986325876258437, 3),
  '4788864': (23.797944544219046, 4),
  '2787508': (23.600291498247536, 5),
  '4291373': (23.29833831823784, 6),
  '8760868': (22.9136095790802, 7),
  '994978': (22.299874632719703, 8),
  '8760873': (22.227822430830557, 9),
  '2863296': (22.06348875832687, 10),
  '3295328': (21.947165237888992, 11),
  '3641640': (21.832061850537638, 12),
  '8760870': (21.71815949921796, 13),
  '2572535': (21.605439483339723, 14),
  '3387556': (21.493883488541876, 15),
  '4083926': (21.493883488541876, 16),
  '2157456': (21.254608602022362, 17),
  '7943228': (20.958128301202713, 18),
  '3030655': (20.283525908210706, 19),
  '3030654': (19.931256963338548, 20),
  '1308037': (19.829893617378197, 21),
  '8537479': (19.759464937122424, 22),
  '8820474': (19.324703653542056, 23),
  '5099520': (19.26822039534153, 24),
  '8417086': (19.21549376678812, 25