In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM,BitsAndBytesConfig
#FINE-TUNING PRE-TRAINED LLAMA ON TPCDSA DATA

import torch
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training
)

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)

In [5]:
model_name = "codellama/CodeLlama-34b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)

quantization_config = BitsAndBytesConfig(
   load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)


base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    quantization_config=quantization_config,
    device_map='auto',
    use_cache=True
    )

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [7]:
eos_token_id = tokenizer.eos_token_id

In [27]:
#generating responses for all test dataset
import os
import time
import sqlparse
query_time=""
prompt_folder="data/sqlcoder/question_and_schema"

for queryIdx in range(1):
    start=time.time()
    with open(f"{prompt_folder}/combine{queryIdx+1}.txt","r") as file:
        codellama_prompt= file.read()
        print(codellama_prompt)
        print(f"{prompt_folder}/combine{queryIdx+1}.txt")
    print(f"started infernece for {queryIdx+1}")
    inputs = tokenizer(codellama_prompt, return_tensors="pt").to("cuda")
    generated_ids = base_model.generate(**inputs,num_return_sequences=1,eos_token_id=eos_token_id,pad_token_id=eos_token_id,max_new_tokens=1500,do_sample=False,num_beams=1)
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    query_time+= str(time.time() - start)  +"\n"
    print(f"inference finished for query{queryIdx+1}")
    with open(f"data/sqlcoder/queries/query_{queryIdx+1}.txt","w") as response_file:
        content=sqlparse.format(outputs[0].split("```sql")[-1], reindent=True)
        response_file.write(content)
        response_file.close()
        print("completed writing response to a file")

with open(f"data/sqlcoder/querytime/query-time.txt","w") as time_file:
            time_file.write(query_time)
            time_file.close()


Retrieve the customer IDs of customers who have returned items significantly more frequently (over 20% above the average return rate) than the average customer return rate for a store in the state of South Dakota in the year 2000. This query calculates the total returns for each customer at each store and compares it to the store's average returns for the specified year. It then retrieves the customer IDs of those whose returns exceed the calculated threshold. The results are sorted by customer ID, and a limit of 100 records is applied

CREATE TABLE store_returns (  sr_returned_date_sk,  sr_return_time_sk,  sr_item_sk,  sr_customer_sk,  sr_cdemo_sk,  sr_hdemo_sk,  sr_addr_sk,  sr_store_sk,  sr_reason_sk,  sr_ticket_number,  sr_return_quantity,  sr_return_amt,  sr_return_tax,  sr_return_amt_inc_tax,  sr_fee,  sr_return_ship_cost,  sr_refunded_cash,  sr_reversed_charge,  sr_store_credit,  sr_net_loss );

CREATE TABLE date_dim (  d_date_sk,  d_date_id,  d_date,  d_month_seq,  d_week_seq, 

In [None]:
#execute randomly sampled 99 queries on duckdb


import duckdb
import os
import time

# Define the connection parameters
query_path='data/sqlcoder/queries/'
query_result='data/sqlcoder/results/'
query_time_file='data/sqlcoder/querytime/log.txt'

# Connect to the DB 
db_con = duckdb.connect()

# Load the TPCDS database
db_con.execute("IMPORT DATABASE '/workspace/data/duckdb/build/release/tpcdssf100'")

print("Loading of TPCDS DB complete.")

# Create a folder to store query results
os.makedirs(query_result, exist_ok=True)

query_files = [f for f in os.listdir(query_path) if os.path.isfile(os.path.join(query_path, f))]
query_files = sorted(query_files)
query_exec_time=""


In [3]:
query_files

['query_1.sql',
 'query_10.sql',
 'query_12.sql',
 'query_13.sql',
 'query_14.sql',
 'query_15.sql',
 'query_16.sql',
 'query_17.sql',
 'query_18.sql',
 'query_19.sql',
 'query_2.sql',
 'query_20.sql',
 'query_21.sql',
 'query_22.sql',
 'query_23.sql',
 'query_24.sql',
 'query_25.sql',
 'query_26.sql',
 'query_27.sql',
 'query_28.sql',
 'query_29.sql',
 'query_3.sql',
 'query_30.sql',
 'query_31.sql',
 'query_32.sql',
 'query_33.sql',
 'query_34.sql',
 'query_35.sql',
 'query_36.sql',
 'query_37.sql',
 'query_38.sql',
 'query_39.sql',
 'query_4.sql',
 'query_40.sql',
 'query_41.sql',
 'query_42.sql',
 'query_43.sql',
 'query_44.sql',
 'query_45.sql',
 'query_46.sql',
 'query_47.sql',
 'query_48.sql',
 'query_49.sql',
 'query_5.sql',
 'query_50.sql',
 'query_52.sql',
 'query_53.sql',
 'query_54.sql',
 'query_55.sql',
 'query_56.sql',
 'query_57.sql',
 'query_58.sql',
 'query_59.sql',
 'query_6.sql',
 'query_60.sql',
 'query_61.sql',
 'query_62.sql',
 'query_63.sql',
 'query_65.sql',
 'q

In [4]:

# Iterate through the queries and execute them
for query_file in query_files:
    print("Executing ", query_file)
    # extract query number from the query file name.
    query_name = os.path.splitext(query_file)[0]

    # Read the query from the file
    with open(os.path.join(query_path, query_file), 'r') as f:
        query=f.read()

    # Execute the query and measure execution time
    try:
        start_time = time.time()
        result = db_con.execute(query)
        execution_time = time.time()-start_time
    
        # Save the query result to a file
        output_file_path = query_result + query_name + ".csv"
        with open(output_file_path, "w") as out_file:
            for row in result.fetchall():
                out_file.write(str(row) + "\n")
        with open(query_time_file, "a") as time_file:
                time_file.write(f'Query {query_name} executed in {execution_time:.2f} seconds' + "\n")

        print(f'Query {query_name} executed in {execution_time:.2f} seconds')
    except Exception as e: 
        print(f'Query {query_name} executed in error: {e}')
# Close the database connection
db_con.close()

Executing  query_1.sql
Query query_1 executed in 0.01 seconds
Executing  query_10.sql
Query query_10 executed in error: Binder Error: Table "c" does not have a column named "c_gender"
LINE 20: GROUP BY c.c_gender,
                  ^
Executing  query_12.sql
Query query_12 executed in 0.13 seconds
Executing  query_13.sql
Query query_13 executed in error: Binder Error: Referenced column "ss_quantity" not found in FROM clause!
Executing  query_14.sql
Query query_14 executed in error: Parser Error: syntax error at or near "'I do not know'"
LINE 1: 'I do not know'
        ^
Executing  query_15.sql
Query query_15 executed in 1.02 seconds
Executing  query_16.sql
Query query_16 executed in error: Binder Error: No function matches the given name and argument types '+(VARCHAR, INTERVAL)'. You might need to add explicit type casts.
	Candidate functions:
	+(TINYINT) -> TINYINT
	+(TINYINT, TINYINT) -> TINYINT
	+(SMALLINT) -> SMALLINT
	+(SMALLINT, SMALLINT) -> SMALLINT
	+(INTEGER) -> INTEGER
	+(INTE

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Query query_22 executed in 16.44 seconds
Executing  query_23.sql
Query query_23 executed in 0.00 seconds
Executing  query_24.sql
Query query_24 executed in error: Binder Error: aggregate function calls cannot be nested
LINE 16: ...c_last_name,
       c.c_first_name,
       s.s_store_name,
       SUM(ss.ss_ext_sales_price) AS total_sales_price
FROM store_sales ss
JOIN item i ON ss.ss_item_sk = i.i_item_sk
JOIN store s ON ss.ss_store_sk = s.s_store_sk
JOIN customer c ON ss.ss_customer_sk = c.c_customer_sk
JOIN customer_address ca ON c.c_current_addr_sk = ca.ca_address_sk
WHERE i.i_color = 'beige'
  AND s.s_market_id = 8
  AND c.c_birth_country != ca.ca_country
GROUP BY c.c_last_name,
         c.c_first_name,
         s.s_store_name
HAVING SUM(ss.ss_ext_sales_price) > (0.05 * AVG(SUM(ss.ss_ext_sales_price)))
                                                   ^
Executing  query_25.sql
Query query_25 executed in error: Conversion Error: Unimplemented type for cast (INTEGER -> DATE)
Executin

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Query query_39 executed in 4.05 seconds
Executing  query_4.sql
Query query_4 executed in error: Parser Error: syntax error at or near ";"
LINE 61: AND s;
              ^
Executing  query_40.sql
Query query_40 executed in 0.04 seconds
Executing  query_41.sql
Query query_41 executed in 0.02 seconds
Executing  query_42.sql
Query query_42 executed in 0.10 seconds
Executing  query_43.sql
Query query_43 executed in error: Catalog Error: Table with name fact_online_sales does not exist!
Did you mean "store_sales"?
Executing  query_44.sql
Query query_44 executed in 0.43 seconds
Executing  query_45.sql
Query query_45 executed in error: Binder Error: No function matches the given name and argument types 'date_part(VARCHAR, INTEGER)'. You might need to add explicit type casts.
	Candidate functions:
	date_part(VARCHAR, DATE) -> BIGINT
	date_part(VARCHAR, TIMESTAMP) -> BIGINT
	date_part(VARCHAR, TIME) -> BIGINT
	date_part(VARCHAR, INTERVAL) -> BIGINT
	date_part(VARCHAR[], DATE) -> STRUCT()
	date_pa

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

Query query_72 executed in 2.84 seconds
Executing  query_73.sql
Query query_73 executed in 0.01 seconds
Executing  query_74.sql
Query query_74 executed in error: Binder Error: Referenced table "c" not found!
Candidate tables: "year_total"
LINE 1: SELECT c.c_first_name,
               ^
Executing  query_75.sql
Query query_75 executed in error: Parser Error: syntax error at or near ";"
LINE 44: WHERE CAST(curr_yr.;
                            ^
Executing  query_76.sql
Query query_76 executed in error: Parser Error: syntax error at or near "UNION"
LINE 23: UNION ALL
         ^
Executing  query_77.sql
Query query_77 executed in error: Parser Error: syntax error at or near "UNION"
LINE 14: UNION ALL
         ^
Executing  query_78.sql
Query query_78 executed in 1.52 seconds
Executing  query_79.sql
Query query_79 executed in 0.01 seconds
Executing  query_8.sql
Query query_8 executed in 0.01 seconds
Executing  query_80.sql
Query query_80 executed in error: Binder Error: Table "ss" does not hav