In [1]:
!pip install spacy
!pip install sqlglot
!pip install sqlparse
!pip install accelerate
!pip install pandas sqlalchemy
!pip install -U sentence-transformers
!pip install -i https://pypi.org/simple/ bitsandbytes


%env http_proxy=http://http.proxy.aws.fmrcloud.com:8000/
%env https_proxy=http://http.proxy.aws.fmrcloud.com:8000/
!python -m spacy download en_core_web_md

/bin/bash: /home/dsp/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Looking in indexes: https://:****@artifactory.fmr.com/api/pypi/pypi-prereleases/simple
/bin/bash: /home/dsp/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Looking in indexes: https://:****@artifactory.fmr.com/api/pypi/pypi-prereleases/simple
/bin/bash: /home/dsp/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Looking in indexes: https://:****@artifactory.fmr.com/api/pypi/pypi-prereleases/simple
/bin/bash: /home/dsp/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Looking in indexes: https://:****@artifactory.fmr.com/api/pypi/pypi-prereleases/simple
/bin/bash: /home/dsp/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Looking in indexes: https://:****@artifactory.fmr.com/api/pypi/pypi-prereleases/simple
/bin/bash: /home/dsp/mini

In [2]:
import os
import numpy as np
import pandas as pd

## Chat with Tabular Data

This notebook contains code to demonstrate NL2SQL capabilities of modern LLMs and only uses open-source data and libraries

### Download input data

We will be using the Churn Modelling dataset from Kaggle. Please note that a Kaggle account is needed to download the dataset.

This data set contains details of a bank's customers and the target variable is a binary variable reflecting the fact whether the customer left the bank (closed his account) or he continues to be a customer.

In [3]:
input_filepath = 'Churn_Modelling.csv'
table_name = input_filepath.rsplit('.', maxsplit=1)[0]

df = pd.read_csv(input_filepath)
df

Unnamed: 0,RowNumber,CustomerId,Surname,CreditScore,Geography,Gender,Age,Tenure,Balance,NumOfProducts,HasCrCard,IsActiveMember,EstimatedSalary,Exited
0,1,15634602,Hargrave,619,France,Female,42,2,0.00,1,1,1,101348.88,1
1,2,15647311,Hill,608,Spain,Female,41,1,83807.86,1,0,1,112542.58,0
2,3,15619304,Onio,502,France,Female,42,8,159660.80,3,1,0,113931.57,1
3,4,15701354,Boni,699,France,Female,39,1,0.00,2,0,0,93826.63,0
4,5,15737888,Mitchell,850,Spain,Female,43,2,125510.82,1,1,1,79084.10,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9996,15606229,Obijiaku,771,France,Male,39,5,0.00,2,1,0,96270.64,0
9996,9997,15569892,Johnstone,516,France,Male,35,10,57369.61,1,1,1,101699.77,0
9997,9998,15584532,Liu,709,France,Female,36,7,0.00,1,0,1,42085.58,1
9998,9999,15682355,Sabbatini,772,Germany,Male,42,3,75075.31,2,1,0,92888.52,1


### Load CSV data into a SQLite DB

Since our LLM will generate SQL code, we must load our CSV data into a SQL compatible database.  
We also use the DB engine to create a DDL or description of the table which will be used as part of the LLM prompt.

In [4]:
from sqlalchemy import create_engine

engine = create_engine(f"sqlite:///mysqlitedb.db")

try:
    df.to_sql(table_name, engine, index=False)
except Exception as e:
    print(f'Error: {e}')
    print('')

with engine.connect() as conn, conn.begin():
    sqlite_master = pd.read_sql_query("SELECT * FROM sqlite_master", conn)
sqlite_master['sql_fmt'] = sqlite_master['sql'].apply(lambda z: [x.strip().strip(',').rsplit(' ', maxsplit=1) for x in z.split('\n')[1:-1]])

table_desc_dict = {}
for _, row in sqlite_master.iterrows():
    table_desc_dict[row['name']] = row['sql']

schema = table_desc_dict[table_name]
print(schema)

Error: Table 'Churn_Modelling' already exists.

CREATE TABLE "Churn_Modelling" (
	"RowNumber" BIGINT, 
	"CustomerId" BIGINT, 
	"Surname" TEXT, 
	"CreditScore" BIGINT, 
	"Geography" TEXT, 
	"Gender" TEXT, 
	"Age" BIGINT, 
	"Tenure" BIGINT, 
	"Balance" FLOAT, 
	"NumOfProducts" BIGINT, 
	"HasCrCard" BIGINT, 
	"IsActiveMember" BIGINT, 
	"EstimatedSalary" FLOAT, 
	"Exited" BIGINT
)


In [5]:
table_desc_df = {}
for _, row in sqlite_master.iterrows():
    table_desc_df[row['name']] = pd.DataFrame(columns=['name', 'type'], data=row['sql_fmt'])
    table_desc_df[row['name']]['comment'] = np.nan

table_desc_df['Churn_Modelling']

Unnamed: 0,name,type,comment
0,"""RowNumber""",BIGINT,
1,"""CustomerId""",BIGINT,
2,"""Surname""",TEXT,
3,"""CreditScore""",BIGINT,
4,"""Geography""",TEXT,
5,"""Gender""",TEXT,
6,"""Age""",BIGINT,
7,"""Tenure""",BIGINT,
8,"""Balance""",FLOAT,
9,"""NumOfProducts""",BIGINT,


### Prompt the LLM 

We will use Ollama (running locally or over the network) to prompt our LLM of choice (in this case, we are using `sqlglot-7b-2`)

In [8]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [9]:
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# else, load in 8 bits – this is a bit slower
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    # torch_dtype=torch.float16,
    load_in_8bit=True,
    device_map="auto",
    use_cache=True,
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [10]:
import sqlparse

def generate_query(quest):
    updated_prompt = prompt.format(question=quest)
    print(updated_prompt)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)

In [11]:
question = 'Fetch the count of male and female in the data.'

prompt_template = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION] 

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema 
This query will run on a database whose schema is represented in this string: {db_schema} 

### Answer 
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION] [SQL]
"""

prompt = prompt_template.format(question=question, db_schema = schema)
print(prompt)

### Task
Generate a SQL query to answer [QUESTION]Fetch the count of male and female in the data.[/QUESTION] 

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema 
This query will run on a database whose schema is represented in this string: CREATE TABLE "Churn_Modelling" (
	"RowNumber" BIGINT, 
	"CustomerId" BIGINT, 
	"Surname" TEXT, 
	"CreditScore" BIGINT, 
	"Geography" TEXT, 
	"Gender" TEXT, 
	"Age" BIGINT, 
	"Tenure" BIGINT, 
	"Balance" FLOAT, 
	"NumOfProducts" BIGINT, 
	"HasCrCard" BIGINT, 
	"IsActiveMember" BIGINT, 
	"EstimatedSalary" FLOAT, 
	"Exited" BIGINT
) 

### Answer 
Given the database schema, here is the SQL query that answers [QUESTION]Fetch the count of male and female in the data.[/QUESTION] [SQL]



In [12]:
generated_sql = generate_query(question)

### Task
Generate a SQL query to answer [QUESTION]Fetch the count of male and female in the data.[/QUESTION] 

### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'

## Database Schema 
This query will run on a database whose schema is represented in this string: CREATE TABLE "Churn_Modelling" (
	"RowNumber" BIGINT, 
	"CustomerId" BIGINT, 
	"Surname" TEXT, 
	"CreditScore" BIGINT, 
	"Geography" TEXT, 
	"Gender" TEXT, 
	"Age" BIGINT, 
	"Tenure" BIGINT, 
	"Balance" FLOAT, 
	"NumOfProducts" BIGINT, 
	"HasCrCard" BIGINT, 
	"IsActiveMember" BIGINT, 
	"EstimatedSalary" FLOAT, 
	"Exited" BIGINT
) 

### Answer 
Given the database schema, here is the SQL query that answers [QUESTION]Fetch the count of male and female in the data.[/QUESTION] [SQL]



In [13]:
print(generated_sql)


SELECT c.gender,
       COUNT(c.gender) AS gender_count
FROM "Churn_Modelling" c
GROUP BY c.gender;


In [14]:
engine = create_engine(f"sqlite:///mysqlitedb.db")
with engine.connect() as conn, conn.begin():
    query_result = pd.read_sql_query(generated_sql, conn)

query_result

Unnamed: 0,Gender,gender_count
0,Female,4543
1,Male,5457
