In [1]:
################# USUALLY PREINSTALLED IN COLAB #################
# !pip install numpy pandas sqlalchemy sqlglot
# !pip install torch transformers spacy

!pip install Levenshtein
!pip install accelerate
!pip install bitsandbytes
!pip install sentence-transformers
!pip install spacy-transformers

!python -m spacy download en_core_web_md
!python -m spacy download en_core_web_trf

Collecting Levenshtein
  Downloading levenshtein-0.26.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.2 kB)
Collecting rapidfuzz<4.0.0,>=3.9.0 (from Levenshtein)
  Downloading rapidfuzz-3.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading levenshtein-0.26.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (162 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m162.6/162.6 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rapidfuzz-3.9.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, Levenshtein
Successfully installed Levenshtein-0.26.0 rapidfuzz-3.9.7
Collecting bitsandbytes
  Downloading bitsandbytes-0.43.3-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Downloading bitsandbytes-0.43.3-py3-none

In [2]:
import os
import numpy as np
import pandas as pd

### Input data

We had previously downloaded the [Young People Survey Dataset](https://www.kaggle.com/datasets/miroslavsabo/young-people-survey) from Kaggle and loaded it into a `sqlite` DB. Let's take a look at the data once more.

In [3]:
df = pd.read_csv('/content/young_people_survey.csv')
df.columns = df.columns.str.replace('-','')
df.columns = df.columns.str.replace(' ','_')
df.columns = df.columns.str.replace('__','_')
df.head()

Unnamed: 0,Music,Slow_songs_or_fast_songs,Dance,Folk,Country,Classical_music,Musical,Pop,Rock,Metal_or_Hardrock,...,Age,Height,Weight,Number_of_siblings,Gender,Left_right_handed,Education,Only_child,Village_town,House_block_of_flats
0,5.0,3.0,2.0,1.0,2.0,2.0,1.0,5.0,5.0,1.0,...,20.0,163.0,48.0,1.0,female,right handed,college/bachelor degree,no,village,block of flats
1,4.0,4.0,2.0,1.0,1.0,1.0,2.0,3.0,5.0,4.0,...,19.0,163.0,58.0,2.0,female,right handed,college/bachelor degree,no,city,block of flats
2,5.0,5.0,2.0,2.0,3.0,4.0,5.0,3.0,5.0,3.0,...,20.0,176.0,67.0,2.0,female,right handed,secondary school,no,city,block of flats
3,5.0,3.0,2.0,1.0,1.0,1.0,1.0,2.0,2.0,1.0,...,22.0,172.0,59.0,1.0,female,right handed,college/bachelor degree,yes,city,house/bungalow
4,5.0,3.0,4.0,3.0,2.0,4.0,3.0,5.0,3.0,1.0,...,20.0,170.0,59.0,1.0,female,right handed,secondary school,no,village,house/bungalow


### Checking the SQLite DB

Let's connect to our `sqlite` file and look at its contents

In [4]:
from sqlalchemy import create_engine
engine = create_engine(f"sqlite:///mysqlitedb.db")
table_name = 'young_people_survey'

try:
    df.to_sql(table_name, engine, index=False)
    print('Data loaded from CSV file!')
except ValueError as e:
    err_msg = e.args[0]
    if 'already exists' in err_msg:
        print('Table already exists in SQLite DB')

Data loaded from CSV file!


In [5]:
with engine.connect() as conn, conn.begin():
    sqlite_master = pd.read_sql_query("SELECT * FROM sqlite_master", conn)

sqlite_master

Unnamed: 0,type,name,tbl_name,rootpage,sql
0,table,young_people_survey,young_people_survey,2,"CREATE TABLE young_people_survey (\n\t""Music"" ..."


As before, we can get the table DDL from the `sqlite_master` table

In [6]:
sqlite_master['sql_fmt'] = sqlite_master['sql'].apply(lambda z: [x.strip().strip(',').rsplit(' ', maxsplit=1) for x in z.split('\n')[1:-1]])

table_desc_dict = {}
for _, row in sqlite_master.iterrows():
    table_desc_dict[row['name']] = row['sql']

schema = table_desc_dict[table_name]
print(schema)

CREATE TABLE young_people_survey (
	"Music" FLOAT, 
	"Slow_songs_or_fast_songs" FLOAT, 
	"Dance" FLOAT, 
	"Folk" FLOAT, 
	"Country" FLOAT, 
	"Classical_music" FLOAT, 
	"Musical" FLOAT, 
	"Pop" FLOAT, 
	"Rock" FLOAT, 
	"Metal_or_Hardrock" FLOAT, 
	"Punk" FLOAT, 
	"Hiphop,_Rap" FLOAT, 
	"Reggae,_Ska" FLOAT, 
	"Swing,_Jazz" FLOAT, 
	"Rock_n_roll" FLOAT, 
	"Alternative" FLOAT, 
	"Latino" FLOAT, 
	"Techno,_Trance" FLOAT, 
	"Opera" FLOAT, 
	"Movies" FLOAT, 
	"Horror" FLOAT, 
	"Thriller" FLOAT, 
	"Comedy" FLOAT, 
	"Romantic" FLOAT, 
	"Scifi" FLOAT, 
	"War" FLOAT, 
	"Fantasy/Fairy_tales" FLOAT, 
	"Animated" FLOAT, 
	"Documentary" FLOAT, 
	"Western" FLOAT, 
	"Action" FLOAT, 
	"History" FLOAT, 
	"Psychology" FLOAT, 
	"Politics" FLOAT, 
	"Mathematics" FLOAT, 
	"Physics" FLOAT, 
	"Internet" FLOAT, 
	"PC" FLOAT, 
	"Economy_Management" FLOAT, 
	"Biology" FLOAT, 
	"Chemistry" FLOAT, 
	"Reading" FLOAT, 
	"Geography" FLOAT, 
	"Foreign_languages" FLOAT, 
	"Medicine" FLOAT, 
	"Law" FLOAT, 
	"Cars" FLOAT,

## Loading the LLM - SQLCoder

In [7]:
import torch
import sqlparse
from transformers import AutoTokenizer, AutoModelForCausalLM

  _torch_pytree._register_pytree_node(


In [8]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    # torch_dtype=torch.float16, # Disable if enabling the below line
    load_in_8bit=True, # Disable if enabling the above line
    device_map="auto",
    use_cache=True,
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/515 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/691 [00:00<?, ?B/s]

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

### Inference method

In [9]:
def generate_query(prompt):

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return outputs[0].split("[SQL]")[-1]

## Prompting - NL2SQL

In [10]:
question1 = 'Fetch the count of male and female who are right handed and are very interested in Cars ?'
question2 = 'Fetch the count of \'male\' and \'female\' who are \'right handed\' and are very interested in Cars ?'
question = 'Fetch the count of \'male\' and \'female\' who are \'right handed\' and afraid of public speaking ?'

In [15]:
prompt_template = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema
This query will run on a database whose schema is represented in this string: {db_schema}

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION] [SQL]
"""

prompt1 = prompt_template.format(question=question2, db_schema = schema)
print(prompt1)

### Task
Generate a SQL query to answer [QUESTION]Fetch the count of 'male' and 'female' who are 'right handed' and are very interested in Cars ?[/QUESTION] 

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema 
This query will run on a database whose schema is represented in this string: CREATE TABLE young_people_survey (
	"Music" FLOAT, 
	"Slow_songs_or_fast_songs" FLOAT, 
	"Dance" FLOAT, 
	"Folk" FLOAT, 
	"Country" FLOAT, 
	"Classical_music" FLOAT, 
	"Musical" FLOAT, 
	"Pop" FLOAT, 
	"Rock" FLOAT, 
	"Metal_or_Hardrock" FLOAT, 
	"Punk" FLOAT, 
	"Hiphop,_Rap" FLOAT, 
	"Reggae,_Ska" FLOAT, 
	"Swing,_Jazz" FLOAT, 
	"Rock_n_roll" FLOAT, 
	"Alternative" FLOAT, 
	"Latino" FLOAT, 
	"Techno,_Trance" FLOAT, 
	"Opera" FLOAT, 
	"Movies" FLOAT, 
	"Horror" FLOAT, 
	"Thriller" FLOAT, 
	"Comedy" FLOAT, 
	"Romantic" FLOAT, 
	"Scifi" FLOAT, 
	"War" FLOAT, 
	"Fantasy/Fairy_tales" FLOAT, 
	"Animated" FLOAT, 
	"Documentary" 

In [16]:
generated_sql1 = generate_query(prompt1)
print(sqlparse.format(generated_sql1, reindent=True))


SELECT SUM(CASE
               WHEN yps.Gender = 'male'
                    AND yps.Left_right_handed = 'right handed'
                    AND yps.Cars > 4.5 THEN 1
               ELSE 0
           END) AS male_right_handed_cars_interested,
       SUM(CASE
               WHEN yps.Gender = 'female'
                    AND yps.Left_right_handed = 'right handed'
                    AND yps.Cars > 4.5 THEN 1
               ELSE 0
           END) AS female_right_handed_cars_interested
FROM young_people_survey yps;


In [17]:
engine = create_engine(f"sqlite:///mysqlitedb.db")
with engine.connect() as conn, conn.begin():
    query_result1 = pd.read_sql_query(generated_sql1, conn)

print(query_result1)

   male_right_handed_cars_interested  female_right_handed_cars_interested
0                                111                                   35


In [14]:
df[(df['Gender'] == 'female') & (df['Left_right_handed'] == 'right handed') & (df['Cars'] > 4)].shape

(35, 150)

### Why Preprocesses: LLM Prompt Pruning

Even though LLMs are offering longer context lengths every day, they still tend to get [lost in the middle](https://arxiv.org/abs/2307.03172), i.e. the model pays most attention to the beginning and the end of the prompt. Since we are adding our table DDL in the prompt as context and DDLs can contain tables numbering in the multiples of hundreds, it is seen that the LLM often fails to "pick" the relevant column names, especially if they occur in the middle of the DDL.

To fix this, we "prune" the DDL, i.e. we do not include all the columns but only those which we deem relevant to our question. With this shortened DDL, the LLM should have a much easier time selecting the correct column and also leaves more room in the context window for the generated SQL.

To find the columns which are "relevant" to our question, we compute a similarity metric between the column names/descriptions and the question.

In [18]:
import spacy
import torch
import sqlglot
import numpy as np
import os, re, logging, pickle
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer

We use an embedding model called `mxbai-embed-large-v1` to compute the embeddings for our column names and our question. We also use a parameter `top_k` which is the number of columns, ordered in descending order of similarity score, we want to select from the complete table DDL.

In addition to this, we also include all columns of type `DATE` or `TIMESTAMP` if we detect any time-related terms in our question.

### Load Embedding Model

In [19]:
encoder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", device='cpu')

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/171 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/114k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/677 [00:00<?, ?B/s]

  _torch_pytree._register_pytree_node(


model.safetensors:   0%|          | 0.00/670M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.24k [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/297 [00:00<?, ?B/s]

### Methods

#### 1. Retrieve top-k columns using KNN model
#### 2. Create schema from top-k columns

In [20]:
def knn_(query: str, all_embs: torch.tensor, top_k: int, threshold: float) -> tuple[torch.tensor, torch.tensor]:

    """
    Get top most similar columns' embeddings to query using cosine similarity.
    """
    encoder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", device='cpu')
    query_emb = encoder.encode(query, convert_to_tensor=True, device='cpu').unsqueeze(0)

    similarity_scores = F.cosine_similarity(query_emb, all_embs)
    top_results = torch.nonzero(similarity_scores > threshold).squeeze()

    # if top_results is empty, return empty tensors
    if top_results.numel() == 0:
        return torch.tensor([]), torch.tensor([])

    # if only 1 result is returned, we need to convert it to a tensor
    elif top_results.numel() == 1:
        return torch.tensor([similarity_scores[top_results]]), torch.tensor([top_results])
    else:
        top_k_scores, top_k_indices = torch.topk(similarity_scores[top_results], k=min(top_k, top_results.numel()))
        return top_k_scores, top_results[top_k_indices]



def format_topk_sql(topk_table_columns: dict[str, list[tuple[str, str, str]]], shuffle: bool) -> str:
    if len(topk_table_columns) == 0:
        return ""

    md_str = "\n"
    # shuffle the keys in topk_table_columns
    table_names = list(topk_table_columns.keys())
    if shuffle:
        np.random.seed(0)
        np.random.shuffle(table_names)
    for table_name in table_names:
        columns_str = ""
        columns = topk_table_columns[table_name]
        if shuffle:
            np.random.seed(0)
            np.random.shuffle(columns)
        for column_tuple in columns:
            if len(column_tuple) > 2:
                columns_str += (
                    f"\n  {column_tuple[0]} {column_tuple[1]}, --{column_tuple[2]}"
                )
            else:
                columns_str += f"\n  {column_tuple[0]} {column_tuple[1]}, "
        md_str += f"CREATE TABLE {table_name} ({columns_str}\n);\n"
    md_str += "\n"
    return md_str

In [21]:
EMBEDDING_PATH = 'embs'
TOP_K_LIMIT = 25 # number of columns to include in the prompt
PRUNE_LIMIT = 5 # minimum number of columns above which a given DDL will be pruned
num_cols = 0
TAB_DETAILS = []


for col in sqlglot.parse_one(schema, dialect='snowflake').find_all(sqlglot.exp.ColumnDef):
    num_cols += 1
    TAB_DETAILS.append([table_name, col.alias_or_name, col.find(sqlglot.exp.DataType).__str__(), col.find(sqlglot.exp.ColumnConstraint)])

# print(TAB_DETAILS)




column_descriptions = []
column_descriptions_typed = []

for row in TAB_DETAILS:
    tab_name, col_name, col_dtype, col_desc = row

    col_str = f"{tab_name}.{col_name}:{col_desc}"
    col_str_typed = f"{tab_name}.{col_name},{col_dtype},{col_desc}"

    column_descriptions.append(col_str)
    column_descriptions_typed.append(col_str_typed)

column_descriptions_typed

['young_people_survey.Music,FLOAT,None',
 'young_people_survey.Slow_songs_or_fast_songs,FLOAT,None',
 'young_people_survey.Dance,FLOAT,None',
 'young_people_survey.Folk,FLOAT,None',
 'young_people_survey.Country,FLOAT,None',
 'young_people_survey.Classical_music,FLOAT,None',
 'young_people_survey.Musical,FLOAT,None',
 'young_people_survey.Pop,FLOAT,None',
 'young_people_survey.Rock,FLOAT,None',
 'young_people_survey.Metal_or_Hardrock,FLOAT,None',
 'young_people_survey.Punk,FLOAT,None',
 'young_people_survey.Hiphop,_Rap,FLOAT,None',
 'young_people_survey.Reggae,_Ska,FLOAT,None',
 'young_people_survey.Swing,_Jazz,FLOAT,None',
 'young_people_survey.Rock_n_roll,FLOAT,None',
 'young_people_survey.Alternative,FLOAT,None',
 'young_people_survey.Latino,FLOAT,None',
 'young_people_survey.Techno,_Trance,FLOAT,None',
 'young_people_survey.Opera,FLOAT,None',
 'young_people_survey.Movies,FLOAT,None',
 'young_people_survey.Horror,FLOAT,None',
 'young_people_survey.Thriller,FLOAT,None',
 'young_peopl

In [22]:
column_embs = encoder.encode(column_descriptions, convert_to_tensor=True, device='cpu')


# 1a) get top k columns
top_k_scores, top_k_indices = knn_(question, column_embs, top_k=5, threshold=0.0)
topk_table_columns = {}
table_column_names = set()

for score, index in zip(top_k_scores, top_k_indices):
    table_name, column_info = column_descriptions_typed[index].split(".", 1)
    column_tuple = re.split(r',\s*(?![^()]*\))', column_info, maxsplit=2) #split only on commas outside parantheses
    if table_name not in topk_table_columns:
        topk_table_columns[table_name] = []
    topk_table_columns[table_name].append(column_tuple)
    table_column_names.add(f"{table_name}.{column_tuple[0]}")
    # print("INCLUDED by embs: ", column_tuple)

topk_table_columns

{'young_people_survey': [['Fear_of_public_speaking', 'FLOAT', 'None'],
  ['Left_right_handed', 'TEXT', 'None'],
  ['Assertiveness', 'FLOAT', 'None'],
  ['Gender', 'TEXT', 'None'],
  ['Public_speaking', 'FLOAT', 'None']]}

In [23]:
# 1b) get columns which match terms in question
nlp = spacy.load("en_core_web_trf")
question_doc = nlp(question)
q_filtered_tokens = [token.lemma_.lower() for token in question_doc if not token.is_stop]
q_alpha_tokens = [i for i in q_filtered_tokens if (len(i)>1 and i.isalpha())]


TIME_TERMS = ['when', 'time', 'hour', 'minute', 'second',
            'day', 'yesterday', 'today', 'tomorrow',
            'week', 'month', 'year',
            'duration', 'date']

time_in_q = False

nlp_ner = spacy.load("en_core_web_md")
q_ner_doc = nlp_ner(question)
ent_types = [w.label_ for w in q_ner_doc.ents]

if 'DATE' in ent_types or 'TIME' in ent_types:
    time_in_q = True
elif any([term in question.lower() for term in TIME_TERMS]):
    time_in_q = True
elif set(q_alpha_tokens).intersection(set(TIME_TERMS)):
    time_in_q = True

for col_details in column_descriptions_typed:
    table_name, column_info = col_details.split(".", 1)
    column_tuple = re.split(r',\s*(?![^()]*\))', column_info, maxsplit=2) #split only on commas outside parantheses
    col_name = column_tuple[0]

    if column_tuple in topk_table_columns[table_name]:
        # print("SKIPPING: ", column_tuple)
        continue

    # if question concerns time, add time-related columns
    if time_in_q and any([timetype in column_tuple[1] for timetype in ['DATE', 'TIMESTAMP']]):
        if table_name not in topk_table_columns:
            topk_table_columns[table_name] = []
        if column_tuple not in topk_table_columns[table_name]:
            topk_table_columns[table_name].append(column_tuple)
        table_column_names.add(f"{table_name}.{column_tuple[0]}")
        continue

    # if question-token-lemmas overlap with column-token-lemmas, add the column
    column_doc = nlp(col_name.replace('_', ' '))
    col_tokens = [token.lemma_.lower() for token in column_doc if not token.is_stop]
    col_alpha_tokens = [i for i in col_tokens if (len(i)>1 and i.isalpha())]
    if set(col_alpha_tokens).intersection(set(q_alpha_tokens)):
        if table_name not in topk_table_columns:
            topk_table_columns[table_name] = []
        if column_tuple not in topk_table_columns[table_name]:
            topk_table_columns[table_name].append(column_tuple)
        table_column_names.add(f"{table_name}.{column_tuple[0]}")

  model.load_state_dict(torch.load(filelike, map_location=device))
  with torch.cuda.amp.autocast(self._mixed_precision):


In [24]:
# 4) format metadata string
pruned_schema = format_topk_sql(topk_table_columns, shuffle=False)
print(pruned_schema)


CREATE TABLE young_people_survey (
  Fear_of_public_speaking FLOAT, --None
  Left_right_handed TEXT, --None
  Assertiveness FLOAT, --None
  Gender TEXT, --None
  Public_speaking FLOAT, --None
  Knowing_the_right_people FLOAT, --None
);




## LLM Pruned Prompting - NL2SQL

In [25]:
prompt_template = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema
This query will run on a database whose schema is represented in this string: {db_schema}

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION] [SQL]
"""

prompt2 = prompt_template.format(question=question, db_schema = pruned_schema)
print(prompt2)

### Task
Generate a SQL query to answer [QUESTION]Fetch the count of 'male' and 'female' who are 'right handed' and afraid of public speaking ?[/QUESTION] 

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema 
This query will run on a database whose schema is represented in this string: 
CREATE TABLE young_people_survey (
  Fear_of_public_speaking FLOAT, --None
  Left_right_handed TEXT, --None
  Assertiveness FLOAT, --None
  Gender TEXT, --None
  Public_speaking FLOAT, --None
  Knowing_the_right_people FLOAT, --None
);

 

### Answer 
Given the database schema, here is the SQL query that answers [QUESTION]Fetch the count of 'male' and 'female' who are 'right handed' and afraid of public speaking ?[/QUESTION] [SQL]



In [26]:
generated_sql2 = generate_query(prompt2)
print(sqlparse.format(generated_sql2, reindent=True))


SELECT yps.Gender,
       COUNT(*)
FROM young_people_survey yps
WHERE yps.Left_right_handed = 'right handed'
  AND yps.Fear_of_public_speaking > 3
GROUP BY yps.Gender;


## Adding Instruction in Pruned Prompting - NL2SQL

In [27]:
prompt_template = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]

### Instructions
- Not afraid at all 1-2-3-4-5 Very afraid of (integer)
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema
This query will run on a database whose schema is represented in this string: {db_schema}

### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION] [SQL]
"""

prompt3 = prompt_template.format(question=question, db_schema = pruned_schema)
print(prompt3)

### Task
Generate a SQL query to answer [QUESTION]Fetch the count of 'male' and 'female' who are 'right handed' and afraid of public speaking ?[/QUESTION] 

### Instructions
- Not afraid at all 1-2-3-4-5 Very afraid of (integer)
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema 
This query will run on a database whose schema is represented in this string: 
CREATE TABLE young_people_survey (
  Fear_of_public_speaking FLOAT, --None
  Left_right_handed TEXT, --None
  Assertiveness FLOAT, --None
  Gender TEXT, --None
  Public_speaking FLOAT, --None
  Knowing_the_right_people FLOAT, --None
);

 

### Answer 
Given the database schema, here is the SQL query that answers [QUESTION]Fetch the count of 'male' and 'female' who are 'right handed' and afraid of public speaking ?[/QUESTION] [SQL]



In [28]:
generated_sql3 = generate_query(prompt3)
print(sqlparse.format(generated_sql3, reindent=True))


SELECT yps.Gender,
       COUNT(*)
FROM young_people_survey yps
WHERE yps.Left_right_handed = 'right handed'
  AND yps.Fear_of_public_speaking > 1
GROUP BY yps.Gender;


In [29]:
engine = create_engine(f"sqlite:///mysqlitedb.db")
with engine.connect() as conn, conn.begin():
    query_result3 = pd.read_sql_query(generated_sql3, conn)

print(query_result3)

   Gender  COUNT(*)
0    None         4
1  female       466
2    male       279


In [30]:
df[(df['Gender'] == 'female') & (df['Left_right_handed'] == 'right handed') & (df['Fear_of_public_speaking'] > 1)].shape

(466, 150)

<!-- ![](https://github.com/abhijeet3922/NL2SQL/blob/main/assets/implementation-NL2SQL.PNG?raw=1)   -->
<img src='https://github.com/abhijeet3922/NL2SQL/blob/main/assets/implementation-NL2SQL.PNG?raw=1'>  