# Baseline System 

In [1]:
from elasticsearch  import Elasticsearch
from typing import Dict, List, Optional
import json

In [2]:
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
nltk.download('punkt')
stop_words = set(stopwords.words('english'))

from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import MonoBERT

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/subankankarunakaran/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
2021-11-20 16:55:50 [INFO] loader: Loading faiss with AVX2 support.
2021-11-20 16:55:50 [INFO] loader: Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
2021-11-20 16:55:50 [INFO] loader: Loading faiss.
2021-11-20 16:55:50 [INFO] loader: Successfully loaded faiss.


In [4]:
INDEX_NAME = "cast_base"

In [5]:
es = Elasticsearch()

In [6]:
class CAsT_base():
    def __init__(self, context_responses: int = 0, reranking: bool = False,remove_stopwords: bool = True) -> None:
        self.INDEX_NAME = "cast_base"
        self.es = Elasticsearch()
        self.queries = []
        self.responses = []
        self.reranking = reranking
        self.reranker = MonoBERT() if reranking else None 
        self.context_responses = context_responses
        self.remove_stopwords = remove_stopwords

    def clear_context(self, clear_queries: bool = True, clear_responses: bool = True):
        if clear_queries:
            self.queries = []
        if clear_responses:
            self.responses = []
 
    def listToString(self,s: List): 
        # initialize an empty string
        str1 = " " 
        
        # return string  
        return (str1.join(s))

    def query(self, q: str) -> str:
        """ 
        Preprocessing query and scoring using bm25
        """
        stop_words = set(stopwords.words('english'))
        
        if self.remove_stopwords:
            tokens = word_tokenize(q)
            q_list = []
            for w in tokens:
                if w not in stop_words:
                    q_list.append(w)
            q = self.listToString(q_list)
            
            
        hits = es.search(
            index=self.INDEX_NAME, q=q, _source=True, size=100
        ).get("hits", {}).get("hits")
        
        
        hits_cleaned = [{
            "passage": hit.get("_source", {}).get("passage"),
            "_id": "MARCO_" + hit.get("_id") if hit.get("_source").get(
                    "origin") == "msmarco" else "CAR_" + hit.get("_id"),
            "_score": hit.get("_score", "FAILED")} for hit in hits]
        
        if self.reranking:
            print("RERANKING")
            texts = [Text(hit.get("passage"), {
                '_id': hit.get("_id", "FAILED")}, 0) for hit in hits_cleaned]

            reranked = self.reranker.rerank(Query(q), texts)
            hits_cleaned = [{
                "passage": hit.text,
                "_id": hit.metadata["_id"],
                "_score": hit.score}
                for hit in reranked]
        
        
        if len(hits) > 0:
            print("Query: " + q)
            return hits_cleaned[:1000]
        else:
            return []

## Running queries:

In [8]:
def run_queries(query_file: str, key: str, CAsT: object, run_id: str):
    queries = json.load(open(query_file))
    if queries[0].get("turn", {})[0].get(key) is None:
        raise KeyError("Provided key: " + key +
                       "is not a valid key for queryfile")
    total_num = len(queries)
    f = open(run_id + ".trec", "w")

    for i, topic in enumerate(queries):
        print("Topic: {}/{}".format(i+1, total_num))
        CAsT.clear_context()
        topic_id = topic.get("number")
        for turn in topic.get("turn"):
            turn_id = turn.get("number")
            hits = CAsT.query(turn.get(key))
            for j, hit in enumerate(hits):
                f.write(str(topic_id) + "_" + str(turn_id) + "\t" + "Q0" + "\t" + str(hit.get("_id")) +
                        "\t" + str(j) + "\t" + str(hit.get("_score")) + "\t" + str(run_id) + "\n")
    f.close()

In [9]:
path = "../eval/2020_automatic_evaluation_topics_v1.0.json"
key = "raw_utterance"

In [None]:
cast = CAsT_base(reranking=True,remove_stopwords=True)
run_queries(path, key=key, CAsT=cast, run_id="Test")