<a href="https://colab.research.google.com/github/SupriyaUpadhyaya/HCNLP-NL2SQL-Project/blob/main/Final_NL2SQL_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
# Installs Unsloth, Xformers (Flash Attention) and all other packages!
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.26" trl peft accelerate bitsandbytes
!pip install langchain_community
!pip install langchain

In [2]:
 from langchain.memory import ChatMessageHistory
 from langchain_core.prompts import MessagesPlaceholder, ChatPromptTemplate
 from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
 from langchain_community.utilities.sql_database import SQLDatabase
 from unsloth import FastLanguageModel

In [3]:
 history = ChatMessageHistory()

In [4]:
db = SQLDatabase.from_uri("sqlite:////content/worlddb.db", sample_rows_in_table_info=2)

print(db.table_info)


CREATE TABLE city (
	"ID" INTEGER NOT NULL, 
	"Name" CHAR(35) DEFAULT '' NOT NULL, 
	"CountryCode" CHAR(3) DEFAULT '' NOT NULL, 
	"District" CHAR(20) DEFAULT '' NOT NULL, 
	"Population" INTEGER DEFAULT '0' NOT NULL, 
	PRIMARY KEY ("ID"), 
	FOREIGN KEY("CountryCode") REFERENCES country ("Code")
)

/*
2 rows from city table:
ID	Name	CountryCode	District	Population
1	Kabul	AFG	Kabol	1780000
2	Qandahar	AFG	Qandahar	237500
*/


CREATE TABLE country (
	"Code" CHAR(3) DEFAULT '' NOT NULL, 
	"Name" CHAR(52) DEFAULT '' NOT NULL, 
	"Continent" TEXT DEFAULT 'Asia' NOT NULL, 
	"Region" CHAR(26) DEFAULT '' NOT NULL, 
	"SurfaceArea" DECIMAL(10, 2) DEFAULT '0.00' NOT NULL, 
	"IndepYear" SMALLINT DEFAULT NULL, 
	"Population" INTEGER DEFAULT '0' NOT NULL, 
	"LifeExpectancy" DECIMAL(3, 1) DEFAULT NULL, 
	"GNP" DECIMAL(10, 2) DEFAULT NULL, 
	"GNPOld" DECIMAL(10, 2) DEFAULT NULL, 
	"LocalName" CHAR(45) DEFAULT '' NOT NULL, 
	"GovernmentForm" CHAR(45) DEFAULT '' NOT NULL, 
	"HeadOfState" CHAR(60) DEFAULT 

In [5]:
text_to_sql_tmpl_str = """\
### Instruction:\n{system_message}{user_message}\n\n### Response:\n{response}"""

text_to_sql_inference_tmpl_str = """\
### Instruction:\n{system_message}{user_message}\n\n### Response:\n"""

def _generate_prompt_sql(input, context, dialect="sqlite", output="", messages=""):
    system_message = f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

You must output the SQL query that answers the question. Use the previous conversation to answer the follow up questions. Do not provide any explanation

    """
    user_message = f"""### Dialect:
{dialect}

### Input:
{input}

### Context:
{context}

### Previous Conversation:
{messages}

### Response:
"""
    if output:
        return text_to_sql_tmpl_str.format(
            system_message=system_message,
            user_message=user_message,
            response=output,
        )
    else:
        return text_to_sql_inference_tmpl_str.format(
            system_message=system_message, user_message=user_message
        )

In [6]:
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

In [7]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "supriyaupadhyaya/llama-3-8b-bnb-4bit-text-to-sql",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
FastLanguageModel.for_inference(model)



config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

==((====))==  Unsloth: Fast Llama patching release 2024.4
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.1+cu121. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. Xformers = 0.0.25.post1. FA = False.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


adapter_config.json:   0%|          | 0.00/732 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.14k [00:00<?, ?B/s]

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/131 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/168M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/449 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [8]:
question = "How many cities in the USA?"
context = db.table_info
messages = history.messages
text2sql_tmpl_str = _generate_prompt_sql(
        question, context, dialect="sqlite", output="", messages=messages
    )

In [9]:
text2sql_tmpl_str

'### Instruction:\nYou are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.\n\nYou must output the SQL query that answers the question. Use the previous conversation to answer the follow up questions. Do not provide any explanation\n\n    ### Dialect:\nsqlite\n\n### Input:\nHow many cities in the USA?\n\n### Context:\n\nCREATE TABLE city (\n\t"ID" INTEGER NOT NULL, \n\t"Name" CHAR(35) DEFAULT \'\' NOT NULL, \n\t"CountryCode" CHAR(3) DEFAULT \'\' NOT NULL, \n\t"District" CHAR(20) DEFAULT \'\' NOT NULL, \n\t"Population" INTEGER DEFAULT \'0\' NOT NULL, \n\tPRIMARY KEY ("ID"), \n\tFOREIGN KEY("CountryCode") REFERENCES country ("Code")\n)\n\n/*\n2 rows from city table:\nID\tName\tCountryCode\tDistrict\tPopulation\n1\tKabul\tAFG\tKabol\t1780000\n2\tQandahar\tAFG\tQandahar\t237500\n*/\n\n\nCREATE TABLE country (\n\t"Code" CHAR(3) DEFAULT \'\' NOT NULL, \n\t"Name" CHAR(52) DEFAULT \'\' NOT NULL, \

In [49]:
inputs = tokenizer(text2sql_tmpl_str, return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
input_length = inputs["input_ids"].shape[1]
response = tokenizer.batch_decode(
      outputs[:, input_length:], skip_special_tokens=True
    )
query = response[0]

# Print the generated SQL query.
print(query)
history.add_user_message(question)
history.add_ai_message(query)
if len(messages) > 10:
  messages.pop()
  messages.pop()

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


SELECT COUNT("Name") FROM city WHERE "CountryCode" = 'USA'


In [50]:
execute_query = QuerySQLDataBaseTool(db=db)
answer = execute_query.invoke(response[0])
answer

'[(274,)]'

In [12]:
refiner_template = """
【Instruction】
When executing SQL below, some errors occurred, please fix up SQL based on query and database info.
Solve the task step by step if you need to. Using SQL format in the code block, and indicate script type in the code block.
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
【Constraints】
- In `SELECT <column>`, just select needed columns in the 【Question】 without any unnecessary column or value
- In `FROM <table>` or `JOIN <table>`, do not include unnecessary table
- If use max or min func, `JOIN <table>` FIRST, THEN use `SELECT MAX(<column>)` or `SELECT MIN(<column>)`
- If [Value examples] of <column> has 'None' or None, use `JOIN <table>` or `WHERE <column> is NOT NULL` is better
- If use `ORDER BY <column> ASC|DESC`, add `GROUP BY <column>` before to select distinct values
【Query】
-- {query}
【Evidence】
{evidence}
【Database info】
{desc_str}
【Foreign keys】
{fk_str}
【old SQL】
```sql
{sql}
```
【SQLite error】
{sqlite_error}
【Exception class】
{exception_class}

Now please fixup old SQL and generate new SQL again.
【correct SQL】
"""

In [30]:
from core.utils import parse_json, parse_sql_from_string, add_prefix, load_json_file, extract_world_info, is_email, is_valid_date_column
import sqlite3

In [73]:
class Refiner():

  def __init__(self, data_path: str, dataset_name: str):
        super().__init__()
        self.data_path = data_path  # path to all databases
        self.dataset_name = dataset_name
        #self._message = {}

  def _execute_sql(self, sql: str, question: str) -> dict:
        # Get database connection
        db_path = self.data_path
        conn = sqlite3.connect(db_path)
        conn.text_factory = lambda b: b.decode(errors="ignore")
        cursor = conn.cursor()
        try:
            cursor.execute(sql)
            result = cursor.fetchall()
            return {
                "question": question,
                "sql": str(sql),
                "data": result[:5],
                "sqlite_error": "",
                "exception_class": ""
            }
        except sqlite3.Error as er:
            return {
                "question": question,
                "sql": str(sql),
                "sqlite_error": str(' '.join(er.args)),
                "exception_class": str(er.__class__)
            }
        except Exception as e:
            return {
                "question": question,
                "sql": str(sql),
                "sqlite_error": str(e.args),
                "exception_class": str(type(e).__name__)
            }

  def _is_need_refine(self, exec_result: dict):
        # spider exist dirty values, even gold sql execution result is None
        if self.dataset_name == 'worlddb':
            if 'data' not in exec_result:
                return True
            return False

        data = exec_result.get('data', None)
        if data is not None:
            if len(data) == 0:
                exec_result['sqlite_error'] = 'no data selected'
                return True
            for t in data:
                for n in t:
                     if n is None:  # fixme fixme fixme fixme fixme
                        exec_result['sqlite_error'] = 'exist None value, you can add `NOT NULL` in SQL'
                        return True
            return False
        else:
            return True

  def _refine(self,
               query: str,
               evidence:str,
               schema_info: str,
               fk_info: str,
               error_info: dict) -> dict:

        sql_arg = add_prefix(error_info.get('sql'))
        sqlite_error = error_info.get('sqlite_error')
        exception_class = error_info.get('exception_class')
        prompt = refiner_template.format(query=query, evidence=evidence, desc_str=schema_info, \
                                       fk_str=fk_info, sql=sql_arg, sqlite_error=sqlite_error, \
                                        exception_class=exception_class)

        #word_info = extract_world_info(self._message)
        inputs = tokenizer(text2sql_tmpl_str, return_tensors = "pt").to("cuda")
        outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
        input_length = inputs["input_ids"].shape[1]
        response = tokenizer.batch_decode(
        outputs[:, input_length:], skip_special_tokens=True)
        query = response[0]
        return query

In [80]:
count = 0
refiner = Refiner(data_path="/content/worlddb.db", dataset_name='worlddb')
query_generated = query
exec_result = refiner._execute_sql(sql=query_generated, question=question)
print(exec_result)
while count <= 5:
  is_refine_required = refiner._is_need_refine(exec_result=exec_result)
  print("is_refine_required :", is_refine_required)
  if is_refine_required:
    print("In if condiition")
    query_generated = refiner._refine(query=query_generated, evidence=exec_result, schema_info=db.table_info, fk_info="", error_info=exec_result)
    exec_result = refiner._execute_sql(sql=query_generated, question=question)
    print(exec_result)
    count += 1
    print(query_generated)
  else:
    print("in else condition")
    count = 6


{'question': 'How many cities in the USA?', 'sql': 'SELECT COUNT("Name") FROM city WHERE "CountryCode" = \'USA\'', 'data': [(274,)], 'sqlite_error': '', 'exception_class': ''}
is_refine_required : False
in else condition


In [79]:
answer_prompt = f'''Given the following user question, corresponding SQL query, and SQL result, answer the user question in a sentence.

 Question: {exec_result['question']}
 SQL Query: {exec_result['sql']}
 SQL Result: {exec_result['data']}
 Answer:'''

inputs = tokenizer(answer_prompt, return_tensors = "pt").to("cuda")

outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
input_length = inputs["input_ids"].shape[1]
response = tokenizer.batch_decode(
      outputs[:, input_length:], skip_special_tokens=True
    )

response[0]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


' There are 274 cities in the USA.'