In [None]:
import json
import numpy as np
from intent_exs import IntentEXS
from pyserini.search.lucene import LuceneSearcher  
index_path = 'datasets/dbpedia-entity/pyserini/dbpedia-entity-small.index'
searcher = LuceneSearcher(index_path)   # load a searcher from pre-computed index.

In [2]:
query = 'Szechwan dish food cuisine'
hits = searcher.search(query)
# Print the first 10 hits:
for i in range(0, 10):
    print(f'{i+1:2} {hits[i].docid:15} {hits[i].score:.5f}')

 1 <dbpedia:Dish_(food)> 8.47820
 2 <dbpedia:National_dish> 7.26490
 3 <dbpedia:2007_Vietnam_food_scare> 7.10600
 4 <dbpedia:Side_dish> 7.02890
 5 <dbpedia:Chifle> 6.98190
 6 <dbpedia:Food_presentation> 6.96550
 7 <dbpedia:Street_food_of_Chennai> 6.77350
 8 <dbpedia:Ragda_pattice> 6.73670
 9 <dbpedia:Khichdi> 6.63490
10 <dbpedia:Khmer_(food)> 6.49920


In [3]:
# extract the retrieved doc ids and doc contents.
doc_ids = [hit.docid for hit in hits]
docs = dict([(hit.docid, json.loads(searcher.doc(hit.docid).raw())['contents']) for hit in hits])

In [4]:
# Load a reranking model
from beir.reranking.models import CrossEncoder
model = 'cross-encoder/ms-marco-electra-base'
reranker = CrossEncoder(model)

In [None]:
# build query-doc pair for reranking model as input.
sentence_pairs = []
for doc_id in doc_ids:
    doc_text = docs[doc_id]
    sentence_pairs.append([query, doc_text])
rerank_scores = reranker.predict(sentence_pairs, batch_size=10)

In [17]:
# show reranked docs.
reranked_docids = np.array(doc_ids)[np.argsort(rerank_scores)[::-1]]
for doc_id in reranked_docids:
    print(doc_id)

<dbpedia:Khichdi>
<dbpedia:Ragda_pattice>
<dbpedia:Dish_(food)>
<dbpedia:National_dish>
<dbpedia:Street_food_of_Chennai>
<dbpedia:Khmer_(food)>
<dbpedia:Food_presentation>
<dbpedia:Chifle>
<dbpedia:Side_dish>
<dbpedia:2007_Vietnam_food_scare>


In [6]:
# build corpus for IntentEXS explain function
corpus = {'query': query,
        'scores': dict([(doc_id, score) for doc_id, score in zip(doc_ids, rerank_scores)]),
        'docs': docs
}
params = {'top_idf': 10, 'topk': 5, 'max_pair': 100, 'max_intent': 10, 'style': 'random'}

In [7]:
# Init the IntentEXS object.
Intent = IntentEXS(reranker, index_path, 'bm25')

In [None]:
expansion = Intent.explain(corpus, params)

In [9]:
expansion

['dish', 'indian', 'many', 'variety', 'bhel', 'traditional', 'rice']