In [1]:
!pip install -U sentence-transformers
!pip install transformers

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import numpy as np
import sys
import pandas as pd
from transformers import AutoModel, AutoTokenizer, BertTokenizer, BertModel, TapasTokenizer, TapasForSequenceClassification
import jsonlines
import time
import os
from sentence_transformers import SentenceTransformer, util
from typing import List, Dict

# Set tokenizers parallelism to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [21]:
# Load once (simulate TaBERT-style scoring)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased').to(device)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

In [4]:
def fetch_chunks(table_id):
  # Read chunks from chunks.json that match the target table
    print("Loading and filtering chunks for target table...")
    # chunks = []
    row_chunks = []
    column_chunks = []
    chunk_count = 0
    matched_chunks = 0

    with jsonlines.open('chunks.json', 'r') as reader:
        for chunk in reader:
            chunk_count += 1

            # Check if the chunk belongs to our target table or any of its alternative IDs
            if 'metadata' in chunk and 'table_name' in chunk['metadata'] and chunk['metadata']['table_name'] in table_id:
                # chunks.append(chunk)
                if chunk['metadata']['chunk_type'] == 'row':
                    row_chunks.append(chunk)
                else:
                    column_chunks.append(chunk)
                matched_chunks += 1

    print(f"Found {matched_chunks} chunks that match table ID '{table_id}' out of {chunk_count} total chunks")

    return row_chunks, column_chunks

In [24]:
def score_column_chunks_tabert(query: str, column_chunks: List[Dict]):
    """
    Compute semantic similarity between a query and each column chunk using [CLS] embedding similarity,
    with support for GPU via device argument.

    Args:
        query (str): Natural language query.
        column_chunks (List[Dict]): List of column chunks in your format.
        device (torch.device): torch.device("cuda") or torch.device("cpu")

    Returns:
        List[Tuple[str, float]]: [(column_name, similarity_score)]
    """
    results = []

    with torch.no_grad():
        # Encode query and move to device
        query_inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)
        query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
        query_output = model(**query_inputs)
        query_vec = query_output.last_hidden_state[:, 0, :]  # CLS token

        for chunk in column_chunks:
            col_text = chunk.get("text", "")
            metadata_text = chunk.get("metadata", {}).get("metadata_text", "")
            col_name = ""

            for part in metadata_text.split(','):
                if part.strip().startswith("col:"):
                    col_name = part.strip().replace("col:", "").strip()
                    break

            col_inputs = tokenizer(col_text, return_tensors="pt", truncation=True, padding=True)
            col_inputs = {k: v.to(device) for k, v in col_inputs.items()}
            col_output = model(**col_inputs)
            col_vec = col_output.last_hidden_state[:, 0, :]  # CLS token

            sim = torch.nn.functional.cosine_similarity(query_vec, col_vec).item()
            sim = round((sim + 1) / 2, 4)  # Normalize to [0, 1]

            results.append((col_name, sim))

    return results
    

In [6]:
def filter_chunks_by_column_score(chunks, scored_columns, threshold):
    """
    Filters column chunks from the list of chunks based on precomputed SBERT scores.

    Args:
        chunks (list): List of all chunk dictionaries.
        scored_columns (list of tuples): List of (column_name, score) tuples from SBERT.
        threshold (float): Minimum score required to retain a chunk.

    Returns:
        filtered_chunks (list): Chunks whose column name score >= threshold.
    """
    # Build a set of column names to keep
    valid_columns = {col for col, score in scored_columns if score >= threshold}

    # Filter chunks that are 'column' type and have a matching valid column name
    filtered_chunks = []
    for chunk in chunks:
        metadata_text = chunk["metadata"].get("metadata_text", "")
        for part in metadata_text.split(','):
            part = part.strip()
            if part.startswith("col:"):
                col_name = part.replace("col:", "").strip()
                if col_name in valid_columns:
                    filtered_chunks.append(chunk)
                break

    return filtered_chunks

In [25]:
def score_row_chunks_tabert(query: str, row_chunks: List[Dict]):
    """
    Compute semantic similarity between a query and each row chunk using [CLS] embedding similarity.

    Args:
        query (str): Natural language query.
        row_chunks (List[Dict]): List of row chunks in your format.
        device (torch.device): torch.device("cuda") or torch.device("cpu")

    Returns:
        List[Dict]: [{"chunk_id": ..., "score": ...}]
    """
    results = []

    with torch.no_grad():
        # Tokenize and move query to device
        query_inputs = tokenizer(query, return_tensors="pt", truncation=True, padding=True)
        query_inputs = {k: v.to(device) for k, v in query_inputs.items()}
        query_output = model(**query_inputs)
        query_vec = query_output.last_hidden_state[:, 0, :]  # CLS token

        for chunk in row_chunks:
            row_text = chunk.get("text", "")
            chunk_id = chunk.get("metadata", {}).get("chunk_id", "")

            row_inputs = tokenizer(row_text, return_tensors="pt", truncation=True, padding=True)
            row_inputs = {k: v.to(device) for k, v in row_inputs.items()}
            row_output = model(**row_inputs)
            row_vec = row_output.last_hidden_state[:, 0, :]  # CLS token

            sim = torch.nn.functional.cosine_similarity(query_vec, row_vec).item()
            sim = round((sim + 1) / 2, 4)  # Normalize to [0, 1]

            results.append({
                "chunk_id": chunk_id,
                "score": sim
            })

    return results

In [8]:
def filter_chunks_by_row_score(chunks, scored_rows, threshold):
    """
    Filters row chunks from the list of chunks based on precomputed semantic similarity scores.

    Args:
        chunks (list): List of all chunk dictionaries (from chunk_row).
        scored_rows (list of dict): List of {'chunk_id': str, 'score': float} dicts.
        threshold (float): Minimum score required to retain a chunk.

    Returns:
        filtered_chunks (list): Chunks whose score >= threshold.
    """
    # Build a set of valid chunk_ids to retain
    valid_ids = {row['chunk_id'] for row in scored_rows if row['score'] >= threshold}

    # Filter chunks based on matching chunk_id
    filtered_chunks = [
        chunk for chunk in chunks
        if chunk['metadata']['chunk_id'] in valid_ids
    ]

    return filtered_chunks

In [9]:
def dynamic_threshold(scores, alpha):
    """
    Calculate a dynamic threshold using mean + alpha * std deviation.

    Args:
        scores (list or np.array): List of similarity scores (floats between 0 and 1).
        alpha (float): Multiplier for standard deviation (default 0.5).

    Returns:
        float: Threshold value
    """
    # alpha = adaptive_alpha(scores)
    # scores = np.array(scores)

    # if len(scores) == 0:
    #     raise ValueError("Score list is empty.")

    # mean = np.mean(scores)
    # std_dev = np.std(scores)
    # threshold = mean + alpha * std_dev

    # return threshold
    median = np.median(scores)
    std_dev = np.std(scores)
    return median + alpha * std_dev

In [10]:
def column_chunks_to_dataframe(column_chunks):
    """
    Converts a list of column chunks into a pandas DataFrame.

    Args:
        column_chunks (list): List of chunks, each containing a column in 'text' and metadata.

    Returns:
        pd.DataFrame: Structured DataFrame with headers and rows.
    """
    data = {}

    for chunk in column_chunks:
        text = chunk.get("text", "")

        # Only process if the format is correct
        if ":" in text:
            header, values_str = text.split(":", 1)
            header = header.strip()

            if not header:
                continue

            # Split values and remove empty strings
            values = [v.strip() for v in values_str.strip().split("|") if v.strip()]
            data[header] = values

    # Normalize column lengths by padding with empty strings
    max_len = max((len(vals) for vals in data.values()), default=0)
    for header, values in data.items():
        if len(values) < max_len:
            values.extend([""] * (max_len - len(values)))

    # Convert to DataFrame
    df = pd.DataFrame.from_dict(data, orient='columns')

    return df

In [11]:
def row_chunks_to_dataframe(row_chunks):
    """
    Converts row-based chunks with inline 'Header: Value' format into a structured DataFrame.

    Args:
        row_chunks (list): List of row-type chunk dicts.

    Returns:
        pd.DataFrame: Structured table with one row per chunk and appropriate columns.
    """
    rows = []
    all_columns = set()

    for chunk in row_chunks:
        text = chunk.get("text", "")
        column_names = chunk.get("metadata", {}).get("columns", [])
        row_data = dict.fromkeys(column_names, "")  # initialize with empty strings

        # Split on pipe and extract 'key: value' pairs
        for part in text.split("|"):
            if ":" in part:
                key, value = part.split(":", 1)
                key = key.strip()
                value = value.strip()
                if key in column_names:
                    row_data[key] = value

        rows.append(row_data)
        all_columns.update(column_names)

    # Build DataFrame with all columns
    df = pd.DataFrame(rows, columns=sorted(all_columns))
    return df

In [12]:
def intersect_row_and_column_dfs(df_row, df_col):
    """
    Computes row-wise intersection between a row-based DataFrame and a column-based DataFrame.

    Args:
        df_row (pd.DataFrame): DataFrame constructed from row chunks.
        df_col (pd.DataFrame): DataFrame constructed from column chunks.

    Returns:
        pd.DataFrame: Rows common to both DataFrames (intersection).
    """
    # Align columns
    common_cols = sorted(set(df_row.columns).intersection(set(df_col.columns)))
    if not common_cols:
        print("⚠️ No overlapping columns found between the DataFrames.")
        return pd.DataFrame()

    df_row_sub = df_row[common_cols].copy()
    df_col_sub = df_col[common_cols].copy()

    # Drop NA to avoid mismatch due to missing values
    df_row_sub = df_row_sub.dropna()
    df_col_sub = df_col_sub.dropna()

    # Deduplicate if necessary
    df_row_sub = df_row_sub.drop_duplicates()
    df_col_sub = df_col_sub.drop_duplicates()

    # Perform intersection
    intersected = pd.merge(df_row_sub, df_col_sub, how='inner')

    return intersected

In [13]:
def dataframe_to_json_entry(df, table_id):
    """
    Convert a pandas DataFrame into a JSON-serializable dict matching the required format.

    Args:
        df (pd.DataFrame): The pruned table.
        table_id (str): Unique table identifier.

    Returns:
        dict: A JSON-compatible dictionary.
    """

    json_entry = {
        "id": table_id,
        "table": {
            "columns": [{"text": str(col)} for col in df.columns],
            "rows": [{"cells": [{"text": str(cell)} for cell in row]} for _, row in df.iterrows()],
            "tableId": table_id,
        }
    }

    return json_entry

In [16]:
def prune_table(table_id, question):
    ### Fetch chunks for the given table ID ###
    row_chunks, column_chunks = fetch_chunks(table_id)
    columns_df = column_chunks_to_dataframe(column_chunks)
    row_df = row_chunks_to_dataframe(row_chunks)
    # print("Column DF:", columns_df)
    # print("Row DF:", row_df)

    ### Column Level Pruning ###
    # column_headers_with_values = fetch_column_headers_with_values(column_chunks)

    # column_scores = score_column_headers_sbert(question, column_headers)
    # column_scores = score_column_headers_with_tapas(question, column_headers_with_values)
    column_scores = score_column_chunks_tabert(question, column_chunks)
    for col, score in column_scores:
        print(f"  {col:15s} → {score:.4f}")

    # print(f"Relevance score for table: {column_scores:.4f}")

    column_scores_only = [score for _, score in column_scores]
    alpha = -0.2 #check if this can be learned
    column_threshold = dynamic_threshold(column_scores_only, alpha)
    print("Column Threshold:", column_threshold)

    filtered_column_chunks = filter_chunks_by_column_score(column_chunks, column_scores, column_threshold)
    # print(filtered_column_chunks)

    ### Row level Pruning ###
    # row_scores = score_row_tabert(row_chunks, question)
    row_scores = score_row_chunks_tabert(question, row_chunks)

    for item in row_scores:
        print(f"{item['chunk_id']} → Score: {item['score']}")

    row_scores_only = [chunk["score"] for chunk in row_scores]
    alpha = 0.5 #check if this can be learned
    row_threshold = dynamic_threshold(row_scores_only, alpha)
    print("Row Threshold:", row_threshold)

    filtered_row_chunks = filter_chunks_by_row_score(row_chunks, row_scores, row_threshold)

    # print(f"Computing similarity scores... done!")

    filtered_columns_df = column_chunks_to_dataframe(filtered_column_chunks)
    filtered_row_df = row_chunks_to_dataframe(filtered_row_chunks)

    # print("column DF")
    # print(filtered_columns_df)
    # print('-' * 80)
    # print("row DF")
    # print(filtered_row_df)

    pruned_df = intersect_row_and_column_dfs(filtered_row_df, filtered_columns_df)

    print('-' * 80)
    print("final DF")
    print(pruned_df)

    return pruned_df

In [26]:
# Main execution block
if __name__ == "__main__":

    # Determine which table ID to use
    query_csv = "query6_TopTables.csv"
    query_csv_df = pd.read_csv(query_csv)
    question = query_csv_df['query'][0]
    print(question)
    table_list = query_csv_df['top tables'].tolist()
    target_table_id = query_csv_df['target table'][0]
    goal_answer = query_csv_df['target answer'][0]

    # all_entries = []

    start_time = time.time()

    with open("all_pruned_tables_tabert.json", "w", encoding="utf-8") as f:
        for table_id in table_list[:10]:

            pruned_df = prune_table(table_id, question)

            if not pruned_df.empty:
                entry = dataframe_to_json_entry(pruned_df,table_id)
                f.write(json.dumps(entry, ensure_ascii=False) + '\n')
    end_time = time.time()

    print(f"Execution time: {end_time - start_time:.4f} seconds")


who is the chief of the gods according to ancient greek myth
Loading and filtering chunks for target table...
Found 13 chunks that match table ID 'List of pharaohs_A3680D6D69C5E013' out of 2881668 total chunks
  Name            → 0.8877
  Image           → 0.9160
  Comments        → 0.8399
  Dates           → 0.8858
                  → 0.8339
Column Threshold: 0.8795692555821957
List of pharaohs_A3680D6D69C5E013_row_0 → Score: 0.8662
List of pharaohs_A3680D6D69C5E013_row_1 → Score: 0.879
List of pharaohs_A3680D6D69C5E013_row_2 → Score: 0.8642
List of pharaohs_A3680D6D69C5E013_row_3 → Score: 0.86
List of pharaohs_A3680D6D69C5E013_row_4 → Score: 0.8677
List of pharaohs_A3680D6D69C5E013_row_5 → Score: 0.8123
List of pharaohs_A3680D6D69C5E013_row_6 → Score: 0.8694
List of pharaohs_A3680D6D69C5E013_row_7 → Score: 0.8645
Row Threshold: 0.8748020314450123
--------------------------------------------------------------------------------
final DF
          Dates Image   Name
0  2589–2566 BC     