# Imports

In [1]:
# Set the mode manually
# Options:
#   "AHP-QA"        -> Standard QA environment (non-anonymized)
#   "AHP-PROD"      -> Standard Production environment (non-anonymized)
#   "AHP-ANON-QA"   -> QA environment with anonymized data
#   "AHP-ANON-PROD" -> Production environment with anonymized data

DATA_MODE = "AHP-QA"
print(f'\n\t*** WORKING WITH - {DATA_MODE} DATA ***\n')


	*** WORKING WITH - AHP-QA DATA ***



### ✅ Section 1: All Imports | Load Environment & Initialize DB Pool

In [2]:
# ===================================
# Generic Python Libraries
# ===================================
import os
import json
import time
import uuid
import pickle
import calendar
import re
from datetime import datetime
from textwrap import indent
from typing import cast, Dict, List, Any, Optional

# ===================================
# Data Science & Display
# ===================================
import pandas as pd
import numpy as np
from IPython.display import display, Markdown

# ===================================
# ML & Similarity
# ===================================
from sklearn.metrics.pairwise import cosine_similarity

# ===================================
# Network / API
# ===================================
import requests

# ===================================
# Langchain & Pydantic
# ===================================
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
from langchain_core.prompts import (
    HumanMessagePromptTemplate, PromptTemplate, SystemMessagePromptTemplate,
    ChatPromptTemplate, MessagesPlaceholder
)
from langchain_core.runnables.history import RunnableWithMessageHistory
from pydantic import BaseModel, Field

# ===================================
# PostgreSQL & .env
# ===================================
import psycopg2
from psycopg2 import pool
from psycopg2.extras import RealDictCursor
from dotenv import load_dotenv

### ✅ Section 2: Load Environment & Initialize DB Pool

# Load environment variables
load_dotenv()

# Set OpenAI Key for Langchain
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY')

# Initialize DB connection pool
def initialize_connection_pool(DB_HOST, DB_NAME, DB_PORT, DB_USER, DB_PASSWORD, minconn=1, maxconn=20):
    dsn = f"dbname='{DB_NAME}' user='{DB_USER}' password='{DB_PASSWORD}' host='{DB_HOST}' port='{DB_PORT}'"
    return psycopg2.pool.SimpleConnectionPool(minconn, maxconn, dsn=dsn)

# DB Credentials
DB_HOST = os.getenv('DB_HOST')
DB_NAME = os.getenv('DB_NAME')
DB_PORT = os.getenv('DB_PORT')
DB_USER = os.getenv('DB_USER')
DB_PASSWORD = os.getenv('DB_PASSWORD')

connection_pool = initialize_connection_pool(DB_HOST, DB_NAME, DB_PORT, DB_USER, DB_PASSWORD)
embedding_function = OpenAIEmbeddings()

# AWS
aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
region_name = os.getenv("AWS_REGION")


### ✅ Section 2: Config Setup (QA/PROD/ANON) + Paths

##### Previous Tables

In [3]:
# # Set defaults
# ROBOT_TABLE = None
# FINANCE_TABLE = None

# # Table Descriptions
# ROBOT_DESC = "Robot Assisted Procedure Analysis"
# FINANCE_DESC = "Overall Financial Data"
# P_EVENT_DESC = "Procedure/Case Level Utilization Data"
# PRODUCT_LEVEL_DESC = "Product Level Utilization Data"

# # DATA_MODE-specific logic
# if DATA_MODE == "QA":
#     collection_name = 'knowledge_base_ahp_qa'
#     knowledge_base_file_name = 'ahp_qa_llm_knowledge_base.xlsx'
#     QUESTION_BANK_FILE_NAME = 'Question_Bank_For_BEE_AHP.xlsx'

#     TABLE_INFO = {
#         "production_load.l_mt_opportunity_case_procedure_p_event_with_outliers_v2_chatbot_qa": P_EVENT_DESC,
#         "production_load.l_mt_opportunity_case_procedure_with_outliers_v2_chatbot_qa": PRODUCT_LEVEL_DESC,
#         "production_load.l_mt_opportunity_case_procedure_p_event_w_outliers_rbt_chatbot_qa": ROBOT_DESC
#     }

#     s3_directory = os.getenv("S3_DIRECTORY_QA")
#     endpoint_for_redirect_link = os.getenv("API_ENDPOINT_FOR_DOWNLOAD_LINK_QA")

# elif DATA_MODE == "PROD":
#     collection_name = 'knowledge_base_ahp_prod'
#     knowledge_base_file_name = 'ahp_prod_llm_knowledge_base.xlsx'
#     QUESTION_BANK_FILE_NAME = 'Question_Bank_For_BEE_AHP.xlsx'

#     TABLE_INFO = {
#         "production_load.l_mt_opportunity_case_procedure_p_event_with_outliers_v2_chatbot_prod": P_EVENT_DESC,
#         "production_load.l_mt_opportunity_case_procedure_with_outliers_v2_chatbot_prod": PRODUCT_LEVEL_DESC,
#         "production_load.l_mt_opportunity_case_procedure_p_event_w_outliers_rbt_chatbot_prod": ROBOT_DESC
#     }

#     s3_directory = os.getenv("S3_DIRECTORY_PROD")
#     endpoint_for_redirect_link = os.getenv("API_ENDPOINT_FOR_DOWNLOAD_LINK_PROD")

# elif DATA_MODE == "ANON":
#     collection_name = 'knowledge_base_anonymized'
#     knowledge_base_file_name = 'anonymized_llm_knowledge_base.xlsx'
#     QUESTION_BANK_FILE_NAME = 'Question_Bank_For_BEE.xlsx'

#     TABLE_INFO = {
#         "production_load.l_mt_opportunity_case_procedure_p_event_w_outliers_av2_chatbot_preview": P_EVENT_DESC,
#         "production_load.l_mt_opportunity_case_procedure_with_outliers_v2_av2_chatbot_preview": PRODUCT_LEVEL_DESC,
#         "production_load.l_mt_opportunity_case_procedure_p_event_w_outliers_av2_rbt_chatbot_preview": ROBOT_DESC
#     }

#     s3_directory = os.getenv("S3_DIRECTORY_QA")
#     endpoint_for_redirect_link = os.getenv("API_ENDPOINT_FOR_DOWNLOAD_LINK_ANON_QA")
#     # endpoint_for_redirect_link = os.getenv("API_ENDPOINT_FOR_DOWNLOAD_LINK_ANON_PROD")

# else:
#     raise ValueError(f"Invalid DATA_MODE: {DATA_MODE}")

# # Assign individual tables by description
# P_EVENT_TABLE = next((k for k, v in TABLE_INFO.items() if v == P_EVENT_DESC), None)
# PRODUCT_LEVEL_TABLE = next((k for k, v in TABLE_INFO.items() if v == PRODUCT_LEVEL_DESC), None)
# FINANCE_TABLE = next((k for k, v in TABLE_INFO.items() if v == FINANCE_DESC), None)
# ROBOT_TABLE = next((k for k, v in TABLE_INFO.items() if v == ROBOT_DESC), None)

# # Directory setup
# knowledge_base_directory = './Data_For_LLM'
# persist_directory = f'{knowledge_base_directory}/{collection_name}'

# # File paths
# question_file_name = "prompt_for_complex_queries.xlsx"
# question_df = pd.read_excel(f'{knowledge_base_directory}/{question_file_name}').fillna('')
# knowledge_base_path = os.path.join(knowledge_base_directory, knowledge_base_file_name)
# question_bank_file_path = os.path.join(knowledge_base_directory, QUESTION_BANK_FILE_NAME)


# # Debug Print
# print("-----------------------------------------------------------------")
# print(f"DATA_MODE: {DATA_MODE}")
# print(f"P_EVENT_TABLE: {P_EVENT_TABLE}")
# print(f"PRODUCT_LEVEL_TABLE: {PRODUCT_LEVEL_TABLE}")
# if ROBOT_TABLE:
#     print(f"ROBOT_TABLE: {ROBOT_TABLE}")
# if FINANCE_TABLE:
#     print(f"FINANCE_TABLE: {FINANCE_TABLE}")
# print(f"Data Dictionary: {knowledge_base_file_name}")
# print(f"Collection Name: {collection_name}")
# print("-----------------------------------------------------------------")


## New Tables

In [None]:
# ===================
# Set defaults
# ===================
ROBOT_TABLE = None
FINANCE_TABLE = None

# Table Descriptions
ROBOT_DESC = "Robot Assisted Procedure Analysis"
P_EVENT_DESC = "Procedure/Case Level Utilization Data"
PRODUCT_LEVEL_DESC = "Product Level Utilization Data"

# ===================
# DATA_MODE-specific logic
# ===================
if DATA_MODE == "AHP-QA":
    collection_name = 'knowledge_base_ahp_qa'
    knowledge_base_file_name = 'ahp_qa_llm_knowledge_base.xlsx'
    QUESTION_BANK_FILE_NAME = 'Question_Bank_For_BEE_AHP.xlsx'

    TABLE_INFO = {
        "cqo_chatbot_case_level": P_EVENT_DESC,
        "cqo_chatbot_product_level": PRODUCT_LEVEL_DESC,
        "cqo_chatbot_robotics": ROBOT_DESC
    }

    s3_directory = os.getenv("S3_DIRECTORY_QA")
    endpoint_for_redirect_link = os.getenv("API_ENDPOINT_FOR_DOWNLOAD_LINK_QA")

elif DATA_MODE == "AHP-PROD":
    collection_name = 'knowledge_base_ahp_prod'
    knowledge_base_file_name = 'ahp_prod_llm_knowledge_base.xlsx'
    QUESTION_BANK_FILE_NAME = 'Question_Bank_For_BEE_AHP.xlsx'

    TABLE_INFO = {
        "cqo_chatbot_case_level": P_EVENT_DESC,
        "cqo_chatbot_product_level": PRODUCT_LEVEL_DESC,
        "cqo_chatbot_robotics": ROBOT_DESC
    }

    s3_directory = os.getenv("S3_DIRECTORY_PROD")
    endpoint_for_redirect_link = os.getenv("API_ENDPOINT_FOR_DOWNLOAD_LINK_PROD")

elif DATA_MODE == "AHP-ANON-QA":
    collection_name = 'knowledge_base_ahp_anonymized_qa'
    knowledge_base_file_name = 'ahp_qa_anonymized_llm_knowledge_base.xlsx'
    QUESTION_BANK_FILE_NAME = 'Question_Bank_For_BEE.xlsx'

    TABLE_INFO = {
        "cqo_chatbot_case_level": P_EVENT_DESC,
        "cqo_chatbot_product_level": PRODUCT_LEVEL_DESC,
        "cqo_chatbot_robotics": ROBOT_DESC
    }

    s3_directory = os.getenv("S3_DIRECTORY_QA")
    endpoint_for_redirect_link = os.getenv("API_ENDPOINT_FOR_DOWNLOAD_LINK_ANON_QA")

elif DATA_MODE == "AHP-ANON-PROD":
    collection_name = 'knowledge_base_ahp_anonymized_prod'
    knowledge_base_file_name = 'ahp_prod_anonymized_llm_knowledge_base.xlsx'
    QUESTION_BANK_FILE_NAME = 'Question_Bank_For_BEE.xlsx'

    TABLE_INFO = {
        "cqo_chatbot_case_level": P_EVENT_DESC,
        "cqo_chatbot_product_level": PRODUCT_LEVEL_DESC,
        "cqo_chatbot_robotics": ROBOT_DESC
    }

    s3_directory = os.getenv("S3_DIRECTORY_PROD")
    endpoint_for_redirect_link = os.getenv("API_ENDPOINT_FOR_DOWNLOAD_LINK_ANON_PROD")

else:
    raise ValueError(f"Invalid DATA_MODE: {DATA_MODE}")

# ===================
# Assign individual tables by description
# ===================
P_EVENT_TABLE = next((k for k, v in TABLE_INFO.items() if v == P_EVENT_DESC), None)
PRODUCT_LEVEL_TABLE = next((k for k, v in TABLE_INFO.items() if v == PRODUCT_LEVEL_DESC), None)
ROBOT_TABLE = next((k for k, v in TABLE_INFO.items() if v == ROBOT_DESC), None)

# ===================
# Directory setup
# ===================
knowledge_base_directory = './Data_For_LLM'
persist_directory = f'{knowledge_base_directory}/{collection_name}'

# File paths
question_file_name = "prompt_for_complex_queries.xlsx"
question_df = pd.read_excel(f'{knowledge_base_directory}/{question_file_name}').fillna('')
knowledge_base_path = os.path.join(knowledge_base_directory, knowledge_base_file_name)
question_bank_file_path = os.path.join(knowledge_base_directory, QUESTION_BANK_FILE_NAME)

# ===================
# Debug Print
# ===================
print("-----------------------------------------------------------------")
print(f"DATA_MODE: {DATA_MODE}")
print(f"P_EVENT_TABLE: {P_EVENT_TABLE}")
print(f"PRODUCT_LEVEL_TABLE: {PRODUCT_LEVEL_TABLE}")
if ROBOT_TABLE:
    print(f"ROBOT_TABLE: {ROBOT_TABLE}")
if FINANCE_TABLE:
    print(f"FINANCE_TABLE: {FINANCE_TABLE}")
print(f"Data Dictionary: {knowledge_base_file_name}")
print(f"Collection Name: {collection_name}")
print("-----------------------------------------------------------------")


-----------------------------------------------------------------
DATA_MODE: AHP-QA
P_EVENT_TABLE: production_load.l_mt_cqo_p_event_with_outliers_chatbot_qa
PRODUCT_LEVEL_TABLE: production_load.l_mt_cqo_product_with_outliers_chatbot_qa
ROBOT_TABLE: production_load.l_mt_cqo_p_event_with_outliers_rbt_chatbot_qa
Data Dictionary: ahp_qa_llm_knowledge_base.xlsx
Collection Name: knowledge_base_ahp_qa
-----------------------------------------------------------------


### Print Function (Not Important)

In [5]:
import pprint
import textwrap
import json
from typing import Any

try:
    from pygments import highlight
    from pygments.lexers import SqlLexer, JsonLexer, PythonLexer
    from pygments.formatters import TerminalFormatter
    HAS_PYGMENTS = True
except ImportError:
    HAS_PYGMENTS = False

def pretty_display(var: Any, title="Output", width=100):
    """
    Beautified pretty print function for any variable.
    """
    _print_header(title, width)
    print()  # Line break before content

    # Handle string-based types
    if isinstance(var, str):
        _print_highlighted_string(var)

    # Handle dict or list
    elif isinstance(var, (dict, list)):
        pretty_json = json.dumps(var, indent=4, ensure_ascii=False)
        _print_highlighted_string(pretty_json, lexer=JsonLexer)

    # Other objects
    else:
        pprint.pprint(var, width=width)

    print(f"\n{'╰' + '─' * (width - 2) + '╯'}\n")

def _print_header(title, width):
    """
    Prints a stylized header box for better visual separation.
    """
    print(f"\n{'╭' + '─' * (width - 2) + '╮'}")
    centered_title = f" {title} ".center(width - 2, '─')
    print(f"│{centered_title}│")
    print(f"{'├' + '─' * (width - 2) + '┤'}")

def _print_highlighted_string(text: str, lexer=None):
    """
    Highlights and prints a string using Pygments, if available.
    """
    if HAS_PYGMENTS:
        if not lexer:
            lexer = SqlLexer if "SELECT" in text.upper() else PythonLexer
        print(highlight(text.strip(), lexer(), TerminalFormatter()))
    else:
        print(textwrap.dedent(text).strip())


# S3 Utils | Get Redirect Link

In [6]:
import boto3
import io
import base64
from urllib.parse import urlparse
import posixpath  # Add this at the top
import requests
# Add these imports (some may already exist in your file)
import os
import fnmatch
import mimetypes
from pathlib import Path
import posixpath
import hashlib
from typing import Optional, Tuple

class S3Utils:
    def __init__(self, aws_access_key_id: str, aws_secret_access_key: str, endpoint_for_redirect_link: str, s3_directory: str, region_name: str = 'us-east-2'):
        self.session = boto3.Session(
            aws_access_key_id=aws_access_key_id,
            aws_secret_access_key=aws_secret_access_key,
            region_name=region_name
        )
        self.s3 = self.session.client('s3')
        self.endpoint_for_redirect_link = endpoint_for_redirect_link
        self.s3_directory = s3_directory

    def _parse_s3_uri(self, s3_uri: str):
        if not s3_uri.startswith("s3://"):
            raise ValueError("Invalid S3 URI")
        parsed = urlparse(s3_uri)
        bucket = parsed.netloc
        prefix = parsed.path.lstrip("/")
        return bucket, prefix

    def list_objects(self, bucket=None, prefix=None, s3_uri=None):
        if s3_uri:
            bucket, prefix = self._parse_s3_uri(s3_uri)
        elif not (bucket and prefix):
            raise ValueError("Provide either s3_uri or bucket and prefix")

        paginator = self.s3.get_paginator('list_objects_v2')
        full_s3_uris = []

        normalized_prefix = prefix.rstrip("/")

        for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
            contents = page.get("Contents", [])
            for obj in contents:
                key = obj["Key"].rstrip("/")
                if key == normalized_prefix:
                    continue  # Skip folder object matching the prefix itself
                full_s3_uris.append(f"s3://{bucket}/{obj['Key']}")

        print(f"Found {len(full_s3_uris)} valid objects in {bucket}/{prefix}")
        return full_s3_uris

    def delete_file(self, s3_uri: str):
        """
        Deletes a specific file from S3. Validates existence and ensures it's not a directory.

        Args:
            s3_uri (str): Full S3 URI of the file to delete.

        Raises:
            ValueError: If the URI is a directory or the object doesn't exist.
        """
        if not s3_uri.startswith("s3://"):
            raise ValueError("s3_uri must start with 's3://'")

        bucket, key = self._parse_s3_uri(s3_uri)

        # Validate it's a file (not a folder key or just a prefix)
        if key.endswith("/") or key.rstrip("/") == key.split("/")[-1]:
            raise ValueError("Only files can be deleted. Directories are not allowed.")

        # Check existence first
        try:
            self.s3.head_object(Bucket=bucket, Key=key)
        except self.s3.exceptions.ClientError as e:
            if e.response['Error']['Code'] == '404':
                raise FileNotFoundError(f"File not found: {s3_uri}")
            else:
                raise

        # If exists, delete it
        self.s3.delete_object(Bucket=bucket, Key=key)
        print(f"Deleted file from {s3_uri}")

    def upload_file(self, df, file_format='csv'):
        if file_format not in ['csv']:
            raise ValueError("Supported formats: csv")
        
        s3_uri = self.s3_directory
        ## Naming of the CSV file
        name_suffix = str(uuid.uuid4()).split('-')[0]
        uploaded_file_name = f'atb_{name_suffix}.csv'
        s3_uri_ = f'{s3_uri}/{uploaded_file_name}'

        bucket, key = self._parse_s3_uri(s3_uri_)
        # print(f'Key: {key}')

        buffer = io.BytesIO()

        if file_format == 'csv':
            df.to_csv(buffer, index=False, encoding='utf-8-sig')
            content_type = 'text/csv'

        buffer.seek(0)
        self.s3.upload_fileobj(buffer, Bucket=bucket, Key=key,
                               ExtraArgs={'ContentType': content_type})

        # print(f"Uploaded file to s3://{bucket}/{key}")
        return s3_uri_
 
    def generate_presigned_url(self, s3_uri: str, expiration: int = 600) -> str:
        """
        Generates a pre-signed URL for downloading a file from S3.

        Args:
            s3_uri (str): Full S3 URI (must include filename).
            expiration (int): Link expiration time in seconds (default: 600 = 10 minutes).

        Returns:
            str: Pre-signed HTTPS URL for download.
        """
        bucket, key = self._parse_s3_uri(s3_uri)
        filename = key.split("/")[-1]  # Extract just the filename from key

        print(f'bucket: {bucket}')
        print(f'key: {key}')

        try:
            url = self.s3.generate_presigned_url(
                'get_object',
                Params={
                    'Bucket': bucket,
                    'Key': key,
                    'ResponseContentDisposition': f'attachment; filename="{filename}"'
                },
                ExpiresIn=expiration,
            )
            return url
        except Exception as e:
            print(f"Failed to generate presigned URL: {e}")
            raise

    def get_redirect_link_from_df(self, df):
        s3_uri_of_the_file = None
        download_link = None
        redirect_link = None

        try:
            # Step 1: Upload the DataFrame to S3
            try:
                s3_uri_of_the_file = self.upload_file(df.copy())
            except Exception as upload_err:
                print(f"[Upload Error] Failed to upload file to S3: {str(upload_err)[:300]}")
                return s3_uri_of_the_file, download_link, redirect_link

            # Step 2: Generate presigned URL
            try:
                download_link = self.generate_presigned_url(s3_uri_of_the_file)
            except Exception as url_err:
                print(f"[Presigned URL Error] Failed to generate presigned URL: {str(url_err)[:300]}")
                return s3_uri_of_the_file, download_link, redirect_link

            # Step 3: Request a redirect link from external API
            try:
                response = requests.post(
                    self.endpoint_for_redirect_link,
                    json={"presigned_url": download_link},
                    timeout=20
                )

                if response.status_code == 200:
                    data = response.json()
                    redirect_link = data.get("download_url")
                    if not redirect_link:
                        print("[Redirect Link Error] Response JSON did not contain 'download_url'")
                else:
                    print(f"[Redirect Link Error] Non-200 response: {response.status_code} - {response.text[:300]}")

            except Exception as api_err:
                print(f"[API Request Error] Failed to get redirect link: {str(api_err)[:300]}")

        except Exception as e:
            print(f"[General Error] Unexpected failure: {str(e)[:500]}")

        return s3_uri_of_the_file, download_link, redirect_link


    # Add this method inside the S3Utils class
    def upload_directory(
        self,
        src_dir: str,
        dest_s3_uri: str,
        *,
        exclude_exts: tuple[str, ...] = (".pkl",),
        exclude_dirs: tuple[str, ...] = ("__pycache__", ".venv", ".git", ".mypy_cache", ".pytest_cache"),
        exclude_files: tuple[str, ...] = (".DS_Store",),
        exclude_globs: tuple[str, ...] = ("*.pyc", "*.pyo", ".coverage*", "*.egg-info*", "*.ipynb_checkpoints*", "*.log"),
        dry_run: bool = False,
    ) -> dict:
        """
        Upload an entire local directory tree to an S3 prefix, preserving structure.
        Excludes common cache/virtualenv/artifact files and directories.

        Args:
            src_dir: Local directory to upload.
            dest_s3_uri: Destination S3 URI prefix (e.g., 's3://my-bucket/path/to/base/').
            exclude_exts: File extensions to exclude (exact suffix match, case-insensitive).
            exclude_dirs: Directory names to skip (match on any path part).
            exclude_files: Specific file basenames to skip.
            exclude_globs: Glob patterns (fnmatch) applied to basenames to skip.
            dry_run: If True, don't upload—just return what would be uploaded/skipped.

        Returns:
            dict with:
                - "uploaded": list of s3://bucket/key uploaded (or would upload if dry_run=True)
                - "skipped": list of local file paths skipped
                - "errors":  list of (local_path, error_message)
        """
        src_path = Path(src_dir).expanduser().resolve()
        if not src_path.exists() or not src_path.is_dir():
            raise ValueError(f"Source directory not found or not a directory: {src_dir}")

        # Parse destination S3 URI
        bucket, prefix = self._parse_s3_uri(dest_s3_uri)
        # Ensure prefix ends with a single '/'
        if prefix and not prefix.endswith("/"):
            prefix = prefix + "/"

        uploaded: list[str] = []
        skipped: list[str] = []
        errors: list[tuple[str, str]] = []

        # Normalize exclusion sets for case-insensitive checks where appropriate
        exclude_exts_lower = tuple(e.lower() for e in exclude_exts)

        for root, dirs, files in os.walk(src_path):
            # Prune excluded directories in-place so os.walk doesn't descend into them
            pruned_dirs = []
            for d in list(dirs):
                if d in exclude_dirs:
                    pruned_dirs.append(d)
            for d in pruned_dirs:
                dirs.remove(d)

            for fname in files:
                local_path = Path(root) / fname
                rel_path = local_path.relative_to(src_path)
                rel_posix = rel_path.as_posix()

                # Exclusion checks
                name_lower = fname.lower()
                suffix_lower = local_path.suffix.lower()

                if fname in exclude_files:
                    skipped.append(str(local_path))
                    continue

                if suffix_lower in exclude_exts_lower:
                    skipped.append(str(local_path))
                    continue

                # Match globs against basename and posix rel path (to allow patterns like '*/.ipynb_checkpoints/*')
                if any(fnmatch.fnmatch(fname, pattern) or fnmatch.fnmatch(rel_posix, pattern) for pattern in exclude_globs):
                    skipped.append(str(local_path))
                    continue

                # Also skip if any parent part matches excluded directory names (defense-in-depth)
                if any(part in exclude_dirs for part in rel_path.parts):
                    skipped.append(str(local_path))
                    continue

                # Build S3 key using POSIX joining to ensure '/'
                key = posixpath.join(prefix, rel_posix) if prefix else rel_posix

                # Guess content-type
                ctype, _ = mimetypes.guess_type(fname)
                extra_args = {"ContentType": ctype} if ctype else None

                s3_uri_out = f"s3://{bucket}/{key}"

                try:
                    if not dry_run:
                        if extra_args:
                            self.s3.upload_file(str(local_path), Bucket=bucket, Key=key, ExtraArgs=extra_args)
                        else:
                            self.s3.upload_file(str(local_path), Bucket=bucket, Key=key)
                    uploaded.append(s3_uri_out)
                except Exception as e:
                    errors.append((str(local_path), str(e)))

        return {"uploaded": uploaded, "skipped": skipped, "errors": errors}

    # ==================================================== # 
    ## Download Feature ## 
    # ==================================================== # 

    # Add this method inside the S3Utils class
    def download_uris_to_tree(
        self,
        file_uris: list[str],
        dest_dir: str = ".",
        anchor: str = "atb_all_modules_import/",
        skip_suffixes: tuple[str, ...] = (),
    ) -> dict:
        """
        Download a list of S3 URIs into the local filesystem, preserving the folder
        structure *after* the given `anchor` segment. For example, for:
            s3://atb-chatbot/atb_all_modules_import/modules/chart_modules/bar_chart_generator.py
        the part after `anchor` is:
            modules/chart_modules/bar_chart_generator.py
        which will be recreated under `dest_dir`.

        Args:
            file_uris: List of S3 URIs to download.
            dest_dir: Local base directory to create (defaults to current directory).
            anchor: Path segment in S3 key after which the structure is preserved.
            skip_suffixes: File suffixes to skip (e.g., (".DS_Store",)).

        Returns:
            dict with:
                - "downloaded": list of local file paths successfully downloaded.
                - "errors": list of (s3_uri, error_message) for failures.
        """
        dest_root = Path(dest_dir).expanduser().resolve()
        dest_root.mkdir(parents=True, exist_ok=True)

        downloaded: list[str] = []
        errors: list[tuple[str, str]] = []

        for s3_uri in file_uris:
            try:
                bucket, key = self._parse_s3_uri(s3_uri)

                # Skip folder-like keys
                if key.endswith("/"):
                    continue

                # Optionally skip unwanted files by suffix
                if skip_suffixes and any(key.endswith(sfx) for sfx in skip_suffixes):
                    continue

                # Determine the relative key after the anchor
                anchor_idx = key.find(anchor)
                if anchor_idx != -1:
                    rel_key = key[anchor_idx + len(anchor):]
                else:
                    # If anchor not found, fallback to just the filename
                    rel_key = key.split("/")[-1]

                # Normalize to a local path
                local_rel_path = Path(*rel_key.split("/"))
                local_path = dest_root / local_rel_path

                # Ensure parent directories exist
                local_path.parent.mkdir(parents=True, exist_ok=True)

                # Download
                self.s3.download_file(Bucket=bucket, Key=key, Filename=str(local_path))

                downloaded.append(str(local_path))
            except Exception as e:
                errors.append((s3_uri, str(e)))

        return {"downloaded": downloaded, "errors": errors}


## Initialize
s3_utils = S3Utils(aws_access_key_id=aws_access_key_id, aws_secret_access_key = aws_secret_access_key, region_name=region_name, endpoint_for_redirect_link=endpoint_for_redirect_link, s3_directory=s3_directory)

In [7]:
# # ## List All the Files
# # file_list = s3_utils.list_objects(s3_uri="s3://document-management-new/dev-test/atb_csv")
# file_list = s3_utils.list_objects(s3_uri="s3://atb-chatbot/ahp/qa/atb_csv")

# file_list
# s3_utils.generate_presigned_url('s3://atb-chatbot/ahp/qa/atb_csv/atb_196655e3.csv')
# s3_utils.generate_presigned_url('s3://atb-chatbot/ahp/qa/atb_csv/atb_010d427c.csv')

### Upload Directory

In [8]:
# --------
# Example usage:
# --------

result = s3_utils.upload_directory(
    src_dir="../BEE_API/sc-bee-app-ahp",                          # local directory to upload
    dest_s3_uri="s3://atb-chatbot/atb_all_modules_import/ahp/qa",  # destination S3 prefix
    exclude_exts=(".pkl", ".db", ".sqlite", ".sqlite3"),
    exclude_dirs=("__pycache__", "venv", "bee-venv", ".git", ".mypy_cache", ".pytest_cache"),
    exclude_files=(".DS_Store",),
    exclude_globs=("*.pyc", "*.pyo", "*.log", "*.ipynb_checkpoints*", "*.egg-info*", "*/node_modules/*"),
    dry_run=False,                                    # set True to preview
)

print("Uploaded:", len(result["uploaded"]))
print("Skipped:", len(result["skipped"]))

if result["errors"]:
    print("Errors:", result["errors"])

Uploaded: 49
Skipped: 5


### Download Directory

In [9]:
# files_to_list = s3_utils.list_objects(s3_uri="s3://atb-chatbot/atb_all_modules_import/ahp/qa/OOP_CODING")

# # Example usage:
# results = s3_utils.download_uris_to_tree(
#     file_uris=files_to_list,
#     dest_dir="./temp_use_case_atb_modules",                           # optional; default "."
#     anchor="OOP_CODING/",                   # keep structure AFTER this segment
#     skip_suffixes=(".DS_Store",),                       # optional skipping
# )
# print("Downloaded:", len(results["downloaded"]))
# if results["errors"]:
#     print("Errors:", results["errors"])

# Reading Knowledge Base - Data Dictionary | DONE

In [10]:
class KnowledgeBaseReader:
    def __init__(self, knowledge_base_directory: str, knowledge_base_file_name: str):
        """
        Initializes the KnowledgeBaseReader class with the file directory and file name for the knowledge base.

        Args:
            knowledge_base_directory (str): The directory where the knowledge base file is located.
            knowledge_base_file_name (str): The file name of the knowledge base (Excel file).
        """
        self.knowledge_base_directory = knowledge_base_directory
        self.knowledge_base_file_name = knowledge_base_file_name

    def get_cleaned_columns_info(self) -> pd.DataFrame:
        """
        Loads the knowledge base from the Excel file, cleans it, and processes the column information,
        and returns the final DataFrame.

        Returns:
            pd.DataFrame: Processed DataFrame with the column information.
        """
        # Load the Excel file
        file_path = f'{self.knowledge_base_directory}/{self.knowledge_base_file_name}'
        data_dictionary_dataframe = pd.read_excel(file_path).fillna('')

        # Clean column names by stripping whitespace
        data_dictionary_dataframe.columns = data_dictionary_dataframe.columns.str.strip()

        # Remove specific columns
        columns_to_delete = ['Data Set Type', 'Definition']
        data_dictionary_dataframe.drop(
            columns=[col for col in columns_to_delete if col in data_dictionary_dataframe.columns],
            axis=1,
            inplace=True
        )

        # Clean the column values by stripping whitespace
        for col in data_dictionary_dataframe.columns:
            data_dictionary_dataframe[col] = data_dictionary_dataframe[col].str.strip()

        # Process the 'Column Name' and 'Synonyms' fields
        index_to_work_with = data_dictionary_dataframe.index
        for idx in index_to_work_with:
            column_name = data_dictionary_dataframe.at[idx, 'Column Name']

            # Clean and format the 'Column Name'
            processed_column_name = column_name.replace('emr_', '').replace('fin_', '')
            processed_column_name = (' ').join(processed_column_name.split('_')).title()

            # Skip 'Asa Rating'
            if processed_column_name == 'Asa Rating':
                continue

            # Update the 'Synonyms' field
            if processed_column_name not in data_dictionary_dataframe.at[idx, 'Synonyms']:
                updated_synonyms = processed_column_name + ', ' + data_dictionary_dataframe.at[idx, 'Synonyms']
                data_dictionary_dataframe.at[idx, 'Synonyms'] = updated_synonyms.strip(', ').strip()

        return data_dictionary_dataframe

# EmbeddingManager -> Data Dictionary

- Knowledge Base Chroma | Spine Tables | Synonyms -> table_module.py 
- llm_knowledge_base_clinical_fin.xlsx

In [11]:
class EmbeddingManager:
    def __init__(self, embedding_function, persist_directory, collection_name):
        """
        Initializes the EmbeddingManager with the necessary parameters.

        Args:
            embedding_function: An instance of an embedding function or model (e.g., OpenAIEmbeddings).
            persist_directory (str): Directory path where embeddings are saved.
            collection_name (str): Name of the collection to manage.
        """
        self.embedding_function = embedding_function
        self.persist_directory = persist_directory
        self.collection_name = collection_name

    def is_collection_available(self) -> bool:
        """
        Check if the collection is available in the persist directory by looking for a Pickle (.pkl) file.

        Returns:
            bool: True if the collection exists, False otherwise.
        """
        collection_path = os.path.join(self.persist_directory, f'{self.collection_name}.pkl')
        return os.path.exists(collection_path)

    def save_to_persist_directory(self, data_dictionary_dataframe: pd.DataFrame) -> pd.DataFrame:
        """
        Save column information embeddings to the specified persist directory using Pickle.

        Args:
            data_dictionary_dataframe (pd.DataFrame): DataFrame containing column details with
                'Column Name', 'Synonyms', and 'Brief Details' columns.

        Returns:
            pd.DataFrame: DataFrame containing the text and embeddings that were saved.
        """
        # Ensure the persist directory exists
        if not os.path.exists(self.persist_directory):
            os.makedirs(self.persist_directory)

        # String to embed
        em_docs = (
            data_dictionary_dataframe['Column Name'] + ' | ' +
            data_dictionary_dataframe['Synonyms'] + ' | ' +
            data_dictionary_dataframe['Brief Details']
        ).tolist()

        # Generate embeddings
        embeddings = self.embedding_function.embed_documents(em_docs)

        # Create DataFrame with text and embeddings
        df_embeddings = pd.DataFrame({
            'text': em_docs,
            'embeddings': embeddings
        })

        # Save the DataFrame to a pickle file
        file_path = os.path.join(self.persist_directory, f'{self.collection_name}.pkl')
        with open(file_path, 'wb') as f:
            pickle.dump(df_embeddings, f)

        print(f'Saved to: {file_path}')

        return df_embeddings

    def load_and_update_embeddings(self, data_dictionary_dataframe: pd.DataFrame) -> pd.DataFrame:
        """
        Load an existing Pickle collection, identify new or changed documents, remove old ones,
        and update the collection with new or changed documents.

        Args:
            data_dictionary_dataframe (pd.DataFrame): DataFrame containing column details with
                'Column Name', 'Synonyms', and 'Brief Details' columns.

        Returns:
            pd.DataFrame: Updated DataFrame containing the text and embeddings.
        """
        file_path = os.path.join(self.persist_directory, f'{self.collection_name}.pkl')

        # Load existing embeddings
        if os.path.exists(file_path):
            all_docs = pd.read_pickle(file_path)
            existing_documents = all_docs['text'].tolist()
            updated_df = all_docs.copy()
        else:
            all_docs = pd.DataFrame(columns=['text', 'embeddings'])
            existing_documents = []
            updated_df = all_docs.copy()

        # Prepare new documents list
        new_documents = (
            data_dictionary_dataframe['Column Name'] + ' | ' +
            data_dictionary_dataframe['Synonyms'] + ' | ' +
            data_dictionary_dataframe['Brief Details']
        ).tolist()

        # Find new or changed documents
        new_or_changed_documents = [doc for doc in new_documents if doc not in existing_documents]

        # Find documents to remove
        documents_to_remove = [doc for doc in existing_documents if doc not in new_documents]

        # Remove old documents
        if documents_to_remove:
            updated_df = updated_df[~updated_df['text'].isin(documents_to_remove)].reset_index(drop=True)
            print(f'\n\t# Documents Removed:')
            for idx, doc in enumerate(documents_to_remove):
                print(f'\t{idx+1}: {doc}')

        # Update only if there are new or changed documents
        if new_or_changed_documents:
            # Add new or changed documents
            new_embeddings = self.embedding_function.embed_documents(new_or_changed_documents)
            new_df = pd.DataFrame({
                'text': new_or_changed_documents,
                'embeddings': new_embeddings
            })
            final_df = pd.concat([updated_df, new_df], ignore_index=True)

            # Save updated embeddings to Pickle
            with open(file_path, 'wb') as f:
                pickle.dump(final_df, f)

            print(f'\n\t# Documents Added:')
            for idx, doc in enumerate(new_or_changed_documents):
                print(f'\t{idx+1}: {doc}')

            print(f'\nKnowledge Base Updated and saved to: {file_path}')
            return final_df
        else:
            print('No changes detected. No updates were made inside the Knowledge Base.')
            return updated_df

# DBInfoManager

- table_info_from_db
- column_dictionary

In [12]:
class DBInfoManager:
    def __init__(self, connection_pool):
        """
        Initializes the DBInfoManager class with a connection pool and prepares the table information DataFrame
        and column dictionary.

        Args:
            connection_pool: Connection pool object to manage database connections.
        """
        self.connection_pool = connection_pool
        self.table_info_from_db = None
        self.column_dictionary = None

    def _get_column_names(self, table_name: str):
        """
        Fetches column names for a given table from the database using the connection pool.

        Args:
            table_name (str): The fully qualified table name (e.g., 'schema.table').

        Returns:
            list: A list of column names.
        """
        schema_name = table_name.split('.')[0]
        table_name = table_name.split('.')[-1]

        # SQL query to get column names
        query = f'''
        SELECT column_name
        FROM information_schema.columns
        WHERE table_schema = '{schema_name}' AND table_name = '{table_name}'
        ORDER BY ordinal_position;
        '''

        conn = None
        column_list = []

        try:
            conn = self.connection_pool.getconn()

            with conn.cursor() as cursor:
                cursor.execute(query)
                col_result = cursor.fetchall()

            # Extract column names from the result
            column_list = [row[0] for row in col_result]

        except Exception as error:
            print(f"Error fetching column names for {table_name}: {error}")

        finally:
            if conn:
                self.connection_pool.putconn(conn)

        return column_list

    def _initialize_table_data(self, table_info):
        """
        Initializes table information and column dictionary by fetching the column names for each table.
        """

        # Fetch column names for each table
        table_columns = {}
        for table, description in table_info.items():
            table_columns[table] = self._get_column_names(table)

        # Remove redundant columns (optional step)
        delete_columns = ['emr_patient_class']
        for col in delete_columns:
            for table in table_columns:
                if col in table_columns[table]:
                    table_columns[table].remove(col)

        # Create the table_info_from_db DataFrame
        data = []
        for table, description in table_info.items():
            related_info = f"Table {table} contains {description.lower()}."
            data.append([description, table, table_columns[table], related_info])

        self.table_info_from_db = pd.DataFrame(data, columns=["table_for", "table_name", "columns", "table_related_info"])

        # Create the column dictionary
        self.column_dictionary = {row['table_name']: row['columns'] for _, row in self.table_info_from_db.iterrows()}

        return self.table_info_from_db, self.column_dictionary

# QueryEvaluationEngine -> Comparative Question or Not? 

In [13]:
# -----------------------------
# Pydantic Schema for Output
# -----------------------------
class AnalyticalQueryIntent(BaseModel):
    """
    Structured output indicating the type(s) of analytical intent present in the query.
    More fields (like utilization, price parity, etc.) can be added over time.
    """
    is_matrix_comparison: int = Field(..., description="1 if the query compares two or more matrices, otherwise 0")

# -----------------------------
# Query Evaluation Engine Class
# -----------------------------

class QueryEvaluationEngine:
    
    
    """
    The QueryEvaluationEngine checks what kind of analytical question the user is asking.

    Right now, it can tell if the question is about comparing two or more matrices.

    In the future, it can be expanded to detect other types of questions too,
    like utilization, price comparisons, or trends.
    """


    def __init__(self, model: str = "gpt-4.1-mini", temperature: float = 0.0):
        """
        Initializes the engine with the specified model and temperature.

        Args:
            model (str): OpenAI model name. Default is 'gpt-4.1.mini'.
            temperature (float): LLM temperature for deterministic output. Default is 0.
        """
        self.llm = ChatOpenAI(model=model, temperature=temperature)
        self.prompt_template = self._build_prompt_template()
        self.chain = self._build_chain()

    def _build_prompt_template(self) -> PromptTemplate:
        """
        Builds the prompt template used to evaluate whether a query involves matrix comparison.

        Returns:
            PromptTemplate: A fully formatted prompt with detailed instructions and examples.
        """
        return PromptTemplate(
            input_variables=["text"],
            template="""MATRIX COMPARISON CLASSIFIER - DECISION TASK

Goal: Classify whether the user's question involves comparing between 2-3 explicitly named entities.

SCHEMA:
--------
- is_matrix_comparison (integer)
  - 1 → The query explicitly compares 2-3 specifically named entities (markets, facilities, etc.)
  - 0 → The query does not involve comparison between specific named entities.

RULES FOR CLASSIFICATION:
--------------------------
Return '1' if:
- The question explicitly compares 2-3 individually named entities of the same type:
  * Named markets (e.g., "LORAIN and YOUNGSTOWN", "RICHMOND vs DALLAS")
  * Named facilities (e.g., "Facility A compared to Facility B")
  * Named physicians (e.g., "Dr. Smith compared to Dr. Jones")
  * Named matrices (e.g., "Matrix A vs Matrix B")
  * Named regions (e.g., "East region vs West region")

IMPORTANT: Focus on what's being compared, not what's being analyzed:
- If comparing specific named markets/facilities/entities → Return 1
- Even if analyzing "top N" items or multiple procedures WITHIN those markets → Still Return 1

Return '0' if:
- The question uses "across" terminology without naming specific entities to compare
- The question is about general patterns without naming specific entities to compare
- The question only mentions comparison over time periods
- It asks for trends or variations within a single matrix/entity
- It uses comparison language but doesn't specify which entities to compare
- Uses terms like "between markets" or "across facilities" without naming specific entities

EXAMPLES:
----------
✅ "Compare the cost matrices between Facility A and Facility B" → 1
✅ "What's the difference between performance in RICHMOND and DALLAS?" → 1
✅ "How do metrics vary between Dr. Smith, Dr. Jones and Dr. Williams?" → 1
✅ "How does avg CPC vary across top 20 procedures for @market LORAIN and @market YOUNGSTOWN?" → 1
✅ "Compare the top 10 primary procedures by avg cpc for @market LORAIN and @market YOUNGSTOWN" → 1

❌ "Compare the Total Encounters, Total Acquisition Cost across Markets" → 0
❌ "Compare the Outcome Scores across facilities for @Market RICHMOND" → 0
❌ "What is the Supply Cost to Actual Patient Revenue ratio across facilities?" → 0
❌ "Compare the top 10 physicians based on performance" → 0
❌ "Compare facilities based on cost per case" → 0
❌ "Compare the Total Charge across facilities" → 0
❌ "Show me how the values changed from January to December" → 0

NOTE:
- The key requirement is comparison between specific named entities (like LORAIN and YOUNGSTOWN).
- It doesn't matter if we're analyzing "top N" procedures or multiple metrics within those entities.
- Terms like "across markets" without naming specific markets should return 0.
- Return only valid JSON with 'is_matrix_comparison' field.

INPUT:
{text}
"""
        )

    def _build_chain(self):
        """
        Creates the structured output chain combining the prompt with the LLM.

        Returns:
            Runnable chain for evaluating user input.
        """
        structured_llm = self.llm.with_structured_output(AnalyticalQueryIntent)
        return self.prompt_template | structured_llm

    def classify(self, user_query: str) -> AnalyticalQueryIntent:
        """
        Evaluates the user query to determine if it involves a matrix comparison.

        Args:
            user_query (str): The input query from the user.

        Returns:
            MatrixQueryDecision: A structured result with is_matrix_comparison = 0 or 1.
        """
        return self.chain.invoke({"text": user_query})

    def get_sql_template_if_applicable(self, user_query: str) -> str | None:
        """
        Returns a fully generic SQL template for matrix comparison if the intent matches.

        Args:
            user_query (str): Input question.

        Returns:
            str | None: SQL guidance and template if matrix comparison is detected.
        """
        result = self.classify(user_query)

        if result.is_matrix_comparison == 1:
            return """⸻

🧩 Fully Generic SQL Template for Comparison Queries

Question Sample: 
\t1.\tExplicit comparison using “compare”
Compare the top N entities based on a specific metric across multiple groups within a selected category and time period.
\t2.\tImplicit comparison using “how does X vary”
How does the average value of a metric vary across top entities between two selected groups for a given timeframe and category?
\t3.\tMulti-metric comparison
How do multiple metrics (e.g., cost, encounter count, and score) vary across top-ranked entities within two or more comparison groups?
\t4.\tGroup-based performance comparison
Show the differences in performance metrics across departments or facilities for a specific category over a defined period.
\t5.\tEntity-to-entity variation
How do key metrics differ between selected individuals or groups (e.g., physicians, vendors, facilities) for top items or procedures in a specific domain?

SQL Query Template:
⸻⸻⸻⸻

WITH base_filtered AS (
    SELECT 
        -- required entity columns (e.g., procedure, physician, item),
        -- grouping columns (e.g., market, facility),
        -- unique identifier for event-level data,
        -- relevant metric column (e.g., cost, score, revenue)
    FROM <your_source_table>
    WHERE 
        -- apply filters based on question (e.g., year, category, market, service line, etc.)
        -- exclude irrelevant or unknown values if needed
),
top_entities AS (
    SELECT 
        <entity_column>,
        -- calculate metric (e.g., AVG, SUM, COUNT) as required by the question
    FROM base_filtered
    GROUP BY <entity_column>
    ORDER BY <calculated_metric> DESC
    LIMIT <N>
),
comparison_result AS (
    SELECT 
        <entity_column>,
        <comparison_group_column>,  -- e.g., market, facility, physician
        -- calculate metric again for each comparison group
    FROM base_filtered bf
    JOIN top_entities te 
      ON bf.<entity_column> = te.<entity_column>
    GROUP BY <entity_column>, <comparison_group_column>
)
SELECT * 
FROM comparison_result
ORDER BY <comparison_group_column>, <calculated_metric> DESC;

⸻⸻⸻⸻⸻⸻

🧭 How to Use It (LLM Guidance)

Placeholder\tWhat to Replace It With
<your_source_table>\tThe source dataset or view
<entity_column>\tWhat you’re comparing (e.g., procedure, physician, product, etc.)
<comparison_group_column>\tThe groups you’re comparing across (e.g., market, region, etc.)
<calculated_metric>\tMetric to compare (e.g., average cost, count, score, etc.)
<N>\tNumber of top entities to include in the comparison
⸻⸻⸻⸻⸻⸻
"""
        return None
    
query_evaluation_engine = QueryEvaluationEngine()

#### Rough Run - Query Evaluation Engine

In [14]:
# query = "compare the top 10 primary procedures by average cpc for Market Lorain and Market Youngstown for Service line WOMEN'S HEALTH for 2024"
# query = "Whats my name?"
# query = "Compare the spend variations over time."
# query = "Compare the  Total Encounters, Total Acquisition Cost, and Outcome Score across Facilities for @Market RICHMOND."
# query = "How does avg CPC, and total encounter vary across top 20 procedures for @market LORAIN and @market @YOUNGSTOWN for @service line @WOMEN'S HEALTH?"
# query = "Compare the top 10 physicians based on the number of encounters and total acquisition cost."
# query = 'Whats the total encounter count between Physician A and PHysician B?'


# # ## Optional
# classify_result = query_evaluation_engine.classify(query).is_matrix_comparison
# print(f'Query: {query}\nClassified as =>> {classify_result}')  # ➝ 1

# # ## Main Code
# comparison_sql_template_prompt = query_evaluation_engine.get_sql_template_if_applicable(query)
# print(f'\nPrompt To Be Used: \n{comparison_sql_template_prompt}')

In [15]:
# check_questions = pd.read_excel(f'Data_For_LLM/Question_Bank_For_BEE_AHP.xlsx')
# all_questions = check_questions['Questions'].tolist()

# records = []
# for q in range(len(all_questions)):
#     question = all_questions[q]
#     print(f'\nQuestion: {question}')
#     classify_result = query_evaluation_engine.classify(question).is_matrix_comparison
#     print(f'Classified as =>> {classify_result}')  # ➝ 1
#     print('-----------------------------------')

#     records.append([classify_result, question])

# temp_df = pd.DataFrame(records, columns=['is_matrix_comparison', 'Question'])
# # temp_df.to_excel('Matrix_Comparison_Report.xlsx', index=False)

# EmbeddingTaskHelper 

- generating_df_from_embeddings
- helper_function_for_table_selection
- find_appropriate_row_from_complex_query_sheet
- get_column_info_for_sql_generation

In [16]:
class EmbeddingTaskHelper:
    def __init__(self, embedding_function, df_embeddings, column_dictionary, data_dictionary_dataframe, table_info_from_db, question_df):
        """
        Initializes the EmbeddingHandling class with the necessary embedding function.

        Args:
            embedding_function: The embedding function or model used to generate embeddings.
        """
        self.embedding_function = embedding_function
        self.df_embeddings = df_embeddings
        self.column_dictionary = column_dictionary
        self.data_dictionary_dataframe = data_dictionary_dataframe
        self.table_info_from_db = table_info_from_db
        self.question_df = question_df

        self.query_evaluation_engine = QueryEvaluationEngine()

    def generating_df_from_embeddings(
        self,
        user_prompt: str,
        df_embeddings: pd.DataFrame,
        list_of_columns: list,
        data_dictionary_dataframe: pd.DataFrame,
        n_top: int = 8
    ) -> tuple:
        """
        Generates a DataFrame from embeddings by finding the best matching columns based on the user's query.

        Args:
            user_prompt (str): The query from the user.
            df_embeddings (pd.DataFrame): DataFrame containing stored embeddings and text.
            list_of_columns (list): List of columns to filter.
            data_dictionary_dataframe (pd.DataFrame): DataFrame containing column information for tables.
            n_top (int, optional): The number of top matches to return. Default is 8.

        Returns:
            tuple: A tuple containing the filtered DataFrame, list of best match columns, and the original DataFrame with similarity scores.
                - filtered_df (pd.DataFrame): DataFrame with filtered and merged results.
                - best_match_columns (list): List of best matching column names.
                - open_df (pd.DataFrame): Original DataFrame with similarity scores.
        """
        # Generate embedding for the user prompt
        user_embedding = self.embedding_function.embed_query(user_prompt)
        
        # Calculate cosine similarity between user embedding and stored embeddings
        stored_embeddings = np.vstack(df_embeddings['embeddings'].to_numpy())
        similarity_scores = cosine_similarity([user_embedding], stored_embeddings)[0]
        
        # Prepare the DataFrame for results
        df_embeddings = df_embeddings.copy()  # To avoid modifying the original DataFrame
        df_embeddings['similarity_score'] = similarity_scores
        df_embeddings['col_name'] = df_embeddings['text'].apply(lambda x: x.split(' | ')[0])
        df_embeddings['synonyms'] = df_embeddings['text'].apply(lambda x: x.split(' | ')[1])

        # Filter based on columns and similarity score
        open_df = df_embeddings[df_embeddings['col_name'].isin(list_of_columns)]
        open_df = open_df[open_df['similarity_score'] > 0.40]
        open_df = open_df.reset_index(drop=True)
        best_match_columns = []

        # Reward matching words in synonyms
        if not open_df.empty:
            user_prompt_cleaned = re.sub(r'[^a-zA-Z\s]', '', user_prompt).lower()
            user_prompt_cleaned = ' '.join(user_prompt_cleaned.split())
            
            for index, row in open_df.iterrows():
                vocab_list = [vocab.strip().lower() for vocab in row['synonyms'].split(',')]
                
                for vocab in vocab_list:
                    if vocab in user_prompt_cleaned:
                        open_df.at[index, 'similarity_score'] += 0.15
                        break
            
            open_df = open_df.sort_values('similarity_score', ascending=False).head(n_top)
            best_match_columns = open_df['col_name'].tolist()

        # Merge with the original table for additional info
        filtered_df = pd.merge(
            open_df,
            data_dictionary_dataframe,
            left_on='col_name',
            right_on='Column Name',
            how='left'
        )

        # Combine columns while avoiding duplicates
        final_columns = list(open_df.columns) + [
            col for col in data_dictionary_dataframe.columns
            if col not in ['Column Name', 'Synonyms']
        ]
        filtered_df = filtered_df[final_columns]

        return filtered_df, best_match_columns, open_df
    
    def helper_function_for_table_selection(self, filtered_df):
        """
        Generates a knowledge base string from the provided DataFrame for use in SQL query generation.

        Args:
            filtered_df (DataFrame): The DataFrame containing the required columns ('col_name', 'synonyms', 'Possible Values', 'Special Instruction').

        Returns:
            str: A formatted string representing the knowledge base, including column names, synonyms, possible values, and special instructions.
        """

        # Assuming filtered_df is already defined and contains the required columns
        knowledge_base = []
        for index, row in filtered_df[['col_name', 'synonyms', 'Possible Values', 'Special Insturction']].iterrows():
            # Prepare the information in the desired format
            info = f"\t\t{index+1}. Column Name: {row['col_name']}\n"
            # info += f"\t\t- Synonyms of {row['col_name']}: {row['synonyms'].split(', ')[:4]}\n"
            info += f"\t\t- Synonyms of {row['col_name']}: {row['synonyms'].split(', ')}\n"

        
            # Append the information to the knowledge_base
            knowledge_base.append(info)
            if len(knowledge_base) == 6:
                break

        knowledge_base = ('\n').join(knowledge_base)
        return knowledge_base

    def knowledge_base_function_for_sql_generation(self, filtered_df):
        """
        Generates a knowledge base string from the provided DataFrame for use in SQL query generation.

        Args:
            filtered_df (DataFrame): The DataFrame containing the required columns ('col_name', 'synonyms', 'Possible Values', 'Special Instruction').

        Returns:
            str: A formatted string representing the knowledge base, including column names, synonyms, possible values, and special instructions.
        """

        # Assuming filtered_df is already defined and contains the required columns
        knowledge_base = []
        for index, row in filtered_df[['col_name', 'synonyms', 'Possible Values', 'Special Insturction']].iterrows():
            # Prepare the information in the desired format
            info = f"\t{index+1}. Column Name: {row['col_name']}\n"
            info += f"\t\t- Synonyms: {row['synonyms'].split(', ')}\n"

            possible_values = row.get('Possible Values', '')
            if possible_values:
                possible_values = possible_values.replace("\n", "\n\t\t\t")
                info += f"""\t\t- Possible Values:\n\t\t\t{possible_values}\n"""

            special_instruction = row.get('Special Insturction', '')
            if special_instruction:
                special_instruction = special_instruction.replace("\n", "\n\t\t\t")
                info += f"""\t\t- Note:\n\t\t\t{special_instruction}\n"""

            # Append the information to the knowledge_base
            knowledge_base.append(info)

        # Now knowledge_base contains all the rows in the desired format
        # This can be fed into a model as a knowledge base
        knowledge_base = ('\n\n').join(knowledge_base)


        knowledge_base = f"\n\t- Most Probable Columns, its synonyms and related information are given below in descending order:\n\t============== Context for Probable Columns ================\n{knowledge_base}\n\t=========================\n\tWhile generating the sql query, give great importance to the above context."
        return knowledge_base
    
    def clean_prompt(self, user_prompt):
        """
        Cleans the user prompt by:
        - Removing special characters (except alphanumeric, spaces, and basic punctuation).
        - Normalizing whitespace.
        - Converting to lowercase.

        Args:
            user_prompt (str): The user prompt to clean.

        Returns:
            str: The cleaned user prompt.
        """
        # Replace newline and tab characters with a space, remove unwanted characters, and normalize spaces
        clean_user_prompt = re.sub(
            r'\s+', ' ',
            re.sub(r'[^a-zA-Z0-9\s.,!?]', '', user_prompt.replace('\n', ' ').replace('\t', ' '))
        ).strip().lower()
        return clean_user_prompt

    def check_prompt_conditions_for_system_benchmark(self, user_prompt):
        """
        Check if the word "System Benchmark" is present and "System Benchmark CPC" (or its elaborated version) is absent in the user prompt.

        Args:
            user_prompt (str): The user prompt to check.

        Returns:
            bool: True if the conditions are met, False otherwise.
        """
        # Clean the user prompt
        clean_user_prompt = self.clean_prompt(user_prompt)

        # Check for "system benchmark"
        has_system_benchmark = "system benchmark" in clean_user_prompt

        # Check for "system benchmark cpc" or its elaborated version
        has_system_benchmark_cpc = any(
            phrase in clean_user_prompt
            for phrase in [
                "system benchmark cpc",
                "system benchmark cost per case"
            ]
        )
        # Return True if "system benchmark" is present and "system benchmark cpc" is absent
        return has_system_benchmark and not has_system_benchmark_cpc

    def find_appropriate_row_from_complex_query_sheet(self, df, prompt):
        # Clean the user prompt
        # prompt = re.sub(r'[^a-zA-Z0-9\s]', '', prompt).lower()
        # prompt = re.sub(r'\s+', ' ', re.sub(r'[^a-zA-Z0-9\s.,!?]', '', prompt.replace('\n', ' ').replace('\t', ' '))).strip().lower()

        prompt = self.clean_prompt(prompt)
        for index, row in df.iterrows():
            synonyms = [syn.strip().lower() for syn in row['synonyms'].split(',')]
            for syn in synonyms:
                if syn in prompt:
                    
                    # If the word "System Benchmark" exists and 
                    # "System Benchmark CPC" or its elaborated form is not found, the function returns True.
                    if self.check_prompt_conditions_for_system_benchmark(prompt):
                        continue  
                    return row['prompt']
        
        return None

    ## Updated Prompt ## With Potentital Savings Calculation
    def get_robot_analysis_prompt(self, table_to_use):
        """
        Returns the mandatory robotic vs. non-robotic analysis SQL template
        WITHOUT any time-frame/date filter.
        Now includes potential savings calculation for relevant questions.
        """
        robot_prompt = ''
        try:
            if table_to_use == self.table_info_from_db[self.table_info_from_db['table_for'] == 'Robot Assisted Procedure Analysis']['table_name'].values[0]:
                robot_prompt = '''# MANDATORY TEMPLATE: ROBOTIC PROCEDURE ANALYSIS

    ALWAYS USE THIS EXACT TEMPLATE FOR ANY ROBOTIC VS. NON-ROBOTIC PROCEDURE ANALYSIS

    ## REQUIRED FILTERS
    The following filters MUST be applied to ANY analysis involving robotic procedures:

    1. PATIENT ELIGIBILITY:
    - Include ONLY patients with ASA scores 1-3
    - SQL: WHERE emr_asa_rating NOT IN ('4 - CONSTANT THREAT TO LIFE', '5 - MORIBUND PATIENT', '6 - BRAIN DEAD PATIENT')

    2. ROBOTICS DATA QUALITY:
    - Exclude cases with unknown robotics status 
    - SQL: AND emr_robotics_y_n <> 'NOT AVAILABLE'

    3. PATIENT TYPE STANDARDIZATION:
    - Combine Emergency Cases with Outpatient
    - SQL: CASE WHEN emr_patient_type_bucket = 'EMERGENCY' THEN 'OUTPATIENT' ELSE emr_patient_type_bucket END AS patient_type

    4. PROCEDURE ELIGIBILITY:
    - Include ONLY procedures with BOTH robotic AND non-robotic cases
    - Each procedure MUST HAVE at least one robotic case AND at least one non-robotic case

    5. FACILITY ELIGIBILITY
    - Include only facilities that have at least one robotic case
    - Facilities with zero robotic cases are completely excluded.

    ## MANDATORY SQL IMPLEMENTATION
    WITH
    -- Step 0: Facilities that performed at least one robotic case
    robotics_facility_list AS (
        SELECT DISTINCT
            emr_facility_name
        FROM
            procedures_table           -- ← replace with your table name
        WHERE
            emr_robotics_y_n = 'Y'
            AND emr_asa_rating NOT IN (
                '4 - CONSTANT THREAT TO LIFE',
                '5 - MORIBUND PATIENT',
                '6 - BRAIN DEAD PATIENT'
            )
    ),
    -- Step 1: Procedures that have BOTH robotic and non-robotic cases
    eligible_procedures AS (
        SELECT
            emr_primary_procedure
        FROM
            procedures_table
        WHERE
            emr_asa_rating NOT IN (
                '4 - CONSTANT THREAT TO LIFE',
                '5 - MORIBUND PATIENT',
                '6 - BRAIN DEAD PATIENT'
            )
            AND emr_robotics_y_n <> 'NOT AVAILABLE'
        GROUP BY
            emr_primary_procedure
        HAVING
            SUM(CASE WHEN emr_robotics_y_n = 'Y' THEN 1 ELSE 0 END) > 0
            AND SUM(CASE WHEN emr_robotics_y_n = 'N' THEN 1 ELSE 0 END) > 0
    ),
    -- Step 2: Main filtered dataset
    filtered_robotic_data AS (
        SELECT
            p.*,
            CASE
                WHEN p.emr_patient_type_bucket = 'EMERGENCY' THEN 'OUTPATIENT'
                ELSE p.emr_patient_type_bucket
            END AS patient_type
        FROM
            procedures_table p
            JOIN eligible_procedures ep
                ON p.emr_primary_procedure = ep.emr_primary_procedure
            JOIN robotics_facility_list rfl
                ON p.emr_facility_name = rfl.emr_facility_name
        WHERE
            p.emr_asa_rating NOT IN (
                '4 - CONSTANT THREAT TO LIFE',
                '5 - MORIBUND PATIENT',
                '6 - BRAIN DEAD PATIENT'
            )
            AND p.emr_robotics_y_n <> 'NOT AVAILABLE'
    )
    -- Step 3: Final query
    SELECT
        frd.*,
        /* POTENTIAL SAVINGS CALCULATION */
        CASE
            WHEN AVG(CASE WHEN frd.emr_robotics_y_n = 'N' THEN frd.emr_total_acquisition_cost END) 
                > AVG(CASE WHEN frd.emr_robotics_y_n = 'Y' THEN frd.emr_total_acquisition_cost END)
            THEN ROUND(
                (
                    AVG(CASE WHEN frd.emr_robotics_y_n = 'N' THEN frd.emr_total_acquisition_cost END) 
                    - AVG(CASE WHEN frd.emr_robotics_y_n = 'Y' THEN frd.emr_total_acquisition_cost END)
                ) * COUNT(DISTINCT CASE WHEN frd.emr_robotics_y_n = 'N' THEN frd.emr_p_event END)
            ,1)
        END AS potential_savings
    FROM
        filtered_robotic_data frd;

    ## CRITICAL WARNING
    DO NOT MODIFY THESE CRITERIA UNDER ANY CIRCUMSTANCES. These filters are REQUIRED for all robotic procedure analyses to ensure data consistency and valid comparisons.

    ## IMPLEMENTATION NOTES
    1. Replace "procedures_table" with the actual table name in your database
    2. You may add additional columns or filters as needed for specific analyses
    3. You must NEVER remove or modify the base filters defined above
    4. All metrics (LOS, cost, complications, potential savings, etc.) must be calculated using this filtered dataset
    5. Return only the SQL query, and ensure it ends with a semicolon (‘;’).
    6. ** average_cost_per_case = SUM(emr_total_acquisition_cost) / COUNT(DISTINCT emr_p_event) ** 
    '''
        except:
            pass

        return robot_prompt

    def get_column_info_for_sql_generation(self, user_prompt_to_use, selected_table_text):

        ## Selected Table and Column Information to Use
        try:
            table_to_use = self.table_info_from_db[self.table_info_from_db['table_name'].str.contains(selected_table_text)]['table_name'].values.tolist()[0]
        except:
            print(f'\nFailed to select a table. \nSelected Table Text Used: {selected_table_text}\n')
            return None, None, None, None, None, None, None
        
        filtered_df, best_match_columns, embedded_df = self.generating_df_from_embeddings(user_prompt_to_use, self.df_embeddings, self.column_dictionary[table_to_use], self.data_dictionary_dataframe)

        column_info_from_knowledge_base = ''
        prompt_to_use_for_complex_question = ''
        
        ## Final Column Information that will be used to generate sql query
        table_columns = self.table_info_from_db[self.table_info_from_db['table_name'] == table_to_use]['columns'].tolist()[0]

        ## We are adding the Synonyms as well | Using this file: LLM_data_definition_v2/LLM_data_definition_v2.xlsx
        if best_match_columns:
            column_info_from_knowledge_base = self.knowledge_base_function_for_sql_generation(filtered_df)

            ## Get Prompt for Robotic Analysis
            robot_analysis_prompt = self.get_robot_analysis_prompt(table_to_use)

            if robot_analysis_prompt !='': 
                prompt_to_use_for_complex_question = robot_analysis_prompt
    
            else: ## If No Prompt found for Robotic Analysis, then 
                prompt_to_use_for_complex_question = self.find_appropriate_row_from_complex_query_sheet(self.question_df, user_prompt_to_use)

                # 🧩 Attempt to retrieve a comparison-specific SQL template prompt based on the user query.
                # This is useful when the query implies or explicitly requests a comparison between two or more matrics (e.g., "compare X and Y").
                sql_comparison_prompt = self.query_evaluation_engine.get_sql_template_if_applicable(user_prompt_to_use)
                # 🧠 If a comparison prompt is applicable and there's already a base prompt, append the comparison logic to it.
                if sql_comparison_prompt:
                    print("\n" + "="*60)
                    print("⚖️  COMPARISON MODE ACTIVATED: Question requires comparing two matrices.")
                    print("="*60 + "\n")

                    if prompt_to_use_for_complex_question:
                        # Add the comparison prompt to the existing prompt, preserving both layers of logic.
                        prompt_to_use_for_complex_question = sql_comparison_prompt + "\n" + prompt_to_use_for_complex_question
                    else:
                        # If no base prompt exists yet, use the comparison prompt as the starting point.
                        prompt_to_use_for_complex_question = sql_comparison_prompt


            if prompt_to_use_for_complex_question: # If Complex Question - Benchmark CPC, Utilization Percentage found | Or RAS Analysis Prompt Available
                columns_info_to_generate_sql = str(table_columns) + column_info_from_knowledge_base + prompt_to_use_for_complex_question
            else:
                columns_info_to_generate_sql = str(table_columns) + column_info_from_knowledge_base

        return table_to_use, filtered_df, best_match_columns, table_columns, column_info_from_knowledge_base, prompt_to_use_for_complex_question, columns_info_to_generate_sql


# <mark >====== Preliminary Readings ====== </mark>

In [17]:
#### ------------------------------------------------ ## ------------------------------------------------ ####
####              DB Connection Establishment          |        Redshift Table and Column Information
#### ------------------------------------------------ ## ------------------------------------------------ ####

information_from_db = DBInfoManager(connection_pool)
table_info_from_db, column_dictionary = information_from_db._initialize_table_data(TABLE_INFO)

#### ------------------------------------------------ ## ------------------------------------------------ ####
####                           Reading Data Dictionary | Reading Data Dictionary
#### ------------------------------------------------ ## ------------------------------------------------ ####

# Initialize the KnowledgeBaseReader class
knowledge_base_reader = KnowledgeBaseReader(
    knowledge_base_directory=knowledge_base_directory,
    knowledge_base_file_name=knowledge_base_file_name
)

# Get the cleaned columns information DataFrame
data_dictionary_dataframe = knowledge_base_reader.get_cleaned_columns_info()

#### ------------------------------------------------ ## ------------------------------------------------ ####
####              Embeddings | Embed the Knowledge Base or Update The knowledge base | df_embeddings
#### ------------------------------------------------ ## ------------------------------------------------ ####


# Create an instance of EmbeddingManager
embedding_manager = EmbeddingManager(
    embedding_function=embedding_function,
    persist_directory=persist_directory,
    collection_name=collection_name
)

if embedding_manager.is_collection_available():
    print(f'Embedding Function Running -> "load_and_update_embeddings".\n')
    df_embeddings = embedding_manager.load_and_update_embeddings(data_dictionary_dataframe)
else:
    print(f'Embedding and Saving Function -> "save_to_persist_directory". \nFolder Name: {collection_name}\npersist_directory: {persist_directory}\n')
    df_embeddings = embedding_manager.save_to_persist_directory(data_dictionary_dataframe)


#### ------------------------------------------------ ## ------------------------------------------------ ####
####    Handling Complex Questions | Retrieving Column Information for SQL Generation
####    This section retrieves relevant column information and metadata based on the user's prompt and the selected table.
####    The data is filtered, matched, and prepared to generate an appropriate SQL query for complex user questions.
#### ------------------------------------------------ ## ------------------------------------------------ ####

embedding_task_helper = EmbeddingTaskHelper(embedding_function, df_embeddings, column_dictionary, data_dictionary_dataframe, table_info_from_db, question_df)

Embedding Function Running -> "load_and_update_embeddings".


	# Documents Removed:
	1: emr_month_yr | Month Yr, Discharge Date, Release Date, Hospital Discharge Date, Patient Discharge Date, Exit Date, Discharge Month, Discharge Year, Month, Year | Month and Year when a patient is officially discharged from the hospital or healthcare facility.

	# Documents Added:
	1: emr_month_yr | Month Yr | Month and Year when a patient is officially discharged from the hospital or healthcare facility.

Knowledge Base Updated and saved to: ./Data_For_LLM/knowledge_base_ahp_qa\knowledge_base_ahp_qa.pkl


# TableSelectionChain

In [18]:
class TableSelectionChain:
    def __init__(self, df_embeddings, embedding_handler_function, table_info_from_db, column_dictionary, data_dictionary_dataframe):
        """
        Initializes the TableSelectionChain class with the necessary data.

        Args:
            embedding_handler (EmbeddingHandling): An instance of the EmbeddingHandling class.
            table_info_from_db (pd.DataFrame): DataFrame containing information about the tables.
            column_dictionary (dict): Dictionary mapping table names to their relevant columns.
            data_dictionary_dataframe (pd.DataFrame): DataFrame containing column information for tables.
        """
        self.df_embeddings = df_embeddings
        self.embedding_handler_function = embedding_handler_function
        self.table_info_from_db = table_info_from_db.copy()  # To avoid modifying the original DataFrame
        self.column_dictionary = column_dictionary
        self.data_dictionary_dataframe = data_dictionary_dataframe

    def refined_langchain_table_selection_prompt(self, user_prompt: str) -> str:
        """
        Constructs a system message prompt for selecting the most appropriate table based on the user's query.

        Args:
            user_prompt (str): The query from the user.

        Returns:
            str: The constructed prompt.
        """
        # Initialize variables
        filtered_df, best_match_columns, embedded_df = [None, None, None]  # Placeholder

        # Convert 'columns' from string representation of list to actual list
        self.table_info_from_db['columns'] = self.table_info_from_db['columns'].apply(
            lambda x: x.strip("[]").replace("'", "").split(", ") if isinstance(x, str) else x
        )

        # Construct the table descriptions
        tables_description_list = []
        for _, row in self.table_info_from_db.iterrows():
            table_name = row['table_name']
            if table_name in self.column_dictionary:
                filter_by_columns = self.column_dictionary[table_name]

                # Use EmbeddingHandling to generate DataFrame from embeddings
                filtered_df, best_match_columns, embedded_df = self.embedding_handler_function.generating_df_from_embeddings(
                    user_prompt=user_prompt,
                    df_embeddings=self.df_embeddings,
                    list_of_columns=filter_by_columns,
                    data_dictionary_dataframe=self.data_dictionary_dataframe
                )

                # Generate helper information for table selection (assuming the method exists elsewhere)
                column_values_for_table_selection = self.embedding_handler_function.helper_function_for_table_selection(filtered_df)
                specific_table_info = f"\t\tTABLE_NAME: '{table_name}'\n"

                if column_values_for_table_selection:
                    table_description = (
                        f"\n\t- Information on {row['table_for']} can be found inside the table '{table_name}' "
                        f"and most relevant columns to user's query are given below:\n{specific_table_info}\n"
                        f"{column_values_for_table_selection}"
                    )
                else:
                    table_description = (
                        f"\n\t- Information on {row['table_for']} can be found inside the table '{table_name}' "
                        f"with columns: {', '.join(row['columns'])}"
                    )
            else:
                table_description = (
                    f"\n\t- Information on {row['table_for']} can be found inside the table '{table_name}' "
                    f"with columns: {', '.join(row['columns'])}"
                )
            tables_description_list.append(table_description)

        tables_description = "\n".join(tables_description_list)

        table_names = self.table_info_from_db['table_name'].tolist()
        priority_list = ''
        for i, table in enumerate(table_names):
            priority_list += f'\n\t{i+1}: {table}'

        # Construct the prompt
        prompt = (
            "You are an expert in identifying the most relevant table for a given question. Based on the provided table information, "
            "your task is to choose the most appropriate table name that can answer the user's question:\n"
            "\n## *Table Information*:\n--------------------------------------------\n"
            f"{tables_description}\n---------------------------------------------\n"
            f"\n## Prioritize this table order for questions unrelated to cost, spend: {priority_list}"
            f"\n\n-------------\n** Special Instruction to follow **\n-------------\n"
            f"\n\t- *For any question related to ras, robot-assisted procedures, non-robotic procedures, or robot usage comparisons, always use the table: {table_names[2]}"
            f"\n\t- *For Questions related to UNSPSC, Contract Category, Manufacturer, Supplier, Manufacturer Part Number - Must select this table: {table_names[1]}"
            f"\n\t- *For Questions related to 'Benchmark' (calculated for UNSPSC, Contract Category, Manufacturer, Supplier, Manufacturer Part Number) - Select this table: {table_names[1]}"
            f"\n\t- *For Questions related to 'Benchmark' (calculated for other than UNSPSC, Contract Category, Manufacturer, Supplier, Manufacturer Part Number), 'Benchmark CPC', 'Percentile' - Select this table: {table_names[0]}"
            f"\n\t- *For Questions related to 'Savings & Opportunities' (calculated for UNSPSC, Contract Category, Manufacturer, Supplier, Manufacturer Part Number) - Select this table: {table_names[1]}"
            f"\n\t- *For Questions related to 'Savings & Opportunities'  (calculated for other than UNSPSC, Contract Category, Manufacturer, Supplier, Manufacturer Part Number)- Must Select this table: {table_names[0]}"
            f"\n\t- *For Questions related to 'Price Parity' - Must Select this table: {table_names[1]}"
            "\n\n## Return only the Full Table Name.\n## Response Template:\nTABLE_NAME\n"
        )

        return prompt.strip()

    def generate_table_selection_chain(self, user_prompt: str):
        """
        Creates a language model chain to select the most appropriate table based on the user's query.

        Args:
            user_prompt (str): The query from the user.

        Returns:
            LLMChain: A chain object for selecting the appropriate table.
        """
        # Initialize the language model
        # llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        # llm = ChatOpenAI(model='gpt-4.1-mini', temperature=0)
        # llm = ChatOpenAI(model='gpt-4.1-nano', temperature=0)
        llm = ChatOpenAI(model='gpt-5-mini', temperature=1)



        # Generate the system message for table selection
        system_message = self.refined_langchain_table_selection_prompt(user_prompt)

        # Create prompt templates
        system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
        human_message_prompt = HumanMessagePromptTemplate.from_template(
            "Provide the full table name that is most suitable to answer the human question: {human_prompt}"
        )

        # Combine prompts into a chat prompt template
        chat_prompt_template = ChatPromptTemplate.from_messages([
            system_message_prompt,
            human_message_prompt
        ])

        # Create a basic single chain
        table_selection_chain = chat_prompt_template | llm | StrOutputParser()

        return table_selection_chain
    
    def get_table_name(self, user_prompt):
        table_selection_chain = self.generate_table_selection_chain(user_prompt)
        selected_table_text = table_selection_chain.invoke({'human_prompt': user_prompt}).replace('```','').strip()

        return selected_table_text

table_selection_object = TableSelectionChain(df_embeddings, embedding_task_helper, table_info_from_db, column_dictionary, data_dictionary_dataframe)

In [19]:
# ################### ONLY FOR TESTING #################

# # test_user_prompt = 'Whats the total procedure count in ras'
# # test_user_prompt = "whats the total number of Robotic cases in 2024, include only ASA ratings 1, 2 and 3"
# # test_user_prompt = 'Compare facilities based on cost per case for @primary procedure 02L73DK-OCCLUSION OF LAA WITH INTRALUM DEV, PERC APPROACH.'

# print(f'user_prompt: {test_user_prompt}\n')
# # print(table_selection_object.refined_langchain_table_selection_prompt(test_user_prompt))
# print('\n----------------------------------------\n')
# print(table_selection_object.get_table_name(test_user_prompt))

In [20]:
# for run in range(5):
#     print(table_selection_object.get_table_name(user_prompt))

# MemoryChainManager

### [NOT IN USE]

In [21]:
# class InMemoryHistory(BaseChatMessageHistory, BaseModel):
#     """In-memory implementation of chat message history."""

#     messages: List[BaseMessage] = Field(default_factory=list)

#     def add_messages(self, messages: List[BaseMessage]) -> None:
#         """Add a list of messages to the store, maintaining a maximum of 12 messages."""
#         if len(self.messages) >= 12:
#             self.messages = self.messages[2:]
#         self.messages.extend(messages)

#     def clear(self) -> None:
#         """Clear all messages from the history."""
#         self.messages = []

#     def remove_last_two_if_human_ai(self) -> None:
#         """Remove the last two messages if they are a HumanMessage followed by an AIMessage."""
#         if len(self.messages) >= 2:
#             if isinstance(self.messages[-2], HumanMessage) and isinstance(self.messages[-1], AIMessage):
#                 self.messages = self.messages[:-2]

# class MemoryChainManager:
#     def __init__(self):
#         """
#         Initializes the MemoryChainManager with a store for session histories.
#         """
#         self.store: Dict[str, InMemoryHistory] = {}

#     def get_by_session_id(self, session_id: str) -> InMemoryHistory:
#         """
#         Retrieves the message history for a given session ID, creating a new one if it doesn't exist.

#         Args:
#             session_id (str): The session ID of the user.

#         Returns:
#             InMemoryHistory: The message history associated with the session ID.
#         """
#         if session_id not in self.store:
#             self.store[session_id] = InMemoryHistory()
#             print(f"\n===========\n**New User Found: {session_id}.\n\nAvailable User Ids in the Store: {(' | ').join(list(self.store.keys()))}\n===========\n")
#         return self.store[session_id]

#     def remove_user_by_session_id(self, session_id: str) -> dict:
#         """
#         Removes a user and their message history from the store.

#         Args:
#             session_id (str): The session ID of the user to remove.

#         Returns:
#             dict: A dictionary containing the status and message of the operation.
#         """
#         result = {}
#         if session_id in self.store:
#             del self.store[session_id]
#             remaining_users = ', '.join(self.store.keys())
#             result['status'] = 'success'
#             result['message'] = f'User Removed: "{session_id}".'
#             result['remaining_users_count'] = len(self.store)
#             result['remaining_users'] = remaining_users
#         else:
#             remaining_users = ', '.join(self.store.keys())
#             result['status'] = 'error'
#             result['message'] = f'User Not Found: "{session_id}".'
#             result['remaining_users_count'] = len(self.store)
#             result['remaining_users'] = remaining_users
#         return result

#     def refinement_chain(self, user_prompt: str, user_id: str) -> str:
#         """
#         Refines the user's prompt using memory to make it clear and ready for SQL generation.

#         Args:
#             user_prompt (str): The user's original query.
#             user_id (str): The session ID of the user.

#         Returns:
#             str: The refined query.
#         """
#         llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)

#         today_date = datetime.now()
#         formatted_date = today_date.strftime('%Y-%m-%d')

#         refinement_system_message = f'''
# You are a query refinement assistant. Your task is to refine user queries using relevant memory, ensuring they are clear, concise, and ready for SQL generation.

# ### Instructions:
# 1. Include only essential details from memory (e.g., {{time_period}}, {{entity}}, {{metric}}) to make the query precise and fully formed for SQL generation.
# 2. Ensure that both the relevant parameter (e.g., {{metric}}) and the time period (e.g., {{time_period}}) are included in the refined query if available in memory.
# 3. **If the query contains an ambiguous time reference (e.g., "this month", "this year") and no specific time period is available in memory, replace the ambiguous reference with today's date.** Use the following format:
#     - Today's date is: {formatted_date}. It is the {today_date.day} day of {today_date.strftime('%B')} in the year {today_date.year}.
# 4. **If no refinement is needed, return the query as is.**
# 5. Provide only the refined query, without any explanation or description.
# 6. If the user's query refers to information not covered in the memory, refine the query to clearly state the original intent without relying on memory details.
# 7. **Output only the refined query**. Do not include any answers, explanations, or other text.

# ### Output Format:
# Only the refined query, formatted as:
# {{refined_query}}
# '''

#         # Create the prompt template
#         refinement_prompt = ChatPromptTemplate.from_messages([
#             MessagesPlaceholder(variable_name="history"),
#             ("system", refinement_system_message),
#             ("human", "{user_prompt}")
#         ])

#         # Create the chain with history management
#         refinement_chain = refinement_prompt | llm
#         refinement_chain_with_history = RunnableWithMessageHistory(
#             refinement_chain,
#             self.get_by_session_id,
#             input_messages_key="user_prompt",
#             history_messages_key="history"
#         )

#         # Invoke the chain
#         refinement_chain_result = refinement_chain_with_history.invoke(
#             {'user_prompt': user_prompt},
#             config={"configurable": {"session_id": user_id}}
#         ).content

#         # Remove the last two messages if necessary
#         self.store[user_id].remove_last_two_if_human_ai()

#         return refinement_chain_result

#     def decision_chain(self, user_prompt: str, user_id: str) -> str:
#         """
#         Determines whether the user's query can be answered from memory.

#         Args:
#             user_prompt (str): The user's query.
#             user_id (str): The session ID of the user.

#         Returns:
#             str: The answer from memory or '0' if not available.
#         """
#         # llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
#         llm = ChatOpenAI(model='gpt-4.1-2025-04-14', temperature=0)
        
#         decision_system_message = '''
# You are an assistant for the 'Supply Copia' company. Your task is to answer questions based on previous interactions and memory in a polite and human-readable tone.

# ### Instructions:
# 1. If the User's query is not available from past interactions or memory, respond with '0' only.
# 2. If the User's query is available from past interactions or memory, provide the answer in a clear, human-readable format.
# 3. If you are uncertain about the answer or cannot find relevant data, respond with '0' only.
# 4. Do not generate or fabricate any information (no hallucinations). Always prioritize accuracy.
# 5. If the required data is not retrievable from memory, always return '0' without adding any explanations or additional information.

# ### Output Format:
# - If the answer is available: Provide it in natural language.
# - If not available: '0'
# '''

#         # Create the prompt template
#         decision_prompt = ChatPromptTemplate.from_messages([
#             MessagesPlaceholder(variable_name="history"),
#             ("system", decision_system_message),
#             ("human", "{question}")
#         ])

#         # Create the chain with history management
#         decision_chain = decision_prompt | llm
#         decision_chain_with_history = RunnableWithMessageHistory(
#             decision_chain,
#             self.get_by_session_id,
#             input_messages_key="question",
#             history_messages_key="history"
#         )

#         # Invoke the chain
#         decision_chain_result = decision_chain_with_history.invoke(
#             {"question": user_prompt},
#             config={"configurable": {"session_id": user_id}}
#         ).content

#         # If the result is '0', remove the last two messages
#         if decision_chain_result == '0':
#             self.store[user_id].remove_last_two_if_human_ai()

#         return decision_chain_result

#     def combined_chain(self, user_prompt: str, user_id: str) -> (str, str):
#         """
#         Combines the decision and refinement chains to process the user's query.

#         Args:
#             user_prompt (str): The user's query.
#             user_id (str): The session ID of the user.

#         Returns:
#             tuple:
#                 - str: The refined query to be used.
#                 - str or None: The result from memory if available, otherwise None.
#         """
#         user_prompt_to_use = user_prompt
#         result_from_memory = None

#         # Check if there is any message history
#         if self.store[user_id].messages:
#             # Use the decision chain to check if the answer is available in memory
#             decision_result = self.decision_chain(user_prompt, user_id)

#             if decision_result == '0':
#                 # If not available, refine the query
#                 user_prompt_to_use = self.refinement_chain(user_prompt, user_id)
#                 print(f'Refined Prompt: {user_prompt_to_use}\n')
#             else:
#                 # If available, use the result from memory
#                 result_from_memory = decision_result

#         return user_prompt_to_use, result_from_memory

# memory_manager = MemoryChainManager()

In [22]:
# H: Whats the total cost monthwise for the year 2024? 
# A: List of month and total cost associated on each month 

# H: Whats the total cost found on June? 
# A: In June, the cost is .... 


# H: Whats the total number of encounters on the second month from the above list? 
# Intermediary Work: Whats the total number of encountesrs on Februrary, 2024? 

### [IN USE]

In [23]:
class InMemoryHistory(BaseChatMessageHistory, BaseModel):
    messages: List[BaseMessage] = Field(default_factory=list)

    def add_messages(self, messages: List[BaseMessage]) -> None:
        if len(self.messages) >= 12:
            self.messages = self.messages[2:]
        self.messages.extend(messages)

    def clear(self) -> None:
        self.messages.clear()

    def remove_last_two_if_human_ai(self) -> None:
        if len(self.messages) >= 2:
            if isinstance(self.messages[-2], HumanMessage) and isinstance(self.messages[-1], AIMessage):
                self.messages = self.messages[:-2]


class MemoryChainManager:
    def __init__(self):
        self.store: Dict[str, InMemoryHistory] = {}

    def get_by_session_id(self, session_id: str) -> InMemoryHistory:
        if session_id not in self.store:
            self.store[session_id] = InMemoryHistory()
            print(f"\n===========\n**New User Found: {session_id}.\n===========\n")
        return self.store[session_id]
    

    def remove_user_by_session_id(self, session_id: str) -> dict:
        """
        Removes a user and their message history from the store.

        Args:
            session_id (str): The session ID of the user to remove.

        Returns:
            dict: A dictionary containing the status and message of the operation.
        """
        result = {}
        if session_id in self.store:
            del self.store[session_id]
            remaining_users = ', '.join(self.store.keys())
            result['status'] = 'success'
            result['message'] = f'User Removed: "{session_id}".'
            result['remaining_users_count'] = len(self.store)
            result['remaining_users'] = remaining_users
        else:
            remaining_users = ', '.join(self.store.keys())
            result['status'] = 'error'
            result['message'] = f'User Not Found: "{session_id}".'
            result['remaining_users_count'] = len(self.store)
            result['remaining_users'] = remaining_users
        return result
    
    def get_all_histories_grouped(self) -> Dict[str, List[Dict[str, str]]]:
        """
        Return all session_ids with their associated messages,
        grouped so each human prompt and its AI reply are in the same dictionary element.

        Returns:
            Dict[str, List[Dict[str, str]]]: {
                "session_id": [
                    {"human": "prompt text", "ai": "response text"},
                    ...
                ]
            }
        """
        all_histories = {}
        for session_id, history in self.store.items():
            grouped = []
            temp_pair = {}
            for msg in history.messages:
                if isinstance(msg, HumanMessage):
                    temp_pair = {"human": msg.content}
                elif isinstance(msg, AIMessage):
                    if "human" in temp_pair:
                        temp_pair["ai"] = msg.content
                        grouped.append(temp_pair)
                        temp_pair = {}
            all_histories[session_id] = grouped
        return all_histories

    def extract_latest_explicit_date_from_history(self, user_id: str) -> Optional[str]:
        history = self.store[user_id].messages
        for msg in reversed(history):
            if isinstance(msg, (HumanMessage, AIMessage)):
                month_year_match = re.search(r"(0[1-9]|1[0-2])-(20\\d{2})", msg.content)
                if month_year_match:
                    return month_year_match.group(0)

                named_month_match = re.search(
                    r"(January|February|March|April|May|June|July|August|September|October|November|December) (\\d{4})",
                    msg.content,
                    re.IGNORECASE
                )
                if named_month_match:
                    month = named_month_match.group(1)
                    year = named_month_match.group(2)
                    month_num = list(calendar.month_name).index(month.capitalize())
                    return f"{month_num:02}-{year}"
        return None

    def extract_latest_entity_from_history(self, user_id: str, entity_keywords: List[str]) -> Optional[str]:
        history = self.store[user_id].messages
        for msg in reversed(history):
            if isinstance(msg, (HumanMessage, AIMessage)):
                for keyword in entity_keywords:
                    match = re.search(fr"{keyword}[:\s-]*([A-Z ,.'\-]+)", msg.content, re.IGNORECASE)
                    if match:
                        return match.group(1).strip()
        return None

    def refinement_chain(self, user_prompt: str, user_id: str) -> str:
        llm = ChatOpenAI(model='gpt-4.1', temperature=0)

        fallback_date = datetime.now().strftime('%m-%Y')
        latest_date = self.extract_latest_explicit_date_from_history(user_id) or fallback_date
        
        refinement_prompt = ChatPromptTemplate.from_messages([
            MessagesPlaceholder(variable_name="history"),
            ("system", f"""You are a query refinement assistant. Your job is to refine vague or incomplete user questions using historical context to form a complete, executable query for SQL generation.

Guidelines:
1. When the user uses vague time references like \"this month\", \"this year\", \"last quarter\", resolve them only if:
   a. The question appears to involve a time-dependent metric (e.g., cost, trend, encounters).
   b. AND there is a clear, recent time reference (e.g., \"October 2024\") in memory.

2. DO NOT carry time references from previous memory if the current user query shifts focus to a new entity (e.g., a physician, market, or facility) — unless the time is explicitly stated again.

3. When vague entity references like \"this physician\", \"this market\", or \"this facility\" are used, replace them using the most recent specific mention in memory:
   - Example: \"this physician\" → \"ERTEM, FURKAN U\" if that was the last mentioned physician.

4. For open-ended queries like:
   - \"Which month had the highest encounters?\"
   - \"Show supply cost trend over time\"
   Do NOT inject any default or historical date — preserve the open nature.

5. If no context is found and fallback is needed (e.g., \"this month\" without prior reference), use: {latest_date}

6. If the input question is already fully formed, return it unchanged.

7. Do NOT generate explanations, comments, or extra formatting. Only return the refined query.

8. Clarify Entity Types: If a user mentions a specific value (e.g., a name or location) without specifying its entity type (like market, facility, physician, etc.), prepend the appropriate entity label based on context.
Example: “What is the total cost for Horizon West?” → “What is the total cost for market Horizon West?”

Output Format:
refined_query
"""),
            ("human", "{user_prompt}")
        ])

        chain = refinement_prompt | llm
        with_history = RunnableWithMessageHistory(
            chain,
            self.get_by_session_id,
            input_messages_key="user_prompt",
            history_messages_key="history"
        )

        result = with_history.invoke(
            {"user_prompt": user_prompt},
            config={"configurable": {"session_id": user_id}}
        ).content

        self.store[user_id].remove_last_two_if_human_ai()
        return result

    def decision_chain(self, user_prompt: str, user_id: str) -> str:
        llm = ChatOpenAI(model='gpt-4o', temperature=0)

        system_message = '''
You are an assistant for the 'Supply Copia' company. Your task is to answer questions based on previous interactions and memory in a polite and human-readable tone.

Instructions:
1. If the User's query is not available from past interactions or memory, respond with '0' only.
2. If the User's query is available from past interactions or memory, provide the answer in a clear, human-readable format.
3. If you are uncertain about the answer or cannot find relevant data, respond with '0' only.
4. Do not generate or fabricate any information (no hallucinations). Always prioritize accuracy.
5. If the required data is not retrievable from memory, always return '0' without adding any explanations or additional information.

Output Format:
- If the answer is available: Provide it in natural language.
- If not available: '0'
'''

        prompt = ChatPromptTemplate.from_messages([
            MessagesPlaceholder(variable_name="history"),
            ("system", system_message),
            ("human", "{question}")
        ])

        chain = prompt | llm
        with_history = RunnableWithMessageHistory(
            chain,
            self.get_by_session_id,
            input_messages_key="question",
            history_messages_key="history"
        )

        result = with_history.invoke(
            {"question": user_prompt},
            config={"configurable": {"session_id": user_id}}
        ).content

        if result == '0':
            self.store[user_id].remove_last_two_if_human_ai()

        return result

    def combined_chain(self, user_prompt: str, user_id: str):
        user_prompt_to_use = user_prompt
        result_from_memory = None

        if self.store[user_id].messages:
            decision = self.decision_chain(user_prompt, user_id)
            if decision == '0':
                user_prompt_to_use = self.refinement_chain(user_prompt, user_id)
                print(f'[Refined Query]: {user_prompt_to_use}')
            else:
                result_from_memory = decision

        return user_prompt_to_use, result_from_memory


In [24]:
# # # >>>>>>>>>>>>>>>>>>>>>>>>  Memory Chain | combined_memory_chain_function <<<<<<<<<<<<<<<<<<<<<<<
# user_prompt_to_use, result_from_memory = memory_manager.combined_chain(user_prompt, user_id)
# prompt_to_save = user_prompt
# if user_prompt != user_prompt_to_use:
#     prompt_to_save = f"{user_prompt} [REFINED QUERY: ({user_prompt_to_use})]"
# # # >>>>>>>>>>>>>>>>>>>>>>>>  Memory Chain | combined_memory_chain_function <<<<<<<<<<<<<<<<<<<<<<<

#  QueryTypeClassifier | General or Specific ??

In [25]:
class QueryTypeClassifier:
    def __init__(self):
        """
        Initializes the QueryTypeClassifier with a language model.
        """
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)

        self.decision_chain = self._create_decision_chain()
        self.out_of_context_chain = self._create_out_of_context_chain()

    def _create_decision_chain(self):
        """
        Creates a decision-making chain to classify user queries.
        """
        system_message = '''# System Instructions:
This system evaluates the user's query and classifies it into one of two categories:

1. General Knowledge Query (Return '1'):
    - These queries seek conceptual understanding or general explanations.
    - Typically answerable without access to structured or tabular data.
    - Examples include: "What is blood utilization?" or "Explain how cost-saving works in healthcare."

2. Specific Data Query (Return '2'):
    - These queries involve comparisons, trends, metrics, differences, aggregations, or numerical insights.
    - Typically require access to structured data, such as databases, tables, or performance metrics.
    - Keywords like "how many", "difference", "compare", "utilization across", "average", "total", or references to specific dimensions (e.g., by physician, by market, over time) indicate a specific data query.
    - Questions asking *"how something varies or differs across categories"* (e.g., departments, service lines, markets, timeframes) are **always** specific data queries.

# Special Notes:
- If no question is provided or the query is clearly general knowledge, return '1'.
- Default to '2' when there is any data-specific context, even if the question seems open-ended.

# Immediate Instruction to LLM:
- Analyze the query content precisely.
- Return:
    - '1' for general knowledge queries.
    - '2' for specific data queries.
'''

        chat_prompt_template = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(system_message),
            HumanMessagePromptTemplate.from_template("{user_query}")
        ])

        return chat_prompt_template | self.llm | StrOutputParser()

    def _create_out_of_context_chain(self):
        """
        Creates a decision-making chain to detect out-of-context general knowledge queries.
        """
        system_message = '''# Task:
Evaluate whether the following general knowledge query (type '1') is still relevant to the same domain as our specific data queries.

# Reference Domain:
Our system handles queries related to healthcare operations, clinical procedures, outcomes, cost/utilization analysis, provider performance, and product/vendor data.

# Response:
- Return '1' → The question is general but still aligned with the above domain.
- Return '0' → The question is unrelated (e.g., about food, space, jokes, or geography).

Return only '0' or '1'. No explanation.
'''

        chat_prompt_template = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(system_message),
            HumanMessagePromptTemplate.from_template("{user_query}")
        ])

        return chat_prompt_template | self.llm | StrOutputParser()

    def classify_query(self, user_query: str) -> str:
        """
        Classifies the user query as either '1' (general knowledge) or '2' (specific data).
        """
        result = self.decision_chain.invoke({"user_query": user_query})
        return result.strip()

    def is_out_of_context_general_knowledge(self, user_query: str) -> str:
        """
        Identifies whether a general knowledge query is out of domain context.

        Returns:
            str: '0' if out of context, '1' if contextually relevant.
        """
        result = self.out_of_context_chain.invoke({"user_query": user_query})
        return result.strip()

    def generate_out_of_scope_response_with_llm(self, user_prompt: str) -> str:
        """
        Uses the LLM to generate a polite, contextual response when a question is out of scope.

        Args:
            user_prompt (str): The user's original query.

        Returns:
            str: A respectful, human-like message explaining that the query is out of scope.
        """
        system_message = """You are an assistant "Ask the Bee" of "Supply Copia" designed to answer questions related to healthcare analytics, cost, quality, clinical performance, and operational metrics.

    If a user's question is out of scope (e.g., unrelated to healthcare, clinical data, or medical analysis), respond respectfully with a friendly message that:

    1. Acknowledges the user’s query.
    2. Explains briefly and kindly that it’s outside the assistant’s domain.
    3. Encourages them to ask something within the assistant’s expertise.

    Do not mention the word “out of scope” or say "I cannot help with that" in a negative tone.
    Do not hallucinate an answer to the actual query.
    Just stay polite and helpful.
    """

        prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(system_message),
            HumanMessagePromptTemplate.from_template("{user_prompt}")
        ])

        chain = prompt | self.llm | StrOutputParser()
        return chain.invoke({"user_prompt": user_prompt}).strip()

query_type_classifier = QueryTypeClassifier()

In [26]:
# query_type = query_type_classifier.classify_query(user_prompt_to_use)
# print(f'{query_type}: {user_prompt_to_use}')

# # If its a general question, then identify if its out of context or not
# if query_type == '1':
#     context_check = query_type_classifier.is_out_of_context_general_knowledge(query)
#     print("Out of Context Check:", context_check)

#     if context_check == '0':
#         response = query_type_classifier.generate_out_of_scope_response_with_llm(user_prompt_to_use)
#         print(response)

# AnalysisChainManager

In [27]:
#### ------------------------------------------------ ## ------------------------------------------------ ####
#### Analysis Chain - To Analyze the Data | Analysis Chain - To Analyze the Data
#### ------------------------------------------------ ## ------------------------------------------------ ####

class AnalysisChainManager:
    def __init__(self):
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)

    def generate_analysis_chain_system_message(self, question, df):
        """
        Creates a language model chain to analyze data based on natural language requests, providing insights in a clear and engaging manner.

        Returns:
            LLMChain: A chain object for conducting data analysis and generating insights based on the user's query and provided data.
        """

        if len(df) > 100:
            display_size_note = f"**Note: The Dataframe includes the first 50 entries out of a total of {len(df)} rows. Let the user know about it."
        else:
            display_size_note = ""

        system_message_prompt = f'''You are an AI Chatbot for Supplycopia, experienced in data analysis. Your task is to interpret natural language requests and provide precise **data analysis** based on the given User's Question, and Dataframe given below.

## For straightforward data, highlight key findings succinctly.
## For complex data, conduct an in-depth analysis:
    - Identify trends and patterns in detail, ensuring that any monthly trends are presented in proper chronological order.
    - Detect and discuss any outliers with thorough explanations.
    - If a specific month is mentioned, present it in full format (e.g., "July 2023" instead of "7th month") and provide the most detailed possible analysis relevant to that month.
## Communicate insights in clear, simple language suitable for non-experts.

{display_size_note}

# Tone of Response:
    - Friendly, professional, engaging, clear, and knowledgeable.

------------
USER'S QUESTION: {question}
------------
DataFrame: {df.head(50).to_json().replace('{', '{{').replace('}', '}}')}
------------
'''
        return system_message_prompt
        
    def run_analysis_chain(self, question, df):
        system_message_prompt = self.generate_analysis_chain_system_message(question, df)
        chat_prompt_template = PromptTemplate.from_template(system_message_prompt)
        analysis_chain = chat_prompt_template | self.llm | StrOutputParser()
        query_response = analysis_chain.invoke({})
        query_response = query_response.strip()
        return query_response
    
analysis_chain_manager = AnalysisChainManager()

In [28]:
# analysis_chain_response = analysis_chain_manager.run_analysis_chain(user_prompt_to_use, df)
# print(analysis_chain_response)

# <span style="color: moccasin; font-weight: BOLD;">SQLQueryGenerator | MultiStageSQLGenerator</span>

## 1: SQLQueryGenerator

In [29]:
class SQLQueryGenerator:
    def __init__(self, connection_pool, table_info_from_db, embedding_task_helper):
        """
        Initializes the SQLQueryGenerator with the specified Redshift table and language model.

        Args:
            redshift_table (str): The name of the Redshift SQL table to query.
            model_name (str, optional): The name of the language model to use. Default is 'gpt-4o-2024-08-06'.
        """
        # self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        self.llm = ChatOpenAI(model='gpt-4.1-2025-04-14', temperature=0)


        ########## ------------- #############
        # self.llm_for_correction = ChatOpenAI(model='gpt-4o-mini', temperature=0) 
        self.llm_for_correction = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        ########## ------------- #############

        self.connection_pool = connection_pool
        self.embedding_task_helper = embedding_task_helper

        ## Identifying the Product level table for further use
        self.product_level_table = table_info_from_db[table_info_from_db['table_for']=='Product Level Utilization Data']['table_name'].values[0]
    
    # ------------------------------------- Finding Correct Sql Query -------------------------------------------------- # 
    def generate_prompt_for_sql_correction(self, user_query: str, sql_query: str, sql_error_message: str) -> str:
        """
        Generates a prompt for the LLM to correct and refine the SQL query.

        Args:
            user_query (str): The user's input query.
            sql_query (str): The initial SQL query.
            sql_error_message (str): The error message encountered during SQL query execution.

        Returns:
            str: The prompt for the LLM.
        """
    
        prompt = f"""You are an expert SQL assistant. Your task is to correct the provided SQL query based on the error message, user's intent, and SQL best practices. Specifically:

**Instructions:**
    - Analyze the provided SQL query and error message to identify the root cause.
    - Correct the query to fix all technical issues while maintaining SQL best practices.
    - Return only the corrected SQL query, formatted as code, and nothing else.

Example of issues you must fix:
    - Missing `GROUP BY` for non-aggregated columns in the `SELECT` clause.
    - Proper use of aliases, subqueries, and aggregate functions.
    - Syntax errors like division by zero risks, invalid column names, or casing issues.

-----------------------------
User Query: {user_query}

Incorrect SQL Query:

```
{sql_query}
```

Error Message: {sql_error_message}
-----------------------------

## Analyze the error message and correct the faulty sql query given.
## Prioritize identifying and fixing technical errors over stylistic corrections.
"""
        
        return prompt.strip()

    def correct_sql_query(self, user_query, sql_query: str, sql_error_message: str) -> str:
        """
        Corrects the SQL query using the LLM.

        Args:
            user_query (str): The user's input query.
            sql_query (str): The initial SQL query.
            sql_error_message (str): The error message encountered during SQL query execution.

        Returns:
            str: The corrected SQL query.
        """
        # Create the system message prompt
        system_message = self.generate_prompt_for_sql_correction(user_query, sql_query, sql_error_message)
        system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)

        # Create the human message prompt to include the user query
        human_message_prompt = HumanMessagePromptTemplate.from_template("Correct the SQL query.")

        # Combine system and human prompts into a chat prompt template
        chat_prompt_template = ChatPromptTemplate.from_messages([
            system_message_prompt,
            human_message_prompt
        ])

        response_chain = chat_prompt_template | self.llm_for_correction | StrOutputParser()
        # Get the corrected SQL query
        corrected_query = response_chain.invoke({})
        corrected_query = corrected_query.replace('```sql','').replace('```','').strip()
        return corrected_query.strip()

    # ------------------------------------------------------------------------------------------------------------------- # 


    def _create_table_definition_prompt(self, user_prompt, column_list, redshift_table) -> str:
        """
        Generates a prompt for GPT to create SQL queries for the specified Redshift table,
        focusing only on the given columns.

        Args:
            column_list (list or str): A list of column names or a string of column names to be included in the SQL query.

        Returns:
            str: A string prompt that details the requirements for generating SQL queries with specific constraints.
        """
        # Get today's date
        today_date = datetime.now()
        # Format the date as YYYY-MM-DD
        formatted_date = today_date.strftime('%Y-%m-%d')

        # Format the columns
        if isinstance(column_list, str):
            columns_formatted = column_list
        else:
            columns_formatted = ", ".join(column_list)
      
        ######################## grouping_instruction - P-EVENT Level ########################
        grouping_instruction = ''
        if redshift_table == self.product_level_table:
            grouping_instruction = f"""\n## Grouping Requirements for the `{self.product_level_table}` Table
            
- Mandatory Grouping:
    - ***Always group by `emr_p_event` when calculating Outcome and Quality metrics, including:
        - Outcome Score
        - Length of Stay
        - Geometric Mean Length of Stay
        - Patient Age Bucket
        - Cost Per Case
        - Implant Cost Per Case
        - Medical Supply Cost Per Case
------------------------------------------------------------------------\n"""
        # ======================================================================== # 

        ######################## System Benchmark Context - Using: embedding_task_helper########################
        system_benchmark_context_to_add = ''
        if self.embedding_task_helper.check_prompt_conditions_for_system_benchmark(user_prompt):
            system_benchmark_context_to_add = """\n## System Benchmark Context:
- When the query includes "System Benchmark", special rules apply to column filters.

### Rules:
1. **Do NOT apply WHERE conditions** to specific columns if **"System Benchmark"** is present.
2. **Excluded Columns**:
    - Facility, *Market, Account Number, Age, ASA Score, BMI Bucket, Detailed Patient Type, Diabetic Status, Patient Ethnicity, Patient Gender, Patient Type, Physician, Payor Group, Smoking Status, SSI Flag, Mortality Flag, Readmission Flag, Blood Transfusion Flag, Product Name, Distributor, Contract Type, Outlier Classification.

### **Example Behavior**:
- User Query: "What are the System Benchmark total encounters across top 10 UNSPSC in @Market X and @Primary Procedure 'Y'?"
- Expected SQL Behavior: Do not apply WHERE filters on the excluded columns listed above, even if other conditions like UNSPSC or Market are present.

### Key Notes:
- Always prioritize the "System Benchmark" context and leave the excluded columns unfiltered.

------------------------------------------------------------------------\n"""

        # print(f'system_benchmark_context_to_add:\n {system_benchmark_context_to_add}\n')
         # ======================================================================== # 


         ## Time Related Prompt
        time_related_prompt = ''
        if 'time' in user_prompt.lower() or 'trend' in user_prompt.lower() or 'month' in user_prompt.lower() or 'year' in user_prompt.lower() or 'quarter' in user_prompt.lower():
            time_related_prompt = '''- Interpret 'spend over time' as a **monthly spend trend**.'''

        prompt = f'''
Craft SQL queries for the "{redshift_table}" table in Redshift, following these instructions:
- Only use these columns for crafting the SQL query: {columns_formatted}.
- Begin each query with "SELECT".
- *Response should only contain the SQL query.
- Include a "WHERE" clause for specific data retrieval.
    -- *Maintain the exact format of user-provided strings within the "WHERE" condition without altering or adding spaces, punctuation, or abbreviations.
- Today's date is: {formatted_date}. It's the {today_date.day} day of the month '{today_date.strftime('%B')}' in the year {today_date.year}.
- If the user's query does not specify or hint at a date or year, there is no need to add date filters inside the SQL query.

- Dates are given inside the columns: 'emr_month_yr', 'emr_discharge_date'.
    - emr_discharge_date: YYYY-MM-DD (e.g., 2023-05-14, use LIKE '2023-%')
    - emr_month_yr: MM-YYYY (e.g., 05-2023, use LIKE '%-2023' for year filtering)
    - Use the LIKE % operator while dealing with the dates.
    
    - **Grouping Date Filters**:
        - Always **group date filters properly** using parentheses when asked "month-wise breakdown", "trend over time", "monthly CPC in 2023" or similar to ensure correct logical grouping with other conditions in the `WHERE` clause.
            - Example:
                ```sql
                WHERE 
                    (emr_discharge_date LIKE '2023-%' AND emr_discharge_date >= '2023-11-01')
                    OR (emr_discharge_date LIKE '2024-%' AND emr_discharge_date <= '2024-10-31')
                ```
        - This ensures that date conditions are applied together and evaluated accurately with other filters.
        
        {time_related_prompt}

- Use the LIKE % operator to search for substrings within a column.
- Must use SUM(), COUNT(), AVG(), MIN(), MAX() functions appropriately while handling queries about: Cost, Charge, Spend, Score, Revenue, etc.
- Important Note: Ensure case insensitivity within the 'WHERE' condition by converting all string columns to lowercase before performing comparisons.

- Include an "ORDER BY" clause in the SQL query to ensure the results are sorted as follows:
    - Case 1: If the question involves only one parameter, sort the response in descending order of that parameter.
    - Case 2: If the question involves multiple parameters and includes either total encounter or total spend or CPC, sort the response in descending order of total encounter or total spend or cost per case (CPC).
    - Case 3: If the question involves multiple parameters without total encounter or total spend or CPC, sort the response in descending order of the first parameter mentioned in the question.
    - Case 4: If the question involves Date parameter, return results sorted by Date in ascending order.
    - Case 5: Always add NULLS LAST in the ORDER BY clause to push nulls to the bottom.

- **Ensure all fields referenced in the outer query**:
    - Include in the inner query’s `SELECT` statement any field used in the outer query’s `SELECT`, `WHERE`, `CASE`, or `GROUP BY` clauses.
    - This includes fields required for calculations (e.g., `AVG`, `SUM`), filters, or aggregations to avoid errors like "column does not exist."

- When aggregating data, construct multilevel aggregation queries. This involves nested queries that perform calculations at multiple levels.
- Only provide the SQL query.
- End the SQL query with a semicolon (;).
- Do not add any explanation before or after the SQL query response.

Example Response Format:
```
SELECT 
    col1, 
    (SUM(col3) / COUNT(col3)) * 100 AS percentage
FROM 
    (
        SELECT 
            col1, 
            AVG(col3) AS col3
        FROM 
            {redshift_table}
        WHERE 
            LOWER(col4) = LOWER(<value provided by the user>) 
            AND LOWER(col5) = LOWER(<value provided by the user>)
        GROUP BY 
            col1
    ) subquery
GROUP BY 
    col1
ORDER BY 
    percentage DESC
LIMIT 10;
```
- Construct a SQL query adhering to these guidelines.
- Generate multilevel aggregation SQL queries like the example above to make the SQL query more usable.

{grouping_instruction}
{system_benchmark_context_to_add}

📌 **Date Range Conversion (Mandatory)**  
    - ALWAYS include date range columns in every query:  
    - For `emr_discharge_date`, ALWAYS convert using:
        TO_CHAR(MIN(TO_DATE(emr_discharge_date, 'YYYY-MM-DD')), 'YYYY-MM-DD') AS start_date,
        TO_CHAR(MAX(TO_DATE(emr_discharge_date, 'YYYY-MM-DD')), 'YYYY-MM-DD') AS end_date  
    - This conversion MUST be applied even if the user does not explicitly request the date range.  

- If you have used multilevel aggregation query, make sure "PROPER AGGREGATION" IS BEING USED.
- Use WHERE to filter rows before aggregation.
- Use HAVING to filter aggregated results or groups after aggregation.

✅ **For comparison-based questions** (e.g., comparing metrics between different markets, service lines, physicians, facilities, etc.):  
- You **must** include the dimension being compared in the final output of the `SELECT` clause.  
- This **must** be done even if aggregation is performed at a different level.  
- The compared dimension **must** be clearly visible in the final result to label and distinguish each row properly.  
- You **must not** omit the compared dimension from the output — doing so will result in incomplete or misleading answers.  
- This rule **must always be followed**. No exceptions.

    ### ✅ Examples of comparison-based questions:
    - Compare the top 10 procedures by average cost for two different markets in 2024  
    - Compare average case costs across service lines  
    - Compare conversion rates between facilities over time
'''
        
        return prompt.strip()

    def generate_sql_chain(self, user_prompt, column_list, redshift_table):
        """
        Creates a language model chain to generate SQL queries based on the provided columns.

        Args:
            column_list (list or str): A list of column names or a string of column names to be included in the SQL query.

        Returns:
            Chain: A chain object for generating SQL queries.
        """
        # Generate the system prompt
        system_message = self._create_table_definition_prompt(user_prompt, column_list, redshift_table)

        # Create system message prompt template
        system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)

        # Human prompt template
        human_message_prompt = HumanMessagePromptTemplate.from_template(
            "A query to answer:\n{human_prompt}"
        )

        # Combine prompts into a chat prompt template
        chat_prompt_template = ChatPromptTemplate.from_messages([
            system_message_prompt,
            human_message_prompt
        ])

        # Create the chain
        sql_chain = chat_prompt_template | self.llm | StrOutputParser()

        return sql_chain

    def extract_sql_query_from_string(self, query_string):
        # Regular expression to find the start of the SQL query
        sql_pattern = re.compile(r"(?is)(?<!\S)(SELECT|WITH)\b.*", re.IGNORECASE | re.DOTALL)

        # Search for the SQL query using regex (start point)
        match = sql_pattern.search(query_string)
        if not match:
            return None

        # Get the query from the start
        query_start = match.start()

        # Find the index of the last semicolon in the query string
        last_semicolon_index = query_string.rfind(';')

        # If no semicolon is found, return None (invalid query)
        if last_semicolon_index == -1:
            return None

        # Extract the query from the start of the match to the last semicolon
        sql_query = query_string[query_start:last_semicolon_index + 1].strip()

        return sql_query 
    
    def get_dataframe_from_sql(self, generated_sql_query):
        """
        Executes a given SQL query and returns the results as a pandas DataFrame.

        Args:
            generated_sql_query (str): The SQL query to execute.

        Returns:
            tuple: A pandas DataFrame containing the query results and an error message (if any).
        """
        # List of allowed aggregation functions and query clauses
        aggregation_functions_list = ["SUM", "COUNT", "AVG", "MIN", "MAX", "DISTINCT"]
        query_upper = generated_sql_query.upper()
        error = None

        # Validate the query before execution
        if (any(func in query_upper for func in aggregation_functions_list) or 
            "WHERE" in query_upper or "LIMIT" in query_upper) and \
            (query_upper.startswith("SELECT") or query_upper.startswith("WITH") or query_upper.startswith("-- ")):
            
            df = pd.DataFrame()  # Initialize an empty DataFrame
            conn = None

            try:
                # Get a connection from the pool
                conn = self.connection_pool.getconn()

                # Execute the query
                with conn.cursor() as cursor:
                    cursor.execute(generated_sql_query)
                    data = cursor.fetchall()
                    columns = [desc[0] for desc in cursor.description]  # Fetch column names

                # Create DataFrame from the query results
                df = pd.DataFrame(data, columns=columns)

            except Exception as e:
                error = str(e)
                # Attempt reconnection and re-execution
                try:
                    if conn:
                        self.connection_pool.putconn(conn)  # Return the connection to the pool
                    conn = self.connection_pool.getconn()  # Get a new connection
                    with conn.cursor() as cursor:
                        cursor.execute(generated_sql_query)
                        data = cursor.fetchall()
                        columns = [desc[0] for desc in cursor.description]  # Fetch column names

                    # Create DataFrame from the query results
                    df = pd.DataFrame(data, columns=columns)
                    error = None  # Clear the error if reconnection is successful
                except Exception as reconnect_error:
                    error = str(reconnect_error)
                    # print(f"Error after reconnecting: {error}")

            finally:
                # Ensure the connection is returned to the pool
                if conn:
                    self.connection_pool.putconn(conn)
        else:
            error = "Invalid or unsupported SQL query."
            df = pd.DataFrame()  # Return an empty DataFrame for invalid queries

        # if error:
        #     print(f"Error during query execution: {error}")
        return df, error

    def get_sql_query(self, user_prompt, column_list, redshift_table):

        ## Generating Sql Chain
        sql_chain = self.generate_sql_chain(user_prompt, column_list, redshift_table)
        
        ## Running the SQL Chain
        result = sql_chain.invoke({'human_prompt': user_prompt})
        sql_chain_query = result.replace('```sql','').replace('```','').replace('`', '').strip()
        sql_chain_query = self.extract_sql_query_from_string(sql_chain_query)

        return sql_chain_query
    
    def get_result_from_sql_query(
        self,
        user_prompt = None,
        column_list = None,
        redshift_table = None,
        generated_sql_query = None
    ):
        """
        Returns results from a SQL query using one of two modes:

        Mode 1: Dynamic SQL Generation Mode
            - Required Inputs: user_prompt, column_list, redshift_table
            - If 'generated_sql_query' is not provided, it will be created using get_sql_query()
        
        Mode 2: Direct SQL Execution Mode
            - Required Input: generated_sql_query
            - Directly uses the provided SQL to fetch results
        
        Returns:
            Final results from executing the SQL query.
        
        Raises:
            ValueError: If required parameters are missing for either mode.
        """
        
        if generated_sql_query is None:
            # Mode 1: Dynamic SQL generation from user prompt
            if not all([user_prompt, column_list, redshift_table]):
                raise ValueError("To generate SQL, you must provide user_prompt, column_list, and redshift_table.")
            
            print("[MODE 1:] No SQL provided. Generating SQL from prompt, column list, and table...")
            generated_sql_query = self.get_sql_query(
                user_prompt=user_prompt,
                column_list=column_list,
                redshift_table=redshift_table
            )
        
        else:
            # Mode 2: Use the provided SQL query directly
            print("[MODE 2:] Using the provided SQL query.")


        # Run to get the dataframe
        df, sql_error_message = self.get_dataframe_from_sql(generated_sql_query)

        # Handle SQL errors with a retry mechanism
        if sql_error_message:
            print(f"--- Initial SQL Error Encountered ---\nError Message: {sql_error_message}\nFaulty Query: {generated_sql_query}\n──────────────────────────────────────────────────────────────────────────────────────────\n")
            
            input_error_message = sql_error_message
            input_error_query = generated_sql_query

            for attempt, _ in enumerate(range(4)):  # Retry up to 3 times
                corrected_query = self.correct_sql_query(user_prompt, str(input_error_query), input_error_message)
                df, new_error_message = self.get_dataframe_from_sql(corrected_query)

                if new_error_message:
                    input_error_query = corrected_query
                    input_error_message = new_error_message
                    print(f"Attempt {attempt + 1} Failed:\nError Message: {new_error_message}\nFaulty Query: {input_error_query}\n──────────────────────────────────────────────────────────────────────────────────────────\n")
                else:
                    generated_sql_query = corrected_query
                    print(f"--- Success on Attempt {attempt + 1} ---\nCorrected Query Executed Successfully: {generated_sql_query}\n──────────────────────────────────────────────────────────────────────────────────────────\n")
                    break  # Exit the loop if no errors
        
        ## Rename Columns 
        if not df.empty and len(df)>100:
            df.columns = (
                df.columns
                .str.replace(r"^emr_", "", regex=True)   # remove 'emr_' prefix
                .str.replace("_", " ")                  # replace underscores with space
                .str.title()                            # capitalize each word
            )
        return generated_sql_query, df

## SQL QUERY GENERATOR - OBJECT
sql_query_generator = SQLQueryGenerator(connection_pool, table_info_from_db, embedding_task_helper)

In [30]:
# ################# Inner Approach ####################
# generated_sql_query = sql_query_generator.get_sql_query(user_prompt_to_use, columns_info_to_generate_sql, table_to_use)
# print(generated_sql_query)
# ################# #################### ####################

# # generated_sql_query, df = sql_query_generator.get_result_from_sql_query(user_prompt_to_use, columns_info_to_generate_sql, table_to_use)
# # df.head()

In [31]:
# print(user_prompt_to_use)
# print(generated_sql_query)

## 2: MultiStageSQLGenerator

### Previous Class: MultiStageSQLGenerator [NOT IN USE]

In [32]:
# class MultiStageSQLGenerator:
#     def __init__(self):
#         self.llm = ChatOpenAI(model="gpt-4.1-mini", temperature=0)
#         self.parser = StrOutputParser()

#     def prompt_chain_step1(self, user_prompt: str, schema: Dict[str, list], 
#                         additional_context: str = "", context_type: str = "general") -> str:
#         """
#         First step: Generate initial SQL query using table schema and focused context.
        
#         Args:
#             user_prompt: Natural language query from user
#             schema: Dictionary with table_name and columns list
#             additional_context: Optional specialized context (data dictionary, domain knowledge)
#             context_type: Type of context provided ("general", "domain_knowledge", "data_dictionary", "complex_instruction")
        
#         Returns:
#             Initial SQL query
#         """
        
#         # Format context based on type for better interpretation
#         context_section = ""
#         if additional_context:
#             context_title = context_type.replace('_', ' ').title()
#             context_section = f"\n\n### {context_title}:\n{additional_context.strip()}"
        
#         # Customize system message based on context type
#         if context_type == "data_dictionary":
#             system = """You are an SQL translator specialized in data dictionary interpretation.
#     Focus exclusively on using the provided data dictionary to create precise column mappings.
#     Return only the SQL query with no explanations."""
#         elif context_type == "domain_knowledge":
#             system = """You are a domain expert SQL generator.
#     Apply the provided domain knowledge to create contextually appropriate filters and joins.
#     Return only the SQL query with no explanations."""
#         elif context_type == "complex_instruction":
#             system = """You are an advanced SQL engineer specializing in complex query patterns.
#     Apply the provided instructions to structure the query appropriately.
#     Return only the SQL query with no explanations."""
#         else:
#             system = """You are a basic SQL translator.
#     Generate a clear, focused SQL query using only the provided schema.
#     Return only the SQL query with no explanations.
#     The SQL query must end with a semicolon (;)."""

#         # Keep user message focused and structured
#         user = f"""User Query: {user_prompt}

#     Table: {schema['table_name']}
#     Columns: {', '.join(schema['columns'])}{context_section}

#     Generate SQL:"""

#         # Create and execute the chain
#         prompt = ChatPromptTemplate.from_messages([
#             SystemMessagePromptTemplate.from_template(system),
#             HumanMessagePromptTemplate.from_template("{msg}")
#         ])
#         chain = prompt | self.llm | self.parser
#         return chain.invoke({"msg": user.strip()})

#     def extract_sql_query_from_string(self, query_string):
#         """
#         Extracts SQL query from a string response using regex pattern matching.
        
#         Args:
#             query_string (str): String potentially containing SQL query
            
#         Returns:
#             str or None: Extracted SQL query or None if no valid query found
#         """
#         # Regular expression to find the start of the SQL query
#         sql_pattern = re.compile(r"(?is)(?<!\S)(SELECT|WITH)\b.*", re.IGNORECASE | re.DOTALL)

#         # Search for the SQL query using regex (start point)
#         match = sql_pattern.search(query_string)
#         if not match:
#             return None

#         # Get the query from the start
#         query_start = match.start()

#         # Find the index of the last semicolon in the query string
#         last_semicolon_index = query_string.rfind(';')

#         # If no semicolon is found, return None (invalid query)
#         if last_semicolon_index == -1:
#             return None

#         # Extract the query from the start of the match to the last semicolon
#         sql_query = query_string[query_start:last_semicolon_index + 1].strip()

#         return sql_query

#     def feedback_agent(self, user_prompt: str, sql: str, schema: Dict[str, list], additional_context: str = "") -> str:
#         """
#         Validate the SQL query against specific rules and provide targeted feedback.

#         Args:
#             user_prompt: Original user query
#             sql: Current SQL query to validate
#             schema: Dictionary with table_name and columns
#             additional_context: Domain-specific validation rules

#         Returns:
#             Validation feedback as a string - "YES" if valid, otherwise specific issues
#         """

#         # Extract schema components
#         columns = schema['columns']
#         table_name = schema['table_name']

#         # Date context for time-based queries
#         today_date = datetime.now()
#         date_context = (f"Today's date is: {today_date.strftime('%Y-%m-%d')} — the {today_date.day} of "
#                     f"{today_date.strftime('%B')}, {today_date.year}.")

#         # Format additional context if provided
#         context_section = f"\n\n### Additional Context:\n{additional_context.strip()}" if additional_context else ""

#         # System message focused on validation rules
#         system_message = f"""You are an SQL validator with expertise in Redshift syntax and best practices.

#     Output format:
#     - Return ONLY "YES" if the query is fully correct
#     - Otherwise, return ONLY a numbered list of specific issues

#     Validation Rules:
#     1. Use only columns that exist in the provided schema
#     2. Use correct SQL syntax and appropriate Redshift features
#     3. Match filtering requirements in the user query
#     4. Use case-insensitive comparisons with LOWER() where appropriate
#     5. Use date filtering only when relevant to the user query:
#         - Dates are given inside the columns: 'emr_month_yr', 'emr_discharge_date'.
#             - emr_discharge_date: YYYY-MM-DD (e.g., 2023-05-14, use LIKE '2023-%')
#             - emr_month_yr: MM-YYYY (e.g., 05-2023, use LIKE '%-2023' for year filtering)
#         - Use the LIKE % operator while dealing with the dates. 
#         - **Grouping Date Filters**:
#             - Always **group date filters properly** using parentheses to ensure correct logical grouping with other conditions in the `WHERE` clause.
#                 - Example:
#                     ```sql
#                     WHERE 
#                         (emr_month_yr LIKE '%-2023' AND emr_month_yr >= '11-2023')
#                         OR (emr_month_yr LIKE '%-2024' AND emr_month_yr <= '10-2024')
#                     ```
#             - This ensures that date conditions are applied together and evaluated accurately with other filters.
#     6. Include all dimensions mentioned in the user query
#     7. Add appropriate aggregations (SUM, COUNT, AVG, MIN, MAX) for metrics mentioned in the query
#     8. Include ORDER BY only if relevant (for sorting, ranking, etc.)
#     9. Structure subqueries and CTEs with proper grouping
#     10. For comparison queries, include all compared dimensions

#     {date_context}
#     {context_section}
#     """

#         # Human message with query details
#         human_message = f"""
#     User Query: {user_prompt}

#     Table: {table_name}

#     Available Columns: {', '.join(columns)}

#     SQL to Validate:
#     {sql}
#     """

#         # Create and execute validation chain
#         prompt = ChatPromptTemplate.from_messages([
#             SystemMessagePromptTemplate.from_template(system_message),
#             HumanMessagePromptTemplate.from_template("{msg}")
#         ])
#         chain = prompt | self.llm | self.parser
#         return chain.invoke({"msg": human_message}).strip()

#     def refine_sql_based_on_feedback(self, user_prompt: str, previous_sql: str, 
#                                     feedback: str, schema: Dict[str, list]) -> str:
#         """
#         Refine the SQL query based on specific feedback from the validation step.
        
#         Args:
#             user_prompt: Original user query
#             previous_sql: Current SQL query being refined
#             feedback: Specific issues identified by the feedback agent
#             schema: Dictionary with table_name and columns
            
#         Returns:
#             Refined SQL query addressing the identified issues
#         """
        
#         # Extract schema components
#         columns = schema['columns']
#         table_name = schema['table_name']
        
#         # Date context for time-based queries
#         today_date = datetime.now()
#         date_context = f"Today's date: {today_date.strftime('%Y-%m-%d')}, day {today_date.day} of {today_date.strftime('%B')} {today_date.year}"
        
#         # System prompt focused on targeted refinement
#         system_message = f"""You are an SQL refinement specialist.
        
#     Your task is to fix ONLY the specific issues in the feedback while preserving the query's core functionality.

#     Guidelines:
#     1. Address each numbered issue in the feedback explicitly
#     2. Maintain the original query structure where possible
#     3. Apply SQL best practices and Redshift-specific optimizations
#     4. Ensure all required columns are used correctly
#     5. Preserve the original business intent of the query

#     {date_context}

#     Return only the corrected SQL query with no explanations.
#     """

#         # Human message with detailed context
#         human_message = f"""
#     Original User Query: {user_prompt}

#     Available Schema:
#     - Table: {table_name}
#     - Columns: {', '.join(columns)}

#     Previous SQL Query:
#     {previous_sql}

#     Issues to Fix:
#     {feedback}

#     Return the corrected SQL query:
#     """

#         # Create and execute refinement chain
#         prompt = ChatPromptTemplate.from_messages([
#             SystemMessagePromptTemplate.from_template(system_message),
#             HumanMessagePromptTemplate.from_template("{msg}")
#         ])
#         chain = prompt | self.llm | self.parser
#         return chain.invoke({"msg": human_message}).strip()

#     def generate_and_iteratively_refine_sql(
#         self,
#         user_prompt: str,
#         schema: Dict[str, list],
#         domain_context: str = "",
#         complex_instructions: str = "",
#         max_iterations: int = 2,
#         verbose: bool = True
#     ) -> str:
#         """
#         Enhanced SQL generation with iterative refinement and optional debug outputs.

#         Args:
#             user_prompt: Natural language query from the user.
#             schema: Dictionary containing table names and columns.
#             domain_context: Optional domain-specific knowledge.
#             complex_instructions: Optional advanced instructions.
#             max_iterations: Maximum feedback‑refinement loops.
#             verbose: Whether to print detailed step-by-step output.

#         Returns:
#             Optimized SQL query string.
#         """
#         if verbose:
#             print("\n========== USER PROMPT ==========")
#             print(user_prompt)

#         # ---------- Phase 1 | Initial SQL ----------
#         if verbose:
#             print("\n========== PHASE 1: INITIAL SQL GENERATION ==========")
#         initial_sql = self.prompt_chain_step1(
#             user_prompt,
#             schema,
#             additional_context="",
#             context_type="general",
#         )
#         current_sql = self.extract_sql_query_from_string(initial_sql) or initial_sql
#         if verbose:
#             print("Initial SQL Generated:")
#             print(current_sql)

#         # ---------- Phase 2 | Add Domain Knowledge ----------
#         if domain_context:
#             if verbose:
#                 print("\n========== PHASE 2: DOMAIN KNOWLEDGE ENHANCEMENT ==========")
#             domain_enhanced_sql = self.prompt_chain_step1(
#                 user_prompt,
#                 schema,
#                 additional_context=f"""\
#     Existing SQL:
#     {current_sql}

#     Domain Knowledge:
#     {domain_context}
#     """,
#                 context_type="domain_knowledge",
#             )
#             current_sql = self.extract_sql_query_from_string(domain_enhanced_sql) or domain_enhanced_sql
#             if verbose:
#                 print("Domain-Enhanced SQL:")
#                 print(current_sql)

#             feedback = self.feedback_agent(
#                 user_prompt,
#                 current_sql,
#                 schema,
#                 additional_context=domain_context,
#             )
#             if verbose:
#                 print("\nDomain SQL Feedback:")
#                 print(feedback)

#             if feedback.upper() != "YES":
#                 if verbose:
#                     print("\nRefining SQL based on domain feedback...")
#                 current_sql = (
#                     self.extract_sql_query_from_string(
#                         self.refine_sql_based_on_feedback(
#                             user_prompt, current_sql, feedback, schema
#                         )
#                     )
#                     or current_sql
#                 )
#                 if verbose:
#                     print("Refined Domain SQL:")
#                     print(current_sql)
#             elif verbose:
#                 print("✅ Domain-enhanced SQL validated.")

#         # ---------- Phase 3 | Add Complex Instructions ----------
#         if complex_instructions:
#             if verbose:
#                 print("\n========== PHASE 3: COMPLEX INSTRUCTION ENHANCEMENT ==========")
#             instruction_enhanced_sql = self.prompt_chain_step1(
#                 user_prompt,
#                 schema,
#                 additional_context=f"""\
#     Existing SQL:
#     {current_sql}

#     Complex Instructions:
#     {complex_instructions}
#     """,
#                 context_type="complex_instruction",
#             )
#             current_sql = (
#                 self.extract_sql_query_from_string(instruction_enhanced_sql)
#                 or instruction_enhanced_sql
#             )
#             if verbose:
#                 print("Instruction-Enhanced SQL:")
#                 print(current_sql)

#             feedback = self.feedback_agent(
#                 user_prompt,
#                 current_sql,
#                 schema,
#                 additional_context=complex_instructions,
#             )
#             if verbose:
#                 print("\nInstruction SQL Feedback:")
#                 print(feedback)

#             if feedback.upper() != "YES":
#                 if verbose:
#                     print("\nRefining SQL based on instruction feedback...")
#                 current_sql = (
#                     self.extract_sql_query_from_string(
#                         self.refine_sql_based_on_feedback(
#                             user_prompt, current_sql, feedback, schema
#                         )
#                     )
#                     or current_sql
#                 )
#                 if verbose:
#                     print("Refined Instruction SQL:")
#                     print(current_sql)
#             elif verbose:
#                 print("✅ Instruction-enhanced SQL validated.")

#         # ---------- Phase 4 | Combined Context & Final Loops ----------
#         if domain_context and complex_instructions:
#             if verbose:
#                 print("\n========== PHASE 4: COMBINED CONTEXT ENHANCEMENT ==========")
#             combined_context = f"{domain_context}\n\n{complex_instructions}"
#             combined_enhanced_sql = self.prompt_chain_step1(
#                 user_prompt,
#                 schema,
#                 additional_context=f"""\
#     Existing SQL:
#     {current_sql}

#     Combined Context:
#     {combined_context}
#     """,
#                 context_type="complex_instruction",
#             )
#             current_sql = (
#                 self.extract_sql_query_from_string(combined_enhanced_sql)
#                 or combined_enhanced_sql
#             )
#             if verbose:
#                 print("Combined-Enhanced SQL:")
#                 print(current_sql)

#             for i in range(1, max_iterations + 1):
#                 if verbose:
#                     print(f"\n========== FINAL FEEDBACK LOOP {i} ==========")
#                 feedback = self.feedback_agent(
#                     user_prompt,
#                     current_sql,
#                     schema,
#                     additional_context=combined_context,
#                 )
#                 if verbose:
#                     print(f"Feedback: {feedback}")

#                 if feedback.upper() == "YES":
#                     if verbose:
#                         print(f"✅ SQL validated at iteration {i}")
#                     break

#                 if verbose:
#                     print("Refining SQL based on final feedback...")
#                 current_sql = (
#                     self.extract_sql_query_from_string(
#                         self.refine_sql_based_on_feedback(
#                             user_prompt, current_sql, feedback, schema
#                         )
#                     )
#                     or current_sql
#                 )
#                 if verbose:
#                     print(f"Refined SQL (Iteration {i}):")
#                     print(current_sql)

#         if verbose:
#             print("\n========== FINAL SQL OUTPUT ==========")
#             print(current_sql)

#         return current_sql
    
# multi_stage_sql_generator = MultiStageSQLGenerator()

### New Approach - Single Pass Generation

In [33]:
class MultiStageSQLGenerator:
    def __init__(
        self,
        model: str = "gpt-4.1-mini",          # Faster default model
        temperature: float = 0,
        request_timeout: int = 20,           # Lower timeout to fail fast on network stalls
        max_retries: int = 1,                # Reduce retry loops to cut latency
        max_tokens: Optional[int] = 2400     # Cap output to keep responses snappy
    ):
        """
        Speed-oriented setup. You can switch back to 'gpt-4.1-mini' if you prefer.
        """
        self.llm = ChatOpenAI(
            model=model,
            temperature=temperature,
            timeout=request_timeout,
            max_retries=max_retries,
            max_tokens=max_tokens
        )
        self.parser = StrOutputParser()

    # -------- Core: single generator prompt (already considers all contexts) --------
    def prompt_chain_single(self, user_prompt: str, schema: Dict[str, list],
                            domain_context: str = "", complex_instructions: str = "") -> str:
        """
        One concise generation pass that integrates domain knowledge + complex instructions if provided.
        Returns ONLY SQL (per instructions), ending with a semicolon.
        """

        # Build a tight context block to minimize tokens but keep precision
        blocks = []
        if domain_context:
            blocks.append(f"### Domain Knowledge:\n{domain_context.strip()}")
        if complex_instructions:
            blocks.append(f"### Complex Instructions:\n{complex_instructions.strip()}")
        extra = ("\n\n" + "\n\n".join(blocks)) if blocks else ""

        system = """You are a senior Redshift SQL generator.
- Use ONLY the provided table and columns.
- Apply domain/context instructions if given.
- Use case-insensitive comparisons with LOWER() for string filters where appropriate.
- Dates:
  * emr_discharge_date: YYYY-MM-DD  (e.g., 2023-05-14; year filter: emr_discharge_date LIKE '2023-%')
  * emr_month_yr: MM-YYYY          (e.g., 05-2023; year filter: emr_month_yr LIKE '%-2023')
  * Use LIKE with % for flexible date filtering where needed.
  * Group date conditions in parentheses when mixing with other filters.

📌 **Date Range Conversion (Mandatory)**  
    - ALWAYS include date range columns in every query:  
    - For `emr_discharge_date`, ALWAYS convert using:
        TO_CHAR(MIN(TO_DATE(emr_discharge_date, 'YYYY-MM-DD')), 'YYYY-MM-DD') AS start_date,
        TO_CHAR(MAX(TO_DATE(emr_discharge_date, 'YYYY-MM-DD')), 'YYYY-MM-DD') AS end_date  
    - This conversion MUST be applied even if the user does not explicitly request the date range.  

- Include aggregations (SUM, COUNT, AVG, MIN, MAX) when the user asks for metrics.
- Include ORDER BY only if relevant.
- Structure CTEs/subqueries cleanly when helpful.
Return ONLY the SQL query, ending with a semicolon.
"""

        # Keep the user message compact to reduce tokens
        user = f"""User Query: {user_prompt}

Table: {schema['table_name']}
Columns: {', '.join(schema['columns'])}{extra}

Generate SQL:"""

        prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(system),
            HumanMessagePromptTemplate.from_template("{msg}")
        ])
        chain = prompt | self.llm | self.parser
        return chain.invoke({"msg": user.strip()})

    # -------- Collapsed: validator OR fixer (one call does both) --------
    def validator_or_fix(self, user_prompt: str, sql: str, schema: Dict[str, list],
                         additional_context: str = "") -> str:
        """
        If query is fully correct → return EXACTLY 'YES'
        Else → return ONLY the corrected SQL query (no explanations).
        This compresses the 'feedback + refine' cycle into a single call.
        """

        columns = schema['columns']
        table_name = schema['table_name']
        today_date = datetime.now()
        date_context = (f"Today's date is: {today_date.strftime('%Y-%m-%d')} — "
                        f"the {today_date.day} of {today_date.strftime('%B')}, {today_date.year}.")

        context_block = f"\n\n### Additional Context:\n{additional_context.strip()}" if additional_context else ""

        system = f"""You are an expert Redshift SQL validator and fixer.

Output rules:
- If the SQL is fully correct, respond EXACTLY with: YES
- Otherwise, respond ONLY with the corrected SQL query (no comments), ending with a semicolon.

Validation Rules:
1) Use only columns from the provided schema
2) Correct Redshift syntax
3) Match the user's filtering intent
4) Case-insensitive comparisons via LOWER() when appropriate
5) Date filtering (only if relevant):
   - emr_discharge_date: YYYY-MM-DD (use LIKE 'YYYY-%' for year)
   - Use LIKE with % when needed
   - Group date conditions in parentheses with other filters
    📌 **Date Range Conversion (Mandatory)**  
        - ALWAYS include date range columns in every query:  
        - For `emr_discharge_date`, ALWAYS convert using:
            TO_CHAR(MIN(TO_DATE(emr_discharge_date, 'YYYY-MM-DD')), 'YYYY-MM-DD') AS start_date,
            TO_CHAR(MAX(TO_DATE(emr_discharge_date, 'YYYY-MM-DD')), 'YYYY-MM-DD') AS end_date  
        - This conversion MUST be applied even if the user does not explicitly request the date range.  

6) Include all dimensions mentioned by the user
7) Use aggregations for metric asks
8) Include ORDER BY only if relevant
9) Proper grouping for CTEs/subqueries
10) For comparisons, include all compared dimensions

{date_context}{context_block}
"""

        human = f"""User Query: {user_prompt}

Table: {table_name}
Available Columns: {', '.join(columns)}

SQL to Validate:
{sql}
"""

        prompt = ChatPromptTemplate.from_messages([
            SystemMessagePromptTemplate.from_template(system),
            HumanMessagePromptTemplate.from_template("{msg}")
        ])
        chain = prompt | self.llm | self.parser
        return chain.invoke({"msg": human}).strip()

    # -------- Utility: robust SQL extractor (kept from your version) --------
    def extract_sql_query_from_string(self, query_string):
        """
        Extract SQL starting at SELECT/WITH up to the last semicolon.
        """
        sql_pattern = re.compile(r"(?is)(?<!\S)(SELECT|WITH)\b.*", re.IGNORECASE | re.DOTALL)
        match = sql_pattern.search(query_string)
        if not match:
            return None
        last_semicolon_index = query_string.rfind(';')
        if last_semicolon_index == -1:
            return None
        return query_string[match.start(): last_semicolon_index + 1].strip()

    # -------- Public API: drastically fewer LLM calls while keeping quality --------
    def generate_and_iteratively_refine_sql(
        self,
        user_prompt: str,
        schema: Dict[str, list],
        domain_context: str = "",
        complex_instructions: str = "",
        max_iterations: int = 2,     # You can set 1 for even faster
        verbose: bool = True
    ) -> str:
        """
        New flow:
          1) Single-pass generation (already integrates domain + complex context)
          2) Single validator_or_fix call per iteration (YES or return corrected SQL)
          → Default path: ~3 calls total vs. many in the original
        """
        if verbose:
            print("\n========== USER PROMPT ==========")
            print(user_prompt)

        # ---- Phase 1: Single-pass generation with combined context ----
        if verbose:
            print("\n========== PHASE 1: SINGLE-PASS GENERATION ==========")

        initial_sql = self.prompt_chain_single(
            user_prompt=user_prompt,
            schema=schema,
            domain_context=domain_context,
            complex_instructions=complex_instructions,
        )
        current_sql = self.extract_sql_query_from_string(initial_sql) or initial_sql
        if verbose:
            print("Initial SQL:")
            print(current_sql)

        # Combine contexts once for validation to keep prompts tiny
        combined_context = "\n\n".join([c for c in [domain_context, complex_instructions] if c])

        # ---- Phase 2+: Compact validation/fix loops (1 call per iteration) ----
        for i in range(1, max_iterations + 1):
            if verbose:
                print(f"\n========== VALIDATE/FIX LOOP {i} ==========")
            verdict_or_fix = self.validator_or_fix(
                user_prompt=user_prompt,
                sql=current_sql,
                schema=schema,
                additional_context=combined_context,
            )

            if verdict_or_fix.strip().upper() == "YES":
                if verbose:
                    print(f"✅ SQL validated at iteration {i}")
                break

            # Otherwise we received corrected SQL — extract and continue
            maybe_sql = self.extract_sql_query_from_string(verdict_or_fix) or verdict_or_fix
            if verbose:
                print("Received corrected SQL:")
                print(maybe_sql)
            current_sql = maybe_sql

        if verbose:
            print("\n========== FINAL SQL OUTPUT ==========")
            print(current_sql)

        return current_sql

# Instantiate the faster generator (compatible name)
multi_stage_sql_generator = MultiStageSQLGenerator()

#### Rough to test: MultiStageSQLGenerator

In [34]:
# # user_prompt = "What is the '42321610-SPINAL SCREWS OR SCREW EXTENSIONS' UNSPSC product utilization across physicians used in '0SG10AJ-FUSION 2-4 L JT W INTBD FUS DEV, POST APPR A COL, OPEN' procedure?"
# # user_prompt = 'Whats the total acquisition cost for the year 2024?'
# # user_prompt = 'Which Market has the highest total cost?'
# user_prompt = "compare the top 10 primary procedures by average cpc for Market Lorain and Market Youngstown for Service line WOMEN'S HEALTH for 2024"
# # user_prompt = input()

# print(f'user_prompt: {user_prompt}')

# ## Running the Chain | TableSelection
# selected_table_text = table_selection_object.get_table_name(user_prompt)

# #### ------------------------------------------------ ## ------------------------------------------------ ####
# #### Retrieving best matched column info for SQL generation based on user prompt and selected table
# #### ------------------------------------------------ ## ------------------------------------------------ ####

# table_to_use, filtered_df, best_match_columns, table_columns, column_info_from_knowledge_base, prompt_to_use_for_complex_question, columns_info_to_generate_sql = embedding_task_helper.get_column_info_for_sql_generation(user_prompt, selected_table_text)

# schema = {
#     "table_name": table_to_use,
#     "columns": table_columns
# }

# domain_dict = column_info_from_knowledge_base
# addons = prompt_to_use_for_complex_question

#### Running the SQL Multi Stage Generation Way ######
# Generate SQL using the function-based approach
# final_sql = multi_stage_sql_generator.generate_and_iteratively_refine_sql(
#     user_prompt=user_prompt,
#     schema=schema,
#     domain_context=domain_dict,
#     complex_instructions=prompt_to_use_for_complex_question,
#     max_iterations = 2)

# generated_sql_query, wf = sql_query_generator.get_result_from_sql_query(generated_sql_query=final_sql)

# SQLResponseGenerator 

In [None]:
class SQLResponseGenerator:
    
    def __init__(self, s3_utils):

        ## __ GPT-4.1 Response __
        model_name = 'gpt-4.1-2025-04-14'
        self.llm = ChatOpenAI(model=model_name, temperature=0)

        # ____ GPT 5 Response ____
        # model_name = 'gpt-5-2025-08-07'
        # model_name = "gpt-5-mini-2025-08-07"
        # model_name = "gpt-5-nano-2025-08-07"
        # self.llm = ChatOpenAI(model=model_name, temperature=1)

        self.get_redirect_link_from_df = s3_utils.get_redirect_link_from_df

        self.generic_response_guideline = """
# Generic Guidelines:
- If the data in the DataFrame are in abbreviated form, do not expand them.
- If the DataFrame is empty, politely let the user know that there is no available answer for their question.
- **Avoid mentioning any details related to the table used in the SQL query.
- **Do not reuse the exact column names from the SQL query when presenting data. Instead, use suitable synonyms or natural-language alternatives that convey the same meaning.**
- **State the data exactly as it appears in the DataFrame without summarizing, interpreting, or making any assumptions about trends.** Only list each value as observed.
- Review the SQL query to determine the basis of sorting, and include this information in the response to clarify the order of results.
- Only mention sorting or ordering if the user’s question explicitly asks about ranking, order, top, or lowest; otherwise, do not include it.
- *Use appropriate commas when formatting any cost-related numbers to make them easier to read.
- If the user requests any 'CHARTS' (e.g., Pie Chart, Bar Chart), do not mention it in your response.
- If all values for a parameter are identical across entities (e.g., identical scores across Entity), explicitly list the names of these entities to give the user clarity.
- Tone of Conversation: Polite and humble, professional, concise and to the point.
- Be as detailed and informative as possible while adhering strictly to these guidelines.
- ***When referencing data from the DataFrame in the response, do not cut, modify, or remove any part of the values. This includes retaining all numeric identifiers, leading information, and any symbols or abbreviations as they appear in the DataFrame.
- Do NOT add any closing or follow-up phrases such as "Let me know if you need more information", "You can ask for a breakdown", "Feel free to request", or similar open-ended invitations.
- End the response right after presenting the last fact or data point. No extra commentary or summary should follow.

## Data Analysis Response Rules (VERY IMPORTANT)
- Priority Order: Question > DataFrame > SQL Query
- **Core Rule: Only include what's explicitly requested. If a parameter isn't asked for, exclude it entirely.**
- Before responding, verify: Does this directly answer only what was asked?

"""
        
        self.formatting_requirements = """
## Date and Time Formatting Requirements:
	1.	Date Formats:
        • Always preserve the full date granularity available in the DataFrame.
        • If day is present → output in YYYY-MM-DD (Example: 2023-05-14)
        • If only month/year is present → output in Month YYYY (Example: July 2023)
        • Never drop the day when it exists.
	2.	Month Representation:
	    •	Use full month names with year (Example: July 2023)
	    •	Convert any numeric months to full text format.
	3.	Chronological Ordering:
	    •	When DataFrame contains temporal data (dates/months):
            a. Primary sort: Year (ascending)
            b. Secondary sort: Month (ascending)
            c. Tertiary sort: Day (ascending, if available)
                •	Example order: 2023-01-15, 2023-01-30, 2023-02-10
                •	Apply this sorting regardless of how data appears in the DataFrame.
                •	Maintain this order even if the SQL query specifies a different ordering.
	4.	Month-Specific Analysis:
	    •	When a specific month is referenced, provide focused insights for that timeframe using the full month format.
	    •	If the data contains day-level granularity for that month, include the exact days.

## Formatting Rules for Responses:
    1. Currency Formatting:
    - Use $ for CPC, Spend, Price, Cost, or Charge.
    - Numbers must:
        - Include commas (`,`) for readability.
        - Be rounded to two decimal places.
        - Have the currency symbol prefixed (e.g., `$36,614.58`).
        - Eg: For Percentage: 32.56%

    2. Non-Currency Numbers:
    - Exclude the currency symbol for metrics like Outcome Scores or Encounters.
    - Format with commas and two decimal places (e.g., `36,614.58`).

    3. Presentation:
    - Ensure numbers align with context (e.g., CPC is always currency).

    Example Response:
    - December 2023: $34,494.08
    - November 2023: $36,826.91
    - Percentage: 32.56%
    - Treat market share values as percentages (e.g., 0.46 → 46%)."""
        
        self.comparison_related_guideline = """
-----------------------
## Comparison Guideline
    - For comparison requests, always present results in a proper *Markdown tabular format*,
    - Each row = entity; each column = metric.
    - Show data exactly as in the DataFrame, no edits.
    - Mention sorting basis, use commas for numbers.
    - Keep tone polite, humble, professional, concise.
-----------------------
"""

    def get_date_range(self, df: Optional[pd.DataFrame]) -> Tuple[str, pd.DataFrame]:
        """
        Returns a full date range instruction prompt for the LLM and the updated DataFrame
        with 'start_date' and 'end_date' columns removed. If unavailable, returns a blank 
        string and the original DataFrame.
        """
        

        if df is None or not isinstance(df, pd.DataFrame):
            return "", df


        ## Make a Copy of DataFrame
        df_copy = df.copy()

        # Normalize column names (lowercase, remove spaces, underscores)
        normalized_cols = {col: col.strip().lower().replace(" ", "_") for col in df_copy.columns}
        df_copy.rename(columns=normalized_cols, inplace=True)

        # If required columns are missing or DataFrame is empty, return unchanged
        if 'start_date' not in df_copy.columns or 'end_date' not in df_copy.columns or df_copy.empty:
            return "", df_copy

        df_copy['start_date'] = pd.to_datetime(df_copy['start_date'], errors='coerce')
        df_copy['end_date'] = pd.to_datetime(df_copy['end_date'], errors='coerce')

        if df_copy['start_date'].dropna().empty or df_copy['end_date'].dropna().empty:
            # Drop the columns even if invalid data
            df_copy = df_copy.drop(columns=['start_date', 'end_date'], errors='ignore')
            return "", df_copy

        # Find min and max dates
        min_date = df_copy['start_date'].min()
        max_date = df_copy['end_date'].max()

        # Format dates as "Month Day, Year" (remove leading zeros)
        min_str = min_date.strftime("%B %d, %Y").replace(" 0", " ")
        max_str = max_date.strftime("%B %d, %Y").replace(" 0", " ")

        if not min_str or not max_str:
            df_copy = df_copy.drop(columns=['start_date', 'end_date'], errors='ignore')
            return "", df_copy

        # Drop date columns before returning
        df_copy = df_copy.drop(columns=['start_date', 'end_date'], errors='ignore')

        date_instruction = (
            f"📅 **Data Range Instruction**:\n"
            f"\t- The data used for this query spans from **{min_str}** to **{max_str}**.\n"
            f"\t- Always mention the **starting date** and **ending date** of the data used for this query "
            f"at the very **BOTTOM** of your response, even if the user did not explicitly ask for it.\n"
            f"\t- Always show the start and end dates exactly as this format: Full_Month DAY, YEAR(YYYY)."
        )

        ## Inverse Normalization
        df_copy.columns = (
                df_copy.columns
                .str.replace("_", " ")                  # replace underscores with space
                .str.title()                            # capitalize each word
            )

        return date_instruction, df_copy

    def _generate_standard_prompt_for_small_df(self, question, sqlQuery, df, date_instruction):
        
        empty_df_instruction = empty_df_instruction = """
When the DataFrame is empty, inform the user that no data is available, which may be due to missing information or because the requested criteria did not match any records.
""" if df.empty else ""
        
        prompt = f'''You are a friendly assistant working for SupplyCopia. Based on the user’s question, SQL Query, and DataFrame, provide a detailed, conversational response that directly addresses the user’s question.
        
{self.generic_response_guideline}

{self.formatting_requirements}

{self.comparison_related_guideline}

------------  
USER'S QUESTION: {question}  
------------  
SQL QUERY: {sqlQuery}  
------------  
DataFrame: {df.to_json().replace('{', '{{').replace('}', '}}')}  
------------  

{empty_df_instruction}

{date_instruction}

RESPONSE TO USER:  
'''
        return prompt

    def _generate_top_rows_with_download_link_prompt(self, question, sqlQuery, df, redirect_link, date_instruction):

        top_df = df.head(20)
        total_rows = len(df)

        columns_to_use = df.columns.tolist()

        prompt = f'''You are a friendly assistant working for SupplyCopia. Based on the user’s question, SQL Query, and DataFrame, provide a detailed, professional response that shows the top results and provides a download link for all data.

# Specific Guidelines:
    - Only the top 20 rows out of a total of {total_rows} rows are shown below. Mention that the full dataset is available via the download link.

Download Link: [Click here to download full results]({redirect_link})

{self.generic_response_guideline}

# Formatting:
    {self.formatting_requirements}

{self.comparison_related_guideline}    

------------
USER'S QUESTION: {question}
------------
SQL QUERY: {sqlQuery}
------------
Showing Top 20 of {total_rows} Rows (Full dataset is downloadable via the download link given below.)
------------
DataFrame: {top_df.to_json().replace('{', '{{').replace('}', '}}')}
------------

Columns in DataFrame: {', '.join(columns_to_use)}

{date_instruction}

RESPONSE TO USER:
    '''
        return prompt

    def get_sql_response(self, question, sqlQuery, df):

        ## Date Instruction
        date_instruction, df_copy = self.get_date_range(df)
        if date_instruction:
            df = df_copy.copy()

        # print(f'Date Instruction: {date_instruction}')
        # print(f'COlumns: {df.columns}')
        
        if len(df) <= 100:
            # Use simple non-chunked prompt
            prompt = self._generate_standard_prompt_for_small_df(question, sqlQuery, df, date_instruction)
            chain = PromptTemplate.from_template(prompt) | self.llm | StrOutputParser()
            return chain.invoke({}).strip()

        # Handle large DataFrame: show top 20 rows + download link
        try:
            _, _, redirect_link = self.get_redirect_link_from_df(df)
        except Exception as e:
            print(f"[Error] Failed to get redirect link: {str(e)[:300]}")
            redirect_link = None

        if redirect_link:
            prompt = self._generate_top_rows_with_download_link_prompt(
                question, sqlQuery, df, redirect_link, date_instruction
            )
        else:
            # Fallback if redirect link fails — show top 20 rows without download message
            print("[Warning] Redirect link not available. Showing top 20 results only.")
            prompt = self._generate_standard_prompt_for_small_df(question, sqlQuery, df.head(20), date_instruction)

        chain = PromptTemplate.from_template(prompt) | self.llm | StrOutputParser()
        return chain.invoke({}).strip().replace('http://', 'https://')


sql_response_generator = SQLResponseGenerator(s3_utils=s3_utils)

In [37]:
# df = df.head(20)

In [89]:
# # s3_utils.get_redirect_link_from_df(df)
# sql_response_chain_result = sql_response_generator.get_sql_response(user_prompt_to_use, generated_sql_query, df)

# print(f'Question: {user_prompt_to_use}\n')
# Markdown(sql_response_chain_result)

In [39]:
# print(f'Question: {user_prompt_to_use}')
# Markdown(sql_response_chain_result)

# GeneralQueryAPIHandler  

In [40]:
class GeneralQueryAPIHandler:
    def __init__(self, api_url: str):
        """
        Initializes the GeneralQueryAPIHandler class with the specified API URL.

        Args:
            api_url (str): The URL of the API endpoint.
        """
        self.api_url = api_url

    def get_result(self, user_prompt: str) -> str:
        """
        Sends a POST request to the API with the provided user prompt and returns the text response.

        Args:
            user_prompt (str): The user's prompt to send to the API.

        Returns:
            str: The 'text' value from the API response, or an error message if the request fails.
        """
        payload = {"question": user_prompt}
        
        try:
            response = requests.post(self.api_url, json=payload)
            response.raise_for_status()  # Raises an HTTPError for bad responses
            return response.json().get('text', 'No text found in response')
        except requests.exceptions.RequestException as e:
            # Handle potential exceptions, such as network issues or invalid responses
            print(f"Error during API request: {e}")
            return "Failed to send payload to the API."

API_URL = "https://chat.supplycopia.com/api/v1/prediction/b19b865b-1ac7-448b-940f-e78857b3fe76"
general_query_api = GeneralQueryAPIHandler(API_URL)

# HighchartModule

In [41]:
class HighchartModule:
    def __init__(self):
        pass

    def function_to_check_for_chart_possibility(self, df):
        charting = False
        if df is not None and not df.empty and len(df) != 1 and len(df.columns) != 1:
            charting = True
        return charting
                
    def context_generation_for_js_config(self, human_message, df):
        """
        Generates the context for JavaScript configuration by creating a formatted string
        containing the user's question and a preview of the provided DataFrame.

        Args:
            human_message (str): The user's question.
            df (DataFrame): The DataFrame containing the data.

        Returns:
            str: A formatted string with the user's question and the DataFrame preview.
        """

        output_context = f'''
        ------
        USER'S QUESTION: `{human_message}`
        ------
        DATAFRAME: \n{df.head(50).to_csv(index=False)}
        ------'''
        return output_context.strip()

    def getting_prompt_for_js_highcharts(self, context_for_chart_generation):

        """
        Constructs a system message prompt for generating Highcharts JS code based on the user's query and SQL response.

        Args:
            context_for_chart_generation (str): The SQL response to be used for generating the Highcharts JS code.

        Returns:
            str: A string containing the system message prompt for Highcharts JS code generation.
        """

        ## Color Codes to Use

        color_codes = [
            "#006292", "#1A729D", "#98D7FC", "#1AC5FB", "#1AABD9", "#1A73CB", "#4A90E5", "#6DEDD1", "#BEFBBF", "#83C38A", "#1A9D6D",
            "#34C670", "#4ED1B6", "#28B78F", "#FFE886", "#F8EEBE", "#D59D50", "#EC8047", "#C36C41", "#FFA56F", "#EA654D", "#C64949",
            "#00BFFA", "#00A2D5", "#0063C5", "#3684E2", "#5DEBCC"]

        system_message = f'''As a seasoned NodeJS developer with expertise in the 'highcharts' library, you are tasked with providing assistance on chart/plot related queries. Your responses should focus solely on delivering complete HTML, CSS, and JavaScript code snippets tailored to the specific requirements of each chart query based on the provided USER'S QUESTION AND DATAFRAME below.

- Use the JavaScript code generated to render and display a graph/chart/table.
- Provide a complete HTML document including HTML, CSS, and JavaScript code.

- Strictly Follow the Response Template:
    ```
    <html>
    <head>
        <script src="https://code.highcharts.com/highcharts.js"></script>
        <script src="https://code.highcharts.com/modules/exporting.js"></script>
        <script src="https://code.highcharts.com/modules/export-data.js"></script>
        <script src="https://code.highcharts.com/modules/accessibility.js"></script>
    </head>
    <body>
        <div id="container_id"></div>
        <script>
        eval(Highcharts.chart('container_id', {{{{
            credits: {{{{ enabled: false }}}},
            exporting: {{{{
                enabled: true,
                buttons: {{{{
                    contextButton: {{{{
                        menuItems: ['downloadPNG', 'downloadPDF']
                    }}}}
                }}}}
            }}}},
            ...
        }}}}));
        </script>
    </body>
    </html>
    ```
# Render the graph after the entire text is loaded.
# Replace the placeholder ... with appropriate elements.

# Instruction for the Use of Colors: 
    ** Use distinct colors for each data point to enhance visual appeal.
    ** Colors must be exclusively selected from the following list: {', '.join(color_codes)}.
    ** Do not repeat colors unless the number of data points exceeds the color options.
    ** The objective is to ensure the chart is visually engaging and easy to differentiate at a glance.

# Chart Types: Line, Bar, Column, Pie, Scatter, Bubble etc. Use the most appropriate chart type based on the given data.

# When dealing with Numbers inside the data: 
    - Must use the currency symbol appropriately ($, €) 
        - (valuePrefix: '$') or (valuePrefix: '€')
    - Round numbers to two decimal places (valueDecimals: 2).
    - Use commas (,) in between digits inside numbers (For this, Use - Highcharts.numberFormat).
    - *Example Numbers: 
        - 36,614,583.26 (Encounter, Outcome Score etc where we are not talking currency symbol is not needed.)
        - $36,614,583.26 (Spend/Price/Cost/Charge etc, where we need $ currecty symbol.)
        - €36,614,583.26 (Spend/Price/Cost/Charge etc, where we need € currecty symbol.)

# Use full month name whenever possible.
# Must Use Proper Titles inside the tooltips.

# Define the `tooltip` section inside Highcharts using a custom formatter function.
# Follow these clear rules step-by-step:
    1. **Always start the tooltip with the category label:**
        - Use: `this.key` — never use `this.x` or fallback logic.
        - It should be the first line of the tooltip.
        - Example line of code:
            ```javascript
            let tooltip = `<b>$this.key</b><br/>`;
            ```
        - Only exception is for the date. If Date is given on X Axis - then use appropriate tooltip.

    2. For each point in the tooltip (loop over this.points), use the following logic to decide the correct display format based on the series name:

        - If the series name contains **'Spend'**, **'Cost'**, **'Price'**, or **'Charge'** (case-insensitive):
            - Use a **$ currency symbol**, format to two decimal places, and add commas.
            - Example: `Total Spend: <b>$36,614,583.26</b>`

        - For all other types:
            - Format as a plain number with two decimal places and commas.
            - Example: `Volume Index: <b>14,284.68</b>`

    3. Use `Highcharts.numberFormat(point.y, decimals, '.', ',')` to format numbers properly.
    4. Return the final tooltip string after the loop.
    5. Wrap the whole formatter inside:
        ```
        tooltip: 
            formatter: function () 
                // your logic here
            shared: true
        ```

# Do not skip any condition or return incomplete tooltip content.
# Be sure to apply the correct number formatting and symbol for each type.
# Always ensure tooltips are clean, aligned with the series meaning, and easily readable.
# If the question is asked in different language, use the same language to provide the response and also use relevant currency symbol.
# Your response should only contain the HTML code following the provided template.
------------------
{context_for_chart_generation}
----------------
HIGHCHARTS GRAPH:
'''
        return system_message.strip()

    def highcharts_js_chain_function(self, context_for_chart_generation):
        
        """
        Creates a language model chain to generate Highcharts JS code based on the provided SQL response.

        Args:
            context_for_chart_generation (str): The SQL response to be used for generating the Highcharts JS code.

        Returns:
            LLMChain: A chain object for generating Highcharts JS code.
        """

        # Initialize the language model
        # llm = ChatOpenAI(model = 'gpt-4o-2024-08-06', temperature=0)
        llm = ChatOpenAI(model='gpt-4.1-2025-04-14', temperature=0)


        # System Prompt
        system_message_prompt = self.getting_prompt_for_js_highcharts(context_for_chart_generation)
        
        # Combining system and human prompts
        final_prompt_template_for_highcharts_js = PromptTemplate.from_template(system_message_prompt)

        # Create a basic single chain
        final_chain = final_prompt_template_for_highcharts_js | llm | StrOutputParser()

        return final_chain

    def saving_visualization_content_as_html(self, html_content, user_id):
        """
        Saves the provided HTML content to a file with a unique name.

        Args:
            html_content (str): The HTML content to be saved.

        Returns:
            str: The file path of the saved HTML file.
        """

        ## Create the Directory if not Available yet
        folder_directory = f'visualization/{user_id}'
        os.makedirs(folder_directory, exist_ok=True)    

        ## Exporting Image
        unique_identifier = str(uuid.uuid4()).split('-')[-1]
        image_file_name_with_path = f"{folder_directory}/viz_{unique_identifier}.html"

        html_content = html_content.replace('let tooltip = `<b>${this.x}</b><br/>`','let tooltip = `<b>${this.key}</b><br/>`').strip()
        # Open a new file in write mode and write the HTML content
        with open(image_file_name_with_path, 'w', encoding='utf-8') as file:
            file.write(html_content)

        image_name = image_file_name_with_path.split('/')[-1]
        return image_file_name_with_path, image_name
    
    def get_highchart_info(self, human_message, df, user_id):

        if self.function_to_check_for_chart_possibility(df):
            context_for_chart_generation = self.context_generation_for_js_config(human_message, df)
            highcharts_js_chain = self.highcharts_js_chain_function(context_for_chart_generation)
            html_content = highcharts_js_chain.invoke({})
            html_content = html_content.replace('`html','').strip('`').strip()
            saved_image_file_path, saved_image_name = self.saving_visualization_content_as_html(html_content, user_id)
            return saved_image_file_path, saved_image_name, html_content
        else:
            return None, None, None
      
hichart_module = HighchartModule()

In [42]:
# saved_image_file_path, saved_image_name, html_content = hichart_module.get_highchart_info(user_prompt_to_use, df, user_id)
# saved_image_name

# DynamicRecommendationChainManager | Endpoint: /recommendation

In [43]:
class DynamicRecommendationChainManager:
    def __init__(self):
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)

    def classify_role_and_generate_context(self, role_input):
        """
        Uses the LLM to classify the user's role into a hierarchy and derive responsibilities dynamically.
        If role_input is blank, returns a general recommendation context.

        Args:
            role_input (str): The user's role or title (e.g., "CMO", "Chief Marketing Officer").

        Returns:
            str: A classification response or general recommendation context.
        """
        if not role_input.strip():
            # Return general recommendation context if no role is provided
            return "- Management Level: General Management\n" \
                "- Role Responsibilities: Responsible for general data analysis and insights to identify trends, " \
                "opportunities, and risks. Recommendations should focus on actionable insights derived from the data."

        # Classification prompt for non-blank role inputs
        classification_prompt = f'''
        You are an AI assistant specializing in organizational hierarchy and role classification.
        Based on the user-provided role or title, identify the appropriate management level 
        (e.g., Executive Leadership, Senior Management, Mid-Level Management, Operational Management, etc.) 
        and provide a concise description of the responsibilities associated with this role.

        Role Provided: "{role_input}"

        Respond with:
        - Management Level: (e.g., "Executive Leadership")
        - Role Responsibilities: (e.g., "Responsible for overseeing high-level strategic operations...")
        '''

        chat_prompt_template_for_classification = PromptTemplate.from_template(classification_prompt)
        dynamic_recommendation_chain_classification = chat_prompt_template_for_classification | self.llm | StrOutputParser()
        classification_response = dynamic_recommendation_chain_classification.invoke({}).strip()
        return classification_response

    def generate_dynamic_recommendation_prompt(self, question, df, role_input):
        """
        Generates a system message prompt for dynamic role-based recommendations.

        Args:
            question (str): User's query or question.
            df (DataFrame): The DataFrame containing data for analysis.
            role_input (str): The user's role or title.

        Returns:
            str: A system message prompt for generating dynamic recommendations.
        """
        role_context = self.classify_role_and_generate_context(role_input)

        lines = role_context.split("\n")
        management_level = next((line.split(":")[1].strip() for line in lines if line.startswith("- Management Level:")), "General Management")
        role_responsibilities = next((line.split(":")[1].strip() for line in lines if line.startswith("- Role Responsibilities:")), "General responsibilities for this role.")

        print(f'\n==============================\nmanagement_level: {management_level}')
        print(f'role_responsibilities: {role_responsibilities}\n==============================\n')

        if len(df) > 100:
            display_size_note = f"**Note: The DataFrame includes the first 50 entries out of a total of {len(df)} rows. Let the user know about it."
        else:
            display_size_note = ""

        system_message_prompt = f'''You are an AI Chatbot for Supplycopia that provides data-driven recommendations based on user roles and responsibilities.

Role Context:
- Management Level: {management_level}
- Role Responsibilities: {role_responsibilities}

Recommendation Guidelines:
1. Focus on actionable recommendations derived from DataFrame analysis
2. Each recommendation must:
   - Link directly to data evidence
   - Align with user's role/responsibilities
   - Address specific question context
   - Include quantifiable metrics where possible

3. Prioritize by:
   - Impact (based on data)
   - Implementation feasibility 
   - Role relevance
   - Urgency (if applicable)

4. Format:
   - Start with clear recommendation (Elaborately)
   - Support with specific data points
   - End with expected outcome

{display_size_note}

Tone: Professional and direct

Query: {question}
DataFrame: {df.head(50).to_json().replace('{', '{{').replace('}', '}}')}
'''
        
        return system_message_prompt

    def run_dynamic_recommendation_chain(self, question, df, role_input=''):
        """
        Runs the dynamic recommendation chain to generate tailored recommendations.

        Args:
            question (str): User's query or question.
            df (DataFrame): The DataFrame containing data for analysis.
            role_input (str): The user's role or title.

        Returns:
            str: Recommendations based on the provided query and data.
        """
        system_message_prompt = self.generate_dynamic_recommendation_prompt(question, df, role_input)
        chat_prompt_template = PromptTemplate.from_template(system_message_prompt)
        dynamic_recommendation_chain = chat_prompt_template | self.llm | StrOutputParser()
        query_response = dynamic_recommendation_chain.invoke({})
        query_response = query_response.strip()
        return query_response

recommendation_module = DynamicRecommendationChainManager()

In [44]:
# recommendations = recommendation_module.run_dynamic_recommendation_chain(user_prompt_to_use, df, '')
# # # recommendations = dynamic_recommendation_chain_manager.run_dynamic_recommendation_chain(user_prompt_to_use, df, 'Software')

# # print(recommendations)

# FollowUpQuestions - Recommending two questions based on the previously asked question

### Previous Approach -> FollowUpQuestions

In [45]:
# import random

# # Define the FollowUpQuestions model
# class FollowUpQuestions(BaseModel):
#     questions: List[str]

#     def to_list(self) -> List[str]:
#         return self.questions
    
# class QuestionSuggestionsGenerator:
    
#     def __init__(self):
#         """
#         Initializes the QuestionSuggestionsGenerator with a language model.
#         """
#         self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
#         self.structured_llm = self.llm.with_structured_output(FollowUpQuestions)

#     def generate_system_message(self, user_prompt: str, generated_sql_query: str, table_to_use: str, table_columns: list) -> str:
#         """
#         Generates the system message prompt for suggesting questions.

#         Args:
#             user_prompt (str): The original user's prompt.
#             table_to_use (str): The name of the table being used.
#             table_columns (list): List of available columns in the table.
#             generated_sql_query (str): Sql query generated to get the result.


#         Returns:
#             str: The formatted system message prompt.
#         """

#         # Get today's date
#         today_date = datetime.now()
#         # Format the date as YYYY-MM-DD
#         formatted_date = today_date.strftime('%Y-%m-%d')

#         system_message = f'''
# You are an AI assistant tasked with generating two relevant follow-up questions based on the user's current query and the SQL query executed to obtain the results.

# **Context:**
# - **Table:** {table_to_use}
# - **Columns:** {', '.join(table_columns)}
# - **Generated SQL Query:** {generated_sql_query}

# **Guidelines:**
# 1. **Relevance:** Each follow-up question must directly relate to the user's original query and the data retrieved by the SQL query.
# 2. **Simplicity:** Generate simple and clear questions that are easy to understand.
# 3. **Column Transformation:** Transform column names to their full business names when generating follow-up questions:
#    - `order_id` → `Order Number`
#    - `cust_name` → `Customer Name`
#    - `ord_date` → `Order Date`
# 4. **Focus Areas:** Concentrate on basic counting, summing, or simple aggregations related to the data.
# 5. **Conciseness:** Keep each question under 10 words.
# 6. **Direct Answerability:** Ensure questions can be answered by direct lookups without requiring complex computations.
# 7. **Single Metrics/Periods:** Target single metrics or specific time periods to maintain focus.
# 8. **Avoid Transformations:** Do not include data transformations or complex operations in the questions.
# 9. **Use Available Columns:** Base questions solely on the columns available in the specified table.
# 10. **Date Reference:** Today's Date: {formatted_date}
# 11. **SQL Utilization:** Use the `generated_sql_query` to inform and enhance the relevance of the follow-up questions.

# **User Query:** {user_prompt}

# **Format:**
# ["...", "..."]

# **Note:** Provide only the numbered questions without any explanations or additional text.
# '''

#         return system_message.strip()

#     def generate_question_chain(self, user_prompt: str, generated_sql_query: str, table_to_use: str, table_columns: list):
#         """
#         Creates a language model chain to generate question suggestions.

#         Args:
#             user_prompt (str): The user's original query.
#             table_to_use (str): The name of the table being used.
#             table_columns (list): List of available columns in the table.

#         Returns:
#             Chain: A chain object for generating question suggestions.
#         """
#         system_message = self.generate_system_message(user_prompt, generated_sql_query, table_to_use, table_columns)
        
#         # Create the system message prompt template
#         system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
        
#         # Create the human message prompt
#         human_message_prompt = HumanMessagePromptTemplate.from_template(
#             "Generate two question suggestions based on this query."
#         )

#         # Combine prompts into a chat prompt template
#         chat_prompt_template = ChatPromptTemplate.from_messages([
#             system_message_prompt,
#             human_message_prompt
#         ])

#         # Create the chain
#         question_chain = chat_prompt_template | self.structured_llm

#         return question_chain

#     def get_suggestions(self, user_prompt: str, generated_sql_query: str, table_to_use: str, table_columns: list) -> FollowUpQuestions:
#         """
#         Gets question suggestions based on the user's prompt and table information.

#         Args:
#             user_prompt (str): The user's original query.
#             table_to_use (str): The name of the table being used.
#             table_columns (list): List of available columns in the table.

#         Returns:
#             str: Two suggested follow-up questions.
#         """

#         random.shuffle(table_columns)

#         question_chain = self.generate_question_chain(user_prompt, generated_sql_query, table_to_use, table_columns)
#         suggestions = question_chain.invoke({})

#         return cast(FollowUpQuestions, suggestions)
    
# # Initialize the Question Suggestions Generator
# question_suggestions_generator = QuestionSuggestionsGenerator()

### New Approach -> FollowUpQuestions

In [46]:
import random 

# Define the FollowUpQuestions model
class FollowUpQuestions(BaseModel):
    questions: List[str]

    def to_list(self) -> List[str]:
        return self.questions

class QuestionSuggestionsGenerator:
    
    def __init__(self):
        # self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        self.llm = ChatOpenAI(model='gpt-4.1', temperature=0.5)
        self.structured_llm = self.llm.with_structured_output(FollowUpQuestions)

    # 🔄 Updated to accept past_questions
    def generate_system_message(
        self, 
        user_prompt: str, 
        generated_sql_query: str, 
        table_to_use: str, 
        table_columns: list,
        past_questions: List[str]  # 🔄
    ) -> str:
        today_date = datetime.now()
        formatted_date = today_date.strftime('%Y-%m-%d')

        past_questions_str = "\n".join([f"- {q}" for q in past_questions]) if past_questions else "None"

        system_message = f"""
You are an AI assistant tasked with generating **two concise, relevant, and professional follow-up questions** based on the user's current query and the SQL query executed to retrieve the results.

### Context
- **Table:** {table_to_use}
- **Columns:** {', '.join(table_columns)}
- **Executed SQL Query:** {generated_sql_query}

### Previously Asked Questions
{past_questions_str}

### Guidelines for Generating Follow-Up Questions

1. **Relevance:** Each question must be directly relevant to the user's current query and the SQL results.
2. **No Repetition:** Do NOT repeat or rephrase any questions already listed above.
3. **Clarity:** Use clear, direct, and simple language. Avoid overly technical or verbose phrasing.
4. **Professional Tone:** Use a formal, business-appropriate tone. Avoid casual or conversational style.
5. **Column Name Mapping:** Translate technical column names to user-friendly business terms where applicable:
   - `order_id` → `Order Number`
   - `cust_name` → `Customer Name`
   - `ord_date` → `Order Date`
6. **Answerability:** Ensure each question can be answered using simple aggregations (e.g., COUNT, SUM) or direct lookups from the data. No advanced calculations or transformations should be assumed.
7. **Single Focus:** Target one metric or time period per question for clarity and specificity.
8. **Use Only Available Columns:** Do not introduce new concepts or columns that are not in the provided list.
9. **No Hardcoded Entities:** Avoid using specific entity names (e.g., “orthopedic”, “SELLERS, MATTHEW B”, “Texas”) **unless they appear in the user's current or past questions.**
10. **Hierarchy-Driven Suggestions:** Follow the semantic drill-down hierarchy to guide deeper exploration of the dataset:
    → Market  
    → Facility  
    → Service Line / Business Line  
    → Procedure / Physician  
    → Contract Category / UNSPSC  
    → Products  

    - If the user asks about a higher-level dimension (e.g., Market), generate questions at the next lower level (e.g., Facility).
    - If the user focuses on Physicians, suggest questions about their Procedures or Products.
    - Do not suggest questions that go *up* the hierarchy.
11. **Avoid Redundancy:** Suggested questions must **not** be semantically similar, layered, or hierarchical variants of one another (e.g., “total encounter for top service line” vs. “top facility in top service line”).
12. **Avoid Transformations:** Do not suggest derived metrics or complex computed features.
13. **Leverage SQL Output:** Use the `generated_sql_query` to inform what metrics or dimensions have already been used, and build logically connected but distinct follow-ups.

14. Clarify Ambiguous Entities:
When a question includes a specific value (e.g., a name or location) without explicitly stating its entity type (such as market, facility, physician, etc.), prepend the appropriate entity label to ensure clarity and avoid ambiguity. This helps downstream systems understand and map the query correctly.
Example:
❌ “What is the total cost for Horizon West?”
✅ “What is the total cost for market Horizon West?”

15. **Current Date Reference:** Today's date is {formatted_date} – include it only if the question context requires a temporal reference.

### Input
**User Query:** {user_prompt}

### Output Format
A list of exactly two distinct, valid follow-up questions:
["<question_1>", "<question_2>"]

### Important
- Only return the list. No explanations, commentary, or additional formatting.
"""
        return system_message.strip()

    # 🔄 Updated to accept past_questions
    def generate_question_chain(
        self, 
        user_prompt: str, 
        generated_sql_query: str, 
        table_to_use: str, 
        table_columns: list,
        past_questions: List[str]  # 🔄
    ):
        system_message = self.generate_system_message(
            user_prompt, generated_sql_query, table_to_use, table_columns, past_questions
        )

        system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
        human_message_prompt = HumanMessagePromptTemplate.from_template(
            "Generate two question suggestions based on this query."
        )

        chat_prompt_template = ChatPromptTemplate.from_messages([
            system_message_prompt,
            human_message_prompt
        ])

        return chat_prompt_template | self.structured_llm

    # 🔄 Updated to accept past_questions
    def get_suggestions(
        self, 
        user_prompt: str, 
        generated_sql_query: str, 
        table_to_use: str, 
        table_columns: list,
        past_questions: List[str] = []  # 🔄 optional
    ) -> FollowUpQuestions:

        random.shuffle(table_columns)

        question_chain = self.generate_question_chain(
            user_prompt, generated_sql_query, table_to_use, table_columns, past_questions
        )
        suggestions = question_chain.invoke({})

        return cast(FollowUpQuestions, suggestions)

## Question Suggestion Generator
question_suggestions_generator = QuestionSuggestionsGenerator()

In [47]:
# previous_questions = [
#     msg.content
#     for msg in memory_manager.store[user_id].messages
#     if isinstance(msg, HumanMessage)
# ][-8:]

# suggestions = question_suggestions_generator.get_suggestions(
#     user_prompt= user_prompt_to_use,
#     generated_sql_query=generated_sql_query,
#     table_to_use=table_to_use,
#     table_columns=table_columns,
#     past_questions=previous_questions # 👈 Prevents repeats
# )

# print(suggestions.to_list())

# <span style="color: yellow; font-weight: bold;">Charting_Json | Endpoint: /charting_json</span> (Not In Use)


In [48]:
colors = [
        "#006292", "#1A729D", "#98D7FC", "#1AC5FB", "#1AABD9", "#1A73CB", "#4A90E5",
        "#6DEDD1", "#BEFBBF", "#83C38A", "#1A9D6D", "#34C670", "#4ED1B6", "#28B78F",
        "#FFE886", "#F8EEBE", "#D59D50", "#EC8047", "#C36C41", "#FFA56F", "#EA654D",
        "#C64949", "#00BFFA", "#00A2D5", "#0063C5", "#3684E2", "#5DEBCC"
]

### ChartRecommender

In [49]:
class ChartRecommender:
    def __init__(self):
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)

    def get_column_types(self, df):
        """
        Provides the number of numeric and categorical columns in the DataFrame,
        using select_dtypes primarily, and a fallback mechanism when all columns are initially identified as categorical.

        Args:
            df (pd.DataFrame): DataFrame to analyze.

        Returns:
            dict: A dictionary with counts of numeric and categorical columns.
        """
        # Step 1: Primary method using select_dtypes
        numeric_columns = df.select_dtypes(include=['float64', 'int64', 'uint8', 'int32', 'float32']).columns
        categorical_columns = df.select_dtypes(include=['object', 'category', 'bool']).columns

        # Step 2: Check if all columns are classified as categorical
        if len(numeric_columns) == 0 and len(categorical_columns) == len(df.columns):
            # Fallback mechanism: Iterate through all columns to detect numeric columns
            numeric_columns = []
            categorical_columns = []
            for col in df.columns:
                try:
                    # Try converting the column to float
                    df[col].astype(float)
                    numeric_columns.append(col)
                except ValueError:
                    # If conversion fails, classify as categorical
                    categorical_columns.append(col)

        # Step 3: Return the counts
        return {
            "numeric_columns": len(numeric_columns),
            "categorical_columns": len(categorical_columns)
        }
    
    def get_suitable_charts_from_dataframe(self, df):
        """
        Performs basic data validation to determine if charts are technically possible,
        leaving final chart suitability decisions to the LLM.

        Args:
            df (pd.DataFrame): DataFrame to analyze.

        Returns:
            dict: Dictionary containing technically possible charts and column counts.
        """
        # Critical validation: Empty DataFrame or single row
        # Basic data validation to check minimum chart requirements
        if (
            df.empty or  # Empty dataframe
            len(df) < 2 or  # Minimum 2 rows needed for visualization
            len(df.columns) < 2  # Minimum 2 columns needed
        ):
            return {
                'suitable_charts': [],
                'numeric_columns': 0,
                'categorical_columns': 0
            }

        column_types = self.get_column_types(df)
        numeric_count = column_types['numeric_columns']
        categorical_count = column_types['categorical_columns']
        suitable_charts = []

        # Bubble: needs at least 3 numeric columns for x, y, and size
        if numeric_count >= 3:
            suitable_charts.append('bubble')

        # Scatter: needs at least 2 numeric columns for x and y
        if numeric_count >= 2:
            suitable_charts.append('scatter')

        # Basic minimum requirements for technical feasibility
        # Line charts: need at least 2 points to form a line
        if numeric_count >= 1 and len(df) >= 2:
            suitable_charts.extend(['area', 'line'])

        # Pie: needs exactly 1 numeric for values and 1 categorical for segments
        if numeric_count >= 1 and categorical_count >= 1:
            suitable_charts.append('pie')

        # Bar/Column: need at least 1 category and 1 numeric value
        if numeric_count >= 1 and categorical_count >= 1:
            suitable_charts.extend(['column', 'bar'])


        # Critical data quality check: remove all charts if any column is entirely NULL
        if df.isna().all().any():
            suitable_charts = []

        return {
            'suitable_charts': suitable_charts,
            'numeric_columns': numeric_count,
            'categorical_columns': categorical_count
        }
    
    def generate_chart_selection_guidelines(self, suitable_charts):
        """
        Generates dynamic chart selection guidelines based on the list of suitable charts.
        """
        chart_guidelines = {
            'line': """**Line Charts**:
- **Use When**: Displaying trends or changes over time or continuous intervals.
- **Data Requirements**:
    - X-axis: Continuous or time-based variable with regular intervals.
    - Y-axis: Numeric variable (e.g., totals, averages).""",
                
            'area': """**Area Charts**:
- **When to Select**: 
    - To visualize cumulative totals over time.
    - To emphasize the magnitude of trends across multiple categories.
    - To compare relative contributions using percentage-stacked area charts.
- **Required Data Characteristics**:
    - X-axis: Time-based or sequential data.
    - Y-axis: Numeric values that can be aggregated or compared across categories.
- **Selection Guidelines**:
    - Prefer for datasets with multiple numeric categories that can be stacked or compared.
- **Avoid Selection If**:
    - The dataset has a single numeric series (use a line or bar chart instead).
    - Individual category values need precise interpretation (stacked charts may obscure this).
- **Special Considerations**:
    - Always include a zero baseline for accurate representation.
    - Arrange stacked groups with the largest or most stable values at the bottom for better readability.""",
                
            'bar': """**Bar Charts**:
- **Use When**: Comparing categories with long names or a large number of categories.
- **Data Requirements**:
    - X-axis: Numeric values.
    - Y-axis: Categorical variable (e.g., labels, distinct groups).""",
                
            'column': """**Column Charts**:
- **Use When**: Comparing discrete categories with numerical values, especially for ordered or sequential data.
- **Data Requirements**:
    - X-axis: Categorical variable or sequentially ordered categories (e.g., months, years).
    - Y-axis: Numeric values.""",
                
            'pie': """**Pie Charts**:
- **Use When**: Showing part-to-whole relationships or percentages.
- **Data Requirements**:
    - Categories that sum to a whole (e.g., 100%) with corresponding numeric values.""",
                
            'scatter': """**Scatter Plots**:
- **Use When**: Exploring relationships or correlations between two numeric variables.
- **Data Requirements**:
    - X-axis: Numeric variable.
    - Y-axis: Numeric variable.""",
            
            'bubble': """**Bubble Charts**:
- **Use When**: Visualizing relationships among three numeric variables, optionally grouped by a categorical variable, to compare magnitudes or identify correlations.
- **Data Requirements**:
    - **X-axis**: Numeric variable (independent or comparison metric).
    - **Y-axis**: Numeric variable (response or comparison metric).
    - **Bubble Size**: Third numeric variable (magnitude or intensity).
    - **Optional**: Categorical variable for grouping (e.g., Markets, Facilities).
- **Considerations**:
    - Prioritize for queries with three numeric variables showing relationships or magnitudes.
    - Ensure variation in bubble sizes and validate all variables are non-null.
    - Avoid if data has excessive categories or fewer than three variables.
    - Handle overlaps with transparency or tooltips if required."""
        }

        guidelines = []
        for idx, chart in enumerate(suitable_charts, start=1):
            if chart in chart_guidelines:
                guidelines.append(f"{idx}. {chart_guidelines[chart]}")

        return "\n\n".join(guidelines)

    def generate_chart_selection_prompt(self, user_query, sql_query, df):
    # Get the column types
        column_types = self.get_suitable_charts_from_dataframe(df)
        # print(json.dumps(column_types, indent=2))
        suitable_charts = column_types['suitable_charts']
        
        chart_selection_guideline_to_be_used = self.generate_chart_selection_guidelines(suitable_charts)
        # print(chart_selection_guideline_to_be_used)

        system_message = f'''
You are a data visualization expert tasked with selecting the most appropriate chart types for visualizing data. Based on the user's query, SQL query, and data structure, recommend one or more of the following chart types: **column**, **bar**, **line**, **pie**, **bubble**, **scatter**. If none are suitable, return an empty list.

**Chart Selection Guidelines**:
{chart_selection_guideline_to_be_used}

**Context Provided**:
===============================================================================
User Query: {user_query}
SQL Query: {sql_query}
DataFrame Preview:
{df.head(50).to_json().replace('{', '{{').replace('}', '}}')}
===============================================================================

**Additional Information**:
- Only return the appropriate chart or charts from this list: {suitable_charts}

**Instructions**:
1. **Understand the Context**:
   - Analyze the user's query for intent.
   - Examine the data structure from the DataFrame preview.
   - Consider SQL query operations (e.g., aggregations, filters).
   - Take into account the preliminary chart suggestions based on the column types.

2. **Select Suitable Chart Types**:
   - Refer to the **Chart Selection Guidelines** to choose the most appropriate chart type(s).
   - Match the data structure and user intent with the chart requirements.

3. **Output Format**:
   - Return a Python list of chart type names in lowercase (e.g., `['column', 'line']`).
   - If no chart type is suitable, return an empty list `[]`.
   - Do not include explanations, comments, or any additional text.

**Example Output**:
['chart_type_1', 'chart_type_2']
'''
        return system_message.strip(), column_types

    def get_suitable_charts(self, user_query, sql_query, df):
        # Create the prompt template
        system_message, column_types = self.generate_chart_selection_prompt(user_query, sql_query, df)
        suitable_charts = column_types['suitable_charts']

        if suitable_charts == []:
            chart_types = []
            return chart_types, column_types
        
        prompt_template = PromptTemplate.from_template(system_message)
        
        # Create and run the chain
        chart_selection_chain = prompt_template | self.llm | StrOutputParser()
        result = chart_selection_chain.invoke({})
        
        # Clean and process the result
        result = result.lower().replace('[', '').replace(']', '').replace('\'', '').replace('"', '')
        chart_types = [chart.strip() for chart in result.split(',')]

        ## If we can create a line chart, then an area chart is also possible 
        if 'line' in chart_types:
            chart_types.append('area')

        chart_types = list(set(chart_types))    
        return chart_types, column_types

## Initializing ChartRecommender Module
chart_recommender = ChartRecommender()

In [50]:
# print(f'user_prompt: {user_prompt}\n')

# # chart_prompt = chart_selection.generate_chart_selection_prompt(user_prompt, generated_sql_query, df)
# # print(chart_prompt)

# list_of_charts, suitable_chart_types = chart_recommender.get_suitable_charts(user_prompt, generated_sql_query, df)
# print(json.dumps(suitable_chart_types, indent = 2))
# print(f'\nlist_of_charts: {list_of_charts}\n')

###  __ Charts Modules __

#### 1. ColumnChartGenerator

In [51]:
class ColumnChartGenerator:
    def __init__(self, colors):
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        self.colors = colors

    def _generate_column_chart_prompt(self, query: str, df: pd.DataFrame) -> str:
        """
        Generates the prompt for the LLM to analyze the data and suggest column chart configuration.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            str: Formatted prompt for the LLM
        """
        system_message = f'''
You are an expert in data visualization specializing in creating Highcharts column chart configurations. Analyze the provided DataFrame and user query to generate the optimal column chart JSON configuration.

DataFrame Preview:
{df.head(50).to_json().replace('{', '{{').replace('}', '}}')}

User Query: {query}

Available Color Codes (use only these):
{", ".join(self.colors)}

Follow this exact JSON template for column charts:
{{{{
  "chart": {{{{
    "type": "column",
    "plotBorderWidth": 1.5
  }}}},
  "accessibility": {{{{
    "description": "<accessibility_description>"
  }}}},
  "title": {{{{
    "text": "<chart_title>",
    "align": "<align>"
  }}}},
  "subtitle": {{{{
    "text": "<chart_subtitle>",
    "align": "<align>"
  }}}},
  "xAxis": {{{{
    "categories": ["Category 1", "Category 2", "Category 3", "Category 4"],
    "title": {{{{ "text": "<x_axis_title>" }}}},
    "crosshair": true,
    "gridLineWidth": 1,
    "lineWidth": 0,
    "allowDecimals": false,
    "accessibility": {{{{
      "rangeDescription": "<range_description>",
      "description": "<x_axis_accessibility_description>"
    }}}}
  }}}},
  "yAxis": {{{{
    "title": {{{{
      "text": "<y_axis_title>",
      "align": "<align>"
    }}}},
  }}}},
  "series": [
    {{{{
      "name": "<series_name>",
      "data": [<data_points>],
      "color": "<hex_color>",
      "tooltip": {{{{
        "pointFormat": "<point_format>"
      }}}},
      "dataLabels": {{{{
        "enabled": true,
        "format": "<format>"
      }}}}
    }}}}
  ],
  "plotOptions": {{{{
    "column": {{{{
      "pointPadding": 0.2,
      "borderWidth": 0,
      "groupPadding": 0.1,
      "dataLabels": {{{{
        "enabled": true,
        "format": "{{{{point.y:,.2f}}}}"
      }}}}
    }}}}
  }}}},
  "legend": {{{{
    "layout": "vertical",
    "align": "right"
  }}}},
  "responsive": {{{{
    "rules": [
      {{{{
        "condition": {{{{ "maxWidth": 500 }}}},
        "chartOptions": {{{{
          "legend": {{{{
            "layout": "horizontal",
            "align": "<align>",
            "verticalAlign": "bottom"
          }}}}
        }}}}
      }}}}
    ]
  }}}}
}}}}

Instructions:
1. Analyze the DataFrame structure and user query to determine:
   - Appropriate columns for x-axis (categories) and y-axis (values)
   - Meaningful title and labels
   - Proper series configuration

2. Number Formatting Requirements:
   - For currency values (Spend/Price/Cost/Charge):
     * Use $ prefix (valuePrefix: '$')
     * Format example: $36,614,583.26
   - For non-currency values (like Encounter, Score):
     * No currency symbol
     * Format example: 36,614,583.26
   - Always:
     * Round to 2 decimal places
     * Use commas for thousand separators
     * Use Highcharts.numberFormat
   - For percentages:
     * Format to 1 decimal place
     * Include % symbol

3. Date and Text Requirements:
   - Use full month names (e.g., "January" not "Jan")
   - Use proper capitalizations in labels and titles

4. The JSON must follow the exact template structure above
5. Only use colors from the provided color list
6. Return only the JSON, no explanations

Generate the complete column chart JSON configuration:'''

        return system_message.strip()

    def _validate_and_clean_json(self, json_str: str) -> Dict[str, Any]:
        """
        Validates and cleans the JSON string returned by the LLM.
        
        Args:
            json_str (str): JSON string from LLM
            
        Returns:
            Dict[str, Any]: Cleaned and validated JSON dictionary
        """
        try:
            # Remove any markdown code block syntax if present
            json_str = json_str.replace('```json', '').replace('```', '').strip()
            
            # Parse JSON
            chart_config = json.loads(json_str)
            
            # Basic validation
            required_keys = ['chart', 'title', 'xAxis', 'yAxis', 'series', 'plotOptions']
            for key in required_keys:
                if key not in chart_config:
                    raise ValueError(f"Missing required key: {key}")
            
            # Ensure 'type' is 'column'
            if chart_config['chart'].get('type') != 'column':
                chart_config['chart']['type'] = 'column'
                
            return chart_config
            
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format: {str(e)}")
        except Exception as e:
            raise ValueError(f"Error processing chart configuration: {str(e)}")

    def generate_chart_json(
        self,
        query: str,
        df: pd.DataFrame
    ) -> Dict[str, Any]:
        """
        Generates a column chart JSON configuration using LLM and data analysis.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            Dict[str, Any]: Complete column chart JSON configuration
        """
        # Generate prompt for LLM
        prompt = self._generate_column_chart_prompt(query, df)
        
        # Create chain
        prompt_template = PromptTemplate.from_template(prompt)
        chart_generation_chain = prompt_template | self.llm | StrOutputParser()
        
        # Get JSON from LLM
        json_response = chart_generation_chain.invoke({})
        
        # Validate and clean the JSON
        chart_config = self._validate_and_clean_json(json_response)
        
        return chart_config
    
## Initialization 
column_chart_generator = ColumnChartGenerator(colors=colors)

#### 2. BarChartGenerator

In [52]:
class BarChartGenerator:
    def __init__(self, colors):
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        self.colors = colors

    def _generate_bar_chart_prompt(self, query: str, df: pd.DataFrame) -> str:
        """
        Generates the prompt for the LLM to analyze the data and suggest bar chart configuration.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            str: Formatted prompt for the LLM
        """
        system_message = f'''
You are an expert in data visualization specializing in creating Highcharts bar chart configurations. Analyze the provided DataFrame and user query to generate the optimal bar chart JSON configuration.

DataFrame Preview:
{df.head(50).to_json().replace('{', '{{').replace('}', '}}')}

User Query: {query}

Available Color Codes (use only these):
{", ".join(self.colors)}

Follow this exact JSON template for bar charts:
{{{{
  "chart": {{{{
    "type": "bar",
    "plotBorderWidth": 1
  }}}},
  "accessibility": {{{{
    "description": "<accessibility_description>"
  }}}},
  "title": {{{{
    "text": "<chart_title>",
    "align": "left",
    "style": {{{{
      "fontSize": "16px"
    }}}}
  }}}},
  "subtitle": {{{{
    "text": "<chart_subtitle>",
    "align": "left"
  }}}},
  "xAxis": {{{{
    "categories": ["Category 1", "Category 2", "Category 3", "Category 4"],
    "title": {{{{ "text": "<x_axis_title>" }}}},
    "crosshair": true,
    "gridLineWidth": 1,
    "lineWidth": 0,
    "allowDecimals": false,
    "accessibility": {{{{
      "rangeDescription": "<range_description>",
      "description": "<x_axis_accessibility_description>"
    }}}}
  }}}},
  "yAxis": {{{{
    "title": {{{{
      "text": "<y_axis_title>",
      "align": "high"
    }}}},
    "min": 0,
    "max": null,
    "labels": {{{{
      "overflow": "justify"
    }}}},
    "gridLineWidth": 0
  }}}},
  "series": [
    {{{{
      "name": "<series_name>",
      "data": [<data_points>],
      "color": "<hex_color>",
      "tooltip": {{{{
        "pointFormat": "<point_format>"
      }}}},
      "dataLabels": {{{{
        "enabled": true,
        "format": "<format>"
      }}}}
    }}}}
  ],
  "plotOptions": {{{{
    "bar": {{{{
      "borderRadius": "50%",
      "groupPadding": 0.1,
      "dataLabels": {{{{
        "enabled": true,
        "format": "{{{{point.y:,.2f}}}}"
      }}}},
      "colorByPoint": false
    }}}}
  }}}},
  "legend": {{{{
    "layout": "vertical",
    "borderWidth": 1,
    "backgroundColor": "#FFFFFF",
    "shadow": true
  }}}},
  "responsive": {{{{
    "rules": [
      {{{{
        "condition": {{{{ "maxWidth": 500 }}}},
        "chartOptions": {{{{
          "legend": {{{{
            "layout": "horizontal",
            "align": "center",
            "verticalAlign": "bottom"
          }}}}
        }}}}
      }}}}
    ]
  }}}}
}}}}

Instructions:
1. Analyze the DataFrame structure and user query to determine:
   - Appropriate columns for x-axis (categories) and y-axis (values)
   - Meaningful title and labels
   - Proper series configuration

2. Number Formatting Requirements:
   - For currency values (Spend/Price/Cost/Charge):
     * Use $ prefix (valuePrefix: '$')
     * Format example: $36,614,583.26
   - For non-currency values (like Encounter, Score):
     * No currency symbol
     * Format example: 36,614,583.26
   - Always:
     * Round to 2 decimal places
     * Use commas for thousand separators
     * Use Highcharts.numberFormat
   - For percentages:
     * Format to 1 decimal place
     * Include % symbol

3. Date and Text Requirements:
   - Use full month names (e.g., "January" not "Jan")
   - Use proper capitalizations in labels and titles

4. The JSON must follow the exact template structure above
5. Only use colors from the provided color list
6. Return only the JSON, no explanations

Generate the complete bar chart JSON configuration:'''

        return system_message.strip()

    def _validate_and_clean_json(self, json_str: str) -> Dict[str, Any]:
        """
        Validates and cleans the JSON string returned by the LLM.
        
        Args:
            json_str (str): JSON string from LLM
            
        Returns:
            Dict[str, Any]: Cleaned and validated JSON dictionary
        """
        try:
            # Remove any markdown code block syntax if present
            json_str = json_str.replace('```json', '').replace('```', '').strip()
            
            # Parse JSON
            chart_config = json.loads(json_str)
            
            # Basic validation
            required_keys = ['chart', 'title', 'xAxis', 'yAxis', 'series', 'plotOptions']
            for key in required_keys:
                if key not in chart_config:
                    raise ValueError(f"Missing required key: {key}")
            
            # Ensure 'type' is 'bar'
            if chart_config['chart'].get('type') != 'bar':
                chart_config['chart']['type'] = 'bar'
                
            return chart_config
            
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format: {str(e)}")
        except Exception as e:
            raise ValueError(f"Error processing chart configuration: {str(e)}")

    def generate_chart_json(
        self,
        query: str,
        df: pd.DataFrame
    ) -> Dict[str, Any]:
        """
        Generates a bar chart JSON configuration using LLM and data analysis.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            Dict[str, Any]: Complete bar chart JSON configuration
        """
        # Generate prompt for LLM
        prompt = self._generate_bar_chart_prompt(query, df)
        
        # Create chain
        prompt_template = PromptTemplate.from_template(prompt)
        chart_generation_chain = prompt_template | self.llm | StrOutputParser()
        
        # Get JSON from LLM
        json_response = chart_generation_chain.invoke({})
        
        # Validate and clean the JSON
        chart_config = self._validate_and_clean_json(json_response)
        
        return chart_config


# Initialize the generator
bar_generator = BarChartGenerator(colors=colors)

#### 3. LineChartGenerator

In [53]:
class LineChartGenerator:
    def __init__(self, colors):
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        self.colors = colors

    def _generate_line_chart_prompt(self, query: str, df: pd.DataFrame) -> str:
        """
        Generates the prompt for the LLM to analyze the data and suggest line chart configuration.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            str: Formatted prompt for the LLM
        """
        system_message = f'''
You are an expert in data visualization specializing in creating Highcharts line chart configurations. Analyze the provided DataFrame and user query to generate the optimal line chart JSON configuration.

DataFrame Preview:
{df.head(50).to_json().replace('{', '{{').replace('}', '}}')}

User Query: {query}

Available Color Codes (use only these):
{", ".join(self.colors)}

Follow this exact JSON template for line charts:
{{{{
  "chart": {{{{
    "type": "line",
    "plotBorderWidth": 1.5
    
  }}}},
  "accessibility": {{{{
    "description": "<accessibility_description>"
  }}}},
  "title": {{{{
    "text": "<chart_title>",
    "align": "left"
  }}}},
  "subtitle": {{{{
    "text": "<chart_subtitle>",
    "align": "left"
  }}}},
  "xAxis": {{{{
    "categories": ["Category 1", "Category 2", "Category 3", "Category 4"],
    "title": {{{{ "text": "<x_axis_title>" }}}},
    "crosshair": false,
    "gridLineWidth": 1,
    "lineWidth": 1,
    "allowDecimals": true,
    "accessibility": {{{{
      "rangeDescription": "<range_description>",
      "description": "<x_axis_accessibility_description>"
    }}}}
  }}}},
  "yAxis": {{{{
    "title": {{{{
      "text": "<y_axis_title>",
      "align": "high"
    }}}},
    "min": 0,
    "max": null,
    "labels": {{{{
      "overflow": "justify"
    }}}},
    "gridLineWidth": 1
  }}}},
  "series": [
    {{{{
      "name": "<series_name>",
      "data": [<data_points>],
      "color": "<hex_color>",
      "tooltip": {{{{
        "pointFormat": "<point_format>"
      }}}},
      "dataLabels": {{{{
        "enabled": true,
        "format": "<format>"
      }}}}
    }}}}
  ],
  "plotOptions": {{{{
    "series": {{{{
      "label": {{{{ "connectorAllowed": false }}}},
      "pointStart": null,
      "dataLabels": {{{{
        "enabled": true,
        "format": "{{{{point.y:,.2f}}}}"
      }}}}
    }}}}
  }}}},
  "legend": {{{{
    "layout": "vertical",
    "align": "right",
    "verticalAlign": "middle",
    "floating": false,
    "borderWidth": 1,
    "backgroundColor": "<hex_color>",
    "shadow": true
  }}}},
  "responsive": {{{{
    "rules": [
      {{{{
        "condition": {{{{ "maxWidth": 500 }}}},
        "chartOptions": {{{{
          "legend": {{{{
            "layout": "horizontal",
            "align": "center",
            "verticalAlign": "bottom"
          }}}}
        }}}}
      }}}}
    ]
  }}}}
}}}}


Instructions:
1. Analyze the DataFrame structure and user query to determine:
   - Appropriate columns for x-axis (categories) and y-axis (values)
   - Meaningful title and labels
   - Proper series configuration

2. Number Formatting Requirements:
   - For currency values (Spend/Price/Cost/Charge):
     * Use $ prefix (valuePrefix: '$')
     * Format example: $36,614,583.26
   - For non-currency values (like Encounter, Score):
     * No currency symbol
     * Format example: 36,614,583.26
   - Always:
     * Round to 2 decimal places
     * Use commas for thousand separators
     * Use Highcharts.numberFormat
   - For percentages:
     * Format to 1 decimal place
     * Include % symbol

3. Date and Text Requirements:
   - Use full month names (e.g., "January" not "Jan")
   - Use proper capitalizations in labels and titles

4. The JSON must follow the exact template structure above
5. Only use colors from the provided color list
6. Return only the JSON, no explanations

Generate the complete line chart JSON configuration:'''

        return system_message.strip()

    def _validate_and_clean_json(self, json_str: str) -> Dict[str, Any]:
        """
        Validates and cleans the JSON string returned by the LLM.
        
        Args:
            json_str (str): JSON string from LLM
            
        Returns:
            Dict[str, Any]: Cleaned and validated JSON dictionary
        """
        try:
            # Remove any markdown code block syntax if present
            json_str = json_str.replace('```json', '').replace('```', '').strip()
            
            # Parse JSON
            chart_config = json.loads(json_str)
            
            # Basic validation
            required_keys = ['chart', 'title', 'xAxis', 'yAxis', 'series', 'plotOptions']
            for key in required_keys:
                if key not in chart_config:
                    raise ValueError(f"Missing required key: {key}")
            
            # Ensure 'type' is 'line'
            if chart_config['chart'].get('type') != 'line':
                chart_config['chart']['type'] = 'line'
                
            return chart_config
            
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format: {str(e)}")
        except Exception as e:
            raise ValueError(f"Error processing chart configuration: {str(e)}")

    def generate_chart_json(
        self,
        query: str,
        df: pd.DataFrame
    ) -> Dict[str, Any]:
        """
        Generates a line chart JSON configuration using LLM and data analysis.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            Dict[str, Any]: Complete line chart JSON configuration
        """
        # Generate prompt for LLM
        prompt = self._generate_line_chart_prompt(query, df)
        
        # Create chain
        prompt_template = PromptTemplate.from_template(prompt)
        chart_generation_chain = prompt_template | self.llm | StrOutputParser()
        
        # Get JSON from LLM
        json_response = chart_generation_chain.invoke({})
        
        # Validate and clean the JSON
        chart_config = self._validate_and_clean_json(json_response)
        
        return chart_config

# Initialize the generator
line_generator = LineChartGenerator(colors = colors)

#### 4. PieChartGenerator

In [54]:
class PieChartGenerator:
    def __init__(self, colors):
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        self.colors = colors

    def _generate_pie_chart_prompt(self, query: str, df: pd.DataFrame) -> str:
        """
        Generates the prompt for the LLM to analyze the data and suggest pie chart configuration.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            str: Formatted prompt for the LLM
        """
        system_message = f'''
You are an expert in data visualization specializing in creating Highcharts pie chart configurations. Analyze the provided DataFrame and user query to generate the optimal pie chart JSON configuration.

DataFrame Preview:
{df.head(50).to_json().replace('{', '{{').replace('}', '}}')}

User Query: {query}

Available Color Codes (use only these):
{", ".join(self.colors)}

Follow this exact JSON template for pie charts:
{{{{
  "chart": {{{{
    "type": "pie",
    "plotBorderWidth": 1
  }}}},
  "accessibility": {{{{
    "description": "<accessibility_description>",
    "announceNewData": {{{{
      "enabled": true
    }}}}
  }}}},
  "title": {{{{
    "text": "<chart_title>",
    "align": "left",
    "style": {{{{
      "fontSize": "<font_size>"
    }}}}
  }}}},
  "subtitle": {{{{
    "text": "<chart_subtitle>",
    "align": "left"
  }}}},
  "plotOptions": {{{{
    "series": {{{{
      "allowPointSelect": true,
      "cursor": "pointer",
      "tooltip": {{{{
        "pointFormat": "<point_format>"
      }}}},
      "dataLabels": [
        {{{{
          "enabled": true,
          "distance": <+distance>,
          "format": "<format>"
        }}}}
      ]
    }}}}
  }}}},
  "series": [
    {{{{
      "name": "<series_name>",
      "colorByPoint": true,
      "data": [<data_points>]
    }}}}
  ],
  "legend": {{{{
    "enabled": true,
    "layout": "vertical",
    "align": "right",
    "verticalAlign": "middle",
    "itemMarginTop": 10,
    "itemMarginBottom": 10,
    "labelFormat": "{{{{name}}}} ({{{{percentage:.1f}}}}%)"
  }}}},
  "responsive": {{{{
    "rules": [
      {{{{
        "condition": {{{{ "maxWidth": 500 }}}},
        "chartOptions": {{{{
          "legend": {{{{
            "layout": "horizontal",
            "align": "center",
            "verticalAlign": "bottom"
          }}}}
        }}}}
      }}}}
    ]
  }}}}
}}}}

Instructions:
1. Analyze the DataFrame structure and user query to determine:
   - Appropriate categories and values for pie segments
   - Meaningful title and labels
   - Proper series configuration

2. Number Formatting Requirements:
   - For currency values (Spend/Price/Cost/Charge):
     * Use $ prefix (valuePrefix: '$')
     * Format example: $36,614,583.26
   - For non-currency values (like Encounter, Score):
     * No currency symbol
     * Format example: 36,614,583.26
   - Always:
     * Round to 2 decimal places
     * Use commas for thousand separators
     * Use Highcharts.numberFormat
   - For percentages:
     * Format to 1 decimal place
     * Include % symbol

3. Date and Text Requirements:
   - Use full month names (e.g., "January" not "Jan")
   - Use proper capitalizations in labels and titles
   - Ensure segment labels are clear and readable

4. Pie Chart Specific Requirements:
   - Set largest segment as sliced out by default
   - Enable point selection
   - Show percentage in data labels for segments > 10%
   - Include percentages in legend labels

5. The JSON must follow the exact template structure above
6. Only use colors from the provided color list
7. Return only the JSON, no explanations

Generate the complete pie chart JSON configuration:'''

        return system_message.strip()

    def _validate_and_clean_json(self, json_str: str) -> Dict[str, Any]:
        """
        Validates and cleans the JSON string returned by the LLM.
        
        Args:
            json_str (str): JSON string from LLM
            
        Returns:
            Dict[str, Any]: Cleaned and validated JSON dictionary
        """
        try:
            # Remove any markdown code block syntax if present
            json_str = json_str.replace('```json', '').replace('```', '').strip()
            
            # Parse JSON
            chart_config = json.loads(json_str)
            
            # Basic validation
            required_keys = ['chart', 'title', 'series', 'plotOptions']
            for key in required_keys:
                if key not in chart_config:
                    raise ValueError(f"Missing required key: {key}")
            
            # Ensure 'type' is 'pie'
            if chart_config['chart'].get('type') != 'pie':
                chart_config['chart']['type'] = 'pie'
            
            # Ensure pie-specific settings
            if 'series' in chart_config and chart_config['series']:
                for series in chart_config['series']:
                    series['colorByPoint'] = True
                
            if 'plotOptions' in chart_config:
                if 'series' not in chart_config['plotOptions']:
                    chart_config['plotOptions']['series'] = {}
                chart_config['plotOptions']['series']['allowPointSelect'] = True
                chart_config['plotOptions']['series']['cursor'] = 'pointer'
                
            return chart_config
            
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format: {str(e)}")
        except Exception as e:
            raise ValueError(f"Error processing chart configuration: {str(e)}")

    def generate_chart_json(
        self,
        query: str,
        df: pd.DataFrame
    ) -> Dict[str, Any]:
        """
        Generates a pie chart JSON configuration using LLM and data analysis.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            Dict[str, Any]: Complete pie chart JSON configuration
        """
        # Generate prompt for LLM
        prompt = self._generate_pie_chart_prompt(query, df)
        
        # Create chain
        prompt_template = PromptTemplate.from_template(prompt)
        chart_generation_chain = prompt_template | self.llm | StrOutputParser()
        
        # Get JSON from LLM
        json_response = chart_generation_chain.invoke({})
        
        # Validate and clean the JSON
        chart_config = self._validate_and_clean_json(json_response)
        
        return chart_config

# Initialization - Pie Chart
pie_generator = PieChartGenerator(colors=colors)

#### 5. BubbleChartGenerator

In [55]:
class BubbleChartGenerator:
    def __init__(self, colors):
        """
        Initializes the BubbleChartGenerator with a list of approved colors.
        
        Args:
            colors (list): List of approved color codes for the chart
        """
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        self.colors = colors

    def _generate_bubble_chart_prompt(self, query: str, df: pd.DataFrame) -> str:
        """
        Generates the prompt for the LLM to analyze the data and suggest bubble chart configuration.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            str: Formatted prompt for the LLM
        """
        system_message = f'''
You are an expert in data visualization specializing in creating Highcharts bubble chart configurations. Analyze the provided DataFrame and user query to generate the optimal bubble chart JSON configuration.

DataFrame Preview:
{df.head(50).to_json().replace('{', '{{').replace('}', '}}')}

User Query: {query}

Available Color Codes (use only these):
{", ".join(self.colors)}

Follow this exact JSON template for bubble charts:
{{{{
  "chart": {{{{
    "type": "bubble",
    "plotBorderWidth": 1,
    "zooming": {{{{
      "type": "xy"
    }}}},
    "backgroundColor": null
  }}}},
  "title": {{{{
    "text": "<chart_title>",
    "align": "left"
  }}}},
  "subtitle": {{{{
    "text": "<chart_subtitle>",
    "align": "left"
  }}}},
  "accessibility": {{{{
    "point": {{{{
      "valueDescriptionFormat": "{{{{index}}}}. {{{{point.name}}}}, x: {{{{point.x}}}}g, y: {{{{point.y}}}}g, z: {{{{point.z}}}}%."
    }}}}
  }}}},
  "xAxis": {{{{
    "gridLineWidth": 1,
    "title": {{{{
      "text": "<x_axis_title>"
    }}}},
    "labels": {{{{
      "format": "{{{{value}}}} <unit>"
    }}}},
    "plotLines": [
      {{{{
        "color": "black",
        "dashStyle": "dot",
        "width": 2,
        "value": "<reference_line_value>",
        "label": {{{{
          "rotation": 0,
          "y": 15,
          "style": {{{{
            "fontStyle": "italic"
          }}}},
          "text": "<reference_line_label>"
        }}}},
        "zIndex": 3
      }}}}
    ],
    "accessibility": {{{{
      "rangeDescription": "Range: <x_range_description>"
    }}}}
  }}}},
  "yAxis": {{{{
    "startOnTick": false,
    "endOnTick": false,
    "title": {{{{
      "text": "<y_axis_title>"
    }}}},
    "labels": {{{{
      "format": "{{{{value}}}} <unit>"
    }}}},
    "maxPadding": 0.2,
    "plotLines": [
      {{{{
        "color": "black",
        "dashStyle": "dot",
        "width": 2,
        "value": "<reference_line_value>",
        "label": {{{{
          "align": "right",
          "style": {{{{
            "fontStyle": "italic"
          }}}},
          "text": "<reference_line_label>",
          "x": -10
        }}}},
        "zIndex": 3
      }}}}
    ],
    "accessibility": {{{{
      "rangeDescription": "Range: <y_range_description>"
    }}}}
  }}}},
  "tooltip": {{{{
    "useHTML": true,
    "headerFormat": "<table>",
    "pointFormat": "<"pointFormat": "<tr><th>Category:</th><td>{{{{point.name}}}}</td></tr>\\n<tr><th>Value X:</th><td>{{{{point.x}}}}</td></tr>\\n<tr><th>Value Y:</th><td>{{{{point.y}}}}</td></tr>\\n<tr><th>Metric Z:</th><td>{{{{point.z}}}}</td></tr>">",
    "footerFormat": "</table>",
    "followPointer": true
  }}}},
  "plotOptions": {{{{
    "series": {{{{
      "dataLabels": {{{{
        "enabled": true,
        "format": "{{{{point.name}}}}"
      }}}},
      "marker": {{{{
        "fillOpacity": 0.7
      }}}}
    }}}}
  }}}},
  "series": [
    {{{{
      "data": [
        {{{{
          "x": "<x_value>",
          "y": "<y_value>",
          "z": "<z_value>",
          "name": "<point_name>",
          "color": "<color_code>"
        }}}}
      ],
      "name": "<series_name>",
      "marker": {{{{
        "symbol": "circle",
        "lineColor": null,
        "lineWidth": 1
      }}}}
    }}}}
  ],
  "legend": {{{{
    "enabled": true,
    "layout": "vertical",
    "align": "right",
    "verticalAlign": "middle",
    "itemMarginTop": 10,
    "itemMarginBottom": 10,
    "bubbleLegend": {{{{
      "enabled": true,
      "borderColor": "black",
      "borderWidth": 2,
      "connectorWidth": 2,
      "labels": {{{{
        "align": "left"
      }}}}
    }}}}
  }}}},
  "responsive": {{{{
    "rules": [
      {{{{
        "condition": {{{{
          "maxWidth": 500
        }}}},
        "chartOptions": {{{{
          "legend": {{{{
            "layout": "horizontal",
            "align": "center",
            "verticalAlign": "bottom"
          }}}}
        }}}}
      }}}}
    ]
  }}}}
}}}}

Instructions:
1. Analyze the DataFrame structure and user query to determine:
   - Appropriate x, y, and z values for bubble chart
   - Meaningful axis titles and labels
   - Proper series configuration and bubble sizes

2. Number Formatting Requirements:
   - For currency values (Spend/Price/Cost/Charge):
     * Use $ prefix (valuePrefix: '$')
     * Format example: $36,614,583.26
   - For non-currency values (like Encounter, Score):
     * No currency symbol
     * Format example: 36,614,583.26
   - Always:
     * Round to 2 decimal places
     * Use commas for thousand separators
     * Use Highcharts.numberFormat
   - For percentages:
     * Format to 1 decimal place
     * Include % symbol

3. Bubble Chart Specific Requirements:
   - Ensure bubble sizes (z values) are proportional
   - Add meaningful tooltips with all three dimensions
   - Enable bubble legend if appropriate
   - Use semi-transparent bubbles to handle overlapping
   - Include appropriate reference lines if data suggests thresholds

4. The JSON must follow the exact template structure above
5. Only use colors from the provided color list
6. Return only the JSON, no explanations

Generate the complete bubble chart JSON configuration:'''
        return system_message.strip()

    def _validate_and_clean_json(self, json_str: str) -> Dict[str, Any]:
        """
        Validates and cleans the JSON string returned by the LLM.
        Ensures all required fields are present and data points are properly formatted.
        
        Args:
            json_str (str): JSON string from LLM
            
        Returns:
            Dict[str, Any]: Cleaned and validated JSON dictionary
        """
        try:
            # Clean and parse JSON
            json_str = json_str.replace('```json', '').replace('```', '').strip()
            chart_config = json.loads(json_str)

            # Ensure 'chart' block and chart type
            if 'chart' not in chart_config:
                chart_config['chart'] = {}
            if 'type' not in chart_config['chart']:
                chart_config['chart']['type'] = 'bubble'
            if 'plotBorderWidth' not in chart_config['chart']:
                chart_config['chart']['plotBorderWidth'] = 1

            # Ensure xAxis and yAxis exist (basic placeholder if missing)
            if 'xAxis' not in chart_config:
                chart_config['xAxis'] = {
                    "gridLineWidth": 1,
                    "title": {"text": ""},
                    "labels": {"format": "{value}"},
                    "plotLines": [],
                    "accessibility": {"rangeDescription": ""}
                }
            if 'yAxis' not in chart_config:
                chart_config['yAxis'] = {
                    "startOnTick": False,
                    "endOnTick": False,
                    "title": {"text": ""},
                    "labels": {"format": "{value}"},
                    "maxPadding": 0.2,
                    "plotLines": [],
                    "accessibility": {"rangeDescription": ""}
                }

            # Validate 'series' data
            if 'series' not in chart_config:
                chart_config['series'] = []
            for series in chart_config['series']:
                if 'data' not in series:
                    series['data'] = []

                for point in series['data']:
                    # Convert x, y, z values to float if possible
                    for coord in ['x', 'y', 'z']:
                        if coord in point and point[coord] is not None:
                            try:
                                point[coord] = float(point[coord])
                            except (ValueError, TypeError):
                                point[coord] = 0.0

                    # Validate color usage; fallback to first available color if invalid or missing
                    if 'color' not in point or point['color'] not in self.colors:
                        point['color'] = self.colors[0] if self.colors else "#000000"

            return chart_config

        except Exception as e:
            raise ValueError(f"Error processing chart configuration: {str(e)}")

    def generate_chart_json(
        self,
        query: str,
        df: pd.DataFrame
    ) -> Dict[str, Any] | str:
        """
        Generates a bubble chart JSON configuration using LLM and data analysis.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            Dict[str, Any]: Complete bubble chart JSON configuration
        """
        # Generate prompt for LLM
        prompt = self._generate_bubble_chart_prompt(query, df)

        # Create chain
        prompt_template = PromptTemplate.from_template(prompt)
        chart_generation_chain = prompt_template | self.llm | StrOutputParser()

        # Get JSON from LLM
        json_response = chart_generation_chain.invoke({})

        # Validate and clean the JSON
        chart_config = self._validate_and_clean_json(json_response)
        return chart_config

# Example usage:
bubble_generator = BubbleChartGenerator(colors=colors)

In [56]:
# result_config = bubble_generator.generate_chart_json(user_prompt, df)
# # print(result_config)
# json.loads(result_config)

#### 6. ScatterChartGenerator

In [57]:
class ScatterChartGenerator:
    def __init__(self, colors):
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        self.colors = colors

    def _generate_scatter_chart_prompt(self, query: str, df: pd.DataFrame) -> str:
        """
        Generates the prompt for the LLM to analyze the data and suggest scatter chart configuration.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            str: Formatted prompt for the LLM
        """
        system_message = f'''
You are an expert in data visualization specializing in creating Highcharts scatter plot configurations. Analyze the provided DataFrame and user query to generate the optimal scatter chart JSON configuration.

DataFrame Preview:
{df.head(50).to_json().replace('{', '{{').replace('}', '}}')}

User Query: {query}

Available Color Codes (use only these):
{", ".join(self.colors)}

Follow this exact JSON template for scatter charts:
{{{{
  "chart": {{{{
    "type": "scatter",
    "zooming": {{{{
      "type": "xy"
    }}}}
  }}}},
  "colors": ["<colors>"],
  "title": {{{{
    "text": "<chart_title>",
    "align": "left"
  }}}},
  "subtitle": {{{{
    "text": "<chart_subtitle>",
    "align": "left"
  }}}},
  "xAxis": {{{{
    "title": {{{{
      "text": "<x_axis_title>"
    }}}},
    "labels": {{{{
      "format": "<proper_format>"
    }}}},
    "startOnTick": true,
    "endOnTick": true,
    "showLastLabel": true,
    "gridLineWidth": 1
  }}}},
  "yAxis": {{{{
    "title": {{{{
      "text": "<y_axis_title>"
    }}}},
    "labels": {{{{
      "format": "<proper_format>"
    }}}}
  }}}},
  "legend": {{{{
    "enabled": true,
    "layout": "vertical",
    "align": "right",
    "verticalAlign": "middle",
    "borderWidth": 1
  }}}},
  "plotOptions": {{{{
    "scatter": {{{{
      "marker": {{{{
        "radius": 5,
        "symbol": "circle",
        "states": {{{{
          "hover": {{{{
            "enabled": true,
            "lineColor": "<line_color>"
          }}}}
        }}}}
      }}}},
      "states": {{{{
        "hover": {{{{
          "marker": {{{{
            "enabled": false
          }}}}
        }}}}
      }}}},
      "jitter": {{{{
        "x": 0.005
      }}}}
    }}}}
  }}}},
  "tooltip": {{{{
    "pointFormat": "<point_format>"
  }}}},
  "series": [
    {{{{
      "name": "<series_name>",
      "data": [
        {{{{
          "x": "<x_value>",
          "y": "<y_value>",
          "name": "<point_name>",
          "color": "<color_code>"
        }}}}
      ],
      "marker": {{{{
        "symbol": "circle"
      }}}}
    }}}}
  ],
  "responsive": {{{{
    "rules": [
      {{{{
        "condition": {{{{
          "maxWidth": 500
        }}}},
        "chartOptions": {{{{
          "legend": {{{{
            "align": "center",
            "layout": "horizontal",
            "verticalAlign": "bottom"
          }}}}
        }}}}
      }}}}
    ]
  }}}}
}}}}

Instructions:
1. Analyze the DataFrame structure and user query to determine:
   - Appropriate x and y variables for scatter plot
   - Meaningful title and axis labels
   - Proper series configuration

2. Number Formatting Requirements:
   - For currency values (Spend/Price/Cost/Charge):
     * Use $ prefix (valuePrefix: '$')
     * Format example: $36,614,583.26
   - For non-currency values (like Encounter, Score):
     * No currency symbol
     * Format example: 36,614,583.26
   - Always:
     * Round to 2 decimal places
     * Use commas for thousand separators
     * Use Highcharts.numberFormat
   - For percentages:
     * Format to 1 decimal place
     * Include % symbol

3. Scatter Plot Specific Requirements:
   - Choose appropriate axis scales
   - Set meaningful marker sizes
   - Enable zoom functionality
   - Configure proper tooltips
   - Handle point overlaps using jitter if needed

4. The JSON must follow the exact template structure above
5. Only use colors from the provided color list
6. Return only the JSON, no explanations

Generate the complete scatter chart JSON configuration:'''

        return system_message.strip()

    def _validate_and_clean_json(self, json_str: str) -> Dict[str, Any]:
        """
        Validates and cleans the JSON string returned by the LLM.
        
        Args:
            json_str (str): JSON string from LLM
            
        Returns:
            Dict[str, Any]: Cleaned and validated JSON dictionary
        """
        try:
            # Remove any markdown code block syntax if present
            json_str = json_str.replace('```json', '').replace('```', '').strip()
            
            # Parse JSON
            chart_config = json.loads(json_str)
            
            # Basic validation
            required_keys = ['chart', 'title', 'xAxis', 'yAxis', 'series']
            for key in required_keys:
                if key not in chart_config:
                    raise ValueError(f"Missing required key: {key}")
            
            # Ensure 'type' is 'scatter'
            if chart_config['chart'].get('type') != 'scatter':
                chart_config['chart']['type'] = 'scatter'
            
            # Ensure scatter-specific settings
            if 'plotOptions' in chart_config:
                if 'scatter' not in chart_config['plotOptions']:
                    chart_config['plotOptions']['scatter'] = {}
                    
                # Set default marker options if not present
                if 'marker' not in chart_config['plotOptions']['scatter']:
                    chart_config['plotOptions']['scatter']['marker'] = {
                        'radius': 5,
                        'symbol': 'circle'
                    }
            
            # Ensure zooming is enabled
            if 'zooming' not in chart_config['chart']:
                chart_config['chart']['zooming'] = {'type': 'xy'}
            

            # Validate series colors
            if 'series' in chart_config:
                for series in chart_config['series']:
                    if 'data' in series:
                        for point in series['data']:
                            if 'color' in point and point['color'] not in [f"#{color}" for color in self.colors]:
                                point['color'] = f"#{self.colors[0]}"  # Default to first color if invalid
                
            return chart_config
            
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format: {str(e)}")
        except Exception as e:
            raise ValueError(f"Error processing chart configuration: {str(e)}")

    def generate_chart_json(
        self,
        query: str,
        df: pd.DataFrame
    ) -> Dict[str, Any]:
        """
        Generates a scatter chart JSON configuration using LLM and data analysis.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            Dict[str, Any]: Complete scatter chart JSON configuration
        """
        # Generate prompt for LLM
        prompt = self._generate_scatter_chart_prompt(query, df)
        
        # Create chain
        prompt_template = PromptTemplate.from_template(prompt)
        chart_generation_chain = prompt_template | self.llm | StrOutputParser()
        
        # Get JSON from LLM
        json_response = chart_generation_chain.invoke({})
        
        # Validate and clean the JSON
        chart_config = self._validate_and_clean_json(json_response)
        
        return chart_config

# Initialization - Scatter Chart
scatter_generator = ScatterChartGenerator(colors=colors)

#### 7. AreaChartGenerator

In [58]:
class AreaChartGenerator:
    def __init__(self, colors):
        self.llm = ChatOpenAI(model='gpt-4o-2024-08-06', temperature=0)
        self.colors = colors

    def _generate_area_chart_prompt(self, query: str, df: pd.DataFrame) -> str:
        """
        Generates the prompt for the LLM to analyze the data and suggest area chart configuration.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            str: Formatted prompt for the LLM
        """
        system_message = f'''You are an expert in data visualization specializing in creating Highcharts area chart configurations. Analyze the provided DataFrame and user query to generate the optimal area chart JSON configuration.

DataFrame Preview:
{df.head(50).to_json().replace('{', '{{').replace('}', '}}')}

User Query: {query}

Available Color Codes (use only these):
{", ".join(self.colors)}

Follow this exact JSON template for area charts:
{{{{
  "chart": {{{{
    "type": "area"
  }}}},
  "accessibility": {{{{
    "description": "<accessibility_description>"
  }}}},
  "title": {{{{
    "text": "<chart_title>",
    "align": "<align>"
  }}}},
  "subtitle": {{{{
    "text": "<chart_subtitle>",
    "align": "<align>"
  }}}},
  "xAxis": {{{{
    "categories": null,
    "title": {{{{ "text": "<x_axis_title>" }}}},
    "crosshair": false,
    "gridLineWidth": null,
    "lineWidth": null,
    "allowDecimals": false,
    "accessibility": {{{{
      "rangeDescription": "<range_description>",
      "description": <description>
    }}}}
  }}}},
  "yAxis": {{{{
    "title": {{{{
      "text": "<y_axis_title>",
      "align": null
    }}}},
    "min": null,
    "max": null,
    "labels": {{{{
      "overflow": null
    }}}},
    "gridLineWidth": null
  }}}},
  "series": [
    {{{{
      "name": "<series_name>",
      "data": [<data_points>],
      "color": "<hex_color>"
    }}}}
  ],
  "tooltip": {{{{
    "pointFormat": "<point_format>",
    "valuePrefix": <value_prefix>
  }}}},
  "plotOptions": {{{{
    "area": {{{{
      "pointStart": "<start_year>",
      "marker": {{{{
        "enabled": false,
        "symbol": "circle",
        "radius": 2,
        "states": {{{{
          "hover": {{{{
            "enabled": true
          }}}}
        }}}}
      }}}}
    }}}}
  }}}},
  "legend": {{{{
    "layout": "vertical",
    "align": "right",
    "verticalAlign": "middle",
    "x": null,
    "y": null,
    "floating": null,
    "borderWidth": null,
    "backgroundColor": null,
    "shadow": null
  }}}},
  "responsive": {{{{
    "rules": [
      {{{{
        "condition": {{{{ "maxWidth": 500 }}}},
        "chartOptions": {{{{
          "legend": {{{{
            "layout": "horizontal",
            "align": "center",
            "verticalAlign": "bottom"
          }}}}
        }}}}
      }}}}
    ]
  }}}}
}}}}

Instructions:
1. Analyze the DataFrame structure and user query to determine:
   - Appropriate columns for x-axis (time series) and y-axis (values)
   - Meaningful title and labels
   - Proper series configuration

2. **Number Formatting Requirements:
   - For currency values (Spend/Price/Cost/Charge):
     * Use $ prefix (valuePrefix: '$')
     * Format example: $36,614,583.26
   - For non-currency values (like Encounter, Score):
     * No currency symbol
     * Format example: 36,614,583.26
   - Always:
     * Round to 2 decimal places
     * Use commas for thousand separators
     * Use Highcharts.numberFormat
   - For percentages:
     * Format to 1 decimal place
     * Include % symbol

3. Area Chart Specific Requirements:
   - Best suited for time-series data showing cumulative totals
   - Use pointStart for setting the starting time period
   - Configure markers appropriately for data point visibility
   - Consider using fillOpacity for overlapping areas

4. Date and Text Requirements:
   - Use full month names (e.g., "January" not "Jan")
   - Use proper capitalizations in labels and titles

5. The JSON must follow the exact template structure above
6. Only use colors from the provided color list
7. Return only the JSON, no explanations

Generate the complete area chart JSON configuration:'''

        return system_message.strip()

    def _validate_and_clean_json(self, json_str: str) -> Dict[str, Any]:
        """
        Validates and cleans the JSON string returned by the LLM.
        
        Args:
            json_str (str): JSON string from LLM
            
        Returns:
            Dict[str, Any]: Cleaned and validated JSON dictionary
        """
        try:
            # Remove any markdown code block syntax if present
            json_str = json_str.replace('```json', '').replace('```', '').strip()
            
            # Parse JSON
            chart_config = json.loads(json_str)
            
            # Basic validation
            required_keys = ['chart', 'title', 'xAxis', 'yAxis', 'series', 'plotOptions']
            for key in required_keys:
                if key not in chart_config:
                    raise ValueError(f"Missing required key: {key}")
            
            # Ensure 'type' is 'area'
            if chart_config['chart'].get('type') != 'area':
                chart_config['chart']['type'] = 'area'
                
            # Validate area-specific configurations
            plot_options = chart_config.get('plotOptions', {}).get('area', {})
            if not isinstance(plot_options, dict):
                chart_config['plotOptions'] = {'area': {
                    'marker': {
                        'enabled': False,
                        'symbol': 'circle',
                        'radius': 2,
                        'states': {'hover': {'enabled': True}}
                    }
                }}
            
            return chart_config
            
        except json.JSONDecodeError as e:
            raise ValueError(f"Invalid JSON format: {str(e)}")
        except Exception as e:
            raise ValueError(f"Error processing chart configuration: {str(e)}")

    def generate_chart_json(
        self,
        query: str,
        df: pd.DataFrame
    ) -> Dict[str, Any]:
        """
        Generates an area chart JSON configuration using LLM and data analysis.
        
        Args:
            query (str): User's query
            df (pd.DataFrame): Input DataFrame
            
        Returns:
            Dict[str, Any]: Complete area chart JSON configuration
        """
        # Generate prompt for LLM
        prompt = self._generate_area_chart_prompt(query, df)
        
        # Create chain
        prompt_template = PromptTemplate.from_template(prompt)
        chart_generation_chain = prompt_template | self.llm | StrOutputParser()
        
        # Get JSON from LLM
        json_response = chart_generation_chain.invoke({})
        
        # Validate and clean the JSON
        chart_config = self._validate_and_clean_json(json_response)
        
        return chart_config

# Example initialization:
area_chart_generator = AreaChartGenerator(colors=colors)

### Json_Output Function (For my use only)

#### save_chart_configs

Generates and saves JSON files for a list of chart configurations.

Args:
    chart_configs (list): List of chart configurations.
    output_directory (str): Directory to save the JSON files.

Returns:
    None


In [59]:
#################################################
#                 save_chart_configs
#################################################

def save_chart_configs(chart_configs, output_directory):
    """
    Generates and saves JSON files for a list of chart configurations.

    Args:
        chart_configs (list): List of chart configurations.
        output_directory (str): Directory to save the JSON files.

    Returns:
        None
    """
    import os
    import json

    def create_highchart_file(json_input, output_filename='highchart_config.txt'):
        """
        Takes a JSON input (as a dict or a valid JSON string) and writes a file
        with the following format:
            Highcharts.chart('container', <properly_indented_json>);
        
        :param json_input: A Python dictionary or a JSON string compatible with Highcharts.
        :param output_filename: File name for the output text file (default: 'highchart_config.txt').
        """
        # 1. If the input is a string, convert it to a Python dictionary
        if isinstance(json_input, str):
            try:
                json_input = json.loads(json_input)
            except json.JSONDecodeError as e:
                raise ValueError(f"Invalid JSON string provided: {e}")
        
        # 2. Convert the dictionary to a pretty-printed JSON string
        json_str = json.dumps(json_input, indent=2)
        
        # 3. Construct the final content
        content = f"Highcharts.chart('container', {json_str});"
        
        # 4. Write the content to a text file
        with open(output_filename, 'w', encoding='utf-8') as f:
            f.write(content)

        print(f"File '{output_filename}' created successfully.")

    # Ensure the output directory exists
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    for idx, config in enumerate(chart_configs):
        try:
            # Extract chart type and generate output filename
            chart_type = config.get('chart', {}).get('type', f"chart_{idx}")
            output_filename = os.path.join(output_directory, f"{chart_type}_chart_{idx}.json")

            # Save the configuration as a JSON file
            with open(output_filename, 'w') as json_file:
                json.dump(config, json_file, indent=4)

            print(f"Successfully saved {chart_type} chart configuration to {output_filename}")

            # Optionally create a Highcharts file
            highchart_output_filename = os.path.join(output_directory, f"{chart_type}_chart_{idx}.txt")
            create_highchart_file(config, highchart_output_filename)

        except Exception as e:
            print(f"Error saving chart configuration {idx}: {e}")

### ChartGeneratorManager

In [60]:
import asyncio
from concurrent.futures import ThreadPoolExecutor

class ChartGeneratorManager:
   def __init__(self):
       self.generators = {
           'line': LineChartGenerator(colors=colors),
           'area': AreaChartGenerator(colors=colors),
           'bar': BarChartGenerator(colors=colors),
           'column': ColumnChartGenerator(colors=colors),
           'pie': PieChartGenerator(colors=colors),
           'scatter': ScatterChartGenerator(colors=colors), 
           'bubble': BubbleChartGenerator(colors=colors)
       }
       self.executor = ThreadPoolExecutor()

   def validate_dataframe(self, df):
       """Validates if DataFrame meets minimum requirements for charting"""
       return not (
           df.empty or  # Empty dataframe
           len(df) < 2 or  # Minimum 2 rows needed
           len(df.columns) < 2  # Minimum 2 columns needed
       )

   async def generate_single_chart(self, chart_type, query, df):
       if chart_type not in self.generators or not self.validate_dataframe(df):
           return None
           
       try:
           generator = self.generators[chart_type]
           chart_config = await asyncio.get_event_loop().run_in_executor(
               self.executor,
               generator.generate_chart_json,
               query,
               df
           )
           return chart_config
       except Exception as e:
           print(f"Error generating {chart_type} chart: {e}")
           return None

   async def generate_chart_configs(self, list_of_charts, query, df):
       if not self.validate_dataframe(df):
           return []
           
       tasks = [
           self.generate_single_chart(chart_type, query, df)
           for chart_type in list_of_charts
       ]
       chart_configs = await asyncio.gather(*tasks)
       return [config for config in chart_configs if config is not None]

   def close(self):
       self.executor.shutdown()

chart_manager = ChartGeneratorManager()

### Run - New chart_configs

In [61]:
# ### Running - Chart Functions
# print(f'user_prompt: {user_prompt}\n')

# # ╔═══════════════════════════════════════════════════════════════════════╗
# # ║                      🌟 chart_recommender 🌟                         ║
# # ╚═══════════════════════════════════════════════════════════════════════╝

# list_of_charts, suitable_chart_types = chart_recommender.get_suitable_charts(user_prompt, generated_sql_query, df)
# print(json.dumps(suitable_chart_types, indent = 2))
# print(f'\nlist_of_charts: {list_of_charts}\n')

# # ╔═══════════════════════════════════════════════════════════════════════╗
# # ║                      🌟 chart_manager 🌟                             ║
# # ╚═══════════════════════════════════════════════════════════════════════╝

# chart_configs = await chart_manager.generate_chart_configs(list_of_charts, user_prompt, df)
# print(f"\nGenerated Chart Configs: {[chart_type['chart']['type'] for chart_type in chart_configs]}")

# # ╔═══════════════════════════════════════════════════════════════════════╗
# # ║                      🌟 save_chart_configs 🌟                        ║
# # ╚═══════════════════════════════════════════════════════════════════════╝

# print(f'\n================================= Saving the Files ==================================\n')
# output_directory = "highchart_json_files"
# if chart_configs:
#     save_chart_configs(chart_configs, output_directory)

# Generate and Execute SQL Query

In [62]:
def generate_and_execute_sql(
    user_prompt,
    table_to_use,
    table_columns,
    column_info_from_knowledge_base,
    prompt_to_use_for_complex_question,
    columns_info_to_generate_sql
):
    """
    Generate and execute SQL queries based on user prompts and available table information.
    
    Parameters:
    -----------
    user_prompt : str
        The user's natural language query
    table_to_use : str, optional
        The table name to query against
    table_columns : list, optional
        List of columns in the table
    column_info_from_knowledge_base : dict, optional
        Additional context about columns from knowledge base
    prompt_to_use_for_complex_question : str, optional
        Additional instructions for complex query generation
    sql_query_generator : object, optional
        Object with get_result_from_sql_query method for direct SQL execution
    multi_stage_sql_generator : object, optional
        Object with generate_and_iteratively_refine_sql method for complex SQL generation
        
    Returns:
    --------
    tuple
        (generated_sql_query, df) where df is the pandas DataFrame with query results
    """
    # Initialize defaults
    generated_sql_query = ''
    df = pd.DataFrame()
    
    # Only proceed if we have a table to query
    if table_to_use:
        # Check if complex instruction prompt exists (multi-stage generation path)
        if prompt_to_use_for_complex_question:
            print("[INFO] Running in Multi-Stage SQL Generation Mode...")
            # Generate final SQL with refinement steps
            final_sql = multi_stage_sql_generator.generate_and_iteratively_refine_sql(
                user_prompt=user_prompt,
                schema={"table_name": table_to_use, "columns": table_columns},
                domain_context=column_info_from_knowledge_base,
                complex_instructions=prompt_to_use_for_complex_question,
                max_iterations=2
            )
            # Execute the generated SQL directly (Mode 2)
            generated_sql_query, df = sql_query_generator.get_result_from_sql_query(
                generated_sql_query=final_sql
            )
        else:
            print("[INFO] Running in Normal SQL Generation Mode...")
            # Generate and execute SQL based on prompt + column + table info (Mode 1)
            generated_sql_query, df = sql_query_generator.get_result_from_sql_query(
                user_prompt=user_prompt,
                column_list=columns_info_to_generate_sql,  # Assuming columns_info_to_generate_sql == table_columns
                redshift_table=table_to_use
            )
    else:
        print("\n❌ No table found. Skipping query generation.\n")
    
    return generated_sql_query, df

In [63]:
# generated_sql_query, df = generate_and_execute_sql(
#             user_prompt=user_prompt_to_use,
#             table_to_use=table_to_use,
#             table_columns=table_columns,
#             column_info_from_knowledge_base=column_info_from_knowledge_base,
#             prompt_to_use_for_complex_question=prompt_to_use_for_complex_question,
#             columns_info_to_generate_sql = columns_info_to_generate_sql,
#         )

# CostSavingQuestionRefiner [TEMP_USE_CASE] - refine_cost_savings

In [64]:
class CostSavingQuestionRefiner:
    def __init__(self, model_name: str = 'gpt-4.1-mini'):
        """
        Initializes the cost saving question refiner using an OpenAI LLM.
        """
        self.llm = ChatOpenAI(model=model_name, temperature=0)

        # Example-guided prompt
        self.prompt = PromptTemplate.from_template(
            """You are a helpful assistant.

When a user asks how to save money for a specific UNSPSC category, you must rephrase the question
into a structured form that requests supplier-level spend and pricing details.

### Rules:
- Only rewrite if the user is asking how to save money on a specific UNSPSC category.
- Do NOT change or omit the UNSPSC value.
- Always return the question in the following format exactly:
  "What is the total_spend, total_quantity, average_unit_price, and emr_supply_supplier for UNSPSC <UNSPSC_VALUE>?"

- The refined question MUST ensure the returned dataframe includes the following columns:
  - 'emr_supply_supplier'
  - 'total_spend'
  - 'total_quantity'
  - 'average_unit_price'

### Examples:

User Question:
How can I save money on @UNSPSC UNSPSC_Name?

→ Rephrased:
What is the total_spend, total_quantity, average_unit_price, and emr_supply_supplier for UNSPSC UNSPSC_Name?

User Question:
Can we reduce cost for UNSPSC UNSPSC_Name?

→ Rephrased:
What is the total_spend, total_quantity, average_unit_price, and emr_supply_supplier for UNSPSC UNSPSC_Name?

----------------------------
Now rewrite the following user question:

User Question:
{question}

→ Rephrased:"""
        )

        self.chain = self.prompt | self.llm | StrOutputParser()

    def calculate_unspsc_savings(self, df: pd.DataFrame) -> dict:
        """
        Calculates total spend, total quantity, and estimated savings
        by consolidating purchases to the supplier with the lowest average unit price.

        Parameters:
            df (pd.DataFrame): Must include columns:
                - 'emr_supply_supplier'
                - 'total_spend'
                - 'total_quantity'
                - 'average_unit_price'

        Returns:
            dict: Summary with total cost, min supplier, and estimated savings.
        """
        df2 = df.copy()

        # Calculate total current cost
        total_cost = df2['total_spend'].sum()

        # Identify minimum unit price and supplier
        min_unit_price = df2['average_unit_price'].min()
        min_supplier = df2[df2['average_unit_price'] == min_unit_price]['emr_supply_supplier'].values[0]

        # Total quantity across all suppliers
        total_quantity = df2['total_quantity'].sum()

        # Optimal cost if purchased from lowest-cost supplier
        optimal_cost = total_quantity * min_unit_price

        # Estimated savings
        savings = total_cost - optimal_cost
        return {
            "total_cost": round(total_cost, 2),
            "optimal_cost": round(optimal_cost, 2),
            "estimated_savings": round(savings, 2),
            "total_quantity": int(total_quantity),
            "lowest_unit_price": round(min_unit_price, 2),
            "lowest_cost_supplier": min_supplier
        }

    def refine(self, question: str) -> str:
        """
        Refines a cost-saving question into a structured query for supplier pricing.
        """
        return self.chain.invoke({"question": question}).strip()

refine_cost_savings = CostSavingQuestionRefiner()

In [65]:
# refine_cost_savings.refine(user_prompt)

# Running all at once

In [66]:
#### ------------------------------------------------ ## ------------------------------------------------ ####
####                                  Creating a new User to store messages
#### ------------------------------------------------ ## ------------------------------------------------ ####
memory_manager = MemoryChainManager()
user_id = 'temp_user'
creating_new_session_memory = memory_manager.get_by_session_id(user_id)


**New User Found: temp_user.



In [90]:
user_prompt = input()

# print(f'\n\t >>>>>>>> User Query: {user_prompt} <<<<<<<<<')
pretty_display(user_prompt, title="User Query")

sql_response_chain_result = True

# # >>>>>>>>>>>>>>>>>>>>>>>>  Memory Chain | combined_memory_chain_function <<<<<<<<<<<<<<<<<<<<<<<
user_prompt_to_use, result_from_memory = memory_manager.combined_chain(user_prompt, user_id)
prompt_to_save = user_prompt
if user_prompt != user_prompt_to_use:
    prompt_to_save = f"{user_prompt} [REFINED QUERY: ({user_prompt_to_use})]"
# # >>>>>>>>>>>>>>>>>>>>>>>>  Memory Chain | combined_memory_chain_function <<<<<<<<<<<<<<<<<<<<<<<

## ============ TEMPORARY USE CASE ============= ##
temp_use_case_for_cost_savings = False
if 'save' in user_prompt.lower() and 'money' in user_prompt.lower() and 'unspsc' in user_prompt.lower():

    result_from_memory = None
    prompt_to_save = user_prompt

    print(f'\n=================\nUNSPSC Savings question Found.\n=================\n')
    user_prompt_to_use = refine_cost_savings.refine(user_prompt)
    print(f'Refined Question: {user_prompt_to_use}')
    temp_use_case_for_cost_savings = True
## ============ ==================== ============= ##


df = pd.DataFrame()
if not result_from_memory:
    query_type = query_type_classifier.classify_query(user_prompt_to_use)

    if query_type == '1': # API Endpoint | General Query - Running Chroma Chain
        print(f'\n\t >>>>>>>> Classification: {query_type} | General Query - Running Chroma Chain <<<<<<<<<')
        context_check = query_type_classifier.is_out_of_context_general_knowledge(user_prompt_to_use)

        if context_check == '0':
            print(f'\n***** Its an out-of-scope Question *****')
            
            general_query_result = query_type_classifier.generate_out_of_scope_response_with_llm(user_prompt_to_use)
        else:
            general_query_result = general_query_api.get_result(user_prompt_to_use)
        
        print(f'\nResponse: {general_query_result}')

        ### Saving Memory -> memory_manager.store
        memory_manager.store[user_id].add_user_message(prompt_to_save)
        memory_manager.store[user_id].add_ai_message(general_query_result)
    
    else:
        ## Running the Chain | TableSelection
        print(f'\n\t >>>>>>>> Classification: {query_type} | Calculative Query <<<<<<<<<')
        selected_table_text = table_selection_object.get_table_name(user_prompt_to_use)

        #### ------------------------------------------------ ## ------------------------------------------------ ####
        #### Retrieving best matched column info for SQL generation based on user prompt and selected table
        #### ------------------------------------------------ ## ------------------------------------------------ ####

        for run in range(2):
            table_to_use, filtered_df, best_match_columns, table_columns, column_info_from_knowledge_base, prompt_to_use_for_complex_question, columns_info_to_generate_sql = embedding_task_helper.get_column_info_for_sql_generation(user_prompt_to_use, selected_table_text)
            
            print(f'Table Selected: {table_to_use}')
            generated_sql_query, df = generate_and_execute_sql(
                user_prompt=user_prompt_to_use,
                table_to_use=table_to_use,
                table_columns=table_columns,
                column_info_from_knowledge_base=column_info_from_knowledge_base,
                prompt_to_use_for_complex_question=prompt_to_use_for_complex_question,
                columns_info_to_generate_sql = columns_info_to_generate_sql
            )

            if not df.empty:
                break
        # --------------------------------------------------------------------------------------------------------------------------------- # 

        #################### SQL RESPONSE #########################
        if sql_response_chain_result:

            if temp_use_case_for_cost_savings and not df.empty: ## ## ============ TEMPORARY USE CASE ============= ## ONLY FOR ONE WEEK
                temp_result = refine_cost_savings.calculate_unspsc_savings(df)
                user_prompt_to_use_temp = (
                    user_prompt +
                    f"\n\nThe following table contains supplier-level data retrieved from the database for the specified UNSPSC category:\n\n"
                    f"{str(temp_result).replace('{', '{{').replace('}', '}}')}\n\n"
                    f"Please display this table clearly in your response and use its values to perform a detailed analysis of potential cost savings.\n"
                    f"Your explanation must reference the actual spend, quantity, and unit price for each supplier, and explain how consolidating purchases to the lowest-cost supplier can lead to savings.\n"
                    f"Ensure the table appears in your final response to help the user easily interpret the figures."
                )
                sql_response_chain_result = sql_response_generator.get_sql_response(user_prompt_to_use_temp, generated_sql_query, df)
            else:
                sql_response_chain_result = sql_response_generator.get_sql_response(user_prompt_to_use, generated_sql_query, df)
            
            print(f'SQL Response: \n{sql_response_chain_result}')
            # Saving Memory Context | DataFrame
            memory_manager.store[user_id].add_user_message(prompt_to_save)
            memory_manager.store[user_id].add_ai_message(sql_response_chain_result)

else:
    print(f'\n\t >>>>>>>> Result from Memory <<<<<<<<<\n')
    print(f'{result_from_memory}')

################ ---------------------- ##################### # >>>>>>>> CHARTING <<<<<<<<< ANALYSIS >>>>>>>>>>>>>>>
# >>>>>>>> CHARTING <<<<<<<<< ANALYSIS >>>>>>>>>>>>>>> # >>>>>>>> CHARTING <<<<<<<<< ANALYSIS >>>>>>>>>>>>>>>
################ ---------------------- ##################### # >>>>>>>> CHARTING <<<<<<<<< ANALYSIS >>>>>>>>>>>>>>>

if False:
    
    charting_flag = False
    analysis_flag = False
    recommendation_flag = False
    follow_up_flag = False

    if charting_flag:
        saved_image_file_path, saved_image_name, html_content = hichart_module.get_highchart_info(user_prompt_to_use, df, user_id)

    if analysis_flag:
        analysis_chain_response = analysis_chain_manager.run_analysis_chain(user_prompt_to_use, df)

    if recommendation_flag:
        recommendations = recommendation_module.run_dynamic_recommendation_chain(user_prompt_to_use, df, '')
        print(f'\n\t >>>>>>>>  Recommendations: <<<<<<<<<\n{recommendations}')

    if follow_up_flag:
        previous_questions = [
            msg.content
            for msg in memory_manager.store[user_id].messages
            if isinstance(msg, HumanMessage)
        ][-8:] ## Feeding Latest 8 questions

        suggestions = question_suggestions_generator.get_suggestions(
            user_prompt= user_prompt_to_use,
            generated_sql_query=generated_sql_query,
            table_to_use=table_to_use,
            table_columns=table_columns,
            past_questions=previous_questions # 👈 Prevents repeats
        )

        print(suggestions.to_list())

print(f'\n================================================================================\n')

### Clear all Messages ###
memory_manager.store[user_id].clear()
# print(generated_sql_query)
df.head()


╭──────────────────────────────────────────────────────────────────────────────────────────────────╮
│─────────────────────────────────────────── User Query ───────────────────────────────────────────│
├──────────────────────────────────────────────────────────────────────────────────────────────────┤

whats the total spend [34mfor[39;49;00m the last [34m12[39;49;00m months[04m[91m?[39;49;00m[37m[39;49;00m


╰──────────────────────────────────────────────────────────────────────────────────────────────────╯


	 >>>>>>>> Classification: 2 | Calculative Query <<<<<<<<<
Table Selected: production_load.l_mt_cqo_p_event_with_outliers_chatbot_qa
[INFO] Running in Normal SQL Generation Mode...
[MODE 1:] No SQL provided. Generating SQL from prompt, column list, and table...
SQL Response: 
The total spend for the last 12 months is $547,315,674.60.

Data range: October 1, 2024 to August 1, 2025.




Unnamed: 0,total_spend,start_date,end_date
0,547315700.0,2024-10-01,2025-08-01


In [91]:
print(generated_sql_query)

SELECT 
    SUM(emr_total_acquisition_cost) AS total_spend,
    TO_CHAR(MIN(TO_DATE(emr_discharge_date, 'YYYY-MM-DD')), 'YYYY-MM-DD') AS start_date,
    TO_CHAR(MAX(TO_DATE(emr_discharge_date, 'YYYY-MM-DD')), 'YYYY-MM-DD') AS end_date
FROM 
    production_load.l_mt_cqo_p_event_with_outliers_chatbot_qa
WHERE 
    TO_DATE(emr_discharge_date, 'YYYY-MM-DD') >= DATEADD(month, -12, '2025-09-19');


In [68]:
# sql_response_chain_result = sql_response_generator.get_sql_response(user_prompt_to_use, generated_sql_query, df)
# print(sql_response_chain_result)

###  - Some Questions - 

In [69]:
# ------------------------------------
# --- Robotic Analysis ----
# ------------------------------------

# find out the total number of Inpatient Robotic Cases
# whats the total number of Robotic cases in 2024
# total number of robotic and non robotic cases for inpatient?
# total number of robotic and non robotic cases for inpatient and Outpatient ?
# whats the total number of unique procedures in robotic analysis?
# compare the ssi rate between robotic vs non-robotic cases
# whats the total number of cases available in robotic, non-robotic and overall


# ======================================

# How does avg CPC, and total encounter vary across top 20 procedures for @market LORAIN and @market @YOUNGSTOWN for @service line @WOMEN'S HEALTH?
# Compare the top 10 primary procedures by avg cpc for @market LORAIN and @market @YOUNGSTOWN for @service line @WOMEN'S HEALTH?
# What is the '42321610-SPINAL SCREWS OR SCREW EXTENSIONS' UNSPSC product utilization across physicians used in '0SG10AJ-FUSION 2-4 L JT W INTBD FUS DEV, POST APPR A COL, OPEN' procedure?

#### Response

In [70]:
# sql_response_chain_result = sql_response_generator.get_sql_response(user_prompt_to_use, generated_sql_query, df)
# Markdown(sql_response_chain_result)

In [71]:
# What is the total number of robotic and non-robotic cases for both inpatient and outpatient settings?
# How many unique procedures are included in the robotic procedure analysis?
# How does the SSI (Surgical Site Infection) rate compare between robotic and non-robotic cases?
# What is the total number of cases categorized as robotic, non-robotic, and overall?

# <h2 style="color: red; font-weight: bold;">🔥 ROUGH TASK 🔥</h2>
- Used for Correctional Purposes

### Save - > columns_info_to_generate_sql | Save -> prompt_to_use_for_complex_question

In [72]:
# # Save - > columns_info_to_generate_sql
# with open("../Rough - Notes/columns_info_to_generate_sql.txt", "w", encoding="utf-8") as f:
#     f.write(columns_info_to_generate_sql.strip())


# if prompt_to_use_for_complex_question:
#     with open("../Rough - Notes/prompt_to_use_for_complex_question.txt", "w", encoding="utf-8") as f:
#         f.write(prompt_to_use_for_complex_question.strip())


### Read from file - > columns_info_to_generate_sql

In [73]:
# # Read from file - > columns_info_to_generate_sql
# with open("../Rough - Notes/columns_info_to_generate_sql.txt", "r", encoding="utf-8") as f:
#     columns_info_to_generate_sql = f.read()

# with open("../Rough - Notes/prompt_to_use_for_complex_question.txt", "r", encoding="utf-8") as f:
#     prompt_to_use_for_complex_question_new = f.read()

# print(user_prompt_to_use)

# # ------------------------------------------- # 
# # generated_sql_query, df = sql_query_generator.get_result_from_sql_query(user_prompt_to_use, columns_info_to_generate_sql, table_to_use)

# generated_sql_query, sf = generate_and_execute_sql(
#     user_prompt=user_prompt_to_use,
#     table_to_use=table_to_use,
#     table_columns=table_columns,
#     column_info_from_knowledge_base=column_info_from_knowledge_base,
#     prompt_to_use_for_complex_question=prompt_to_use_for_complex_question_new,
#     columns_info_to_generate_sql = columns_info_to_generate_sql
# )
# # ------------------------------------------- # 

# print(generated_sql_query)
# sf.head()

# # # ------------------------------------------- # 
# # sql_response_chain_result = sql_response_generator.get_sql_response(user_prompt_to_use, generated_sql_query, df)
# # Markdown(sql_response_chain_result)
# # # ------------------------------------------- # 

In [74]:
# print(generated_sql_query)

# Get Column Values from each table (Temporary Use Case)

In [75]:
# def get_column_value_descriptions(table_name: str, column_list: list, connection_pool, sample_limit: int = 20) -> dict:
#     """
#     Generates descriptions for each column in a table using SQL and a DB connection.

#     Args:
#         table_name (str): The table to query.
#         column_list (list): List of column names to evaluate.
#         connection_pool: Active DB connection pool (e.g., psycopg2.pool.SimpleConnectionPool).
#         sample_limit (int): Number of values to fetch from each column.

#     Returns:
#         dict: Column descriptions based on unique values.
#     """
#     column_descriptions = {}

#     conn = connection_pool.getconn()
#     try:
#         cursor = conn.cursor()
#         for col in column_list:
#             query = f"""
#                 SELECT DISTINCT {col}
#                 FROM {table_name}
#                 WHERE {col} IS NOT NULL
#                 LIMIT {sample_limit};
#             """

#             try:
#                 cursor.execute(query)
#                 results = cursor.fetchall()
#                 values = [r[0] for r in results if r[0] is not None]

#                 if not values:
#                     column_descriptions[col] = "No available data"
#                     continue

#                 # Determine data type
#                 sample_type = type(values[0])

#                 if all(isinstance(v, (int, float)) for v in values):
#                     if all(isinstance(v, int) for v in values):
#                         desc = "This will be a Positive Integer Value" if min(values) > 0 else "This will be a Non-Negative Integer Value"
#                     else:
#                         desc = "This will be a Positive Float Value" if min(values) > 0 else "This will be a Non-Negative Float Value"
#                 else:
#                     str_values = [str(v).strip() for v in values if str(v).strip()]
#                     unique_count = len(str_values)

#                     if unique_count > 10:
#                         desc = "Example Values:\n" + "\n".join([f"{i+1}. {v}" for i, v in enumerate(str_values[:2])])
#                     else:
#                         desc = ", ".join(str_values[:10])

#                 column_descriptions[col] = desc

#             except Exception as e:
#                 column_descriptions[col] = f"Error retrieving data: {str(e)}"

#         cursor.close()
#     finally:
#         connection_pool.putconn(conn)

#     return column_descriptions


In [76]:
# l_mt_opportunity_case_procedure_p_event_with_outliers_v2_chatbot, l_mt_opportunity_case_procedure_with_outliers_v2_chatbot, production_load.l_mt_opportunity_case_procedure_p_event_w_outliers_rbt

In [77]:
# for tb in table_info_from_db['columns'][2]:
#     if tb not in table_info_from_db['columns'][0]:
#         print(tb)

In [78]:
# # # Table 1
# # table_name = list(column_dictionary.keys())[0]
# # columns = column_dictionary[table_name]

# ## Table 2
# # table_name = list(column_dictionary.keys())[1]
# # columns = column_dictionary[table_name]


# # Table 3
# table_name = list(column_dictionary.keys())[2]
# columns = column_dictionary[table_name]



# #############################################################
# # print(f'Table Running: {table_name}')
# column_info = get_column_value_descriptions(table_name, columns, connection_pool)

# for col, desc in column_info.items():
#     print(f"\n🧠 {col}:\n{desc}")


# ## Saving the File
# sf = pd.DataFrame([
#     {"column_name": k, "description": v}
#     for k, v in column_info.items()
# ])
# # sf.to_excel('temp_use_case/Pharma_Poc_Data_Column_Details.xlsx', index=False)
# sf.to_excel('temp_use_case/ras_table.xlsx', index=False)
# # sf.to_excel('temp_use_case/Drug_Shortages_Data_Column_Details.xlsx', index=False)