# LLMs for Text-to-SQL

In [30]:
import os
import json
import boto3

from botocore.exceptions import ClientError
from botocore.config import Config

def call_claude_msg(messages, system, bedrock_region_name='us-west-2', model="sonnet", streaming=False):
    retry_config = Config(
        region_name=bedrock_region_name,
        retries={
            "max_attempts": 10,
            "mode": "standard",
        },
    )

    bedrock_runtime = boto3.client(service_name='bedrock-runtime',config=retry_config)
    
    prompt_config = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 512,
        "temperature" : 0.0,
        "top_k": 350,
        "top_p": 0.999,
        "system" : system,
        "messages": messages
    }

    body = json.dumps(prompt_config)
    
    if model == "sonnet":
        modelId = "anthropic.claude-3-sonnet-20240229-v1:0"
    elif model == "haiku":
        modelId = "anthropic.claude-3-haiku-20240307-v1:0"
        
    accept = "application/json"
    contentType = "application/json"
    
    if streaming:        
        response = bedrock_runtime.invoke_model_with_response_stream(
            body=body, modelId=modelId, accept=accept, contentType=contentType
        )
        results=bedrock_streemer(response)
    else:
        response = bedrock_runtime.invoke_model(
            body=body, modelId=modelId, accept=accept, contentType=contentType
        )
        response_body = json.loads(response.get("body").read())
        results = response_body.get("content")[0].get("text")
    return results

In [None]:
SQL_DIALECT = "Postgre SQL"

EXAMPLE_DB_SCHEMA = """
"table_schema","table_name","column_name","data_type","character_maximum_length","numeric_precision","column_default","is_nullable"
"public","listings","vehicle_id","integer",NULL,32,NULL,"YES"
"public","listings","price","integer",NULL,32,NULL,"NO"
"public","listings","mileage","integer",NULL,32,NULL,"NO"
"public","listings","region","character varying",20,NULL,NULL,"YES"
...
"""

EXAMPLE_DB_CONSTRAINTS = """
"table_schema","table_name","constraint_name","constraint_type","column_name","foreign_table_schema","foreign_table_name","foreign_column_name"
"public","listings","listings_vehicle_id_fkey","FOREIGN KEY","vehicle_id","public","vehicles","vehicle_id"
....
"""
QUESTION = "캘리포니아에서 몇 대의 차가 팔렸습니까?"

In [41]:
SQL_DIALECT = "Postgre SQL"

EXAMPLE_DB_SCHEMA = """
"table_schema","table_name","column_name","data_type","character_maximum_length","numeric_precision","column_default","is_nullable"
"public","listings","vehicle_id","integer",NULL,32,NULL,"YES"
"public","listings","price","integer",NULL,32,NULL,"NO"
"public","listings","mileage","integer",NULL,32,NULL,"NO"
"public","listings","region","character varying",20,NULL,NULL,"YES"
"public","regions","vehicle_id","integer",NULL,32,NULL,"YES"
"public","regions","city","character varying",100,NULL,NULL,"NO"
"public","regions","state","character varying",2,NULL,NULL,"NO"
"public","regions","county","character varying",100,NULL,NULL,"YES"
"public","regions","region","character varying",20,NULL,NULL,"YES"
"public","vehicles","vehicle_id","integer",NULL,32,"nextval('vehicles_vehicle_id_seq'::regclass)","NO"
"public","vehicles","make","character varying",50,NULL,NULL,"NO"
"public","vehicles","model","character varying",100,NULL,NULL,"NO"
"public","vehicles","year","smallint",NULL,16,NULL,"NO"
"public","vehicles","vin","character varying",17,NULL,NULL,"NO"
"public","vehicles","created_on","timestamp without time zone",NULL,NULL,"CURRENT_TIMESTAMP","NO"
"""

EXAMPLE_DB_CONSTRAINTS = """
"table_schema","table_name","constraint_name","constraint_type","column_name","foreign_table_schema","foreign_table_name","foreign_column_name"
"public","listings","listings_vehicle_id_fkey","FOREIGN KEY","vehicle_id","public","vehicles","vehicle_id"
"public","regions","regions_vehicle_id_fkey","FOREIGN KEY","vehicle_id","public","vehicles","vehicle_id"
"public","vehicles","vehicles_pkey","PRIMARY KEY","vehicle_id","public","vehicles","vehicle_id"
"public","vehicles","vehicles_vin_key","UNIQUE","vin","public","vehicles","vin"
"""

QUESTION = "How many cars were sold in California?"
QUESTION = "캘리포니아에서 몇 대의 차가 팔렸습니까?"


In [42]:
system_message = """
    당신은 SQL 쿼리를 생성하는 {} bot입니다.
"""

user_message = """
    필요한 테이블 이름에 대해 '데이터베이스 제약 조건' 및 '테이블 요약'의 데이터셋(.csv 문자열로 포맷됨)이 제공됩니다. 
    질문에 가장 정확하게 답변할 수 있도록 테이블에서 데이터를 검색하는 구문적으로 올바른 {} SQL 쿼리를 생성해야 합니다.

    쿼리 생성 방법 가이드:
    - 질문에 세부 사항에 주의를 기울이고 정확히 지침을 따르세요.
    - 쿼리만 반환하고 다른 것은 반환하지 마세요. 텍스트로 질문에 답변하지 마세요.
    - 예시 출력 쿼리: 'SELECT * FROM table_name'.

    관련 테이블의 '데이터베이스 제약 조건'은 {}입니다.
    관련 테이블의 '테이블 요약'은 {}입니다.

    질문은 {}입니다.
"""

# user_message = """
#     You are given the datasets of database constraints and table summaries (formatted as .csv strings) for the required table names. 
#     You are required to generate a syntactically correct {} SQL query that retrieves the data from the tables that would help answer the question most accurately. 

#     Guide on how to generate the query:
#        - Pay attention to the details of the question: accurately follow the instructions. 
#        - Return only the query and nothing else. Do not return anything other than a query. Do not answer the question with text.
#        - Example output query: 'SELECT * FROM table_name'. 

#     Database constraints for the relevant tables are: {}. 
#     Table summaries for the relevant tables are: {}.  

#     The question is: {}.
# """

system = system_message.format(SQL_DIALECT)
text_prompt = user_message.format(SQL_DIALECT, EXAMPLE_DB_CONSTRAINTS, EXAMPLE_DB_SCHEMA, QUESTION)

In [45]:
messages = [
  {"role": "user", "content": [
      {"type": "text", "text": text_prompt}]}
]
bedrock_region_name = 'us-west-2'
response = call_claude_msg(messages, system, bedrock_region_name, 'haiku')
print(response)

SELECT COUNT(*) 
FROM listings l
JOIN regions r ON l.vehicle_id = r.vehicle_id
WHERE r.state = 'CA';


In [46]:
# EXAMPLE_KNOWLEDGE = {
#     "Question": "How many cars were sold in California?", 
#     "Explanation": "The number of cars sold in California State in the USA", 
#     "Query": """
#              SELECT COUNT(*) FROM listings 
#              JOIN regions ON listings.vehicle_id = regions.vehicle_id 
#              WHERE regions.state = 'California' OR regions.state = 'CALIFORNIA';
#              """, 
# }
EXAMPLE_KNOWLEDGE = {
    "문의": "캘리포니아에서 몇 대의 자동차가 판매되었나요?",
    "설명": "미국 캘리포니아 주에서 판매된 자동차의 수",
    "쿼리": """
             리스팅 테이블과 지역 테이블을 vehicle_id 열을 기준으로 조인한 후
             지역 테이블의 state 열 값이 'California' 또는 'CALIFORNIA'인 레코드의 수를 계산합니다.
             """
}

# QUESTION_W_KNOWLEDGE = "How many cars were sold in California in 2003?"
QUESTION_W_KNOWLEDGE = "2003년 캘리포니아에서 몇 대의 자동차가 판매되었습니까?"

In [47]:
user_message = """
    You are given the datasets of database constraints and table summaries (formatted as .csv strings) for the required table names. 
    You are required to generate a syntactically correct {} SQL query that retrieves the data from the tables that would help answer the question most accurately. 

    Guide on how to generate the query:
       - Pay attention to the details of the question: accurately follow the instructions. 
       - Return only the query and nothing else. Do not return anything other than a query. Do not answer the question with text.
       - You may or may not be provided a relevant ground truth example. Use it to generate a more accurate query.
       - Example output query: 'SELECT * FROM table_name'. 

    Database constraints for the relevant tables are: {}. 
    Table summaries for the relevant tables are: {}.  
    Ground truth example is: {}.

    The question is: {}.
"""

user_message = """
    필요한 테이블 이름에 대해 '데이터베이스 제약 조건' 및 '테이블 요약'의 데이터셋(.csv 문자열로 포맷됨)이 제공됩니다. 
    질문에 가장 정확하게 답변할 수 있도록 테이블에서 데이터를 검색하는 구문적으로 올바른 {} SQL 쿼리를 생성해야 합니다.

    쿼리 생성 방법 가이드:
    - 질문에 세부 사항에 주의를 기울이고 정확히 지침을 따르세요.
    - 쿼리만 반환하고 다른 것은 반환하지 마세요. 텍스트로 질문에 답변하지 마세요.
    - 관련된 실제 예시가 제공될 수도 있고 그렇지 않을 수도 있습니다. 더 정확한 쿼리를 만들기 위해 예시를 활용하세요.
    - 예시 출력 쿼리: 'SELECT * FROM table_name'.

    관련 테이블의 '데이터베이스 제약 조건'은 {}입니다.
    관련 테이블의 '테이블 요약'은 {}입니다.
    정답 예시는: {}입니다.

    질문은 {}입니다.
"""

text_prompt = user_message.format(SQL_DIALECT, EXAMPLE_DB_CONSTRAINTS, EXAMPLE_DB_SCHEMA, EXAMPLE_KNOWLEDGE, QUESTION_W_KNOWLEDGE)

In [48]:
messages = [
  {"role": "user", "content": [
      {"type": "text", "text": text_prompt}]}
]
bedrock_region_name = 'us-west-2'
response = call_claude_msg(messages, system, bedrock_region_name, 'haiku')
print(response)

SELECT COUNT(*) 
FROM listings l
JOIN regions r ON l.vehicle_id = r.vehicle_id
WHERE r.state IN ('California', 'CALIFORNIA')
  AND l.created_on BETWEEN '2003-01-01' AND '2003-12-31';
