In [None]:
import os
from typing import Dict, List, Tuple
import json
import pandas as pd
import numpy as np
from langchain.embeddings import HuggingFaceEmbeddings
from openai import AzureOpenAI
from langchain.vectorstores import Chroma
from langchain.schema import HumanMessage, AIMessage
import re
import time
import sqlite3

In [None]:
# Function to extract table schema from the database
class TrainingPlanItem:
    ITEM_TYPE_IS = "Information_Schema"
    
    def __init__(self, item_type, item_group, item_name, item_value):
        self.item_type = item_type
        self.item_group = item_group  # will be dynamic(e.g. database name)
        self.item_name = item_name
        self.item_value = item_value

class TrainingPlan:
    def __init__(self, plan=[]):
        self._plan = plan

    def add_item(self, item):
        self._plan.append(item)

    def display_plan(self):
        for item in self._plan:
            print(f"Group: {item.item_group}\nName: {item.item_name}\nValue:\n{item.item_value}\n")
    
    def get_plan_as_text(self):
        """
        Return the training plan as a string, which can be passed to the LLM or other systems.
        """
        plan_text = ""
        for item in self._plan:
            plan_text += f"Group: {item.item_group}\n"
            plan_text += f"Name: {item.item_name}\n"
            plan_text += f"Value:\n{item.item_value}\n\n"
        return plan_text

# Function to extract schema information from any SQLite database
def get_training_plan_generic(cursor, conn) -> TrainingPlan:
    # Create a new training plan
    plan = TrainingPlan([])

    # Extract the database name from the connection(SQLite doesn't have a real DB name)
    # In other systems, you would extract the actual DB name.
    cursor.execute("PRAGMA database_list;")
    databases = cursor.fetchall()
    
    # Generalized extraction of database name for the item_group
    database_name = databases[0][2].split("\\")[-1] # Fetch the database name from the PRAGMA query (second column is the name)

    # Fetch all table names from the SQLite database
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    # Iterate through each table
    for table in tables:
        table_name = table[0]

        # Fetch table schema (column names and data types)
        cursor.execute(f"PRAGMA table_info({table_name});")
        columns = cursor.fetchall()

        # Create a document that summarizes the table structure
        doc = f"The following columns are in the {table_name} table in the {database_name} database:\n\n"
        
        # DataFrame to format the markdown output
        column_data = []
        for column in columns:
            column_data.append([column[1], column[2], column[3], column[4]])  # Column name, data type, not null, default value
        
        df_columns = pd.DataFrame(column_data, columns=["Column Name", "Data Type", "Not Null", "Default Value"])
        doc += df_columns.to_markdown()

        # Addintg this information as a new training plan item
        plan.add_item(
            TrainingPlanItem(
                item_type=TrainingPlanItem.ITEM_TYPE_IS,
                item_group=database_name,  # Dynamic database name
                item_name=table_name,
                item_value=doc
            )
        )

    return plan

In [None]:
def llm_response(client, user_question, ddl_schema, chatHistory):
    # Get the schema of the database    
    # Define the initial system message (including SQL documentation and guidelines)
    initial_prompt = 'You are a SQL expert. Please help to generate a SQL query to answer the question. Your response should only be based on the given context and follow the response guidelines and format instructions.'
    response_guidelines = (
        "\n===Response Guidelines \n"
        "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
        "3. If the provided context is insufficient, please explain why it can't be generated. \n"
        "4. Please use the most relevant table(s). \n"
        "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
        "6. Ensure that the output SQL is SQL-compliant and executable, and free of syntax errors. \n"
    )

    # Combine everything to create the system message for the model
    system_message = initial_prompt + "\n\n===Tables and their details:\n" + ddl_schema + response_guidelines
    
    messages = [
        {"role": "system", "content": system_message},
    ]
    
    
    # Add previous messages from chatHistory
    for msg in chatHistory:
        if isinstance(msg, HumanMessage):
            messages.append({"role": "user", "content": msg.content})
        elif isinstance(msg, AIMessage):
            messages.append({"role": "assistant", "content": msg.content})

    messages.append({"role": "user", "content": user_question})
    
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
        temperature=0.0
    )
    
    generated_response = response.choices[0].message.content
    return generated_response

def extract_sql(llm_response):
        """
        Extracts the SQL query from the LLM response. This is useful in case the LLM response contains other information besides the SQL query.
        Override this function if your LLM responses need custom extraction logic.
        """

        # If the llm_response contains a CTE (with clause), extract the last sql between WITH and ;
        sqls = re.findall(r"\bWITH\b .*?;", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            return sql.replace("\n", " ")

        # If the llm_response is not markdown formatted, extract last sql by finding select and ; in the response
        sqls = re.findall(r"SELECT.*?;", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            return sql.replace("\n", " ")

        # If the llm_response contains a markdown code block, with or without the sql tag, extract the last sql from it
        sqls = re.findall(r"```sql\n(.*)```", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            return sql.replace("\n", " ")

        sqls = re.findall(r"```(.*)```", llm_response, re.DOTALL)
        if sqls:
            sql = sqls[-1]
            return sql.replace("\n", " ")

def run_sql(query, connection):
    # Use pandas to execute the SQL query and store the results in a DataFrame
    df = pd.read_sql_query(query, connection)
    return df

In [None]:
def generate_followup_questions(client, question, sql, df, n_questions = 3, **kwargs):
        """
        Generate a list of followup questions
        """
        
        system_message = f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
        user_message = f"Generate a list of {n_questions} followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
        
        messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}
        ]
        
        response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
        temperature=0.0,
        )

        llm_response =response.choices[0].message.content

        numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
        return numbers_removed.split("\n")

In [15]:
# Set up the database connection
conn = sqlite3.connect("C://Users//VRBRAHMB//Documents//sqlite-tools//chinook.db")
cursor = conn.cursor()


In [None]:
training_plan = get_training_plan_generic(cursor, conn=conn)
ddl_schema = training_plan.get_plan_as_text()

In [38]:
def get_azureopenai_cleint(key, version, endpoint, deployment):
    client = AzureOpenAI(
        api_key=key,  
        api_version=version,
        azure_endpoint = endpoint,
        azure_deployment=deployment
        )
    return client

In [None]:
class ChatManager:
    def __init__(self, base_directory: str = "./chats"):
        self.base_directory = base_directory
        self.embed_model = HuggingFaceEmbeddings(
            model_name='sentence-transformers/all-MiniLM-L6-v2'
        )
        self.active_chat_id = None
        self.active_vectorstore = None
        self.chat_history: Dict[str, List] = {}
        
        # Create base directory if it doesn't exist
        if not os.path.exists(base_directory):
            os.makedirs(base_directory)
            
        # Load existing chat IDs
        self.existing_chats = self._load_existing_chats()
        
    def _load_existing_chats(self) -> List[str]:
        """Load all existing chat IDs from the base directory."""
        if not os.path.exists(self.base_directory):
            return []
        return [d for d in os.listdir(self.base_directory) 
                if os.path.isdir(os.path.join(self.base_directory, d))]
    
    def _get_chat_directory(self, chat_id: str) -> str:
        """Get the directory path for a specific chat."""
        return os.path.join(self.base_directory, chat_id)

    def _normalize_question(self, question: str) -> str:
        """Normalize question text to improve matching."""
        return ' '.join(question.lower().split())
    
    def initialize_chat(self, chat_id: str) -> None:
        """Initialize or switch to a specific chat."""
        chat_dir = self._get_chat_directory(chat_id)
        
        # If this is a new chat
        if chat_id not in self.existing_chats:
            os.makedirs(chat_dir, exist_ok=True)
            self.chat_history[chat_id] = []
            self.existing_chats.append(chat_id)
            
        # Initialize or switch to the chat's vector store
        self.active_vectorstore = Chroma(
            embedding_function=self.embed_model,
            persist_directory=os.path.join(chat_dir, "chroma_db")
        )
        self.active_chat_id = chat_id
        
        # Load chat history if it exists
        history_path = os.path.join(chat_dir, "chat_history.json")
        if os.path.exists(history_path):
            with open(history_path, 'r') as f:
                self.chat_history[chat_id] = json.load(f)

    def cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
        """Calculate cosine similarity between two vectors."""
        norm1 = np.linalg.norm(vec1)
        norm2 = np.linalg.norm(vec2)
        return np.dot(vec1, vec2) / (norm1 * norm2) if norm1 and norm2 else 0
                
    def save_chat_history(self) -> None:
        """Save the current chat history to disk."""
        if self.active_chat_id:
            history_path = os.path.join(
                self._get_chat_directory(self.active_chat_id),
                "chat_history.json"
            )
            with open(history_path, 'w') as f:
                json.dump(self.chat_history[self.active_chat_id], f)

    def find_similar_question(self, question: str, question_embedding: List[float], 
                            similarity_threshold: float) -> Tuple[bool, Dict]:
        """
        Find a similar question in the vector store.
        Returns (found, metadata) tuple.
        """
        try:
            # Get all stored questions
            collection = self.active_vectorstore.get()
            
            if not collection['ids']:  # Check if the collection is empty
                return False, None

            # Perform similarity search using ChromaDB's search functionality
            results = self.active_vectorstore.similarity_search_with_relevance_scores(
                question,
                k=1  # Get the most similar result
            )

            if not results:
                return False, None

            most_similar_doc, similarity_score = results[0]
            
            if similarity_score >= similarity_threshold:
                metadata = most_similar_doc.metadata
                return True, metadata
            else:
                return False, None

        except Exception as e:
            return False, None
                
    def ask(self, client, question: str, db_connection, ddl_schema, 
            similarity_threshold: float = 0.85) -> Tuple[str, pd.DataFrame]:
        """Process a question in the context of the current chat."""
        if not self.active_chat_id:
            raise ValueError("No active chat session. Please initialize a chat first.")
                
        # Normalize the question
        normalized_question = self._normalize_question(question)
        
        # Generate embedding for the new question
        question_embedding = self.embed_model.embed_query(normalized_question)
        
        # Check cache for similar questions
        found_similar, metadata = self.find_similar_question(
            normalized_question, 
            question_embedding, 
            similarity_threshold
        )
        
        if found_similar and metadata:
            sql_query = metadata['text2sql']
            results = pd.DataFrame(json.loads(metadata['df']))
        else:
            chat_history = self.chat_history[self.active_chat_id]
            chat_history.append({"role": "user", "content": question})
            
            response = llm_response(client, question, ddl_schema, chat_history)
            sql_query = extract_sql(response)
            results = run_sql(sql_query, db_connection)
            
            # Cache the new question and results
            results_json = results.to_json(orient='records')
            
            # Add to vector store with metadata
            self.active_vectorstore.add_texts(
                texts=[normalized_question],
                metadatas=[{
                    'text2sql': sql_query, 
                    'df': results_json,
                    'original_question': question
                }],
                ids=[f"{self.active_chat_id}_{str(time.time())}"]
            )
            self.active_vectorstore.persist()
            
            # Update chat history
            chat_history.append({"role": "assistant", "content": sql_query})
            self.save_chat_history()
            
        return sql_query, results

In [39]:
def main():
    chat_manager = ChatManager()
    
    key = "your_api_key"
    version = "version_name"
    endpoint = "your_endpoint"
    deployment = "deployment_name"
    client = get_azureopenai_cleint(key, version, endpoint, deployment)
    
    while True:
        # Get chat ID from user
        chat_id = input("Enter chat ID (or 'exit' to quit): ").strip()
        
        if chat_id.lower() == 'exit':
            break
            
        # Initialize chat session
        chat_manager.initialize_chat(chat_id)
        print(f"Active chat session: {chat_id}")
        
        while True:
            user_question = input(f"[Chat {chat_id}] Ask your question (or 'switch' to change chat): ")
            
            if user_question.lower() == 'switch':
                break
                
            if user_question.lower() in ['exit', 'quit']:
                return
                
            start_time = time.time()
            sql_query, results = chat_manager.ask(client,
                user_question, 
                conn, 
                ddl_schema
            )
            end_time = time.time()
            
            print(f"\nGenerated SQL Query: {sql_query}")
            print(f"Query Results: {results}\n")
            print(f'Time taken: {end_time - start_time:.2f} seconds')
            
            # Generate follow-up questions for the current context
            followup_questions = generate_followup_questions(
                client = client,
                question=user_question,
                sql=sql_query,
                df=results,
                n_questions=3
            )
            print("\nSuggested follow-up questions:")
            for i, q in enumerate(followup_questions, 1):
                print(f"{i}. {q}")

In [None]:
if __name__ == "__main__":
    main()