In [1]:
%%capture --no-stderr
%pip install -U --quiet langchain langgraph langchain_google_genai
%pip install -U --quiet saspy

ERROR: Could not install packages due to an OSError: [WinError 32] Le processus ne peut pas accéder au fichier car ce fichier est utilisé par un autre processus: 'c:\\users\\talelbm\\appdata\\local\\anaconda3\\lib\\site-packages\\saspy\\java\\iomclient\\log4j-1.2-api-2.12.4.jar'
Consider using the `--user` option or check the permissions.



In [1]:
import warnings
warnings.filterwarnings('ignore')
import math
import os
import getpass
import re
from collections import deque
from typing import Optional, Dict, Any, List, Tuple
from typing_extensions import TypedDict
import pandas as pd
from gradio_client import Client
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.output_parsers import PydanticOutputParser
from langchain_core.output_parsers import StrOutputParser
import google.generativeai as genai
from pydantic import BaseModel, Field
import logging
import saspy
import json
# Constants
END = "end"

In [2]:
genai.configure(api_key="GOOGLE_API_KEY")

In [3]:
def _set_if_undefined(var: str) -> None:
    """Set environment variable if not already defined."""
    if os.environ.get(var):
        return
    os.environ[var] = getpass.getpass(var)

_set_if_undefined("GOOGLE_API_KEY")

GOOGLE_API_KEY ········


In [4]:
# Configure logging
logging.basicConfig(level=logging.INFO, format="""%(asctime)s - %(levelname)s - %(message)s
-------------------------------------------""")

In [5]:
try:
    metadata_file_path = 'metadata_dict.json'  # Assuming the file is in the same directory
    with open(metadata_file_path, 'r', encoding='utf-8') as f:
        metadata_dict = json.load(f)
except FileNotFoundError:
    logging.error(f"Error: Metadata file not found at {metadata_file_path}")
    raise
except json.JSONDecodeError:
    logging.error(f"Error: Invalid JSON format in {metadata_file_path}")
    raise

In [None]:
try:
    sas = saspy.SASsession(cfgname='iomwin', cfgfile='sascfg.py')
    logging.info("SAS session initialized successfully.")
except Exception as e:
    logging.error(f"Error initializing SAS session: {e}")
    raise

In [7]:
class MetadataVerifier:
    def __init__(self, model):
        self.model = model

    def verify_metadata(self, metadata: str, question: str, original_metadata: dict) -> tuple[bool, str]:
        logging.info("Verifying metadata...")
        verification_prompt = f"""Act as a critical data analyst reviewing metadata selection. 
Verify if the selected metadata is contains all the sufficient and necessary databases and columns for answering the query. do not go further than that.
if the database is sufficient to answer the query, approve it.

Question: {question}
Selected Metadata (to be analyzed carefully):
{metadata}

Complete Available Metadata (to be analyzed carefully):
{original_metadata}

Check for these specific issues:
1. Missing Essential Columns:
   <redacted for privacy reasons>

2. Completeness Check:
   - All columns necessary for the query's calculations
   - All columns needed for filtering conditions
   - All columns needed for grouping or aggregation
   - All columns needed for the final output

3. Business Rules Verification:
   <redacted for privacy reasons>

Respond with a string containing:
    approved: boolean ("approved: True" or "approved : False"),
    issues: [list of specific issues found, strings with "..." format],
    missing columns: [list of essential missing columns, strings with "..." format],
    suggestions: [specific suggestions for improvement, strings with "..." format],
    criticism: "detailed explanation of why the metadata is or isn't optimal"
"""
        try:
            response = self.model.generate_content(verification_prompt)
            logging.info("Metadata verification completed.")
            return "approved: True" in response.text, response.text
        except Exception as e:
            logging.error(f"Error during metadata verification: {e}")
            return False, str(e)

class CachedMetadata:
    _instance = None

    def __init__(self, question, metadata_dict):
        if CachedMetadata._instance is not None:
            logging.warning("Attempted to create a new instance of Singleton CachedMetadata.")
            raise Exception("This class is a singleton!")

        self.model_gemini = genai.GenerativeModel("gemini-2.0-flash-exp")
        self.question = question
        self.metadata_dict = metadata_dict
        self._metadata = None
        self._verifier = MetadataVerifier(self.model_gemini)
        self._generate_metadata()

    @classmethod
    def get_instance(cls, question=None, metadata_dict=None):
        if cls._instance is None:
            if question is None or metadata_dict is None:
                logging.error("Metadata dictionary and question must be provided for the first initialization.")
                raise ValueError(
                    "Metadata dictionary and question must be provided for the first initialization.")
            logging.info("Creating new instance of CachedMetadata.")
            cls._instance = CachedMetadata(question, metadata_dict)
        else:
            logging.info("Returning existing instance of CachedMetadata.")
        return cls._instance

    def _generate_metadata(self):
        logging.info("Generating metadata...")
        prompt = f"""Act as an expert data analyst. Given a query and a database's metadata, identify the smallest set of databases 
        and columns necessary to answer the query accurately. Ensure that only the essential databases and columns are included, and 
        consider any required joins or relationships between tables. 
        Follow these specific rules:
        - give back a string in this format : 
        ""
        relevant_database1 : relevant_database1_description
            relevant_column11 : relevant_column11_type
                                relevant_column11_description
                                relevant_column11_values (if they exist)
            ...
        ""
        
        <redacted for privacy reasons>
        
        Provide the selected subset of metadata as a string, including only the necessary databases and their associated columns, without any additional explanations."
        Here is the question: {self.question}
        Here is the metadata: {self.metadata_dict}"""

        try:
            response = self.model_gemini.generate_content(prompt)
            self._metadata = response.text
            logging.info("Metadata generated.")

            # Verify metadata immediately
            is_approved, verification_result = self._verifier.verify_metadata(
                self._metadata, self.question, self.metadata_dict
            )
            if not is_approved:
                logging.warning("Initial metadata not approved. Regenerating...")
                self._regenerate_metadata(verification_result)
            return self._metadata
        except Exception as e:
            logging.error(f"Error during metadata generation: {e}")
            raise

    def _regenerate_metadata(self, verification_result):
        logging.info("Regenerating metadata...")
        regeneration_prompt = f"""Previous metadata selection was inadequate. Please regenerate the metadata selection addressing these issues:
        task:
        Act as an expert data analyst. Given a query and a database's metadata, identify the smallest set of databases 
        and columns necessary to answer the query accurately. Ensure that only the essential databases and columns are included, and 
        consider any required joins or relationships between tables. 
        Follow these specific rules:
        - give back a string in this format : 
        ""
        relevant_database1 : relevant_database1_description
            relevant_column11 : relevant_column11_type
                                relevant_column11_description
                                relevant_column11_values (if they exist)
            ...
        ""
        <redacted for privacy reasons>
        
        Provide the selected subset of metadata as a string, including only the necessary databases and their associated columns, without any additional explanations.
        The old response:
        {self._metadata}
        The full critique:
        {verification_result}

        Original question: {self.question}
        Available metadata: {self.metadata_dict}"""
        try:
            response = self.model_gemini.generate_content(regeneration_prompt)
            self._metadata = response.text
            logging.info("Metadata regenerated.")
        except Exception as e:
            logging.error(f"Error during metadata regeneration: {e}")
            raise

    def get_metadata(self):
        return self._metadata

In [8]:
# --- Utility Functions ---
def execute_sas_code(sas_code: str) -> str:
    """Execute SAS code and return log."""
    logging.info("Executing SAS code...")
    try:
        result = sas.submit(sas_code)
        logging.info("SAS code executed. Returning log.")
        return result['LOG']
    except Exception as e:
        logging.error(f"Error executing SAS code: {e}")
        return str(e)

def extract_data(sas_code: str, client) -> None:
    """Execute SAS code and save results to Excel."""
    logging.info("Extracting data...")
    try:
        sas.submit(sas_code)
        # Assuming 'client' is defined globally or passed as a parameter
        df_name_result = client.predict(
            query=f"Extract final dataframe name from, dont include 'WORK.', just the raw name:\n{sas_code}",
            api_name="/generation_code"
        )

        if df_name_result:
            df_name = df_name_result[0]
            df = sas.sd2df(df_name)
            if not df.empty:
                excel_file = f"{df_name}.xlsx"
                df.to_excel(excel_file, index=False)
                logging.info(f"Data extracted to {excel_file}")
            else:
                logging.warning(f"Dataframe {df_name} is empty.")
        else:
            logging.warning("Could not determine dataframe name for extraction.")
    except Exception as e:
        logging.error(f"Error during data extraction: {e}")

In [9]:
# --- Prompt Templates and System Messages ---
def get_thinker_system_prompt(necessary_metadata):
    return f"""
You are StarThinker, a strategic module of the StarData AI coding agent.
You have a basic understanding of actuary line of thinking, especially concerning automobile insurance.
Your primary role is to analyze a given data query and devise multiple, distinct strategies for answering it using SAS code. You do not generate code, only high-level strategies.

**Your Task:**

- Given a data query and metadata, generate distinct and exhaustive strategies outlining how to approach the problem using the available data in the SAS databases.
- Strategies should be concise (1-2 sentences each) and explore different logical paths to the solution.
- Consider various ways to utilize the provided metadata, which includes database names, descriptions, column names, descriptions, types, and potential values.
- Dynamically determine the appropriate number of strategies based on the complexity of the query. Stop generating new strategies if you deem them to be sufficient and exhaustive.
- give your response in this format :
    'strategy 1 : <details of this strategy>\n
    strategy 2 : <details of this strategy>\n
    ...'
**Key Considerations from StarData:**

<redacted for privacy reasons>

**Metadata:**

{necessary_metadata}
"""

def get_solver_system_prompt(necessary_metadata):
    return f"""
You are StarSolver, a code generation module of the StarData AI coding agent.
You have a basic understanding of actuary line of thinking, especially concerning automobile insurance.
You are an expert in SAS programming. Your task is to generate SAS code that accurately answers data queries based on provided strategies and metadata.

**Your Task:**

- Given a data query and a specific strategy, generate a complete and executable SAS code solution.
- The code should be syntactically correct and produce the desired output based on the query and strategy.
- Output ONLY the SAS code, without any explanations or additional text.
- Save all new dataframes in the WORK library (temporary).
- When doing union joins, select only relevant columns.
- Adhere to all coding instructions and guidelines specified below.

**Coding Instructions from StarData:**

<redacted for privacy reasons>

**Metadata:**

{necessary_metadata}
"""

def get_debugger_system_prompt(necessary_metadata):
    return f"""
You are StarDebugger, a code refinement module of the StarData AI coding agent.
You have a basic understanding of actuary line of thinking, especially concerning automobile insurance.
You are an expert in SAS programming and debugging. Your task is to refine and improve SAS code solutions based on provided feedback.

**Your Task:**

- Given a SAS code solution, and feedback (which may include error messages, requirement shortcomings, or suggestions), generate a corrected and improved version of the code.
- The refined code should address all issues mentioned in the feedback and produce the desired output according to the original query.
- Output ONLY the refined SAS code, without any explanations or additional text.

**Coding and Debugging Instructions from StarData:**
<redacted for privacy reasons>
**Metadata:**

{necessary_metadata}
"""

def get_critic_system_prompt(necessary_metadata):
    return f"""
You are StarCritic, a solution evaluator module of the StarData AI coding agent.
You have a basic understanding of actuary line of thinking, especially concerning automobile insurance.
You are an expert in SAS programming and analysis. Your task is to evaluate SAS code solutions generated by the Solver agent, provide feedback, and assess their suitability for refinement or acceptance.

**Your Task:**

- Given a SAS code solution, its corresponding strategy, and the original query, evaluate the solution's correctness, adherence to the strategy, and overall quality.
- Provide a numerical score (critic_score) that reflects the solution's quality and potential for improvement.
- Generate specific feedback (feedback) that identifies errors, missing requirements, and areas for improvement.
- Determine whether the solution should be refined, aborted, or accepted based on your evaluation.

**Evaluation Criteria:**

- **Correctness:** Does the code execute without errors? Does it produce the expected output for visible test cases?
- **Strategy Adherence:** Does the code effectively implement the given strategy? Does it logically follow the strategy and use appropriate data structures and algorithms?
- **Robustness:** Is the solution likely to be correct for unseen test cases? Does it handle potential edge cases and demonstrate a general understanding of the problem?
- **Requirements Fulfillment:** Does the code meet all the requirements stated in the original query? Are there any missing functionalities or discrepancies?

**Decision Logic:**

- **Refine:** If the solution has errors, fails to meet requirements, or does not fully adhere to the strategy, it should be refined.
- **Abort:** If the solution has major flaws, scores very low on the evaluation criteria, or is unlikely to be improved with further refinement, it should be aborted.
- **Accept:** If the solution passes all visible test cases, adheres to the strategy, meets all requirements, and is deemed robust, it should be accepted as the final solution.

**Additional Instructions:**

- Consider the provided metadata when evaluating the solution's correctness and requirements fulfillment.
- Be specific in your feedback, pointing out the exact lines of code or functionalities that need improvement.
- Use a combination of numerical scores and qualitative feedback to provide a comprehensive evaluation.
- When dealing with the GOUVERNORAT COLUMN, automatically make these changes: "LE KEF" to "KEF", "MEHDIA" to "MAHDIA", and "MANOUBA" to "MANNOUBA".
- A sinistre is an accident, identified by a key and associated with its police. A sinistre can happen in year x and not get reported until later years. So when looking for a sinistre that happened in year x, you need to look for all the vue_sinistres_y where y>=x. RADHOUAN.VUE_SINISTRE is basically vue_sinistre_2024.
- When asked about characteristics of a sinistre that don't exist in the vue_sinistre database (e.g., gouvernorat or classe bm of the associated policy), you have to look for these details in the dataframes associated with the policies and make the due join operations.
- When asked about a history of a sinistre, you need to first verify if it exists in RADHOUAN.VUE_HIST_SINISTRE.
- When making unions make sure that the columns of the same name are of similar type too.
- When asked for filtering by 'Défense' or 'Recours' automatically filter by column TYPE_DOSSIER_AFIN and try to find the lines that contain 'Défense' or 'Recours'
- When asked for calculation that use reglements columns, always use reglements afin, same with reserve, use reserve_afin unless explicitly stated otherwise in the query.

**Metadata:**

{necessary_metadata}
"""

In [19]:
# --- Agent Classes ---
class Thinker:
    def __init__(self, llm, necessary_metadata):
        self.llm = llm
        self.necessary_metadata = necessary_metadata
        self.strategies_cache: Dict[str, List[str]] = {}
        self.reflections_cache: Dict[Tuple[str, str, str, str], List[str]] = {}

    def generate_strategies(self, question: str, max_strategies: int = 5, previous_strategies: List[str] = None) -> List[str]:
        """Generates multiple strategies for solving the coding problem autoregressively."""
        logging.info(f"Generating strategies for question: {question}")
        if question in self.strategies_cache:
            logging.info(f"Using cached strategies for question: {question}")
            return self.strategies_cache[question]

        strategies = []
        
        if previous_strategies is not None:
            strategies = previous_strategies

        for i in range(max_strategies):
            prompt = self._build_strategy_prompt(question, strategies)
            try:
                response = (prompt | self.llm | StrOutputParser()).invoke({
                    "question": question,
                    "necessary_metadata": self.necessary_metadata,
                    "previous_strategies": "\n".join(strategies)
                })
            except Exception as e:
                logging.error(f"Error during strategy generation: {e}")
                break

            new_strategy = self._extract_strategy(response)
            if not new_strategy or new_strategy.lower() == "no further strategies.":
                logging.info("No further strategies generated.")
                break

            strategies.append(new_strategy)
            logging.info(f"Generated strategy {i+1}: {new_strategy}")

        self.strategies_cache[question] = strategies
        return strategies

    def _build_strategy_prompt(self, question: str, previous_strategies: List[str]) -> ChatPromptTemplate:
        """Builds the prompt for generating the next strategy."""
        system_prompt = get_thinker_system_prompt(self.necessary_metadata)

        if previous_strategies:
            system_prompt += "\n\nPrevious Strategies:\n" + "\n".join(
                [f"{i+1}. {strategy}" for i, strategy in enumerate(previous_strategies)]
            )
        system_prompt += "\n\nGenerate the next strategy, or write 'No further strategies.' if no further strategies can be generated."

        prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("user", f"Question: {question}\nStrategy:")
        ])
        return prompt

    def _extract_strategy(self, response: str) -> str:
        """Extracts a strategy from the LLM's response."""
        # Use regex to find the strategy after 'strategy x :'
        match = re.search(r"strategy \d+ : (.*)", response, re.IGNORECASE)
        if match:
            return match.group(1).strip()
        else:
            return ""

    def generate_reflections(self, question: str, solution: str, feedback: str, log: str, strategy: str,
                             num_reflections: int = 3) -> List[str]:
        """Generates reflections on the code autoregressively."""
        logging.info(f"Generating reflections for question: {question}")
        cache_key = (question, solution, feedback, log)
        if cache_key in self.reflections_cache:
            logging.info(f"Using cached reflections for question: {question}")
            return self.reflections_cache[cache_key]

        reflections = []
        for _ in range(num_reflections):
            prompt = self._build_reflection_prompt(question, solution, feedback, log, strategy, reflections)
            try:
                response = (prompt | self.llm | StrOutputParser()).invoke({
                    "question": question,
                    "solution": solution,
                    "feedback": feedback,
                    "log": log,
                    "strategy": strategy,
                    "necessary_metadata": self.necessary_metadata,
                    "previous_reflections": "\n".join(reflections)
                })
            except Exception as e:
                logging.error(f"Error during reflection generation: {e}")
                break

            new_reflection = response.strip()
            if not new_reflection or new_reflection.lower() == "no further reflections.":
                logging.info("No further reflections generated.")
                break
            reflections.append(new_reflection)
            logging.info(f"Generated reflection: {new_reflection}")

        self.reflections_cache[cache_key] = reflections
        return reflections

    def _build_reflection_prompt(self, question: str, solution: str, feedback: str, log: str, strategy: str,
                                 previous_reflections: List[str]) -> ChatPromptTemplate:
        """Builds the prompt for generating the next reflection."""
        system_prompt = get_thinker_system_prompt(self.necessary_metadata)
        system_prompt += f"\n\nStrategy used: {strategy}"
        system_prompt += "\n\nYour task is to generate reflections on the following SAS code solution, considering the feedback and log provided. "
        system_prompt += "Reflections should be concise and focus on identifying areas for improvement or issues in the code."

        if previous_reflections:
            system_prompt += "\n\nPrevious Reflections:\n" + "\n".join(
                [f"{i+1}. {reflection}" for i, reflection in enumerate(previous_reflections)]
            )
            system_prompt += "\n\nGenerate the next reflection, or write 'No further reflections.' if no further reflections are needed"

        prompt = ChatPromptTemplate.from_messages([
            ("system", system_prompt),
            ("user", f"Given the question: '{question}', the strategy: '{strategy}', the following SAS code solution: \n"
                     f"```sas\n{solution}\n```\n"
                     f"and the feedback from the Critic: '{feedback}',\n"
                     f"and the log from the code execution : '{log}', \n"
                     f"please generate reflections to guide the code refinement process.")
        ])
        return prompt

class Solver:
    def __init__(self, llm, necessary_metadata):
        self.llm = llm
        self.necessary_metadata = necessary_metadata
        self.solutions_cache: Dict[Tuple[str, str], str] = {}

    def generate_solution(self, question: str, strategy: str) -> str:
        """Generates an initial SAS code solution based on the given strategy."""
        logging.info(f"Generating solution for question: {question} with strategy: {strategy}")
        cache_key = (question, strategy)
        if cache_key in self.solutions_cache:
            logging.info(f"Using cached solution for question: {question} and strategy: {strategy}")
            return self.solutions_cache[cache_key]

        prompt = ChatPromptTemplate.from_messages([
            ("system", get_solver_system_prompt(self.necessary_metadata)),
            ("user", f"Given the strategy: '{strategy}', please write SAS code to answer this question: {question}")
        ])
        try:
            response = (prompt | self.llm | StrOutputParser()).invoke({
                "question": question,
                "strategy": strategy,
                "necessary_metadata": self.necessary_metadata
            })
            solution = self._extract_code(response)
            logging.info(f"Generated solution:\n{solution}")
        except Exception as e:
            logging.error(f"Error during solution generation: {e}")
            solution = ""

        self.solutions_cache[cache_key] = solution
        return solution

    def _extract_code(self, response: str) -> str:
        """Extracts the SAS code from the LLM's response."""
        return response.strip()

class Debugger:
    def __init__(self, llm, necessary_metadata):
        self.llm = llm
        self.necessary_metadata = necessary_metadata
        self.debugged_solutions_cache: Dict[Tuple[str, str, str], str] = {}

    def generate_refinement(self, question: str, solution: str, reflections: str) -> str:
        """Refines the given SAS code solution based on reflections."""
        logging.info(f"Generating refinement for question: {question}")
        cache_key = (question, solution, reflections)
        if cache_key in self.debugged_solutions_cache:
            logging.info(f"Using cached refined solution for question: {question}")
            return self.debugged_solutions_cache[cache_key]

        prompt = ChatPromptTemplate.from_messages([
            ("system", get_debugger_system_prompt(self.necessary_metadata)),
            ("user",
             f"Given the reflections: '{reflections}', please refine the following SAS code:\n```sas\n{solution}\n```")
        ])
        try:
            response = (prompt | self.llm | StrOutputParser()).invoke({
                "question": question,
                "solution": solution,
                "reflections": reflections,
                "necessary_metadata": self.necessary_metadata
            })
            refined_solution = self._extract_code(response)
            logging.info(f"Generated refined solution:\n{refined_solution}")
        except Exception as e:
            logging.error(f"Error during refinement generation: {e}")
            refined_solution = ""

        self.debugged_solutions_cache[cache_key] = refined_solution
        return refined_solution

    def _extract_code(self, response: str) -> str:
        """Extracts the SAS code from the LLM's response."""
        return response.strip()

class Critic:
    def __init__(self, llm, client, necessary_metadata, config: Dict):
        self.llm = llm
        self.client = client
        self.necessary_metadata = necessary_metadata
        self.evaluation_cache = {}
        self.config = config

    def evaluate_solution(self, question: str, solution: str, strategy: str) -> Tuple[str, float, str]:
        """Evaluates the generated solution, provides feedback, and assigns a numerical score."""
        logging.info(f"Evaluating solution for question: {question}")
        cache_key = (question, solution, strategy)
        if cache_key in self.evaluation_cache:
            logging.info(f"Using cached evaluation for question: {question}")
            return self.evaluation_cache[cache_key]
            
        if "```sas" in solution:
            start_index = solution.find("```sas") + 6  # +6 to skip "```sas"
            end_index = solution.find("```", start_index)
        if end_index != -1:
              solution = solution[start_index:end_index].strip()
            
        log = execute_sas_code(solution)
        try:
            error_reflection = self._get_error_reflection(question, solution, log)
            requirement_reflection = self._get_requirement_reflection(question, solution, log)
            strategy_adherence_reflection = self._get_strategy_adherence_reflection(question, solution, strategy)
        except Exception as e:
            logging.error(f"Error during reflections generation : {e}")
            raise

        # --- More Detailed Evaluation ---
        evaluation = ""
        critic_score = 0.0
        feedback = ""

        if error_reflection.error_count > 0:
            feedback += "Errors:\n"
            feedback += f"- {error_reflection.error_log}\n"
            critic_score += self.config["error_weight"]  # Use configurable weight
            evaluation = "Refine"

        if not requirement_reflection.requirements_met:
            feedback += "Missing Requirements:\n"
            feedback += f"- {requirement_reflection.missing_requirements}\n"
            critic_score += self.config["requirement_weight"]  # Use configurable weight
            evaluation = "Refine" if evaluation != "Refine" else "Refine"

        # Strategy Adherence Score
        critic_score = (
                            critic_score + strategy_adherence_reflection.adherence_score
                        ) / 2 if evaluation == "Refine" else strategy_adherence_reflection.adherence_score

        feedback += f"Strategy Adherence: {strategy_adherence_reflection.adherence_score:.2f} - {strategy_adherence_reflection.reasoning}\n"

        # --- Decision Logic ---
        if error_reflection.error_count == 0 and requirement_reflection.requirements_met:
            # Solution passes visible tests, perform verification
            try:
                if self.verify_solution(question, solution):
                    evaluation = "Accept"
                    critic_score = 1.0  # Perfect score if it passes verification
                else:
                    evaluation = "Refine"
                    feedback += "Solution passes visible tests but fails verification (potential overfitting or lack of robustness).\n"
                    critic_score = self.config["verification_fail_score"]  # Use configurable score
            except Exception as e:
                logging.error(f"Error during solution verification : {e}")
                evaluation = "Refine"
                feedback += "Solution passes visible tests but fails verification (potential overfitting or lack of robustness).\n"
                critic_score = self.config["verification_fail_score"]
        elif critic_score < self.config["abort_threshold"]:  # Use configurable threshold
            evaluation = "Abort"

        logging.info(f"Evaluation: {evaluation}, Critic Score: {critic_score}, Feedback: {feedback}")
        self.evaluation_cache[cache_key] = (evaluation, critic_score, feedback)
        return evaluation, critic_score, feedback

    def verify_solution(self, question: str, solution: str) -> bool:
        """Verifies if a solution that passes visible tests is robust and generalizable."""
        logging.info(f"Verifying solution for question: {question}")
        verification_prompt = ChatPromptTemplate.from_messages([
            ("system", get_critic_system_prompt(self.necessary_metadata)),
            ("user", f"""
        We have a SAS code solution that passes all visible test cases for the query: '{question}'.
        ```sas
        {solution}
        ```
        Your task is to assess whether this solution is likely to be correct for unseen test cases as well. 
        Consider the following:
        - Does the code seem overly tailored to the specific visible test cases, or does it demonstrate a general understanding of the problem?
        - Are there any potential edge cases or scenarios not covered by the visible tests that the code might fail on?

        Answer with 'True' if the solution is likely to be correct for unseen test cases, and 'False' otherwise. Provide a brief explanation for your assessment.
        """)
    ])
        try:
            response = (verification_prompt | self.llm | StrOutputParser()).invoke({
            "question": question,
            "solution": solution,
            "necessary_metadata": self.necessary_metadata
        })

            logging.info(f"Raw verification response: {response}")  # Log the raw response

            # Normalize the response to handle variations (e.g., "True.", "True ", "TRUE")
            cleaned_response = response.strip().lower() 

            logging.info(f"Cleaned verification response: {cleaned_response}") # Log the cleaned response

            if "true" in cleaned_response:
                logging.info("Solution verification successful.")
                return True
            else:
                logging.warning(f"Solution verification failed: {response}")
                return False

        except Exception as e:
             logging.error(f"Error during solution verification: {e}")
             return False

    def _get_error_reflection(self, query: str, sas_code: str, log: str) -> "ErrorReflection":
        """
        Identifies errors in the SAS log and provides an ErrorReflection.
        Now also tries to categorize the error.
        """
        logging.info("Getting error reflection...")
        parser = PydanticOutputParser(pydantic_object=ErrorReflection)

        prompt = ChatPromptTemplate.from_messages([
            ("system",
             "You are a helpful assistant that identifies and categorizes errors in SAS code execution logs. {format_instructions}"),
            ("user", "Query: {query}\nCode:\n```sas\n{sas_code}\n```\nLog:\n{log}")
        ])
        try:
            response = (prompt | self.llm | parser).invoke({
                "query": query,
                "sas_code": sas_code,
                "log": log,
                "format_instructions": parser.get_format_instructions()
            })
            logging.info(
                f"Error reflection: Error Count: {response.error_count}, Error Log: {response.error_log}, Error Category: {response.error_category}")
            return response
        except Exception as e:
            logging.error(f"Error during error reflection generation : {e}")
            raise

    def _get_requirement_reflection(self, query: str, sas_code: str, log: str) -> "RequirementReflection":
        """
        Determines whether the SAS code meets the requirements of the original query and provides a RequirementReflection.
        Now uses a PydanticOutputParser for structured output.
        """
        logging.info("Getting requirement reflection...")
        parser = PydanticOutputParser(pydantic_object=RequirementReflection)

        prompt = ChatPromptTemplate.from_messages([
            ("system", f"""
            You are a helpful assistant that identifies missing requirements in SAS code.
            Use the metadata to understand the database and ensure the requirements are aligned with the database's structure.
            {{format_instructions}}
            metadata : {self.necessary_metadata}"""),
            ("user", f"""
            Original Query: {query}
            Generated SAS Code:
            ```sas
            {sas_code}
            ```
            Does the SAS code fulfill all requirements stated in the original query?
            Answer with 'True' if everything is generally satisfying to the average user, only give back "False" if there flagrant error of totally not understanding the query.
            """)
        ])
        try:
            response = (prompt | self.llm | parser).invoke({
                "query": query,
                "sas_code": sas_code,
                "necessary_metadata": self.necessary_metadata,
                "format_instructions": parser.get_format_instructions()
            })
            logging.info(
                f"Requirement reflection: Missing Requirements: {response.missing_requirements}, Requirements Met: {response.requirements_met}")
            return response
        except Exception as e:
            logging.error(f"Error during requirement reflection generation : {e}")
            raise

    def _get_strategy_adherence_reflection(self, query: str, sas_code: str,
                                            strategy: str) -> "StrategyAdherenceReflection":
        """
        Evaluates the adherence of the SAS code to the given strategy and provides a StrategyAdherenceReflection.
        Uses a PydanticOutputParser for structured output, including a numerical score.
        """
        logging.info("Getting strategy adherence reflection...")
        parser = PydanticOutputParser(pydantic_object=StrategyAdherenceReflection)

        prompt = ChatPromptTemplate.from_messages([
            ("system", f"""
            You are a helpful assistant that evaluates the adherence of SAS code to a given strategy.
            Use the metadata to understand the database and ensure the code aligns with the strategy and the database's structure.
            {{format_instructions}}
            metadata : {self.necessary_metadata}
            """),
            ("user", f"""
            Original Query: {query}
            Strategy: {strategy}
            Generated SAS Code:
            ```sas
            {sas_code}
            ```
            Assess the adherence of the SAS code to the given strategy. Consider:
            1. How well does the code logically follow the strategy?
            2. Does the code use appropriate data structures and algorithms as suggested by the strategy?
            3. Are there any deviations from the strategy, and if so, are they justified?
            Provide a numerical adherence score between 0 and 1, where 1 represents perfect adherence and 0 represents no adherence.
            """)
        ])

        try:
            response = (prompt | self.llm | parser).invoke({
                "query": query,
                "sas_code": sas_code,
                "strategy": strategy,
                "necessary_metadata": self.necessary_metadata,
                "format_instructions": parser.get_format_instructions()
            })
            logging.info(
                f"Strategy adherence reflection: Adherence Score: {response.adherence_score}, Reasoning: {response.reasoning}")
            return response
        except Exception as e:
            logging.error(f"Error during strategy adherence reflection generation: {e}")
            raise

class ErrorReflection(BaseModel):
    error_count: int = Field(description="Number of errors in SAS log")
    error_log: str = Field(description="Portion of the log that contains error messages")
    error_category: str = Field(..., description="Category of the error (e.g., syntax, runtime, logical)")

class RequirementReflection(BaseModel):
    missing_requirements: str = Field(description="Requirements not met in the code")
    requirements_met: bool = Field(description="Whether all requirements are met")

class StrategyAdherenceReflection(BaseModel):
    adherence_score: float = Field(...,
                                   description="Numerical score between 0 and 1 representing adherence to the strategy")
    reasoning: str = Field(..., description="Explanation of the adherence score, including any deviations from the strategy")

In [15]:
# --- Node and TreeState ---
class Node:
    def __init__(
            self,
            strategy: str,
            solution: str,
            evaluation: str,
            config: Dict,
            critic_score: float = 0,
            parent: Optional["Node"] = None,
            question: Optional[str] = None,

    ):
        self.strategy = strategy
        self.solution = solution
        self.parent = parent
        self.children = []
        self.evaluation = evaluation
        self.visits = 0
        self.value = 0
        self.critic_score = critic_score
        self.score = 0
        self.depth = parent.depth + 1 if parent is not None else 1
        self.question = question if question is not None else (parent.question if parent is not None else None)
        self.config = config
        if self.evaluation == "Accept":
            self.backpropagate(1, self.config)
        else:
            self.backpropagate(0, self.config)

    def calculate_score(self, config: Dict) -> float:
        """Calculates the node's score based on execution results and Critic's evaluation."""
        logging.info(f"Calculating score for node with strategy: {self.strategy}")
        execution_score = 0

        if self.solution:  # Only calculate if a solution exists
            log = execute_sas_code(self.solution)
            num_tests = 0
            passed_tests = 0

            # Example: Simple pass/fail scoring (adapt based on your test case format)
            for line in log.splitlines():
                if "passed:" in line.lower():
                    num_tests += 1
                    if "true" in line.lower():
                        passed_tests += 1
                if "test ok" in line.lower():
                    num_tests += 1
                    passed_tests += 1

            if num_tests > 0:
                execution_score = passed_tests / num_tests

        # Combine execution score and Critic's score (weighted average)
        # Use weights from config
        self.score = config["execution_weight"] * execution_score + config["critic_weight"] * self.critic_score
        logging.info(f"Node score calculated: {self.score}")
        return self.score

    def upper_confidence_bound(self, exploration_weight: float = math.sqrt(2)) -> float:
        if self.visits == 0:
            return float('inf')
        exploitation = self.value / self.visits
        exploration = math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration_weight * exploration

    def backpropagate(self, reward: float, config: Dict) -> None:
        """Updates the node's value and propagates it to its ancestors."""
        logging.info(f"Backpropagating reward: {reward} for node with strategy: {self.strategy}")
        node = self
        while node:
            node.visits += 1
            node.value = (node.value * (node.visits - 1) + reward) / node.visits
            if node.evaluation == "Accept":
                node.score = 1
            else:
                node.calculate_score(config)  # Recalculate score during backpropagation
            node = node.parent

    def select_child(self) -> "Node":
        """Selects the best child using UCB."""
        logging.info(f"Selecting child for node with strategy: {self.strategy}")
        best_child = None
        best_ucb = float('-inf')
        for child in self.children:
            ucb = child.upper_confidence_bound()
            if ucb > best_ucb:
                best_ucb = ucb
                best_child = child
        logging.info(f"Selected child with strategy: {best_child.strategy if best_child else 'None'}")
        return best_child

    def is_terminal(self) -> bool:
        return self.evaluation == "Accept"

    def get_best_solution(self) -> "Node":
        """Retrieves the best solution from the subtree rooted at this node."""
        logging.info(f"Getting best solution for node with strategy: {self.strategy}")
        all_nodes = [self] + self._get_all_children()
        best_node = max(
            all_nodes,
            key=lambda node: (1 if node.evaluation == "Accept" else 0) * node.value
        )
        logging.info(f"Best solution found: {best_node.solution}")
        return best_node

    def _get_all_children(self) -> List["Node"]:
        """Retrieves all children of this node (recursively)."""
        all_nodes = []
        queue = deque(self.children)
        while queue:
            node = queue.popleft()
            all_nodes.append(node)
            queue.extend(node.children)
        return all_nodes

class TreeState(TypedDict):
    root: Optional[Node]
    input: str

In [12]:
# --- Tree Expansion Functions ---
def generate_initial_strategies(state: TreeState, llm, client, necessary_metadata, config) -> dict:
    """Generates initial strategies and creates the root node."""
    question = state["input"]
    logging.info(f"Generating initial strategies for question: {question}")
    thinker = Thinker(llm, necessary_metadata)
    solver = Solver(llm, necessary_metadata)
    critic = Critic(llm, client, necessary_metadata, config)

    # Create a dummy root node
    root = Node(strategy="Root", solution="", evaluation="Expand", question=question, config=config)

    # Generate and evaluate strategies one by one
    strategies = []
    for i in range(5):  # Assuming a maximum of 5 initial strategies as before
        
        if i == 0 :
            # 1. Generate a strategy
            strategy = thinker.generate_strategies(question, 1)
            new_strategy = strategy[0]
        else :
            strategy = thinker.generate_strategies(question, 1, strategies)
            new_strategy = strategy[0]
            
        if not new_strategy or new_strategy.lower() == "no further strategies.":
            logging.info("No further strategies generated.")
            break

        strategies.append(new_strategy)
        logging.info(f"Generated strategy {i+1}: {new_strategy}")

        # 2. Create a child node for the new strategy
        child_node = Node(strategy=new_strategy, solution="", evaluation="Expand", parent=root, question=question,
                          config=config)
        root.children.append(child_node)

        # 3. Generate a solution for the strategy
        solution = solver.generate_solution(question, new_strategy)
        child_node.solution = solution  # Update the node with the solution

        # 4. Evaluate the solution
        evaluation, critic_score, feedback = critic.evaluate_solution(question, solution, new_strategy)
        child_node.evaluation = evaluation
        child_node.critic_score = critic_score
        child_node.calculate_score(config)

        # 5. Check if the solution is verified ("Accept")
        if evaluation == "Accept":
            logging.info(f"Verified solution found for strategy: {new_strategy}. Stopping initial strategy generation.")
            return {**state, "root": root}  # Stop early

    logging.info("Finished generating initial strategies.")
    return {**state, "root": root}

def expand_tree(state: TreeState, llm, client, necessary_metadata, config: Dict, max_depth: int = 5) -> dict:
    """Expands the tree based on node evaluations and scores."""
    root = state["root"]
    question = state["input"]

    # Check if an "Accept" node already exists before expanding
    if any(node.evaluation == "Accept" for node in root._get_all_children()):
        logging.info("Acceptable solution already exists. Skipping expansion.")
        return state

    logging.info(f"Expanding tree for question: {question}")
    current_node = select_node_to_expand(root)
    logging.info(f"Selected node for expansion: {current_node.strategy}")

    solver = Solver(llm, necessary_metadata)
    critic = Critic(llm, client, necessary_metadata, config)
    debugger = Debugger(llm, necessary_metadata)
    thinker = Thinker(llm, necessary_metadata)

    if current_node.evaluation == "Expand":
        # Generate solution and evaluate
        logging.info(f"Generating solution for node: {current_node.strategy}")
        solution = solver.generate_solution(question, current_node.strategy)
        logging.info(f"Evaluating solution: {solution}")
        evaluation, critic_score, feedback = critic.evaluate_solution(question, solution, current_node.strategy)

        # Update current node and create a new node
        current_node.solution = solution
        current_node.evaluation = evaluation
        current_node.critic_score = critic_score
        current_node.calculate_score(config)

        new_node = Node(strategy=current_node.strategy, solution=solution, evaluation=evaluation,
                        critic_score=critic_score, parent=current_node, question=question, config=config)
        current_node.children.append(new_node)

    elif current_node.evaluation == "Refine":
        # Get reflections and refine solution
        log = execute_sas_code(current_node.solution)
        try:
            error_reflection = critic._get_error_reflection(current_node.question, current_node.solution, log)
            requirement_reflection = critic._get_requirement_reflection(current_node.question, current_node.solution,
                                                                        log)
        except Exception as e:
            logging.error(f"Error during reflections generation : {e}")
            raise

        feedback = ""
        if error_reflection.error_count > 0:
            feedback += "Errors:\n"
            feedback += f"- {error_reflection.error_log}\n"

        if not requirement_reflection.requirements_met:
            feedback += "Missing Requirements:\n"
            feedback += f"- {requirement_reflection.missing_requirements}\n"

        # Generate reflections using Thinker
        logging.info(f"Generating reflections for node: {current_node.strategy}")
        reflections = thinker.generate_reflections(current_node.question, current_node.solution, feedback, log,
                                                   current_node.strategy)

        if reflections == "the code is good":
            # If no issues, simply append the current node again (no changes)
            logging.info(f"No issues found, appending current node again: {current_node.strategy}")
            evaluation, critic_score, critic_feedback = critic.evaluate_solution(
                current_node.question, current_node.solution, current_node.strategy
            )
            new_node = Node(strategy=current_node.strategy, solution=current_node.solution, evaluation=evaluation,
                            critic_score=critic_score, parent=current_node, question=question, config=config)
            current_node.children.append(new_node)
        else:
            # Refine solution using Debugger
            logging.info(f"Refining solution for node: {current_node.strategy}")
            refined_solution = debugger.generate_refinement(question, current_node.solution, reflections)
            logging.info(f"Evaluating refined solution: {refined_solution}")
            evaluation, critic_score, feedback = critic.evaluate_solution(question, refined_solution,
                                                                          current_node.strategy)

            # Create new node for refined solution
            new_node = Node(strategy=current_node.strategy, solution=refined_solution, evaluation=evaluation,
                            critic_score=critic_score, parent=current_node, question=question, config=config)
            current_node.children.append(new_node)

            # Dynamic expansion based on Critic's score
            if critic_score > config["expansion_threshold"] and current_node.depth < max_depth:
                logging.info(
                    f"Critic score {critic_score} above expansion threshold {config['expansion_threshold']}. Generating new strategies.")
                new_strategies = thinker.generate_strategies(question)
                for new_strategy in new_strategies:
                    new_node = Node(strategy=new_strategy, solution="", evaluation="Expand", parent=current_node,
                                    question=question, config=config)
                    current_node.children.append(new_node)
            elif critic_score < config["abort_threshold"]:
                logging.info(
                    f"Critic score {critic_score} below abort threshold {config['abort_threshold']}. Marking node as Abort.")
                current_node.evaluation = "Abort"

    return {**state, "root": root}

def select_node_to_expand(root: Node) -> Node:
    """Selects a node to expand based on UCB and evaluation."""
    node = root
    while node.children:
        expandable_children = [child for child in node.children if child.evaluation == "Expand"]
        if expandable_children:
            logging.info("Selecting node to expand...")
            return max(expandable_children, key=lambda child: child.upper_confidence_bound())

        refinable_children = [child for child in node.children if child.evaluation == "Refine"]
        if refinable_children:
            logging.info("Selecting node to refine...")
            return max(refinable_children, key=lambda child: child.upper_confidence_bound())

        abortable_children = [child for child in node.children if child.evaluation == "Abort"]
        if abortable_children:
            logging.info("Found abortable node. Moving to parent.")
            if node == root:
                return node
            else:
                node = node.parent
        else:
            node = node.children[0]

    return node

def should_continue(state: TreeState) -> str:
    """Determines whether the search should continue."""
    root = state["root"]

    # If a solution is found, stop
    if any(node.evaluation == "Accept" for node in root._get_all_children()):
        logging.info("Solution found. Stopping...")
        return END

    # If any node is marked for expansion, continue
    if any(node.evaluation == "Expand" for node in root._get_all_children()):
        logging.info("Expansion required. Continuing...")
        return "expand"

    # If any node is marked for expansion, continue
    if any(node.evaluation == "Refine" for node in root._get_all_children()):
        logging.info("Refinement required. Continuing...")
        return "expand"

    # If max depth is reached for all nodes, stop
    if all(node.depth >= 5 for node in root._get_all_children()):
        logging.info("Max depth reached. Stopping...")
        return END

    logging.info("No action determined. Continuing by default...")
    return "expand"

In [21]:
def main():
    # Initialize configuration (tune these parameters)
    config = {
        "execution_weight": 0.6,
        "critic_weight": 0.4,
        "error_weight": 0.2,
        "requirement_weight": 0.3,
        "verification_fail_score": 0.5,
        "abort_threshold": 0.4,
        "expansion_threshold": 0.6
    }
    # Example query
    query = """
La liste des véhicules ayant enregistré une forte sinistralité au cours des 5 dernières années (Volet MAT) :

-	Le nbre = 5 sinistres 
-	Le coût total : 10000dt
 
Les critères que nous aimerions inclure sont les suivants :
                                                                                                                                                                                                                                             
•	Marque et modèle du véhicule
•	Nombre total de sinistres déclarés par véhicule
•	Montant total des indemnités versées
•	Types de sinistres les plus fréquents
•	La S/P
•	Avec les autres critères habituels du tableau (date déclaration- date sinistre – garantie facultative si existe……)
    """

    # Initialize LLM and necessary metadata
    llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp", convert_system_message_to_human=True)
    # Mocking the client for testing purposes, replace with actual client if needed
    client = Client("Qwen/Qwen2.5-Coder-Artifacts")
    necessary_metadata = CachedMetadata.get_instance(metadata_dict, query)._metadata

    # Initialize the tree state
    state = {"input": query}
    state = generate_initial_strategies(state, llm, necessary_metadata, config)

    # Main loop
    while should_continue(state) == "expand":
        state = expand_tree(state, llm, client, necessary_metadata, config)

    # Find the best solution
    best_solution_node = state["root"].get_best_solution()

    if best_solution_node:
        logging.info(f"Best solution found:\n{best_solution_node.solution}")
        if best_solution_node.evaluation == "Accept":
            logging.info("Solution was accepted by the Critic.")
        else:
            logging.info("Solution was not accepted, but it's the best found within the search limits.")

        # Extract data if needed
        extract_data(best_solution_node.solution,client)
    else:
        logging.info("No solution found within the search limits.")

if __name__ == "__main__":
    main()

Loaded as API: https://qwen-qwen2-5-coder-artifacts.hf.space ✔


2024-12-25 12:31:20,994 - INFO - HTTP Request: GET https://qwen-qwen2-5-coder-artifacts.hf.space/config "HTTP/1.1 200 OK"
-------------------------------------------
2024-12-25 12:31:22,105 - INFO - HTTP Request: GET https://qwen-qwen2-5-coder-artifacts.hf.space/gradio_api/info?serialize=False "HTTP/1.1 200 OK"
-------------------------------------------
2024-12-25 12:31:22,117 - INFO - Returning existing instance of CachedMetadata.
-------------------------------------------
2024-12-25 12:31:22,120 - INFO - Generating initial strategies for question: 
La liste des véhicules ayant enregistré une forte sinistralité au cours des 5 dernières années (Volet MAT) :

-	Le nbre = 5 sinistres 
-	Le coût total : 10000dt
 
Les critères que nous aimerions inclure sont les suivants :
                                                                                                                                                                                                                          

ResourceExhausted: 429 Resource has been exhausted (e.g. check quota).