In [26]:
import os 
from typing import Dict, Any, List, Optional
from typing_extensions import TypedDict, NotRequired
import requests
from langchain.agents.middleware import AgentState, AgentMiddleware,ModelRequest,ModelResponse,hook_config
from langchain_openai import AzureChatOpenAI
import os 
from dotenv import load_dotenv 
from langchain_core.tools import Tool
import logging
import sys
from langgraph.runtime import Runtime
from langchain.messages import HumanMessage, AIMessage, SystemMessage


In [27]:

class AzureOpenAIConfig:
    """Production Azure OpenAI configuration with validation and health checks."""
    
    def __init__(self):
        """Initialize and validate Azure OpenAI configuration."""
        # Load configuration from environment
        self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
        self.endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
        self.api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview")
        self.deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4.1")
        self.gpt4_deployment_name = os.getenv("AZURE_OPENAI_GPT4_DEPLOYMENT_NAME", "gpt-4.1")
        
        # Configuration limits
        self.max_retries = int(os.getenv("MAX_RETRIES", "3"))
        self.rate_limit_calls = int(os.getenv("RATE_LIMIT_CALLS_PER_MINUTE", "60"))
        self.max_conversation_messages = int(os.getenv("MAX_CONVERSATION_MESSAGES", "100"))
        self.request_timeout = int(os.getenv("REQUEST_TIMEOUT", "60"))
        
        # Validate configuration
        self._validate()
        
    def _validate(self):
        """Validate configuration and fail fast if invalid."""
        errors = []
        
        if not self.api_key:
            errors.append("AZURE_OPENAI_API_KEY is required")
        
        if not self.endpoint:
            errors.append("AZURE_OPENAI_ENDPOINT is required")
        elif not self.endpoint.startswith("https://"):
            errors.append("AZURE_OPENAI_ENDPOINT must start with https://")
        
        if not self.deployment_name:
            errors.append("AZURE_OPENAI_DEPLOYMENT_NAME is required")
        
        if errors:
            error_msg = "Configuration validation failed:\n" + "\n".join(f"  - {e}" for e in errors)
            logger.error(error_msg)
            raise ValueError(error_msg)
        
        # Ensure endpoint has trailing slash
        if not self.endpoint.endswith("/"):
            self.endpoint += "/"
        
        logger.info("Azure OpenAI configuration validated successfully")
    
    def get_model(
        self,
        deployment_name: Optional[str] = None,
        temperature: float = 0.7,
        max_tokens: Optional[int] = None,
    ) -> AzureChatOpenAI:
        """Get configured Azure OpenAI model instance."""
        deployment = deployment_name or self.deployment_name
        
        return AzureChatOpenAI(
            azure_deployment=deployment,
            api_version=self.api_version,
            temperature=temperature,
            max_tokens=max_tokens,
            azure_endpoint=self.endpoint,
            api_key=self.api_key,
            timeout=self.request_timeout,
        )
    
    def health_check(self) -> Dict[str, Any]:
        """Perform health check on Azure OpenAI connection."""
        try:
            model = self.get_model()
            response = model.invoke([HumanMessage(content="Health check")])
            
            return {
                "status": "healthy",
                "endpoint": self.endpoint,
                "deployment": self.deployment_name,
                "timestamp": datetime.utcnow().isoformat(),
            }
        except Exception as e:
            logger.error(f"Health check failed: {e}")
            return {
                "status": "unhealthy",
                "error": str(e),
                "timestamp": datetime.utcnow().isoformat(),
            }

In [28]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('agent.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)


# CUSTOM STATE SCHEMA


In [29]:


# ============================================================================
# CUSTOM STATE SCHEMA
# ============================================================================

class ProductionAgentState(AgentState):
    """Extended agent state for production use."""
    
    # User context
    user_id: NotRequired[str]
    session_id: NotRequired[str]
    user_metadata: NotRequired[Dict[str, Any]]
    
    # Tracking
    model_call_count: NotRequired[int]
    tool_call_count: NotRequired[int]
    total_tokens: NotRequired[int]
    conversation_start_time: NotRequired[float]
    
    # Rate limiting
    rate_limit_remaining: NotRequired[int]
    
    # Safety
    safety_violations: NotRequired[int]
    content_flags: NotRequired[List[str]]
    
    # Performance
    total_latency_ms: NotRequired[float]
    last_model_latency_ms: NotRequired[float]




# PRODUCTION MIDDLEWARE



In [30]:

class ProductionLoggingMiddleware(AgentMiddleware[ProductionAgentState]):
    """Production-grade logging with structured output and metrics."""
    
    state_schema = ProductionAgentState
    
    def __init__(self):
        super().__init__()
        self.logger = logging.getLogger("ProductionLogging")
        self._start_time = None
    
    def before_agent(self, state: ProductionAgentState, runtime: Runtime) -> Dict[str, Any]:
        """Initialize tracking at agent start."""
        self._start_time = time.time()
        
        self.logger.info(
            "Agent started",
            extra={
                "user_id": state.get("user_id", "unknown"),
                "session_id": state.get("session_id", "unknown"),
                "initial_messages": len(state.get("messages", [])),
            }
        )
        
        return {
            "model_call_count": 0,
            "tool_call_count": 0,
            "total_tokens": 0,
            "conversation_start_time": self._start_time,
            "total_latency_ms": 0,
        }
    
    def before_model(self, state: ProductionAgentState, runtime: Runtime) -> Dict[str, Any]:
        """Log before model call."""
        count = state.get("model_call_count", 0) + 1
        
        self.logger.info(
            f"Model call #{count}",
            extra={
                "user_id": state.get("user_id"),
                "message_count": len(state.get("messages", [])),
            }
        )
        
        return {"model_call_count": count}
    
    def after_model(self, state: ProductionAgentState, runtime: Runtime) -> Dict[str, Any]:
        """Track model response metrics."""
        updates = {}
        
        if state.get("messages"):
            last_msg = state["messages"][-1]
            
            # Estimate tokens
            if hasattr(last_msg, "content"):
                tokens = len(str(last_msg.content)) // 4
                updates["total_tokens"] = state.get("total_tokens", 0) + tokens
            
            # Count tool calls
            if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
                tool_count = len(last_msg.tool_calls)
                updates["tool_call_count"] = state.get("tool_call_count", 0) + tool_count
                
                self.logger.info(
                    f"Tools requested: {tool_count}",
                    extra={
                        "tools": [tc.get("name") for tc in last_msg.tool_calls]
                    }
                )
        
        return updates
    
    def after_agent(self, state: ProductionAgentState, runtime: Runtime) -> Dict[str, Any]:
        """Log final metrics."""
        if self._start_time:
            duration = time.time() - self._start_time
            
            self.logger.info(
                "Agent completed",
                extra={
                    "user_id": state.get("user_id"),
                    "duration_seconds": round(duration, 2),
                    "model_calls": state.get("model_call_count", 0),
                    "tool_calls": state.get("tool_call_count", 0),
                    "total_messages": len(state.get("messages", [])),
                    "total_tokens": state.get("total_tokens", 0),
                }
            )
        
        return None

In [31]:
class ProductionRetryMiddleware(AgentMiddleware):
    """ Production retry logic with exponential backoff and circuit breaker. """
    def __init__(self, max_retries: int = 3, base_delay: float = 1.0):
        super().__init__()
        self.max_retries = max_retries
        self.base_delay = base_delay
        self.logger = logging.getLogger("ProductionRetry")

        # circuit breaker state
        self.failure_count = 0
        self.circuit_open = False
        self.circuit_open_until = 0

    def _is_circuit_open(self) -> bool:
        if self.circuit_open and time.time() >= self.circuit_open_until:
            self.logger.info("Circuit breaker reset")
            self.circuit_open = False
            self.failure_count = 0
        return self.circuit_open

    def wrap_model_call(self, request : ModelRequest,
                        handler ,) -> ModelResponse:
        """ Wrap model call with retry logic. """
        #check circuit breaker
        if self._is_circuit_open():
            self.logger.warning("Circuit breaker open - rejecting model call")
            raise Exception("Service temporarily unavailable due to repeated failures")

        last_exception = None
        for attempt in range(1, self.max_retries + 1):
            try : 
                response = handler(request)
                #reset failure count on success
                self.failure_count = 0
                if attempt > 1:
                    self.logger.info(f"Model call succeeded on attempt {attempt}")
                return response
            except Exception as e:
                last_exception = e
                self.failure_count += 1
                self.logger.error(
                    f"Model call failed on attempt {attempt}: {str(e)}",
                    exc_info=True
                )
                if self.failure_count >= self.max_retries:
                    self.circuit_open = True
                    self.circuit_open_until = time.time() + 60  # open for 60 seconds
                    self.logger.error("Circuit breaker opened due to repeated failures")
                    raise Exception("Service temporarily unavailable due to repeated failures") from e
                else:
                    delay = self.base_delay * (2 ** (attempt - 1))
                    self.logger.info(f"Retrying after {delay:.2f} seconds...")
                    time.sleep(delay)
        # If we exhaust retries, raise the last exception
        raise last_exception



In [32]:
class ProductionRateLimitMiddleware(AgentMiddleware[ProductionAgentState]):
    """ Production rate limiting middleware that tracks API usage and enforces limits. """
    state_schema = ProductionAgentState

    def __init__(self, max_calls_per_minute: int = 60):
        super().__init__()
        self.max_calls_per_minute = max_calls_per_minute
        self.logger = logging.getLogger("ProductionRateLimit")

        # per user tracking
        self.user_calls : Dict[str, List[float]] = {}

    def _check_rate_limit(self, user_id : str) -> bool:
        """Check if user is within rate limit."""
        now = time.time()
        cutoff = now - 60  # 1 minute window

        # initialize tracking for new user
        if user_id not in self.user_calls:
            self.user_calls[user_id] = []
        
        # remove old calls
        self.user_calls[user_id] = [
            call_time for call_time in self.user_calls[user_id]
            if call_time > cutoff
        ]

        # check limits 
        if len(self.user_calls[user_id]) >= self.max_calls_per_minute:
            self.logger.warning(f"User {user_id} exceeded rate limit")
            return False

        # record this call
        self.user_calls[user_id].append(now)
        return True

    @hook_config(can_jump_to=["end"])
    def before_model(
        self,
        state: ProductionAgentState,
        runtime: Runtime
    ) -> Dict[str, Any]:
        """Enforce rate limits before model call."""
        user_id = state.get("user_id", "default")
        
        if not self._check_rate_limit(user_id):
            remaining_calls = len([
                t for t in self.user_calls[user_id]
                if t > time.time() - 60
            ])
            
            self.logger.warning(
                f"Rate limit exceeded for user {user_id}",
                extra={"calls_in_window": remaining_calls}
            )
            
            return {
                "messages": [
                    AIMessage(
                        content=(
                            f"Rate limit exceeded. You can make "
                            f"{self.max_calls_per_minute} requests per minute. "
                            f"Please wait before trying again."
                        )
                    )
                ],
                "jump_to": "end"
            }
        
        # Update remaining count
        remaining = self.max_calls_per_minute - len(self.user_calls[user_id])
        return {"rate_limit_remaining": remaining}



In [33]:

class ProductionSafetyMiddleware(AgentMiddleware[ProductionAgentState]):
    """Production safety and content moderation."""
    
    state_schema = ProductionAgentState
    
    def __init__(self, max_messages: int = 100, max_violations: int = 3):
        super().__init__()
        self.max_messages = max_messages
        self.max_violations = max_violations
        self.logger = logging.getLogger("ProductionSafety")
        
        # Content filters
        self.blocked_patterns = [
            "ignore previous instructions",
            "disregard your guidelines",
            "bypass safety",
            "jailbreak",
            "prompt injection",
        ]
    
    @hook_config(can_jump_to=["end"])
    def before_model(
        self,
        state: ProductionAgentState,
        runtime: Runtime
    ) -> Dict[str, Any]:
        """Perform safety checks before model call."""
        messages = state.get("messages", [])
        violations = state.get("safety_violations", 0)
        
        # Check conversation length
        if len(messages) >= self.max_messages:
            self.logger.warning(
                f"Conversation length limit reached: {len(messages)} messages"
            )
            return {
                "messages": [
                    AIMessage(
                        content=f"Conversation limit of {self.max_messages} messages reached. "
                                "Please start a new conversation."
                    )
                ],
                "jump_to": "end"
            }
        
        # Check violation count
        if violations >= self.max_violations:
            self.logger.error(f"Maximum violations reached: {violations}")
            return {
                "messages": [
                    AIMessage(
                        content="This conversation has been terminated due to safety policy violations."
                    )
                ],
                "jump_to": "end"
            }
        
        # Content filtering
        if messages:
            last_message = messages[-1]
            if hasattr(last_message, "content"):
                content = str(last_message.content).lower()
                
                for pattern in self.blocked_patterns:
                    if pattern.lower() in content:
                        new_violations = violations + 1
                        
                        self.logger.warning(
                            f"Blocked pattern detected: '{pattern}'",
                            extra={"user_id": state.get("user_id")}
                        )
                        
                        return {
                            "messages": [
                                AIMessage(
                                    content=f"Content policy violation detected. "
                                            f"({new_violations}/{self.max_violations} strikes)"
                                )
                            ],
                            "safety_violations": new_violations,
                            "content_flags": state.get("content_flags", []) + [pattern],
                            "jump_to": "end"
                        }
        
        return None


In [34]:

class AzureDynamicModelMiddleware(AgentMiddleware[ProductionAgentState]):
    """Production-grade dynamic Azure deployment selection."""
    
    state_schema = ProductionAgentState
    
    def __init__(self, config: AzureOpenAIConfig, complexity_threshold: int = 10):
        super().__init__()
        self.config = config
        self.complexity_threshold = complexity_threshold
        self.logger = logging.getLogger("AzureModelSelection")
        
        # Initialize models
        self.simple_model = config.get_model(config.deployment_name)
        self.complex_model = config.get_model(config.gpt4_deployment_name)
    
    def _analyze_complexity(self, request: ModelRequest) -> str:
        """Analyze request complexity."""
        message_count = len(request.messages)
        
        # Long conversations need better model
        if message_count > self.complexity_threshold:
            return "complex"
        
        # Check for complexity indicators
        if request.messages:
            last_msg = request.messages[-1]
            if hasattr(last_msg, "content"):
                content = str(last_msg.content).lower()
                
                # Complexity keywords
                if any(kw in content for kw in [
                    "analyze", "detailed", "comprehensive", "technical",
                    "explain in detail", "step by step"
                ]):
                    return "complex"
                
                # Long messages
                if len(content) > 500:
                    return "complex"
        
        return "simple"
    
    def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
        """Select appropriate Azure deployment."""
        complexity = self._analyze_complexity(request)
        
        if complexity == "complex":
            model = self.complex_model
            deployment = self.config.gpt4_deployment_name
        else:
            model = self.simple_model
            deployment = self.config.deployment_name
        
        self.logger.info(
            f"Selected deployment: {deployment} (complexity: {complexity})"
        )
        
        return handler(request.override(model=model))
