# Fine-tuning TAPAS with SPIDER and QATCH

TAPAS is a question-answering model. To train or fine-tune TAPAS on a custom dataset, in our case SPIDER, we need:
- Tables
- Natural language questions
- Answers from the tables
- Answer coordinates in the tables

Unfortunately, we lack the latter two (answers and their coordinates) since SPIDER only provides the natural language questions and SQL queries. Therefore, we need to execute the queries, retrieve the answers, and determine their coordinates for the entire dataset. This task is challenging to complete in a single step.

After preparing and cleaning the data, we need to feed it to TAPAS and start the fine-tuning process. I have followed [the Hugging Face/Google Research tutorial](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TAPAS/Fine_tuning_TapasForQuestionAnswering_on_SQA.ipynb). This Colab notebook (Hugging Face documentation for TAPAS) fine-tunes TAPAS on a very small portion of [the Microsoft Research Sequential Question Answering (SQA) Dataset](https://www.microsoft.com/en-us/research/publication/search-based-neural-structured-learning-sequential-question-answering/), which was specifically designed for TAPAS and requires no further preprocessing or cleaning.

I aimed to apply a similar approach with a single database from the SPIDER dataset, as a proof of concept to construct a data preparation pipeline that can then be applied to the entire dataset.

The goal is to observe improvements in the model's performance as we provide it with increasingly more synthetic data generated from QATCH.

SPIDER contains SQL examples that are incompatible with QATCH, as QATCH only supports projection, selection, and conditions—no joins, no aggregations. However, SPIDER includes numerous queries that use these operations, so we must remove them.


In [None]:
!pip install --quiet datasets
!pip install --quiet frozendict
!pip install --quiet transformers
!pip install --quiet qatch

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.9/71.9 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.0/13.0 MB[0m [31m42.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.3/38.3 MB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.4/337.4 kB[0m [31m31.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━

In [None]:
# ! pip install --quiet torch-scatter

### The Coordinates of the Answers

Getting the coordinates of the answers in the format that Tapas expects is a very difficult task, and doing it for the entire dataset is even more challenging. I found a script in a [GitHub repo](https://github.com/NielsRogge/tapas_utils?tab=readme-ov-file) which performs this task. However, we still face an issue as it struggles with duplicates in the dataset, throwing an error when it becomes ambiguous to determine the coordinates.

By coordinates, I mean that if we have a tuple (V1, V2, V3) in the answer, we need to get something like [(x1, y1), (x2, y2), (x3, y3)], where x is the index of the row and y is the index of the column. In our case, the database (SQLite3) file where the data is stored doesn't support this, so we should import the data into dataframes to establish some kind of ordering.


In [None]:
#@title Script to get the coordinates of the answers
# coding=utf-8
# Copyright 2019 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""This module implements a simple parser that can be used for TAPAS.

Given a table, a question and one or more answer_texts, it will parse the texts
to populate other fields (e.g. answer_coordinates, float_value) that are required
by TAPAS.

Please note that exceptions in this module are concise and not parameterized,
since they are used as counter names in a BEAM pipeline.
"""

import enum
from typing import Callable, List, Text, Optional

import six
import struct
import unicodedata
import re

import frozendict
import numpy as np
import scipy.optimize


class SupervisionMode(enum.Enum):
  # Don't filter out any supervised information.
  NONE = 0
  # Remove all the supervised signals and recompute them by parsing answer
  # texts.
  REMOVE_ALL = 2
  # Same as above but discard ambiguous examples
  # (where an answer matches multiple cells).
  REMOVE_ALL_STRICT = 3


def _find_matching_coordinates(table, answer_text,
                               normalize):
  normalized_text = normalize(answer_text)
  for row_index, row in table.iterrows():
    for column_index, cell in enumerate(row):
      if normalized_text == normalize(str(cell)):
        yield (row_index, column_index)


def _compute_cost_matrix_inner(
    table,
    answer_texts,
    normalize,
    discard_ambiguous_examples,
):
  """Returns a cost matrix M where the value M[i,j] contains a matching cost from answer i to cell j.

  The matrix is a binary matrix and -1 is used to indicate a possible match from
  a given answer_texts to a specific cell table. The cost matrix can then be
  usedto compute the optimal assignments that minimizes the cost using the
  hungarian algorithm (see scipy.optimize.linear_sum_assignment).

  Args:
    table: a Pandas dataframe.
    answer_texts: a list of strings.
    normalize: a function that normalizes a string.
    discard_ambiguous_examples: If true discard if answer has multiple matches.

  Raises:
    ValueError if:
      - we cannot correctly construct the cost matrix or the text-cell
      assignment is ambiguous.
      - we cannot find a matching cell for a given answer_text.

  Returns:
    A numpy matrix with shape (num_answer_texts, num_rows * num_columns).
  """
  max_candidates = 0
  n_rows, n_columns = table.shape[0], table.shape[1]
  num_cells = n_rows * n_columns
  num_candidates = np.zeros((n_rows, n_columns))
  cost_matrix = np.zeros((len(answer_texts), num_cells))

  for index, answer_text in enumerate(answer_texts):
    found = 0
    for row, column in _find_matching_coordinates(table, answer_text,
                                                  normalize):
      found += 1
      cost_matrix[index, (row * len(table.columns)) + column] = -1
      num_candidates[row, column] += 1
      max_candidates = max(max_candidates, num_candidates[row, column])
    if found == 0:
      return None
    if discard_ambiguous_examples and found > 1:
      raise ValueError("Found multiple cells for answers")

  # TODO(piccinno): Shall we allow ambiguous assignments?
  # if max_candidates > 1:
  #   raise ValueError("Assignment is ambiguous")

  return cost_matrix


def _compute_cost_matrix(
    table,
    answer_texts,
    discard_ambiguous_examples,
):
  """Computes cost matrix."""
  for index, normalize_fn in enumerate(STRING_NORMALIZATIONS):
    try:
      result = _compute_cost_matrix_inner(
          table,
          answer_texts,
          normalize_fn,
          discard_ambiguous_examples,
      )
      if result is None:
        continue
      return result
    except ValueError:
      if index == len(STRING_NORMALIZATIONS) - 1:
        raise
  return None


def _parse_answer_coordinates(table,
                              answer_texts,
                              discard_ambiguous_examples):
  """Populates answer_coordinates using answer_texts.

  Args:
    table: a Table message, needed to compute the answer coordinates.
    answer_texts: a list of strings
    discard_ambiguous_examples: If true discard if answer has multiple matches.

  Raises:
    ValueError if the conversion fails.
  """

  cost_matrix = _compute_cost_matrix(
      table,
      answer_texts,
      discard_ambiguous_examples,
  )
  if cost_matrix is None:
    return
  row_indices, column_indices = scipy.optimize.linear_sum_assignment(
      cost_matrix)

  # create answer coordinates as list of tuples
  answer_coordinates = []
  for row_index in row_indices:
    flatten_position = column_indices[row_index]
    row_coordinate = flatten_position // len(table.columns)
    column_coordinate = flatten_position % len(table.columns)
    answer_coordinates.append((row_coordinate, column_coordinate))

  return answer_coordinates


### START OF UTILITIES FROM TEXT_UTILS.PY ###

def wtq_normalize(x):
  """Returns the normalized version of x.
  This normalization function is taken from WikiTableQuestions github, hence the
  wtq prefix. For more information, see
  https://github.com/ppasupat/WikiTableQuestions/blob/master/evaluator.py
  Args:
    x: the object (integer type or string) to normalize.
  Returns:
    A normalized string.
  """
  x = x if isinstance(x, six.text_type) else six.text_type(x)
  # Remove diacritics.
  x = "".join(
      c for c in unicodedata.normalize("NFKD", x)
      if unicodedata.category(c) != "Mn")
  # Normalize quotes and dashes.
  x = re.sub(u"[‘’´`]", "'", x)
  x = re.sub(u"[“”]", '"', x)
  x = re.sub(u"[‐‑‒–—−]", "-", x)
  x = re.sub(u"[‐]", "", x)
  while True:
    old_x = x
    # Remove citations.
    x = re.sub(u"((?<!^)\\[[^\\]]*\\]|\\[\\d+\\]|[•♦†‡*#+])*$", "",
               x.strip())
    # Remove details in parenthesis.
    x = re.sub(u"(?<!^)( \\([^)]*\\))*$", "", x.strip())
    # Remove outermost quotation mark.
    x = re.sub(u'^"([^"]*)"$', r"\1", x.strip())
    if x == old_x:
      break
  # Remove final '.'.
  if x and x[-1] == ".":
    x = x[:-1]
  # Collapse whitespaces and convert to lower case.
  x = re.sub(r"\s+", " ", x, flags=re.U).lower().strip()
  x = re.sub("<[^<]+?>", "", x)
  x = x.replace("\n", " ")
  return x


_TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE)


def tokenize_string(x):
  return list(_TOKENIZER.findall(x.lower()))


# List of string normalization functions to be applied in order. We go from
# simplest to more complex normalization procedures.
STRING_NORMALIZATIONS = (
    lambda x: x,
    lambda x: x.lower(),
    tokenize_string,
    wtq_normalize,
)


def to_float32(v):
  """If v is a float reduce precision to that of a 32 bit float."""
  if not isinstance(v, float):
    return v
  return struct.unpack("!f", struct.pack("!f", v))[0]


def convert_to_float(value):
  """Converts value to a float using a series of increasingly complex heuristics.
  Args:
    value: object that needs to be converted. Allowed types include
      float/int/strings.
  Returns:
    A float interpretation of value.
  Raises:
    ValueError if the float conversion of value fails.
  """
  if isinstance(value, float):
    return value
  if isinstance(value, int):
    return float(value)
  if not isinstance(value, six.string_types):
    raise ValueError("Argument value is not a string. Can't parse it as float")
  sanitized = value

  try:
    # Example: 1,000.7
    if "." in sanitized and "," in sanitized:
      return float(sanitized.replace(",", ""))
    # 1,000
    if "," in sanitized and _split_thousands(",", sanitized):
      return float(sanitized.replace(",", ""))
    # 5,5556
    if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(
        ",", sanitized):
      return float(sanitized.replace(",", "."))
    # 0.0.0.1
    if sanitized.count(".") > 1:
      return float(sanitized.replace(".", ""))
    # 0,0,0,1
    if sanitized.count(",") > 1:
      return float(sanitized.replace(",", ""))
    return float(sanitized)
  except ValueError:
    # Avoid adding the sanitized value in the error message.
    raise ValueError("Unable to convert value to float")

### END OF UTILITIES FROM TEXT_UTILS.PY ###

def _parse_answer_float(answer_texts, float_value):
  if len(answer_texts) > 1:
    raise ValueError("Cannot convert to multiple answers to single float")
  float_value = convert_to_float(answer_texts[0])
  float_value = float_value

  return answer_texts, float_value


def _has_single_float_answer_equal_to(question, answer_texts, target):
  """Returns true if the question has a single answer whose value equals to target."""
  if len(answer_texts) != 1:
    return False
  try:
    float_value = convert_to_float(answer_texts[0])
    # In general answer_float is derived by applying the same conver_to_float
    # function at interaction creation time, hence here we use exact match to
    # avoid any false positive.
    return to_float32(float_value) == to_float32(target)
  except ValueError:
    return False


def _parse_question(
    table,
    original_question,
    answer_texts,
    answer_coordinates,
    float_value,
    aggregation_function,
    clear_fields,
    discard_ambiguous_examples,
):
  """Parses question's answer_texts fields to possibly populate additional fields.

  Args:
    table: a Pandas dataframe, needed to compute the answer coordinates.
    original_question: a string.
    answer_texts: a list of strings, serving as the answer to the question.
    anser_coordinates:
    float_value: a float, serves as float value signal.
    aggregation_function:
    clear_fields: A list of strings indicating which fields need to be cleared
      and possibly repopulated.
    discard_ambiguous_examples: If true, discard ambiguous examples.

  Returns:
    A Question message with answer_coordinates or float_value field populated.

  Raises:
    ValueError if we cannot parse correctly the question message.
  """
  question = original_question

  # If we have a float value signal we just copy its string representation to
  # the answer text (if multiple answers texts are present OR the answer text
  # cannot be parsed to float OR the float value is different), after clearing
  # this field.
  if "float_value" in clear_fields and float_value is not None:
    if not _has_single_float_answer_equal_to(question, answer_texts, float_value):
      del answer_texts[:]
      float_value = float(float_value)
      if float_value.is_integer():
        number_str = str(int(float_value))
      else:
        number_str = str(float_value)
      answer_texts = []
      answer_texts.append(number_str)

  if not answer_texts:
    raise ValueError("No answer_texts provided")

  for field_name in clear_fields:
    if field_name == "answer_coordinates":
        answer_coordinates = None
    if field_name == "float_value":
        float_value = None
    if field_name == "aggregation_function":
        aggregation_function = None

  error_message = ""
  if not answer_coordinates:
    try:
      answer_coordinates = _parse_answer_coordinates(
          table,
          answer_texts,
          discard_ambiguous_examples,
      )
    except ValueError as exc:
      error_message += "[answer_coordinates: {}]".format(str(exc))
      if discard_ambiguous_examples:
        raise ValueError(f"Cannot parse answer: {error_message}")

  if not float_value:
    try:
      answer_texts, float_value = _parse_answer_float(answer_texts, float_value)
    except ValueError as exc:
      error_message += "[float_value: {}]".format(str(exc))

  # Raises an exception if we cannot set any of the two fields.
  if not answer_coordinates and not float_value:
    raise ValueError("Cannot parse answer: {}".format(error_message))

  return question, answer_texts, answer_coordinates, float_value, aggregation_function


# TODO(piccinno): Use some sort of introspection here to get the field names of
# the proto.
_CLEAR_FIELDS = frozendict.frozendict({
    SupervisionMode.REMOVE_ALL: [
        "answer_coordinates", "float_value", "aggregation_function"
    ],
    SupervisionMode.REMOVE_ALL_STRICT: [
        "answer_coordinates", "float_value", "aggregation_function"
    ]
})


def parse_question(table, question, answer_texts, answer_coordinates=None, float_value=None, aggregation_function=None,
                    mode=SupervisionMode.REMOVE_ALL):
    """Parses answer_text field of a question to populate additional fields required by TAPAS.

    Args:
        table: a Pandas dataframe, needed to compute the answer coordinates. Note that one should apply .astype(str)
        before supplying the table to this function.
        question: a string.
        answer_texts: a list of strings, containing one or more answer texts that serve as answer to the question.
        answer_coordinates: optional answer coordinates supervision signal, if you already have those.
        float_value: optional float supervision signal, if you already have this.
        aggregation_function: optional aggregation function supervised signal, if you already have this.
        mode: see SupervisionMode enum for more information.

    Returns:
        A list with the question, populated answer_coordinates or float_value.

    Raises:
        ValueError if we cannot parse correctly the question string.
    """
    if mode == SupervisionMode.NONE:
        return question, answer_texts

    clear_fields = _CLEAR_FIELDS.get(mode, None)
    if clear_fields is None:
        raise ValueError(f"Mode {mode.name} is not supported")

    return _parse_question(
        table,
        question,
        answer_texts,
        answer_coordinates,
        float_value,
        aggregation_function,
        clear_fields,
        discard_ambiguous_examples=mode == SupervisionMode.REMOVE_ALL_STRICT,
    )

### Downloading the Spider Dataset
I have downloaded the [Spider dataset](https://drive.usercontent.google.com/download?id=1iRDVHLr4mX2wQKSgA9J8Pire73Jahh0m&export=download&authuser=0) and stored it on Google Drive. It might be more convenient to download and unzip it directly into the Colab environment for easier access and experimentation in the future.

This dataset is necessary for accessing the tables because importing the dataset directly in Python only provides references to the databases used, not the actual data.


In [None]:
import os
import pandas as pd
import sqlite3
import ast
from google.colab import drive
from datasets import load_dataset
import pickle

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
os.chdir("/content/drive/MyDrive/spider")

In [None]:
# We import it here because the the datset that we dowloaded will only give the
#tabkles, the querries and natural language questions are not strcured(json),
# here we can get them in a dataframe
dataset = load_dataset("spider")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/5.51k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/831k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/126k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1034 [00:00<?, ? examples/s]

In [None]:
print(dataset)


train_set = dataset['train']
validation_set = dataset['validation']

# print(train_set[0])

# print("Question:", train_set[0]['question'])
# print("SQL Query:", train_set[0]['query'])

DatasetDict({
    train: Dataset({
        features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
        num_rows: 7000
    })
    validation: Dataset({
        features: ['db_id', 'query', 'question', 'query_toks', 'query_toks_no_value', 'question_toks'],
        num_rows: 1034
    })
})


In [None]:
# Convert the trainset into a pandas DataFrame
df_train = pd.DataFrame(train_set)
df_valid=pd.DataFrame(validation_set)


Build a dcitionnary with db_id and tables that are prsent in the train_set

In [None]:
# unique_db_ids = set(instance['db_id'] for instance in train_set)
unique_db_ids_valid = set(instance['db_id'] for instance in validation_set)
def get_table_names(db_path):
    tables_in_db = []
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
    tables = cursor.fetchall()
    for table in tables:
        tables_in_db.append(table[0])
    conn.close()
    return tables_in_db

base_path = '/content/drive/MyDrive/spider/test_database/'


db_to_tables_valid = {}

for db_id in sorted(unique_db_ids_valid):

    db_path = f'{base_path}{db_id}/{db_id}.sqlite'

    if os.path.exists(db_path):
      tables_in_db = get_table_names(db_path)
      db_to_tables_valid[db_id] = tables_in_db
    else:
        print(f"Database path for {db_id} not found or inaccessible.")

with open('/content/drive/MyDrive/spider/db_to_tables_valid.pkl', 'wb') as file:
        pickle.dump(db_to_tables_valid, file)

In [None]:
with open('/content/drive/MyDrive/spider/db_to_tables_valid.pkl', 'rb') as file:
    db_to_tables_valid = pickle.load(file)

Build a dcitionnary with db_id and tables that are prsent in the train_set which have 3 to 15 rows

In [None]:
def get_table_names_with_row_constraints(db_path, min_rows=3, max_rows=15):
    tables_in_db = []
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;")
    tables = cursor.fetchall()

    for table in tables:
        table_name = table[0]

        cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
        row_count = cursor.fetchone()[0]

        if min_rows < row_count <= max_rows:
            tables_in_db.append(table_name)

    conn.close()
    return tables_in_db

base_path = '/content/drive/MyDrive/spider/test_database/'


db_to_tables_constraint_valid = {}

for db_id in sorted(unique_db_ids_valid):
    db_path = f'{base_path}{db_id}/{db_id}.sqlite'

    if os.path.exists(db_path):
        tables_in_db = get_table_names_with_row_constraints(db_path)
        db_to_tables_constraint_valid[db_id] = tables_in_db
    else:
        print(f"Database path for {db_id} not found or inaccessible.")

with open('/content/drive/MyDrive/spider/db_to_tables_constraint_valid.pkl', 'wb') as file:
        pickle.dump(db_to_tables_constraint_valid, file)

In [None]:
with open('/content/drive/MyDrive/spider/db_to_tables_constraint_valid.pkl', 'rb') as file:
    db_to_tables_constraint_valid = pickle.load(file)

Build a ditionnary which contans mapping between the `db_id_X_tableName` and the actual data in the table that we converted to a pandas DF

In [None]:
db_table_to_df_valid = {}
base_path = '/content/drive/MyDrive/spider/test_database/'

for db_id, tables in db_to_tables_constraint_valid.items():

    db_path = f'{base_path}{db_id}/{db_id}.sqlite'

    for table_name in tables:
        composite_key = f"{db_id}_X_{table_name}"
        df = pd.read_sql_query(f"SELECT * FROM {table_name}", sqlite3.connect(db_path))

        db_table_to_df_valid[composite_key] = df

with open('/content/drive/MyDrive/spider/db_table_to_df_valid.pkl', 'wb') as file:
        pickle.dump(db_table_to_df_valid, file)

In [None]:
with open('/content/drive/MyDrive/spider/db_table_to_df_valid.pkl', 'rb') as file:
    db_table_to_df_valid = pickle.load(file)

We also need to identify the tables each query interacts with. Although I've tried several methods, the simplest and most reliable approach involves knowing the tables in the database and then checking if they are mentioned in the query using some basic regex patterns.


In [None]:
#@title pipeline
import re


def find_tables_in_query(query, db_id, db_to_tables_dict):
    tables_used = []

    tables_list = db_to_tables_dict.get(db_id, [])
    for table in tables_list:
        if re.search(r'\b' + re.escape(table) + r'\b', query, re.IGNORECASE):
            tables_used.append(table)
    return tables_used

# Apply the function to each row in the DataFrame

df_valid['tables_used'] = df_valid.apply(lambda row: find_tables_in_query(row['query'], row['db_id'], db_to_tables_valid), axis=1)


In [None]:
#@title pipeline
df_valid

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks,tables_used
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]",[singer]
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]",[singer]
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin...",[singer]
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ...",[singer]
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ...","[SELECT, avg, (, age, ), ,, min, (, age, ), ,,...","[select, avg, (, age, ), ,, min, (, age, ), ,,...","[What, is, the, average, ,, minimum, ,, and, m...",[singer]
...,...,...,...,...,...,...,...
1029,singer,SELECT Citizenship FROM singer WHERE Birth_Yea...,What are the citizenships that are shared by s...,"[SELECT, Citizenship, FROM, singer, WHERE, Bir...","[select, citizenship, from, singer, where, bir...","[What, are, the, citizenships, that, are, shar...",[singer]
1030,real_estate_properties,SELECT count(*) FROM Other_Available_Features,How many available features are there in total?,"[SELECT, count, (, *, ), FROM, Other_Available...","[select, count, (, *, ), from, other_available...","[How, many, available, features, are, there, i...",[Other_Available_Features]
1031,real_estate_properties,SELECT T2.feature_type_name FROM Other_Availab...,What is the feature type name of feature AirCon?,"[SELECT, T2.feature_type_name, FROM, Other_Ava...","[select, t2, ., feature_type_name, from, other...","[What, is, the, feature, type, name, of, featu...","[Other_Available_Features, Ref_Feature_Types]"
1032,real_estate_properties,SELECT T2.property_type_description FROM Prope...,Show the property type descriptions of propert...,"[SELECT, T2.property_type_description, FROM, P...","[select, t2, ., property_type_description, fro...","[Show, the, property, type, descriptions, of, ...","[Properties, Ref_Property_Types]"


I mentionned before that qatch doesn't support joins, we can only touch one table at a time, so we filter out the querries which touch on more than table

In [None]:
#@title pipeline
df_single_table = df_valid[df_valid['tables_used'].apply(lambda x: len(x) == 1)].copy()


df_single_table['table_used'] = df_single_table['tables_used'].apply(lambda x: x[0] if x else None)

df_single_table = df_single_table.drop(columns=['tables_used'])

df_single_table


Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]",singer
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]",singer
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin...",singer
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ...",singer
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ...","[SELECT, avg, (, age, ), ,, min, (, age, ), ,,...","[select, avg, (, age, ), ,, min, (, age, ), ,,...","[What, is, the, average, ,, minimum, ,, and, m...",singer
...,...,...,...,...,...,...,...
1017,singer,"SELECT Citizenship , max(Net_Worth_Millions) ...","For each citizenship, what is the maximum net ...","[SELECT, Citizenship, ,, max, (, Net_Worth_Mil...","[select, citizenship, ,, max, (, net_worth_mil...","[For, each, citizenship, ,, what, is, the, max...",singer
1028,singer,SELECT Citizenship FROM singer WHERE Birth_Yea...,Show the citizenship shared by singers with bi...,"[SELECT, Citizenship, FROM, singer, WHERE, Bir...","[select, citizenship, from, singer, where, bir...","[Show, the, citizenship, shared, by, singers, ...",singer
1029,singer,SELECT Citizenship FROM singer WHERE Birth_Yea...,What are the citizenships that are shared by s...,"[SELECT, Citizenship, FROM, singer, WHERE, Bir...","[select, citizenship, from, singer, where, bir...","[What, are, the, citizenships, that, are, shar...",singer
1030,real_estate_properties,SELECT count(*) FROM Other_Available_Features,How many available features are there in total?,"[SELECT, count, (, *, ), FROM, Other_Available...","[select, count, (, *, ), from, other_available...","[How, many, available, features, are, there, i...",Other_Available_Features


Filter out querries which use tables which don't conform to our cardinality constraints(3to 15 rows)

In [None]:
#@title pipeline
filtered_df = df_single_table[df_single_table.apply(lambda row: row['table_used'] in db_to_tables_constraint_valid.get(row['db_id'], []), axis=1)]

filtered_df

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]",singer
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]",singer
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin...",singer
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ...",singer
4,concert_singer,"SELECT avg(age) , min(age) , max(age) FROM s...","What is the average, minimum, and maximum age ...","[SELECT, avg, (, age, ), ,, min, (, age, ), ,,...","[select, avg, (, age, ), ,, min, (, age, ), ,,...","[What, is, the, average, ,, minimum, ,, and, m...",singer
...,...,...,...,...,...,...,...
1016,singer,"SELECT Citizenship , max(Net_Worth_Millions) ...",Show different citizenships and the maximum ne...,"[SELECT, Citizenship, ,, max, (, Net_Worth_Mil...","[select, citizenship, ,, max, (, net_worth_mil...","[Show, different, citizenships, and, the, maxi...",singer
1017,singer,"SELECT Citizenship , max(Net_Worth_Millions) ...","For each citizenship, what is the maximum net ...","[SELECT, Citizenship, ,, max, (, Net_Worth_Mil...","[select, citizenship, ,, max, (, net_worth_mil...","[For, each, citizenship, ,, what, is, the, max...",singer
1028,singer,SELECT Citizenship FROM singer WHERE Birth_Yea...,Show the citizenship shared by singers with bi...,"[SELECT, Citizenship, FROM, singer, WHERE, Bir...","[select, citizenship, from, singer, where, bir...","[Show, the, citizenship, shared, by, singers, ...",singer
1029,singer,SELECT Citizenship FROM singer WHERE Birth_Yea...,What are the citizenships that are shared by s...,"[SELECT, Citizenship, FROM, singer, WHERE, Bir...","[select, citizenship, from, singer, where, bir...","[What, are, the, citizenships, that, are, shar...",singer


We need to filter out anything thats not a simple selection, projection or condition, so that our data comes from the same distribution and coherently integrates with QATCH. Given that are not tuning the aggregation head we ommit count(*),max,min .. examples as well. Ask to confirm

In [None]:
import numpy as np
np.nan

nan

In [None]:
#@title pipeline
# def is_simple_query(query):
#     # disallowed_keywords = {
#     #     'join', 'group by', 'having', 'union', 'intersect', 'except',
#     #     'count', 'sum', 'avg', 'min', 'max', 'limit', 'offset', 'inner',
#     #     'outer', 'left', 'right', 'full'
#     # }
#     disallowed_keywords = {
#         'join', 'union', 'intersect', 'except', 'inner',
#         'outer', 'left', 'right', 'full'
#     }


#     query_lower = query.lower()


#     for keyword in disallowed_keywords:
#         if keyword in query_lower:
#             return False
#     return True
# def is_simple_query_with_single_aggregation(query):
#     # Define both disallowed general keywords and specific aggregation keywords.
#     disallowed_keywords = {
#         'join', 'union', 'intersect', 'except', 'inner',
#         'outer', 'left', 'right', 'full'
#     }

#     aggregation_keywords = {
#         'count', 'sum', 'avg', 'min', 'max'
#     }

#     # Convert the query to lowercase for comparison.
#     query_lower = query.lower()

#     # Check for disallowed general keywords.
#     for keyword in disallowed_keywords:
#         if keyword in query_lower:
#             return False

#     # Check for aggregation keywords, ensuring only one is present.
#     aggregation_found = None
#     for keyword in aggregation_keywords:
#         if keyword in query_lower:
#             # If we've already found an aggregation, return False (more than one aggregation).
#             if aggregation_found:
#                 return False
#             # Mark this aggregation as found.
#             aggregation_found = keyword

#     return True


def is_simple_query_with_single_aggregation(query_toks):
    # Define both disallowed general keywords and specific aggregation keywords.
    disallowed_keywords = {
        'join', 'union', 'intersect', 'except', 'inner',
        'outer', 'left', 'right', 'full', 'limit', 'like'
    }
    aggregation_keywords = {
        'count', 'sum', 'avg', 'min', 'max'
    }

    # Convert the query tokens to lowercase for comparison.
    query_toks_lower = [tok.lower() for tok in query_toks]

    # Check for disallowed general keywords.
    if any(keyword in query_toks_lower for keyword in disallowed_keywords):
        return False, np.nan

    # Check for nested queries by counting the occurrences of 'select'.
    select_count = query_toks_lower.count('select')
    if select_count > 1:
        return False, np.nan

    # Check for aggregation keywords.
    aggregation_count = sum(tok in aggregation_keywords for tok in query_toks_lower)

    # Check if there is at most one aggregation keyword.
    if aggregation_count > 1:
        return False, np.nan

    # If an aggregation keyword is present, ensure only one column is projected.
    if aggregation_count == 1:
        select_index = query_toks_lower.index('select')
        from_index = query_toks_lower.index('from')
        columns = query_toks[select_index + 1: from_index]
        if ',' in columns:
            return False, np.nan
        else:
            # Extract the aggregation operator between 'select' and 'from'.
            operator = next((tok for tok in columns if tok.lower() in aggregation_keywords), np.nan)
            return True, operator

    return True, np.nan



filtered_df_simple=filtered_df.copy()

filtered_df_simple['is_simple'] = filtered_df_simple['query_toks_no_value'].apply(lambda x: is_simple_query_with_single_aggregation(x)[0])
filtered_df_simple['operator'] = filtered_df_simple['query_toks_no_value'].apply(lambda x: is_simple_query_with_single_aggregation(x)[1])

simple_queries_df = filtered_df_simple[filtered_df_simple['is_simple']].copy()


simple_queries_df.drop(columns=['is_simple'], inplace=True)



In [None]:
complex_queries_df = filtered_df_simple[~filtered_df_simple['is_simple']].copy()
complex_queries_df

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used,is_simple,operator
3,department_management,"SELECT max(budget_in_billions) , min(budget_i...",What are the maximum and minimum budget of the...,"[SELECT, max, (, budget_in_billions, ), ,, min...","[select, max, (, budget_in_billions, ), ,, min...","[What, are, the, maximum, and, minimum, budget...",department,False,
8,department_management,SELECT creation FROM department GROUP BY creat...,In which year were most departments established?,"[SELECT, creation, FROM, department, GROUP, BY...","[select, creation, from, department, group, by...","[In, which, year, were, most, departments, est...",department,False,
15,department_management,"SELECT head_id , name FROM head WHERE name LI...",Which head's name has the substring 'Ha'? List...,"[SELECT, head_id, ,, name, FROM, head, WHERE, ...","[select, head_id, ,, name, from, head, where, ...","[Which, head, 's, name, has, the, substring, '...",head,False,
26,farm,"SELECT max(Cows) , min(Cows) FROM farm",What are the maximum and minimum number of cow...,"[SELECT, max, (, Cows, ), ,, min, (, Cows, ), ...","[select, max, (, cows, ), ,, min, (, cows, ), ...","[What, are, the, maximum, and, minimum, number...",farm,False,
27,farm,"SELECT max(Cows) , min(Cows) FROM farm",Return the maximum and minimum number of cows ...,"[SELECT, max, (, Cows, ), ,, min, (, Cows, ), ...","[select, max, (, cows, ), ,, min, (, cows, ), ...","[Return, the, maximum, and, minimum, number, o...",farm,False,
...,...,...,...,...,...,...,...,...,...
6987,culture_company,"SELECT title , director FROM movie WHERE YEAR...",Return the title and director of the movie rel...,"[SELECT, title, ,, director, FROM, movie, WHER...","[select, title, ,, director, from, movie, wher...","[Return, the, title, and, director, of, the, m...",movie,False,
6988,culture_company,SELECT director FROM movie WHERE YEAR = 2000...,Show all director names who have a movie in bo...,"[SELECT, director, FROM, movie, WHERE, YEAR, =...","[select, director, from, movie, where, year, =...","[Show, all, director, names, who, have, a, mov...",movie,False,
6989,culture_company,SELECT director FROM movie WHERE YEAR = 2000...,Which directors had a movie both in the year 1...,"[SELECT, director, FROM, movie, WHERE, YEAR, =...","[select, director, from, movie, where, year, =...","[Which, directors, had, a, movie, both, in, th...",movie,False,
6992,culture_company,"SELECT avg(budget_million) , max(budget_milli...","What is the average, maximum, and minimum budg...","[SELECT, avg, (, budget_million, ), ,, max, (,...","[select, avg, (, budget_million, ), ,, max, (,...","[What, is, the, average, ,, maximum, ,, and, m...",movie,False,


In [None]:
for index,row in complex_queries_df.iterrows():
  print(row["query"])

In [None]:
for index,row in simple_queries_df.iterrows():
  print(row["query"])

SELECT count(*) FROM head WHERE age  >  56
SELECT name ,  born_state ,  age FROM head ORDER BY age
SELECT creation ,  name ,  budget_in_billions FROM department
SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15
SELECT name FROM head WHERE born_state != 'California'
SELECT born_state FROM head GROUP BY born_state HAVING count(*)  >=  3
SELECT count(DISTINCT temporary_acting) FROM management
SELECT count(*) FROM farm
SELECT count(*) FROM farm
SELECT Total_Horses FROM farm ORDER BY Total_Horses ASC
SELECT Total_Horses FROM farm ORDER BY Total_Horses ASC
SELECT Hosts FROM farm_competition WHERE Theme !=  'Aliens'
SELECT Hosts FROM farm_competition WHERE Theme !=  'Aliens'
SELECT Theme FROM farm_competition ORDER BY YEAR ASC
SELECT Theme FROM farm_competition ORDER BY YEAR ASC
SELECT avg(Working_Horses) FROM farm WHERE Total_Horses  >  5000
SELECT avg(Working_Horses) FROM farm WHERE Total_Horses  >  5000
SELECT count(DISTINCT Status) FROM city
SELECT count(DISTINCT S

In [None]:
simple_queries_df

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used,operator
0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]",singer,count
1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]",singer,count
2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin...",singer,
3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ...",singer,
8,concert_singer,SELECT DISTINCT country FROM singer WHERE age ...,What are all distinct countries where singers ...,"[SELECT, DISTINCT, country, FROM, singer, WHER...","[select, distinct, country, from, singer, wher...","[What, are, all, distinct, countries, where, s...",singer,
...,...,...,...,...,...,...,...,...
1005,singer,"SELECT Birth_Year , Citizenship FROM singer",What are the birth years and citizenships of t...,"[SELECT, Birth_Year, ,, Citizenship, FROM, sin...","[select, birth_year, ,, citizenship, from, sin...","[What, are, the, birth, years, and, citizenshi...",singer,
1006,singer,"SELECT Name FROM singer WHERE Citizenship != ""...",List the name of singers whose citizenship is ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[List, the, name, of, singers, whose, citizens...",singer,
1007,singer,"SELECT Name FROM singer WHERE Citizenship != ""...",What are the names of the singers who are not ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[What, are, the, names, of, the, singers, who,...",singer,
1008,singer,SELECT Name FROM singer WHERE Birth_Year = 1...,Show the name of singers whose birth year is e...,"[SELECT, Name, FROM, singer, WHERE, Birth_Year...","[select, name, from, singer, where, birth_year...","[Show, the, name, of, singers, whose, birth, y...",singer,


In [None]:
float_df = simple_queries_df[simple_queries_df['operator'].notna()]

float_df["operator"].value_counts()

count    431
avg       77
sum       39
max       12
min        3
Name: operator, dtype: int64

Add the db_id_X_nableName column

In [None]:
#@title pipeline
simple_df=simple_queries_df.copy()
simple_df['seq_id'] = simple_df.apply(lambda row: f"{row['db_id']}_X_{row['table_used']}", axis=1)

In [None]:
#@title pipeline
simple_df.reset_index(inplace=True)
simple_df.rename(columns={'index': 'ID'}, inplace=True)
simple_df.head()

Unnamed: 0,ID,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used,operator,seq_id
0,0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]",singer,count,concert_singer_X_singer
1,1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]",singer,count,concert_singer_X_singer
2,2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin...",singer,,concert_singer_X_singer
3,3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ...",singer,,concert_singer_X_singer
4,8,concert_singer,SELECT DISTINCT country FROM singer WHERE age ...,What are all distinct countries where singers ...,"[SELECT, DISTINCT, country, FROM, singer, WHER...","[select, distinct, country, from, singer, wher...","[What, are, all, distinct, countries, where, s...",singer,,concert_singer_X_singer


In [None]:
simple_df

Unnamed: 0,ID,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used,operator,seq_id
0,0,concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]",singer,count,concert_singer_X_singer
1,1,concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]",singer,count,concert_singer_X_singer
2,2,concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin...",singer,,concert_singer_X_singer
3,3,concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ...",singer,,concert_singer_X_singer
4,8,concert_singer,SELECT DISTINCT country FROM singer WHERE age ...,What are all distinct countries where singers ...,"[SELECT, DISTINCT, country, FROM, singer, WHER...","[select, distinct, country, from, singer, wher...","[What, are, all, distinct, countries, where, s...",singer,,concert_singer_X_singer
...,...,...,...,...,...,...,...,...,...,...
211,1005,singer,"SELECT Birth_Year , Citizenship FROM singer",What are the birth years and citizenships of t...,"[SELECT, Birth_Year, ,, Citizenship, FROM, sin...","[select, birth_year, ,, citizenship, from, sin...","[What, are, the, birth, years, and, citizenshi...",singer,,singer_X_singer
212,1006,singer,"SELECT Name FROM singer WHERE Citizenship != ""...",List the name of singers whose citizenship is ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[List, the, name, of, singers, whose, citizens...",singer,,singer_X_singer
213,1007,singer,"SELECT Name FROM singer WHERE Citizenship != ""...",What are the names of the singers who are not ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[What, are, the, names, of, the, singers, who,...",singer,,singer_X_singer
214,1008,singer,SELECT Name FROM singer WHERE Birth_Year = 1...,Show the name of singers whose birth year is e...,"[SELECT, Name, FROM, singer, WHERE, Birth_Year...","[select, name, from, singer, where, birth_year...","[Show, the, name, of, singers, whose, birth, y...",singer,,singer_X_singer


We now need to get the answers to our querries, by executing them

In [None]:
#@title pipeline
from itertools import chain




results_df = pd.DataFrame(columns=['ID','answer_text'])


base_path = '/content/drive/MyDrive/spider/test_database/'

for index, row in simple_df.iterrows():
    db_id = row['db_id']
    table_used = row['table_used']
    query = row['query']
    ID=row['ID']


    db_path = f'{base_path}{db_id}/{db_id}.sqlite'


    with sqlite3.connect(db_path) as conn:
        cur = conn.cursor()
        cur.execute(query)
        answers = cur.fetchall()


        list_answers = list(chain.from_iterable(answers))


        new_row = pd.DataFrame({'ID': [ID], 'answer_text': [list_answers]})
        results_df = pd.concat([results_df, new_row], ignore_index=True)

In [None]:
#@title pipeline
results_df

Unnamed: 0,ID,answer_text
0,0,[6]
1,1,[6]
2,2,"[Joe Sharp, Netherlands, 52, John Nizinik, Fra..."
3,3,"[Joe Sharp, Netherlands, 52, John Nizinik, Fra..."
4,8,"[Netherlands, United States, France]"
...,...,...
211,1005,"[1944.0, France, 1948.0, United States, 1949.0..."
212,1006,"[Christy Walton, Alice Walton, Iris Fontbona, ..."
213,1007,"[Christy Walton, Alice Walton, Iris Fontbona, ..."
214,1008,"[Christy Walton, Alice Walton]"


In [None]:
#@title pipeline
merged_df = pd.merge(results_df, simple_df, on='ID')
merged_df = merged_df[merged_df['answer_text'].apply(lambda x: x != [])]
merged_df

Unnamed: 0,ID,answer_text,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used,operator,seq_id
0,0,[6],concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]",singer,count,concert_singer_X_singer
1,1,[6],concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]",singer,count,concert_singer_X_singer
2,2,"[Joe Sharp, Netherlands, 52, John Nizinik, Fra...",concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin...",singer,,concert_singer_X_singer
3,3,"[Joe Sharp, Netherlands, 52, John Nizinik, Fra...",concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ...",singer,,concert_singer_X_singer
4,8,"[Netherlands, United States, France]",concert_singer,SELECT DISTINCT country FROM singer WHERE age ...,What are all distinct countries where singers ...,"[SELECT, DISTINCT, country, FROM, singer, WHER...","[select, distinct, country, from, singer, wher...","[What, are, all, distinct, countries, where, s...",singer,,concert_singer_X_singer
...,...,...,...,...,...,...,...,...,...,...,...
211,1005,"[1944.0, France, 1948.0, United States, 1949.0...",singer,"SELECT Birth_Year , Citizenship FROM singer",What are the birth years and citizenships of t...,"[SELECT, Birth_Year, ,, Citizenship, FROM, sin...","[select, birth_year, ,, citizenship, from, sin...","[What, are, the, birth, years, and, citizenshi...",singer,,singer_X_singer
212,1006,"[Christy Walton, Alice Walton, Iris Fontbona, ...",singer,"SELECT Name FROM singer WHERE Citizenship != ""...",List the name of singers whose citizenship is ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[List, the, name, of, singers, whose, citizens...",singer,,singer_X_singer
213,1007,"[Christy Walton, Alice Walton, Iris Fontbona, ...",singer,"SELECT Name FROM singer WHERE Citizenship != ""...",What are the names of the singers who are not ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[What, are, the, names, of, the, singers, who,...",singer,,singer_X_singer
214,1008,"[Christy Walton, Alice Walton]",singer,SELECT Name FROM singer WHERE Birth_Year = 1...,Show the name of singers whose birth year is e...,"[SELECT, Name, FROM, singer, WHERE, Birth_Year...","[select, name, from, singer, where, birth_year...","[Show, the, name, of, singers, whose, birth, y...",singer,,singer_X_singer


In [None]:
merged_df_cleaned = merged_df[merged_df['answer_text'].apply(lambda x: x != [])]

In [None]:
merged_df_cleaned

Unnamed: 0,ID,answer_text,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used,operator,seq_id
0,0,[6],concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]",singer,count,concert_singer_X_singer
1,1,[6],concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]",singer,count,concert_singer_X_singer
2,2,"[Joe Sharp, Netherlands, 52, John Nizinik, Fra...",concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin...",singer,,concert_singer_X_singer
3,3,"[Joe Sharp, Netherlands, 52, John Nizinik, Fra...",concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ...",singer,,concert_singer_X_singer
4,8,"[Netherlands, United States, France]",concert_singer,SELECT DISTINCT country FROM singer WHERE age ...,What are all distinct countries where singers ...,"[SELECT, DISTINCT, country, FROM, singer, WHER...","[select, distinct, country, from, singer, wher...","[What, are, all, distinct, countries, where, s...",singer,,concert_singer_X_singer
...,...,...,...,...,...,...,...,...,...,...,...
211,1005,"[1944.0, France, 1948.0, United States, 1949.0...",singer,"SELECT Birth_Year , Citizenship FROM singer",What are the birth years and citizenships of t...,"[SELECT, Birth_Year, ,, Citizenship, FROM, sin...","[select, birth_year, ,, citizenship, from, sin...","[What, are, the, birth, years, and, citizenshi...",singer,,singer_X_singer
212,1006,"[Christy Walton, Alice Walton, Iris Fontbona, ...",singer,"SELECT Name FROM singer WHERE Citizenship != ""...",List the name of singers whose citizenship is ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[List, the, name, of, singers, whose, citizens...",singer,,singer_X_singer
213,1007,"[Christy Walton, Alice Walton, Iris Fontbona, ...",singer,"SELECT Name FROM singer WHERE Citizenship != ""...",What are the names of the singers who are not ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[What, are, the, names, of, the, singers, who,...",singer,,singer_X_singer
214,1008,"[Christy Walton, Alice Walton]",singer,SELECT Name FROM singer WHERE Birth_Year = 1...,Show the name of singers whose birth year is e...,"[SELECT, Name, FROM, singer, WHERE, Birth_Year...","[select, name, from, singer, where, birth_year...","[Show, the, name, of, singers, whose, birth, y...",singer,,singer_X_singer


In [None]:
testSpider=pd.read_pickle('/content/drive/MyDrive/spider/SPIDER_Simple_cleaned_final_valid.pkl')

In [None]:
testSpider

Unnamed: 0,ID,answer_text,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used,operator,seq_id
0,0,[6],concert_singer,SELECT count(*) FROM singer,How many singers do we have?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[How, many, singers, do, we, have, ?]",singer,count,concert_singer_X_singer
1,1,[6],concert_singer,SELECT count(*) FROM singer,What is the total number of singers?,"[SELECT, count, (, *, ), FROM, singer]","[select, count, (, *, ), from, singer]","[What, is, the, total, number, of, singers, ?]",singer,count,concert_singer_X_singer
2,2,"[Joe Sharp, Netherlands, 52, John Nizinik, Fra...",concert_singer,"SELECT name , country , age FROM singer ORDE...","Show name, country, age for all singers ordere...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[Show, name, ,, country, ,, age, for, all, sin...",singer,,concert_singer_X_singer
3,3,"[Joe Sharp, Netherlands, 52, John Nizinik, Fra...",concert_singer,"SELECT name , country , age FROM singer ORDE...","What are the names, countries, and ages for ev...","[SELECT, name, ,, country, ,, age, FROM, singe...","[select, name, ,, country, ,, age, from, singe...","[What, are, the, names, ,, countries, ,, and, ...",singer,,concert_singer_X_singer
4,8,"[Netherlands, United States, France]",concert_singer,SELECT DISTINCT country FROM singer WHERE age ...,What are all distinct countries where singers ...,"[SELECT, DISTINCT, country, FROM, singer, WHER...","[select, distinct, country, from, singer, wher...","[What, are, all, distinct, countries, where, s...",singer,,concert_singer_X_singer
...,...,...,...,...,...,...,...,...,...,...,...
211,1005,"[1944.0, France, 1948.0, United States, 1949.0...",singer,"SELECT Birth_Year , Citizenship FROM singer",What are the birth years and citizenships of t...,"[SELECT, Birth_Year, ,, Citizenship, FROM, sin...","[select, birth_year, ,, citizenship, from, sin...","[What, are, the, birth, years, and, citizenshi...",singer,,singer_X_singer
212,1006,"[Christy Walton, Alice Walton, Iris Fontbona, ...",singer,"SELECT Name FROM singer WHERE Citizenship != ""...",List the name of singers whose citizenship is ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[List, the, name, of, singers, whose, citizens...",singer,,singer_X_singer
213,1007,"[Christy Walton, Alice Walton, Iris Fontbona, ...",singer,"SELECT Name FROM singer WHERE Citizenship != ""...",What are the names of the singers who are not ...,"[SELECT, Name, FROM, singer, WHERE, Citizenshi...","[select, name, from, singer, where, citizenshi...","[What, are, the, names, of, the, singers, who,...",singer,,singer_X_singer
214,1008,"[Christy Walton, Alice Walton]",singer,SELECT Name FROM singer WHERE Birth_Year = 1...,Show the name of singers whose birth year is e...,"[SELECT, Name, FROM, singer, WHERE, Birth_Year...","[select, name, from, singer, where, birth_year...","[Show, the, name, of, singers, whose, birth, y...",singer,,singer_X_singer


In [None]:
with open('/content/drive/MyDrive/spider/SPIDER_Simple_cleaned_final_valid.pkl', 'wb') as file:
        pickle.dump(merged_df_cleaned, file)

In [None]:
#@title pipeline
# merged_df.columns

Here I use a wrapper function around the the fucntion which caluculates the coordinates that we imported earlier `parse_question(
                table=table_data, question=question, answer_texts=answer_texts)`

Here are ids for queries which turned out to be problematic as the answer  and the table don't fit in the sequence lenght limit, so omit them here

In [None]:
#@title pipeline
# prob=[5, 294, 1290, 1500, 1977, 2035, 2526, 3774, 3775, 3776, 3777, 3796, 3797, 3798, 4168, 4169, 4170, 4184, 4185, 4186, 5432, 5522, 5523, 5651, 5860, 6177,15, 2036, 2527, 5433, 6178,6193,6194]
# merged_df = merged_df[~merged_df['ID'].isin(prob)]


In [None]:
#@title get_answer_coordinates Old
# def get_answer_coordinates(row):
#     seq_id= row['seq_id']
#     table_data = db_table_to_df.get(seq_id)

#     if table_data is not None:
#         question = row['question']
#         # Ensure answer_texts are strings, even if they are initially integers
#         answer_texts = [str(ans) for ans in row['answer_text']]

#         try:
#             _, new_answer_texts, answer_coordinates, float_value, _ = parse_question(
#                 table=table_data, question=question, answer_texts=answer_texts)
#             return answer_coordinates
#         except Exception as e:
#             print(f"Error processing row: {e}")
#             return None
#     else:
#         return None

In [None]:
#@title pipeline
problematic_rows_ids=[]
def get_answer_coordinates(row):
    seq_id = row['seq_id']
    table_data = db_table_to_df.get(seq_id)

    if table_data is not None:
        question = row['question']
        answer_texts = [str(ans) for ans in row['answer_text']]

        try:
            _, new_answer_texts, answer_coordinates, float_value, _ = parse_question(
                table=table_data, question=question, answer_texts=answer_texts)
            return answer_coordinates, new_answer_texts, float_value
        except Exception as e:
            problematic_rows_ids.append(row['ID'])
            print(f"Error processing row: {e}")
            return None, None, None
    else:
        problematic_rows_ids.append(row['ID'])
        return None, None, None

In [None]:
#@title pipeline
# merged_df['answer_coordinates'] = merged_df.apply(get_answer_coordinates, axis=1)
# Applying the function and creating new columns
result = merged_df.apply(get_answer_coordinates, axis=1, result_type='expand')
merged_df['answer_coordinates'], merged_df['float_value'] = result[0], result[2]

Error processing row: Cannot parse answer: [float_value: Cannot convert to multiple answers to single float]
Error processing row: name '_split_thousands' is not defined
Error processing row: name '_split_thousands' is not defined
Error processing row: name '_split_thousands' is not defined
Error processing row: name '_split_thousands' is not defined
Error processing row: Cannot parse answer: 


In [None]:
#@title pipeline
# merged_df

Drop the querries which have empty answers

In [None]:
#@title pipeline
# problematic_rows = merged_df[merged_df['ID'].isin(problematic_rows_ids)].copy()
# problematic_rows.head(50)

In [None]:
#@title pipeline
merged_df_cleaned = merged_df[~merged_df['ID'].isin(problematic_rows_ids)].copy()

In [None]:
merged_df_cleaned

Unnamed: 0,ID,answer_text,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used,operator,seq_id,answer_coordinates,float_value
0,0,[5],department_management,SELECT count(*) FROM head WHERE age > 56,How many heads of the departments are older th...,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...","[How, many, heads, of, the, departments, are, ...",head,count,department_management_X_head,"[(4, 0)]",5.000000
1,1,"[Pádraig Harrington, Connecticut, 43.0, Stewar...",department_management,"SELECT name , born_state , age FROM head ORD...","List the name, born state and age of the heads...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","[List, the, name, ,, born, state, and, age, of...",head,,department_management_X_head,"[(8, 1), (8, 2), (8, 3), (6, 1), (6, 2), (6, 3...",
2,2,"[1789, State, 9.96, 1789, Treasury, 11.1, 1947...",department_management,"SELECT creation , name , budget_in_billions ...","List the creation year, name and budget of eac...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","[List, the, creation, year, ,, name, and, budg...",department,,department_management_X_department,"[(0, 2), (0, 1), (0, 4), (1, 2), (1, 1), (1, 4...",
3,4,[105468.16666666667],department_management,SELECT avg(num_employees) FROM department WHER...,What is the average number of employees of the...,"[SELECT, avg, (, num_employees, ), FROM, depar...","[select, avg, (, num_employees, ), from, depar...","[What, is, the, average, number, of, employees...",department,avg,department_management_X_department,,105468.166667
4,5,"[Tiger Woods, K. J. Choi, Jeff Maggert, Stewar...",department_management,SELECT name FROM head WHERE born_state != 'Cal...,What are the names of the heads who are born o...,"[SELECT, name, FROM, head, WHERE, born_state, ...","[select, name, from, head, where, born_state, ...","[What, are, the, names, of, the, heads, who, a...",head,,department_management_X_head,"[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1), (9, 1)]",
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1689,6983,"[Jill Rips, 2000, Anthony Hickox, Storm Catche...",culture_company,"SELECT title , YEAR , director FROM movie OR...","What are the titles, years, and directors of a...","[SELECT, title, ,, YEAR, ,, director, FROM, mo...","[select, title, ,, year, ,, director, from, mo...","[What, are, the, titles, ,, years, ,, and, dir...",movie,,culture_company_X_movie,"[(3, 1), (3, 2), (2, 3), (2, 1), (0, 2), (3, 3...",
1690,6984,[9],culture_company,SELECT COUNT (DISTINCT director) FROM movie,How many movie directors are there?,"[SELECT, COUNT, (, DISTINCT, director, ), FROM...","[select, count, (, distinct, director, ), from...","[How, many, movie, directors, are, there, ?]",movie,count,culture_company_X_movie,"[(8, 0)]",9.000000
1691,6985,[9],culture_company,SELECT COUNT (DISTINCT director) FROM movie,Count the number of different directors.,"[SELECT, COUNT, (, DISTINCT, director, ), FROM...","[select, count, (, distinct, director, ), from...","[Count, the, number, of, different, directors, .]",movie,count,culture_company_X_movie,"[(8, 0)]",9.000000
1692,6990,"[Troy Duffy, John Swanbeck, Anthony Hickox, An...",culture_company,SELECT director FROM movie WHERE YEAR = 1999...,Show all director names who have a movie in th...,"[SELECT, director, FROM, movie, WHERE, YEAR, =...","[select, director, from, movie, where, year, =...","[Show, all, director, names, who, have, a, mov...",movie,,culture_company_X_movie,"[(0, 3), (1, 3), (2, 3), (3, 3), (4, 3), (5, 3...",


In [None]:
#@title pipeline
# merged_df_cleaned

I noticed that the model performs better when the quetions for the same table are given in a squence, so we add a random positioning for this

In [None]:
#@title pipeline
merged_df_cleaned = merged_df_cleaned.dropna(subset=['answer_coordinates'])
merged_df_cleaned

Unnamed: 0,ID,answer_text,db_id,query,question,query_toks,query_toks_no_value,question_toks,table_used,operator,seq_id,answer_coordinates,float_value
0,0,[5],department_management,SELECT count(*) FROM head WHERE age > 56,How many heads of the departments are older th...,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...","[How, many, heads, of, the, departments, are, ...",head,count,department_management_X_head,"[(4, 0)]",5.0
1,1,"[Pádraig Harrington, Connecticut, 43.0, Stewar...",department_management,"SELECT name , born_state , age FROM head ORD...","List the name, born state and age of the heads...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","[List, the, name, ,, born, state, and, age, of...",head,,department_management_X_head,"[(8, 1), (8, 2), (8, 3), (6, 1), (6, 2), (6, 3...",
2,2,"[1789, State, 9.96, 1789, Treasury, 11.1, 1947...",department_management,"SELECT creation , name , budget_in_billions ...","List the creation year, name and budget of eac...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","[List, the, creation, year, ,, name, and, budg...",department,,department_management_X_department,"[(0, 2), (0, 1), (0, 4), (1, 2), (1, 1), (1, 4...",
4,5,"[Tiger Woods, K. J. Choi, Jeff Maggert, Stewar...",department_management,SELECT name FROM head WHERE born_state != 'Cal...,What are the names of the heads who are born o...,"[SELECT, name, FROM, head, WHERE, born_state, ...","[select, name, from, head, where, born_state, ...","[What, are, the, names, of, the, heads, who, a...",head,,department_management_X_head,"[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1), (9, 1)]",
5,7,[California],department_management,SELECT born_state FROM head GROUP BY born_stat...,What are the names of the states where at leas...,"[SELECT, born_state, FROM, head, GROUP, BY, bo...","[select, born_state, from, head, group, by, bo...","[What, are, the, names, of, the, states, where...",head,,department_management_X_head,"[(1, 2)]",
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1689,6983,"[Jill Rips, 2000, Anthony Hickox, Storm Catche...",culture_company,"SELECT title , YEAR , director FROM movie OR...","What are the titles, years, and directors of a...","[SELECT, title, ,, YEAR, ,, director, FROM, mo...","[select, title, ,, year, ,, director, from, mo...","[What, are, the, titles, ,, years, ,, and, dir...",movie,,culture_company_X_movie,"[(3, 1), (3, 2), (2, 3), (2, 1), (0, 2), (3, 3...",
1690,6984,[9],culture_company,SELECT COUNT (DISTINCT director) FROM movie,How many movie directors are there?,"[SELECT, COUNT, (, DISTINCT, director, ), FROM...","[select, count, (, distinct, director, ), from...","[How, many, movie, directors, are, there, ?]",movie,count,culture_company_X_movie,"[(8, 0)]",9.0
1691,6985,[9],culture_company,SELECT COUNT (DISTINCT director) FROM movie,Count the number of different directors.,"[SELECT, COUNT, (, DISTINCT, director, ), FROM...","[select, count, (, distinct, director, ), from...","[Count, the, number, of, different, directors, .]",movie,count,culture_company_X_movie,"[(8, 0)]",9.0
1692,6990,"[Troy Duffy, John Swanbeck, Anthony Hickox, An...",culture_company,SELECT director FROM movie WHERE YEAR = 1999...,Show all director names who have a movie in th...,"[SELECT, director, FROM, movie, WHERE, YEAR, =...","[select, director, from, movie, where, year, =...","[Show, all, director, names, who, have, a, mov...",movie,,culture_company_X_movie,"[(0, 3), (1, 3), (2, 3), (3, 3), (4, 3), (5, 3...",


In [None]:
with open('/content/drive/MyDrive/spider/SPIDER_Simple_cleaned_final_valid.pkl', 'wb') as file:
        pickle.dump(simple_df, file)

In [None]:
with open('/content/drive/MyDrive/spider/SPIDER_Simple_cleaned_final.pkl', 'wb') as file:
        pickle.dump(merged_df_cleaned, file)

In [None]:
#@title Cleanig QATCH
db_ids_train_spider= set(instance['db_id'] for instance in train_set)


from qatch.database_reader import MultipleDatabases

# The path to multiple databases
db_save_path = '/content/drive/MyDrive/spider/test_database'
databases = MultipleDatabases(db_save_path)


from qatch import TestGenerator

# init generator
test_generator = TestGenerator(databases=databases)

# generate tests for each database and for each generator
tests_df = test_generator.generate()


# Create a mask where each row is True if 'db_id' is in db_ids_train_spider, else False
mask = tests_df['db_id'].isin(db_ids_train_spider)

# Apply the mask to filter the DataFrame
tests_df = tests_df[mask]

unique_db_ids=tests_df['db_id'].unique()


tests_df = tests_df[tests_df.apply(lambda row: row['tbl_name'] in db_to_tables_constraint.get(row['db_id'], []), axis=1)]

tests_df=tests_df.rename(columns={'tbl_name':'table_used'})


# qatch_synthetic=tests_df.drop(columns=["sql_tags"], axis=1)
qatch_synthetic=tests_df.copy()
qatch_synthetic['seq_id'] = qatch_synthetic.apply(lambda row: f"{row['db_id']}_X_{row['table_used']}", axis=1)


qatch_synthetic.reset_index(inplace=True)
qatch_synthetic.rename(columns={'index': 'ID'}, inplace=True)


results_df_qatch = pd.DataFrame(columns=['ID','answer_text'])


base_path = '/content/drive/MyDrive/spider/test_database/'

for index, row in qatch_synthetic.iterrows():
    db_id = row['db_id']
    table_used = row['table_used']
    query = row['query']
    ID=row['ID']


    db_path = f'{base_path}{db_id}/{db_id}.sqlite'


    with sqlite3.connect(db_path) as conn:
        cur = conn.cursor()
        cur.execute(query)
        answers = cur.fetchall()


        list_answers = list(chain.from_iterable(answers))


        new_row = pd.DataFrame({'ID': [ID], 'answer_text': [list_answers]})
        results_df_qatch = pd.concat([results_df_qatch, new_row], ignore_index=True)


merged_df_qatch = pd.merge(results_df_qatch, qatch_synthetic, on='ID')
merged_df_qatch = merged_df_qatch[merged_df_qatch['answer_text'].apply(lambda x: x != [])]


problematic_rows_ids=[]


result = merged_df_qatch.apply(get_answer_coordinates, axis=1, result_type='expand')
merged_df_qatch['answer_coordinates'],merged_df_qatch['float_value'] = result[0],result[2]

merged_df_cleaned_qatch = merged_df_qatch[~merged_df_qatch['ID'].isin(problematic_rows_ids)].copy()

merged_df_cleaned_qatch = merged_df_cleaned_qatch.dropna(subset=['answer_coordinates'])

In [None]:
#@title pipeline
# float_df = merged_df_test.dropna(subset=['float_value'])
# float_df.head(50)

In [None]:
merged_df_cleaned = pd.read_pickle('/content/drive/MyDrive/spider/cleaned_SPIDER_train.pkl')
merged_df_cleaned

Unnamed: 0,ID,answer_text,db_id,query,question,table_used,seq_id,answer_coordinates,float_value
0,0,[5],department_management,SELECT count(*) FROM head WHERE age > 56,How many heads of the departments are older th...,head,department_management_X_head,"[(4, 0)]",5.0
1,1,"[Pádraig Harrington, Connecticut, 43.0, Stewar...",department_management,"SELECT name , born_state , age FROM head ORD...","List the name, born state and age of the heads...",head,department_management_X_head,"[(8, 1), (8, 2), (8, 3), (6, 1), (6, 2), (6, 3...",
2,2,"[1789, State, 9.96, 1789, Treasury, 11.1, 1947...",department_management,"SELECT creation , name , budget_in_billions ...","List the creation year, name and budget of eac...",department,department_management_X_department,"[(0, 2), (0, 1), (0, 4), (1, 2), (1, 1), (1, 4...",
4,5,"[Tiger Woods, K. J. Choi, Jeff Maggert, Stewar...",department_management,SELECT name FROM head WHERE born_state != 'Cal...,What are the names of the heads who are born o...,head,department_management_X_head,"[(0, 1), (2, 1), (4, 1), (6, 1), (8, 1), (9, 1)]",
5,7,[California],department_management,SELECT born_state FROM head GROUP BY born_stat...,What are the names of the states where at leas...,head,department_management_X_head,"[(1, 2)]",
...,...,...,...,...,...,...,...,...,...
2454,6985,[9],culture_company,SELECT COUNT (DISTINCT director) FROM movie,Count the number of different directors.,movie,culture_company_X_movie,"[(8, 0)]",9.0
2455,6986,"[The Whole Nine Yards, Jonathan Lynn]",culture_company,"SELECT title , director FROM movie WHERE YEAR...",What is the title and director for the movie w...,movie,culture_company_X_movie,"[(4, 1), (4, 3)]",
2456,6987,"[The Whole Nine Yards, Jonathan Lynn]",culture_company,"SELECT title , director FROM movie WHERE YEAR...",Return the title and director of the movie rel...,movie,culture_company_X_movie,"[(4, 1), (4, 3)]",
2457,6990,"[Troy Duffy, John Swanbeck, Anthony Hickox, An...",culture_company,SELECT director FROM movie WHERE YEAR = 1999...,Show all director names who have a movie in th...,movie,culture_company_X_movie,"[(0, 3), (1, 3), (2, 3), (3, 3), (4, 3), (5, 3...",


In [None]:
# Filtering rows where 'your_column_name' column has non-NaN values
float_df = merged_df_cleaned[merged_df_cleaned['float_value'].notna()]

float_df.head(50)

Unnamed: 0,ID,answer_text,db_id,query,question,table_used,seq_id,answer_coordinates,float_value
0,0,[5],department_management,SELECT count(*) FROM head WHERE age > 56,How many heads of the departments are older th...,head,department_management_X_head,"[(4, 0)]",5.0
6,8,[1789],department_management,SELECT creation FROM department GROUP BY creat...,In which year were most departments established?,department,department_management_X_department,"[(0, 2)]",1789.0
7,10,[2],department_management,SELECT count(DISTINCT temporary_acting) FROM m...,How many acting statuses are there?,management,department_management_X_management,"[(0, 0)]",2.0
19,28,[2],farm,SELECT count(DISTINCT Status) FROM city,How many different statuses do cities have?,city,farm_X_city,"[(1, 0)]",2.0
20,29,[2],farm,SELECT count(DISTINCT Status) FROM city,Count the number of different statuses.,city,farm_X_city,"[(1, 0)]",2.0
35,58,[111],student_assessment,SELECT student_id FROM student_course_registra...,what is id of students who registered some cou...,Student_Course_Registrations,student_assessment_X_Student_Course_Registrations,"[(0, 0)]",111.0
36,59,[111],student_assessment,SELECT student_id FROM student_course_registra...,What are the ids of the students who registere...,Student_Course_Registrations,student_assessment_X_Student_Course_Registrations,"[(0, 0)]",111.0
39,75,[121],student_assessment,SELECT candidate_id FROM candidate_assessments...,Find id of the candidate who most recently acc...,Candidate_Assessments,student_assessment_X_Candidate_Assessments,"[(1, 0)]",121.0
40,76,[121],student_assessment,SELECT candidate_id FROM candidate_assessments...,What is the id of the candidate who most recen...,Candidate_Assessments,student_assessment_X_Candidate_Assessments,"[(1, 0)]",121.0
45,89,[171],student_assessment,SELECT student_id FROM student_course_attendan...,What is the id of the student who most recentl...,Student_Course_Attendance,student_assessment_X_Student_Course_Attendance,"[(4, 0)]",171.0


In [None]:
import re

def is_simple_query_with_single_aggregation(query):
    # Define both disallowed general keywords and specific aggregation keywords.
    disallowed_keywords = {
        'join', 'union', 'intersect', 'except', 'inner',
        'outer', 'left', 'right', 'full'
    }
    aggregation_keywords = {
        'count', 'sum', 'avg', 'min', 'max'
    }

    # Convert the query to lowercase for comparison.
    query_lower = query.lower()

    # Check for disallowed general keywords.
    for keyword in disallowed_keywords:
        if keyword in query_lower:
            return False

    # Check for aggregation keywords using regex.
    aggregation_pattern = r'\b(' + '|'.join(aggregation_keywords) + r')\b'
    aggregation_matches = re.findall(aggregation_pattern, query_lower)

    # Check if there is exactly one aggregation keyword.
    if len(aggregation_matches) != 1:
        return False

    return True

In [None]:
query="SELECT AVG() FROM head WHERE age > 56"
print(is_simple_query_with_single_aggregation(query))

True


In [None]:
# merged_df_cleaned['position'] = merged_df_cleaned.groupby('seq_id').cumcount()

This is is just for testing purposes , just like the other google reserach notebook where they test the tapas encoding int this way, by grouping the tables and their questions.

In [None]:
grouped = merged_df_cleaned.groupby('seq_id').agg(lambda x: x.tolist())
grouped = grouped.reset_index()

grouped

Unnamed: 0,seq_id,ID,answer_text,db_id,query,question,table_used,answer_coordinates,float_value
0,aircraft_X_aircraft,"[4797, 4798, 4799, 4800]","[[5], [5], [Light utility helicopter, Turbosha...","[aircraft, aircraft, aircraft, aircraft]","[SELECT count(*) FROM aircraft, SELECT count(*...","[How many aircrafts are there?, What is the nu...","[aircraft, aircraft, aircraft, aircraft]","[[(4, 0)], [(4, 0)], [(0, 2), (1, 2), (2, 2), ...","[5.0, 5.0, nan, nan]"
1,aircraft_X_airport,"[4803, 4804, 4835, 4836]","[[61344438.0, 5562516.0], [61344438.0, 5562516...","[aircraft, aircraft, aircraft, aircraft]","[SELECT International_Passengers , Domestic_P...",[What are the number of international and dome...,"[airport, airport, airport, airport]","[[(0, 4), (0, 5)], [(0, 4), (0, 5)], [(0, 0), ...","[nan, nan, nan, nan]"
2,aircraft_X_pilot,"[4809, 4810, 4811, 4812, 4813, 4814, 4827, 4828]","[[Ayana Spencer, Ellen Ledner III, Elisha Hick...","[aircraft, aircraft, aircraft, aircraft, aircr...","[SELECT Name FROM pilot WHERE Age >= 25, SEL...",[What are the name of pilots aged 25 or older?...,"[pilot, pilot, pilot, pilot, pilot, pilot, pil...","[[(3, 1), (4, 1), (5, 1), (6, 1), (7, 1), (8, ...","[nan, nan, nan, nan, nan, nan, nan, nan]"
3,allergy_1_X_Allergy_Type,"[443, 444, 445, 446, 447, 448, 449, 450, 455, ...","[[food, environmental, animal], [food, environ...","[allergy_1, allergy_1, allergy_1, allergy_1, a...",[SELECT DISTINCT allergytype FROM Allergy_type...,"[Show all allergy types., What are the differe...","[Allergy_Type, Allergy_Type, Allergy_Type, All...","[[(0, 1), (7, 1), (10, 1)], [(0, 1), (7, 1), (...","[nan, nan, nan, nan, nan, nan, nan, nan, nan, ..."
4,apartment_rentals_X_Apartment_Bookings,"[1194, 1195, 1196, 1197, 1248, 1249]","[[15], [15], [2016-09-26 17:13:49, 2017-10-07 ...","[apartment_rentals, apartment_rentals, apartme...","[SELECT count(*) FROM Apartment_Bookings, SELE...",[How many apartment bookings are there in tota...,"[Apartment_Bookings, Apartment_Bookings, Apart...","[[(1, 1)], [(1, 1)], [(0, 4), (0, 5), (1, 4), ...","[15.0, 15.0, nan, nan, nan, nan]"
...,...,...,...,...,...,...,...,...,...
328,wedding_X_wedding,"[1645, 1648]","[[2], [2]]","[wedding, wedding]",[SELECT count(*) FROM wedding WHERE YEAR = 2...,"[How many weddings are there in year 2016?, Ho...","[wedding, wedding]","[[(0, 2)], [(0, 2)]]","[2.0, 2.0]"
329,workshop_paper_X_submission,"[5814, 5815, 5816, 5817, 5818, 5819, 5820, 582...","[[10], [10], [Steve Niehaus, Sherman Smith, Sa...","[workshop_paper, workshop_paper, workshop_pape...","[SELECT count(*) FROM submission, SELECT count...","[How many submissions are there?, Count the nu...","[submission, submission, submission, submissio...","[[(9, 0)], [(9, 0)], [(0, 2), (2, 2), (1, 2), ...","[10.0, 10.0, nan, nan, nan, nan, nan, nan, nan..."
330,workshop_paper_X_workshop,"[5840, 5841]","[[July 5, 2011, Istanbul Turkey, August 18, 20...","[workshop_paper, workshop_paper]","[SELECT Date , Venue FROM workshop ORDER BY V...",[Show the date and venue of each workshop in a...,"[workshop, workshop]","[[(5, 1), (5, 2), (0, 1), (0, 2), (1, 1), (1, ...","[nan, nan]"
331,wrestler_X_Elimination,"[1854, 1855, 1882, 1883]","[[Go To Sleep, Spear], [Go To Sleep, Spear], [...","[wrestler, wrestler, wrestler, wrestler]",[SELECT Elimination_Move FROM Elimination WHER...,[What are the elimination moves of wrestlers w...,"[Elimination, Elimination, Elimination, Elimin...","[[(0, 4), (5, 4)], [(0, 4), (5, 4)], [(0, 5), ...","[nan, nan, nan, nan]"


In [None]:
import torch
from transformers import TapasTokenizer

# initialize the tokenizer
tokenizer = TapasTokenizer.from_pretrained("google/tapas-base")



tokenizer_config.json:   0%|          | 0.00/490 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/262k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/154 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

In [None]:
item = grouped.iloc[5]
table=db_table_to_df[item.seq_id].astype(str)
display(table)
print("")
print(table.columns)
print("")
print(item.question)
print("")
print(item.answer_text)

Unnamed: 0,building_id,building_short_name,building_full_name,building_description,building_address,building_manager,building_phone
0,133,Normandie Court,Normandie Court,Studio,"7950 Casper Vista Apt. 176\nMarquiseberg, CA 7...",Emma,(948)040-1064x387
1,153,Mercedes House,Mercedes House,Studio,"354 Otto Villages\nCharliefort, VT 71664",Brenden,915-617-2408x832
2,191,The Eugene,The Eugene,Flat,"71537 Gorczany Inlet\nWisozkburgh, AL 08256",Melyssa,(609)946-0491
3,196,VIA 57 WEST,VIA 57 WEST,Studio,"959 Ethel Viaduct\nWest Efrainburgh, DE 40074",Kathlyn,681.772.2454
4,225,Columbus Square,Columbus Square,Studio,"0703 Danika Mountains Apt. 362\nMohrland, AL 5...",Kyle,1-724-982-9507x640
5,532,Avalon Park,Avalon Park,Duplex,"6827 Kessler Parkway Suite 908\nAhmedberg, WI ...",Albert,376-017-3538
6,556,Peter Cooper Village,Peter Cooper Village,Flat,"861 Narciso Glens Suite 392\nEast Ottis, ND 73970",Darlene,1-224-619-0295x13195
7,624,Stuyvesant Town,Stuyvesant Town,Studio,101 Queenie Mountains Suite 619\nNew Korbinmou...,Marie,(145)411-6406
8,644,The Anthem,The Anthem,Flat,"50804 Mason Isle Suite 844\nWest Whitney, ID 6...",Ewald,(909)086-5221x3455
9,673,Barclay Tower,Barclay Tower,Flat,"1579 Runte Forges Apt. 548\nLeuschkeland, OK 1...",Rogers,1-326-267-3386x613



Index(['building_id', 'building_short_name', 'building_full_name',
       'building_description', 'building_address', 'building_manager',
       'building_phone'],
      dtype='object')

['Show all distinct building descriptions.', 'Give me a list of all the distinct building descriptions.', 'Show the short names of the buildings managed by "Emma".', 'Which buildings does "Emma" manage? Give me the short names of the buildings.', 'Show the addresses and phones of all the buildings managed by "Brenden".', 'What are the address and phone number of the buildings managed by "Brenden"?']

[['Studio', 'Flat', 'Duplex'], ['Studio', 'Flat', 'Duplex'], ['Normandie Court'], ['Normandie Court'], ['354 Otto Villages\nCharliefort, VT 71664', '915-617-2408x832'], ['354 Otto Villages\nCharliefort, VT 71664', '915-617-2408x832']]


# Fine tunning:
The table we are going to give to tapas is merged_Df, the grouped table we created before was just for experimentation purposes.

In [None]:
# @title Old tokenization
# class TableDataset(torch.utils.data.Dataset):
#     def __init__(self, df, tokenizer):
#         self.df = df
#         self.tokenizer = tokenizer
#         self.problematic_ids = []  # Initialize a list to store IDs of problematic questions

#     def __getitem__(self, idx):
#         item = self.df.iloc[idx]
#         table_name = item['seq_id']
#         table = db_table_to_df[table_name].astype(str)

#         try:
#             if item.position != 0:
#                 previous_item = self.df.iloc[idx-1]
#                 encoding = self.tokenizer(table=table,
#                                           queries=[previous_item.question, item.question],
#                                           answer_coordinates=[previous_item.answer_coordinates, item.answer_coordinates],
#                                           answer_text=[previous_item.answer_text, item.answer_text],
#                                           padding="max_length",
#                                           truncation=True,
#                                           return_tensors="pt"
#                                          )
#                 # Use encodings of second table-question pair in the batch
#                 encoding = {key: val[-1] for key, val in encoding.items()}
#             else:
#                 # This means it's the first table-question pair in a sequence
#                 encoding = self.tokenizer(table=table,
#                                           queries=item.question,
#                                           answer_coordinates=item.answer_coordinates,
#                                           answer_text=item.answer_text,
#                                           padding="max_length",
#                                           truncation=True,
#                                           return_tensors="pt"
#                                          )
#                 # Remove the batch dimension which the tokenizer adds
#                 encoding = {key: val.squeeze(0) for key, val in encoding.items()}

#         except Exception as e:
#             print(f"Error processing index {idx} (ID: {item['ID']}): {e}")
#             self.problematic_ids.append(item['ID'])  # Store the ID of the problematic question
#             encoding = None

#         return encoding

#     def __len__(self):
#         return len(self.df)

#     def get_problematic_ids(self):
#         # Returns the list of problematic IDs
#         return self.problematic_ids

# # Custom collate function to filter out None values
# def custom_collate_fn(batch):
#     batch = [item for item in batch if item is not None]  # Filter out None values
#     if not batch:
#         return None
#     return torch.utils.data.dataloader.default_collate(batch)

# train_dataset = TableDataset(df=merged_df_cleaned, tokenizer=tokenizer)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, collate_fn=custom_collate_fn)



# for idx, batch in enumerate(train_dataloader):
#     if batch is not None:  # Only process batches that are not None
#         print(f"Batch {idx} loaded successfully")
#     else:
#         print(f"Batch {idx} skipped due to problematic data")

# print(train_dataset.get_problematic_ids())

In [None]:
#@title Old tableset but not so old
# import torch
# import pandas as pd
# import numpy as np

# class TableDataset(torch.utils.data.Dataset):
#     def __init__(self, df, tokenizer):
#         self.df = df
#         self.tokenizer = tokenizer

#     def __getitem__(self, idx):
#         item = self.df.iloc[idx]
#         table_name = item['seq_id']
#         table = db_table_to_df[table_name].astype(str)
#         # Check if float_value is not NaN
#         # if item.float_value:
#         encoding = self.tokenizer(
#             table=table,
#             queries=item.question,
#             answer_coordinates=item.answer_coordinates,
#             answer_text=item.answer_text,
#             truncation=True,
#             padding="max_length",
#             return_tensors="pt",
#         )
#         float_answer = torch.tensor(item.float_value)
#         # Remove the batch dimension which the tokenizer adds by default
#         encoding = {key: val.squeeze(0) for key, val in encoding.items()}
#         # Add the float_answer
#         encoding["float_answer"] = float_answer
#         return encoding

#         # else:
#         #     encoding = self.tokenizer(
#         #         table=table,
#         #         queries=item.question,
#         #         answer_coordinates=item.answer_coordinates,
#         #         answer_text=item.answer_text,
#         #         truncation=True,
#         #         padding="max_length",
#         #         return_tensors="pt",
#         #     )
#         #     # float_answer = torch.tensor(item.float_value)

#         #     # Remove the batch dimension which the tokenizer adds by default
#         #     encoding = {key: val.squeeze(0) for key, val in encoding.items()}
#         #     return encoding

#     def __len__(self):
#         return len(self.df)

# train_dataset = TableDataset(df=merged_df_cleaned, tokenizer=tokenizer)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16)

# for idx, batch in enumerate(train_dataloader):
#   continue


In [None]:
# merged_df_cleaned.to_pickle('/content/drive/MyDrive/spider/cleaned_SPIDER_train.pkl')

In [None]:
from sklearn.model_selection import train_test_split

# Assume merged_df_cleaned is your DataFrame

# First, split into training and a temp set (which will be split into validation and testing)
train_df, val_df = train_test_split(merged_df_cleaned, test_size=0.3, random_state=42)


# At this point, you have:
# train_df: 60% of the original data for training
# val_df: 20% of the original data for validation
# test_df: 20% of the original data for testing

print("Training set size:", len(train_df))
print("Validation set size:", len(val_df))


Training set size: 1462
Validation set size: 627


In [None]:
# import json
# df22 = pd.read_csv('/content/drive/MyDrive/spider/QATCH_SPIDER.csv', converters={'answer_text': json.loads})
# df22_sampled=df22.sample(n=1000, random_state=42)

In [None]:
#@title Cleanig QATCH
db_ids_train_spider= set(instance['db_id'] for instance in train_set)


from qatch.database_reader import MultipleDatabases

# The path to multiple databases
db_save_path = '/content/drive/MyDrive/spider/test_database'
databases = MultipleDatabases(db_save_path)


from qatch import TestGenerator

# init generator
test_generator = TestGenerator(databases=databases)

# generate tests for each database and for each generator
tests_df = test_generator.generate()


# Create a mask where each row is True if 'db_id' is in db_ids_train_spider, else False
mask = tests_df['db_id'].isin(db_ids_train_spider)

# Apply the mask to filter the DataFrame
tests_df = tests_df[mask]

unique_db_ids=tests_df['db_id'].unique()


tests_df = tests_df[tests_df.apply(lambda row: row['tbl_name'] in db_to_tables_constraint.get(row['db_id'], []), axis=1)]

tests_df=tests_df.rename(columns={'tbl_name':'table_used'})


qatch_synthetic=tests_df.drop(columns=["sql_tags"], axis=1)
qatch_synthetic=qatch_synthetic.copy()
qatch_synthetic['seq_id'] = qatch_synthetic.apply(lambda row: f"{row['db_id']}_X_{row['table_used']}", axis=1)


qatch_synthetic.reset_index(inplace=True)
qatch_synthetic.rename(columns={'index': 'ID'}, inplace=True)


results_df_qatch = pd.DataFrame(columns=['ID','answer_text'])


base_path = '/content/drive/MyDrive/spider/test_database/'

for index, row in qatch_synthetic.iterrows():
    db_id = row['db_id']
    table_used = row['table_used']
    query = row['query']
    ID=row['ID']


    db_path = f'{base_path}{db_id}/{db_id}.sqlite'


    with sqlite3.connect(db_path) as conn:
        cur = conn.cursor()
        cur.execute(query)
        answers = cur.fetchall()


        list_answers = list(chain.from_iterable(answers))


        new_row = pd.DataFrame({'ID': [ID], 'answer_text': [list_answers]})
        results_df_qatch = pd.concat([results_df_qatch, new_row], ignore_index=True)


merged_df_qatch = pd.merge(results_df_qatch, qatch_synthetic, on='ID')
merged_df_qatch = merged_df_qatch[merged_df_qatch['answer_text'].apply(lambda x: x != [])]


problematic_rows_ids=[]


result = merged_df_qatch.apply(get_answer_coordinates, axis=1, result_type='expand')
merged_df_qatch['answer_coordinates'],merged_df_qatch['float_value'] = result[0],result[2]

merged_df_cleaned_qatch = merged_df_qatch[~merged_df_qatch['ID'].isin(problematic_rows_ids)].copy()

merged_df_cleaned_qatch = merged_df_cleaned_qatch.dropna(subset=['answer_coordinates'])

In [None]:
# qatch_df_sampled=merged_df_cleaned_qatch.sample(n=3000, random_state=42)

In [None]:
# merged_df_cleaned_qatch.to_pickle('/content/drive/MyDrive/spider/QATCH_SPIDER.pkl')

In [None]:
qatch_pickle = pd.read_pickle('/content/drive/MyDrive/spider/QATCH_SPIDER.pkl')

In [None]:
qatch_pickle

Unnamed: 0,ID,answer_text,db_id,table_used,query,question,seq_id,answer_coordinates,float_value
0,10,"[1, Internet Explorer, 28.96, 2, Firefox, 18.1...",browser_web,browser,"SELECT * FROM ""browser""",Show all the rows in the table browser,browser_web_X_browser,"[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2...",
1,11,"[1, 2, 3, 4]",browser_web,browser,"SELECT ""id"" FROM ""browser""","Show all ""id"" in the table browser",browser_web_X_browser,"[(0, 0), (1, 0), (2, 0), (3, 0)]",
2,12,"[1, Internet Explorer, 2, Firefox, 3, Safari, ...",browser_web,browser,"SELECT ""id"", ""name"" FROM ""browser""","Show all ""id"", ""name"" in the table browser",browser_web_X_browser,"[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1...",
3,13,"[Internet Explorer, Firefox, Safari, Opera]",browser_web,browser,"SELECT ""name"" FROM ""browser""","Show all ""name"" in the table browser",browser_web_X_browser,"[(0, 1), (1, 1), (2, 1), (3, 1)]",
4,14,"[Internet Explorer, 28.96, Firefox, 18.11, Saf...",browser_web,browser,"SELECT ""name"", ""market_share"" FROM ""browser""","Show all ""name"", ""market_share"" in the table b...",browser_web_X_browser,"[(0, 1), (0, 2), (1, 1), (1, 2), (2, 1), (2, 2...",
...,...,...,...,...,...,...,...,...,...
26165,63103,[6],school_player,school_details,"SELECT COUNT(DISTINCT""Nickname"") FROM ""school_...","How many different ""Nickname"" are in table ""sc...",school_player_X_school_details,"[(5, 0)]",6.0
26166,63104,[6],school_player,school_details,"SELECT COUNT(DISTINCT""Colors"") FROM ""school_de...","How many different ""Colors"" are in table ""scho...",school_player_X_school_details,"[(5, 0)]",6.0
26167,63105,[1],school_player,school_details,"SELECT COUNT(DISTINCT""League"") FROM ""school_de...","How many different ""League"" are in table ""scho...",school_player_X_school_details,"[(0, 0)]",1.0
26168,63106,[3],school_player,school_details,"SELECT COUNT(DISTINCT""Class"") FROM ""school_det...","How many different ""Class"" are in table ""schoo...",school_player_X_school_details,"[(2, 0)]",3.0


In [None]:
qatch_pickle_sampled=qatch_pickle.sample(n=3000, random_state=42)

In [None]:
train_df = pd.concat([train_df, qatch_pickle_sampled], axis=0)

In [None]:
train_df

Unnamed: 0,ID,answer_text,db_id,query,question,table_used,seq_id,answer_coordinates,float_value
50,98,"[Dariana, Hoyt, Lizeth, Mayra, Nova, Shannon, ...",student_assessment,SELECT first_name FROM people ORDER BY first_name,What are the first names of the people in alph...,People,student_assessment_X_People,"[(2, 1), (4, 1), (6, 1), (5, 1), (7, 1), (0, 1...",
1930,5521,[15],products_gen_characteristics,SELECT count(*) FROM CHARACTERISTICS,Count the number of characteristics.,Characteristics,products_gen_characteristics_X_Characteristics,"[(14, 0)]",15.0
299,854,[3],chinook_1,SELECT COUNT(DISTINCT city) FROM EMPLOYEE,Find the number of different cities that emplo...,Employee,chinook_1_X_Employee,"[(2, 0)]",3.0
1356,4011,"[Battle ship, 3, Cargo ship, 5]",ship_mission,"SELECT TYPE , COUNT(*) FROM ship GROUP BY TYPE","For each type, how many ships are there?",ship,ship_mission_X_ship,"[(1, 2), (2, 0), (0, 2), (4, 0)]",
109,309,[Murray Coffee shop],product_catalog,SELECT distinct(catalog_publisher) FROM catalo...,Find all the catalog publishers whose name con...,Catalogs,product_catalog_X_Catalogs,"[(1, 2)]",
...,...,...,...,...,...,...,...,...,...
25105,61802,"[340, 161, 45, 2017-05-01 17:32:26, 2018-03-09...",student_assessment,"SELECT * FROM ""People_Addresses"" WHERE ""person...","Show the data of the table ""People_Addresses"" ...",People_Addresses,student_assessment_X_People_Addresses,"[(5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (6, 0...",
844,1377,"[68476213, 68421879, 65465421, 36546321]",hospital_1,"SELECT ""InsuranceID"" FROM ""Patient"" ORDER BY ""...","Project the ""InsuranceID"" ordered in descendin...",Patient,hospital_1_X_Patient,"[(0, 4), (3, 4), (2, 4), (1, 4)]",
1514,2307,"[Westport, CT, San Antonio, TX, Placentia, CA,...",perpetrator,"SELECT ""Home Town"" FROM ""people"" ORDER BY ""Hom...","Project the ""Home Town"" ordered in descending ...",people,perpetrator_X_people,"[(3, 4), (5, 4), (2, 4), (0, 4), (7, 4), (8, 4...",
11219,25355,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...",customer_deliveries,"SELECT ""truck_id"" FROM ""Trucks""","Show all ""truck_id"" in the table Trucks",Trucks,customer_deliveries_X_Trucks,"[(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0...",


In [None]:
train_df=train_df.drop(columns=["ID"])
train_df.reset_index(inplace=True)
train_df.rename(columns={'index': 'ID'}, inplace=True)
train_df.head()

Unnamed: 0,ID,answer_text,db_id,query,question,table_used,seq_id,answer_coordinates,float_value
0,50,"[Dariana, Hoyt, Lizeth, Mayra, Nova, Shannon, ...",student_assessment,SELECT first_name FROM people ORDER BY first_name,What are the first names of the people in alph...,People,student_assessment_X_People,"[(2, 1), (4, 1), (6, 1), (5, 1), (7, 1), (0, 1...",
1,1930,[15],products_gen_characteristics,SELECT count(*) FROM CHARACTERISTICS,Count the number of characteristics.,Characteristics,products_gen_characteristics_X_Characteristics,"[(14, 0)]",15.0
2,299,[3],chinook_1,SELECT COUNT(DISTINCT city) FROM EMPLOYEE,Find the number of different cities that emplo...,Employee,chinook_1_X_Employee,"[(2, 0)]",3.0
3,1356,"[Battle ship, 3, Cargo ship, 5]",ship_mission,"SELECT TYPE , COUNT(*) FROM ship GROUP BY TYPE","For each type, how many ships are there?",ship,ship_mission_X_ship,"[(1, 2), (2, 0), (0, 2), (4, 0)]",
4,109,[Murray Coffee shop],product_catalog,SELECT distinct(catalog_publisher) FROM catalo...,Find all the catalog publishers whose name con...,Catalogs,product_catalog_X_Catalogs,"[(1, 2)]",


In [None]:
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)


In [None]:
import torch
import pandas as pd
import numpy as np

class TableDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer):
        self.df = df
        self.tokenizer = tokenizer
        self.problematic_ids = []  # Initialize a list to store IDs of problematic questions

    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        table_name = item['seq_id']
        try:
            table = db_table_to_df[table_name].astype(str)
            encoding = self.tokenizer(
                table=table,
                queries=item.question,
                answer_coordinates=item.answer_coordinates,
                answer_text=item.answer_text,
                truncation=True,
                padding="max_length",
                return_tensors="pt",
            )
            float_answer = torch.tensor(item.float_value)
            encoding = {key: val.squeeze(0) for key, val in encoding.items()}
            encoding["float_answer"] = float_answer


        except Exception as e:
            print(f"Error processing index {idx} (ID: {item['ID']}): {e}")
            self.problematic_ids.append(item.ID)  # Store the ID of the problematic question
            encoding = None

        return encoding


    def __len__(self):
        return len(self.df)

    def get_problematic_ids(self):
        # Returns the list of problematic IDs
        return self.problematic_ids

# Custom collate function to filter out None values, if not already implemented
def custom_collate_fn(batch):
    batch = [item for item in batch if item is not None]  # Filter out None values
    if not batch:
        return None
    return torch.utils.data.dataloader.default_collate(batch)

train_dataset = TableDataset(df=merged_df_cleaned, tokenizer=tokenizer)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, collate_fn=custom_collate_fn)

for idx, batch in enumerate(train_dataloader):
    if batch is not None:
        continue
    else:
        print(f"Batch {idx} skipped due to problematic data")

print(train_dataset.get_problematic_ids())

Error processing index 769 (ID: 3774): Couldn't find all answers
Error processing index 770 (ID: 3775): Couldn't find all answers
Error processing index 771 (ID: 3776): Couldn't find all answers
Error processing index 772 (ID: 3777): Couldn't find all answers
Error processing index 781 (ID: 3796): Couldn't find all answers
Error processing index 782 (ID: 3797): Couldn't find all answers
Error processing index 839 (ID: 4168): Couldn't find all answers
Error processing index 840 (ID: 4169): Couldn't find all answers
Error processing index 847 (ID: 4184): Couldn't find all answers
Error processing index 848 (ID: 4185): Couldn't find all answers
Error processing index 1125 (ID: 5522): Couldn't find all answers
Error processing index 1126 (ID: 5523): Couldn't find all answers
Error processing index 1137 (ID: 5536): Couldn't find all answers
Error processing index 1138 (ID: 5537): Couldn't find all answers
Error processing index 1196 (ID: 5848): Couldn't find all answers
[3774, 3775, 3776, 3

In [None]:
#@title legacy stuff
# import torch
# import pandas as pd
# import numpy as np

# class TableDataset(torch.utils.data.Dataset):
#     def __init__(self, df, tokenizer):
#         self.df = df
#         self.tokenizer = tokenizer
#         self.problematic_ids = []  # Initialize a list to store IDs of problematic questions

#     def __getitem__(self, idx):
#         item = self.df.iloc[idx]
#         table_name = item['seq_id']
#         try:
#             table = db_table_to_df[table_name].astype(str)
#             encoding = self.tokenizer(
#                 table=table,
#                 queries=item.question,
#                 answer_coordinates=item.answer_coordinates,
#                 answer_text=item.answer_text,
#                 truncation=True,
#                 padding="max_length",
#                 return_tensors="pt",
#             )
#             float_answer = torch.tensor(item.float_value)
#             encoding = {key: val.squeeze(0) for key, val in encoding.items()}
#             encoding["float_answer"] = float_answer

#         except ValueError as e:  # Catch ValueError specifically
#             if str(e) == "Couldn't find all answers":
#                 print(f"Specific error processing index {idx} (ID: {item['ID']}): {e}")
#                 self.problematic_ids.append(item.ID)  # Store the ID of the problematic question
#                 encoding = None
#             else:
#                 raise  # Re-raise the exception if it's not the one you're looking for
#         except Exception as e:  # Catch other exceptions
#             print(f"General error processing index {idx} (ID: {item['ID']}): {e}")
#             encoding = None  # It might be a good idea to handle general errors as well

#         return encoding

#     def __len__(self):
#         return len(self.df)

#     def get_problematic_ids(self):
#         # Returns the list of problematic IDs
#         return self.problematic_ids

# # Custom collate function to filter out None values, if not already implemented
# def custom_collate_fn(batch):
#     batch = [item for item in batch if item is not None]  # Filter out None values
#     if not batch:
#         return None
#     return torch.utils.data.dataloader.default_collate(batch)

# # Example usage (assuming train_df and tokenizer are defined)
# train_dataset = TableDataset(df=train_df, tokenizer=tokenizer)
# train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, collate_fn=custom_collate_fn)

# for idx, batch in enumerate(train_dataloader):
#     if batch is not None:
#         continue
#     else:
#         print(f"Batch {idx} skipped due to problematic data")

# print(train_dataset.get_problematic_ids())


In [None]:
# prob_list=[3776, 5523, 5522, 4185, 5537, 3774, 4168, 5848, 3797, 3796, 3775, 4169, 52768, 48172, 57739, 53019, 10050, 56016, 33902, 60889, 34287, 23285, 24387, 23988, 48735, 61339, 34361, 28029, 8175, 61734, 25276, 10208, 50328, 21063, 50108, 24419, 2710, 57971, 27532, 7282, 18727, 58255, 57568, 35438, 9721, 5272, 6017, 52112, 51759, 47965, 24367, 53167, 3567, 15534, 30590, 59200, 61127, 2403, 5044, 54678, 14421, 35811, 10072, 23160, 50390, 14279, 5918, 29174, 24845, 52921, 30481, 28586, 48033, 51788, 60013, 959, 10658, 7210, 61112, 27207, 30207, 58144, 56903, 15413, 8046, 19849, 14447, 62012, 17677, 50117, 15520, 48111, 10699, 8120, 51491, 24409, 23588, 35720, 9626, 34488, 3503, 33792, 27498, 29830, 28981, 25828, 1974, 1501, 49787, 3220, 19869, 19514, 60343, 7516, 50932, 55660, 24820, 795, 51357, 52471, 48381, 18521, 51390, 15586, 60847, 8039, 2219, 24653, 61534, 24736, 58021, 59948, 58032, 8282, 30206, 34173, 60247, 18802, 16789, 18524, 24422, 21230, 9971, 33821, 15230, 50105, 18420, 31474, 7820, 15248, 37421, 49718, 134, 30420, 16780, 53173, 27253, 5683, 24045, 25321, 35546, 49061, 30032, 25150, 56925, 52870, 51342, 34343, 15322, 17493, 30634, 2951, 14342, 25167, 63094, 3639, 48197, 51960, 30231, 2917, 58209, 18667, 1808, 18250, 766, 8279, 2936, 2879, 27320, 34048, 48346, 9556, 56778, 16621, 35664, 21846, 61089, 18499, 35258, 48377, 7425, 36903, 27156, 35340, 15364, 27084, 37045, 17603, 14349, 23279, 48432, 928, 10755, 590, 2166, 21227, 36736, 37108, 8384, 62592, 60319, 29259, 1215, 2461, 49173, 5226, 10096, 112, 50693, 2889, 60188, 30459, 15795, 9528, 37381, 23261, 5965, 48363, 7525, 8787, 27163, 16720, 50186, 8234, 25086, 49994, 23246, 8326, 37264, 9771, 5939, 28707, 24617, 27161, 20025, 37242, 48093, 28929, 60283, 49912, 52632, 25335, 24325, 37051, 49405, 27076, 35576, 25809, 539, 47931, 21691, 35821, 35030, 5230, 36901, 57466, 21102, 28772, 25495, 54518, 23152, 9899, 24276, 37195, 27644, 25810, 33778, 57473, 5055, 3642, 60089, 16426, 27108, 49607, 4967, 30570, 52473, 51944, 63103, 16712, 19846, 8291, 48414, 5867, 24542, 541, 9525, 51690, 24897, 49080, 48548, 28918, 62111, 28729, 52617, 25133, 18501, 50080, 27002, 61247, 49471, 30090, 2993, 16611, 50808, 53311, 15019, 47969, 33759, 49351, 23433, 2835, 50086, 60397, 48633, 7574, 23265, 29932, 51086, 3585, 51822, 28516, 35692, 15496, 2333, 51148, 63022, 48026, 30029, 50003, 16776, 49250, 58298, 34209, 21094, 54575, 54905, 47995, 10056, 5019, 3681, 61184, 27092, 10093, 52414, 20114, 62321, 50247, 18561, 50196, 7873, 9509, 34179, 2343, 60896, 5058, 8212, 34217, 62155, 50169, 23292, 54483, 48082, 18689, 23355, 18206, 60083, 28591, 14529, 51456, 30282, 35076, 57595, 2453, 2013, 51632, 18456, 25798, 33898, 10666, 59144, 25331, 35146, 35470, 50590, 36771, 27391, 54795, 55890, 9833, 27030, 1526, 27679, 35268, 52154, 56783, 50729, 9710, 1287, 8458, 18287, 59943, 52854, 3498, 59213, 16798, 15615, 30476, 63032, 48385, 31450, 15500, 21124, 16749, 49780, 54836, 24859, 52068, 52568, 15429, 27301, 25043, 17536, 51812, 18218, 24212, 25641, 14443, 1125, 58176, 18209, 61044, 14436, 29021, 61518, 28903, 36898, 34384, 53083, 5874, 28593, 8325, 37289, 18201, 55371, 59981, 2224, 49112, 61684, 33730, 34406, 33972, 16382, 25930, 28980, 52637, 61463, 53150, 23395, 9896, 27113, 58179, 56024, 7685, 62724, 61883, 54534, 58027, 19884, 30186, 19890, 28094, 7230, 2189, 61806, 21062, 52097, 59933, 5427, 9840, 35057, 51542, 60264, 8131, 60712, 14522, 50715, 18370, 14432, 60406, 51210, 50346, 8258, 53535, 15607, 28851, 18413, 30045, 17672, 29298, 23201, 27402, 51868, 16574, 30043, 960, 19864, 61241, 14311, 62312, 35680, 37414, 2451, 15873, 49197, 5420, 62671, 21248, 1200, 48875, 24018, 18819, 18820, 18517, 770, 18418, 21275, 17547, 28704, 23470, 25561, 62365, 47958, 50073, 8216, 16392, 7784, 16784, 16781, 10159, 24671, 51729, 2434, 37417, 2915, 48155, 60381, 29025, 54594, 33745, 7855, 53079, 61572, 51299, 50326, 55693, 53264, 7376, 48550, 35122, 5131, 28651, 15811, 62244, 53442, 48824, 51303, 35158, 24770, 17550, 53139, 51799, 51917, 15601, 61661, 34164, 19896, 33843, 24274, 50405, 60892, 36784, 54565, 8408, 61077, 52153, 28663, 24263, 24681, 23295, 25929, 50610, 2802, 5875, 3155, 52903, 2952, 58177, 48000, 9678, 5304, 29315, 49928, 14375, 33774, 24868, 24466, 3505, 1088, 36835, 57614, 5756, 28508, 7832, 37424, 34491, 50761, 28746, 61212, 19886, 34199, 5406, 21223, 10210, 7813, 3620, 9777, 5921, 10160, 35261, 24232, 27015, 49014, 36886, 60462, 27523, 60038, 23153, 57727, 28720, 48810, 16738, 50178, 57682, 28594, 49166, 5240, 25312, 51498, 3525, 48928, 51600, 52210, 23268, 5820, 47959, 5828, 23167, 49023, 35803, 16773, 50854, 15049, 24038, 27120, 52994, 7183, 57844, 51372, 57736, 60336, 57656, 37071, 52937, 62850, 34284, 61102, 27581, 8248, 5381, 239, 24718, 59206, 53199, 49754, 25875, 50754, 28879, 50837, 8171, 5718, 51110, 5751, 27432, 51828, 51213, 28854, 28559, 27634, 15361, 15301, 19553, 814, 30979, 37513, 62561, 50205, 15498, 53025, 23192, 57446, 1048, 33814, 30361, 62407, 59110, 19865, 24198, 50084, 10806, 20096, 62068, 29209, 18459, 52669, 20001, 10062, 54732, 991, 1349, 8044, 27179, 27024, 30467, 2354, 18531, 50857, 29086, 10811, 56819, 21170, 51044, 48464, 17455, 58212, 34513, 61542, 60119, 52918, 30391, 51409, 18692, 14296, 30540, 27216, 29134, 30407, 8263, 8793, 51128, 10748, 52775, 35895, 55771, 50616, 55605, 52593, 51005, 58025, 8970, 7356, 56927, 25307, 34446, 24139, 3592, 5263, 51445, 62223, 5837, 34068, 48337, 54600, 1939, 16910, 57832, 56425, 30290, 49935, 61421, 60107, 37376, 61539, 52136, 35706, 18817, 7141, 20033, 7418, 34352, 23223, 19535, 50310, 15742, 18217, 25386, 57428, 48684, 35831, 35434, 1086, 2259, 4938, 49752, 23301, 7456, 50777, 33772, 24304, 35575, 50265, 61140, 51027, 23378, 48722, 29939, 15717, 30485, 245, 48311, 62224, 5887, 2962, 60329, 3125, 18791, 9736, 57679, 48303, 18716, 48443, 15575, 57582, 33798, 35825, 23523, 25305, 9549, 47919, 27229, 30434, 52746, 1862, 52133, 62510, 7822, 61481, 52140, 60933, 37091, 27341, 23137, 24728, 21258, 29314, 5205, 50032, 16855, 60421, 2203, 50844, 1743, 52767, 8086, 54585, 52516, 61256, 18278, 50638, 37154, 23247, 29941, 15504, 10785, 52449, 7687, 60303, 5693, 7463, 25878, 61139, 50577, 59254, 27509, 16585, 29976, 18480, 48064, 18803, 27287, 52574, 57000, 1734, 50049, 25114, 24049, 48556, 5098, 59052, 27173, 25127, 51227, 15729, 9921, 18317, 52600, 9536, 50396, 17606, 8799, 10814, 61287, 9806, 51191, 15405, 2371, 744, 27171, 52009, 7220, 25865, 2170, 23234, 2892, 10781, 30601, 61237, 24041, 60288, 34232, 29333, 5997, 51599, 33744, 57749, 31448, 18519, 7472, 5992, 16576, 54708, 8383, 29945, 57521, 51127, 34366, 15077, 57492, 52174, 7457, 9983, 48434, 62618, 54634, 10105, 56868, 25486, 8912, 23996, 59115, 8003, 49219, 15626, 60101, 49641, 25780, 60426, 16969, 48450, 23501, 19540, 5377, 57519, 56724, 35169, 7494, 17633, 51221, 55415, 28100, 16847, 53430]
# print(len(prob_list))

1012


In [None]:
prob=train_dataset.get_problematic_ids()
train_df = merged_df_cleaned[~merged_df_cleaned['ID'].isin(prob)]
train_dataset = TableDataset(df=merged_df_cleaned, tokenizer=tokenizer)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, collate_fn=custom_collate_fn)

In [None]:
while train_dataset.get_problematic_ids()!=[]:
  prob=train_dataset.get_problematic_ids()
  train_df = train_df[~train_df['ID'].isin(prob)]
  train_dataset = TableDataset(df=train_df, tokenizer=tokenizer)
  train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, collate_fn=custom_collate_fn)
  for idx, batch in enumerate(train_dataloader):
      if batch is not None:
          continue
      else:
          print(f"Batch {idx} skipped due to problematic data")
  print(f"problematic_ids: {train_dataset.get_problematic_ids()}")

problematic_ids: []


In [None]:
train_dataset[0]["token_type_ids"].shape

torch.Size([512, 7])

In [None]:
train_dataset[1]["input_ids"].shape

torch.Size([512])

In [None]:
batch = next(iter(train_dataloader))

In [None]:
batch["input_ids"].shape

torch.Size([8, 512])

In [None]:
batch["token_type_ids"].shape

torch.Size([8, 512, 7])

In [None]:
tokenizer.decode(batch["input_ids"][0])

'[CLS] what are the names of all the documents, as well as the access counts of each, ordered alphabetically? [SEP] document _ code document _ structure _ code document _ type _ code access _ count document _ name 217 8 book 1864 learning english 621 1 paper 8208 research about art history 958 8 book 3769 learning database 961 5 advertisement 6661 summer sails 989 9 book 2910 learning japanese 930 9 cv 6345 david cv 928 8 book 2045 how to cook pasta 510 6 paper 3479 humanity : a fact 706 9 advertisement 8623 winter sails 465 9 cv 5924 john cv 713 8 cv 2294 joe cv 566 5 advertisement 3289 spring sails 349 9 book 1219 life about claude monet 675 1 advertisement 7509 fall sails 714 6 paper 9948 relationships between history and arts [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [None]:
#first example should not have any prev_labels set
assert batch["token_type_ids"][0][:,3].sum() == 0

In [None]:
tokenizer.decode(batch["input_ids"][1])

'[CLS] return the lot details and investor ids. [SEP] lot _ id investor _ id lot _ details 1 13 r 2 16 z 3 10 s 4 19 s 5 6 q 6 20 d 7 7 m 8 7 h 9 20 z 10 9 x 11 1 d 12 19 m 13 7 z 14 6 d 15 1 h [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

In [None]:
for id, prev_label in zip(batch["input_ids"][1], batch["token_type_ids"][1][:,3]):
  if id != 0:
    print(tokenizer.decode([id]), prev_label.item())

[CLS] 0
return 0
the 0
lot 0
details 0
and 0
investor 0
id 0
##s 0
. 0
[SEP] 0
lot 0
_ 0
id 0
investor 0
_ 0
id 0
lot 0
_ 0
details 0
1 0
13 0
r 0
2 0
16 0
z 0
3 0
10 0
s 0
4 0
19 0
s 0
5 0
6 0
q 0
6 0
20 0
d 0
7 0
7 0
m 0
8 0
7 0
h 0
9 0
20 0
z 0
10 0
9 0
x 0
11 0
1 0
d 0
12 0
19 0
m 0
13 0
7 0
z 0
14 0
6 0
d 0
15 0
1 0
h 0


In [None]:
# from transformers import TapasForQuestionAnswering
# model = TapasForQuestionAnswering.from_pretrained("google/tapas-base")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model.to(device)

# The training

In [None]:
# from torch.optim import AdamW
# optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
from transformers import TapasConfig, TapasForQuestionAnswering
from torch.optim import AdamW
# this is the default WTQ configuration
config = TapasConfig(
    num_aggregation_labels=4,# MAX MIN COUNT AVG
    use_answer_as_supervision=True,
    answer_loss_cutoff=0.664694,
    cell_selection_preference=0.207951,
    huber_loss_delta=0.121194,
    init_cell_selection_weights_to_zero=True,
    select_one_column=True,
    allow_empty_column_selection=False,
    temperature=0.0352513,
)
model = TapasForQuestionAnswering.from_pretrained("google/tapas-base", config=config)
optimizer = AdamW(model.parameters(), lr=5e-5)## I have changed
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model.to(device)
# model.to('cuda')

pytorch_model.bin:   0%|          | 0.00/443M [00:00<?, ?B/s]

Some weights of TapasForQuestionAnswering were not initialized from the model checkpoint at google/tapas-base and are newly initialized: ['aggregation_classifier.bias', 'aggregation_classifier.weight', 'column_output_bias', 'column_output_weights', 'output_bias', 'output_weights']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
import torch
import os


save_path = '/content/drive/MyDrive/model_checkpoints'
checkpoint_path = os.path.join(save_path, "100prct_spider_checkpoint4.pt")


if os.path.exists(checkpoint_path):
    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)#,map_location=torch.device('cpu')

    # Restore the model and optimizer states
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    print("Checkpoint loaded successfully.")
else:
    print(f"No checkpoint found at {checkpoint_path}")


Checkpoint loaded successfully.


In [None]:
model.to('cuda')

TapasForQuestionAnswering(
  (tapas): TapasModel(
    (embeddings): TapasEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(1024, 768)
      (token_type_embeddings_0): Embedding(3, 768)
      (token_type_embeddings_1): Embedding(256, 768)
      (token_type_embeddings_2): Embedding(256, 768)
      (token_type_embeddings_3): Embedding(2, 768)
      (token_type_embeddings_4): Embedding(256, 768)
      (token_type_embeddings_5): Embedding(256, 768)
      (token_type_embeddings_6): Embedding(10, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): TapasEncoder(
      (layer): ModuleList(
        (0-11): 12 x TapasLayer(
          (attention): TapasAttention(
            (self): TapasSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias

In [None]:
# max_grad_norm = 5.0

# for epoch in range(10):
#     print("Epoch:", epoch)
#     for idx, batch in enumerate(train_dataloader):
#         # get the inputs;
#         input_ids = batch["input_ids"].to(device)
#         attention_mask = batch["attention_mask"].to(device)
#         token_type_ids = batch["token_type_ids"].to(device)
#         labels = batch["labels"].to(device)

#         # zero the parameter gradients
#         optimizer.zero_grad()

#         # forward + backward + optimize
#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
#         loss = outputs.loss
#         print("Loss:", loss.item())

#         loss.backward()

#         # Clip gradients to help prevent the "exploding gradient" problem
#         torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

#         optimizer.step()


In [None]:
# torch.cuda.memory_summary(device=None, abbreviated=False)

In [None]:
optimizer = AdamW(model.parameters(), lr=1e-5)## I have changed

In [None]:
# from torch.optim.lr_scheduler import ReduceLROnPlateau

# # Create the scheduler
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

# # model.train()
# for epoch in range(10):  # loop over the dataset multiple times
#     print("Epoch:", epoch)
#     epoch_loss = 0.0
#     for batch in train_dataloader:
#         # get the inputs;
#         input_ids = batch["input_ids"].to('cuda')
#         attention_mask = batch["attention_mask"].to('cuda')
#         token_type_ids = batch["token_type_ids"].to('cuda')
#         labels = batch["labels"].to('cuda')
#         numeric_values = batch["numeric_values"].to('cuda')
#         numeric_values_scale = batch["numeric_values_scale"].to('cuda')
#         float_answer = batch["float_answer"].to('cuda')

#         # zero the parameter gradients
#         optimizer.zero_grad()

#         # forward + backward + optimize
#         outputs = model(
#             input_ids=input_ids,
#             attention_mask=attention_mask,
#             token_type_ids=token_type_ids,
#             labels=labels,
#             numeric_values=numeric_values,
#             numeric_values_scale=numeric_values_scale,
#             float_answer=float_answer,
#         )
#         loss = outputs.loss
#         print("Loss:", loss.item())
#         epoch_loss += loss.item()

#         loss.backward()
#         optimizer.step()

#     # Calculate average loss for the epoch
#     avg_loss = epoch_loss / len(train_dataloader)

#     # Update the learning rate based on the average loss
#     scheduler.step(avg_loss)

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Create the scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=2, verbose=True)

# model.train()
for epoch in range(10):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    epoch_loss = 0.0
    for batch in train_dataloader:
        # get the inputs;
        input_ids = batch["input_ids"]#.to('cuda')
        attention_mask = batch["attention_mask"]#.to('cuda')
        token_type_ids = batch["token_type_ids"]#.to('cuda')
        labels = batch["labels"]#.to('cuda')
        numeric_values = batch["numeric_values"]#.to('cuda')
        numeric_values_scale = batch["numeric_values_scale"]#.to('cuda')
        float_answer = batch["float_answer"]#.to('cuda')

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=labels,
            numeric_values=numeric_values,
            numeric_values_scale=numeric_values_scale,
            float_answer=float_answer,
        )
        loss = outputs.loss
        print("Loss:", loss.item())
        epoch_loss += loss.item()

        loss.backward()
        optimizer.step()

    # Calculate average loss for the epoch
    avg_loss = epoch_loss / len(train_dataloader)

    # Update the learning rate based on the average loss
    scheduler.step(avg_loss)

In [None]:
# model.train()

for epoch in range(10):  # loop over the dataset multiple times
    print("Epoch:", epoch)
    for batch in train_dataloader:
        # get the inputs;
        input_ids = batch["input_ids"].to('cuda')
        attention_mask = batch["attention_mask"].to('cuda')
        token_type_ids = batch["token_type_ids"].to('cuda')
        labels = batch["labels"].to('cuda')
        numeric_values = batch["numeric_values"].to('cuda')
        numeric_values_scale = batch["numeric_values_scale"].to('cuda')
        float_answer = batch["float_answer"].to('cuda')

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            labels=labels,
            numeric_values=numeric_values,
            numeric_values_scale=numeric_values_scale,
            float_answer=float_answer,
        )

        loss = outputs.loss
        print("Loss:", loss.item())
        loss.backward()
        optimizer.step()

Epoch: 0
Loss: 2.572195529937744
Loss: 3.356193780899048
Loss: 2.1603939533233643
Loss: 2.390073299407959
Loss: 3.2742481231689453
Loss: 2.6931276321411133
Loss: 2.071603775024414
Loss: 2.6375935077667236
Loss: 2.453601121902466
Loss: 1.9995925426483154
Loss: 2.4221837520599365
Loss: 2.218867301940918
Loss: 2.341974973678589
Loss: 2.2829134464263916
Loss: 2.422752618789673
Loss: 2.2769386768341064
Loss: 2.2779934406280518
Loss: 2.353764772415161
Loss: 1.997378945350647
Loss: 2.120710849761963
Loss: 2.331651449203491
Loss: 2.21628999710083
Loss: 2.345754861831665
Loss: 2.15533709526062
Loss: 2.4877965450286865
Loss: 2.1091699600219727
Loss: 1.9629642963409424
Loss: 1.8696364164352417
Loss: 2.0381860733032227
Loss: 1.9559087753295898
Loss: 2.028841972351074
Loss: 2.416954517364502
Loss: 2.648376703262329
Loss: 2.4997520446777344
Loss: 1.982934594154358
Loss: 2.348442554473877
Loss: 2.351020097732544
Loss: 1.9429798126220703
Loss: 2.404539108276367
Loss: 1.9663386344909668
Loss: 2.1592187

In [None]:
    # model.eval()
    # val_loss = 0.0
    # with torch.no_grad():
    #     for batch in eval_dataloader:
    #         # put batch on device
    #         batch = {k:v.to(device) for k,v in batch.items()}

    #         # forward pass
    #         outputs = model(**batch)
    #         loss = outputs.logits

    #         val_loss += loss.item()

    # print("Validation loss after epoch {epoch}:", val_loss/len(eval_dataloader))

In [None]:
# import numpy as np
# f1_scores=np.array(f1_scores)
# accuracy_scores=np.array(accuracy_scores)

In [None]:
# print(f"mean of f1 scores: {f1_scores.mean()}\n")
# print(f"mean of accuracy scores: {accuracy_scores.mean()}")

In [None]:
# import torch
# torch.cuda.empty_cache()


In [None]:
# import gc
# gc.collect()
# torch.cuda.empty_cache()

In [None]:
# final_accuracy_score = accuracy.compute()
# final_f1_score = f1.compute()

In [None]:
# print(f"Final accuracy: {final_accuracy_score['accuracy']}")
# print(f"Final F1 score: {final_f1_score['f1']}")

In [None]:
save_path = '/content/drive/MyDrive/model_checkpoints'

# Save the model and optimizer state at the end of training
if not os.path.exists(save_path):
    os.makedirs(save_path)
checkpoint_path = os.path.join(save_path, "100prct_spider_checkpoint5Dinetunedwqa.pt")
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, checkpoint_path)
print(f"Final checkpoint saved to {checkpoint_path}")

Final checkpoint saved to /content/drive/MyDrive/model_checkpoints/100prct_spider_checkpoint5Dinetunedwqa.pt


In [None]:
import collections
import numpy as np

def compute_prediction_sequence(model, data, device):
  """Computes predictions using model's answers to the previous questions."""

  # prepare data
  input_ids = data["input_ids"].to(device)
  attention_mask = data["attention_mask"].to(device)
  token_type_ids = data["token_type_ids"].to(device)

  all_logits = []
  prev_answers = None

  num_batch = data["input_ids"].shape[0]

  for idx in range(num_batch):

    if prev_answers is not None:
        coords_to_answer = prev_answers[idx]
        # Next, set the label ids predicted by the model
        prev_label_ids_example = token_type_ids_example[:,3] # shape (seq_len,)
        model_label_ids = np.zeros_like(prev_label_ids_example.cpu().numpy()) # shape (seq_len,)

        # for each token in the sequence:
        token_type_ids_example = token_type_ids[idx] # shape (seq_len, 7)
        for i in range(model_label_ids.shape[0]):
          segment_id = token_type_ids_example[:,0].tolist()[i]
          col_id = token_type_ids_example[:,1].tolist()[i] - 1
          row_id = token_type_ids_example[:,2].tolist()[i] - 1
          if row_id >= 0 and col_id >= 0 and segment_id == 1:
            model_label_ids[i] = int(coords_to_answer[(col_id, row_id)])

        # set the prev label ids of the example (shape (1, seq_len) )
        token_type_ids_example[:,3] = torch.from_numpy(model_label_ids).type(torch.long).to(device)

    prev_answers = {}
    # get the example
    input_ids_example = input_ids[idx] # shape (seq_len,)
    attention_mask_example = attention_mask[idx] # shape (seq_len,)
    token_type_ids_example = token_type_ids[idx] # shape (seq_len, 7)
    # forward pass to obtain the logits
    outputs = model(input_ids=input_ids_example.unsqueeze(0),
                    attention_mask=attention_mask_example.unsqueeze(0),
                    token_type_ids=token_type_ids_example.unsqueeze(0))
    logits = outputs.logits
    all_logits.append(logits)

    # convert logits to probabilities (which are of shape (1, seq_len))
    dist_per_token = torch.distributions.Bernoulli(logits=logits)
    probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(dist_per_token.probs.device)

    # Compute average probability per cell, aggregating over tokens.
    # Dictionary maps coordinates to a list of one or more probabilities
    coords_to_probs = collections.defaultdict(list)
    prev_answers = {}
    for i, p in enumerate(probabilities.squeeze().tolist()):
      segment_id = token_type_ids_example[:,0].tolist()[i]
      col = token_type_ids_example[:,1].tolist()[i] - 1
      row = token_type_ids_example[:,2].tolist()[i] - 1
      if col >= 0 and row >= 0 and segment_id == 1:
        coords_to_probs[(col, row)].append(p)

    # Next, map cell coordinates to 1 or 0 (depending on whether the mean prob of all cell tokens is > 0.5)
    coords_to_answer = {}
    for key in coords_to_probs:
      coords_to_answer[key] = np.array(coords_to_probs[key]).mean() > 0.5
    prev_answers[idx+1] = coords_to_answer

  logits_batch = torch.cat(tuple(all_logits), 0)

  return logits_batch

In [None]:
queries = ['Show all distinct building descriptions.', 'Give me a list of all the distinct building descriptions.', 'Show the short names of the buildings managed by "Emma".', 'Which buildings does "Emma" manage? Give me the short names of the buildings.', 'Show the addresses and phones of all the buildings managed by "Brenden".', 'What are the address and phone number of the buildings managed by "Brenden"?']

inputs = tokenizer(table=table, queries=queries, padding='max_length', return_tensors="pt",truncation=True)## I added truncation mayebe remove it
logits = compute_prediction_sequence(model, inputs, 'cuda')

In [None]:
predicted_answer_coordinates, = tokenizer.convert_logits_to_predictions(inputs, logits.cpu().detach())

In [None]:
# handy helper function in case inference on Pandas dataframe
answers = []
for coordinates in predicted_answer_coordinates:
  if len(coordinates) == 1:
    # only a single cell:
    answers.append(table.iat[coordinates[0]])
  else:
    # multiple cells
    cell_values = []
    for coordinate in coordinates:
      cell_values.append(table.iat[coordinate])
    answers.append(", ".join(cell_values))

display(table)
print("")
for query, answer in zip(queries, answers):
  print(query)
  print("Predicted answer: " + answer)

Unnamed: 0,building_id,building_short_name,building_full_name,building_description,building_address,building_manager,building_phone
0,133,Normandie Court,Normandie Court,Studio,"7950 Casper Vista Apt. 176\nMarquiseberg, CA 7...",Emma,(948)040-1064x387
1,153,Mercedes House,Mercedes House,Studio,"354 Otto Villages\nCharliefort, VT 71664",Brenden,915-617-2408x832
2,191,The Eugene,The Eugene,Flat,"71537 Gorczany Inlet\nWisozkburgh, AL 08256",Melyssa,(609)946-0491
3,196,VIA 57 WEST,VIA 57 WEST,Studio,"959 Ethel Viaduct\nWest Efrainburgh, DE 40074",Kathlyn,681.772.2454
4,225,Columbus Square,Columbus Square,Studio,"0703 Danika Mountains Apt. 362\nMohrland, AL 5...",Kyle,1-724-982-9507x640
5,532,Avalon Park,Avalon Park,Duplex,"6827 Kessler Parkway Suite 908\nAhmedberg, WI ...",Albert,376-017-3538
6,556,Peter Cooper Village,Peter Cooper Village,Flat,"861 Narciso Glens Suite 392\nEast Ottis, ND 73970",Darlene,1-224-619-0295x13195
7,624,Stuyvesant Town,Stuyvesant Town,Studio,101 Queenie Mountains Suite 619\nNew Korbinmou...,Marie,(145)411-6406
8,644,The Anthem,The Anthem,Flat,"50804 Mason Isle Suite 844\nWest Whitney, ID 6...",Ewald,(909)086-5221x3455
9,673,Barclay Tower,Barclay Tower,Flat,"1579 Runte Forges Apt. 548\nLeuschkeland, OK 1...",Rogers,1-326-267-3386x613



Show all distinct building descriptions.
Predicted answer: Studio, Flat, Duplex
Give me a list of all the distinct building descriptions.
Predicted answer: Studio, Flat, Duplex
Show the short names of the buildings managed by "Emma".
Predicted answer: Normandie Court
Which buildings does "Emma" manage? Give me the short names of the buildings.
Predicted answer: Normandie Court
Show the addresses and phones of all the buildings managed by "Brenden".
Predicted answer: 915-617-2408x832
What are the address and phone number of the buildings managed by "Brenden"?
Predicted answer: 915-617-2408x832


# Test with QATCH

In [None]:
from qatch.database_reader import MultipleDatabases

# The path to multiple databases
db_save_path = '/content/drive/MyDrive/spider/test_database'
databases = MultipleDatabases(db_save_path)

In [None]:
# from qatch import TestGenerator

# # # init generator
# # test_generator = TestGenerator(databases=databases)

# # # generate tests for each database and for each generator
# # tests_df = test_generator.generate()

Generating test for each database: 100%|██████████| 206/206 [04:09<00:00,  1.21s/it]


In [None]:
# excluded_tags = [
#     'GROUPBY-NO-AGGR', 'GROUPBY-COUNT', 'GROUPBY-AGG-MIN',
#     'GROUPBY-AGG-MAX', 'GROUPBY-AGG-AVG', 'GROUPBY-AGG-SUM',
#     'HAVING-COUNT-GR', 'HAVING-COUNT-LS', 'HAVING-COUNT-EQ',
#     'HAVING-AGG-AVG-GR', 'HAVING-AGG-AVG-LS', 'HAVING-AGG-SUM-GR',
#     'HAVING-AGG-SUM-LS', 'SIMPLE-AGG-COUNT',
#     'SIMPLE-AGG-COUNT-DISTINCT', 'SIMPLE-AGG-MAX', 'SIMPLE-AGG-MIN',
#     'SIMPLE-AGG-AVG', 'NULL-COUNT', 'NOT-NULL-COUNT',
#     'WHERE-CAT-MOST-FREQUENT', 'WHERE-CAT-LEAST-FREQUENT',
#     'WHERE-NOT-MOST-FREQUENT', 'WHERE-NOT-LEAST-FREQUENT','WHERE-NUM-MAX-VALUES-EMPTY',
#        'WHERE-NUM-MAX-VALUES', 'WHERE-NUM-MIN-VALUES',
#        'WHERE-NUM-MIN-VALUES-EMPTY', 'WHERE-NUM-MEAN-VALUES']

In [None]:
# filtered_tests_df = tests_df[~tests_df['sql_tags'].apply(lambda tag: any(excluded_tag in tag for excluded_tag in excluded_tags))]

In [None]:
# filtered_tests_df

Unnamed: 0,db_id,tbl_name,sql_tags,query,question
0,browser_web,Web_client_accelerator,SELECT-ALL,"SELECT * FROM ""Web_client_accelerator""",Show all the rows in the table Web_client_acce...
1,browser_web,Web_client_accelerator,SELECT-ADD-COL,"SELECT ""id"" FROM ""Web_client_accelerator""","Show all ""id"" in the table Web_client_accelerator"
2,browser_web,Web_client_accelerator,SELECT-ADD-COL,"SELECT ""id"", ""name"" FROM ""Web_client_accelerator""","Show all ""id"", ""name"" in the table Web_client_..."
3,browser_web,Web_client_accelerator,SELECT-ADD-COL,"SELECT ""id"", ""name"", ""Operating_system"" FROM ""...","Show all ""id"", ""name"", ""Operating_system"" in t..."
4,browser_web,Web_client_accelerator,SELECT-ADD-COL,"SELECT ""id"", ""name"", ""Operating_system"", ""Clie...","Show all ""id"", ""name"", ""Operating_system"", ""Cl..."
...,...,...,...,...,...
62788,school_player,player,DISTINCT-SINGLE,"SELECT DISTINCT ""Team"" FROM ""player""","Show the different ""Team"" in the table player"
62789,school_player,player,DISTINCT-SINGLE,"SELECT DISTINCT ""Position"" FROM ""player""","Show the different ""Position"" in the table player"
62790,school_player,player,DISTINCT-MULT,"SELECT DISTINCT ""Team"" FROM ""player""","Show the different ""Team"" in the table ""player"""
62791,school_player,player,DISTINCT-MULT,"SELECT DISTINCT ""Team"", ""Position"" FROM ""player""","Show the different ""Team"", ""Position"" in the t..."


In [None]:
# filtered_df_tests = filtered_tests_df[filtered_tests_df.apply(lambda row: row['tbl_name'] in db_to_tables_constraint.get(row['db_id'], []), axis=1)]

In [None]:
# filtered_df_tests

Unnamed: 0,db_id,tbl_name,sql_tags,query,question
10,browser_web,browser,SELECT-ALL,"SELECT * FROM ""browser""",Show all the rows in the table browser
11,browser_web,browser,SELECT-ADD-COL,"SELECT ""id"" FROM ""browser""","Show all ""id"" in the table browser"
12,browser_web,browser,SELECT-ADD-COL,"SELECT ""id"", ""name"" FROM ""browser""","Show all ""id"", ""name"" in the table browser"
13,browser_web,browser,SELECT-RANDOM-COL,"SELECT ""name"" FROM ""browser""","Show all ""name"" in the table browser"
14,browser_web,browser,SELECT-RANDOM-COL,"SELECT ""name"", ""market_share"" FROM ""browser""","Show all ""name"", ""market_share"" in the table b..."
...,...,...,...,...,...
62726,school_player,school_details,DISTINCT-MULT,"SELECT DISTINCT ""Class"" FROM ""school_details""","Show the different ""Class"" in the table ""schoo..."
62727,school_player,school_details,DISTINCT-MULT,"SELECT DISTINCT ""Class"", ""Division"" FROM ""scho...","Show the different ""Class"", ""Division"" in the ..."
62728,school_player,school_details,DISTINCT-MULT,"SELECT DISTINCT ""Class"", ""Division"", ""Colors"" ...","Show the different ""Class"", ""Division"", ""Color..."
62729,school_player,school_details,DISTINCT-MULT,"SELECT DISTINCT ""Class"", ""Division"", ""Colors"",...","Show the different ""Class"", ""Division"", ""Color..."


In [None]:
# filtered_df_sampled

Unnamed: 0,db_id,tbl_name,sql_tags,query,question
15511,phone_1,phone,ORDERBY-PROJECT,"SELECT ""Accreditation_level"" FROM ""phone"" ORDE...","Project the ""Accreditation_level"" ordered in d..."
49809,election,party,DISTINCT-MULT,"SELECT DISTINCT ""Comptroller"", ""US_Senate"", ""A...","Show the different ""Comptroller"", ""US_Senate"",..."
51446,tracking_orders,Order_Items,ORDERBY-PROJECT,"SELECT ""order_id"" FROM ""Order_Items"" ORDER BY ...","Project the ""order_id"" ordered in descending o..."
9584,culture_company,movie,DISTINCT-SINGLE,"SELECT DISTINCT ""Publisher"" FROM ""book_club""","Show the different ""Publisher"" in the table bo..."
55896,coffee_shop,shop,DISTINCT-SINGLE,"SELECT DISTINCT ""Address"" FROM ""shop""","Show the different ""Address"" in the table shop"
...,...,...,...,...,...
7264,store_product,store,DISTINCT-SINGLE,"SELECT DISTINCT ""Type"" FROM ""store""","Show the different ""Type"" in the table store"
24852,icfp_1,Inst,ORDERBY-SINGLE,"SELECT * FROM ""Inst"" ORDER BY ""instID"" DESC",Show all data ordered by instID in descending ...
50883,sports_competition,player,DISTINCT-SINGLE,"SELECT DISTINCT ""name"" FROM ""club""","Show the different ""name"" in the table club"
52432,phone_market,phone,ORDERBY-PROJECT,"SELECT ""Carrier"" FROM ""phone"" ORDER BY ""Carrie...","Project the ""Carrier"" ordered in ascending ord..."


In [None]:
# filtered_df_sampled = filtered_df_sampled.reset_index().rename(columns={'index': 'row_index'})

In [None]:
device='cuda'

In [None]:
def getAnswer(row):
    row_index = row['row_index']
    print(f"Processing row {row_index}...")

    seq_id = row["db_id"] + "_X_" + row["tbl_name"]

    # Check if the table exists in the dictionary and is not empty
    if seq_id not in db_table_to_df or db_table_to_df[seq_id].empty:
        print(f"Row {row_index} skipped: table not found or empty.")
        return None

    table = db_table_to_df[seq_id]
    table = table.astype(str)
    query_nl = [row["question"]]

    try:
        inputs = tokenizer(table=table, queries=query_nl, padding='max_length', return_tensors="pt", truncation=True)
        logits = compute_prediction_sequence(model, inputs, device)
        predicted_answer_coordinates, = tokenizer.convert_logits_to_predictions(inputs, logits.cpu().detach())

        answers = []
        for coordinates in predicted_answer_coordinates:
            if len(coordinates) == 1:
                # Only a single cell:
                answers.append(table.iat[coordinates[0]])
            else:
                # Multiple cells
                cell_values = []
                for coordinate in coordinates:
                    cell_values.append(table.iat[coordinate])
                answers.append(", ".join(cell_values))
        print(f"Row {row_index} processed successfully.")
    except IndexError:
        # If there's an index out of range error, skip this table
        print(f"Row {row_index} skipped due to IndexError.")
        return None

    return answers


In [None]:
# # Apply the function row-wise
# filtered_df_sampled['predictions_TAPAS'] = filtered_df_sampled.apply(getAnswer, axis=1)

Processing row 15511...
Row 15511 processed successfully.
Processing row 49809...
Row 49809 processed successfully.
Processing row 51446...
Row 51446 processed successfully.
Processing row 9584...
Row 9584 processed successfully.
Processing row 55896...
Row 55896 processed successfully.
Processing row 28534...
Row 28534 processed successfully.
Processing row 56856...
Row 56856 processed successfully.
Processing row 47919...
Row 47919 processed successfully.
Processing row 5194...
Row 5194 processed successfully.
Processing row 4934...
Row 4934 processed successfully.
Processing row 27034...
Row 27034 processed successfully.
Processing row 50225...
Row 50225 processed successfully.
Processing row 48130...
Row 48130 processed successfully.
Processing row 62729...
Row 62729 processed successfully.
Processing row 25422...
Row 25422 processed successfully.
Processing row 33967...
Row 33967 processed successfully.
Processing row 16303...
Row 16303 processed successfully.
Processing row 59988

In [None]:
# filtered_df_sampled

Unnamed: 0,row_index,db_id,tbl_name,sql_tags,query,question,predictions_TAPAS
0,15511,phone_1,phone,ORDERBY-PROJECT,"SELECT ""Accreditation_level"" FROM ""phone"" ORDE...","Project the ""Accreditation_level"" ordered in d...","[XPERIA T, XPERIA J, LG-P760, GT-I9300, Z520e,..."
1,49809,election,party,DISTINCT-MULT,"SELECT DISTINCT ""Comptroller"", ""US_Senate"", ""A...","Show the different ""Comptroller"", ""US_Senate"",...","[Carl McCall, Alan Hevesi, John Faso]"
2,51446,tracking_orders,Order_Items,ORDERBY-PROJECT,"SELECT ""order_id"" FROM ""Order_Items"" ORDER BY ...","Project the ""order_id"" ordered in descending o...","[4, 15, 12, 8, 11]"
3,9584,culture_company,movie,DISTINCT-SINGLE,"SELECT DISTINCT ""Publisher"" FROM ""book_club""","Show the different ""Publisher"" in the table bo...","[The Boondock Saints, The Big Kahuna, Storm Ca..."
4,55896,coffee_shop,shop,DISTINCT-SINGLE,"SELECT DISTINCT ""Address"" FROM ""shop""","Show the different ""Address"" in the table shop","[1200 Main Street, 1111 Main Street, 1330 Balt..."
...,...,...,...,...,...,...,...
495,7264,store_product,store,DISTINCT-SINGLE,"SELECT DISTINCT ""Type"" FROM ""store""","Show the different ""Type"" in the table store","[City Mall, Village Store]"
496,24852,icfp_1,Inst,ORDERBY-SINGLE,"SELECT * FROM ""Inst"" ORDER BY ""instID"" DESC",Show all data ordered by instID in descending ...,"[1000, 1010, 1020, 1030, 1040, 1050, 1060, 1070]"
497,50883,sports_competition,player,DISTINCT-SINGLE,"SELECT DISTINCT ""name"" FROM ""club""","Show the different ""name"" in the table club","[Michael Platt, Dave Halley, James Evans, Tame..."
498,52432,phone_market,phone,ORDERBY-PROJECT,"SELECT ""Carrier"" FROM ""phone"" ORDER BY ""Carrie...","Project the ""Carrier"" ordered in ascending ord...","[Sprint, Sprint, TMobile, TMobile]"


In [None]:
# testing=filtered_df_sampled.drop(columns=['sql_tags'])

In [None]:
# testing

Unnamed: 0,row_index,db_id,tbl_name,query,question,predictions_TAPAS
0,15511,phone_1,phone,"SELECT ""Accreditation_level"" FROM ""phone"" ORDE...","Project the ""Accreditation_level"" ordered in d...","[XPERIA T, XPERIA J, LG-P760, GT-I9300, Z520e,..."
1,49809,election,party,"SELECT DISTINCT ""Comptroller"", ""US_Senate"", ""A...","Show the different ""Comptroller"", ""US_Senate"",...","[Carl McCall, Alan Hevesi, John Faso]"
2,51446,tracking_orders,Order_Items,"SELECT ""order_id"" FROM ""Order_Items"" ORDER BY ...","Project the ""order_id"" ordered in descending o...","[4, 15, 12, 8, 11]"
3,9584,culture_company,movie,"SELECT DISTINCT ""Publisher"" FROM ""book_club""","Show the different ""Publisher"" in the table bo...","[The Boondock Saints, The Big Kahuna, Storm Ca..."
4,55896,coffee_shop,shop,"SELECT DISTINCT ""Address"" FROM ""shop""","Show the different ""Address"" in the table shop","[1200 Main Street, 1111 Main Street, 1330 Balt..."
...,...,...,...,...,...,...
495,7264,store_product,store,"SELECT DISTINCT ""Type"" FROM ""store""","Show the different ""Type"" in the table store","[City Mall, Village Store]"
496,24852,icfp_1,Inst,"SELECT * FROM ""Inst"" ORDER BY ""instID"" DESC",Show all data ordered by instID in descending ...,"[1000, 1010, 1020, 1030, 1040, 1050, 1060, 1070]"
497,50883,sports_competition,player,"SELECT DISTINCT ""name"" FROM ""club""","Show the different ""name"" in the table club","[Michael Platt, Dave Halley, James Evans, Tame..."
498,52432,phone_market,phone,"SELECT ""Carrier"" FROM ""phone"" ORDER BY ""Carrie...","Project the ""Carrier"" ordered in ascending ord...","[Sprint, Sprint, TMobile, TMobile]"


In [None]:
from qatch import MetricEvaluator

# evaluator = MetricEvaluator(databases=databases)
# tests_df_results = evaluator.evaluate_with_df(filtered_df_sampled,
#                                       prediction_col_name=f'predictions_TAPAS',
#                                       task="QA")

In [None]:
# tests_df_results.head(50)

Unnamed: 0,db_id,row_index,tbl_name,sql_tags,query,question,predictions_TAPAS,cell_precision_predictions_TAPAS,cell_recall_predictions_TAPAS,tuple_cardinality_predictions_TAPAS,tuple_constraint_predictions_TAPAS,tuple_order_predictions_TAPAS
0,aircraft,9920,match,DISTINCT-MULT,"SELECT DISTINCT ""Fastest_Qualifying"", ""Winning...","Show the different ""Fastest_Qualifying"", ""Winn...","[Hannes Arch, Paul Bonhomme, Nigel Lamb]",0.0,0.0,0.143,0.0,
1,aircraft,9855,match,SELECT-RANDOM-COL,"SELECT ""Date"", ""Winning_Pilot"" FROM ""match""","Show all ""Date"", ""Winning_Pilot"" in the table ...","[March 26–27, April 17–18, May 8–9, June 5–6, ...",0.149,0.357,0.143,0.0,
2,aircraft,9909,airport,ORDERBY-PROJECT,"SELECT ""%_Change_2007"" FROM ""airport"" ORDER BY...","Project the ""%_Change_2007"" ordered in descend...","[1.5%, 2.9%, 6.0%, 4.0%, 2.6%, 4.3%, 0.5%, 7.0...",0.0,0.0,0.1,0.0,0.5
3,aircraft,9860,match,SELECT-RANDOM-COL,"SELECT ""Date"", ""Winning_Pilot"", ""Winning_Aircr...","Show all ""Date"", ""Winning_Pilot"", ""Winning_Air...","[March 26–27, April 17–18, May 8–9, June 5–6, ...",0.141,0.184,0.143,0.0,
4,aircraft,9892,match,ORDERBY-PROJECT,"SELECT ""Winning_Aircraft"" FROM ""match"" ORDER B...","Project the ""Winning_Aircraft"" ordered in asce...","[Mina' Zayid , Abu Dhabi, Swan River , Perth, ...",0.0,0.0,0.143,0.0,0.5
5,aircraft,9881,airport_aircraft,SELECT-ADD-COL,"SELECT ""ID"", ""Airport_ID"" FROM ""airport_aircraft""","Show all ""ID"", ""Airport_ID"" in the table airpo...","[1, 3]",0.5,0.375,0.25,0.0,
6,aircraft,9886,match,ORDERBY-SINGLE,"SELECT * FROM ""match"" ORDER BY ""Winning_Aircra...","Show all data ordered by ""Winning_Aircraft"" in...","[Mina' Zayid , Abu Dhabi, Swan River , Perth, ...",0.0,0.0,0.143,0.0,0.5
7,apartment_rentals,52750,Apartment_Buildings,SELECT-ADD-COL,"SELECT ""building_id"", ""building_short_name"", ""...","Show all ""building_id"", ""building_short_name"",...","[Studio, Studio, Flat, Duplex]",0.0,0.0,0.067,0.0,
8,apartment_rentals,52753,Apartment_Buildings,SELECT-RANDOM-COL,"SELECT ""building_description"" FROM ""Apartment_...","Show all ""building_description"" in the table A...","[Studio, Flat, Duplex]",0.0,0.0,0.067,0.0,
9,apartment_rentals,52773,Apartments,SELECT-RANDOM-COL,"SELECT ""apt_number"", ""bedroom_count"", ""room_co...","Show all ""apt_number"", ""bedroom_count"", ""room_...","[1, 2, 13]",0.5,0.219,0.067,0.0,


In [None]:
# stats_df = tests_df_results[['cell_precision_predictions_TAPAS', 'cell_recall_predictions_TAPAS', 'tuple_cardinality_predictions_TAPAS', 'tuple_constraint_predictions_TAPAS', 'tuple_order_predictions_TAPAS']].describe()

In [None]:
# stats_df = tests_df_results[['cell_precision_predictions_TAPAS', 'cell_recall_predictions_TAPAS', 'tuple_cardinality_predictions_TAPAS', 'tuple_constraint_predictions_TAPAS', 'tuple_order_predictions_TAPAS']].describe()

# mean_cell_precision = tests_df_results['cell_precision_predictions_TAPAS'].mean()
# mean_cell_recall = tests_df_results['cell_recall_predictions_TAPAS'].mean()
# mean_tuple_cardinality = tests_df_results['tuple_cardinality_predictions_TAPAS'].mean()
# mean_tuple_constraint = tests_df_results['tuple_constraint_predictions_TAPAS'].mean()
# mean_tuple_order = tests_df_results['tuple_order_predictions_TAPAS'].mean()

# print("Model Performance Metrics:")
# print("----------------------------")
# print("Average Cell Precision:", mean_cell_precision)
# print("Average Cell Recall:", mean_cell_recall)
# print("Average Tuple Cardinality Accuracy:", mean_tuple_cardinality)
# print("Average Tuple Constraint Accuracy:", mean_tuple_constraint)
# print("Average Tuple Order Accuracy:", mean_tuple_order)


# Group by sql_tag and calculate the mean for each category
# stats_df = tests_df_results.groupby('sql_tags')[['cell_precision_predictions_TAPAS', 'cell_recall_predictions_TAPAS', 'tuple_cardinality_predictions_TAPAS', 'tuple_constraint_predictions_TAPAS', 'tuple_order_predictions_TAPAS']].mean()

# # Print the performance metrics for each category
# print("Model Performance Metrics:")
# print("----------------------------")
# for category, row in stats_df.iterrows():
#     print(f"Category: {category}")
#     print("Average Cell Precision:", row['cell_precision_predictions_TAPAS'])
#     print("Average Cell Recall:", row['cell_recall_predictions_TAPAS'])
#     print("Average Tuple Cardinality Accuracy:", row['tuple_cardinality_predictions_TAPAS'])
#     print("Average Tuple Constraint Accuracy:", row['tuple_constraint_predictions_TAPAS'])
#     print("Average Tuple Order Accuracy:", row['tuple_order_predictions_TAPAS'])
#     print("----------------------------")

Model Performance Metrics:
----------------------------
Category: DISTINCT-MULT
Average Cell Precision: 0.017065573770491806
Average Cell Recall: 0.01043076923076923
Average Tuple Cardinality Accuracy: 0.1804615384615384
Average Tuple Constraint Accuracy: 0.0
Average Tuple Order Accuracy: nan
----------------------------
Category: DISTINCT-SINGLE
Average Cell Precision: 0.061224489795918366
Average Cell Recall: 0.014862745098039216
Average Tuple Cardinality Accuracy: 0.2268627450980392
Average Tuple Constraint Accuracy: 0.014862745098039216
Average Tuple Order Accuracy: nan
----------------------------
Category: ORDERBY-PROJECT
Average Cell Precision: 0.1383768115942029
Average Cell Recall: 0.200231884057971
Average Tuple Cardinality Accuracy: 0.10408695652173917
Average Tuple Constraint Accuracy: 0.0
Average Tuple Order Accuracy: 0.5
----------------------------
Category: ORDERBY-SINGLE
Average Cell Precision: 0.16995454545454544
Average Cell Recall: 0.09210294117647057
Average Tuple 

In [None]:
# qatch_synthetic=filtered_df_sampled.drop(columns=["sql_tags", "tbl_name", "row_index", "predictions_TAPAS"], axis=1)
# qatch_synthetic

Unnamed: 0,db_id,query,question
0,phone_1,"SELECT ""Accreditation_level"" FROM ""phone"" ORDE...","Project the ""Accreditation_level"" ordered in d..."
1,election,"SELECT DISTINCT ""Comptroller"", ""US_Senate"", ""A...","Show the different ""Comptroller"", ""US_Senate"",..."
2,tracking_orders,"SELECT ""order_id"" FROM ""Order_Items"" ORDER BY ...","Project the ""order_id"" ordered in descending o..."
3,culture_company,"SELECT DISTINCT ""Publisher"" FROM ""book_club""","Show the different ""Publisher"" in the table bo..."
4,coffee_shop,"SELECT DISTINCT ""Address"" FROM ""shop""","Show the different ""Address"" in the table shop"
...,...,...,...
495,store_product,"SELECT DISTINCT ""Type"" FROM ""store""","Show the different ""Type"" in the table store"
496,icfp_1,"SELECT * FROM ""Inst"" ORDER BY ""instID"" DESC",Show all data ordered by instID in descending ...
497,sports_competition,"SELECT DISTINCT ""name"" FROM ""club""","Show the different ""name"" in the table club"
498,phone_market,"SELECT ""Carrier"" FROM ""phone"" ORDER BY ""Carrie...","Project the ""Carrier"" ordered in ascending ord..."


In [None]:
train_df

Unnamed: 0,ID,answer_text,db_id,query,question,table_used,seq_id,answer_coordinates,new_answer_texts,float_value
1556,4500,"[David CV, 6345, Fall Sails, 7509, How to cook...",document_management,"SELECT document_name , access_count FROM docu...","What are the names of all the documents, as we...",Documents,document_management_X_Documents,"[(5, 4), (5, 3), (13, 4), (13, 3), (6, 4), (6,...","[David CV, 6345, Fall Sails, 7509, How to cook...",
2056,5860,"[r, 13, z, 16, s, 10, s, 19, q, 6, d, 20, m, 7...",tracking_share_transactions,"SELECT lot_details , investor_id FROM LOTS",Return the lot details and investor ids.,Lots,tracking_share_transactions_X_Lots,"[(0, 2), (0, 1), (1, 2), (1, 1), (2, 2), (2, 1...","[r, 13, z, 16, s, 10, s, 19, q, 6, d, 20, m, 7...",
355,1006,"[University of Delaware, Lebanon Valley Colleg...",university_basketball,SELECT school FROM university WHERE founded >...,What are the schools that were either founded ...,university,university_basketball_X_university,"[(0, 1), (1, 1), (2, 1), (3, 1), (4, 1)]","[University of Delaware, Lebanon Valley Colleg...",
615,1847,"[The Great Sasuke §, Gran Hamada, Shinjiro Ota...",wrestler,SELECT Name FROM wrestler ORDER BY Days_held DESC,"What are the names of the wrestlers, ordered d...",wrestler,wrestler_X_wrestler,"[(5, 1), (3, 1), (9, 1), (6, 1), (8, 1), (0, 1...","[The Great Sasuke §, Gran Hamada, Shinjiro Ota...",
2308,6405,"[Review on Canadian files, 42, Review on USA f...",cre_Docs_and_Epenses,"SELECT document_name , document_id FROM Docum...",Find names and ids of all documents with docum...,Documents,cre_Docs_and_Epenses_X_Documents,"[(1, 4), (1, 0), (3, 4), (3, 0), (0, 4), (6, 0...","[Review on Canadian files, 42, Review on USA f...",
...,...,...,...,...,...,...,...,...,...,...
1929,5520,[15],products_gen_characteristics,SELECT count(*) FROM CHARACTERISTICS,How many characteristics are there?,Characteristics,products_gen_characteristics_X_Characteristics,"[(14, 0)]",[15],15.0
1262,3751,[3],program_share,SELECT count(DISTINCT program_id) FROM broadca...,"How many distinct programs are broadcast at ""N...",broadcast,program_share_X_broadcast,"[(2, 0)]",[3],3.0
1303,3858,"[930, 49743]",insurance_policies,"SELECT Amount_Settled , Amount_Claimed FROM C...","Among all the claims, what is the amount claim...",Claims,insurance_policies_X_Claims,"[(6, 5), (6, 4)]","[930, 49743]",
1501,4396,[researcher],tracking_grants_for_research,SELECT role_code FROM Project_Staff GROUP BY r...,Which role is most common for the staff?,Project_Staff,tracking_grants_for_research_X_Project_Staff,"[(2, 2)]",[researcher],


In [None]:
train_df_qatch = train_df.reset_index().rename(columns={'index': 'row_index','table_used':'tbl_name'})
# Apply the function row-wise
train_df_qatch['predictions_TAPAS'] = train_df_qatch.apply(getAnswer, axis=1)

Processing row 1556...
Row 1556 processed successfully.
Processing row 2056...
Row 2056 processed successfully.
Processing row 355...
Row 355 processed successfully.
Processing row 615...
Row 615 processed successfully.
Processing row 2308...
Row 2308 processed successfully.
Processing row 767...
Row 767 processed successfully.
Processing row 955...
Row 955 processed successfully.
Processing row 1474...
Row 1474 processed successfully.
Processing row 372...
Row 372 processed successfully.
Processing row 688...
Row 688 processed successfully.
Processing row 165...
Row 165 processed successfully.
Processing row 2326...
Row 2326 processed successfully.
Processing row 393...
Row 393 processed successfully.
Processing row 91...
Row 91 processed successfully.
Processing row 1759...
Row 1759 processed successfully.
Processing row 55...
Row 55 processed successfully.
Processing row 510...
Row 510 processed successfully.
Processing row 177...
Row 177 processed successfully.
Processing row 2250.

In [None]:
evaluator = MetricEvaluator(databases=databases)
tests_df_results = evaluator.evaluate_with_df(train_df_qatch,
                                      prediction_col_name=f'predictions_TAPAS',
                                      task="QA")

Getting target results: 100%|██████████| 126/126 [00:01<00:00, 69.75it/s]
  return round(sum_cell_match / prediction.size, 3)
Evaluating cell_precision_predictions_TAPAS: 100%|██████████| 1241/1241 [00:00<00:00, 5120.39it/s]
Evaluating cell_recall_predictions_TAPAS: 100%|██████████| 1241/1241 [00:00<00:00, 4567.45it/s]
Evaluating tuple_cardinality_predictions_TAPAS: 100%|██████████| 1241/1241 [00:00<00:00, 10113.18it/s]
Evaluating tuple_constraint_predictions_TAPAS: 100%|██████████| 1241/1241 [00:00<00:00, 7786.30it/s]
Evaluating tuple_order_predictions_TAPAS: 100%|██████████| 424/424 [00:00<00:00, 6320.65it/s]


In [None]:
stats_df = tests_df_results[['cell_precision_predictions_TAPAS', 'cell_recall_predictions_TAPAS', 'tuple_cardinality_predictions_TAPAS', 'tuple_constraint_predictions_TAPAS', 'tuple_order_predictions_TAPAS']].describe()

mean_cell_precision = tests_df_results['cell_precision_predictions_TAPAS'].mean()
mean_cell_recall = tests_df_results['cell_recall_predictions_TAPAS'].mean()
mean_tuple_cardinality = tests_df_results['tuple_cardinality_predictions_TAPAS'].mean()
mean_tuple_constraint = tests_df_results['tuple_constraint_predictions_TAPAS'].mean()
mean_tuple_order = tests_df_results['tuple_order_predictions_TAPAS'].mean()

print("Model Performance Metrics:")
print("----------------------------")
print("Average Cell Precision:", mean_cell_precision)
print("Average Cell Recall:", mean_cell_recall)
print("Average Tuple Cardinality Accuracy:", mean_tuple_cardinality)
print("Average Tuple Constraint Accuracy:", mean_tuple_constraint)
print("Average Tuple Order Accuracy:", mean_tuple_order)

Model Performance Metrics:
----------------------------
Average Cell Precision: 0.03239685169842588
Average Cell Recall: 0.0725914585012087
Average Tuple Cardinality Accuracy: 0.5791216760676887
Average Tuple Constraint Accuracy: 0.005908944399677679
Average Tuple Order Accuracy: 0.5047169811320755


In [None]:
stats_df = tests_df_results[['cell_precision_predictions_TAPAS', 'cell_recall_predictions_TAPAS', 'tuple_cardinality_predictions_TAPAS', 'tuple_constraint_predictions_TAPAS', 'tuple_order_predictions_TAPAS']].describe()

mean_cell_precision = tests_df_results['cell_precision_predictions_TAPAS'].mean()
mean_cell_recall = tests_df_results['cell_recall_predictions_TAPAS'].mean()
mean_tuple_cardinality = tests_df_results['tuple_cardinality_predictions_TAPAS'].mean()
mean_tuple_constraint = tests_df_results['tuple_constraint_predictions_TAPAS'].mean()
mean_tuple_order = tests_df_results['tuple_order_predictions_TAPAS'].mean()

print("Model Performance Metrics:")
print("----------------------------")
print("Average Cell Precision:", mean_cell_precision)
print("Average Cell Recall:", mean_cell_recall)
print("Average Tuple Cardinality Accuracy:", mean_tuple_cardinality)
print("Average Tuple Constraint Accuracy:", mean_tuple_constraint)
print("Average Tuple Order Accuracy:", mean_tuple_order)

Model Performance Metrics:
----------------------------
Average Cell Precision: 0.13624979724249794
Average Cell Recall: 0.14808541498791303
Average Tuple Cardinality Accuracy: 0.5791216760676887
Average Tuple Constraint Accuracy: 0.10475423045930701
Average Tuple Order Accuracy: 0.5200471698113207


In [None]:
val_df_qatch = val_df.reset_index().rename(columns={'index': 'row_index','table_used':'tbl_name'})
# Apply the function row-wise
val_df_qatch['predictions_TAPAS'] = val_df_qatch.apply(getAnswer, axis=1)

Processing row 576...
Row 576 processed successfully.
Processing row 1698...
Row 1698 processed successfully.
Processing row 415...
Row 415 processed successfully.
Processing row 640...
Row 640 processed successfully.
Processing row 2202...
Row 2202 processed successfully.
Processing row 2154...
Row 2154 processed successfully.
Processing row 259...
Row 259 processed successfully.
Processing row 2159...
Row 2159 processed successfully.
Processing row 888...
Row 888 processed successfully.
Processing row 58...
Row 58 processed successfully.
Processing row 826...
Row 826 processed successfully.
Processing row 1173...
Row 1173 processed successfully.
Processing row 1158...
Row 1158 processed successfully.
Processing row 108...
Row 108 processed successfully.
Processing row 2393...
Row 2393 processed successfully.
Processing row 2051...
Row 2051 processed successfully.
Processing row 1269...
Row 1269 processed successfully.
Processing row 1895...
Row 1895 processed successfully.
Processing

In [None]:
evaluator = MetricEvaluator(databases=databases)
tests_df_results = evaluator.evaluate_with_df(val_df_qatch,
                                      prediction_col_name=f'predictions_TAPAS',
                                      task="QA")

Getting target results: 100%|██████████| 119/119 [01:06<00:00,  1.78it/s]
  return round(sum_cell_match / prediction.size, 3)
Evaluating cell_precision_predictions_TAPAS: 100%|██████████| 627/627 [00:00<00:00, 5587.09it/s]
Evaluating cell_recall_predictions_TAPAS: 100%|██████████| 627/627 [00:00<00:00, 5256.03it/s]
Evaluating tuple_cardinality_predictions_TAPAS: 100%|██████████| 627/627 [00:00<00:00, 8525.23it/s]
Evaluating tuple_constraint_predictions_TAPAS: 100%|██████████| 627/627 [00:00<00:00, 9345.48it/s]
Evaluating tuple_order_predictions_TAPAS: 100%|██████████| 223/223 [00:00<00:00, 6845.96it/s]


In [None]:
stats_df = tests_df_results[['cell_precision_predictions_TAPAS', 'cell_recall_predictions_TAPAS', 'tuple_cardinality_predictions_TAPAS', 'tuple_constraint_predictions_TAPAS', 'tuple_order_predictions_TAPAS']].describe()

mean_cell_precision = tests_df_results['cell_precision_predictions_TAPAS'].mean()
mean_cell_recall = tests_df_results['cell_recall_predictions_TAPAS'].mean()
mean_tuple_cardinality = tests_df_results['tuple_cardinality_predictions_TAPAS'].mean()
mean_tuple_constraint = tests_df_results['tuple_constraint_predictions_TAPAS'].mean()
mean_tuple_order = tests_df_results['tuple_order_predictions_TAPAS'].mean()

print("Model Performance Metrics:")
print("----------------------------")
print("Average Cell Precision:", mean_cell_precision)
print("Average Cell Recall:", mean_cell_recall)
print("Average Tuple Cardinality Accuracy:", mean_tuple_cardinality)
print("Average Tuple Constraint Accuracy:", mean_tuple_constraint)
print("Average Tuple Order Accuracy:", mean_tuple_order)

Model Performance Metrics:
----------------------------
Average Cell Precision: 0.09839065108514192
Average Cell Recall: 0.11360765550239234
Average Tuple Cardinality Accuracy: 0.5826124401913871
Average Tuple Constraint Accuracy: 0.06698564593301436
Average Tuple Order Accuracy: 0.5089686098654709


In [None]:
import pandas as pd
import re

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

def normalize_number(s):
    if is_number(s):
        return str(float(s))
    else:
        return s

def calculate_metrics(answer_list, predictions_list):
    answer_list = [normalize_number(str(x)) for x in answer_list]
    predictions_list = [normalize_number(str(x)) for x in predictions_list]
    answer_set = set(answer_list)
    predictions_set = set(predictions_list)

    true_positives = 0
    false_positives = 0
    false_negatives = 0

    for element in answer_set:
        if element in predictions_set:
            true_positives += 1
        else:
            false_negatives += 1

    for element in predictions_set:
        if element not in answer_set:
            false_positives += 1
    precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
    recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return precision, recall, f1

def benchmark_tapas(df):
    df[['precision', 'recall', 'f1']] = df.apply(lambda row: calculate_metrics(row['answer_text'], row['predictions_TAPAS']), axis=1, result_type='expand')
    avg_precision = df['precision'].mean()
    avg_recall = df['recall'].mean()
    avg_f1 = df['f1'].mean()

    print("Average Precision:", avg_precision)
    print("Average Recall:", avg_recall)
    print("Average F1 Score:", avg_f1)


benchmark_tapas(val_df_qatch)

Average Precision: 0.2503987240829346
Average Recall: 0.22607655502392343
Average F1 Score: 0.23334092807777018


In [None]:
stats_df = tests_df_results[['cell_precision_predictions_TAPAS', 'cell_recall_predictions_TAPAS', 'tuple_cardinality_predictions_TAPAS', 'tuple_constraint_predictions_TAPAS', 'tuple_order_predictions_TAPAS']].describe()

mean_cell_precision = tests_df_results['cell_precision_predictions_TAPAS'].mean()
mean_cell_recall = tests_df_results['cell_recall_predictions_TAPAS'].mean()
mean_tuple_cardinality = tests_df_results['tuple_cardinality_predictions_TAPAS'].mean()
mean_tuple_constraint = tests_df_results['tuple_constraint_predictions_TAPAS'].mean()
mean_tuple_order = tests_df_results['tuple_order_predictions_TAPAS'].mean()

print("Model Performance Metrics:")
print("----------------------------")
print("Average Cell Precision:", mean_cell_precision)
print("Average Cell Recall:", mean_cell_recall)
print("Average Tuple Cardinality Accuracy:", mean_tuple_cardinality)
print("Average Tuple Constraint Accuracy:", mean_tuple_constraint)
print("Average Tuple Order Accuracy:", mean_tuple_order)

In [None]:
stats_df = tests_df_results[['cell_precision_predictions_TAPAS', 'cell_recall_predictions_TAPAS', 'tuple_cardinality_predictions_TAPAS', 'tuple_constraint_predictions_TAPAS', 'tuple_order_predictions_TAPAS']].describe()

mean_cell_precision = tests_df_results['cell_precision_predictions_TAPAS'].mean()
mean_cell_recall = tests_df_results['cell_recall_predictions_TAPAS'].mean()
mean_tuple_cardinality = tests_df_results['tuple_cardinality_predictions_TAPAS'].mean()
mean_tuple_constraint = tests_df_results['tuple_constraint_predictions_TAPAS'].mean()
mean_tuple_order = tests_df_results['tuple_order_predictions_TAPAS'].mean()

print("Model Performance Metrics:")
print("----------------------------")
print("Average Cell Precision:", mean_cell_precision)
print("Average Cell Recall:", mean_cell_recall)
print("Average Tuple Cardinality Accuracy:", mean_tuple_cardinality)
print("Average Tuple Constraint Accuracy:", mean_tuple_constraint)
print("Average Tuple Order Accuracy:", mean_tuple_order)

Model Performance Metrics:
----------------------------
Average Cell Precision: 0.10164764267990076
Average Cell Recall: 0.11635406698564595
Average Tuple Cardinality Accuracy: 0.5767679425837317
Average Tuple Constraint Accuracy: 0.06937799043062201
Average Tuple Order Accuracy: 0.5138888888888888
