In [None]:
from tqdm import tqdm
from openai import OpenAI, AsyncOpenAI
import re
from typing import Optional, Union, List, get_origin, get_args, Any, Dict, Literal
import inspect
import asyncio
import json
import logging
import pandas as pd
from pydantic import BaseModel, Field, create_model
import math
# import demjson3


def generate_tools_spec(*functions):
    # Mapping of Python types to JSON Schema types
    type_map = {
        str: "string",
        int: "integer",
        float: "number",
        bool: "boolean",
        list: "array",
        dict: "object",
        type(None): "null"
    }
    tools = []
    
    for func in functions:
        func_name = func.__name__
        func_description = (func.__doc__ or "").strip()
        sig = inspect.signature(func)
        
        properties = {}
        required = []
        
        for param in sig.parameters.values():
            # Skip *args and **kwargs as they cannot be described in JSON schema easily
            if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
                continue
            
            param_name = param.name
            annotation = param.annotation
            json_type = "string"  # default type for fallback

            if annotation is not inspect._empty:
                origin = get_origin(annotation)
                
                # Handle Literal types (e.g., Literal["T", "N", "M"])
                if origin is Literal:
                    literal_args = get_args(annotation)
                    
                    # If all literal args are strings, produce a string enum
                    if all(isinstance(arg, str) for arg in literal_args):
                        properties[param_name] = {
                            "type": "string",
                            "enum": list(literal_args)
                        }
                    # If all are integers, produce an integer enum, etc.
                    elif all(isinstance(arg, int) for arg in literal_args):
                        properties[param_name] = {
                            "type": "integer",
                            "enum": list(literal_args)
                        }
                    else:
                        # Fallback if the Literal contains mixed or unsupported types
                        properties[param_name] = {"type": "string"}
                
                # Handle Optional[X] or Union[X, None]
                elif origin is Union:
                    union_args = [t for t in get_args(annotation) if t is not type(None)]
                    if len(union_args) == 1:
                        # e.g. Optional[str] -> just str
                        real_type = union_args[0]
                        origin2 = get_origin(real_type)
                        
                        if origin2 is Literal:
                            # If inside an Optional[Literal[...]]
                            literal_args = get_args(real_type)
                            if all(isinstance(arg, str) for arg in literal_args):
                                properties[param_name] = {
                                    "type": "string",
                                    "enum": list(literal_args)
                                }
                            elif all(isinstance(arg, int) for arg in literal_args):
                                properties[param_name] = {
                                    "type": "integer",
                                    "enum": list(literal_args)
                                }
                            else:
                                properties[param_name] = {"type": "string"}
                        else:
                            # Map direct type to JSON schema
                            json_type = type_map.get(real_type, "string")
                            properties[param_name] = {"type": json_type}
                    else:
                        # More complex Unions not automatically handled; fallback to string
                        properties[param_name] = {"type": "string"}
                
                # If it's a known type (str, int, etc.)
                elif annotation in type_map:
                    json_type = type_map[annotation]
                    properties[param_name] = {"type": json_type}
                
                # Handle typing.List[...] or typing.Dict[...] 
                elif origin in type_map:
                    json_type = type_map[origin]
                    if json_type == "array":
                        # For list[...] or array
                        item_type = "string"
                        args = get_args(annotation)
                        if args and args[0] in type_map:
                            item_type = type_map[args[0]]
                        properties[param_name] = {
                            "type": "array",
                            "items": {"type": item_type}
                        }
                    elif json_type == "object":
                        # For dict[...] or any unhandled complex mapping
                        properties[param_name] = {"type": "object"}
                
                else:
                    # Fallback if we can't detect the type
                    properties[param_name] = {"type": "string"}
            
            else:
                # No annotation; assume string
                properties[param_name] = {"type": "string"}

            # Mark as required if no default value
            if param.default is inspect._empty:
                required.append(param_name)
        
        tool_dict = {
            "type": "function",
            "function": {
                "name": func_name,
                "description": func_description,
                "parameters": {
                    "type": "object",
                    "properties": properties
                }
            }
        }
        if required:
            tool_dict["function"]["parameters"]["required"] = required
        
        tools.append(tool_dict)
    
    return tools


async def extract_information(info_to_extract: List[str], pathology_text: str) -> Dict[str, str]:
    """
    Extracts relevant information from a given pathology text.

    Args:
        info_to_extract (List[str]): A list of information fields to be extracted,
            e.g. ["tumor_size", "depth_of_invasion", ...].
        pathology_text (str): The full text of the pathology report.
    
    Returns:
        Dict[str, str]: A dictionary mapping each requested field to the extracted information.
    """
    client = AsyncOpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
    model_name = "meta-llama/Llama-3.3-70B-Instruct"

    async def gather_responses():
        tasks = []
        for field in info_to_extract:
            prompt = f"""You are given a pathology report:

\"\"\"{pathology_text}\"\"\"

Please extract the information for the field: {field}.
Provide a concise answer containing only the relevant information for that field.
"""
            tasks.append(
                client.chat.completions.create(
                    messages=[{"role": "user", "content": prompt}],
                    model=model_name,
                    temperature=0.1,
                    max_tokens=500,
                )
            )

        # Run all extraction queries in parallel
        responses = await asyncio.gather(*tasks)

        # Build a dictionary of field -> extracted text
        extracted_info = {}
        for field, response in zip(info_to_extract, responses):
            # Depending on your LLM library's return format, adjust how you access content.
            raw_answer = response.choices[0].message.content.strip()
            extracted_info[field] = raw_answer

        return extracted_info

    # Run the async logic and return the results
    return await gather_responses()

def compare_numerical_values(
    value: float,
    min_value: float = None,
    max_value: float = None,
    inclusive_min: bool = True,
    inclusive_max: bool = True
) -> bool:
    """
    Compares a given numeric value against optional minimum and maximum thresholds.
    Args:
        value (float): The numeric value to compare.
        min_value (float, optional): The lower threshold. 
                                     If None, no lower bound check is performed.
        max_value (float, optional): The upper threshold. 
                                     If None, no upper bound check is performed.
        inclusive_min (bool): Whether the comparison with min_value 
                              should be inclusive (value >= min_value) 
                              or exclusive (value > min_value).
        inclusive_max (bool): Whether the comparison with max_value 
                              should be inclusive (value <= max_value) 
                              or exclusive (value < max_value).

    Returns:
        bool: True if the value satisfies all specified boundary conditions; 
              False otherwise.

    Examples:
        compare_numerical_values(3.2, min_value=2, max_value=5) 
            -> True (assuming inclusive checks)
        compare_numerical_values(2, min_value=2, max_value=5, inclusive_min=False)
            -> False, since 2 is not strictly > 2
    """
    if min_value is not None:
        if inclusive_min:
            if value < min_value:
                return False
        else:
            if value <= min_value:
                return False

    if max_value is not None:
        if inclusive_max:
            if value > max_value:
                return False
        else:
            if value >= max_value:
                return False

    return True

def produce_final_staging_response(agent_response: str) -> Dict[str, str]:
    """
    Takes the agent's final response and reformat it into a JSON schema with 'reasoning' and 'stage' as keys.

    Args:
        agent_response (str): The final response from the agent (after all internal processing).

    Returns:
        Dict[str, str]: A dictionary containing two keys: 'reasoning' and 'stage'.
    """

    # Example: we create a second LLM client specifically for formatting the final output
    formatting_client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
    formatting_model = "meta-llama/Llama-3.3-70B-Instruct"

    prompt = f"""You are given the final reasoning and conclusions about a cancer T-staging task:

\"\"\"{agent_response}\"\"\"

Please provide valid JSON (and ONLY JSON, without extra text) with the following structure:
{{
"reasoning": "A brief explanation of the reasoning that led to the stage conclusion.",
"stage": "The final T stage (e.g. T1, T2, T3...)"
}}

Make sure the output is strictly valid JSON.
"""
        
    class ResponseStage(BaseModel):
        reasoning: str = Field(
            description="A step-by-step explanation for how you arrived at the predicted T stage."
        )
        stage: Literal["T1", "T2", "T3", "T4"] = Field(
            description="The final predicted T stage (T1, T2, T3, or T4)."
        )


    response = formatting_client.chat.completions.create(
        messages=[{"role": "user", "content": prompt}],
        model=formatting_model,
        temperature=0.1,
        extra_body={"guided_json": ResponseStage.model_json_schema()}
        )

    return response.choices[0].message.content


available_tools = {
    "extract_information": extract_information,

}
tools = generate_tools_spec(*available_tools.values())