In [6]:
import os
import csv
import json
import logging
import google.generativeai as genai
from typing import List, Dict, Optional, Callable, Protocol
from pydantic import BaseModel, Field, validator
from enum import Enum, auto
from datetime import datetime
from pathlib import Path
from abc import ABC, abstractmethod
import re
import fcntl

from dotenv import load_dotenv
load_dotenv()

logging.basicConfig(
    level=logging.DEBUG,
    format='%(asctime)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s',
    handlers=[
        logging.FileHandler('../logs/agent_trace.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)



class ToolName(Enum):
    SEARCH_PRODUCTS = auto()
    GET_PRODUCT_DETAILS = auto()
    CHECK_AVAILABILITY = auto()
    PLACE_ORDER = auto()
    GET_ORDER_HISTORY = auto()
    RESPOND_TO_CUSTOMER = auto()
    GET_USER_PREFERENCES = auto()
    UPDATE_USER_PREFERENCES = auto()
    NONE = auto()

    def __str__(self) -> str:
        return self.name


class Message(BaseModel):
    role: str = Field(..., description="Role of the message sender")
    content: str = Field(..., description="Content of the message")
    timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())


class Product(BaseModel):
    product_id: str
    product_name: str
    category: str
    brand: str
    price: float
    stock_quantity: int
    rating: float
    description: str


class Order(BaseModel):
    order_id: str
    product_id: str
    product_name: str
    price: float
    quantity: int
    timestamp: str
    user_id: str


class UserPreferences(BaseModel):
    budget: Optional[float] = None
    preferred_brands: List[str] = []



class InputValidator:
    @staticmethod
    def validate_user_id(user_id: str) -> str:
        if not re.match(r'^[a-zA-Z0-9_-]{1,20}$', user_id):
            raise ValueError("Invalid user ID. Use only alphanumeric, underscore, hyphen (1-20 chars)")
        return user_id

    @staticmethod
    def validate_product_id(product_id: str) -> str:
        if not re.match(r'^P\d{3}$', product_id):
            raise ValueError("Invalid product ID format. Expected: P### (e.g., P001)")
        return product_id

    @staticmethod
    def validate_quantity(quantity: int) -> int:
        if not 1 <= quantity <= 100:
            raise ValueError("Quantity must be between 1 and 100")
        return quantity

    @staticmethod
    def sanitize_search_query(query: str) -> str:
        # Remove potential injection patterns and limit length
        sanitized = re.sub(r'[<>"\';\\]', '', query)
        return sanitized[:100]


In [7]:
class LLMProvider(ABC):
    @abstractmethod
    def generate(self, prompt: str) -> str:
        pass


class GeminiProvider(LLMProvider):
    def __init__(self, model_name: str = "gemini-2.5-flash"):
        genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
        self.model = genai.GenerativeModel(model_name)

    def generate(self, prompt: str) -> str:
        logger.debug(f"Generating response for prompt length: {len(prompt)}")
        max_retries = 3
        for attempt in range(max_retries):
            try:
                logger.debug(f"Gemini API call attempt {attempt + 1}")
                response = self.model.generate_content(
                    prompt,
                    generation_config=genai.types.GenerationConfig(
                        temperature=0,
                        max_output_tokens=500,
                        top_p=0.8,
                        top_k=40
                    )
                )

                if hasattr(response, 'text') and response.text:
                    logger.debug(f"Gemini response received: {response.text[:200]}...")
                    return response.text

                logger.warning(f"Empty response from Gemini (attempt {attempt + 1})")

            except Exception as e:
                logger.error(f"Gemini API error (attempt {attempt + 1}): {e}")
                if attempt == max_retries - 1:
                    fallback_response = '{"thought": "Technical difficulties", "answer": "I\'m experiencing technical issues. Please try again."}'
                    logger.debug(f"Returning fallback response: {fallback_response}")
                    return fallback_response

                import time
                time.sleep(1)

        fallback_response = '{"thought": "Max retries exceeded", "answer": "Connection issues. Please try again later."}'
        logger.debug(f"Max retries exceeded, returning: {fallback_response}")
        return fallback_response