In [None]:
import json, os
from tqdm import tqdm
from typing import List, Dict, Optional

from rank_bm25 import BM25Okapi
from llama_index import Document

# This is the staging flag. Set to False if you want to run on the real
# collection.
# STAGING=False
STAGING=True

def save_list_to_json(lst, filename):
  """ Save Files """
  with open(filename, 'w') as file:
    json.dump(lst, file)

def wr_dict(filename,dic):
  """ Write Files """
  try:
    if not os.path.isfile(filename):
      data = []
      data.append(dic)
      with open(filename, 'w') as f:
        json.dump(data, f)
    else:      
      with open(filename, 'r') as f:
        data = json.load(f)
        data.append(dic)
      with open(filename, 'w') as f:
        json.dump(data, f)
  except Exception as e:
    print("Save Error:", str(e))
  return
            
def rm_file(file_path):
  """ Delete Files """
  if os.path.exists(file_path):
    os.remove(file_path)
    print(f"File {file_path} removed successfully.")

def _depth_first_yield(json_data: Any, levels_back: int, collapse_length: 
                       Optional[int], path: List[str], ensure_ascii: bool = False,
                      ) -> Generator[str, None, None]:
  """ Do depth first yield of all of the leaf nodes of a JSON.
      Combines keys in the JSON tree using spaces.
      If levels_back is set to 0, prints all levels.
      If collapse_length is not None and the json_data is <= that number
      of characters, then we collapse it into one line.
  """
  if isinstance(json_data, (dict, list)):
    # only try to collapse if we're not at a leaf node
    json_str = json.dumps(json_data, ensure_ascii=ensure_ascii)
    if collapse_length is not None and len(json_str) <= collapse_length:
      new_path = path[-levels_back:]
      new_path.append(json_str)
      yield " ".join(new_path)
      return
    elif isinstance(json_data, dict):
      for key, value in json_data.items():
        new_path = path[:]
        new_path.append(key)
        yield from _depth_first_yield(value, levels_back, collapse_length, new_path)
    elif isinstance(json_data, list):
      for _, value in enumerate(json_data):
        yield from _depth_first_yield(value, levels_back, collapse_length, path)
    else:
      new_path = path[-levels_back:]
      new_path.append(str(json_data))
      yield " ".join(new_path)


# The two classes are used to parse the json corpus and queries.
class JSONReader():
  """JSON reader.
     Reads JSON documents with options to help suss out relationships between nodes.
  """
  def __init__(self, is_jsonl: Optional[bool] = False,) -> None:
    """Initialize with arguments."""
    super().__init__()
    self.is_jsonl = is_jsonl

  def load_data(self, input_file: str) -> List[Document]:
    """Load data from the input file."""
    documents = []
    with open(input_file, 'r') as file:
      load_data = json.load(file)
    for data in load_data:
      metadata = {"title": data['title'], 
                  "published_at": data['published_at'],
                  "source":data['source']}
      documents.append(Document(text=data['body'], metadata=metadata))
    return documents
    

In [None]:
def gen_bm25(corpus, queries, output_name):
    print('Remove save file if exists.')
    rm_file(output_name)

    # Read the corpus json file
    reader = JSONReader()
    data = reader.load_data(corpus)
    
    print('Corpus Data')
    print('--------------------------')
    print(data[0])
    print('--------------------------')

    corpus_texts = [doc.text for doc in data]
    tokenized_corpus = [doc.split(" ") for doc in corpus_texts]

    # Initialize BM25
    bm25 = BM25Okapi(tokenized_corpus)
    print('BM25 Initialized ...')

    # Parse the queries
    with open(queries, 'r') as file:
        query_data = json.load(file)

    print('Query Data')
    print('--------------------------')
    print(query_data[0])
    print('--------------------------')

    retrieval_save_list = []
    print("Running BM25 Retrieval ...")

    for data in tqdm(query_data):
        query = data['query']
        tokenized_query = query.split(" ")
        scores = bm25.get_scores(tokenized_query)
        
        # Get top results
        top_results = sorted(zip(scores, corpus_texts), reverse=True)[:10]

        retrieval_list = []
        for score, text in top_results:
            dic = {}
            dic['text'] = text
            dic['score'] = score
            retrieval_list.append(dic)

        save = {}
        save['query'] = data['query']
        save['answer'] = data['answer']
        save['question_type'] = data['question_type']
        save['retrieval_list'] = retrieval_list
        save['gold_list'] = data['evidence_list']
        retrieval_save_list.append(save)

    print('Retrieval complete. Saving Results')
    with open(output_name, 'w') as json_file:
        json.dump(retrieval_save_list, json_file)

if __name__ == '__main__':
    if STAGING:
        corpus = "data/sample-corpus.json"
        queries = "data/sample-rag.json"
    else:
        corpus = "data/corpus.json"
        queries = "data/rag.json"
        
    output_name = "output/bm25-retrieval.json"

    gen_bm25(corpus, queries, output_name)