<a href="https://colab.research.google.com/github/CalvHobbes/ecomm_ai/blob/main/nsql6b_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install -q psycopg2-binary
%pip install -q 'vanna[chromadb,ollama,postgres]'
%pip install -q langchain_community
%pip install -q bitsandbytes accelerate

In [None]:
def extract_from_default_prompt(prompt):
    DDL_statement = ""
    user_question = ""
    instruct = ''
    instructions = []

    for entry in prompt:
        role = entry['role']
        content = entry['content']

        if role == 'system':
            DDL_statement = content
        elif role == 'user':
            user_question = content
        elif role == 'assistant':
            instructions.append((prev_user_content, content))

        # Lưu lại nội dung user trước đó
        if role == 'user':
            prev_user_content = content

    instruct += 'SAMPLE QUESTION-QUERY:\n\n'
    for q, s in instructions:
        instruct += f'Sample question: {q}\nSample query: {s}\n\n'
        instruct += f''

    doc = DDL_statement.split('===Additional Context')[1].split('===Response Guidelines')[0]
    doc = 'Additional Context: ' + doc + 'Refer to the sample question-query below.\n\n'
    instruct = doc + instruct

    DDL_statement = DDL_statement.split('===Additional Context')[0]

    return DDL_statement, user_question, instruct

In [None]:
from vanna.base import VannaBase
from transformers import AutoTokenizer, AutoModelForCausalLM

import json
import sqlparse
import torch

class SQLCoder8b(VannaBase):
    def __init__(self, config=None):
        if config is None:
            raise ValueError(
                "For SQLCoder, config must be provided at least model_name"
            )

        self.model_name = config.get('model_name')
        if not self.model_name:
            raise ValueError("config must contain model_name")

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(self.model_name,
                                                          trust_remote_code=True,
                                                          torch_dtype=torch.float16,
                                                          device_map="auto",
                                                          use_cache=True,)

    def system_message(self, message: str) -> any:
        return {"role": "system", "content": message}

    def user_message(self, message: str) -> any:
        return {"role": "user", "content": message}

    def assistant_message(self, message: str) -> any:
        return {"role": "assistant", "content": message}

    def submit_prompt(self, prompt, **kwargs) -> str:
        if prompt is None:
            raise Exception("Prompt is None")

        if len(prompt) == 0:
            raise Exception("Prompt is empty")

        create_table_statements, user_question, instructions = extract_from_default_prompt(prompt)
        updated_prompt = f'''<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{user_question}`
{instructions}

DDL statements:
{create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{user_question}`:
```sql'''

        device = "cuda" if torch.cuda.is_available() else "cpu"
        inputs = self.tokenizer(updated_prompt, return_tensors="pt").to(device)
        generated_ids = self.model.generate(
            **inputs,
            num_return_sequences=1,
            eos_token_id=self.tokenizer.eos_token_id,
            pad_token_id=self.tokenizer.eos_token_id,
            max_new_tokens=400,
            do_sample=False,
            num_beams=1,
        )
        outputs = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        torch.cuda.empty_cache()
        torch.cuda.synchronize()

        return outputs[0].split("```sql")[-1]

In [None]:
from vanna.chromadb import ChromaDB_VectorStore

class MyVanna(ChromaDB_VectorStore, SQLCoder8b):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        SQLCoder8b.__init__(self, config=config)

In [None]:
# if model_name is defog/llama-3-sqlcoder-8b, the notebook will download it
# if you have download all model
vn = MyVanna(config={'model_name': '/kaggle/input/defog-llama-3-sqlcoder-8b'})

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
vn.connect_to_postgres(
    host='your-host',
    dbname='your-db-name',
    user='your-user',
    password='your-password',
    port='your-port'
)

In [None]:
# paste your ddl here
list_ddl = [
	"""
	CREATE TABLE Persons (
    PersonID int,
    LastName varchar(255),
    FirstName varchar(255),
    Address varchar(255),
    City varchar(255)
	);
	""",
]

docs = [
    'When the question contains exact string, ALWAYS use case-insensitive LOWER or ILIKE.',
    'Make sure ALWAYS use correct table name.',
    'When question ask about "tên" in name, use "first_name"',
    'When question ask about "họ" in name, use "last_name"',
    'City is difference from country',
    "If you cannot answer the question with the available database schema, return 'I do not know'",
    '"nơi đăng ký kết hôn" is also known as "marriage registation place"',
    "try to parse into integer when you have a condition clause, example: where year_birth=1990 ==> where year_birth='1990'",
    "If the question contains an address, you should query based on the type of address name located in a certain cell, not adress equal a cell value, example: LOWER(mp.marriage_registration_place) = '%hải phòng%'; ==> LOWER(mp.marriage_registration_place) LIKE '%hải phòng%';",
]

qs_pair = []

In [None]:
for ddl in list_ddl:
    vn.train(ddl=ddl)
for doc in docs:
    vn.train(documentation=doc)
for q, s in qs_pair:
    vn.train(question=q, sql=s)

Adding ddl: 
    CREATE TABLE hr_employee (
	id SERIAL NOT NULL, 
	resource_id INTEGER NOT NULL, 
	company_id INTEGER NOT NULL, 
	resource_calendar_id INTEGER, 
	message_main_attachment_id INTEGER, 
	color INTEGER, 
	department_id INTEGER, 
	job_id INTEGER, 
	address_id INTEGER, 
	work_contact_id INTEGER, 
	work_location_id INTEGER, 
	user_id INTEGER, 
	parent_id INTEGER, 
	coach_id INTEGER, 
	address_home_id INTEGER, 
	country_id INTEGER, 
	children INTEGER, 
	country_of_birth INTEGER, 
	bank_account_id INTEGER, 
	km_home_work INTEGER, 
	departure_reason_id INTEGER, 
	create_uid INTEGER, 
	write_uid INTEGER, 
	name VARCHAR, 
	job_title VARCHAR, 
	work_phone VARCHAR, 
	mobile_phone VARCHAR, 
	work_email VARCHAR, 
	employee_type VARCHAR NOT NULL, 
	gender VARCHAR, 
	marital VARCHAR, 
	spouse_complete_name VARCHAR, 
	place_of_birth VARCHAR, 
	ssnid VARCHAR, 
	sinid VARCHAR, 
	identification_id VARCHAR, 
	passport_id VARCHAR, 
	permit_no VARCHAR, 
	visa_no VARCHAR, 
	certificate VARCHAR, 

/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:01<00:00, 53.2MiB/s]


Adding ddl: 
    CREATE TABLE medical_patient (
	id SERIAL NOT NULL, 
	message_main_attachment_id INTEGER, 
	end_period_menses INTEGER, 
	end_time_menses INTEGER, 
	start_time_menses INTEGER, 
	iui_qty INTEGER, 
	ivf_qty INTEGER, 
	year_wait_child INTEGER, 
	month_wait_child INTEGER, 
	doctor_id INTEGER, 
	patient_record_count INTEGER, 
	age INTEGER, 
	company_id INTEGER NOT NULL, 
	origin_patient_id INTEGER, 
	ethnic_id INTEGER, 
	religion_id INTEGER, 
	state_id INTEGER, 
	country_id INTEGER, 
	current_relationship_id INTEGER, 
	create_uid INTEGER, 
	write_uid INTEGER, 
	name VARCHAR NOT NULL, 
	pulse_rate VARCHAR, 
	breathing_rate VARCHAR, 
	notes_others VARCHAR, 
	is_contracept VARCHAR, 
	contraceptive VARCHAR, 
	notes_father VARCHAR, 
	notes_mother VARCHAR, 
	notes_grandparent VARCHAR, 
	has_medical_disease VARCHAR, 
	right_ovary VARCHAR, 
	left_ovary VARCHAR, 
	uterus_comment VARCHAR, 
	left_epididymis VARCHAR, 
	right_epididymis VARCHAR, 
	left_varicose_testicle VARCHAR, 
	right_

In [None]:
vn.get_training_data()

Unnamed: 0,id,question,content,training_data_type
0,b22f0e34-01dc-52d1-9054-944dbd01dd7c-ddl,,\n CREATE TABLE medical_patient (\n\tid SER...,ddl
1,eeb1cf4c-4e08-5dc1-b58d-c0355eb369af-ddl,,\n CREATE TABLE hr_employee (\n\tid SERIAL ...,ddl
0,15efc567-67c0-5027-8549-844cf3bc52a3-doc,,"""nơi đăng ký kết hôn"" is also known as ""marria...",documentation
1,23685f6f-aea9-52fd-9510-fe3b59092ee5-doc,,"When question ask about ""họ"" in name, use ""las...",documentation
2,3a0acfec-aab4-591e-8837-b7704b9b24cc-doc,,try to parse into integer when you have a cond...,documentation
3,cd028ba8-a454-5600-bd6e-ef6e57302147-doc,,"If the question contains an address, you shoul...",documentation
4,cd119275-7b70-578f-9dc5-a6f649aa14e1-doc,,"When the question contains exact string, ALWAY...",documentation
5,e1a21420-01c9-5963-b625-0e0d2f4fd60e-doc,,City is difference from country,documentation
6,fa6658bc-bd5f-5cb7-abcb-1b80eb76fbb1-doc,,If you cannot answer the question with the ava...,documentation
7,fafb7b27-6b8b-5e37-a380-6f919bab58ef-doc,,Make sure ALWAYS use correct table name.,documentation


In [None]:
%%time
s = vn.ask(
    auto_train=False,
    visualize=False,
    print_results=False,
    allow_llm_to_see_data=True,
    question="Liệt kê tất cả các bệnh nhân nam sinh năm 1990 và đăng ký kết hôn tại Hải Phòng")

SQL Prompt: [{'role': 'system', 'content': 'You are a PostgreSQL expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. \n===Tables \n\n    CREATE TABLE medical_patient (\n\tid SERIAL NOT NULL, \n\tmessage_main_attachment_id INTEGER, \n\tend_period_menses INTEGER, \n\tend_time_menses INTEGER, \n\tstart_time_menses INTEGER, \n\tiui_qty INTEGER, \n\tivf_qty INTEGER, \n\tyear_wait_child INTEGER, \n\tmonth_wait_child INTEGER, \n\tdoctor_id INTEGER, \n\tpatient_record_count INTEGER, \n\tage INTEGER, \n\tcompany_id INTEGER NOT NULL, \n\torigin_patient_id INTEGER, \n\tethnic_id INTEGER, \n\treligion_id INTEGER, \n\tstate_id INTEGER, \n\tcountry_id INTEGER, \n\tcurrent_relationship_id INTEGER, \n\tcreate_uid INTEGER, \n\twrite_uid INTEGER, \n\tname VARCHAR NOT NULL, \n\tpulse_rate VARCHAR, \n\tbreathing_rate VARCHAR, \n\tnotes_others VARCHAR, \n\tis_contracept VARCHA



LLM Response: 
SELECT mp.first_name, mp.last_name FROM medical_patient mp WHERE LOWER(mp.marriage_registration_place) LIKE '%hải phòng%' AND mp.year_birth = '1990' AND mp.gender ='male';
Extracted SQL: SELECT mp.first_name, mp.last_name FROM medical_patient mp WHERE LOWER(mp.marriage_registration_place) LIKE '%hải phòng%' AND mp.year_birth = '1990' AND mp.gender ='male';
CPU times: user 42.4 s, sys: 1.91 s, total: 44.3 s
Wall time: 44.5 s


In [None]:
s[0]

"SELECT mp.first_name, mp.last_name FROM medical_patient mp WHERE LOWER(mp.marriage_registration_place) LIKE '%hải phòng%' AND mp.year_birth = '1990' AND mp.gender ='male';"

In [None]:
vn.run_sql("SELECT mp.first_name, mp.last_name FROM medical_patient mp WHERE LOWER(mp.marriage_registration_place) LIKE '%hải phòng%' AND mp.year_birth = '1990' AND mp.gender ='male';")

Unnamed: 0,first_name,last_name
0,Ngô Sỹ,Hùng
1,Đàm Văn,Thức
2,Bùi Bình,Đông
3,Bùi Văn,Lợi
4,Vũ Văn,Hợp
...,...,...
76,Đỗ Huy,Hoàng
77,Đoàn Vũ,Linh
78,Nguyễn Trung,Thành
79,Nguyễn Văn,Thuấn
