In [None]:
!pip install spacy
!pip install sqlglot
!pip install -U pandas sqlalchemy
!pip install -U sentence-transformers
!pip install bitsandbytes>=0.39.0
!pip install -U transformers accelerate datasets
!python -m spacy download en_core_web_md
!python -m spacy download en_core_web_trf
!pip install Levenshtein

In [1]:
import os
import pandas as pd
import numpy as np

### Input data

We had previously downloaded the [Young People Survey Dataset](https://www.kaggle.com/datasets/miroslavsabo/young-people-survey) from Kaggle and loaded it into a `sqlite` DB. Let's take a look at the data once more.

In [2]:
df = pd.read_csv('data/kaggle-young-people-survey-dataset/responses.csv')
df.head()

Unnamed: 0,Music,Slow songs or fast songs,Dance,Folk,Country,Classical music,Musical,Pop,Rock,Metal or Hardrock,...,Age,Height,Weight,Number of siblings,Gender,Left - right handed,Education,Only child,Village - town,House - block of flats
0,5.0,3.0,2.0,1.0,2.0,2.0,1.0,5.0,5.0,1.0,...,20.0,163.0,48.0,1.0,female,right handed,college/bachelor degree,no,village,block of flats
1,4.0,4.0,2.0,1.0,1.0,1.0,2.0,3.0,5.0,4.0,...,19.0,163.0,58.0,2.0,female,right handed,college/bachelor degree,no,city,block of flats
2,5.0,5.0,2.0,2.0,3.0,4.0,5.0,3.0,5.0,3.0,...,20.0,176.0,67.0,2.0,female,right handed,secondary school,no,city,block of flats
3,5.0,3.0,2.0,1.0,1.0,1.0,1.0,2.0,2.0,1.0,...,22.0,172.0,59.0,1.0,female,right handed,college/bachelor degree,yes,city,house/bungalow
4,5.0,3.0,4.0,3.0,2.0,4.0,3.0,5.0,3.0,1.0,...,20.0,170.0,59.0,1.0,female,right handed,secondary school,no,village,house/bungalow


### Checking the SQLite DB

Let's connect to our `sqlite` file and look at its contents

In [3]:
from sqlalchemy import create_engine
engine = create_engine(f"sqlite:///mysqlitedb.db")
table_name = 'young-people-survey'

try:
    df = pd.read_csv('data/kaggle-young-people-survey-dataset/responses.csv')
    df.to_sql(table_name, engine, index=False)
    print('Data loaded from CSV file!')
except ValueError as e:
    err_msg = e.args[0]
    if 'already exists' in err_msg:
        print('Table already exists in SQLite DB')

Table already exists in SQLite DB


In [4]:
with engine.connect() as conn, conn.begin():
    sqlite_master = pd.read_sql_query("SELECT * FROM sqlite_master", conn)
sqlite_master['sql_fmt'] = sqlite_master['sql'].apply(lambda z: [x.strip().strip(',').rsplit(' ', maxsplit=1) for x in z.split('\n')[1:-1]])

table_desc_dict = {}
for _, row in sqlite_master.iterrows():
    table_desc_dict[row['name']] = row['sql']

schema = table_desc_dict[table_name]
print(schema)

CREATE TABLE "young-people-survey" (
	"Music" FLOAT, 
	"Slow songs or fast songs" FLOAT, 
	"Dance" FLOAT, 
	"Folk" FLOAT, 
	"Country" FLOAT, 
	"Classical music" FLOAT, 
	"Musical" FLOAT, 
	"Pop" FLOAT, 
	"Rock" FLOAT, 
	"Metal or Hardrock" FLOAT, 
	"Punk" FLOAT, 
	"Hiphop, Rap" FLOAT, 
	"Reggae, Ska" FLOAT, 
	"Swing, Jazz" FLOAT, 
	"Rock n roll" FLOAT, 
	"Alternative" FLOAT, 
	"Latino" FLOAT, 
	"Techno, Trance" FLOAT, 
	"Opera" FLOAT, 
	"Movies" FLOAT, 
	"Horror" FLOAT, 
	"Thriller" FLOAT, 
	"Comedy" FLOAT, 
	"Romantic" FLOAT, 
	"Sci-fi" FLOAT, 
	"War" FLOAT, 
	"Fantasy/Fairy tales" FLOAT, 
	"Animated" FLOAT, 
	"Documentary" FLOAT, 
	"Western" FLOAT, 
	"Action" FLOAT, 
	"History" FLOAT, 
	"Psychology" FLOAT, 
	"Politics" FLOAT, 
	"Mathematics" FLOAT, 
	"Physics" FLOAT, 
	"Internet" FLOAT, 
	"PC" FLOAT, 
	"Economy Management" FLOAT, 
	"Biology" FLOAT, 
	"Chemistry" FLOAT, 
	"Reading" FLOAT, 
	"Geography" FLOAT, 
	"Foreign languages" FLOAT, 
	"Medicine" FLOAT, 
	"Law" FLOAT, 
	"Cars" FLO

In [14]:
table_desc_df = {}
for _, row in sqlite_master.iterrows():
    table_desc_df[row['name']] = pd.DataFrame(columns=['name', 'type'], data=row['sql_fmt'])
    table_desc_df[row['name']]['comment'] = np.nan

table_desc_df

{'young-people-survey':                            name   type  comment
 0                       "Music"  FLOAT      NaN
 1    "Slow songs or fast songs"  FLOAT      NaN
 2                       "Dance"  FLOAT      NaN
 3                        "Folk"  FLOAT      NaN
 4                     "Country"  FLOAT      NaN
 ..                          ...    ...      ...
 145       "Left - right handed"   TEXT      NaN
 146                 "Education"   TEXT      NaN
 147                "Only child"   TEXT      NaN
 148            "Village - town"   TEXT      NaN
 149    "House - block of flats"   TEXT      NaN
 
 [150 rows x 3 columns]}

### DDL Pruning and Prompt Generation

As we saw in the previous chapter, we prune the table DDL so that the model doesn't get lost inside hundreds of column names and descriptions and only has to look at the most relevant ones to generate the SQL query.

In [5]:
import spacy
import torch
import sqlglot
import numpy as np
import os, re, logging, pickle
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer

  from tqdm.autonotebook import tqdm, trange


In [6]:
def knn_(query: str, all_embs: torch.tensor, top_k: int, threshold: float) -> tuple[torch.tensor, torch.tensor]:
    
    """
    Get top most similar columns' embeddings to query using cosine similarity.
    """
    encoder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", device='cpu')
    query_emb = encoder.encode(query, convert_to_tensor=True, device='cpu').unsqueeze(0)

    similarity_scores = F.cosine_similarity(query_emb, all_embs)
    top_results = torch.nonzero(similarity_scores > threshold).squeeze()

    # if top_results is empty, return empty tensors
    if top_results.numel() == 0:
        return torch.tensor([]), torch.tensor([])

    # if only 1 result is returned, we need to convert it to a tensor
    elif top_results.numel() == 1:
        return torch.tensor([similarity_scores[top_results]]), torch.tensor([top_results])
    else:
        top_k_scores, top_k_indices = torch.topk(similarity_scores[top_results], k=min(top_k, top_results.numel()))
        return top_k_scores, top_results[top_k_indices]
    
    
    
def format_topk_sql(topk_table_columns: dict[str, list[tuple[str, str, str]]], shuffle: bool) -> str:
    if len(topk_table_columns) == 0:
        return ""

    md_str = "\n"
    # shuffle the keys in topk_table_columns
    table_names = list(topk_table_columns.keys())
    if shuffle:
        np.random.seed(0)
        np.random.shuffle(table_names)
    for table_name in table_names:
        columns_str = ""
        columns = topk_table_columns[table_name]
        if shuffle:
            np.random.seed(0)
            np.random.shuffle(columns)
        for column_tuple in columns:
            if len(column_tuple) > 2:
                columns_str += (
                    f"\n  {column_tuple[0]} {column_tuple[1]}, --{column_tuple[2]}"
                )
            else:
                columns_str += f"\n  {column_tuple[0]} {column_tuple[1]}, "
        md_str += f"CREATE TABLE {table_name} ({columns_str}\n);\n"
    md_str += "\n"
    return md_str

In [None]:
question = 'Fetch the count of male and female in the data.'
EMBEDDING_PATH = 'embs'
TOP_K_LIMIT = 25 # number of columns to include in the prompt
PRUNE_LIMIT = 5 # minimum number of columns above which a given DDL will be pruned
num_cols = 0
TAB_DETAILS = []


for col in sqlglot.parse_one(schema, dialect='snowflake').find_all(sqlglot.exp.ColumnDef):
    num_cols += 1
    TAB_DETAILS.append([table_name, col.alias_or_name, col.find(sqlglot.exp.DataType).__str__(), col.find(sqlglot.exp.ColumnConstraint)])
        
# print(TAB_DETAILS)


encoder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1", device='cpu')

column_descriptions = []
column_descriptions_typed = []

for row in TAB_DETAILS:
    tab_name, col_name, col_dtype, col_desc = row

    col_str = f"{tab_name}.{col_name}:{col_desc}"
    col_str_typed = f"{tab_name}.{col_name},{col_dtype},{col_desc}"

    column_descriptions.append(col_str)
    column_descriptions_typed.append(col_str_typed)
column_embs = encoder.encode(column_descriptions, convert_to_tensor=True, device='cpu')


# 1a) get top k columns
top_k_scores, top_k_indices = knn_(question, column_embs, top_k=5, threshold=0.0)
topk_table_columns = {}
table_column_names = set()

for score, index in zip(top_k_scores, top_k_indices):
    table_name, column_info = column_descriptions_typed[index].split(".", 1)
    column_tuple = re.split(r',\s*(?![^()]*\))', column_info, maxsplit=2) #split only on commas outside parantheses
    if table_name not in topk_table_columns:
        topk_table_columns[table_name] = []
    topk_table_columns[table_name].append(column_tuple)
    table_column_names.add(f"{table_name}.{column_tuple[0]}")
    # print("INCLUDED by embs: ", column_tuple)

# 1b) get columns which match terms in question
nlp = spacy.load("en_core_web_trf")
question_doc = nlp(question)
q_filtered_tokens = [token.lemma_.lower() for token in question_doc if not token.is_stop]
q_alpha_tokens = [i for i in q_filtered_tokens if (len(i)>1 and i.isalpha())]


TIME_TERMS = ['when', 'time', 'hour', 'minute', 'second', 
            'day', 'yesterday', 'today', 'tomorrow', 
            'week', 'month', 'year', 
            'duration', 'date']

time_in_q = False

nlp_ner = spacy.load("en_core_web_md")
q_ner_doc = nlp_ner(question)
ent_types = [w.label_ for w in q_ner_doc.ents]

if 'DATE' in ent_types or 'TIME' in ent_types:
    time_in_q = True
elif any([term in question.lower() for term in TIME_TERMS]):
    time_in_q = True
elif set(q_alpha_tokens).intersection(set(TIME_TERMS)):
    time_in_q = True

for col_details in column_descriptions_typed:
    table_name, column_info = col_details.split(".", 1)
    column_tuple = re.split(r',\s*(?![^()]*\))', column_info, maxsplit=2) #split only on commas outside parantheses
    col_name = column_tuple[0]

    if column_tuple in topk_table_columns[table_name]:
        # print("SKIPPING: ", column_tuple)
        continue

    # if question concerns time, add time-related columns
    if time_in_q and any([timetype in column_tuple[1] for timetype in ['DATE', 'TIMESTAMP']]):
        if table_name not in topk_table_columns:
            topk_table_columns[table_name] = []
        if column_tuple not in topk_table_columns[table_name]:
            topk_table_columns[table_name].append(column_tuple)
        table_column_names.add(f"{table_name}.{column_tuple[0]}")
        continue

    # if question-token-lemmas overlap with column-token-lemmas, add the column
    column_doc = nlp(col_name.replace('_', ' '))
    col_tokens = [token.lemma_.lower() for token in column_doc if not token.is_stop]
    col_alpha_tokens = [i for i in col_tokens if (len(i)>1 and i.isalpha())]
    if set(col_alpha_tokens).intersection(set(q_alpha_tokens)):
        if table_name not in topk_table_columns:
            topk_table_columns[table_name] = []
        if column_tuple not in topk_table_columns[table_name]:
            topk_table_columns[table_name].append(column_tuple)
        table_column_names.add(f"{table_name}.{column_tuple[0]}")

# 4) format metadata string
pruned_schema = format_topk_sql(topk_table_columns, shuffle=False)
print(pruned_schema)

In [8]:
pruned_schema = '''CREATE TABLE young-people-survey (
  Gender TEXT, --None
  Number of siblings FLOAT, --None
  Height FLOAT, --None
  Age FLOAT, --None
  Alternative FLOAT, --None
);'''

In [9]:
prompt_template = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION] 

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema 
This query will run on a database whose schema is represented in this string: {db_schema} 

### Answer 
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION] [SQL]
"""

prompt = prompt_template.format(question=question, db_schema = pruned_schema)
print(prompt)

### Task
Generate a SQL query to answer [QUESTION]Fetch the count of male and female in the data.[/QUESTION] 

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema 
This query will run on a database whose schema is represented in this string: CREATE TABLE young-people-survey (
  Gender TEXT, --None
  Number of siblings FLOAT, --None
  Height FLOAT, --None
  Age FLOAT, --None
  Alternative FLOAT, --None
); 

### Answer 
Given the database schema, here is the SQL query that answers [QUESTION]Fetch the count of male and female in the data.[/QUESTION] [SQL]



### Prompt the LLM 

We will use Ollama (running locally or over the network) to prompt our LLM of choice (in this case, we are using `sqlglot-7b-2`)

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# else, load in 8 bits – this is a bit slower
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    # torch_dtype=torch.float16,
    load_in_8bit=True,
    device_map="auto",
    use_cache=True,
)

In [None]:
import sqlparse

def generate_query(quest):
    updated_prompt = prompt.format(question=quest)
    print(updated_prompt)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [None]:
generated_sql = generate_query(question)

In [None]:
print(generated_sql)

### Fixing LLM hallucinations


In [11]:
from sqlglot import parse_one, exp, parse, table, column, to_identifier
from Levenshtein import distance
from sentence_transformers import SentenceTransformer
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import time

class queryPostprocessing:
    
    def __init__(self, query, table_metadata, embedding_model_name):
        """
            table metadata should be in upper case
        """
        self.query = query.upper().split(';')[0]
        self.table_metadata = table_metadata
        self.col_mapping = {}
        self.table_mapping = {}
        self.embedding_model = SentenceTransformer(embedding_model_name, device='cpu')

    def getIncorrectColumns(self):
        column_names_query = [col.name for col in parse_one(self.query).find_all(exp.Column)]
        column_names_query = np.unique(column_names_query).tolist()
        column_names_table = self.table_metadata['columns']

        matching_cols = list(set(column_names_query).intersection(set(column_names_table)))
        invalid_cols = list(set(column_names_query).difference(set(matching_cols)))
        available_set = list(set(column_names_table).difference(set(matching_cols)))

        # check if enough columns are available to substitute (3 columns in query, but only 2 in table

        for col in matching_cols:
            self.col_mapping[col] = col

        for col in invalid_cols:
            closest_col = self.embedding_distance(col, available_set)
            self.col_mapping[col] = closest_col
            available_set.remove(closest_col)

        return self.col_mapping
    
    def getIncorrectTables(self):
        ast = parse_one(self.query)
        table_list = []
        for tbl in ast.find_all(exp.Table):
            tbl.set('alias',None)
            table_name = tbl.sql()
            table_list.append(table_name)

        unique_tables = np.unique(table_list).tolist()
        if len(unique_tables):
            table_name = unique_tables[0]
            self.table_mapping[table_name] = self.table_metadata['table_name']

        return self.table_mapping

    def lv_distance(self, col, available_set):
        
        distances = []
        for column in available_set:
            d = distance(col, column)
            distances.append(d)

        most_similar_column = available_set[np.argmin(distances)]
        return most_similar_column

    def embedding_distance(self, col, available_set):
        col_embedding = self.embedding_model.encode([col])
        available_set_embedding = self.embedding_model.encode(available_set)

        similarity_list = cosine_similarity(col_embedding, available_set_embedding)
        most_similar_column = available_set[np.argmax(similarity_list)]

        return most_similar_column
 
    def formatQuery(self):

        _ = self.getIncorrectColumns()
        _ = self.getIncorrectTables()

        updated_query = self.query

        for tbl, updated_tbl in self.table_mapping.items():
            updated_query = updated_query.replace(tbl, updated_tbl)

        for col, updated_col in self.col_mapping.items():
            updated_query = updated_query.replace(col, updated_col)
        
        return updated_query
    
    def formatQuerySQLglot(self):

        _ = self.getIncorrectColumns()
        _ = self.getIncorrectTables()

        query_ast = parse_one(self.query)
        alias_cols = [al.alias for al in query_ast.find_all(exp.Alias)]

        for col in alias_cols:
            self.col_mapping.pop(col, None)

        for tbl in query_ast.find_all(exp.Table):
            table_alias = None
            if 'alias' in tbl.args:
                table_alias = tbl.alias
            tbl.set('alias',None)
            table_name = tbl.sql()
            if table_name in self.table_mapping:
                new_table = table(table=self.table_mapping[table_name], quoted=False, alias = table_alias)
                tbl.replace(new_table)

        for col in query_ast.find_all(exp.Column):
            column_name = col.this.this
            if column_name in self.col_mapping:
                col.this.set('this',self.col_mapping[column_name]) 
        
        return query_ast.sql(dialect='snowflake')

In [12]:
generated_sql = 'SELECT * from table'

In [15]:
print('Running post-processing...')
try:
    qp = queryPostprocessing(generated_sql.upper(), 
                             {'table_name':table_name, 'columns':table_desc_df[table_name]['name'].tolist()}, 
                             embedding_model_name='mixedbread-ai/mxbai-embed-large-v1')
    processed_query = qp.formatQuerySQLglot()
except Exception as e:
    print(e)
    processed_query = generated_sql

processed_query = processed_query.replace('""', '"')
print(f"Done! Processed query: {processed_query}")

Running post-processing...




Done! Processed query: SELECT * FROM young-people-survey


In [18]:
processed_query = 'SELECT * FROM "young-people-survey"'

In [19]:
engine = create_engine(f"sqlite:///mysqlitedb.db")
with engine.connect() as conn, conn.begin():
    query_result = pd.read_sql_query(processed_query, conn)

query_result

Unnamed: 0,Music,Slow songs or fast songs,Dance,Folk,Country,Classical music,Musical,Pop,Rock,Metal or Hardrock,...,Age,Height,Weight,Number of siblings,Gender,Left - right handed,Education,Only child,Village - town,House - block of flats
0,5.0,3.0,2.0,1.0,2.0,2.0,1.0,5.0,5.0,1.0,...,20.0,163.0,48.0,1.0,female,right handed,college/bachelor degree,no,village,block of flats
1,4.0,4.0,2.0,1.0,1.0,1.0,2.0,3.0,5.0,4.0,...,19.0,163.0,58.0,2.0,female,right handed,college/bachelor degree,no,city,block of flats
2,5.0,5.0,2.0,2.0,3.0,4.0,5.0,3.0,5.0,3.0,...,20.0,176.0,67.0,2.0,female,right handed,secondary school,no,city,block of flats
3,5.0,3.0,2.0,1.0,1.0,1.0,1.0,2.0,2.0,1.0,...,22.0,172.0,59.0,1.0,female,right handed,college/bachelor degree,yes,city,house/bungalow
4,5.0,3.0,4.0,3.0,2.0,4.0,3.0,5.0,3.0,1.0,...,20.0,170.0,59.0,1.0,female,right handed,secondary school,no,village,house/bungalow
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1005,5.0,2.0,5.0,2.0,2.0,5.0,4.0,4.0,4.0,3.0,...,20.0,164.0,57.0,1.0,female,right handed,secondary school,no,city,house/bungalow
1006,4.0,4.0,5.0,1.0,3.0,4.0,1.0,4.0,1.0,1.0,...,27.0,183.0,80.0,5.0,male,left handed,masters degree,no,village,house/bungalow
1007,4.0,3.0,1.0,1.0,2.0,2.0,2.0,3.0,4.0,1.0,...,18.0,173.0,75.0,0.0,female,right handed,secondary school,yes,city,block of flats
1008,5.0,3.0,3.0,3.0,1.0,3.0,1.0,3.0,4.0,1.0,...,25.0,173.0,58.0,1.0,female,right handed,college/bachelor degree,no,city,block of flats
