In [1]:
!pip install tqdm rank_bm25 tiktoken nltk

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import json
import os
import csv
from tqdm import tqdm
from rank_bm25 import BM25Okapi
import numpy as np
import random
import re
import tiktoken
import nltk
from nltk.corpus import stopwords
nltk.download('stopwords')

stop_words = set(stopwords.words('english'))

def tokenize(text):
    # Lowercase
    text = text.lower()

    # Remove punctuation but keep pipes/colons
    text = re.sub(r"[^a-z0-9|: ]+", " ", text)

    # Tokenize using tiktoken
    enc = tiktoken.get_encoding("cl100k_base")
    tokens = enc.decode(enc.encode(text)).split()

    # Remove stopwords
    filtered_tokens = [token for token in tokens if token not in stop_words and len(token) > 1]

    return filtered_tokens

################################ Chunking #################################

def chunk_rows_in_groups(rows, table_name, columns, group_size=3):
    grouped_chunks = []
    for i in range(0, len(rows), group_size):
        subset = rows[i:i + group_size]
        text = ' | '.join([
            f"{columns[j]['text']}: {row['cells'][j]['text']}"
            for row in subset
            for j in range(len(columns))
            if columns[j]['text']
        ])
        chunk = {
            "text": text,
            "metadata": {
                "chunk_id": f"{table_name}_group_{i // group_size}",
                "table_name": table_name,
                "chunk_type": "row_group",
                "metadata_text": f"grouped rows from {table_name} starting at row {i}"
            }
        }
        grouped_chunks.append(chunk)
    return grouped_chunks

def chunk_column(rows, col_id, col_name, table_name):
    column_text = ' | '.join([row['cells'][col_id]['text'] for row in rows if row['cells'][col_id]['text']])
    return {
        "text": f"{col_name if col_name else ''}: {column_text}",
        "metadata": {
            "table_name": table_name,
            "col_id": col_id,
            "chunk_id": f"{table_name}_column_{col_id}",
            "chunk_type": "column",
            "metadata_text": f"table: {table_name}, col: {col_name if col_name else ''}, chunk_id: {table_name}_column_{col_id}, chunk_type: column"
        }
    }

def chunk_table(rows, table_id, columns):
    column_names = " | ".join([col['text'] for col in columns])
    table_text = '\n'.join([column_names] + [' | '.join([cell['text'] for cell in row['cells']]) for row in rows])
    return {
        "text": table_text,
        "metadata": {
            "table_name": table_id,
            "chunk_id": f"{table_id}_table",
            "chunk_type": "table",
            "columns": [col["text"] for col in columns],
            "metadata_text": f"table_name: {table_id}, chunk_id: {table_id}_table, chunk_type: table, columns: {', '.join([col['text'] for col in columns])}"
        }
    }

######################## Processing ##################################

def process_jsonl(file_path):
    metadata_list, chunks, table_chunks = [], [], []

    with open(file_path, 'r') as f:
        for line in tqdm(f):
            data = json.loads(line.strip())
            table_id = data['tableId']
            rows = data['rows']
            columns = data['columns']

            grouped_row_chunks = chunk_rows_in_groups(rows, table_id, columns, group_size=3)
            for chunk in grouped_row_chunks:
                chunks.append(chunk)
                metadata_list.append(chunk["metadata"])

            # for col_id, col in enumerate(columns):
            #     if col["text"]:
            #         col_chunk = chunk_column(rows, col_id, col["text"], table_id)
            #         chunks.append(col_chunk)
            #         metadata_list.append(col_chunk["metadata"])

            table_chunk = chunk_table(rows, table_id, columns)
            chunks.append(table_chunk)
            table_chunks.append(table_chunk)

    return metadata_list, chunks, table_chunks

def rank_chunks_with_bm25(bm25, tokenized_chunks, query, top_n):
    scores = bm25.get_scores(query)
    ranked_chunks = sorted(zip(scores, tokenized_chunks), reverse=True, key=lambda x: x[0])
    return ranked_chunks[:top_n]

def save_top_chunks(top_chunks, output_dir, output_filename):
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, output_filename)
    with open(output_path, 'w') as f:
        json.dump(top_chunks, f, indent=2)
    print(f"Saved top chunks to {output_path}")

def calculate_recall(ranked_chunks, correct_table_id, top_n):
    for idx, (_, chunk) in enumerate(ranked_chunks):
        if chunk['table_id'] == correct_table_id:
            rank = idx + 1
            is_in_top_10 = 1 if rank <= top_n * 0.1 else 0
            is_in_top_20 = 1 if rank <= top_n * 0.2 else 0
            return 1, rank, is_in_top_10, is_in_top_20
    return 0, None, 0, 0

def save_tokenized_chunks(tokenized_chunks, filepath):
    with open(filepath, 'w', encoding='utf-8') as f:
        for chunk in tokenized_chunks:
            json.dump(chunk, f)
            f.write('\n')
    print(f"Saved tokenized chunks to {filepath}")

def load_tokenized_chunks(filepath):
    tokenized_chunks = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            tokenized_chunks.append(json.loads(line.strip()))
    print(f"Loaded {len(tokenized_chunks)} tokenized chunks from {filepath}")
    return tokenized_chunks

############################## MAIN #####################################

def main(tables_file_path, file_paths, output_dir, query_output_dir, top_n_values, queries_count, saved_queries_path, tokenized_chunks_file=None):
    # Check if the random queries already exist
    if os.path.exists(saved_queries_path):
        print(f"Loading saved queries from {saved_queries_path}")
        with open(saved_queries_path, 'r') as f:
            selected_queries = [json.loads(line) for line in f][500-queries_count:]
    else:
        # Combine all queries from the three files
        all_queries = []
        for file_path in file_paths:
            with open(file_path, 'r') as f:
                for line in f:
                    data = json.loads(line.strip())
                    all_queries.append(data)
        
        # Shuffle queries and select the random queries_count
        random.shuffle(all_queries)
        selected_queries = all_queries[:queries_count]
        
        # Save the random queries to a JSONL file for future use
        print(f"Saving random queries to {saved_queries_path}")
        with open(saved_queries_path, 'w') as f:
            for query in selected_queries:
                json.dump(query, f)
                f.write("\n")

    if tokenized_chunks_file and os.path.exists(tokenized_chunks_file):
        tokenized_chunks = load_tokenized_chunks(tokenized_chunks_file)
    else:
        metadata, chunks, table_chunks = process_jsonl(tables_file_path)
        chunks = sorted(chunks, key=lambda x: x["metadata"]["table_name"])

        tokenized_chunks = []
        for chunk in tqdm(chunks, desc="Tokenizing Chunks", unit="chunk"):
            table_id = chunk['metadata']['table_name']
            tokenized_text = tokenize(chunk['text'] + str(chunk['metadata']))
            tokenized_chunks.append({
                "table_id": table_id,
                "tokenized_text": tokenized_text,
            })
        if tokenized_chunks_file:
            save_tokenized_chunks(tokenized_chunks, tokenized_chunks_file)

    bm25 = BM25Okapi([chunk['tokenized_text'] for chunk in tokenized_chunks], k1=1.5, b=0.75)
    total_recall = {top_n: 0 for top_n in top_n_values}
    total_top_10 = {top_n: 0 for top_n in top_n_values}
    total_top_20 = {top_n: 0 for top_n in top_n_values}
    results = {top_n: [] for top_n in top_n_values}

    for top_n in top_n_values:
        queries = [(tokenize(q['questions'][0]['originalText']), q['questions'][0]['originalText'], q['table']['tableId'], q['questions'][0]['answer']['answerTexts'][0]) for q in selected_queries]
        scores = [bm25.get_scores(query) for query, _, _, _ in queries]
        
        top_chunks_output_path = os.path.join(output_dir, f"top_chunks_top_{top_n}.jsonl")
        with open(top_chunks_output_path, 'w', encoding='utf-8') as jsonl_file:
            # Create a folder for this top_n value
            query_output_dir = os.path.join(output_dir, f"top_chunks_top_{top_n}")
            os.makedirs(query_output_dir, exist_ok=True)
            
            for i, (query, query_text, correct_table_id, answer) in enumerate(tqdm(queries, desc=f"Processing Top {top_n} Queries", unit="query")):                
                ranked_chunks = sorted(zip(scores[i], tokenized_chunks), reverse=True, key=lambda x: x[0])[:top_n]
                recall, rank, is_in_top_10, is_in_top_20 = calculate_recall(ranked_chunks, correct_table_id, top_n)

                rows = []

                for score, chunk in ranked_chunks:
                    rows.append({
                        "query": query_text,
                        "top tables": chunk["table_id"],
                        "target table": correct_table_id,
                        "target answer": answer,
                        "score": float(score)
                    })
    
                total_recall[top_n] += recall
                total_top_10[top_n] += is_in_top_10
                total_top_20[top_n] += is_in_top_20
    
                results[top_n].append({
                    "Recall": recall * 100,
                    "Rank": rank if rank is not None else "Not found",
                    "Ans table in top 10%": is_in_top_10 * 100,
                    "Ans table in top 20%": is_in_top_20 * 100
                })

                # Save each query result as an individual CSV file
                csv_file_path = os.path.join(query_output_dir, f"query_{i}_top_{top_n}.csv")
                with open(csv_file_path, 'w', newline='', encoding='utf-8') as csvfile:
                    writer = csv.DictWriter(csvfile, fieldnames=["query", "top tables", "target table", "target answer", "score"])
                    writer.writeheader()
                    writer.writerows(rows)

        recall_percentage = (total_recall[top_n] / queries_count) * 100
        top_10_percent = (total_top_10[top_n] / queries_count) * 100
        top_20_percent = (total_top_20[top_n] / queries_count) * 100

        print(f"Overall Recall (Top {top_n}): {recall_percentage:.2f}%, Top 10%: {top_10_percent:.2f}%, Top 20%: {top_20_percent:.2f}%")
        csv_filename = f"query_results_top_{top_n}.csv"
        with open(os.path.join(output_dir, csv_filename), 'w', newline='', encoding='utf-8') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=["Recall", "Rank", "Ans table in top 10%", "Ans table in top 20%"])
            writer.writeheader()
            writer.writerows(results[top_n])

######################## Execution Parameters ########################

# Execution parameters
tables_file_path = "tables.jsonl"
output_dir = "result"
query_output_dir = "query_results"
tokenized_chunks_file = "tokenized_chunks.jsonl"
saved_queries_path = "saved_random_queries.jsonl"
file_paths = ["dev.jsonl", "train.jsonl", "test.jsonl"]
top_n = [5000, 2500, 1250, 625, 312, 156, 78, 39, 18, 10]
queries_count = 150

main(tables_file_path, file_paths, output_dir, query_output_dir, top_n, queries_count, saved_queries_path, tokenized_chunks_file)

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/sjain300/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


Loading saved queries from saved_random_queries.jsonl


169898it [00:12, 13619.97it/s]
Tokenizing Chunks: 100%|██████████| 850707/850707 [04:13<00:00, 3349.58chunk/s]


Saved tokenized chunks to tokenized_chunks.jsonl


Processing Top 5000 Queries: 100%|██████████| 150/150 [01:10<00:00,  2.12query/s]


Overall Recall (Top 5000): 97.33%, Top 10%: 85.33%, Top 20%: 89.33%


Processing Top 2500 Queries: 100%|██████████| 150/150 [01:05<00:00,  2.29query/s]


Overall Recall (Top 2500): 94.00%, Top 10%: 82.67%, Top 20%: 85.33%


Processing Top 1250 Queries: 100%|██████████| 150/150 [01:04<00:00,  2.34query/s]


Overall Recall (Top 1250): 92.00%, Top 10%: 75.33%, Top 20%: 82.67%


Processing Top 625 Queries: 100%|██████████| 150/150 [01:12<00:00,  2.06query/s]


Overall Recall (Top 625): 85.33%, Top 10%: 67.33%, Top 20%: 75.33%


Processing Top 312 Queries: 100%|██████████| 150/150 [01:29<00:00,  1.68query/s]


Overall Recall (Top 312): 84.00%, Top 10%: 61.33%, Top 20%: 67.33%


Processing Top 156 Queries: 100%|██████████| 150/150 [01:38<00:00,  1.52query/s]


Overall Recall (Top 156): 76.00%, Top 10%: 52.67%, Top 20%: 61.33%


Processing Top 78 Queries: 100%|██████████| 150/150 [01:12<00:00,  2.08query/s]


Overall Recall (Top 78): 69.33%, Top 10%: 44.00%, Top 20%: 52.67%


Processing Top 39 Queries: 100%|██████████| 150/150 [01:09<00:00,  2.17query/s]


Overall Recall (Top 39): 64.00%, Top 10%: 40.00%, Top 20%: 44.00%


Processing Top 18 Queries: 100%|██████████| 150/150 [01:08<00:00,  2.19query/s]


Overall Recall (Top 18): 56.67%, Top 10%: 27.33%, Top 20%: 40.00%


Processing Top 10 Queries: 100%|██████████| 150/150 [01:08<00:00,  2.19query/s]


Overall Recall (Top 10): 46.67%, Top 10%: 27.33%, Top 20%: 38.00%
