In [6]:
import pkg.dbs.postgres
from pkg import config
from typing import Dict, Optional, List
from langchain_anthropic import ChatAnthropic  # Updated import
from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
import json
import os

In [None]:
api_key = config.get_API_keys()

In [3]:
class SQLResponse(BaseModel):
    query: Optional[str] = Field(description="The generated SQL query")
    explanation: Optional[str] = Field(description="Explanation if query generation fails")

In [None]:
class SQLQueryGenerator:
    def __init__(self, api_key: str):
        self.llm = ChatAnthropic(
            model="claude-3-sonnet-20240229",
            anthropic_api_key=api_key,
            max_tokens=1024  # Added token limit
        )
        self.output_parser = PydanticOutputParser(pydantic_object=SQLResponse)
        
        # Define the system prompt template
        self.system_template = """You are a text to SQL query expert. Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions.

                                    Table Schema:
                                    {schema}

                                    Example Q&A Pairs:
                                    {examples}

                                    Foreign Key Information:
                                    {foreign_keys}
                                    
                                    DB Types Information:
                                    {types}

                                    Response Guidelines:
                                    1. If the provided context is sufficient, please generate a valid query without any explanations for the question.
                                    2. If the provided context is insufficient, please explain why it can't be generated.
                                    3. Please use the most relevant table(s).
                                    4. Please format the query before responding.

                                    The output should be a JSON object with the following format:
                                    {{
                                        "query": "A generated SQL query when context is sufficient.",
                                        "explanation": "An explanation of failing to generate the query."
                                    }}

                                    Question: {question}"""
                                    

    def _format_examples(self, examples: List[Dict[str, str]]) -> str:
        """Format the example Q&A pairs."""
        examples_str = ""
        for example in examples:
            examples_str += f"### {example['question']}\n{example['query']}\n\n"
        return examples_str

    def generate_query(self, 
                      question: str, 
                      tables: str,
                      examples: List[Dict[str, str]], 
                      foreign_keys: str,
                      types: str) -> SQLResponse:
        """
        Generate a SQL query based on the natural language question.
        
        Args:
            question: The natural language question
            tables: Tables Information
            examples: List of example Q&A pairs
            foreign_keys: Foreign key relationships between tables
            typs: User-defined types information
            
        Returns:
            SQLResponse object containing the generated query or explanation
        """
        # Format the prompt components
        examples_str = self._format_examples(examples)
        
        # Create the prompt
        prompt = ChatPromptTemplate.from_template(self.system_template)
        
        # Format the prompt with the given inputs
        formatted_prompt = prompt.format(
            schema=tables,
            examples=examples_str,
            foreign_keys=foreign_keys,
            types=types,
            question=question
        )
        
        # Generate response using Claude
        response = self.llm.invoke(formatted_prompt)
        
        # Parse the response
        try:
            # Extract JSON from the response
            json_str = response.content
            # Parse into SQLResponse object
            sql_response = self.output_parser.parse(json_str)
            return sql_response
        except Exception as e:
            return SQLResponse(
                query=None,
                explanation=f"Failed to parse response: {str(e)}"
            )

In [None]:
# Get DB Schema
conf = config.getConfig("./config/config.ini")
pconf = conf['postgres']
psql = pkg.dbs.postgres.PSQL(pconf['host'], int(pconf['port']),pconf['db'],pconf['user'],pconf['pw'])

# Tables
tables = psql.get_db_schema()

# Foreign Keys
foreign_keys = psql.get_foregin_keys()

# Types
typs = psql.get_db_types()

# Initialize with your API key
generator = SQLQueryGenerator(api_key)


# Define example Q&A pairs
examples = [
    {
        "question": "How many people whose name is Hong Gil-dong are there?",
        "query": "SELECT count(*) FROM 인적정보 WHERE Name = '홍길동';"
    },
    {
        "question": "How many certificates have the name 통신자격증 are there?",
        "query": "SELECT count(*) FROM 자격증 WHERE Name = '통신자격증';"
    },
    {
        "question": "How many ReqGrade 2 are there?",
        "query": "SELECT count(*) FROM 병과 WHERE ReqGrade = 2;"
    }
]


# Generate query for a sample question
question = "What are the Name and Affiliation IDs of people whose 신체검사 Grade 1 or 2?"
result = generator.generate_query(question, tables, examples, foreign_keys)

# Print the result
print("Generated SQL Query:")
print(result.query if result.query else result.explanation)


Generated SQL Query:
SELECT COUNT(*) FROM 부대 WHERE Organization LIKE '%사단%';
