In [1]:
import ollama 
import json

In [2]:
client = ollama.Client()

In [36]:
sql = client.generate(model="sqls", prompt="What is the total sales for each region segment by product type?")
generated_sql = json.loads(sql.response)['sql_ans']

In [None]:
generated_sql

In [38]:
query = {
    "user_selections": {
        "columns": ["region", "total"],
        "selected_values": {
            "region": ["East", "South"]
        }
    },
    "user_query": "What is the total sales for each region segment by product type?",
    "generated_sql": generated_sql
}

In [39]:
query_str = json.dumps(query)

In [40]:
ver = client.generate(model="checker", prompt=query_str)

In [None]:
query

In [None]:
json.loads(ver.response)['updated_sql']

In [1]:
class SQLQueryValidator:
    """
    A class to handle SQL query validation and modification based on user selections.
    This class encapsulates the process of generating, validating, and updating SQL queries.
    """
    
    def __init__(self, client):
        """
        Initialize the SQLQueryValidator with an Ollama client.
        
        Args:
            client: An instance of ollama.Client for model interactions
        """
        self.client = ollama.Client()
        
    def generate_initial_sql(self, user_query):
        """
        Generate initial SQL query based on user's natural language query.
        
        Args:
            user_query (str): Natural language query from user
            
        Returns:
            dict: Response containing success status and SQL query
        """
        sql = self.client.generate(model="sqls", prompt=user_query)
        response = json.loads(sql.response)
        
        if not response or 'sql_ans' not in response:
            return {
                'success': False,
                'error': 'Failed to generate SQL query'
            }
            
        if response['sql_ans'] == 'nan':
            return {
                'success': False,
                'error': 'Invalid query or question not related to database'
            }
            
        return {
            'success': True,
            'sql_query': response['sql_ans']
        }
    
    def create_query_object(self, user_query, generated_sql, columns, selected_values):
        """
        Create a query object for validation.
        
        Args:
            user_query (str): Original user query
            generated_sql (str): Generated SQL query
            columns (list): Selected columns
            selected_values (dict): Selected filter values
            
        Returns:
            dict: Query object for validation
        """
        return {
            "user_selections": {
                "columns": columns,
                "selected_values": selected_values
            },
            "user_query": user_query,
            "generated_sql": generated_sql
        }
    
    def validate_and_update_sql(self, query_object):
        """
        Validate and update SQL query based on user selections.
        
        Args:
            query_object (dict): Query object containing user selections and SQL
            
        Returns:
            dict: Response containing success status and updated SQL query
        """
        try:
            query_str = json.dumps(query_object)
            verification = self.client.generate(model="checker", prompt=query_str)
            response = json.loads(verification.response)
            
            if not response or 'updated_sql' not in response:
                return {
                    'success': False,
                    'error': 'Failed to validate SQL query'
                }
                
            return {
                'success': True,
                'sql_query': response['updated_sql']
            }
            
        except Exception as e:
            return {
                'success': False,
                'error': str(e)
            }

# Example usage:
# validator = SQLQueryValidator(client)
# initial_response = validator.generate_initial_sql("What is the total sales for each region segment by product type?")
# if initial_response['success']:
#     query_obj = validator.create_query_object(
#         "What is the total sales for each region segment by product type?",
#         initial_response['sql_query'],
#         ["region", "total"],
#         {"region": ["East", "South"]}
#     )
#     final_response = validator.validate_and_update_sql(query_obj)



In [None]:
# Example usage:
validator = SQLQueryValidator(client)
initial_response = validator.generate_initial_sql("What is the total sales for each region segment by product type?")
if initial_response['success']:
    query_obj = validator.create_query_object(
        "What is the total sales for each region segment by product type?",
        initial_response['sql_query'],
        ["region", "total"],
        {"region": ["East", "South"]}
    )
    final_response = validator.validate_and_update_sql(query_obj)


In [None]:
# Import the OllamaLLM class from the correct module path
# Import the OllamaLLM class from the local models directory
from models.llm_ollama import OllamaLLM

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("yasserrmd/Text2SQL-1.5B")
model = AutoModelForCausalLM.from_pretrained("yasserrmd/Text2SQL-1.5B")

# Define the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Define system instruction
system_instruction = "Always separate code and explanation. Return SQL code in a separate block, followed by the explanation in a separate paragraph. Use markdown triple backticks (```sql for SQL) to format the code properly. Write the SQL query first in a separate code block. Then, explain the query in plain text. Do not merge them into one response. The query should always include the table structure using a CREATE TABLE statement before executing the main SQL query."

# Define user query
user_query = "Show the total sales for each customer who has spent more than $50,000.
CREATE TABLE sales (
    id INT PRIMARY KEY,
    customer_id INT,
    total_amount DECIMAL(10,2),
    FOREIGN KEY (customer_id) REFERENCES customers(id)
);

CREATE TABLE customers (
    id INT PRIMARY KEY,
    name VARCHAR(255)
);
"

# Define messages for input
messages = [
    {"role": "system", "content": system_instruction},
    {"role": "user", "content": user_query},
]

# Generate SQL output
response = pipe(messages)


# Print the generated SQL query
print(response[0]['generated_text'])
