In [None]:
import os
from typing import Dict, Any, List, Optional, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import asyncio
from pathlib import Path
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("llm_deployment")

class LLMManager:
    """Manages loading and inference with small LLMs for agents"""
    
    def __init__(self, cache_dir: Optional[str] = None):
        """Initialize the LLM manager
        
        Args:
            cache_dir: Directory to cache downloaded models
        """
        self.cache_dir = cache_dir
        self.loaded_models = {}
        self.model_info = {
            "phi-3-mini": {
                "model_id": "microsoft/phi-3-mini",
                "revision": "main",
                "requires_gpu": False,
                "quantization": "4bit",  # Options: None, "4bit", "8bit"
                "context_length": 4096
            },
            "gemma-2b": {
                "model_id": "google/gemma-2b",
                "revision": "main",
                "requires_gpu": False,
                "quantization": "4bit",
                "context_length": 8192
            },
            "tinyllama-1.1b": {
                "model_id": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
                "revision": "main",
                "requires_gpu": False,
                "quantization": "4bit",
                "context_length": 2048
            },
            "mistral-7b": {
                "model_id": "mistralai/Mistral-7B-Instruct-v0.2",
                "revision": "main",
                "requires_gpu": True,
                "quantization": "4bit",
                "context_length": 8192
            }
        }
        
        # Check if CUDA is available
        self.cuda_available = torch.cuda.is_available()
        if self.cuda_available:
            logger.info(f"CUDA is available: {torch.cuda.get_device_name(0)}")
        else:
            logger.info("CUDA is not available, will use CPU")
    
    async def load_model(self, model_name: str, device: Optional[str] = None) -> str:
        """Load a model asynchronously
        
        Args:
            model_name: Name of the model to load (must be in model_info)
            device: Device to load the model on ("cuda", "cpu", etc.)
                    If None, will use CUDA if available and model requires it
        
        Returns:
            model_key: Unique key for the loaded model
        """
        if model_name not in self.model_info:
            raise ValueError(f"Unknown model: {model_name}. Available models: {list(self.model_info.keys())}")
        
        # Determine device
        if device is None:
            if self.model_info[model_name]["requires_gpu"] and not self.cuda_available:
                raise ValueError(f"Model {model_name} requires GPU but CUDA is not available")
            device = "cuda" if self.cuda_available and self.model_info[model_name]["requires_gpu"] else "cpu"
        
        # Create unique model key
        model_key = f"{model_name}_{device}"
        
        # Check if model is already loaded
        if model_key in self.loaded_models:
            logger.info(f"Model {model_name} already loaded on {device}")
            return model_key
        
        # Get model info
        model_info = self.model_info[model_name]
        model_id = model_info["model_id"]
        revision = model_info["revision"]
        quantization = model_info["quantization"]
        
        logger.info(f"Loading model {model_name} on {device}...")
        
        # Create a loading task
        def load_model_task():
            try:
                # Load tokenizer
                tokenizer = AutoTokenizer.from_pretrained(
                    model_id,
                    revision=revision,
                    cache_dir=self.cache_dir
                )
                
                # Load model with appropriate quantization
                if quantization == "4bit":
                    model = AutoModelForCausalLM.from_pretrained(
                        model_id,
                        revision=revision,
                        cache_dir=self.cache_dir,
                        device_map=device,
                        load_in_4bit=True,
                        trust_remote_code=True
                    )
                elif quantization == "8bit":
                    model = AutoModelForCausalLM.from_pretrained(
                        model_id,
                        revision=revision,
                        cache_dir=self.cache_dir,
                        device_map=device,
                        load_in_8bit=True,
                        trust_remote_code=True
                    )
                else:
                    model = AutoModelForCausalLM.from_pretrained(
                        model_id,
                        revision=revision,
                        cache_dir=self.cache_dir,
                        device_map=device,
                        trust_remote_code=True
                    )
                
                # Create pipeline
                pipe = pipeline(
                    "text-generation",
                    model=model,
                    tokenizer=tokenizer,
                    device_map=device,
                    max_new_tokens=512,
                    trust_remote_code=True
                )
                
                return {
                    "model": model,
                    "tokenizer": tokenizer,
                    "pipeline": pipe,
                    "info": model_info
                }
            except Exception as e:
                logger.error(f"Error loading model {model_name}: {e}")
                raise
        
        # Run the loading task in a separate thread to not block the event loop
        loop = asyncio.get_event_loop()
        model_data = await loop.run_in_executor(None, load_model_task)
        
        # Store the loaded model
        self.loaded_models[model_key] = model_data
        
        logger.info(f"Model {model_name} loaded successfully on {device}")
        return model_key
    
    async def unload_model(self, model_key: str) -> bool:
        """Unload a model to free memory
        
        Args:
            model_key: Key of the model to unload
            
        Returns:
            bool: True if model was unloaded, False if it wasn't loaded
        """
        if model_key not in self.loaded_models:
            logger.warning(f"Model {model_key} is not loaded")
            return False
        
        # Get model data
        model_data = self.loaded_models[model_key]
        
        # Unload model
        model_data["model"] = None
        model_data["tokenizer"] = None
        model_data["pipeline"] = None
        
        # Remove from loaded models
        del self.loaded_models[model_key]
        
        # Force garbage collection
        import gc
        gc.collect()
        
        if self.cuda_available:
            torch.cuda.empty_cache()
        
        logger.info(f"Model {model_key} unloaded successfully")
        return True
    
    async def generate(
        self, 
        model_key: str, 
        prompt: str, 
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        top_k: int = 50,
        repetition_penalty: float = 1.1,
        **kwargs
    ) -> str:
        """Generate text from the model
        
        Args:
            model_key: Key of the model to use
            prompt: Input prompt
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature
            top_p: Nucleus sampling parameter
            top_k: Top-k sampling parameter
            repetition_penalty: Penalty for repeating tokens
            **kwargs: Additional arguments to pass to the generation function
            
        Returns:
            str: Generated text
        """
        if model_key not in self.loaded_models:
            raise ValueError(f"Model {model_key} is not loaded")
        
        # Get model data
        model_data = self.loaded_models[model_key]
        pipe = model_data["pipeline"]
        
        # Create a generation task
        def generate_task():
            try:
                outputs = pipe(
                    prompt,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    top_p=top_p,
                    top_k=top_k,
                    repetition_penalty=repetition_penalty,
                    do_sample=True,
                    **kwargs
                )
                
                # Extract generated text
                generated_text = outputs[0]["generated_text"]
                
                # Remove the prompt from the generated text
                if generated_text.startswith(prompt):
                    generated_text = generated_text[len(prompt):]
                
                return generated_text.strip()
            except Exception as e:
                logger.error(f"Error generating text: {e}")
                raise
        
        # Run the generation task in a separate thread to not block the event loop
        loop = asyncio.get_event_loop()
        result = await loop.run_in_executor(None, generate_task)
        
        return result
    
    def get_available_models(self) -> List[str]:
        """Get a list of available models
        
        Returns:
            List[str]: List of available model names
        """
        return list(self.model_info.keys())
    
    def get_loaded_models(self) -> Dict[str, str]:
        """Get