In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
import pandas as pd
import random
from typing import List, Tuple
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def apply_erosion(schema: str, sql: str,p_drop:int,p_add:int,schemas) -> Tuple[str, str]:
        sql_cols = re.findall(r'<col\d+>', sql)
        tables = re.findall(r'<table\d+>\[(.*?)\]\s\{(.*?)\}', schema)
        table_nums = re.findall(r'<col(\d+)>', schema)
        table_nums = list(map(int, table_nums))
        max_table_num = max(table_nums)+1
        modified_tables = []

        for table in tables:
            modified_table=[]
            table_name=table[0]
            columns = re.findall(r'<col\d+>\[.*?\]:\[.*?\]', table[1])
            
            # Permutation: Randomly reorder the columns
            random.shuffle(columns)
            
            # Removal: Randomly remove columns based on the drop probability
            for col in columns:
                if random.random() < p_drop:
                    columns.remove(col)
                    removed_col = re.search(r'<col\d+>', col).group(0)
                    sql_cols = list(map(lambda x: x.replace(removed_col, "<unk>"), sql_cols))
                    sql = sql.replace(removed_col,"<unk>")
                    
                      
                 
            
            # Addition: Add columns from other schemas based on the add probability
            if random.random() < p_add:
                extra_table = random.choice(schemas)
                extra_columns = re.findall(r'\[.*?\]:\[.*?\]', extra_table)
                
                added_cols= random.sample(extra_columns, k=min(len(extra_columns), 1))
                for col in added_cols:
                    modified_col = f'<col{max_table_num}>{col}'
                    max_table_num += 1
                    columns.append(modified_col)
            
            
            modified_table.append(table_name)
            modified_table.append(" ".join(columns))
            modified_tables.append(modified_table)
        
        modified_schema = " ".join([f"<table{i}>[{table[0]}] {{{table[1]}}}" for i, table in enumerate(modified_tables)])
        
        return modified_schema, sql

In [6]:
from datasets import load_dataset

ds = load_dataset("prkhar05/SeaD_smalltrain",split="train")

In [7]:
schema = ds[0]['schema']
print(schema)
sql = ds[0]['query']
print(sql)

<table0>[table_name_44] { <col0>[numer_of_jamaicans_granted_british_citizenship]:[INTEGER] <col1>[year]:[VARCHAR] <col2>[registration_of_a_minor_child]:[VARCHAR]  } 
SELECT SUM( `<col0>` ) FROM `<table0>` WHERE `<col1>` = 2004 AND `<col2>` > 640


In [12]:
def get_all_text(ds):
    for item in ds:
        return item['schema'] + " " + item['question'] + " " +  item['query']

In [13]:
all_text = get_all_text(ds)
all_text

'<table0>[table_name_44] { <col0>[numer_of_jamaicans_granted_british_citizenship]:[INTEGER] <col1>[year]:[VARCHAR] <col2>[registration_of_a_minor_child]:[VARCHAR]  }   Tell me the sum of number of jamaicans given british citizenship for 2004 and registration of a minor child more than 640\n SELECT SUM( `<col0>` ) FROM `<table0>` WHERE `<col1>` = 2004 AND `<col2>` > 640'

In [14]:
len(ds)

9999

In [5]:
p_drop = 0.2
p_add = 0.4
schemas = ds['train']['schema']

In [220]:
modified_schema , modified_sql = apply_erosion(schema,sql,p_drop,p_add,schemas)
print(modified_schema)
print(modified_sql)

<table0>[table_name_44] {<col0>[numer_of_jamaicans_granted_british_citizenship]:[INTEGER] <col1>[year]:[VARCHAR]}
SELECT SUM( `<col0>` ) FROM `<table0>` WHERE `<col1>` = 2004 AND `<unk>` > 640


In [222]:
sql

'SELECT SUM( `<col0>` ) FROM `<table0>` WHERE `<col1>` = 2004 AND `<col2>` > 640'

In [281]:
def shuffle(sql: str) -> str:
    # Define regex patterns for <col>, <table>, and numbers
    col_pattern = r'<col\d+>'
    table_pattern = r'<table\d+>'
    number_pattern = r'\b\d+\b'  # Matches standalone numbers

    # Function to shuffle a list of items preserving non-matching items
    def shuffle_entities(entities, tokens):
        shuffled_entities = entities[:]
        random.shuffle(shuffled_entities)
        result = []
        index = 0
        for token in tokens:
            if token in entities:
                result.append(shuffled_entities[index])
                index += 1
            else:
                result.append(token)
        return result

    # Tokenize SQL by spaces
    split_pattern = r'\s+|([`()"\'\[\]])'
    tokens = re.split(split_pattern,sql)
    tokens = filter(None, tokens)

    # Collect all <col>, <table>, and number entities
    entities_to_shuffle = []

    cols = re.findall(col_pattern,sql)
    tables = re.findall(table_pattern,sql)
    nums = re.findall(number_pattern,sql)

    entities_to_shuffle = cols + tables + nums
    
    # Shuffle entities while preserving the rest of the tokens
    shuffled_tokens = shuffle_entities(entities_to_shuffle, tokens)
    # Reconstruct SQL query
    shuffled_sql = " ".join(shuffled_tokens)
    return shuffled_sql

In [283]:
shuffled_sql = shuffle(sql)
shuffled_sql

'SELECT SUM ( ` 2004 ` ) FROM ` <table0> ` WHERE ` <col1> ` = 640 AND ` <col0> ` > <col2>'