In [None]:
!pip install --quiet qatch
!pip install --quiet frozendict
!pip install --quiet datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.9/71.9 kB[0m [31m834.9 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.3/38.3 MB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m337.4/337.4 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!git clone https://github.com/spapicchio/QATCH.git

Cloning into 'QATCH'...
remote: Enumerating objects: 1302, done.[K
remote: Counting objects: 100% (357/357), done.[K
remote: Compressing objects: 100% (240/240), done.[K
remote: Total 1302 (delta 176), reused 185 (delta 105), pack-reused 945[K
Receiving objects: 100% (1302/1302), 6.07 MiB | 18.67 MiB/s, done.
Resolving deltas: 100% (763/763), done.


In [None]:
import os
import pandas as pd
import sqlite3
import ast
import pickle

In [None]:
#@title Script to get the coordinates of the answers
# coding=utf-8
# Copyright 2019 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""This module implements a simple parser that can be used for TAPAS.

Given a table, a question and one or more answer_texts, it will parse the texts
to populate other fields (e.g. answer_coordinates, float_value) that are required
by TAPAS.

Please note that exceptions in this module are concise and not parameterized,
since they are used as counter names in a BEAM pipeline.
"""

import enum
from typing import Callable, List, Text, Optional

import six
import struct
import unicodedata
import re

import frozendict
import numpy as np
import scipy.optimize


class SupervisionMode(enum.Enum):
  # Don't filter out any supervised information.
  NONE = 0
  # Remove all the supervised signals and recompute them by parsing answer
  # texts.
  REMOVE_ALL = 2
  # Same as above but discard ambiguous examples
  # (where an answer matches multiple cells).
  REMOVE_ALL_STRICT = 3


def _find_matching_coordinates(table, answer_text,
                               normalize):
  normalized_text = normalize(answer_text)
  for row_index, row in table.iterrows():
    for column_index, cell in enumerate(row):
      if normalized_text == normalize(str(cell)):
        yield (row_index, column_index)


def _compute_cost_matrix_inner(
    table,
    answer_texts,
    normalize,
    discard_ambiguous_examples,
):
  """Returns a cost matrix M where the value M[i,j] contains a matching cost from answer i to cell j.

  The matrix is a binary matrix and -1 is used to indicate a possible match from
  a given answer_texts to a specific cell table. The cost matrix can then be
  usedto compute the optimal assignments that minimizes the cost using the
  hungarian algorithm (see scipy.optimize.linear_sum_assignment).

  Args:
    table: a Pandas dataframe.
    answer_texts: a list of strings.
    normalize: a function that normalizes a string.
    discard_ambiguous_examples: If true discard if answer has multiple matches.

  Raises:
    ValueError if:
      - we cannot correctly construct the cost matrix or the text-cell
      assignment is ambiguous.
      - we cannot find a matching cell for a given answer_text.

  Returns:
    A numpy matrix with shape (num_answer_texts, num_rows * num_columns).
  """
  max_candidates = 0
  n_rows, n_columns = table.shape[0], table.shape[1]
  num_cells = n_rows * n_columns
  num_candidates = np.zeros((n_rows, n_columns))
  cost_matrix = np.zeros((len(answer_texts), num_cells))

  for index, answer_text in enumerate(answer_texts):
    found = 0
    for row, column in _find_matching_coordinates(table, answer_text,
                                                  normalize):
      found += 1
      cost_matrix[index, (row * len(table.columns)) + column] = -1
      num_candidates[row, column] += 1
      max_candidates = max(max_candidates, num_candidates[row, column])
    if found == 0:
      return None
    if discard_ambiguous_examples and found > 1:
      raise ValueError("Found multiple cells for answers")

  # TODO(piccinno): Shall we allow ambiguous assignments?
  # if max_candidates > 1:
  #   raise ValueError("Assignment is ambiguous")

  return cost_matrix


def _compute_cost_matrix(
    table,
    answer_texts,
    discard_ambiguous_examples,
):
  """Computes cost matrix."""
  for index, normalize_fn in enumerate(STRING_NORMALIZATIONS):
    try:
      result = _compute_cost_matrix_inner(
          table,
          answer_texts,
          normalize_fn,
          discard_ambiguous_examples,
      )
      if result is None:
        continue
      return result
    except ValueError:
      if index == len(STRING_NORMALIZATIONS) - 1:
        raise
  return None


def _parse_answer_coordinates(table,
                              answer_texts,
                              discard_ambiguous_examples):
  """Populates answer_coordinates using answer_texts.

  Args:
    table: a Table message, needed to compute the answer coordinates.
    answer_texts: a list of strings
    discard_ambiguous_examples: If true discard if answer has multiple matches.

  Raises:
    ValueError if the conversion fails.
  """

  cost_matrix = _compute_cost_matrix(
      table,
      answer_texts,
      discard_ambiguous_examples,
  )
  if cost_matrix is None:
    return
  row_indices, column_indices = scipy.optimize.linear_sum_assignment(
      cost_matrix)

  # create answer coordinates as list of tuples
  answer_coordinates = []
  for row_index in row_indices:
    flatten_position = column_indices[row_index]
    row_coordinate = flatten_position // len(table.columns)
    column_coordinate = flatten_position % len(table.columns)
    answer_coordinates.append((row_coordinate, column_coordinate))

  return answer_coordinates


### START OF UTILITIES FROM TEXT_UTILS.PY ###

def wtq_normalize(x):
  """Returns the normalized version of x.
  This normalization function is taken from WikiTableQuestions github, hence the
  wtq prefix. For more information, see
  https://github.com/ppasupat/WikiTableQuestions/blob/master/evaluator.py
  Args:
    x: the object (integer type or string) to normalize.
  Returns:
    A normalized string.
  """
  x = x if isinstance(x, six.text_type) else six.text_type(x)
  # Remove diacritics.
  x = "".join(
      c for c in unicodedata.normalize("NFKD", x)
      if unicodedata.category(c) != "Mn")
  # Normalize quotes and dashes.
  x = re.sub(u"[‘’´`]", "'", x)
  x = re.sub(u"[“”]", '"', x)
  x = re.sub(u"[‐‑‒–—−]", "-", x)
  x = re.sub(u"[‐]", "", x)
  while True:
    old_x = x
    # Remove citations.
    x = re.sub(u"((?<!^)\\[[^\\]]*\\]|\\[\\d+\\]|[•♦†‡*#+])*$", "",
               x.strip())
    # Remove details in parenthesis.
    x = re.sub(u"(?<!^)( \\([^)]*\\))*$", "", x.strip())
    # Remove outermost quotation mark.
    x = re.sub(u'^"([^"]*)"$', r"\1", x.strip())
    if x == old_x:
      break
  # Remove final '.'.
  if x and x[-1] == ".":
    x = x[:-1]
  # Collapse whitespaces and convert to lower case.
  x = re.sub(r"\s+", " ", x, flags=re.U).lower().strip()
  x = re.sub("<[^<]+?>", "", x)
  x = x.replace("\n", " ")
  return x


_TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE)


def tokenize_string(x):
  return list(_TOKENIZER.findall(x.lower()))


# List of string normalization functions to be applied in order. We go from
# simplest to more complex normalization procedures.
STRING_NORMALIZATIONS = (
    lambda x: x,
    lambda x: x.lower(),
    tokenize_string,
    wtq_normalize,
)


def to_float32(v):
  """If v is a float reduce precision to that of a 32 bit float."""
  if not isinstance(v, float):
    return v
  return struct.unpack("!f", struct.pack("!f", v))[0]


def convert_to_float(value):
  """Converts value to a float using a series of increasingly complex heuristics.
  Args:
    value: object that needs to be converted. Allowed types include
      float/int/strings.
  Returns:
    A float interpretation of value.
  Raises:
    ValueError if the float conversion of value fails.
  """
  if isinstance(value, float):
    return value
  if isinstance(value, int):
    return float(value)
  if not isinstance(value, six.string_types):
    raise ValueError("Argument value is not a string. Can't parse it as float")
  sanitized = value

  try:
    # Example: 1,000.7
    if "." in sanitized and "," in sanitized:
      return float(sanitized.replace(",", ""))
    # 1,000
    if "," in sanitized and _split_thousands(",", sanitized):
      return float(sanitized.replace(",", ""))
    # 5,5556
    if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(
        ",", sanitized):
      return float(sanitized.replace(",", "."))
    # 0.0.0.1
    if sanitized.count(".") > 1:
      return float(sanitized.replace(".", ""))
    # 0,0,0,1
    if sanitized.count(",") > 1:
      return float(sanitized.replace(",", ""))
    return float(sanitized)
  except ValueError:
    # Avoid adding the sanitized value in the error message.
    raise ValueError("Unable to convert value to float")

### END OF UTILITIES FROM TEXT_UTILS.PY ###

def _parse_answer_float(answer_texts, float_value):
  if len(answer_texts) > 1:
    raise ValueError("Cannot convert to multiple answers to single float")
  float_value = convert_to_float(answer_texts[0])
  float_value = float_value

  return answer_texts, float_value


def _has_single_float_answer_equal_to(question, answer_texts, target):
  """Returns true if the question has a single answer whose value equals to target."""
  if len(answer_texts) != 1:
    return False
  try:
    float_value = convert_to_float(answer_texts[0])
    # In general answer_float is derived by applying the same conver_to_float
    # function at interaction creation time, hence here we use exact match to
    # avoid any false positive.
    return to_float32(float_value) == to_float32(target)
  except ValueError:
    return False


def _parse_question(
    table,
    original_question,
    answer_texts,
    answer_coordinates,
    float_value,
    aggregation_function,
    clear_fields,
    discard_ambiguous_examples,
):
  """Parses question's answer_texts fields to possibly populate additional fields.

  Args:
    table: a Pandas dataframe, needed to compute the answer coordinates.
    original_question: a string.
    answer_texts: a list of strings, serving as the answer to the question.
    anser_coordinates:
    float_value: a float, serves as float value signal.
    aggregation_function:
    clear_fields: A list of strings indicating which fields need to be cleared
      and possibly repopulated.
    discard_ambiguous_examples: If true, discard ambiguous examples.

  Returns:
    A Question message with answer_coordinates or float_value field populated.

  Raises:
    ValueError if we cannot parse correctly the question message.
  """
  question = original_question

  # If we have a float value signal we just copy its string representation to
  # the answer text (if multiple answers texts are present OR the answer text
  # cannot be parsed to float OR the float value is different), after clearing
  # this field.
  if "float_value" in clear_fields and float_value is not None:
    if not _has_single_float_answer_equal_to(question, answer_texts, float_value):
      del answer_texts[:]
      float_value = float(float_value)
      if float_value.is_integer():
        number_str = str(int(float_value))
      else:
        number_str = str(float_value)
      answer_texts = []
      answer_texts.append(number_str)

  if not answer_texts:
    raise ValueError("No answer_texts provided")

  for field_name in clear_fields:
    if field_name == "answer_coordinates":
        answer_coordinates = None
    if field_name == "float_value":
        float_value = None
    if field_name == "aggregation_function":
        aggregation_function = None

  error_message = ""
  if not answer_coordinates:
    try:
      answer_coordinates = _parse_answer_coordinates(
          table,
          answer_texts,
          discard_ambiguous_examples,
      )
    except ValueError as exc:
      error_message += "[answer_coordinates: {}]".format(str(exc))
      if discard_ambiguous_examples:
        raise ValueError(f"Cannot parse answer: {error_message}")

  if not float_value:
    try:
      answer_texts, float_value = _parse_answer_float(answer_texts, float_value)
    except ValueError as exc:
      error_message += "[float_value: {}]".format(str(exc))

  # Raises an exception if we cannot set any of the two fields.
  if not answer_coordinates and not float_value:
    raise ValueError("Cannot parse answer: {}".format(error_message))

  return question, answer_texts, answer_coordinates, float_value, aggregation_function


# TODO(piccinno): Use some sort of introspection here to get the field names of
# the proto.
_CLEAR_FIELDS = frozendict.frozendict({
    SupervisionMode.REMOVE_ALL: [
        "answer_coordinates", "float_value", "aggregation_function"
    ],
    SupervisionMode.REMOVE_ALL_STRICT: [
        "answer_coordinates", "float_value", "aggregation_function"
    ]
})


def parse_question(table, question, answer_texts, answer_coordinates=None, float_value=None, aggregation_function=None,
                    mode=SupervisionMode.REMOVE_ALL):
    """Parses answer_text field of a question to populate additional fields required by TAPAS.

    Args:
        table: a Pandas dataframe, needed to compute the answer coordinates. Note that one should apply .astype(str)
        before supplying the table to this function.
        question: a string.
        answer_texts: a list of strings, containing one or more answer texts that serve as answer to the question.
        answer_coordinates: optional answer coordinates supervision signal, if you already have those.
        float_value: optional float supervision signal, if you already have this.
        aggregation_function: optional aggregation function supervised signal, if you already have this.
        mode: see SupervisionMode enum for more information.

    Returns:
        A list with the question, populated answer_coordinates or float_value.

    Raises:
        ValueError if we cannot parse correctly the question string.
    """
    if mode == SupervisionMode.NONE:
        return question, answer_texts

    clear_fields = _CLEAR_FIELDS.get(mode, None)
    if clear_fields is None:
        raise ValueError(f"Mode {mode.name} is not supported")

    return _parse_question(
        table,
        question,
        answer_texts,
        answer_coordinates,
        float_value,
        aggregation_function,
        clear_fields,
        discard_ambiguous_examples=mode == SupervisionMode.REMOVE_ALL_STRICT,
    )

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
os.chdir("/content/QATCH")

In [None]:
from utils import *
def Merge(dict1, dict2):
    res = dict1 | dict2
    return res

In [None]:
#@title read data overwrite : mainly to control the number of rows
def read_data(db_id: str, model_name: str,
              input_base_path_data='./data',
              seed: int = 2023, inject_null_percentage: float = 0.0
              ) -> dict[str, pd.DataFrame]:
    model_name = check_model_names(model_name)

    sample_size = {
        ('medicine', 'chatgpt_qa', 'heart-attack'): 30,
        ('medicine', 'llama_qa', 'heart-attack'): 20,
        ('medicine', 'tapas', 'heart-attack'): 15,
        ('medicine', 'tapex', 'heart-attack'): 30,
        ('medicine', 'omnitab', 'heart-attack'): 20,
        ('medicine', 'sp', 'heart-attack'): None,
        ('medicine', 'tapas', 'breast-cancer'): 15,
        ('medicine', 'tapex', 'breast-cancer'): 30,
        ('medicine', 'chatgpt_qa', 'breast-cancer'): 35,
        ('medicine', 'llama_qa', 'breast-cancer'): 25,
        ('medicine', 'omnitab', 'breast-cancer'): 20,
        ('medicine', 'sp', 'breast-cancer'): None,

        ('ecommerce', 'tapas', 'sales-transactions'): 15,
        ('ecommerce', 'tapex', 'sales-transactions'): 20,
        ('ecommerce', 'chatgpt_qa', 'sales-transactions'): 40,
        ('ecommerce', 'llama_qa', 'sales-transactions'): 20,
        ('ecommerce', 'omnitab', 'sales-transactions'): 20,
        ('ecommerce', 'sp', 'sales-transactions'): 30000,
        ('ecommerce', 'tapas', 'fitness-trackers'): 15,
        ('ecommerce', 'tapex', 'fitness-trackers'): 20,
        ('ecommerce', 'chatgpt_qa', 'fitness-trackers'): 30,
        ('ecommerce', 'llama_qa', 'fitness-trackers'): 20,
        ('ecommerce', 'omnitab', 'fitness-trackers'): 20,
        ('ecommerce', 'sp', 'fitness-trackers'): None,

        ('finance', 'tapas', 'fraud'): 15,
        ('finance', 'tapex', 'fraud'): 15,
        ('finance', 'chatgpt_qa', 'fraud'): 30,
        ('finance', 'llama_qa', 'fraud'): 20,
        ('finance', 'omnitab', 'fraud'): 20,
        ('finance', 'sp', 'fraud'): 30000,
        ('finance', 'tapas', 'ibm'): 15,
        ('finance', 'tapex', 'ibm'): 20,
        ('finance', 'chatgpt_qa', 'ibm'): 25,
        ('finance', 'llama_qa', 'ibm'): 20,
        ('finance', 'omnitab', 'ibm'): 20,
        ('finance', 'sp', 'ibm'): None,

        ('miscellaneous', 'tapas', 'mush'): 15,
        ('miscellaneous', 'tapex', 'mush'): 15,
        ('miscellaneous', 'chatgpt_qa', 'mush'): 30,
        ('miscellaneous', 'llama_qa', 'mush'): 20,
        ('miscellaneous', 'omnitab', 'mush'): 20,
        ('miscellaneous', 'sp', 'mush'): None,
        ('miscellaneous', 'tapas', 'adult'): 15,
        ('miscellaneous', 'tapex', 'adult'): 15,
        ('miscellaneous', 'chatgpt_qa', 'adult'): 30,
        ('miscellaneous', 'llama_qa', 'adult'): 20,
        ('miscellaneous', 'omnitab', 'adult'): 20,
        ('miscellaneous', 'sp', 'adult'): 30000
    }

    if db_id == 'medicine':
        df_1 = read_heart_attack_dataset(
            pd.read_csv(f'{input_base_path_data}/medicine/heart-attack.csv'),
            sample_size=sample_size[(db_id, model_name, 'heart-attack')],
            random_state=seed
        )
        df_2 = read_breast_cancer_dataset(
            pd.read_csv(f'{input_base_path_data}/medicine/breast-cancer.csv'),
            sample_size=sample_size[(db_id, model_name, 'breast-cancer')],
            random_state=seed
        )
        db_tables = {'heartAttack': df_1, 'breastCancer': df_2}

    elif db_id == 'ecommerce':
        df_1 = read_sales_transactions_dataset(
            pd.read_csv(f'{input_base_path_data}/ecommerce/sales-transactions.csv'),
            sample_size=sample_size[(db_id, model_name, 'sales-transactions')],
            random_state=seed
        )

        df_2 = read_fitness_trackers_dataset(
            pd.read_csv(f'{input_base_path_data}/ecommerce/fitness-trackers.csv'),
            sample_size=sample_size[(db_id, model_name, 'fitness-trackers')],
            random_state=seed
        )
        db_tables = {'salesTransactions': df_1, 'fitnessTrackers': df_2}

    elif db_id == 'miscellaneous':
        df_1 = read_mushroom_dataset(
            pd.read_csv(f'{input_base_path_data}/miscellaneous/mushrooms.csv'),
            sample_size=sample_size[(db_id, model_name, 'mush')],
            random_state=seed
        )

        df_2 = read_adult_dataset(
            pd.read_csv(f'{input_base_path_data}/miscellaneous/adult-census.csv'),
            sample_size=sample_size[(db_id, model_name, 'adult')],
            random_state=seed
        )
        db_tables = {'mushrooms': df_1, 'adultCensus': df_2}

    elif db_id == 'finance':
        df_1 = read_bank_fraud_dataset(
            pd.read_csv(f'{input_base_path_data}/finance/account-fraud.csv'),
            sample_size=sample_size[(db_id, model_name, 'fraud')],
            random_state=seed
        )

        df_2 = read_finance_factory_ibm(
            pd.read_csv(f'{input_base_path_data}/finance/late-payment.csv'),
            sample_size=sample_size[(db_id, model_name, 'ibm')],
            random_state=seed
        )
        db_tables = {'accountFraud': df_1, 'latePayment': df_2}

    else:
        raise ValueError('Unknown dataset name')
    # inject null values
    db_tables = inject_null_values_in_tables(inject_null_percentage, db_tables, seed)
    return db_tables


In [None]:
os.chdir("/content/drive/MyDrive/ProprietaryDatasets")

In [None]:
ProprietaryDatasets=['medicine','miscellaneous','ecommerce','finance']
db_to_df={}
base_path='/content/drive/MyDrive/ProprietaryDatasets'
for db_id in ProprietaryDatasets:
    db_to_df_temp=read_data(db_id, 'tapas',
                  input_base_path_data=base_path,
                  seed= 2023)
    db_to_df=Merge(db_to_df_temp,db_to_df)

In [None]:
db_to_df

{'mushrooms':         class capshape capsurface capcolor  bruises     odor gillattachment  \
 0   poisonous     flat     smooth    white  bruises  pungent           free   
 1      edible   convex    fibrous     gray       no     none           free   
 2      edible     bell     smooth    white  bruises   almond           free   
 3      edible     flat      scaly    brown  bruises     none           free   
 4   poisonous   convex    fibrous   yellow       no     foul           free   
 5   poisonous   convex     smooth    white  bruises  pungent           free   
 6      edible     flat     smooth    brown       no     none       attached   
 7   poisonous   convex     smooth      red       no    fishy           free   
 8      edible   convex    fibrous     gray       no     none           free   
 9      edible   convex    fibrous    white       no     none           free   
 10  poisonous  knobbed      scaly      red       no     foul           free   
 11  poisonous   convex    

In [None]:
os.chdir("/content/drive/MyDrive/spider")

In [None]:
# with open('/content/drive/MyDrive/ProprietaryDatasets.pkl', 'rb') as f:
#       db_to_df = pickle.load(f)
# with open('/content/drive/MyDrive/ProprietaryDatasetsDict.pkl', 'rb') as f:
#       db_to_df = pickle.load(f)
with open('/content/drive/MyDrive/Proprietary/domain_to_dict.pkl', 'wb') as f:
      pickle.dump(domain_to_dict,f)

In [None]:
from qatch.database_reader import SingleDatabase
import pandas as pd


# define where to store the sqlite database
db_save_path = '/content/drive/MyDrive/ProprietaryDatasets/database8'

# define the name of the database
db_id = 'propiretary8'

# create database connection
db = SingleDatabase(db_path=db_save_path, db_name=db_id, tables=db_to_df)

In [None]:
from qatch.database_reader import MultipleDatabases

databases = MultipleDatabases(db_save_path)

In [None]:
from qatch import TestGenerator

# init generator
test_generator = TestGenerator(databases=databases)

# generate tests for each database and for each generator
tests_df = test_generator.generate()

Generating test for each database: 100%|██████████| 1/1 [00:00<00:00,  2.25it/s]


In [None]:
tests_df['sql_tags'].unique()

array(['SELECT-ALL', 'SELECT-ADD-COL', 'SELECT-RANDOM-COL',
       'ORDERBY-SINGLE', 'ORDERBY-PROJECT', 'DISTINCT-SINGLE',
       'DISTINCT-MULT', 'WHERE-CAT-MOST-FREQUENT',
       'WHERE-CAT-LEAST-FREQUENT', 'WHERE-NOT-MOST-FREQUENT',
       'WHERE-NOT-LEAST-FREQUENT', 'WHERE-NUM-MAX-VALUES',
       'WHERE-NUM-MIN-VALUES', 'WHERE-NUM-MEAN-VALUES', 'GROUPBY-NO-AGGR',
       'GROUPBY-COUNT', 'GROUPBY-AGG-MIN', 'GROUPBY-AGG-MAX',
       'GROUPBY-AGG-AVG', 'GROUPBY-AGG-SUM', 'HAVING-COUNT-GR',
       'HAVING-COUNT-LS', 'HAVING-COUNT-EQ', 'HAVING-AGG-AVG-GR',
       'HAVING-AGG-AVG-LS', 'HAVING-AGG-SUM-GR', 'HAVING-AGG-SUM-LS',
       'SIMPLE-AGG-COUNT', 'SIMPLE-AGG-COUNT-DISTINCT', 'SIMPLE-AGG-MAX',
       'SIMPLE-AGG-MIN', 'SIMPLE-AGG-AVG'], dtype=object)

In [None]:
tests_df.head()

Unnamed: 0,db_id,tbl_name,sql_tags,query,question
0,propiretary7,salesTransactions,SELECT-ALL,SELECT * FROM `salesTransactions`,Show all the rows in the table salesTransactions
1,propiretary7,salesTransactions,SELECT-ADD-COL,SELECT `transactionno` FROM `salesTransactions`,Show all `transactionno` in the table salesTra...
2,propiretary7,salesTransactions,SELECT-ADD-COL,"SELECT `transactionno`, `date` FROM `salesTran...","Show all `transactionno`, `date` in the table ..."
3,propiretary7,salesTransactions,SELECT-ADD-COL,"SELECT `transactionno`, `date`, `productno` FR...","Show all `transactionno`, `date`, `productno` ..."
4,propiretary7,salesTransactions,SELECT-ADD-COL,"SELECT `transactionno`, `date`, `productno`, `...","Show all `transactionno`, `date`, `productno`,..."


In [None]:
tests_df.shape#1848

(379, 5)

In [None]:
tests_df=tests_df.rename(columns={'tbl_name':'table_used'})

In [None]:
domain_mapping = {
    'fitnessTrackers': 'ecommerce',
    'salesTransactions': 'ecommerce',
    'accountFraud': 'finance',
    'latePayment': 'finance',
    'breastCancer': 'medicine',
    'heartAttack': 'medicine',
    'adultCensus': 'miscellaneous',
    'mushrooms': 'miscellaneous'
}


def get_domain(table_name):
    return domain_mapping.get(table_name, 'unknown')

In [None]:
tests_df['domain']=tests_df['table_used'].apply(get_domain)

In [None]:
tests_df['table_used'].unique()

array(['accountFraud', 'latePayment', 'salesTransactions',
       'fitnessTrackers', 'mushrooms', 'adultCensus', 'heartAttack',
       'breastCancer'], dtype=object)

In [None]:
tests_df['domain'].value_counts()

domain
finance          522
medicine         486
ecommerce        461
miscellaneous    379
Name: count, dtype: int64

In [None]:
qatch_synthetic=tests_df.copy()

In [None]:
qatch_synthetic.reset_index(inplace=True)
qatch_synthetic.rename(columns={'index': 'ID'}, inplace=True)
qatch_synthetic.head()

Unnamed: 0,ID,db_id,table_used,sql_tags,query,question,domain
0,0,propiretary4,accountFraud,SELECT-ALL,SELECT * FROM `accountFraud`,Show all the rows in the table accountFraud,finance
1,1,propiretary4,accountFraud,SELECT-ADD-COL,SELECT `hasothercards` FROM `accountFraud`,Show all `hasothercards` in the table accountF...,finance
2,2,propiretary4,accountFraud,SELECT-ADD-COL,"SELECT `hasothercards`, `housingstatus` FROM `...","Show all `hasothercards`, `housingstatus` in t...",finance
3,3,propiretary4,accountFraud,SELECT-ADD-COL,"SELECT `hasothercards`, `housingstatus`, `date...","Show all `hasothercards`, `housingstatus`, `da...",finance
4,4,propiretary4,accountFraud,SELECT-ADD-COL,"SELECT `hasothercards`, `housingstatus`, `date...","Show all `hasothercards`, `housingstatus`, `da...",finance


In [None]:
# Replacing apostrophes with double quotes in the 'query' column
qatch_synthetic['query'] = qatch_synthetic['query'].str.replace("`", '"', regex=False)

In [None]:
import numpy as np

def is_simple_query_with_single_aggregation(query):
    # Split the query string into tokens based on whitespace.
    query_toks = query.split()

    # Define both disallowed general keywords and specific aggregation keywords.
    disallowed_keywords = {
        'join', 'union', 'intersect', 'except', 'inner',
        'outer', 'left', 'right', 'full', 'limit', 'like'
    }
    aggregation_keywords = {
        'count', 'sum', 'avg', 'min', 'max'
    }

    # Convert the query tokens to lowercase for comparison.
    query_toks_lower = [tok.lower() for tok in query_toks]

    # Check for disallowed general keywords.
    if any(keyword in query_toks_lower for keyword in disallowed_keywords):
        return False, np.nan

    # Check for nested queries by counting the occurrences of 'select'.
    select_count = query_toks_lower.count('select')
    if select_count > 1:
        return False, np.nan

    # Check for aggregation keywords.
    aggregation_count = sum(tok in aggregation_keywords for tok in query_toks_lower)

    # Check if there is at most one aggregation keyword.
    if aggregation_count > 1:
        return False, np.nan

    # If an aggregation keyword is present, ensure only one column is projected.
    if aggregation_count == 1:
        select_index = query_toks_lower.index('select')
        from_index = query_toks_lower.index('from')
        columns = query_toks[select_index + 1: from_index]
        if ',' in ' '.join(columns): # Join the columns list to handle comma separation correctly.
            return False, np.nan
        else:
            # Extract the aggregation operator between 'select' and 'from'.
            operator = next((tok for tok in columns if tok.lower() in aggregation_keywords), np.nan)
            return True, operator

    return True, np.nan



filtered_df_simple=qatch_synthetic.copy()

filtered_df_simple['is_simple'] = filtered_df_simple['query'].apply(lambda x: is_simple_query_with_single_aggregation(x)[0])
filtered_df_simple['operator'] = filtered_df_simple['query'].apply(lambda x: is_simple_query_with_single_aggregation(x)[1])

simple_queries_df = filtered_df_simple[filtered_df_simple['is_simple']].copy()


simple_queries_df.drop(columns=['is_simple'], inplace=True)

In [None]:
simple_queries_df

Unnamed: 0,ID,db_id,table_used,sql_tags,query,question,domain,operator
0,0,propiretary4,accountFraud,SELECT-ALL,"SELECT * FROM ""accountFraud""",Show all the rows in the table accountFraud,finance,
1,1,propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"" FROM ""accountFraud""",Show all `hasothercards` in the table accountF...,finance,
2,2,propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"" FROM ""...","Show all `hasothercards`, `housingstatus` in t...",finance,
3,3,propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,
4,4,propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,
...,...,...,...,...,...,...,...,...
1843,1843,propiretary4,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""progesteronereceptor"") FROM ""breas...","Find the minimum ""progesteronereceptor"" for th...",medicine,
1844,1844,propiretary4,breastCancer,SIMPLE-AGG-AVG,"SELECT AVG(""progesteronereceptor"") FROM ""breas...","Find the average ""progesteronereceptor"" for th...",medicine,
1845,1845,propiretary4,breastCancer,SIMPLE-AGG-MAX,"SELECT MAX(""estrogenreceptor"") FROM ""breastCan...","Find the maximum ""estrogenreceptor"" for the ta...",medicine,
1846,1846,propiretary4,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""estrogenreceptor"") FROM ""breastCan...","Find the minimum ""estrogenreceptor"" for the ta...",medicine,


In [None]:
simple_queries_df['table_used'].unique()

array(['accountFraud', 'latePayment', 'salesTransactions',
       'fitnessTrackers', 'mushrooms', 'adultCensus', 'heartAttack',
       'breastCancer'], dtype=object)

In [None]:
from itertools import chain



results_df = pd.DataFrame(columns=['ID','answer_text'])


db_path = '/content/drive/MyDrive/ProprietaryDatasets/database4/propiretary4/propiretary4.sqlite'

for index, row in simple_queries_df.iterrows():
    query = row['query']
    ID=row['ID']



    with sqlite3.connect(db_path) as conn:
        cur = conn.cursor()
        cur.execute(query)
        answers = cur.fetchall()


        list_answers = list(chain.from_iterable(answers))


        new_row = pd.DataFrame({'ID': [ID], 'answer_text': [list_answers]})
        results_df = pd.concat([results_df, new_row], ignore_index=True)

In [None]:
results_df

Unnamed: 0,ID,answer_text
0,0,"[False, BA, 8, 0.7000000000000001, AB, CA, 181..."
1,1,"[False, False, False, False, False, False, Fal..."
2,2,"[False, BA, False, BB, False, BB, False, BC, F..."
3,3,"[False, BA, 8, False, BB, 8, False, BB, 1, Fal..."
4,4,"[False, BA, 8, 0.7000000000000001, False, BB, ..."
...,...,...
1843,1843,[0]
1844,1844,[143.93333333333334]
1845,1845,[339]
1846,1846,[0]


In [None]:
merged_df = pd.merge(results_df, simple_queries_df, on='ID')
merged_df = merged_df[merged_df['answer_text'].apply(lambda x: x != [])]
merged_df

Unnamed: 0,ID,answer_text,db_id,table_used,sql_tags,query,question,domain,operator
0,0,"[False, BA, 8, 0.7000000000000001, AB, CA, 181...",propiretary4,accountFraud,SELECT-ALL,"SELECT * FROM ""accountFraud""",Show all the rows in the table accountFraud,finance,
1,1,"[False, False, False, False, False, False, Fal...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"" FROM ""accountFraud""",Show all `hasothercards` in the table accountF...,finance,
2,2,"[False, BA, False, BB, False, BB, False, BC, F...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"" FROM ""...","Show all `hasothercards`, `housingstatus` in t...",finance,
3,3,"[False, BA, 8, False, BB, 8, False, BB, 1, Fal...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,
4,4,"[False, BA, 8, 0.7000000000000001, False, BB, ...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,
...,...,...,...,...,...,...,...,...,...
1843,1843,[0],propiretary4,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""progesteronereceptor"") FROM ""breas...","Find the minimum ""progesteronereceptor"" for th...",medicine,
1844,1844,[143.93333333333334],propiretary4,breastCancer,SIMPLE-AGG-AVG,"SELECT AVG(""progesteronereceptor"") FROM ""breas...","Find the average ""progesteronereceptor"" for th...",medicine,
1845,1845,[339],propiretary4,breastCancer,SIMPLE-AGG-MAX,"SELECT MAX(""estrogenreceptor"") FROM ""breastCan...","Find the maximum ""estrogenreceptor"" for the ta...",medicine,
1846,1846,[0],propiretary4,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""estrogenreceptor"") FROM ""breastCan...","Find the minimum ""estrogenreceptor"" for the ta...",medicine,


In [None]:
# problematic_rows_ids=[]
# def get_answer_coordinates(row):
#     table= row['table_used']
#     table_data = db_to_df.get(table)

#     if table_data is not None:
#         question = row['question']
#         answer_texts = [str(ans) for ans in row['answer_text']]

#         try:
#             _, new_answer_texts, answer_coordinates, float_value, _ = parse_question(
#                 table=table_data, question=question, answer_texts=answer_texts)
#             return answer_coordinates, new_answer_texts, float_value
#         except Exception as e:
#             problematic_rows_ids.append(row['ID'])
#             print(f"Error processing row: {e}")
#             return None, None, None
#     else:
#         problematic_rows_ids.append(row['ID'])
#         return None, None, None
def get_answer_coordinates(row):
    table = row['table_used']
    table_data = db_to_df.get(table)
    index=row['ID']
    total_rows=1815

    if table_data is not None:
        question = row['question']
        answer_texts = [str(ans) for ans in row['answer_text']]

        try:
            _, new_answer_texts, answer_coordinates, float_value, _ = parse_question(
                table=table_data, question=question, answer_texts=answer_texts)
            if index % 100 == 0:  # Adjust the modulus depending on your total number of rows for less or more frequent updates
                print(f"Progress: Processed {index + 1} out of {total_rows} rows.")
            return answer_coordinates, new_answer_texts, float_value
        except Exception as e:
            problematic_rows_ids.append(row['ID'])
            print(f"Error processing row {row['ID']}: {e}")
            return None, None, None
    else:
        problematic_rows_ids.append(row['ID'])
        return None, None, None



In [None]:
# merged_df['answer_coordinates'] = merged_df.apply(get_answer_coordinates, axis=1)
# Applying the function and creating new columns
problematic_rows_ids=[]
result = merged_df.apply(get_answer_coordinates, axis=1, result_type='expand')
merged_df['answer_coordinates'],merged_df['float_value'] = result[0],result[2]

Progress: Processed 1 out of 1815 rows.
Progress: Processed 101 out of 1815 rows.
Progress: Processed 201 out of 1815 rows.
Progress: Processed 301 out of 1815 rows.
Progress: Processed 401 out of 1815 rows.
Progress: Processed 501 out of 1815 rows.
Progress: Processed 601 out of 1815 rows.
Progress: Processed 701 out of 1815 rows.
Error processing row 791: Cannot parse answer: [float_value: Cannot convert to multiple answers to single float]
Error processing row 792: Cannot parse answer: [float_value: Cannot convert to multiple answers to single float]
Error processing row 793: Cannot parse answer: [float_value: Cannot convert to multiple answers to single float]
Error processing row 794: Cannot parse answer: [float_value: Cannot convert to multiple answers to single float]
Error processing row 795: Cannot parse answer: [float_value: Cannot convert to multiple answers to single float]
Progress: Processed 801 out of 1815 rows.
Error processing row 804: Cannot parse answer: [float_value

In [None]:
problematic_rows_ids

[791,
 792,
 793,
 794,
 795,
 804,
 805,
 806,
 807,
 808,
 809,
 810,
 811,
 818,
 819,
 822,
 832,
 833,
 834,
 835,
 838,
 839,
 849,
 858,
 860,
 861,
 864,
 865,
 872,
 873,
 875,
 876,
 877,
 886,
 887,
 888,
 889,
 890,
 891,
 892,
 893,
 904,
 905,
 906,
 907,
 908,
 909,
 910,
 911,
 912,
 913,
 923,
 924,
 925,
 926,
 927,
 928,
 929,
 930,
 931,
 936,
 937,
 938,
 939,
 945,
 946,
 947,
 948,
 949,
 958,
 959,
 960,
 961,
 962,
 963,
 964,
 965,
 969,
 980,
 981,
 982,
 983,
 985,
 987,
 1619,
 1679,
 1680,
 1683,
 1684,
 1687,
 1688,
 1690,
 1692,
 1695,
 1697,
 1699,
 1701,
 1703,
 1705]

In [None]:
merged_df_cleaned = merged_df[~merged_df['ID'].isin(problematic_rows_ids)].copy()

In [None]:
merged_df_cleaned = merged_df_cleaned.dropna(subset=['answer_coordinates'])
merged_df_cleaned

Unnamed: 0,ID,answer_text,db_id,table_used,sql_tags,query,question,domain,operator,answer_coordinates,float_value
0,0,"[False, BA, 8, 0.7000000000000001, AB, CA, 181...",propiretary4,accountFraud,SELECT-ALL,"SELECT * FROM ""accountFraud""",Show all the rows in the table accountFraud,finance,,"[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5...",
1,1,"[False, False, False, False, False, False, Fal...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"" FROM ""accountFraud""",Show all `hasothercards` in the table accountF...,finance,,"[(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0...",
2,2,"[False, BA, False, BB, False, BB, False, BC, F...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"" FROM ""...","Show all `hasothercards`, `housingstatus` in t...",finance,,"[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1...",
3,3,"[False, BA, 8, False, BB, 8, False, BB, 1, Fal...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,,"[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2...",
4,4,"[False, BA, 8, 0.7000000000000001, False, BB, ...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,,"[(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1...",
...,...,...,...,...,...,...,...,...,...,...,...
1840,1840,[1],propiretary4,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""numberpositivelymphnodes"") FROM ""b...","Find the minimum ""numberpositivelymphnodes"" fo...",medicine,,"[(1, 5)]",1.0
1842,1842,[412],propiretary4,breastCancer,SIMPLE-AGG-MAX,"SELECT MAX(""progesteronereceptor"") FROM ""breas...","Find the maximum ""progesteronereceptor"" for th...",medicine,,"[(3, 6)]",412.0
1843,1843,[0],propiretary4,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""progesteronereceptor"") FROM ""breas...","Find the minimum ""progesteronereceptor"" for th...",medicine,,"[(6, 6)]",0.0
1845,1845,[339],propiretary4,breastCancer,SIMPLE-AGG-MAX,"SELECT MAX(""estrogenreceptor"") FROM ""breastCan...","Find the maximum ""estrogenreceptor"" for the ta...",medicine,,"[(3, 7)]",339.0


In [None]:
merged_df_cleaned

Unnamed: 0,ID,answer_text,db_id,table_used,sql_tags,query,question,domain,operator,answer_coordinates,float_value
0,0,"[False, BA, 8, 0.7000000000000001, AB, CA, 181...",propiretary4,accountFraud,SELECT-ALL,"SELECT * FROM ""accountFraud""",Show all the rows in the table accountFraud,finance,,"[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5...",
1,1,"[False, False, False, False, False, False, Fal...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"" FROM ""accountFraud""",Show all `hasothercards` in the table accountF...,finance,,"[(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0...",
2,2,"[False, BA, False, BB, False, BB, False, BC, F...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"" FROM ""...","Show all `hasothercards`, `housingstatus` in t...",finance,,"[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1...",
3,3,"[False, BA, 8, False, BB, 8, False, BB, 1, Fal...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,,"[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2...",
4,4,"[False, BA, 8, 0.7000000000000001, False, BB, ...",propiretary4,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,,"[(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1...",
...,...,...,...,...,...,...,...,...,...,...,...
1840,1840,[1],propiretary4,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""numberpositivelymphnodes"") FROM ""b...","Find the minimum ""numberpositivelymphnodes"" fo...",medicine,,"[(1, 5)]",1.0
1842,1842,[412],propiretary4,breastCancer,SIMPLE-AGG-MAX,"SELECT MAX(""progesteronereceptor"") FROM ""breas...","Find the maximum ""progesteronereceptor"" for th...",medicine,,"[(3, 6)]",412.0
1843,1843,[0],propiretary4,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""progesteronereceptor"") FROM ""breas...","Find the minimum ""progesteronereceptor"" for th...",medicine,,"[(6, 6)]",0.0
1845,1845,[339],propiretary4,breastCancer,SIMPLE-AGG-MAX,"SELECT MAX(""estrogenreceptor"") FROM ""breastCan...","Find the maximum ""estrogenreceptor"" for the ta...",medicine,,"[(3, 7)]",339.0


In [None]:
merged_df_cleaned.to_pickle('/content/drive/MyDrive/spider/QATCH_proprietary_v2.pkl')

In [None]:
qatch_pickle = pd.read_pickle('/content/drive/MyDrive/spider/QATCH_proprietary.pkl')

In [None]:
qatch_pickle

Unnamed: 0,ID,answer_text,db_id,table_used,sql_tags,query,question,domain,operator,answer_coordinates,float_value
0,0,"[False, BA, 8, 0.7000000000000001, AB, CA, 181...",propiretary_dbs,accountFraud,SELECT-ALL,"SELECT * FROM ""accountFraud""",Show all the rows in the table accountFraud,finance,,"[(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5...",
1,1,"[False, False, False, False, False, False, Fal...",propiretary_dbs,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"" FROM ""accountFraud""",Show all `hasothercards` in the table accountF...,finance,,"[(0, 0), (1, 0), (2, 0), (3, 0), (4, 0), (5, 0...",
2,2,"[False, BA, False, BB, False, BB, False, BC, F...",propiretary_dbs,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"" FROM ""...","Show all `hasothercards`, `housingstatus` in t...",finance,,"[(0, 0), (0, 1), (1, 0), (1, 1), (2, 0), (2, 1...",
3,3,"[False, BA, 8, False, BB, 8, False, BB, 1, Fal...",propiretary_dbs,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,,"[(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2...",
4,4,"[False, BA, 8, 0.7000000000000001, False, BB, ...",propiretary_dbs,accountFraud,SELECT-ADD-COL,"SELECT ""hasothercards"", ""housingstatus"", ""date...","Show all `hasothercards`, `housingstatus`, `da...",finance,,"[(0, 0), (0, 1), (0, 2), (0, 3), (1, 0), (1, 1...",
...,...,...,...,...,...,...,...,...,...,...,...
1840,1840,[1],propiretary_dbs,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""numberpositivelymphnodes"") FROM ""b...","Find the minimum ""numberpositivelymphnodes"" fo...",medicine,,"[(1, 5)]",1.0
1842,1842,[422],propiretary_dbs,breastCancer,SIMPLE-AGG-MAX,"SELECT MAX(""progesteronereceptor"") FROM ""breas...","Find the maximum ""progesteronereceptor"" for th...",medicine,,"[(37, 6)]",422.0
1843,1843,[0],propiretary_dbs,breastCancer,SIMPLE-AGG-MIN,"SELECT MIN(""progesteronereceptor"") FROM ""breas...","Find the minimum ""progesteronereceptor"" for th...",medicine,,"[(6, 6)]",0.0
1845,1845,[442],propiretary_dbs,breastCancer,SIMPLE-AGG-MAX,"SELECT MAX(""estrogenreceptor"") FROM ""breastCan...","Find the maximum ""estrogenreceptor"" for the ta...",medicine,,"[(17, 7)]",442.0


In [None]:
#@title Cleanig QATCH
db_ids_train_spider= set(instance['db_id'] for instance in train_set)


from qatch.database_reader import MultipleDatabases

# The path to multiple databases
db_save_path = '/content/drive/MyDrive/spider/test_database'
databases = MultipleDatabases(db_save_path)


from qatch import TestGenerator

# # init generator
# test_generator = TestGenerator(databases=databases)

# # generate tests for each database and for each generator
# tests_df = test_generator.generate()


# # Create a mask where each row is True if 'db_id' is in db_ids_train_spider, else False
# mask = tests_df['db_id'].isin(db_ids_train_spider)

# # Apply the mask to filter the DataFrame
# tests_df = tests_df[mask]

# unique_db_ids=tests_df['db_id'].unique()


# tests_df = tests_df[tests_df.apply(lambda row: row['tbl_name'] in db_to_tables_constraint.get(row['db_id'], []), axis=1)]

# tests_df=tests_df.rename(columns={'tbl_name':'table_used'})


qatch_synthetic=tests_df.drop(columns=["sql_tags"], axis=1)
qatch_synthetic=qatch_synthetic.copy()
qatch_synthetic['seq_id'] = qatch_synthetic.apply(lambda row: f"{row['db_id']}_X_{row['table_used']}", axis=1)


qatch_synthetic.reset_index(inplace=True)
qatch_synthetic.rename(columns={'index': 'ID'}, inplace=True)


results_df_qatch = pd.DataFrame(columns=['ID','answer_text'])


base_path = '/content/drive/MyDrive/spider/test_database/'

for index, row in qatch_synthetic.iterrows():
    db_id = row['db_id']
    table_used = row['table_used']
    query = row['query']
    ID=row['ID']


    db_path = f'{base_path}{db_id}/{db_id}.sqlite'


    with sqlite3.connect(db_path) as conn:
        cur = conn.cursor()
        cur.execute(query)
        answers = cur.fetchall()


        list_answers = list(chain.from_iterable(answers))


        new_row = pd.DataFrame({'ID': [ID], 'answer_text': [list_answers]})
        results_df_qatch = pd.concat([results_df_qatch, new_row], ignore_index=True)


merged_df_qatch = pd.merge(results_df_qatch, qatch_synthetic, on='ID')
merged_df_qatch = merged_df_qatch[merged_df_qatch['answer_text'].apply(lambda x: x != [])]


problematic_rows_ids=[]


result = merged_df.apply(get_answer_coordinates, axis=1, result_type='expand')
merged_df_qatch['answer_coordinates'],merged_df_qatch['float_value'] = result[0],result[2]

merged_df_cleaned_qatch = merged_df_qatch[~merged_df_qatch['ID'].isin(problematic_rows_ids)].copy()

merged_df_cleaned_qatch = merged_df_cleaned_qatch.dropna(subset=['answer_coordinates'])