In [None]:
import json
import sys
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize

from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize
from sentence_transformers import SentenceTransformer
import torch

#Use the BM25 algorithm to find the most relevant table based on the query tokens and the corpus.
def bm25_retrieve_relevant_table(query, corpus,list_tables_info,actual_table_info,top_k_list):
    query_tokens = word_tokenize(query)
    corpus_tokenize = [word_tokenize(table) for table in corpus]

    bm25 = BM25Okapi(corpus_tokenize)
    scores = bm25.get_scores(query_tokens)

    # Get indices of top-k most relevant tables
    top_k_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:max(top_k_list)]
    top_k_tables = [list_tables_info[i] for i in top_k_indices]

    # Check if actual_table_info is in any of the top-k lists
    results = {k: actual_table_info in top_k_tables[:k] for k in top_k_list}
    return results

# Download NLTK tokenizer data if not already available
nltk.download('punkt')
nltk.download('punkt_tab')
def preprocess(text):
    """Tokenize and preprocess text."""
    return word_tokenize(text.lower())

In [None]:

def table_to_text(table):
    """
    Converts a table (list of lists) to a string.
    Each row is concatenated with spaces, and rows are joined with newlines.
    """
    return " ".join([" ".join(map(str, row)) for row in table])



if __name__ == "__main__":
    result = []
    for k in [50,100,200,500]:
        totto_retrieval_dataset_path = f"./data/retrieval_data/totto_retrieval_{k}.json" 
        top_k_list = [1, 5, 10, 15, 20]
        # Open and read the JSON file
        #with open(wtq_dataset_path, 'r') as file:
        #    data = json.load(file)

        
        data = []
        with open(totto_retrieval_dataset_path, 'r') as file:
            data = json.load(file)
        print("total tables: ", len(data))
        for idx, data_point in enumerate(data):
            # Prepare the query text
            query_text = f"{data_point['summary']}".lower()
            if idx % 50 == 0:
                print(f"{idx}", flush=True)
                
            for table_info in ["title_tab-description","title_column_header", "title_col_table","exact_row"]:
                actual_table_info, list_all_table_info = data_point[f"list_{table_info}_retrieval"] 
                corpus = [table.lower() for table in list_all_table_info]
                bm25_top_k_found_it = bm25_retrieve_relevant_table(query_text, corpus,list_all_table_info,actual_table_info,top_k_list) 
                # dpr_top_k_found_it = dpr_retrieve_relevant_table(query_text, corpus,list_all_table_info,actual_table_info,top_k_list)  
                
                for top_k in top_k_list: 
                    result.append({"idx":idx,
                                "top_k":top_k,
                                "table_info":table_info , 
                                "bm_25":bm25_top_k_found_it[top_k],
                                # "dpr":dpr_top_k_found_it[top_k]
                                })
                if idx % 50 == 0:
                    print(f"\t{table_info}: BM: {bm25_top_k_found_it}", flush=True)
                    if table_info == "exact_row": 
                        df = pd.DataFrame(result)
                        print(df["bm_25"].mean())
                        df.to_csv(f'./results/results_{k}.csv',index=False)
                        del df 
        df = pd.DataFrame(result)
        print(df["bm_25"].mean())
        df.to_csv(f'./results/results_{k}.csv',index=False)