In [1]:
# https://blog.gopenai.com/advanced-rag-for-database-without-exposing-db-data-text-to-sql-a0e71f00e010
# https://github.com/paras55/advanced-chat-with-db
# Licença MIT

from langchain_openai import ChatOpenAI
from dotenv import load_dotenv
from langchain import OpenAI, LLMChain
from langchain.prompts import PromptTemplate
from langchain.utilities import SQLDatabase
from sqlalchemy import create_engine, MetaData, Table, Column, inspect
from langchain_experimental.sql import SQLDatabaseChain
import os

load_dotenv(dotenv_path='.env', override=True)
openai_api_key = os.getenv("OPENAI_API_KEY")

llm = ChatOpenAI(temperature=0, openai_api_key=openai_api_key, model='gpt-4o-mini')

In [2]:
def extract_schema(db_url):
    engine = create_engine(db_url)
    inspector = inspect(engine)
    schema_info = []
    for table_name in inspector.get_table_names():
        columns = inspector.get_columns(table_name)
        schema_info.append(f"Table: {table_name}")
        for column in columns:
            schema_info.append(f"  - {column['name']} ({column['type']})")
    return "\n".join(schema_info)

In [3]:
prompt_template = """
Você é um assistente de IA que gera consultas SQL com base em solicitações de usuários.
Você tem acesso ao seguinte esquema de banco de dados:
{schema}
Com base neste esquema, gere uma consulta SQL para responder à seguinte pergunta:
{question}
SQL Query:
"""
prompt = PromptTemplate(
    input_variables=["schema", "question"],
    template=prompt_template,
)

In [4]:
chain = prompt | llm

def generate_sql_query(question):
    return chain.invoke({"schema":schema, "question":question})

In [5]:
db_url = f"sqlite:///{'new.db'}"
schema = extract_schema(db_url)
print(schema)

Table: Comments
  - CommentID (INTEGER)
  - PostID (INTEGER)
  - UserID (INTEGER)
  - Content (TEXT)
  - CreatedAt (DATETIME)
Table: EventRegistrations
  - RegistrationID (INTEGER)
  - EventID (INTEGER)
  - UserID (INTEGER)
  - RegisteredAt (DATETIME)
Table: Events
  - EventID (INTEGER)
  - GroupID (INTEGER)
  - EventName (TEXT)
  - EventDate (DATETIME)
Table: GroupMemberships
  - MembershipID (INTEGER)
  - GroupID (INTEGER)
  - UserID (INTEGER)
  - JoinedAt (DATETIME)
Table: Groups
  - GroupID (INTEGER)
  - GroupName (TEXT)
  - Description (TEXT)
Table: Likes
  - LikeID (INTEGER)
  - PostID (INTEGER)
  - UserID (INTEGER)
Table: Messages
  - MessageID (INTEGER)
  - SenderID (INTEGER)
  - ReceiverID (INTEGER)
  - Content (TEXT)
  - SentAt (DATETIME)
Table: Posts
  - PostID (INTEGER)
  - UserID (INTEGER)
  - Content (TEXT)
  - CreatedAt (DATETIME)
Table: Profiles
  - ProfileID (INTEGER)
  - UserID (INTEGER)
  - FullName (TEXT)
  - Bio (TEXT)
Table: Users
  - UserID (INTEGER)
  - Username

In [6]:
user_question = "Find me the registration id of the hackathon"
sql_query = generate_sql_query(user_question)
print(f"Generated SQL Query: {sql_query.content}")

Generated SQL Query: Para encontrar o `RegistrationID` do hackathon, precisamos primeiro identificar o `EventID` correspondente ao hackathon na tabela `Events`. Em seguida, podemos buscar o `RegistrationID` na tabela `EventRegistrations` que corresponde a esse `EventID`. 

Aqui está a consulta SQL que faz isso:

```sql
SELECT er.RegistrationID
FROM EventRegistrations er
JOIN Events e ON er.EventID = e.EventID
WHERE e.EventName = 'hackathon';
```

Essa consulta faz um `JOIN` entre as tabelas `EventRegistrations` e `Events` para obter o `RegistrationID` do evento chamado "hackathon".


In [9]:
import re
def extract_sql_query(text: str) -> str:
    """
    Extrai a consulta SQL delimitada por ```sql ... ``` de uma string.
    """
    pattern = r"```sql\s*(.*?)\s*```"
    match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()
    else:
        return ""


In [10]:
print(extract_sql_query(sql_query.content))

SELECT er.RegistrationID
FROM EventRegistrations er
JOIN Events e ON er.EventID = e.EventID
WHERE e.EventName = 'hackathon';
