In [19]:
import os
import sys
import json
import logging
import time
from typing import Dict, Any, List, Optional
from datetime import datetime
from pathlib import Path

from dotenv import load_dotenv
from langchain.agents import create_agent
from langchain.messages import HumanMessage, AIMessage, SystemMessage
from langchain_openai import AzureChatOpenAI
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_core.tools import Tool
from langchain.agents.middleware import (
    AgentMiddleware,
    AgentState,
    ModelRequest,
    ModelResponse,
    before_agent,
    after_agent,
    before_model,
    after_model,
    wrap_model_call,
    hook_config,
)
from langgraph.runtime import Runtime
from typing_extensions import NotRequired


In [20]:
# Load environment variables
load_dotenv()

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('agent.log'),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)


In [21]:
AZURE_OPENAI_ENDPOINT = "https://dhanush-ai507.cognitiveservices.azure.com/"
AZURE_OPENAI_DEPLOYMENT_NAME= "gpt-4.1"
AZURE_OPENAI_API_VERSION= "2024-12-01-preview"
AZURE_OPENAI_API_KEY= "CGSB7TAAbigunwOWa1mKRqVtn6q1q6adAPwySS7B1TuxLNXKlhtoJQQJ99BKACYeBjFXJ3w3AAAAACOGeW9"


In [22]:

class AzureOpenAIConfig:
    """Production Azure OpenAI configuration with validation and health checks."""
    
    def __init__(self):
        """Initialize and validate Azure OpenAI configuration."""
        # Load configuration from environment
        self.api_key = os.getenv("AZURE_OPENAI_API_KEY")
        self.endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
        self.api_version = os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview")
        self.deployment_name = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4.1")
        self.gpt4_deployment_name = os.getenv("AZURE_OPENAI_GPT4_DEPLOYMENT_NAME", "gpt-4.1")
        
        # Configuration limits
        self.max_retries = int(os.getenv("MAX_RETRIES", "3"))
        self.rate_limit_calls = int(os.getenv("RATE_LIMIT_CALLS_PER_MINUTE", "60"))
        self.max_conversation_messages = int(os.getenv("MAX_CONVERSATION_MESSAGES", "100"))
        self.request_timeout = int(os.getenv("REQUEST_TIMEOUT", "60"))
        
        # Validate configuration
        self._validate()
        
    def _validate(self):
        """Validate configuration and fail fast if invalid."""
        errors = []
        
        if not self.api_key:
            errors.append("AZURE_OPENAI_API_KEY is required")
        
        if not self.endpoint:
            errors.append("AZURE_OPENAI_ENDPOINT is required")
        elif not self.endpoint.startswith("https://"):
            errors.append("AZURE_OPENAI_ENDPOINT must start with https://")
        
        if not self.deployment_name:
            errors.append("AZURE_OPENAI_DEPLOYMENT_NAME is required")
        
        if errors:
            error_msg = "Configuration validation failed:\n" + "\n".join(f"  - {e}" for e in errors)
            logger.error(error_msg)
            raise ValueError(error_msg)
        
        # Ensure endpoint has trailing slash
        if not self.endpoint.endswith("/"):
            self.endpoint += "/"
        
        logger.info("Azure OpenAI configuration validated successfully")
    
    def get_model(
        self,
        deployment_name: Optional[str] = None,
        temperature: float = 0.7,
        max_tokens: Optional[int] = None,
    ) -> AzureChatOpenAI:
        """Get configured Azure OpenAI model instance."""
        deployment = deployment_name or self.deployment_name
        
        return AzureChatOpenAI(
            azure_deployment=deployment,
            api_version=self.api_version,
            temperature=temperature,
            max_tokens=max_tokens,
            azure_endpoint=self.endpoint,
            api_key=self.api_key,
            timeout=self.request_timeout,
        )
    
    def health_check(self) -> Dict[str, Any]:
        """Perform health check on Azure OpenAI connection."""
        try:
            model = self.get_model()
            response = model.invoke([HumanMessage(content="Health check")])
            
            return {
                "status": "healthy",
                "endpoint": self.endpoint,
                "deployment": self.deployment_name,
                "timestamp": datetime.utcnow().isoformat(),
            }
        except Exception as e:
            logger.error(f"Health check failed: {e}")
            return {
                "status": "unhealthy",
                "error": str(e),
                "timestamp": datetime.utcnow().isoformat(),
            }


In [23]:

class ProductionAgentState(AgentState):
    """Extended agent state for production use."""
    
    # User context
    user_id: NotRequired[str]
    session_id: NotRequired[str]
    user_metadata: NotRequired[Dict[str, Any]]
    
    # Tracking
    model_call_count: NotRequired[int]
    tool_call_count: NotRequired[int]
    total_tokens: NotRequired[int]
    conversation_start_time: NotRequired[float]
    
    # Rate limiting
    rate_limit_remaining: NotRequired[int]
    
    # Safety
    safety_violations: NotRequired[int]
    content_flags: NotRequired[List[str]]
    
    # Performance
    total_latency_ms: NotRequired[float]
    last_model_latency_ms: NotRequired[float]


In [24]:

class ProductionLoggingMiddleware(AgentMiddleware[ProductionAgentState]):
    """Production-grade logging with structured output and metrics."""
    
    state_schema = ProductionAgentState
    
    def __init__(self):
        super().__init__()
        self.logger = logging.getLogger("ProductionLogging")
        self._start_time = None
    
    def before_agent(self, state: ProductionAgentState, runtime: Runtime) -> Dict[str, Any]:
        """Initialize tracking at agent start."""
        self._start_time = time.time()
        
        self.logger.info(
            "Agent started",
            extra={
                "user_id": state.get("user_id", "unknown"),
                "session_id": state.get("session_id", "unknown"),
                "initial_messages": len(state.get("messages", [])),
            }
        )
        
        return {
            "model_call_count": 0,
            "tool_call_count": 0,
            "total_tokens": 0,
            "conversation_start_time": self._start_time,
            "total_latency_ms": 0,
        }
    
    def before_model(self, state: ProductionAgentState, runtime: Runtime) -> Dict[str, Any]:
        """Log before model call."""
        count = state.get("model_call_count", 0) + 1
        
        self.logger.info(
            f"Model call #{count}",
            extra={
                "user_id": state.get("user_id"),
                "message_count": len(state.get("messages", [])),
            }
        )
        
        return {"model_call_count": count}
    
    def after_model(self, state: ProductionAgentState, runtime: Runtime) -> Dict[str, Any]:
        """Track model response metrics."""
        updates = {}
        
        if state.get("messages"):
            last_msg = state["messages"][-1]
            
            # Estimate tokens
            if hasattr(last_msg, "content"):
                tokens = len(str(last_msg.content)) // 4
                updates["total_tokens"] = state.get("total_tokens", 0) + tokens
            
            # Count tool calls
            if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
                tool_count = len(last_msg.tool_calls)
                updates["tool_call_count"] = state.get("tool_call_count", 0) + tool_count
                
                self.logger.info(
                    f"Tools requested: {tool_count}",
                    extra={
                        "tools": [tc.get("name") for tc in last_msg.tool_calls]
                    }
                )
        
        return updates
    
    def after_agent(self, state: ProductionAgentState, runtime: Runtime) -> Dict[str, Any]:
        """Log final metrics."""
        if self._start_time:
            duration = time.time() - self._start_time
            
            self.logger.info(
                "Agent completed",
                extra={
                    "user_id": state.get("user_id"),
                    "duration_seconds": round(duration, 2),
                    "model_calls": state.get("model_call_count", 0),
                    "tool_calls": state.get("tool_call_count", 0),
                    "total_messages": len(state.get("messages", [])),
                    "total_tokens": state.get("total_tokens", 0),
                }
            )
        
        return None


class ProductionRetryMiddleware(AgentMiddleware):
    """Production retry logic with exponential backoff and circuit breaking."""
    
    def __init__(self, max_retries: int = 3, base_delay: float = 1.0):
        super().__init__()
        self.max_retries = max_retries
        self.base_delay = base_delay
        self.logger = logging.getLogger("ProductionRetry")
        
        # Circuit breaker state
        self.failure_count = 0
        self.circuit_open = False
        self.circuit_open_until = 0
    
    def _is_circuit_open(self) -> bool:
        """Check if circuit breaker is open."""
        if self.circuit_open and time.time() > self.circuit_open_until:
            self.circuit_open = False
            self.failure_count = 0
            self.logger.info("Circuit breaker closed - retrying operations")
        
        return self.circuit_open
    
    def wrap_model_call(
        self,
        request: ModelRequest,
        handler,
    ) -> ModelResponse:
        """Wrap model calls with retry logic and circuit breaking."""
        
        # Check circuit breaker
        if self._is_circuit_open():
            raise RuntimeError("Circuit breaker is open - too many failures")
        
        last_exception = None
        
        for attempt in range(self.max_retries):
            try:
                response = handler(request)
                
                # Reset failure count on success
                self.failure_count = 0
                
                if attempt > 0:
                    self.logger.info(f"Request succeeded on retry {attempt + 1}")
                
                return response
                
            except Exception as e:
                last_exception = e
                self.failure_count += 1
                
                # Open circuit breaker if too many failures
                if self.failure_count >= 5:
                    self.circuit_open = True
                    self.circuit_open_until = time.time() + 60  # Open for 1 minute
                    self.logger.error("Circuit breaker opened due to repeated failures")
                    raise
                
                if attempt < self.max_retries - 1:
                    delay = self.base_delay * (2 ** attempt)
                    self.logger.warning(
                        f"Attempt {attempt + 1} failed: {str(e)}. Retrying in {delay}s..."
                    )
                    time.sleep(delay)
                else:
                    self.logger.error(f"All {self.max_retries} attempts failed")
        
        raise last_exception


class ProductionRateLimitMiddleware(AgentMiddleware[ProductionAgentState]):
    """Production rate limiting with per-user quotas and monitoring."""
    
    state_schema = ProductionAgentState
    
    def __init__(self, max_calls_per_minute: int = 60):
        super().__init__()
        self.max_calls_per_minute = max_calls_per_minute
        self.logger = logging.getLogger("ProductionRateLimit")
        
        # Per-user tracking
        self.user_calls: Dict[str, List[float]] = {}
    
    def _check_rate_limit(self, user_id: str) -> bool:
        """Check if user is within rate limit."""
        now = time.time()
        cutoff = now - 60  # 1 minute window
        
        # Initialize or clean old calls
        if user_id not in self.user_calls:
            self.user_calls[user_id] = []
        
        # Remove calls outside window
        self.user_calls[user_id] = [
            t for t in self.user_calls[user_id] if t > cutoff
        ]
        
        # Check limit
        if len(self.user_calls[user_id]) >= self.max_calls_per_minute:
            return False
        
        # Record this call
        self.user_calls[user_id].append(now)
        return True
    
    @hook_config(can_jump_to=["end"])
    def before_model(
        self,
        state: ProductionAgentState,
        runtime: Runtime
    ) -> Dict[str, Any]:
        """Enforce rate limits before model call."""
        user_id = state.get("user_id", "default")
        
        if not self._check_rate_limit(user_id):
            remaining_calls = len([
                t for t in self.user_calls[user_id]
                if t > time.time() - 60
            ])
            
            self.logger.warning(
                f"Rate limit exceeded for user {user_id}",
                extra={"calls_in_window": remaining_calls}
            )
            
            return {
                "messages": [
                    AIMessage(
                        content=(
                            f"Rate limit exceeded. You can make "
                            f"{self.max_calls_per_minute} requests per minute. "
                            f"Please wait before trying again."
                        )
                    )
                ],
                "jump_to": "end"
            }
        
        # Update remaining count
        remaining = self.max_calls_per_minute - len(self.user_calls[user_id])
        return {"rate_limit_remaining": remaining}


class ProductionSafetyMiddleware(AgentMiddleware[ProductionAgentState]):
    """Production safety and content moderation."""
    
    state_schema = ProductionAgentState
    
    def __init__(self, max_messages: int = 100, max_violations: int = 3):
        super().__init__()
        self.max_messages = max_messages
        self.max_violations = max_violations
        self.logger = logging.getLogger("ProductionSafety")
        
        # Content filters
        self.blocked_patterns = [
            "ignore previous instructions",
            "disregard your guidelines",
            "bypass safety",
            "jailbreak",
            "prompt injection",
        ]
    
    @hook_config(can_jump_to=["end"])
    def before_model(
        self,
        state: ProductionAgentState,
        runtime: Runtime
    ) -> Dict[str, Any]:
        """Perform safety checks before model call."""
        messages = state.get("messages", [])
        violations = state.get("safety_violations", 0)
        
        # Check conversation length
        if len(messages) >= self.max_messages:
            self.logger.warning(
                f"Conversation length limit reached: {len(messages)} messages"
            )
            return {
                "messages": [
                    AIMessage(
                        content=f"Conversation limit of {self.max_messages} messages reached. "
                                "Please start a new conversation."
                    )
                ],
                "jump_to": "end"
            }
        
        # Check violation count
        if violations >= self.max_violations:
            self.logger.error(f"Maximum violations reached: {violations}")
            return {
                "messages": [
                    AIMessage(
                        content="This conversation has been terminated due to safety policy violations."
                    )
                ],
                "jump_to": "end"
            }
        
        # Content filtering
        if messages:
            last_message = messages[-1]
            if hasattr(last_message, "content"):
                content = str(last_message.content).lower()
                
                for pattern in self.blocked_patterns:
                    if pattern.lower() in content:
                        new_violations = violations + 1
                        
                        self.logger.warning(
                            f"Blocked pattern detected: '{pattern}'",
                            extra={"user_id": state.get("user_id")}
                        )
                        
                        return {
                            "messages": [
                                AIMessage(
                                    content=f"Content policy violation detected. "
                                            f"({new_violations}/{self.max_violations} strikes)"
                                )
                            ],
                            "safety_violations": new_violations,
                            "content_flags": state.get("content_flags", []) + [pattern],
                            "jump_to": "end"
                        }
        
        return None


class AzureDynamicModelMiddleware(AgentMiddleware[ProductionAgentState]):
    """Production-grade dynamic Azure deployment selection."""
    
    state_schema = ProductionAgentState
    
    def __init__(self, config: AzureOpenAIConfig, complexity_threshold: int = 10):
        super().__init__()
        self.config = config
        self.complexity_threshold = complexity_threshold
        self.logger = logging.getLogger("AzureModelSelection")
        
        # Initialize models
        self.simple_model = config.get_model(config.deployment_name)
        self.complex_model = config.get_model(config.gpt4_deployment_name)
    
    def _analyze_complexity(self, request: ModelRequest) -> str:
        """Analyze request complexity."""
        message_count = len(request.messages)
        
        # Long conversations need better model
        if message_count > self.complexity_threshold:
            return "complex"
        
        # Check for complexity indicators
        if request.messages:
            last_msg = request.messages[-1]
            if hasattr(last_msg, "content"):
                content = str(last_msg.content).lower()
                
                # Complexity keywords
                if any(kw in content for kw in [
                    "analyze", "detailed", "comprehensive", "technical",
                    "explain in detail", "step by step"
                ]):
                    return "complex"
                
                # Long messages
                if len(content) > 500:
                    return "complex"
        
        return "simple"
    
    def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse:
        """Select appropriate Azure deployment."""
        complexity = self._analyze_complexity(request)
        
        if complexity == "complex":
            model = self.complex_model
            deployment = self.config.gpt4_deployment_name
        else:
            model = self.simple_model
            deployment = self.config.deployment_name
        
        self.logger.info(
            f"Selected deployment: {deployment} (complexity: {complexity})"
        )
        
        return handler(request.override(model=model))



In [25]:

def create_production_tools() -> List[Tool]:
    """Create production-ready tools with error handling."""
    
    # Search tool with error handling
    def safe_search(query: str) -> str:
        """Perform web search with error handling."""
        try:
            search = DuckDuckGoSearchRun()
            result = search.run(query)
            return result
        except Exception as e:
            logger.error(f"Search failed: {e}")
            return f"Search temporarily unavailable. Error: {str(e)}"
    
    search_tool = Tool(
        name="web_search",
        func=safe_search,
        description=(
            "Search the web for current information. "
            "Use this when you need up-to-date facts or recent events."
        )
    )
    
    # Calculator tool
    def safe_calculate(expression: str) -> str:
        """Safely evaluate mathematical expressions."""
        try:
            # Security: only allow safe characters
            allowed = set("0123456789+-*/.() ")
            if not all(c in allowed for c in expression):
                return "Error: Invalid characters in expression"
            
            result = eval(expression)
            return f"Result: {result}"
        except Exception as e:
            return f"Calculation error: {str(e)}"
    
    calc_tool = Tool(
        name="calculator",
        func=safe_calculate,
        description=(
            "Perform mathematical calculations. "
            "Input should be a valid mathematical expression like '2 + 2' or '(15 * 8) / 4'."
        )
    )
    
    return [search_tool, calc_tool]



In [26]:

def create_production_tools() -> List[Tool]:
    """Create production-ready tools with error handling."""
    
    # Search tool with error handling
    def safe_search(query: str) -> str:
        """Perform web search with error handling."""
        try:
            search = DuckDuckGoSearchRun()
            result = search.run(query)
            return result
        except Exception as e:
            logger.error(f"Search failed: {e}")
            return f"Search temporarily unavailable. Error: {str(e)}"
    
    search_tool = Tool(
        name="web_search",
        func=safe_search,
        description=(
            "Search the web for current information. "
            "Use this when you need up-to-date facts or recent events."
        )
    )
    
    # Calculator tool
    def safe_calculate(expression: str) -> str:
        """Safely evaluate mathematical expressions."""
        try:
            # Security: only allow safe characters
            allowed = set("0123456789+-*/.() ")
            if not all(c in allowed for c in expression):
                return "Error: Invalid characters in expression"
            
            result = eval(expression)
            return f"Result: {result}"
        except Exception as e:
            return f"Calculation error: {str(e)}"
    
    calc_tool = Tool(
        name="calculator",
        func=safe_calculate,
        description=(
            "Perform mathematical calculations. "
            "Input should be a valid mathematical expression like '2 + 2' or '(15 * 8) / 4'."
        )
    )
    
    return [search_tool, calc_tool]


In [27]:

class ProductionAgent:
    """Production-ready Azure OpenAI agent with full middleware stack."""
    
    def __init__(self):
        """Initialize production agent."""
        logger.info("Initializing production agent...")
        
        # Load configuration
        self.config = AzureOpenAIConfig()
        
        # Perform health check
        health = self.config.health_check()
        if health["status"] != "healthy":
            raise RuntimeError(f"Azure OpenAI health check failed: {health}")
        
        logger.info("Azure OpenAI health check passed")
        
        # Initialize middleware
        self.middleware = self._create_middleware_stack()
        
        # Create tools
        self.tools = create_production_tools()
        
        # Create agent
        self.agent = self._create_agent()
        
        logger.info("Production agent initialized successfully")
    
    def _create_middleware_stack(self) -> List[AgentMiddleware]:
        """Create production middleware stack."""
        return [
            ProductionSafetyMiddleware(
                max_messages=self.config.max_conversation_messages,
                max_violations=3
            ),
            ProductionRateLimitMiddleware(
                max_calls_per_minute=self.config.rate_limit_calls
            ),
            AzureDynamicModelMiddleware(
                config=self.config,
                complexity_threshold=10
            ),
            ProductionRetryMiddleware(
                max_retries=self.config.max_retries,
                base_delay=1.0
            ),
            ProductionLoggingMiddleware(),
        ]
    
    def _create_agent(self):
        """Create LangChain agent with middleware."""
        default_model = self.config.get_model()
        
        system_prompt = (
            "You are a helpful AI assistant powered by Azure OpenAI. "
            "You have access to web search and calculation tools. "
            "Provide accurate, helpful, and concise responses. "
            "Use tools when appropriate to give the best answers."
        )
        
        return create_agent(
            model=default_model,
            middleware=self.middleware,
            tools=self.tools,
            system_prompt=system_prompt,
        )
    
    def invoke(
        self,
        message: str,
        user_id: str = "default",
        session_id: Optional[str] = None,
        conversation_history: Optional[List] = None,
    ) -> Dict[str, Any]:
        """
        Invoke agent with a message.
        
        Args:
            message: User message
            user_id: User identifier
            session_id: Session identifier
            conversation_history: Previous messages
        
        Returns:
            Response dictionary with message and metadata
        """
        try:
            # Build message list
            messages = conversation_history or []
            messages.append(HumanMessage(content=message))
            
            # Create state
            state = {
                "messages": messages,
                "user_id": user_id,
                "session_id": session_id or f"session_{int(time.time())}",
                "safety_violations": 0,
            }
            
            # Invoke agent
            logger.info(f"Processing message for user {user_id}")
            result = self.agent.invoke(state)
            
            # Extract response
            if result.get("messages"):
                response_message = result["messages"][-1]
                response_content = response_message.content if hasattr(response_message, "content") else str(response_message)
            else:
                response_content = "No response generated"
            
            return {
                "success": True,
                "response": response_content,
                "metadata": {
                    "user_id": user_id,
                    "session_id": result.get("session_id"),
                    "model_calls": result.get("model_call_count", 0),
                    "tool_calls": result.get("tool_call_count", 0),
                    "tokens": result.get("total_tokens", 0),
                    "rate_limit_remaining": result.get("rate_limit_remaining", 0),
                },
                "conversation_state": result,
            }
            
        except Exception as e:
            logger.error(f"Agent invocation failed: {e}", exc_info=True)
            return {
                "success": False,
                "error": str(e),
                "response": "I apologize, but I encountered an error processing your request.",
            }
    
    def health_check(self) -> Dict[str, Any]:
        """Perform health check on agent and dependencies."""
        return {
            "agent": "healthy",
            "azure_openai": self.config.health_check(),
            "middleware_count": len(self.middleware),
            "tools_count": len(self.tools),
            "timestamp": datetime.utcnow().isoformat(),
        }

In [None]:
def main():
    """Main application entry point."""
    print("=" * 80)
    print("PRODUCTION AZURE OPENAI AGENT")
    print("=" * 80)
    print()
    
    try:
        # Initialize agent
        print("Initializing production agent...")
        agent = ProductionAgent()
        print("✓ Agent initialized successfully\n")
        
        # Perform health check
        print("Performing health check...")
        health = agent.health_check()
        print(f"✓ Health check passed: {json.dumps(health, indent=2)}\n")
        
        # Interactive loop
        print("=" * 80)
        print("Agent is ready. Type 'quit' to exit, 'health' for health check")
        print("=" * 80)
        print()
        
        conversation_history = []
        user_id = "demo_user"
        
        while True:
            try:
                user_input = input("\nYou: ").strip()
                
                if not user_input:
                    continue
                
                if user_input.lower() == 'quit':
                    print("\nGoodbye!")
                    break
                
                if user_input.lower() == 'health':
                    health = agent.health_check()
                    print(f"\nHealth Status:\n{json.dumps(health, indent=2)}")
                    continue
                
                # Invoke agent
                result = agent.invoke(
                    message=user_input,
                    user_id=user_id,
                    conversation_history=conversation_history,
                )
                
                # Display response
                if result["success"]:
                    print(f"\nAssistant: {result['response']}")
                    
                    # Show metadata
                    metadata = result["metadata"]
                    print(f"\n[Stats: Calls={metadata['model_calls']}, "
                          f"Tools={metadata['tool_calls']}, "
                          f"Tokens≈{metadata['tokens']}, "
                          f"Rate limit remaining={metadata['rate_limit_remaining']}]")
                    
                    # Update conversation history
                    conversation_history = result["conversation_state"]["messages"]
                else:
                    print(f"\nError: {result['error']}")
                    print(f"Assistant: {result['response']}")
                
            except KeyboardInterrupt:
                print("\n\nInterrupted. Goodbye!")
                break
            except Exception as e:
                logger.error(f"Error in main loop: {e}", exc_info=True)
                print(f"\nError: {e}")
    
    except Exception as e:
        logger.error(f"Fatal error: {e}", exc_info=True)
        print(f"\n Fatal error: {e}")
        sys.exit(1)


In [None]:
main()