In [1]:
import torch
import pyterrier as pt

if not pt.started():
    pt.init(tqdm="notebook")

PyTerrier 0.10.1 has loaded Terrier 5.9 (built by craigm on 2024-05-02 17:40) and terrier-helper 0.0.8

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


In [2]:
dataset = pt.get_dataset("irds:beir/trec-covid")
bm25 = pt.BatchRetrieve("../datam/trec-covid", wmodel="BM25")

In [43]:
import pandas as pd
reduced_topics = pd.read_csv("reduced_trec_cov.csv", usecols=[1,2])
reduced_topics["qid"] = reduced_topics["qid"].astype(str)


In [44]:
def remove_non_alphanumeric(text):
    return ''.join(ch if ch.isalnum() or ch == " " else "" for ch in text)

# Apply the function to the entire column
reduced_topics['query'] = reduced_topics['query'].apply(lambda x: remove_non_alphanumeric(x))
reduced_topics

Unnamed: 0,qid,query
0,1,origin of covid19
1,2,coronavirus response to weather changes
2,3,SARSCoV2 infected people developing immunity
3,4,complications from covid19
4,5,drugs active against sars covsars cov 2
5,6,rapid testing for covid 19
6,7,serological tests for coronavirus antibodies
7,8,Lack of testing availability leading to underr...
8,9,covid 19 in Canada
9,10,social distancing impact on slowing COVID19 sp...


In [25]:
topics = dataset.get_topics(variant="text")
topics

Unnamed: 0,qid,query
0,1,what is the origin of covid 19
1,2,how does the coronavirus respond to changes in...
2,3,will sars cov2 infected people develop immunit...
3,4,what causes death from covid 19
4,5,what drugs have been active against sars cov o...
5,6,what types of rapid testing for covid 19 have ...
6,7,are there serological tests that detect antibo...
7,8,how has lack of testing availability led to un...
8,9,how has covid 19 affected canada
9,10,has social distancing had an impact on slowing...


In [48]:
from pyterrier.measures import nDCG, RR, MAP


pt.Experiment(
    [bm25],
    dataset.get_topics(variant="query"),
    dataset.get_qrels(),
    eval_metrics=[RR @ 10, nDCG @ 10, MAP @ 1000],
)

Unnamed: 0,name,RR@10,nDCG@10,AP@1000
0,BR(BM25),0.817214,0.576109,0.18352


In [45]:
pt.Experiment(
    [bm25],
    reduced_topics,
    dataset.get_qrels(),
    eval_metrics=[RR @ 10, nDCG @ 10, MAP @ 1000],
)

Unnamed: 0,name,RR@10,nDCG@10,AP@1000
0,BR(BM25),0.677024,0.468119,0.146861


In [52]:
candidates_reduced = (bm25 % 5)(reduced_topics)
candidates_reduced

Unnamed: 0,qid,docid,docno,rank,score,query
0,1,34849,6ck2ntid,0,17.748472,origin of covid19
1,1,94087,8a6flxl6,1,17.748472,origin of covid19
2,1,50316,myfw58yo,2,15.791321,origin of covid19
3,1,106877,offpyz12,3,15.791321,origin of covid19
4,1,92337,kqj4wwx5,4,15.426276,origin of covid19
...,...,...,...,...,...,...
48745,50,106269,puqcbf8t,0,23.643785,MRNA vaccine for SARSCoV2
48746,50,70315,4qcbwezv,1,23.194012,MRNA vaccine for SARSCoV2
48747,50,73900,ea1r90io,2,22.702625,MRNA vaccine for SARSCoV2
48748,50,156873,vui9ybc1,3,22.414568,MRNA vaccine for SARSCoV2


In [53]:
candidates = (bm25 % 5)(dataset.get_topics(variant="text"))
candidates

Unnamed: 0,qid,docid,docno,rank,score,query
0,1,81848,dv9m19yk,0,15.639456,what is the origin of covid 19
1,1,103419,kgifmjvb,1,15.447703,what is the origin of covid 19
2,1,123191,wmfcey6f,2,15.423234,what is the origin of covid 19
3,1,67367,4dtk1kyh,3,14.592304,what is the origin of covid 19
4,1,97901,deee71uw,4,14.186018,what is the origin of covid 19
...,...,...,...,...,...,...
49000,50,117999,xbze5s3c,0,31.478040,what is known about an mrna vaccine for the sa...
49001,50,132765,aju2nr9x,1,29.408917,what is known about an mrna vaccine for the sa...
49002,50,149559,1v0f2dtx,2,29.211110,what is known about an mrna vaccine for the sa...
49003,50,75412,ptvsie6m,3,26.888638,what is known about an mrna vaccine for the sa...


## Artic-Embed

In [54]:
from fast_forward.encoder import TransformerEncoder
import torch

class SnowFlakeDocumentEncoder(TransformerEncoder):
  def __call__(self, texts):
    document_tokens =  self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=512)
    document_tokens.to(self.device)
    self.model.eval()
    # Compute token embeddings
    with torch.no_grad():
        #query_embeddings = self.model(**query_tokens)[0][:, 0]
        doument_embeddings = self.model(**document_tokens)[0][:, 0]

    # normalize embeddings
    #query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
    doument_embeddings = torch.nn.functional.normalize(doument_embeddings, p=2, dim=1)
    return doument_embeddings.detach().cpu().numpy()
  
class SnowFlakeQueryEncoder(TransformerEncoder):
  def __call__(self, texts):
    query_prefix = 'Represent this sentence for searching relevant passages: '
    queries_with_prefix = ["{}{}".format(query_prefix, i) for i in texts]
    query_tokens = self.tokenizer(queries_with_prefix, padding=True, truncation=True, return_tensors='pt', max_length=512)

    query_tokens.to(self.device)
    self.model.eval()

    #document_tokens =  self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=512)
    # Compute token embeddings
    with torch.no_grad():
        query_embeddings = self.model(**query_tokens)[0][:, 0]
        #doument_embeddings = self.model(**document_tokens)[0][:, 0]

    # normalize embeddings
    query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
    #doument_embeddings = torch.nn.functional.normalize(doument_embeddings, p=2, dim=1)
    return query_embeddings.detach().cpu().numpy()
  
doc_encoder_artic = SnowFlakeDocumentEncoder('Snowflake/snowflake-arctic-embed-m', device="cuda:0")
q_encoder_artic = SnowFlakeQueryEncoder('Snowflake/snowflake-arctic-embed-m')

Some weights of BertModel were not initialized from the model checkpoint at Snowflake/snowflake-arctic-embed-m and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertModel were not initialized from the model checkpoint at Snowflake/snowflake-arctic-embed-m and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [59]:
from fast_forward import OnDiskIndex, Mode
from pathlib import Path

ff_index_artic = OnDiskIndex.load(
    Path("../datam/ffindex_trec-covid_snowflake.h5"), query_encoder=q_encoder_artic, mode=Mode.MAXP
)

100%|██████████| 171332/171332 [00:00<00:00, 445337.00it/s]


In [60]:
ff_index_artic = ff_index_artic.to_memory()

In [61]:
from fast_forward.util.pyterrier import FFScore

ff_score_artic = FFScore(ff_index_artic)

In [62]:
re_ranked = ff_score_artic(candidates) # time it - do it separately from the experiments
re_ranked

Unnamed: 0,qid,docno,score_0,score,query
0,1,dv9m19yk,15.639456,0.490659,what is the origin of covid 19
1,1,kgifmjvb,15.447703,0.462174,what is the origin of covid 19
2,1,wmfcey6f,15.423234,0.471067,what is the origin of covid 19
3,1,4dtk1kyh,14.592304,0.450673,what is the origin of covid 19
4,1,deee71uw,14.186018,0.434492,what is the origin of covid 19
...,...,...,...,...,...
245,50,xbze5s3c,31.478040,0.577503,what is known about an mrna vaccine for the sa...
246,50,aju2nr9x,29.408917,0.495342,what is known about an mrna vaccine for the sa...
247,50,1v0f2dtx,29.211110,0.501794,what is known about an mrna vaccine for the sa...
248,50,ptvsie6m,26.888638,0.472996,what is known about an mrna vaccine for the sa...


In [63]:
re_ranked_reduced = ff_score_artic(candidates_reduced) # time it - do it separately from the experiments
re_ranked_reduced

Unnamed: 0,qid,docno,score_0,score,query
0,1,6ck2ntid,17.748472,0.442458,origin of covid19
1,1,8a6flxl6,17.748472,0.442458,origin of covid19
2,1,myfw58yo,15.791321,0.218028,origin of covid19
3,1,offpyz12,15.791321,0.218028,origin of covid19
4,1,kqj4wwx5,15.426276,0.359929,origin of covid19
...,...,...,...,...,...
245,50,puqcbf8t,23.643785,0.369228,MRNA vaccine for SARSCoV2
246,50,4qcbwezv,23.194012,0.356289,MRNA vaccine for SARSCoV2
247,50,ea1r90io,22.702625,0.370357,MRNA vaccine for SARSCoV2
248,50,vui9ybc1,22.414568,0.376446,MRNA vaccine for SARSCoV2


In [64]:
from fast_forward.util.pyterrier import FFInterpolate

ff_int = FFInterpolate(alpha=0.1)

In [65]:
pt.Experiment(
    [bm25 % 1000 >> ff_score_artic >> ff_int],
    dataset.get_topics(variant="text"),
    dataset.get_qrels(),
    eval_metrics=[RR @ 10, nDCG @ 10, MAP @ 100],
    names=["BM25 >> FF"],
)


Unnamed: 0,name,RR@10,nDCG@10,AP@100
0,BM25 >> FF,0.871667,0.672455,0.090754


In [66]:
pt.Experiment(
    [bm25 % 1000 >> ff_score_artic >> ff_int],
    reduced_topics,
    dataset.get_qrels(),
    eval_metrics=[RR @ 10, nDCG @ 10, MAP @ 100],
    names=["BM25 >> FF"],
)

Unnamed: 0,name,RR@10,nDCG@10,AP@100
0,BM25 >> FF,0.744024,0.528044,0.066244
