Dear user, now you are entering an uncharted location that is not defined and you are very adventours. I warn you, you will be surprised by the superior performance of the model. However, you soon will realize the model only memorized the queries and did not learn. Given this inability, if you still want to pursue, the below code would work, but may not provide any good insights that you are seeking for yet. 

In [None]:
#FINE-TUNING PRE-TRAINED LLAMA ON TPCDSA DATA

import torch
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    prepare_model_for_kbit_training
)

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from datasets import load_dataset
from trl import SFTTrainer
from transformers import pipeline
from transformers import DataCollatorForLanguageModeling

In [None]:
from datasets import load_from_disk
train_splits=load_from_disk('./dataset//dataset-splits/train-split')
val_splits=load_from_disk('./dataset/dataset-splits/val-split')
test_splits=load_from_disk('./dataset/dataset-splits/test-split')


In [None]:
#INFERENCE MODE -finetuned model

#combine base and pretrained model - since pre-training using LORA is like adding new weights to the base model, that is why you need to merge the adapter and base model
# Reload model in FP16 and merge it with LoRA weights
quantization_config = BitsAndBytesConfig(
   load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

model_id = "../models/7B/output"
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    quantization_config=quantization_config,
    device_map='auto'
    )



new_model_id = "../models/llama-2-7b-finetuned-text2SQL"
new_model = AutoModelForCausalLM.from_pretrained(
    new_model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    quantization_config=quantization_config,
    device_map='auto'
   )

In [None]:
model = PeftModel.from_pretrained(base_model, new_model_id)


In [None]:
# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True,device="cuda")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"


In [None]:
generation_config = base_model.generation_config
generation_config.max_new_tokens = 2048
#generation_config.temperature = 0.7
#generation_config.top_p = 0.7
#generation_config.num_return_sequences = 1
generation_config.pad_token_id = tokenizer.eos_token_id
generation_config.eos_token_id = tokenizer.eos_token_id

In [None]:
#generating responses for all test dataset
import os
import time
query_time=""
for queryIdx in range(len(test_splits)):
    start=time.time()
    print(f"started infernece for {queryIdx+1}")
    with torch.inference_mode():
        encoding = tokenizer(test_splits[queryIdx]["text"], return_tensors="pt").to("cuda")
        outputs = model.generate(input_ids = encoding.input_ids, attention_mask = encoding.attention_mask,generation_config = generation_config) 
        query_time+= str(time.time() - start)  +"\n"
        print(f"inference finished for query{queryIdx+1}")
        with open(f"./test-results/llm-results/query_{queryIdx+1}.txt","w") as response_file:
            response_file.write(tokenizer.decode(outputs[0], skip_special_tokens=True))
            response_file.close()
        print("completed writing response ot a file")

with open(f"./test-results/querytime/query-time.txt","w") as time_file:
            time_file.write(query_time)
            time_file.close()



    
    


In [None]:
#this code randomly selects 99 queries from llama results (random sampling) to execute on db

import os
import shutil
import random

# Paths for source and destination folders
src_folder = './test-results/llm-responses'
dest_folder = './test-results/sampled-queries'

# Ensure the destination folder exists
os.makedirs(dest_folder, exist_ok=True)

# Generate a set of unique random numbers
random_numbers = set(random.sample(range(1, 218), k=99))  # Adjust 'k' for the number of files you want to copy

# Copy files based on generated numbers
for number in random_numbers:
    file_name = f'query_{number}.txt'
    src_path = os.path.join(src_folder, file_name)
    dest_path = os.path.join(dest_folder, file_name)

    # Copy the file from the source to the destination
    shutil.copy(src_path, dest_path)

print(f"Files successfully copied to {dest_folder}")


In [None]:
#execute randomly sampled 99 queries on duckdb


import duckdb
import os
import time

# Define the connection parameters
query_path='./test-results/sampled-queries/'
query_result='./test-results/db-results/llm-resp-db-results/results/'
query_time_file='./test-results/db-results/llm-resp-db-results/querytime/log.txt'

# Connect to the DB 
db_con = duckdb.connect()

# Load the TPCDS database
db_con.execute("IMPORT DATABASE '/workspace/data/cs598-tpcds/data/duckdb/tpcds_sf100'")

print("Loading of TPCDS DB complete.")

# Create a folder to store query results
os.makedirs(query_result, exist_ok=True)

query_files = [f for f in os.listdir(query_path) if os.path.isfile(os.path.join(query_path, f))]
query_files = sorted(query_files)
query_exec_time=""
# Iterate through the queries and execute them
for query_file in query_files:
    print("Executing ", query_file)
    # extract query number from the query file name.
    query_name = os.path.splitext(query_file)[0]

    # Read the query from the file
    with open(os.path.join(query_path, query_file), 'r') as f:
        query=f.read()

    # Execute the query and measure execution time
    try:
        start_time = time.time()
        result = db_con.execute(query)
        execution_time = time.time()-start_time
    
        # Save the query result to a file
        output_file_path = query_result + query_name + ".csv"
        with open(output_file_path, "w") as out_file:
            # print("Writing result:", out_file)
            for row in result.fetchall():
                out_file.write(str(row) + "\n")
                # print(row)
            out_file.close()    
    except Exception as e: 
        print(f'Query {query_name} executed in error: {e}')
    query_exec_time+= str(execution_time)+"\n"
        
with open(query_time_file, "a") as time_file:
    time_file.write(query_exec_time)

    print(f'Query {query_name} executed in {execution_time:.2f} seconds')

# Close the database connection
db_con.close()


In [None]:
#execute all sampled queries's golden answer on duckdb for comparison


import duckdb
import os
import time
import re
from datasets import load_from_disk

# Define the connection parameters
query_path='./test-results/sampled-queries/'
query_result='./test-results/db-results/golden-query-db-results/results/'
query_time_file='./test-results/db-results/golden-query-db-results/querytime/log.txt'

#load test splits as each of the sample has golden truth 
test_splits=load_from_disk('./dataset-splits/test-split')
# Connect to the DB 
db_con = duckdb.connect()

# Load the TPCDS database
db_con.execute("IMPORT DATABASE '/workspace/data/cs598-tpcds/data/duckdb/tpcds_sf100'")

print("Loading of TPCDS DB complete.")

# Create a folder to store query results
os.makedirs(query_result, exist_ok=True)

query_files = [f for f in os.listdir(query_path) if os.path.isfile(os.path.join(query_path, f))]
query_files = sorted(query_files)
query_exec_time=""


In [None]:

for query_file in query_files:
    print("Executing ", query_file)
    # extract query number from the query file name.
    query_name = os.path.splitext(query_file)[0]
    #regex to extract query number from query_10.txt 
    query_number=int(re.search(r'\d+',query_name).group())
    # print(query_name)
    # print("query_number",query_number-1)


    sql_query=test_splits[query_number-1]["output"]
    # print("sql_query",sql_query)
    # Execute the query and measure execution time
    try:
        start_time = time.time()
        result = db_con.execute(sql_query)
        execution_time = time.time()-start_time
    
        # Save the query result to a file
        output_file_path = query_result + query_name + ".csv"
        with open(output_file_path, "w") as out_file:
            # print("Writing result:", out_file)
            for row in result.fetchall():
                out_file.write(str(row) + "\n")
                # print(row)
            out_file.close()    
    except Exception as e: 
        print(f'Query {query_name} executed in error: {e}')
    query_exec_time+= str(execution_time)+"\n"
        
with open(query_time_file, "a") as time_file:
    time_file.write(query_exec_time)

    print(f'Query {query_name} executed in {execution_time:.2f} seconds')
# Close the database connection
db_con.close()


In [None]:
#sample inferene running on final fine-tuned model
test_prompt='''
'\n\n[SYSTEM]:"You are an expert Text-to-SQL generator assistant. Your goal is to provide correct SQL queries to the given text description. Your output only contains the SQL code. No explanation or introductory sentences surrounding the SQL response is needed. You are given schema information. Here is the schema information: \n<tableName>store_returns</tableName>\n<columns>sr_returned_date_sk,  sr_return_time_sk,  sr_item_sk,  sr_customer_sk,  sr_cdemo_sk,  sr_hdemo_sk,  sr_addr_sk,  sr_store_sk,  sr_reason_sk,  sr_ticket_number,  sr_return_quantity,  sr_return_amt,  sr_return_tax,  sr_return_amt_inc_tax,  sr_fee,  sr_return_ship_cost,  sr_refunded_cash,  sr_reversed_charge,  sr_store_credit,  sr_net_loss</columns>\n<tableName>date_dim</tableName>\n<columns>d_date_sk,  d_date_id,  d_date,  d_month_seq,  d_week_seq,  d_quarter_seq,  d_year,  d_dow,  d_moy,  d_dom,  d_qoy,  d_fy_year,  d_fy_quarter_seq,  d_fy_week_seq,  d_day_name,  d_quarter_name,  d_holiday,  d_weekend,  d_following_holiday,  d_first_dom,  d_last_dom,  d_same_day_ly,  d_same_day_lq,  d_current_day,  d_current_week,  d_current_month,  d_current_quarter,  d_current_year</columns>\n<tableName>store</tableName>\n<columns>s_store_sk,  s_store_id,  s_rec_start_date,  s_rec_end_date,  s_closed_date_sk,  s_store_name,  s_number_employees,  s_floor_space,  s_hours,  s_manager,  s_market_id,  s_geography_class,  s_market_desc,  s_market_manager,  s_division_id,  s_division_name,  s_company_id,  s_company_name,  s_street_number,  s_street_name,  s_street_type,  s_suite_number,  s_city,  s_county,  s_state,  s_zip,  s_country,  s_gmt_offset,  s_tax_percentage</columns>\n<tableName>customer</tableName>\n<columns>c_customer_sk,  c_customer_id,  c_current_cdemo_sk,  c_current_hdemo_sk,  c_current_addr_sk,  c_first_shipto_date_sk,  c_first_sales_date_sk,  c_salutation,  c_first_name,  c_last_name,  c_preferred_cust_flag,  c_birth_day,  c_birth_month,  c_birth_year,  c_birth_country,  c_login,  c_email_address,  c_last_review_date_sk</columns>\n. Here are the 5 critical rules for the interactions you must abide: <rules> 1. Do not wrap the generated SQL code within SQL code markdown format. Also, do not include the SQL keyword in the beginning of the response. 2. If I don\'t tell you to find the limited set of results, limit to 100. 3. Only use table and columns from the list provided 4. When performing aliasing, make sure to refer the aliased tables as alias.column_name and not as alias_column_name. 5. For US state names, use abbreviated forms. For example, for South Dakota state, use SD.</rules> \n\n"Here is the user question:"[/SYSTEM]\n[HUMAN]: For the state of South Dakota in the year 2000, identify the first 100 customers, sorted by their IDs, whose returns are notably higher, exceeding the store\'s average by more than 20%, indicating a trend of higher returns.\n[/HUMAN]\n\n
'''
encoding = tokenizer(test_prompt, return_tensors="pt").to("cuda")
with torch.inference_mode():
  outputs = model.generate(
      input_ids = encoding.input_ids,
      attention_mask = encoding.attention_mask,
      generation_config = generation_config,
      
      
    
  )

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
with open(f"./test-q1-SD.txt","w") as file:
  file.write(str(tokenizer.decode(outputs[0], skip_special_tokens=True)))
