In [None]:
import os
os.chdir("../")

# Overlap between train set of totto and test set of wikitq

In [None]:
from tqdm import tqdm
from joblib import Parallel, delayed
import copy

In [None]:
from datasets import load_dataset

In [None]:
totto_dataset = load_dataset("GEM/totto")["train"]

In [None]:
wikitq_dataset = load_dataset("wikitablequestions")["test"]

In [None]:
totto_tables = []
wikitq_tables = []

In [None]:
def _add_adjusted_col_offsets(table):
    """Add adjusted column offsets to take into account multi-column cells."""
    adjusted_table = []
    for row in table:
        real_col_index = 0
        adjusted_row = []
        for cell in row:
            adjusted_cell = copy.deepcopy(cell)
            adjusted_cell["adjusted_col_start"] = real_col_index
            adjusted_cell["adjusted_col_end"] = (
                adjusted_cell["adjusted_col_start"] + adjusted_cell["column_span"]
            )
            real_col_index += adjusted_cell["column_span"]
            adjusted_row.append(adjusted_cell)
        adjusted_table.append(adjusted_row)
    return adjusted_table


def _get_heuristic_col_headers(adjusted_table, row_index, col_index):
    """Heuristic to find column headers."""
    adjusted_cell = adjusted_table[row_index][col_index]
    adjusted_col_start = adjusted_cell["adjusted_col_start"]
    adjusted_col_end = adjusted_cell["adjusted_col_end"]
    col_headers = []
    for r in range(0, row_index):
        row = adjusted_table[r]
        for cell in row:
            if (
                cell["adjusted_col_start"] < adjusted_col_end
                and cell["adjusted_col_end"] > adjusted_col_start
            ):
                if cell["is_header"]:
                    col_headers.append(cell)

    return col_headers


def get_totto_full_table(table, cell_indices, table_page_title = None, table_section_title = None):

    """Verbalize full table and return a string."""
    table_str = "Start of a new table with repetition of column names in between for your reference\n"
    if table_page_title:
        table_str += "<page_title> " + table_page_title + " </page_title> "
    if table_section_title:
        table_str += "<section_title> " + table_section_title + " </section_title> "

    adjusted_table = _add_adjusted_col_offsets(table)

    col_headers = []
    for r_index, row in enumerate(table):
        row_str = "<row> "
        for c_index, col in enumerate(row):
            col_header = _get_heuristic_col_headers(adjusted_table, r_index, c_index)
            
            if r_index == 1:
                for ch in col_header:
                    if ch["value"] not in col_headers:
                        col_headers.append(ch["value"])


    highlighted_cells = []
    table_dict = {"header": col_headers, "rows": []}
    for r_index, row in enumerate(table):
        
        if r_index == 0:
            continue

        row_list = []
        for c_index, col in enumerate(row):
            
            # Select the highlighted cell
            if [r_index, c_index] in cell_indices:
                highlighted_cells.append(col["value"])

            # The value of the cell.
            row_list.append(col["value"])


        table_dict["rows"].append(row_list)

    return table_dict, highlighted_cells

In [None]:
processed_data = Parallel(n_jobs = -1)(
        delayed(get_totto_full_table)(data["table"], data["highlighted_cells"]) for i, data in tqdm(enumerate(totto_dataset), position = 0, leave = True, total = len(totto_dataset))
    )

In [None]:
len(processed_data)

In [None]:
totto_tables = []

In [None]:
for i in tqdm(range(len(processed_data)), position = 0, leave = True, total = len(processed_data)):
    
    table_column_names = processed_data[i][0]["header"]
    table_column_names = [x.lower() for x in table_column_names]
    table_content_values = processed_data[i][0]["rows"]

    table = "[HEADER] " + " | ".join(table_column_names)
    for row_id, row in enumerate(table_content_values):
        row = [x.lower() for x in row]
        table += f" [ROW] {row_id}: " + " | ".join(row)

    totto_tables.append(table)

In [None]:
wikitq_tables = []

In [None]:
for i in tqdm(range(len(wikitq_dataset)), position = 0, leave = True, total = len(wikitq_dataset)):
    table_column_names = wikitq_dataset[i]["table"]["header"]
    table_column_names = [x.lower() for x in table_column_names]
    table_content_values = wikitq_dataset[i]["table"]["rows"]

    table = "[HEADER] " + " | ".join(table_column_names)
    for row_id, row in enumerate(table_content_values):
        row = [x.lower() for x in row]
        table += f" [ROW] {row_id}: " + " | ".join(row)

    wikitq_tables.append(table)

In [None]:
overlap_tables = []

In [None]:
for table in tqdm(wikitq_tables, position = 0, leave = True, total = len(wikitq_tables)):
    if table in totto_tables:
        overlap_tables.append(table)

In [None]:
len(overlap_tables)