# **Table Manners**: Improving Digestion of Tabular Data via Contextualization and Handling of Non-Uniform Tables
| Title    | **Table Manners**: Improving Digestion of Tabular Data via Contextualization and Handling of Non-Uniform Tables   |
|-----------|----------------|
| Authors  | Ashish Thakur & Emily Okabe|
| Mentor   | Heidi Zhang         |

## ColBERTv2: Indexing & Search Notebook
Based on [ColBERTv2 notebook](https://colab.research.google.com/github/stanford-futuredata/ColBERT/blob/main/docs/intro2new.ipynb), sourced from [ColBERTv2](https://github.com/stanford-futuredata/ColBERT).

If you're working in Google Colab, we recommend selecting "GPU" as your hardware accelerator in the runtime settings.

## **Preparing collections, queries, and answers**
*First, we'll import the relevant classes. Note that `Indexer` and `Searcher` are the key actors here. Next, we'll download the necessary dependencies.*

In [None]:
!git -C ColBERT/ pull || git clone https://github.com/stanford-futuredata/ColBERT.git
import sys; sys.path.insert(0, 'ColBERT/')


In [None]:
try:
    import google.colab
    !pip install -U pip
    !pip install -e ColBERT/['faiss-gpu','torch']
except Exception:
  import sys; sys.path.insert(0, 'ColBERT/')
  try:
    from colbert import Indexer, Searcher
  except Exception:
    print("If you're running outside Colab, please make sure you install ColBERT in conda following the instructions in ColBERTv2's README. You can also install (as above) with pip but it may install slower or less stable faiss or torch dependencies. Conda is recommended.")
    assert False

In [None]:
import colbert

In [None]:
from colbert import Indexer, Searcher
from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Queries, Collection

In [None]:
from datasets import load_dataset

#### Downloading "train" split of CompMix Dataset
*Queries are narrowed down to those pertaining to table data*

In [None]:
queries_dataset = load_dataset("pchristm/CompMix", split="train")
queries = [row['question'] for row in queries_dataset if row['answer_src'] == 'table']
answers = [row['answer_text'] for row in queries_dataset if row['answer_src'] == 'table']

#### Downloading `wiki-all-8-4-tamber` variant of the `castorini/odqa-wiki-corpora` Dataset
*Please ensure that you have sufficient disk space available before downloading wiki corpora dataset (recommend at least 50GB)*

In [None]:
collection_dataset = load_dataset("castorini/odqa-wiki-corpora", "wiki-all-8-4-tamber")
original_collection = [row['text'] for row in collection_dataset]

f'Loaded {len(queries)} queries and {len(original_collection):,} passages'

### Preparing modified collection
*Modified collection is prepared by replacing linearized table data from original collection with our enhanced lineareized table date (improved contextualization, handling of more non-uniform/non-standard tables, etc.)*

In [None]:
import csv

modified_collection = [row['text'] for row in collection_dataset if row['text'].count('|') <= 3]

tsv_file_path = "/content/linearized_table_data.tsv"

with open(tsv_file_path, 'r') as file:
    tsv_reader = csv.DictReader(file, delimiter='\t')
    for row in tsv_reader:
        original_collection.append(row['text'])

## **Indexing**

*For an efficient search, we can pre-compute the ColBERT representation of each set of passages and index each of them.*

*Below, the `Indexer` takes a model checkpoint and writes a (compressed) index to disk. We then prepare a `Searcher` for retrieval from each index.*

In [None]:
nbits = 2   # encode each dimension with 2 bits
doc_maxlen = 500 # truncate passages at 500 tokens

original_dataset_index_name = f'wiki-all-8-4-tamber.original.{nbits}bits'
modified_dataset_index_name = f'wiki-all-8-4-tamber.modified.{nbits}bits'

*Now run the `Indexer` on each collection. This may take over 24 hours on a T4 GPU.*

In [None]:
checkpoint = 'colbert-ir/colbertv2.0'

with Run().context(RunConfig(nranks=1, experiment='notebook')):  # nranks specifies the number of GPUs to use
    config = ColBERTConfig(doc_maxlen=doc_maxlen, nbits=nbits, kmeans_niters=4) # kmeans_niters specifies the number of iterations of k-means clustering; 4 is a good and fast default.

    original_indexer = Indexer(checkpoint=checkpoint, config=config)
    original_indexer.index(name=original_dataset_index_name, collection=original_collection, overwrite=True)

    modified_indexer = Indexer(checkpoint=checkpoint, config=config)
    modified_indexer.index(name=modified_dataset_index_name, collection=modified_collection, overwrite=True)

## **Searching and Evaluation**

*Having built each of the indexes, we now prepare our searchers and evaluate the top-k retrieval accuracies across each collection.*

In [None]:
def is_valid_word(word: str):
    """Check if the word is valid (not a common/trivial word and not a single character unless a digit)."""
    common_words = {"an", "the", "or", "of", "and"}
    return word not in common_words and (len(word) > 1 or word.isdigit())

def preprocess_answer(answer: str):
    """Preprocess the answer string by replacing punctuation with spaces and filtering out invalid words."""
    punctuation = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
    answer_no_punct = answer.translate(str.maketrans(punctuation, ' ' * len(punctuation)))
    return [word for word in answer_no_punct.split() if is_valid_word(word)]

def check_passage_for_answer(passage: str, answer_words: list[str]):
    """Check if the passage contains at least one of the preprocessed answer words."""
    return any(word in passage for word in answer_words)

In [None]:
def search_and_evaluate(searcher, queries, answers, top_k=10):
    accuracy = {k: 0 for k in range(1, top_k+1)}
    total_queries = len(queries)

    for query, answer in zip(queries, answers):
        results = searcher.search(query, k=top_k)
        answer_words = preprocess_answer(answer)
        found = False

        for rank, (passage_id, _, _) in enumerate(zip(*results), start=1):
            if check_passage_for_answer(searcher.collection[passage_id], answer_words):
                for k in range(rank, top_k+1):
                    accuracy[k] += 1
                break

    for k in range(1, top_k+1):
        accuracy[k] /= total_queries
        print(f"Top-{k} Retrieval Accuracy: {accuracy[k]:.2%}")

In [None]:
with Run().context(RunConfig(experiment='notebook')):
    original_searcher = Searcher(index=original_dataset_index_name, collection=original_collection)
    modified_searcher = Searcher(index=modified_dataset_index_name, collection=modified_collection)

    print("Evaluating Original Collection:")
    search_and_evaluate(original_searcher, queries, answers, top_k=10)

    print("\nEvaluating Modified Collection:")
    search_and_evaluate(modified_searcher, queries, answers, top_k=10)