In [None]:
%%capture --no-stderr
%pip install -U --quiet langchain langgraph langchain_google_genai
%pip install -U --quiet saspy

In [1]:
import warnings
warnings.filterwarnings('ignore')
import math
import os
import getpass
import re
from collections import deque
from typing import Optional, Dict, Any, List, Tuple
from typing_extensions import TypedDict
import pandas as pd
from gradio_client import Client
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.output_parsers import PydanticOutputParser
from langchain_core.output_parsers import StrOutputParser
import google.generativeai as genai
from pydantic import BaseModel, Field
import logging
import saspy
import json
# Constants
END = "end"

In [2]:
genai.configure(api_key="____")

In [3]:
def _set_if_undefined(var: str) -> None:
    """Set environment variable if not already defined."""
    if os.environ.get(var):
        return
    os.environ[var] = getpass.getpass(var)

_set_if_undefined("GOOGLE_API_KEY")

GOOGLE_API_KEY ········


In [4]:
# Configure logging
logging.basicConfig(level=logging.INFO, format="""%(asctime)s - %(levelname)s - %(message)s
-------------------------------------------""")

In [5]:
metadata_file_path = 'metadata.json'  # Assuming the file is in the same directory
with open(metadata_file_path, 'r', encoding='utf-8') as f:
    metadata_dict = json.load(f)

In [None]:
sas = saspy.SASsession(cfgname='iomwin', cfgfile='sascfg.py')
logging.info("SAS session initialized successfully.")

In [9]:
from typing import Any, Dict, List, Tuple, Optional
import json
import logging
from dataclasses import dataclass

from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.prompts import PromptTemplate
from pydantic import BaseModel

class DatabaseMetadata:
    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description

class ColumnMetadata:
    def __init__(self, name: str, description: str):
        self.name = name
        self.description = description

class VerificationResult(BaseModel):
    is_valid: bool
    analysis: Dict[str, Any]
    improvements: Dict[str, Any]
    explanation: str
    raw_response: Optional[str] = None

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "VerificationResult":
        return cls(
            is_valid=data.get("is_valid", False),
            analysis=data.get("analysis", {}),
            improvements=data.get("improvements", {}),
            explanation=data.get("explanation", ""),
            raw_response=data.get("raw_response")
        )

class ResponseSchemas:
    """Centralized response schemas configuration"""

    @staticmethod
    def get_db_schemas() -> List[ResponseSchema]:
        return [
            ResponseSchema(
                name="databases",
                type="list",
                description="List of databases relevant to the query",
                properties=[
                    ResponseSchema(name="database", type="string", description="Name of the database"),
                    ResponseSchema(name="reason", type="string", description="Reason for selecting this database")
                ]
            )
        ]

    @staticmethod
    def get_column_analysis_schemas() -> List[ResponseSchema]:
        return [
            ResponseSchema(name="key_entities", type="list", description="List of key entities being queried"),
            ResponseSchema(name="filter_conditions", type="list", description="List of filter conditions in the query"),
            ResponseSchema(name="required_calculations", type="list", description="List of required calculations"),
            ResponseSchema(name="related_to_description", type="list", description="Parts of the query related to the database description")
        ]

    @staticmethod
    def get_column_selection_schemas() -> List[ResponseSchema]:
        return [
            ResponseSchema(name="columns", type="list", description="List of selected column names")
        ]

    @staticmethod
    def get_verification_schemas() -> List[ResponseSchema]:
        """Schema for database verification results."""
        return [
            ResponseSchema(name="is_valid", type="boolean", description="Whether the database selection is valid"),
            ResponseSchema(name="analysis", type="object", description="Analysis of the database selection", properties=[
                ResponseSchema(name="temporal_coverage", type="string", description="Analysis of temporal coverage"),
                ResponseSchema(name="processing_requirements", type="string", description="Analysis of processing requirements"),
                ResponseSchema(name="completeness", type="string", description="Analysis of database completeness")
            ]),
            ResponseSchema(name="improvements", type="object", description="Suggested improvements to the database selection", properties=[
                ResponseSchema(name="databases_to_add", type="object", description="Databases to add with reasons"),
                ResponseSchema(name="databases_to_remove", type="object", description="Databases to remove with reasons")
            ]),
            ResponseSchema(name="explanation", type="string", description="Explanation of the verification result")
        ]

    @staticmethod
    def get_column_verification_schemas() -> List[ResponseSchema]:
        """Schema for column verification results."""
        return [
            ResponseSchema(name="is_valid", type="boolean", description="Indicates if the column selection is valid"),
            ResponseSchema(name="analysis", type="object", description="Analysis of the column selection", properties=[
                ResponseSchema(name="mandatory_columns", type="string", description="Evaluation of mandatory column presence"),
                ResponseSchema(name="query_requirements", type="string", description="Verification of query requirement fulfillment"),
                ResponseSchema(name="rule_compliance", type="string", description="Assessment of adherence to specific rules")
            ]),
            ResponseSchema(name="improvements", type="object", description="Suggested improvements for column selection", properties=[
                ResponseSchema(name="columns_to_add", type="array", description="List of columns to be added"),
                ResponseSchema(name="columns_to_remove", type="array", description="List of columns to be removed")
            ]),
            ResponseSchema(name="explanation", type="string", description="Detailed explanation of the verification outcome")
        ]

class PromptTemplates:
    """Centralized prompt template configuration"""

    @staticmethod
    def get_db_selection_template() -> str:
        return """
You are a data analyst tasked with identifying the relevant databases to answer a specific query.

**Question:** {question}

**Available Databases:**
{database_descriptions}

{previous_verification_feedback}

**Instructions:**
<redacted for privacy reasons>
     
8. **Response Format:**
   - Respond strictly in a JSON dictionary: with the structure:  
   {{
       "databases" : [
          {{"database": "database_name1", "reason": "reason for selection"}},
          {{"database": "database_name2", "reason": "reason for selection"}},
          ...
       ]
   }}
"""

    @staticmethod
    def get_column_analysis_template() -> str:
        return """
You are a data analyst. Analyze the following query to extract key information.

**Question:** {question}
**Database:** {db_name}
**Reason for Selecting Database:** {reason}

**Tasks:**
1. Identify the **key entities** being queried (e.g., claims, policies, etc.).
2. List the **filter conditions** mentioned in the query (e.g., date ranges, status types, specific codes).
3. Determine the **required calculations** (e.g., reglements, reserves).
4. Identify which parts of the query relate to the **database description**.

{format_instructions}
"""

    @staticmethod
    def get_column_selection_template() -> str:
        return """
Based on the analysis below, select the relevant columns from the database.

Analysis from Prompt 1:
{analysis}

Columns in {db_name}:
{column_info}

{previous_verification_feedback}

**Specific Rules:**
<redacted for privacy reasons>
{format_instructions}
"""

    @staticmethod
    def get_verification_template() -> str:
        """Template for verifying database selections."""
        return """
    Verify the database selection for the following query:

    Question: {question}

    Current Selection:
    {selected_dbs}

    Available Databases:
    {available_databases}

    Verification Tasks:
    1. Check temporal coverage:
    - Are all required years covered?
    - Are any unnecessary years included?
    2. Verify processing requirements:
    - Are databases for relevant processing periods included?
    3. Assess database completeness:
    - Are all necessary databases for the query included?
    
    {format_instructions}
    """

    @staticmethod
    def get_column_verification_template() -> str:
        """Template for verifying column selections within a database."""
        return """
    Verify the column selection for the following query and database:

    Question: {question}
    Database: {db_name}

    Current Selection:
    {selected_columns}

    Available Columns:
    {available_columns}

    Verification Tasks:
    <redacted for privacy reasons>


    {format_instructions}
    """

    @staticmethod
    def get_question_expansion_template() -> str:
        return """
You are a helpful assistant tasked with clarifying and expanding user questions to make them more precise and easier to understand in the context of database metadata selection.

**Original Question:** {question}

**Instructions:**

1. **Analyze the Question:** Carefully read the original question and identify any ambiguities, implicit assumptions, or missing information.
2. **Expand and Clarify:** Rephrase the question to make it more explicit and precise.
    - Explicitly state any assumptions that should be made (e.g., the current year is 2024).
    - Resolve any ambiguities by using more specific terminology or adding context.
    - Incorporate relevant domain knowledge or rules related to database selection.
    - If specific databases or terms are mentioned, ensure they are accurately represented in the expanded question.
3. **Consider the Following Rules:**
    <redacted for privacy reasons>

4. **Output Format:** Provide the expanded and clarified question as a single string.

**Example:**

<redacted for privacy reasons>

**Expanded Question:**
"""

class MetadataVerifier:
    """Handles verification of database and column selections"""

    def __init__(self, model: ChatGoogleGenerativeAI):
        self.model = model
        self.verification_output_parser = StructuredOutputParser.from_response_schemas(
            ResponseSchemas.get_verification_schemas()
        )
        self.verification_prompt = PromptTemplate(
            template=PromptTemplates.get_verification_template(),
            input_variables=["question", "selected_dbs", "available_databases"],
            partial_variables={"format_instructions": self.verification_output_parser.get_format_instructions()}
        )
        self.column_verification_output_parser = StructuredOutputParser.from_response_schemas(
            ResponseSchemas.get_column_verification_schemas()
        )
        self.column_verification_prompt = PromptTemplate(
            template=PromptTemplates.get_column_verification_template(),
            input_variables=["question", "db_name", "selected_columns", "available_columns"],
            partial_variables={"format_instructions": self.column_verification_output_parser.get_format_instructions()}
        )

    def verify_databases(self, question: str, selected_dbs: Dict[str, str],
                            metadata_dict: Dict[str, Any]) -> Tuple[bool, VerificationResult]:
        """Verifies the initial database selection using the LLM"""
        try:
            _input = self.verification_prompt.format_prompt(
                question=question,
                selected_dbs=json.dumps(selected_dbs, indent=2, ensure_ascii=False),
                available_databases=json.dumps(
                    {k: v['description'] for k, v in metadata_dict.items()},
                    indent=2, ensure_ascii=False
                )
            )
            response = self.model.invoke(_input.to_messages())
            verification_result = VerificationResult.from_dict(
                self.verification_output_parser.parse(response.content)
            )
            verification_result.raw_response = response.content

            logging.info(f"Database verification: {'VALID' if verification_result.is_valid else 'INVALID'}")
            return verification_result.is_valid, verification_result

        except Exception as e:
            logging.error(f"Database verification error: {e}")
            return False, VerificationResult(
                is_valid=False,
                analysis={},
                improvements={},
                explanation=f"Verification failed: {str(e)}",
                raw_response=None
            )

    def verify_columns(self, question: str, db_name: str, selected_columns: List[str],
                      db_metadata: Dict[str, Any]) -> Tuple[bool, VerificationResult]:
        """Verifies the column selection for a given database using the LLM"""
        try:
            _input = self.column_verification_prompt.format_prompt(
                question=question,
                db_name=db_name,
                selected_columns=json.dumps(selected_columns, indent=2, ensure_ascii=False),
                available_columns=json.dumps(db_metadata['columns'], indent=2, ensure_ascii=False)
            )
            response = self.model.invoke(_input.to_messages())

            verification_result = VerificationResult.from_dict(
                self.column_verification_output_parser.parse(response.content)
            )
            verification_result.raw_response = response.content

            logging.info(f"Column verification for {db_name}: {'VALID' if verification_result.is_valid else 'INVALID'}")
            return verification_result.is_valid, verification_result

        except Exception as e:
            logging.error(f"Column verification error: {e}")
            return False, VerificationResult(
                is_valid=False,
                analysis={},
                improvements={},
                explanation=f"Verification failed: {str(e)}",
                raw_response=None
            )
            
class MetadataSelector:
    """Main class for handling metadata selection process"""

    def __init__(self, model_name: str = "gemini-1.5-flash", max_iterations: int = 5):
        self.model = ChatGoogleGenerativeAI(
            model=model_name,
            temperature=0,
            convert_system_message_to_human=True
        )
        self.verifier = MetadataVerifier(self.model)
        self.max_iterations = max_iterations

        # Initialize parsers
        self._init_parsers()

        # Initialize prompts
        self._init_prompts()

    def _init_parsers(self):
        """Initialize all output parsers"""
        self.db_output_parser = StructuredOutputParser.from_response_schemas(
            ResponseSchemas.get_db_schemas()
        )
        self.column_analysis_parser = StructuredOutputParser.from_response_schemas(
            ResponseSchemas.get_column_analysis_schemas()
        )
        self.column_selection_parser = StructuredOutputParser.from_response_schemas(
            ResponseSchemas.get_column_selection_schemas()
        )

    def _init_prompts(self):
        """Initialize all prompt templates"""
        self.db_prompt = PromptTemplate(
            template=PromptTemplates.get_db_selection_template(),
            input_variables=["question", "database_descriptions"],
            partial_variables={"format_instructions": self.db_output_parser.get_format_instructions()}
        )
        self.column_analysis_prompt = PromptTemplate(
            template=PromptTemplates.get_column_analysis_template(),
            input_variables=["question", "db_name", "reason"],
            partial_variables={"format_instructions": self.column_analysis_parser.get_format_instructions()}
        )
        self.column_selection_prompt = PromptTemplate(
            template=PromptTemplates.get_column_selection_template(),
            input_variables=["analysis", "db_name", "column_info"],
            partial_variables={"format_instructions": self.column_selection_parser.get_format_instructions()}
        )

    def _expand_question(self, question: str) -> str:
        """Expands the original question to make it clearer and more explicit."""
        prompt = PromptTemplate(
            template=PromptTemplates.get_question_expansion_template(),
            input_variables=["question"],
        )
        _input = prompt.format_prompt(question=question)
        response = self.model.invoke(_input.to_messages())
        return response.content

    def select_metadata(self, question: str, metadata_dict: Dict[str, Any]) -> Dict[str, Any]:
        """Main method to select metadata in two stages"""
        try:
            # Stage 1: Database selection
            selected_dbs = self._select_and_verify_databases(question, metadata_dict)
            if not selected_dbs:
                return {"error": "Failed to select valid databases"}

            # Stage 2: Column selection
            selected_columns = self._select_and_verify_columns(question, selected_dbs, metadata_dict)

            return self._build_final_metadata(selected_dbs, selected_columns, metadata_dict)
        except Exception as e:
            logging.error(f"Metadata selection failed: {e}")
            return {"error": str(e)}

    def _select_and_verify_databases(self, question: str, metadata_dict: Dict[str, Any]) -> Dict[str, str]:
        """Database selection with verification"""
        iteration = 0
        selected_dbs = {}
        previous_verification = None

        while iteration < self.max_iterations:
            logging.info(f"Database selection iteration {iteration + 1}")

            # Select databases using LLM
            selected_dbs_with_reasons = self._select_databases(question, metadata_dict, previous_verification)
            selected_dbs = {db: reason for db, reason in selected_dbs_with_reasons.items()}

            if not selected_dbs:
                iteration += 1
                continue

            # Verify selection
            is_valid, feedback = self.verifier.verify_databases(question, selected_dbs, metadata_dict)
            if is_valid:
                logging.info("Database selection verified successfully")
                return selected_dbs

            # Apply improvements
            if isinstance(feedback, VerificationResult):
                selected_dbs = self._refine_databases(selected_dbs, feedback.improvements)
                previous_verification = feedback

            iteration += 1

        logging.warning("Max iterations reached in database selection")
        return selected_dbs

    def _select_databases(self, question: str, metadata_dict: Dict[str, Any],
                          previous_verification: Optional[VerificationResult] = None) -> Dict[str, str]:
        """Select databases using LLM"""
        try:
            # Expand the question before using it for database selection
            expanded_question = self._expand_question(question)

            database_descriptions = self._format_database_descriptions(metadata_dict)

            # Modify the prompt to include feedback if available
            if previous_verification:
                feedback_info = f"""
Previous Verification Feedback:
- Analysis: {previous_verification.analysis}
- Explanation: {previous_verification.explanation}
- Raw Response: {previous_verification.raw_response}
"""
            else:
                feedback_info = ""

            _input = self.db_prompt.format_prompt(
                question=expanded_question,  # Use the expanded question here
                database_descriptions=database_descriptions,
                previous_verification_feedback=feedback_info
            )
            response = self.model.invoke(_input.to_messages())
            return self._parse_database_selection(response)
        except Exception as e:
            logging.error(f"Database selection error: {e}")
            return {}
    
    def _select_and_verify_columns(
        self,
        question: str,
        selected_dbs: Dict[str, str],
        metadata_dict: Dict[str, Any]
    ) -> Dict[str, List[str]]:
        """Column selection with verification"""
        final_columns = {}

        for db_name in selected_dbs:
            if db_name not in metadata_dict:
                logging.warning(f"Database '{db_name}' not found in metadata_dict.")
                continue

            iteration = 0
            selected_cols = []
            previous_verification = None

            while iteration < self.max_iterations:
                logging.info(f"Column selection iteration {iteration + 1} for {db_name}")

                # Select columns using LLM
                selected_cols = self._select_columns(
                    question,
                    db_name,
                    selected_dbs[db_name],
                    metadata_dict[db_name],
                    previous_verification,
                )

                if not selected_cols:
                    iteration += 1
                    continue

                # Verify selection
                is_valid, feedback = self.verifier.verify_columns(
                    question, db_name, selected_cols, metadata_dict[db_name]
                )

                if is_valid:
                    logging.info(f"Column selection verified successfully for {db_name}")
                    final_columns[db_name] = selected_cols
                    break

                # Apply improvements
                if isinstance(feedback, VerificationResult):
                    selected_cols = self._refine_columns(
                        selected_cols, feedback.improvements
                    )
                    previous_verification = feedback

                iteration += 1

            if iteration >= self.max_iterations:
                logging.warning(
                    f"Max iterations reached in column selection for {db_name}"
                )
                final_columns[db_name] = selected_cols

        return final_columns

    def _select_columns(
        self,
        question: str,
        db_name: str,
        reason: str,
        db_metadata: Dict[str, Any],
        previous_verification: Optional[VerificationResult] = None,
    ) -> List[str]:
        """Select columns using two-step process, optionally incorporating feedback from previous verification"""
        try:
            analysis = self._perform_column_analysis(question, db_name, reason)

            # Modify the prompt to include feedback if available
            if previous_verification:
                feedback_info = (
                    f"Previous Verification Feedback:\n"
                    f"- Analysis: {previous_verification.analysis}\n"
                    f"- Explanation: {previous_verification.explanation}\n"
                    f"- Raw Response: {previous_verification.raw_response}\n"
                )
            else:
                feedback_info = ""

            return self._perform_column_selection(
                analysis, db_name, db_metadata, feedback_info
            )
        except Exception as e:
            logging.error(f"Column selection error for {db_name}: {e}")
            return []

    def _format_database_descriptions(self, metadata_dict: Dict[str, Any]) -> str:
        return "\n".join(
            f"- {db_name}: {db_info['description']}"
            for db_name, db_info in metadata_dict.items()
        )

    def _get_llm_response(self, question: str, database_descriptions: str) -> str:
        _input = self.db_prompt.format_prompt(
            question=question, database_descriptions=database_descriptions
        )
        return self.model.invoke(_input.to_messages()).content

    def _parse_database_selection(self, response: str) -> Dict[str, str]:
        try:
            parsed_output = self.db_output_parser.parse(response.content)
            return {item["database"]: item["reason"] for item in parsed_output["databases"]}
        except Exception as e:
            logging.error(f"Error parsing database selection response: {e}")
            return {}

    def _perform_column_analysis(
        self, question: str, db_name: str, reason: str
    ) -> Dict[str, Any]:
        _input = self.column_analysis_prompt.format_prompt(
            question=question, db_name=db_name, reason=reason
        )
        response = self.model.invoke(_input.to_messages())
        return self.column_analysis_parser.parse(response.content)

    def _perform_column_selection(
        self, analysis: Dict[str, Any], db_name: str, db_metadata: Dict[str, Any], feedback_info: str = ""
    ) -> List[str]:
        """Perform column selection, optionally with feedback info"""
        column_info = json.dumps(db_metadata["columns"], indent=2)
        _input = self.column_selection_prompt.format_prompt(
            analysis=analysis,
            db_name=db_name,
            column_info=column_info,
            previous_verification_feedback=feedback_info,
        )
        response = self.model.invoke(_input.to_messages())
        parsed_output = self.column_selection_parser.parse(response.content)
        return parsed_output["columns"]

    def _refine_databases(
        self, current_dbs: Dict[str, str], improvements: Dict[str, Any]
    ) -> Dict[str, str]:
        """Refines database selection based on verification feedback"""
        refined_dbs = current_dbs.copy()

        if improvements is None:
            return current_dbs

        # Add new databases
        to_add = improvements.get("databases_to_add", {})
        if isinstance(to_add, dict):
            refined_dbs.update(to_add)

        # Remove specified databases
        to_remove = improvements.get("databases_to_remove", {})
        if isinstance(to_remove, dict):
            for db in to_remove:
                refined_dbs.pop(db, None)

        return refined_dbs

    def _refine_columns(
        self, current_columns: List[str], improvements: Dict[str, Any]
    ) -> List[str]:
        """Refines column selection based on verification feedback"""
        refined_columns = current_columns.copy()

        if improvements is None:
            return current_columns

        # Add new columns
        to_add = improvements.get("columns_to_add", [])
        if isinstance(to_add, list):
            refined_columns.extend(to_add)

        # Remove specified columns
        to_remove = improvements.get("columns_to_remove", [])
        if isinstance(to_remove, list):
            refined_columns = [col for col in refined_columns if col not in to_remove]

        return list(set(refined_columns))  # Remove duplicates

    def _build_final_metadata(
        self,
        selected_dbs: Dict[str, str],
        selected_columns: Dict[str, List[str]],
        metadata_dict: Dict[str, Any],
    ) -> Dict[str, Any]:
        """Builds final metadata dictionary"""
        final_metadata = {}
        for db_name in selected_dbs:
            if db_name in metadata_dict:
                db_metadata = {
                    "description": metadata_dict[db_name]["description"],
                    "columns": {},
                }
                if db_name in selected_columns:
                    for col in selected_columns[db_name]:
                        if col in metadata_dict[db_name]["columns"]:
                            db_metadata["columns"][col] = metadata_dict[db_name][
                                "columns"
                            ][col]
                final_metadata[db_name] = db_metadata
        return final_metadata

def format_metadata_to_string(metadata_dict: Dict[str, Any]) -> str:
    """Formats the metadata dictionary into a human-readable string."""
    formatted_string = ""
    for db_name, db_info in metadata_dict.items():
        formatted_string += f"\n{db_name}:\n"
        if isinstance(db_info, dict):  # Check if db_info is a dictionary
            formatted_string += f"  Description: {db_info.get('description', 'N/A')}\n"
            formatted_string += "  Columns:\n"
            for col_name, col_info in db_info.get('columns', {}).items():
                formatted_string += f"    - {col_name}: {col_info.get('description', 'N/A')}\n"
        else:
            formatted_string += f"  Error: {db_info}\n"  # Handle the case where db_info is a string (error message)
    return formatted_string

In [10]:
# --- Utility Functions ---
def execute_sas_code(sas_code: str) -> str:
    """Execute SAS code and return log."""
    logging.info("Executing SAS code...")
    try:
        result = sas.submit(sas_code)
        logging.info("SAS code executed. Returning log.")
        return result['LOG']
    except Exception as e:
        logging.error(f"Error executing SAS code: {e}")
        return str(e)

def extract_data(sas_code: str, client) -> None:
    """Execute SAS code and save results to Excel."""
    logging.info("Extracting data...")
    try:
        sas.submit(sas_code)
        # Assuming 'client' is defined globally or passed as a parameter
        df_name_result = client.predict(
            query=f"Extract final dataframe name from, dont include 'WORK.', just the raw name:\n{sas_code}",
            api_name="/generation_code"
        )

        if df_name_result:
            df_name = df_name_result[0]
            df = sas.sd2df(df_name)
            if not df.empty:
                excel_file = f"{df_name}.xlsx"
                df.to_excel(excel_file, index=False)
                logging.info(f"Data extracted to {excel_file}")
            else:
                logging.warning(f"Dataframe {df_name} is empty.")
        else:
            logging.warning("Could not determine dataframe name for extraction.")
    except Exception as e:
        logging.error(f"Error during data extraction: {e}")

In [11]:
# --- Prompt Templates and System Messages ---
def get_thinker_system_prompt(necessary_metadata):
    return f"""
You are StarThinker, a strategic module of the StarData AI coding agent.
You have a basic understanding of actuary line of thinking, especially concerning automobile insurance.
Your primary role is to analyze a given data query and devise multiple, distinct strategies for answering it using SAS code. You do not generate code, only high-level strategies.

**Your Task:**

- Given a data query and metadata, generate distinct and exhaustive strategies outlining how to approach the problem using the available data in the SAS databases.
- Strategies should be concise (1-2 sentences each) and explore different logical paths to the solution.
- Consider various ways to utilize the provided metadata, which includes database names, descriptions, column names, descriptions, types, and potential values.
- Dynamically determine the appropriate number of strategies based on the complexity of the query. Stop generating new strategies if you deem them to be sufficient and exhaustive.
- give your response in this format :
    'strategy 1 : <details of this strategy>\n
    strategy 2 : <details of this strategy>\n
    ...'
**Key Considerations from StarData:**

<redacted for privacy reasons>

**Metadata:**

{necessary_metadata}
"""

def get_solver_system_prompt(necessary_metadata):
    return f"""
You are StarSolver, a code generation module of the StarData AI coding agent.
You have a basic understanding of actuary line of thinking, especially concerning automobile insurance.
You are an expert in SAS programming. Your task is to generate SAS code that accurately answers data queries based on provided strategies and metadata.

**Your Task:**

- Given a data query and a specific strategy, generate a complete and executable SAS code solution.
- The code should be syntactically correct and produce the desired output based on the query and strategy.
- Output ONLY the SAS code, without any explanations or additional text.
- Save all new dataframes in the WORK library (temporary).
- When doing union joins, select only relevant columns.
- Adhere to all coding instructions and guidelines specified below.

**Coding Instructions from StarData:**

<redacted for privacy reasons>

**Metadata:**

{necessary_metadata}
"""

def get_debugger_system_prompt(necessary_metadata):
    return f"""
You are StarDebugger, a code refinement module of the StarData AI coding agent.
You have a basic understanding of actuary line of thinking, especially concerning automobile insurance.
You are an expert in SAS programming and debugging. Your task is to refine and improve SAS code solutions based on provided feedback.

**Your Task:**

- Given a SAS code solution, and feedback (which may include error messages, requirement shortcomings, or suggestions), generate a corrected and improved version of the code.
- The refined code should address all issues mentioned in the feedback and produce the desired output according to the original query.
- Output ONLY the refined SAS code, without any explanations or additional text.

**Coding and Debugging Instructions from StarData:**
<redacted for privacy reasons>

**Metadata:**

{necessary_metadata}
"""

def get_critic_system_prompt(necessary_metadata):
    return f"""
You are StarCritic, a solution evaluator module of the StarData AI coding agent.
You have a basic understanding of actuary line of thinking, especially concerning automobile insurance.
You are an expert in SAS programming and analysis. Your task is to evaluate SAS code solutions generated by the Solver agent, provide feedback, and assess their suitability for refinement or acceptance.

**Your Task:**

- Given a SAS code solution, its corresponding strategy, and the original query, evaluate the solution's correctness, adherence to the strategy, and overall quality.
- Provide a numerical score (critic_score) that reflects the solution's quality and potential for improvement.
- Generate specific feedback (feedback) that identifies errors, missing requirements, and areas for improvement.
- Determine whether the solution should be refined, aborted, or accepted based on your evaluation.

**Evaluation Criteria:**

- **Correctness:** Does the code execute without errors? Does it produce the expected output for visible test cases?
- **Strategy Adherence:** Does the code effectively implement the given strategy? Does it logically follow the strategy and use appropriate data structures and algorithms?
- **Robustness:** Is the solution likely to be correct for unseen test cases? Does it handle potential edge cases and demonstrate a general understanding of the problem?
- **Requirements Fulfillment:** Does the code meet all the requirements stated in the original query? Are there any missing functionalities or discrepancies?

**Decision Logic:**

- **Refine:** If the solution has errors, fails to meet requirements, or does not fully adhere to the strategy, it should be refined.
- **Abort:** If the solution has major flaws, scores very low on the evaluation criteria, or is unlikely to be improved with further refinement, it should be aborted.
- **Accept:** If the solution passes all visible test cases, adheres to the strategy, meets all requirements, and is deemed robust, it should be accepted as the final solution.

**Additional Instructions:**

<redacted for privacy reasons>

**Metadata:**

{necessary_metadata}
"""

In [12]:
# --- Agent Classes ---
class Thinker:
    def __init__(self, llm, necessary_metadata):
        self.llm = llm
        self.necessary_metadata = necessary_metadata
        self.strategies_cache: Dict[str, List[str]] = {}
        self.reflections_cache: Dict[Tuple[str, str, str, str], List[str]] = {}

    def generate_strategies(self, question: str, max_strategies: int = 5, previous_strategies: List[str] = None) -> List[str]:
        """Generates multiple strategies for solving the coding problem autoregressively."""
        logging.info(f"Generating strategies for question: {question}")
        if question in self.strategies_cache:
            logging.info(f"Using cached strategies for question: {question}")
            return self.strategies_cache[question]

        strategies = []
        
        if previous_strategies is not None:
            strategies = previous_strategies

        for i in range(max_strategies):
            prompt = self._build_strategy_prompt(question, strategies)
            try:
                response = (prompt | self.llm | StrOutputParser()).invoke({
                    "question": question,
                    "necessary_metadata": self.necessary_metadata,
                    "previous_strategies": "\n".join(strategies)
                })
            except Exception as e:
                logging.error(f"Error during strategy generation: {e}")
                break

            new_strategy = self._extract_strategy(response)
            if not new_strategy or new_strategy.lower() == "no further strategies.":
                logging.info("No further strategies generated.")
                break

            strategies.append(new_strategy)
            logging.info(f"Generated strategy {i+1}: {new_strategy}")

        self.strategies_cache[question] = strategies
        return strategies

    def _build_strategy_prompt(self, question: str, previous_strategies: List[str]) -> ChatPromptTemplate:
        """Builds the prompt for generating the next strategy."""
        system_prompt = get_thinker_system_prompt(self.necessary_metadata)

        if previous_strategies:
            system_prompt += "\n\nPrevious Strategies:\n" + "\n".join(
                [f"{i+1}. {strategy}" for i, strategy in enumerate(previous_strategies)]
            )
        system_prompt += "\n\nGenerate the next strategy, or write 'No further strategies.' if no further strategies can be generated."

        prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("user", f"Question: {question}\nStrategy:")
        ])
        return prompt

    def _extract_strategy(self, response: str) -> str:
        """Extracts a strategy from the LLM's response."""
        # Use regex to find the strategy after 'strategy x :'
        match = re.search(r"strategy \d+ : (.*)", response, re.IGNORECASE)
        if match:
            return match.group(1).strip()
        else:
            return ""

    def generate_reflections(self, question: str, solution: str, feedback: str, log: str, strategy: str,
                             num_reflections: int = 3) -> List[str]:
        """Generates reflections on the code autoregressively."""
        logging.info(f"Generating reflections for question: {question}")
        cache_key = (question, solution, feedback, log)
        if cache_key in self.reflections_cache:
            logging.info(f"Using cached reflections for question: {question}")
            return self.reflections_cache[cache_key]

        reflections = []
        for _ in range(num_reflections):
            prompt = self._build_reflection_prompt(question, solution, feedback, log, strategy, reflections)
            try:
                response = (prompt | self.llm | StrOutputParser()).invoke({
                    "question": question,
                    "solution": solution,
                    "feedback": feedback,
                    "log": log,
                    "strategy": strategy,
                    "necessary_metadata": self.necessary_metadata,
                    "previous_reflections": "\n".join(reflections)
                })
            except Exception as e:
                logging.error(f"Error during reflection generation: {e}")
                break

            new_reflection = response.strip()
            if not new_reflection or new_reflection.lower() == "no further reflections.":
                logging.info("No further reflections generated.")
                break
            reflections.append(new_reflection)
            logging.info(f"Generated reflection: {new_reflection}")

        self.reflections_cache[cache_key] = reflections
        return reflections

    def _build_reflection_prompt(self, question: str, solution: str, feedback: str, log: str, strategy: str,
                                 previous_reflections: List[str]) -> ChatPromptTemplate:
        """Builds the prompt for generating the next reflection."""
        system_prompt = get_thinker_system_prompt(self.necessary_metadata)
        system_prompt += f"\n\nStrategy used: {strategy}"
        system_prompt += "\n\nYour task is to generate reflections on the following SAS code solution, considering the feedback and log provided. "
        system_prompt += "Reflections should be concise and focus on identifying areas for improvement or issues in the code."

        if previous_reflections:
            system_prompt += "\n\nPrevious Reflections:\n" + "\n".join(
                [f"{i+1}. {reflection}" for i, reflection in enumerate(previous_reflections)]
            )
            system_prompt += "\n\nGenerate the next reflection, or write 'No further reflections.' if no further reflections are needed"

        prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("user", f"Given the question: '{question}', the strategy: '{strategy}', the following SAS code solution: \n"
                     f"```sas\n{solution}\n```\n"
                     f"and the feedback from the Critic: '{feedback}',\n"
                     f"and the log from the code execution : '{log}', \n"
                     f"please generate reflections to guide the code refinement process.")
        ])
        return prompt

class Solver:
    def __init__(self, llm, necessary_metadata):
        self.llm = llm
        self.necessary_metadata = necessary_metadata
        self.solutions_cache: Dict[Tuple[str, str], str] = {}

    def generate_solution(self, question: str, strategy: str) -> str:
        """Generates an initial SAS code solution based on the given strategy."""
        logging.info(f"Generating solution for question: {question} with strategy: {strategy}")
        cache_key = (question, strategy)
        if cache_key in self.solutions_cache:
            logging.info(f"Using cached solution for question: {question} and strategy: {strategy}")
            return self.solutions_cache[cache_key]

        prompt = ChatPromptTemplate.from_messages([
            ("system", get_solver_system_prompt(self.necessary_metadata)),
            ("user", f"Given the strategy: '{strategy}', please write SAS code to answer this question: {question}")
        ])
        try:
            response = (prompt | self.llm | StrOutputParser()).invoke({
                "question": question,
                "strategy": strategy,
                "necessary_metadata": self.necessary_metadata
            })
            solution = self._extract_code(response)
            logging.info(f"Generated solution:\n{solution}")
        except Exception as e:
            logging.error(f"Error during solution generation: {e}")
            solution = ""

        self.solutions_cache[cache_key] = solution
        return solution

    def _extract_code(self, response: str) -> str:
        """Extracts the SAS code from the LLM's response."""
        return response.strip()

class Debugger:
    def __init__(self, llm, necessary_metadata):
        self.llm = llm
        self.necessary_metadata = necessary_metadata
        self.debugged_solutions_cache: Dict[Tuple[str, str, Tuple[str]], str] = {}  # Update type hint

    def generate_refinement(self, question: str, solution: str, reflections: list) -> str:
        """Refines the given SAS code solution based on reflections."""
        logging.info(f"Generating refinement for question: {question}")
        cache_key = (question, solution, tuple(reflections))  # Convert to tuple
        if cache_key in self.debugged_solutions_cache:
            logging.info(f"Using cached refined solution for question: {question}")
            return self.debugged_solutions_cache[cache_key]

        prompt = ChatPromptTemplate.from_messages([
            ("system", get_debugger_system_prompt(self.necessary_metadata)),
            ("user",
             f"Given the reflections: '{reflections}', please refine the following SAS code:\n```sas\n{solution}\n```")
        ])
        try:
            response = (prompt | self.llm | StrOutputParser()).invoke({
                "question": question,
                "solution": solution,
                "reflections": reflections,
                "necessary_metadata": self.necessary_metadata
            })
            refined_solution = self._extract_code(response)
            logging.info(f"Generated refined solution:\n{refined_solution}")
        except Exception as e:
            logging.error(f"Error during refinement generation: {e}")
            refined_solution = ""

        self.debugged_solutions_cache[cache_key] = refined_solution
        return refined_solution

    def _extract_code(self, response: str) -> str:
        """Extracts the SAS code from the LLM's response."""
        return response.strip()

class Critic:
    def __init__(self, llm, client, necessary_metadata, config: Dict):
        self.llm = llm
        self.client = client
        self.necessary_metadata = necessary_metadata
        self.evaluation_cache = {}
        self.config = config

    def evaluate_solution(self, question: str, solution: str, strategy: str) -> Tuple[str, float, str]:
        """Evaluates the generated solution, provides feedback, and assigns a numerical score."""
        logging.info(f"Evaluating solution for question: {question}")
        cache_key = (question, solution, strategy)
        if cache_key in self.evaluation_cache:
            logging.info(f"Using cached evaluation for question: {question}")
            return self.evaluation_cache[cache_key]
            
        if "```sas" in solution:
            start_index = solution.find("```sas") + 6  # +6 to skip "```sas"
            end_index = solution.find("```", start_index)
            if end_index != -1:
              solution = solution[start_index:end_index].strip()
            
        log = execute_sas_code(solution)
        try:
            error_reflection = self._get_error_reflection(question, solution, log)
            requirement_reflection = self._get_requirement_reflection(question, solution, log)
            strategy_adherence_reflection = self._get_strategy_adherence_reflection(question, solution, strategy)
        except Exception as e:
            logging.error(f"Error during reflections generation : {e}")
            raise

        # --- More Detailed Evaluation ---
        evaluation = ""
        critic_score = 0.0
        feedback = ""

        if error_reflection.error_count > 0:
            feedback += "Errors:\n"
            feedback += f"- {error_reflection.error_log}\n"
            critic_score += self.config["error_weight"]  # Use configurable weight
            evaluation = "Refine"

        if not requirement_reflection.requirements_met:
            feedback += "Missing Requirements:\n"
            feedback += f"- {requirement_reflection.missing_requirements}\n"
            critic_score += self.config["requirement_weight"]  # Use configurable weight
            evaluation = "Refine" if evaluation != "Refine" else "Refine"

        # Strategy Adherence Score
        critic_score = (
                            critic_score + strategy_adherence_reflection.adherence_score
                        ) / 2 if evaluation == "Refine" else strategy_adherence_reflection.adherence_score

        feedback += f"Strategy Adherence: {strategy_adherence_reflection.adherence_score:.2f} - {strategy_adherence_reflection.reasoning}\n"

        # --- Decision Logic ---
        if error_reflection.error_count == 0 and requirement_reflection.requirements_met:
            # Solution passes visible tests, perform verification
            try:
                if self.verify_solution(question, solution):
                    evaluation = "Accept"
                    critic_score = 1.0  # Perfect score if it passes verification
                else:
                    evaluation = "Refine"
                    feedback += "Solution passes visible tests but fails verification (potential overfitting or lack of robustness).\n"
                    critic_score = self.config["verification_fail_score"]  # Use configurable score
            except Exception as e:
                logging.error(f"Error during solution verification : {e}")
                evaluation = "Refine"
                feedback += "Solution passes visible tests but fails verification (potential overfitting or lack of robustness).\n"
                critic_score = self.config["verification_fail_score"]
        elif critic_score < self.config["abort_threshold"]:  # Use configurable threshold
            evaluation = "Abort"

        logging.info(f"Evaluation: {evaluation}, Critic Score: {critic_score}, Feedback: {feedback}")
        self.evaluation_cache[cache_key] = (evaluation, critic_score, feedback)
        return evaluation, critic_score, feedback

    def verify_solution(self, question: str, solution: str) -> bool:
        """Verifies if a solution that passes visible tests is robust and generalizable."""
        logging.info(f"Verifying solution for question: {question}")
        verification_prompt = ChatPromptTemplate.from_messages([
            ("system", get_critic_system_prompt(self.necessary_metadata)),
            ("user", f"""
        We have a SAS code solution that passes all visible test cases for the query: '{question}'.
        ```sas
        {solution}
        ```
        Your task is to assess whether this solution is likely to be correct for unseen test cases as well. 
        Consider the following:
        - Does the code seem overly tailored to the specific visible test cases, or does it demonstrate a general understanding of the problem?
        - Are there any potential edge cases or scenarios not covered by the visible tests that the code might fail on?

        Answer with 'True' if the solution is likely to be correct for unseen test cases, and 'False' otherwise. Provide a brief explanation for your assessment.
        """)
    ])
        try:
            response = (verification_prompt | self.llm | StrOutputParser()).invoke({
            "question": question,
            "solution": solution,
            "necessary_metadata": self.necessary_metadata
        })

            logging.info(f"Raw verification response: {response}")  # Log the raw response

            # Normalize the response to handle variations (e.g., "True.", "True ", "TRUE")
            cleaned_response = response.strip().lower() 

            logging.info(f"Cleaned verification response: {cleaned_response}") # Log the cleaned response

            if "true" in cleaned_response:
                logging.info("Solution verification successful.")
                return True
            else:
                logging.warning(f"Solution verification failed: {response}")
                return False

        except Exception as e:
             logging.error(f"Error during solution verification: {e}")
             return False

    def _get_error_reflection(self, query: str, sas_code: str, log: str) -> "ErrorReflection":
        """
        Identifies errors in the SAS log and provides an ErrorReflection.
        Now also tries to categorize the error.
        """
        logging.info("Getting error reflection...")
        parser = PydanticOutputParser(pydantic_object=ErrorReflection)

        prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are a helpful assistant that identifies and categorizes errors in SAS code execution logs. {format_instructions}"),
            ("user", "Query: {query}\nCode:\n```sas\n{sas_code}\n```\nLog:\n{log}")
        ])
        try:
            response = (prompt | self.llm | parser).invoke({
                "query": query,
                "sas_code": sas_code,
                "log": log,
                "format_instructions": parser.get_format_instructions()
            })
            logging.info(
                f"Error reflection: Error Count: {response.error_count}, Error Log: {response.error_log}, Error Category: {response.error_category}")
            return response
        except Exception as e:
            logging.error(f"Error during error reflection generation : {e}")
            raise

    def _get_requirement_reflection(self, query: str, sas_code: str, log: str) -> "RequirementReflection":
        """
        Determines whether the SAS code meets the requirements of the original query and provides a RequirementReflection.
        Now uses a PydanticOutputParser for structured output.
        """
        logging.info("Getting requirement reflection...")
        parser = PydanticOutputParser(pydantic_object=RequirementReflection)

        prompt = ChatPromptTemplate.from_messages([
            ("system", f"""
            You are a helpful assistant that identifies missing requirements in SAS code.
            Use the metadata to understand the database and ensure the requirements are aligned with the database's structure.
            {{format_instructions}}
            metadata : {self.necessary_metadata}"""),
            ("user", f"""
            Original Query: {query}
            Generated SAS Code:
            ```sas
            {sas_code}
            ```
            Does the SAS code fulfill all requirements stated in the original query?
            Answer with 'True' if everything is generally satisfying to the average user, only give back "False" if there flagrant error of totally not understanding the query.
            """)
        ])
        try:
            response = (prompt | self.llm | parser).invoke({
                "query": query,
                "sas_code": sas_code,
                "necessary_metadata": self.necessary_metadata,
                "format_instructions": parser.get_format_instructions()
            })
            logging.info(
                f"Requirement reflection: Missing Requirements: {response.missing_requirements}, Requirements Met: {response.requirements_met}")
            return response
        except Exception as e:
            logging.error(f"Error during requirement reflection generation : {e}")
            raise

    def _get_strategy_adherence_reflection(self, query: str, sas_code: str,
                                            strategy: str) -> "StrategyAdherenceReflection":
        """
        Evaluates the adherence of the SAS code to the given strategy and provides a StrategyAdherenceReflection.
        Uses a PydanticOutputParser for structured output, including a numerical score.
        """
        logging.info("Getting strategy adherence reflection...")
        parser = PydanticOutputParser(pydantic_object=StrategyAdherenceReflection)

        prompt = ChatPromptTemplate.from_messages([
            ("system", f"""
            You are a helpful assistant that evaluates the adherence of SAS code to a given strategy.
            Use the metadata to understand the database and ensure the code aligns with the strategy and the database's structure.
            {{format_instructions}}
            metadata : {self.necessary_metadata}
            """),
            ("user", f"""
            Original Query: {query}
            Strategy: {strategy}
            Generated SAS Code:
            ```sas
            {sas_code}
            ```
            Assess the adherence of the SAS code to the given strategy. Consider:
            1. How well does the code logically follow the strategy?
            2. Does the code use appropriate data structures and algorithms as suggested by the strategy?
            3. Are there any deviations from the strategy, and if so, are they justified?
            Provide a numerical adherence score between 0 and 1, where 1 represents perfect adherence and 0 represents no adherence.
            """)
        ])

        try:
            response = (prompt | self.llm | parser).invoke({
                "query": query,
                "sas_code": sas_code,
                "strategy": strategy,
                "necessary_metadata": self.necessary_metadata,
                "format_instructions": parser.get_format_instructions()
            })
            logging.info(
                f"Strategy adherence reflection: Adherence Score: {response.adherence_score}, Reasoning: {response.reasoning}")
            return response
        except Exception as e:
            logging.error(f"Error during strategy adherence reflection generation: {e}")
            raise

class ErrorReflection(BaseModel):
    error_count: int = Field(description="Number of errors in SAS log")
    error_log: str = Field(description="Portion of the log that contains error messages")
    error_category: str = Field(..., description="Category of the error (e.g., syntax, runtime, logical)")

class RequirementReflection(BaseModel):
    missing_requirements: str = Field(description="Requirements not met in the code")
    requirements_met: bool = Field(description="Whether all requirements are met")

class StrategyAdherenceReflection(BaseModel):
    adherence_score: float = Field(...,
                                   description="Numerical score between 0 and 1 representing adherence to the strategy")
    reasoning: str = Field(..., description="Explanation of the adherence score, including any deviations from the strategy")

In [13]:
# --- Node and TreeState ---
class Node:
    def __init__(
            self,
            strategy: str,
            solution: str,
            evaluation: str,
            config: Dict,
            critic_score: float = 0,
            parent: Optional["Node"] = None,
            question: Optional[str] = None,

    ):
        self.strategy = strategy
        self.solution = solution
        self.parent = parent
        self.children = []
        self.evaluation = evaluation
        self.visits = 0
        self.value = 0
        self.critic_score = critic_score
        self.score = 0
        self.depth = parent.depth + 1 if parent is not None else 1
        self.question = question if question is not None else (parent.question if parent is not None else None)
        self.config = config
        if self.evaluation == "Accept":
            self.backpropagate(1, self.config)
        else:
            self.backpropagate(0, self.config)

    def calculate_score(self, config: Dict) -> float:
        """Calculates the node's score based on execution results and Critic's evaluation."""
        logging.info(f"Calculating score for node with strategy: {self.strategy}")
        execution_score = 0

        if self.solution:  # Only calculate if a solution exists
            log = execute_sas_code(self.solution)
            num_tests = 0
            passed_tests = 0

            # Example: Simple pass/fail scoring (adapt based on your test case format)
            for line in log.splitlines():
                if "passed:" in line.lower():
                    num_tests += 1
                    if "true" in line.lower():
                        passed_tests += 1
                if "test ok" in line.lower():
                    num_tests += 1
                    passed_tests += 1

            if num_tests > 0:
                execution_score = passed_tests / num_tests

        # Combine execution score and Critic's score (weighted average)
        # Use weights from config
        self.score = config["execution_weight"] * execution_score + config["critic_weight"] * self.critic_score
        logging.info(f"Node score calculated: {self.score}")
        return self.score

    def upper_confidence_bound(self, exploration_weight: float = math.sqrt(2)) -> float:
        if self.visits == 0:
            return float('inf')
        exploitation = self.value / self.visits
        exploration = math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration_weight * exploration

    def backpropagate(self, reward: float, config: Dict) -> None:
        """Updates the node's value and propagates it to its ancestors."""
        logging.info(f"Backpropagating reward: {reward} for node with strategy: {self.strategy}")
        node = self
        while node:
            node.visits += 1
            node.value = (node.value * (node.visits - 1) + reward) / node.visits
            if node.evaluation == "Accept":
                node.score = 1
            else:
                node.calculate_score(config)  # Recalculate score during backpropagation
            node = node.parent

    def select_child(self) -> "Node":
        """Selects the best child using UCB."""
        logging.info(f"Selecting child for node with strategy: {self.strategy}")
        best_child = None
        best_ucb = float('-inf')
        for child in self.children:
            ucb = child.upper_confidence_bound()
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
        logging.info(f"Selected child with strategy: {best_child.strategy if best_child else 'None'}")
        return best_child

    def is_terminal(self) -> bool:
        return self.evaluation == "Accept"

    def get_best_solution(self) -> "Node":
        """Retrieves the best solution from the subtree rooted at this node."""
        logging.info(f"Getting best solution for node with strategy: {self.strategy}")
        all_nodes = [self] + self._get_all_children()
        best_node = max(
            all_nodes,
            key=lambda node: (1 if node.evaluation == "Accept" else 0) * node.value
        )
        logging.info(f"Best solution found: {best_node.solution}")
        return best_node

    def _get_all_children(self) -> List["Node"]:
        """Retrieves all children of this node (recursively)."""
        all_nodes = []
        queue = deque(self.children)
        while queue:
            node = queue.popleft()
            all_nodes.append(node)
            queue.extend(node.children)
        return all_nodes

class TreeState(TypedDict):
    root: Optional[Node]
    input: str

In [14]:
# --- Tree Expansion Functions ---
def generate_initial_strategies(state: TreeState, llm, client, necessary_metadata, config) -> dict:
    """Generates initial strategies and creates the root node."""
    question = state["input"]
    logging.info(f"Generating initial strategies for question: {question}")
    thinker = Thinker(llm, necessary_metadata)
    solver = Solver(llm, necessary_metadata)
    critic = Critic(llm, client, necessary_metadata, config)

    # Create a dummy root node
    root = Node(strategy="Root", solution="", evaluation="Expand", question=question, config=config)

    # Generate and evaluate strategies one by one
    strategies = []
    for i in range(5):  # Assuming a maximum of 5 initial strategies as before
        
        if i == 0 :
            # 1. Generate a strategy
            strategy = thinker.generate_strategies(question, 1)
            new_strategy = strategy[0]
        else :
            strategy = thinker.generate_strategies(question, 1, strategies)
            new_strategy = strategy[0]
            
        if not new_strategy or new_strategy.lower() == "no further strategies.":
            logging.info("No further strategies generated.")
            break

        strategies.append(new_strategy)
        logging.info(f"Generated strategy {i+1}: {new_strategy}")

        # 2. Create a child node for the new strategy
        child_node = Node(strategy=new_strategy, solution="", evaluation="Expand", parent=root, question=question,
                          config=config)
        root.children.append(child_node)

        # 3. Generate a solution for the strategy
        solution = solver.generate_solution(question, new_strategy)
        child_node.solution = solution  # Update the node with the solution

        # 4. Evaluate the solution
        evaluation, critic_score, feedback = critic.evaluate_solution(question, solution, new_strategy)
        child_node.evaluation = evaluation
        child_node.critic_score = critic_score
        child_node.calculate_score(config)

        # 5. Check if the solution is verified ("Accept")
        if evaluation == "Accept":
            logging.info(f"Verified solution found for strategy: {new_strategy}. Stopping initial strategy generation.")
            return {**state, "root": root}  # Stop early

    logging.info("Finished generating initial strategies.")
    return {**state, "root": root}

def expand_tree(state: TreeState, llm, client, necessary_metadata, config: Dict, max_depth: int = 5) -> dict:
    """Expands the tree based on node evaluations and scores."""
    root = state["root"]
    question = state["input"]

    # Check if an "Accept" node already exists before expanding
    if any(node.evaluation == "Accept" for node in root._get_all_children()):
        logging.info("Acceptable solution already exists. Skipping expansion.")
        return state

    logging.info(f"Expanding tree for question: {question}")
    current_node = select_node_to_expand(root)
    logging.info(f"Selected node for expansion: {current_node.strategy}")

    solver = Solver(llm, necessary_metadata)
    critic = Critic(llm, client, necessary_metadata, config)
    debugger = Debugger(llm, necessary_metadata)
    thinker = Thinker(llm, necessary_metadata)

    if current_node.evaluation == "Expand":
        # Generate solution and evaluate
        logging.info(f"Generating solution for node: {current_node.strategy}")
        solution = solver.generate_solution(question, current_node.strategy)
        logging.info(f"Evaluating solution: {solution}")
        evaluation, critic_score, feedback = critic.evaluate_solution(question, solution, current_node.strategy)

        # Update current node and create a new node
        current_node.solution = solution
        current_node.evaluation = evaluation
        current_node.critic_score = critic_score
        current_node.calculate_score(config)

        new_node = Node(strategy=current_node.strategy, solution=solution, evaluation=evaluation,
                        critic_score=critic_score, parent=current_node, question=question, config=config)
        current_node.children.append(new_node)

    elif current_node.evaluation == "Refine":
        # Get reflections and refine solution
        log = execute_sas_code(current_node.solution)
        try:
            error_reflection = critic._get_error_reflection(current_node.question, current_node.solution, log)
            requirement_reflection = critic._get_requirement_reflection(current_node.question, current_node.solution,
                                                                        log)
        except Exception as e:
            logging.error(f"Error during reflections generation : {e}")
            raise

        feedback = ""
        if error_reflection.error_count > 0:
            feedback += "Errors:\n"
            feedback += f"- {error_reflection.error_log}\n"

        if not requirement_reflection.requirements_met:
            feedback += "Missing Requirements:\n"
            feedback += f"- {requirement_reflection.missing_requirements}\n"

        # Generate reflections using Thinker
        logging.info(f"Generating reflections for node: {current_node.strategy}")
        reflections = thinker.generate_reflections(current_node.question, current_node.solution, feedback, log,
                                                   current_node.strategy)

        if reflections == "the code is good":
            # If no issues, simply append the current node again (no changes)
            logging.info(f"No issues found, appending current node again: {current_node.strategy}")
            evaluation, critic_score, critic_feedback = critic.evaluate_solution(
                current_node.question, current_node.solution, current_node.strategy
            )
            new_node = Node(strategy=current_node.strategy, solution=current_node.solution, evaluation=evaluation,
                            critic_score=critic_score, parent=current_node, question=question, config=config)
            current_node.children.append(new_node)
        else:
            # Refine solution using Debugger
            logging.info(f"Refining solution for node: {current_node.strategy}")
            refined_solution = debugger.generate_refinement(question, current_node.solution, reflections)
            logging.info(f"Evaluating refined solution: {refined_solution}")
            evaluation, critic_score, feedback = critic.evaluate_solution(question, refined_solution,
                                                                          current_node.strategy)

            # Create new node for refined solution
            new_node = Node(strategy=current_node.strategy, solution=refined_solution, evaluation=evaluation,
                            critic_score=critic_score, parent=current_node, question=question, config=config)
            current_node.children.append(new_node)

            # Dynamic expansion based on Critic's score
            if critic_score > config["expansion_threshold"] and current_node.depth < max_depth:
                logging.info(
                    f"Critic score {critic_score} above expansion threshold {config['expansion_threshold']}. Generating new strategies.")
                new_strategies = thinker.generate_strategies(question)
                for new_strategy in new_strategies:
                    new_node = Node(strategy=new_strategy, solution="", evaluation="Expand", parent=current_node,
                                    question=question, config=config)
                    current_node.children.append(new_node)
            elif critic_score < config["abort_threshold"]:
                logging.info(
                    f"Critic score {critic_score} below abort threshold {config['abort_threshold']}. Marking node as Abort.")
                current_node.evaluation = "Abort"

    return {**state, "root": root}

def select_node_to_expand(root: Node) -> Node:
    """Selects a node to expand based on UCB and evaluation."""
    node = root
    while node.children:
        expandable_children = [child for child in node.children if child.evaluation == "Expand"]
        if expandable_children:
            logging.info("Selecting node to expand...")
            return max(expandable_children, key=lambda child: child.upper_confidence_bound())

        refinable_children = [child for child in node.children if child.evaluation == "Refine"]
        if refinable_children:
            logging.info("Selecting node to refine...")
            return max(refinable_children, key=lambda child: child.upper_confidence_bound())

        abortable_children = [child for child in node.children if child.evaluation == "Abort"]
        if abortable_children:
            logging.info("Found abortable node. Moving to parent.")
            if node == root:
                return node
            else:
                node = node.parent
        else:
            node = node.children[0]

    return node

def should_continue(state: TreeState) -> str:
    """Determines whether the search should continue."""
    root = state["root"]

    # If a solution is found, stop
    if any(node.evaluation == "Accept" for node in root._get_all_children()):
        logging.info("Solution found. Stopping...")
        return END

    # If any node is marked for expansion, continue
    if any(node.evaluation == "Expand" for node in root._get_all_children()):
        logging.info("Expansion required. Continuing...")
        return "expand"

    # If any node is marked for expansion, continue
    if any(node.evaluation == "Refine" for node in root._get_all_children()):
        logging.info("Refinement required. Continuing...")
        return "expand"

    # If max depth is reached for all nodes, stop
    if all(node.depth >= 5 for node in root._get_all_children()):
        logging.info("Max depth reached. Stopping...")
        return END

    logging.info("No action determined. Continuing by default...")
    return "expand"

In [None]:
# Initialize configuration (tune these parameters)
config = {
        "execution_weight": 0.6,
        "critic_weight": 0.4,
        "error_weight": 0.2,
        "requirement_weight": 0.3,
        "verification_fail_score": 0.5,
        "abort_threshold": 0.4,
        "expansion_threshold": 0.6
    }
    # Example query
query = """
    """
# Initialize LLM and necessary metadata
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp", convert_system_message_to_human=True)
# Mocking the client for testing purposes, replace with actual client if needed
client = Client("Qwen/Qwen2.5-Coder-Artifacts")
selector = MetadataSelector(model_name="gemini-2.0-flash-exp", max_iterations=2)
necesary_metadata = selector.select_metadata(query, filtered_metadata)
print("necessary metadata to answer the question :",necesary_metadata)
# Initialize the tree state
state = {"input": query}
state = generate_initial_strategies(state, llm,client, necessary_metadata, config)

# Main loop
while should_continue(state) == "expand":
    state = expand_tree(state, llm, client, necessary_metadata, config)

# Find the best solution
best_solution_node = state["root"].get_best_solution()

if best_solution_node:
    logging.info(f"Best solution found:\n{best_solution_node.solution}")
    if best_solution_node.evaluation == "Accept":
        logging.info("Solution was accepted by the Critic.")
    else:
        logging.info("Solution was not accepted, but it's the best found within the search limits.")

    # Extract data if needed
    extract_data(best_solution_node.solution,client)
else:
    logging.info("No solution found within the search limits.")