In [1]:
import pandas as pd
import duckdb
import time 
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import accelerate
import bitsandbytes

'NoneType' object has no attribute 'cadam32bit_grad_fp32'


  warn("The installed version of bitsandbytes was compiled without GPU support. "


In [2]:
token = os.getenv('HUGGING_FACE_TOKEN')

In [3]:
path = "./data_raw"

files = [x for x in os.listdir(path = path) if ".csv" in x]

print(files)

['chicago_crime_2022.csv', 'chicago_crime_2023.csv', 'chicago_crime_2021.csv']


In [4]:
chicago_crime = pd.concat((pd.read_csv(path +"/" + f) for f in files), ignore_index=True)

chicago_crime

Unnamed: 0,ID,Case Number,Date,Block,IUCR,Primary Type,Description,Location Description,Arrest,Domestic,...,Ward,Community Area,FBI Code,X Coordinate,Y Coordinate,Year,Updated On,Latitude,Longitude,Location
0,12589893,JF109865,01/11/2022 03:00:00 PM,087XX S KINGSTON AVE,1565,SEX OFFENSE,INDECENT SOLICITATION OF A CHILD,RESIDENCE,False,True,...,7.0,46,17,1194660.0,1847481.0,2022,09/14/2023 03:41:59 PM,41.736409,-87.562410,"(41.736409029, -87.562410309)"
1,12592454,JF113025,01/14/2022 03:55:00 PM,067XX S MORGAN ST,2826,OTHER OFFENSE,HARASSMENT BY ELECTRONIC MEANS,RESIDENCE,False,True,...,16.0,68,26,1170805.0,1860170.0,2022,09/14/2023 03:41:59 PM,41.771782,-87.649437,"(41.771782439, -87.649436929)"
2,12601676,JF124024,01/13/2022 04:00:00 PM,031XX W AUGUSTA BLVD,1752,OFFENSE INVOLVING CHILDREN,AGGRAVATED CRIMINAL SEXUAL ABUSE BY FAMILY MEMBER,RESIDENCE,False,True,...,36.0,23,17,1155171.0,1906486.0,2022,09/14/2023 03:41:59 PM,41.899206,-87.705506,"(41.899206068, -87.705505587)"
3,12785595,JF346553,08/05/2022 09:00:00 PM,072XX S UNIVERSITY AVE,1544,SEX OFFENSE,SEXUAL EXPLOITATION OF A CHILD,APARTMENT,True,False,...,5.0,69,17,1185135.0,1857211.0,2022,09/14/2023 03:41:59 PM,41.763338,-87.597001,"(41.763337967, -87.597001131)"
4,12808281,JF373517,08/14/2022 02:00:00 PM,055XX W ARDMORE AVE,1562,SEX OFFENSE,AGGRAVATED CRIMINAL SEXUAL ABUSE,RESIDENCE,False,False,...,39.0,11,17,1138383.0,1937953.0,2022,09/14/2023 03:41:59 PM,41.985875,-87.766404,"(41.985875279, -87.766403857)"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
648826,26461,JE455267,11/24/2021 12:51:00 AM,107XX S LANGLEY AVE,0110,HOMICIDE,FIRST DEGREE MURDER,VACANT LOT,False,False,...,9.0,50,01A,1182822.0,1833730.0,2021,09/19/2022 03:41:05 PM,41.698957,-87.606206,"(41.698957409, -87.606205674)"
648827,26041,JE281927,06/28/2021 01:12:00 AM,117XX S LAFLIN ST,0110,HOMICIDE,FIRST DEGREE MURDER,AUTO,False,False,...,34.0,53,01A,1168442.0,1826982.0,2021,09/01/2022 03:42:17 PM,41.680761,-87.659052,"(41.680760863, -87.659051873)"
648828,26238,JE353715,08/29/2021 03:07:00 AM,010XX N LAWNDALE AVE,0110,HOMICIDE,FIRST DEGREE MURDER,STREET,False,False,...,27.0,23,01A,1151525.0,1906643.0,2021,09/19/2022 03:41:05 PM,41.899709,-87.718893,"(41.899709327, -87.718893208)"
648829,26479,JE465230,12/03/2021 08:37:00 PM,000XX W 78TH PL,0110,HOMICIDE,FIRST DEGREE MURDER,PORCH,True,False,...,6.0,69,01A,1177156.0,1852951.0,2021,09/01/2022 03:42:17 PM,41.751832,-87.626374,"(41.751831742, -87.626373808)"


In [5]:
def create_message(table_name, query):

    class message:
        def __init__(message, system, user, column_names, column_attr):
            message.system = system
            message.user = user
            message.column_names = column_names
            message.column_attr = column_attr

    
    system_template = """

    Given the following SQL table, your job is to write queries given a user’s request. \n

    CREATE TABLE {} ({}) \n
    """

    user_template = "Write a SQL query that returns - {}"
    
    tbl_describe = duckdb.sql("DESCRIBE SELECT * FROM " + table_name +  ";")
    col_attr = tbl_describe.df()[["column_name", "column_type"]]
    col_attr["column_joint"] = col_attr["column_name"] + " " +  col_attr["column_type"]
    col_names = str(list(col_attr["column_joint"].values)).replace('[', '').replace(']', '').replace('\'', '')

    system = system_template.format(table_name, col_names)
    user = user_template.format(query)

    m = message(system = system, user = user, column_names = col_attr["column_name"], column_attr = col_attr["column_type"])
    return m

def add_quotes(query, col_names):
    for i in col_names:
        if i in query:
            query = str(query).replace(i, '"' + i + '"') 
    return(query)


In [6]:
query = "How many cases ended up with arrest?"
msg = create_message(table_name = "chicago_crime", query = query)

print(msg.system)
print(msg.user)
print(msg.column_names)
print(msg.column_attr)



    Given the following SQL table, your job is to write queries given a user’s request. 


    CREATE TABLE chicago_crime (ID BIGINT, Case Number VARCHAR, Date VARCHAR, Block VARCHAR, IUCR VARCHAR, Primary Type VARCHAR, Description VARCHAR, Location Description VARCHAR, Arrest BOOLEAN, Domestic BOOLEAN, Beat BIGINT, District BIGINT, Ward DOUBLE, Community Area BIGINT, FBI Code VARCHAR, X Coordinate DOUBLE, Y Coordinate DOUBLE, Year BIGINT, Updated On VARCHAR, Latitude DOUBLE, Longitude DOUBLE, Location VARCHAR) 

    
Write a SQL query that returns - How many cases ended up with arrest?
0                       ID
1              Case Number
2                     Date
3                    Block
4                     IUCR
5             Primary Type
6              Description
7     Location Description
8                   Arrest
9                 Domestic
10                    Beat
11                District
12                    Ward
13          Community Area
14                FBI Code

In [7]:
m = create_message(table_name = "chicago_crime", query = query)

messages = [
    {
      "role": "system",
      "content": m.system
    },
    {
      "role": "user",
      "content": m.user
    }
    ]
print(messages)

[{'role': 'system', 'content': '\n\n    Given the following SQL table, your job is to write queries given a user’s request. \n\n\n    CREATE TABLE chicago_crime (ID BIGINT, Case Number VARCHAR, Date VARCHAR, Block VARCHAR, IUCR VARCHAR, Primary Type VARCHAR, Description VARCHAR, Location Description VARCHAR, Arrest BOOLEAN, Domestic BOOLEAN, Beat BIGINT, District BIGINT, Ward DOUBLE, Community Area BIGINT, FBI Code VARCHAR, X Coordinate DOUBLE, Y Coordinate DOUBLE, Year BIGINT, Updated On VARCHAR, Latitude DOUBLE, Longitude DOUBLE, Location VARCHAR) \n\n    '}, {'role': 'user', 'content': 'Write a SQL query that returns - How many cases ended up with arrest?'}]


In [8]:
# Source: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
pipe = pipeline("text-generation", 
                model="HuggingFaceH4/zephyr-7b-alpha", 
                torch_dtype=torch.bfloat16, 
                device_map="auto")

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [10]:
prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = pipe(prompt, 
               max_new_tokens=256, 
               do_sample=True, 
               temperature=0.1, 
               top_k=1, 
               top_p=0.95)


In [12]:
print(outputs[0]["generated_text"])

<|system|>


    Given the following SQL table, your job is to write queries given a user’s request. 


    CREATE TABLE chicago_crime (ID BIGINT, Case Number VARCHAR, Date VARCHAR, Block VARCHAR, IUCR VARCHAR, Primary Type VARCHAR, Description VARCHAR, Location Description VARCHAR, Arrest BOOLEAN, Domestic BOOLEAN, Beat BIGINT, District BIGINT, Ward DOUBLE, Community Area BIGINT, FBI Code VARCHAR, X Coordinate DOUBLE, Y Coordinate DOUBLE, Year BIGINT, Updated On VARCHAR, Latitude DOUBLE, Longitude DOUBLE, Location VARCHAR) 

    </s>
<|user|>
Write a SQL query that returns - How many cases ended up with arrest?</s>
<|assistant|>
SELECT COUNT(*)
FROM chicago_crime
WHERE Arrest = TRUE;

This query uses the COUNT() function to count the number of rows that meet the condition Arrest = TRUE. The result will be the total number of cases that ended up with an arrest.


In [13]:
print(outputs)

[{'generated_text': '<|system|>\n\n\n    Given the following SQL table, your job is to write queries given a user’s request. \n\n\n    CREATE TABLE chicago_crime (ID BIGINT, Case Number VARCHAR, Date VARCHAR, Block VARCHAR, IUCR VARCHAR, Primary Type VARCHAR, Description VARCHAR, Location Description VARCHAR, Arrest BOOLEAN, Domestic BOOLEAN, Beat BIGINT, District BIGINT, Ward DOUBLE, Community Area BIGINT, FBI Code VARCHAR, X Coordinate DOUBLE, Y Coordinate DOUBLE, Year BIGINT, Updated On VARCHAR, Latitude DOUBLE, Longitude DOUBLE, Location VARCHAR) \n\n    </s>\n<|user|>\nWrite a SQL query that returns - How many cases ended up with arrest?</s>\n<|assistant|>\nSELECT COUNT(*)\nFROM chicago_crime\nWHERE Arrest = TRUE;\n\nThis query uses the COUNT() function to count the number of rows that meet the condition Arrest = TRUE. The result will be the total number of cases that ended up with an arrest.'}]


In [8]:
# pipe = pipeline("text-generation", 
#                 model="HuggingFaceH4/zephyr-7b-alpha", 
#                 torch_dtype=torch.bfloat16, 
#                 device_map="auto")


pipe = pipeline("text-generation", model="defog/sqlcoder-7b")

Downloading (…)lve/main/config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

: 

In [None]:
outputs = pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.1, top_k=50, top_p=0.95)

In [None]:
# https://colab.research.google.com/drive/13BIKsqHnPOBcQ-ba2p77L5saiepTIwu0#scrollTo=8bAMjQKJfG3d

model_name = "defog/sqlcoder"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    load_in_4bit=True,
    device_map="auto",
    use_cache=True,
)