In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM,BitsAndBytesConfig
#FINE-TUNING PRE-TRAINED LLAMA ON TPCDSA DATA

import torch
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training
)

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)

In [2]:
model_name = "codellama/CodeLlama-34b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)

quantization_config = BitsAndBytesConfig(
   load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)


base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    quantization_config=quantization_config,
    device_map='auto',
    use_cache=True
    )

    

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [3]:
eos_token_id = tokenizer.eos_token_id

The schema-induced prompt templates are provided in the `prompt_folder`. The below script automatically picks the prompts and generates the requires schema specific to each query and then executes on the SQLCoder-34B. 

In [8]:
#generating responses for all test dataset
import os
import time
import sqlparse
query_time=""
prompt_folder="../../data/sqlcoder/question_and_schema"

for queryIdx in range(1):
    start=time.time()
    with open(f"{prompt_folder}/combine{queryIdx+1}.txt","r") as file:
        codellama_prompt= file.read()
        print(codellama_prompt)
    print(f"started infernece for {queryIdx+1}")
    inputs = tokenizer(codellama_prompt, return_tensors="pt").to("cuda")
    generated_ids = base_model.generate(**inputs,num_return_sequences=1,eos_token_id=eos_token_id,pad_token_id=eos_token_id,max_new_tokens=1500,do_sample=False,num_beams=1)
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    query_time+= str(time.time() - start)  +"\n"
    print(f"inference finished for query{queryIdx+1}")
    with open(f"../../data/sqlcoder/queries/query_{queryIdx+1}.txt","w") as response_file:
        content=sqlparse.format(outputs[0].split("```sql")[-1], reindent=True)
        response_file.write(content)
        response_file.close()
        print("completed writing response to a file")

with open(f"../../data/sqlcoder/querytime/query-time.txt","w") as time_file:
            time_file.write(query_time)
            time_file.close()


Retrieve the customer IDs of customers who have returned items significantly more frequently (over 20% above the average return rate) than the average customer return rate for a store in the state of South Dakota in the year 2000. This query calculates the total returns for each customer at each store and compares it to the store's average returns for the specified year. It then retrieves the customer IDs of those whose returns exceed the calculated threshold. The results are sorted by customer ID, and a limit of 100 records is applied

CREATE TABLE store_returns (  sr_returned_date_sk,  sr_return_time_sk,  sr_item_sk,  sr_customer_sk,  sr_cdemo_sk,  sr_hdemo_sk,  sr_addr_sk,  sr_store_sk,  sr_reason_sk,  sr_ticket_number,  sr_return_quantity,  sr_return_amt,  sr_return_tax,  sr_return_amt_inc_tax,  sr_fee,  sr_return_ship_cost,  sr_refunded_cash,  sr_reversed_charge,  sr_store_credit,  sr_net_loss );

CREATE TABLE date_dim (  d_date_sk,  d_date_id,  d_date,  d_month_seq,  d_week_seq, 

# NOTE

As discussed in the article, automatic execution of the SQL queries directly with duckDB cannot work in case of SQLCoder-34B because of the above boilerplate text. One requires to manually edit the file stored in the `data/sqlcoder/queries/query_{queryIdx+1}.txt` (Yep, we know it is a headache but it is what baseline model generates as of now). This requires constant human intervention. 

To save time and resources, we have preloaded the sqlcoder queries folder with only the successfully executeable queires.

In [12]:
import duckdb
import os
import time

# Define the connection parameters
query_path='../../data/sqlcoder/queries/'
query_result='../../data/sqlcoder/results/'
query_time_file='../../data/sqlcoder/results/querytime/log.txt'

# Connect to the DB 
db_con = duckdb.connect()

# Load the TPCDS database
db_con.execute("IMPORT DATABASE '/workspace/data/cs598-tpcds/data/duckdb/tpcds_sf100'")

print("Loading of TPCDS DB complete.")

# Create a folder to store query results
os.makedirs(query_result, exist_ok=True)

query_files = [f for f in os.listdir(query_path) if os.path.isfile(os.path.join(query_path, f))]
query_files = sorted(query_files)
query_exec_time=""


Loading of TPCDS DB complete.


In [13]:
# Iterate through the queries and execute them
for query_file in query_files:
    print("Executing ", query_file)
    # extract query number from the query file name.
    query_name = os.path.splitext(query_file)[0]

    # Read the query from the file
    with open(os.path.join(query_path, query_file), 'r') as f:
        query=f.read()

    # Execute the query and measure execution time
    try:
        start_time = time.time()
        result = db_con.execute(query)
        execution_time = time.time()-start_time
    
        # Save the query result to a file
        output_file_path = query_result + query_name + ".csv"
        with open(output_file_path, "w") as out_file:
            for row in result.fetchall():
                out_file.write(str(row) + "\n")
        with open(query_time_file, "a") as time_file:
                time_file.write(f'Query {query_name} executed in {execution_time:.2f} seconds' + "\n")

        print(f'Query {query_name} executed in {execution_time:.2f} seconds')
    except Exception as e: 
        print(f'Query {query_name} executed in error: {e}')
# Close the database connection
db_con.close()

Executing  query_1.txt
Query query_1 executed in error: Parser Error: syntax error at or near "Retrieve"
LINE 1: Retrieve the customer IDs of customers ...
        ^
Executing  query_22.sql
Query query_22 executed in 0.35 seconds
Executing  query_37.sql
Query query_37 executed in 0.01 seconds
Executing  query_95.sql
Query query_95 executed in 0.01 seconds


: 