In [None]:
# nb70_fastapi_endpoints.ipynb
# FastAPI 端點實作：Chat/RAG/Agent/Game 服務化

# Cell1:  Shared Cache Bootstrap
import os, pathlib, torch
import sys
from datetime import datetime

# Shared cache configuration (複製到每本 notebook)
AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "../ai_warehouse/cache")

for k, v in {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)
print("[Cache]", AI_CACHE_ROOT, "| GPU:", torch.cuda.is_available())

In [None]:
## Cell 2: 依賴導入與基礎設定
from fastapi import FastAPI, HTTPException, Depends, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, validator
from typing import List, Dict, Optional, Any
import uvicorn
import asyncio
import time
import logging
from contextlib import asynccontextmanager
import json

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global state for models (will be initialized on startup)
app_state = {
    "llm_adapter": None,
    "rag_retriever": None,
    "agent_orchestrator": None,
    "game_engine": None,
}

In [None]:
## Cell 3: Pydantic 請求/回應模型
# Request/Response schemas
class ChatRequest(BaseModel):
    message: str = Field(..., min_length=1, max_length=4096)
    max_tokens: Optional[int] = Field(default=256, ge=1, le=2048)
    temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0)
    conversation_id: Optional[str] = None


class ChatResponse(BaseModel):
    response: str
    conversation_id: str
    tokens_used: int
    latency_ms: float


class RAGRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=1024)
    domain: Optional[str] = Field(default="general")
    top_k: Optional[int] = Field(default=5, ge=1, le=20)
    use_rerank: Optional[bool] = True


class RAGResponse(BaseModel):
    answer: str
    sources: List[Dict[str, Any]]
    query_embedding_time: float
    retrieval_time: float
    generation_time: float


class AgentRequest(BaseModel):
    task: str = Field(..., min_length=1, max_length=2048)
    mode: str = Field(
        default="research_write",
        pattern="^(research|plan|write|review|research_write)$",
    )
    max_iterations: Optional[int] = Field(default=3, ge=1, le=10)


class AgentResponse(BaseModel):
    result: str
    execution_log: List[Dict[str, Any]]
    total_time: float
    iterations_used: int


class GameRequest(BaseModel):
    action: str = Field(..., min_length=1, max_length=512)
    session_id: Optional[str] = None
    save_state: Optional[bool] = False


class GameResponse(BaseModel):
    narrative: str
    choices: List[str]
    game_state: Dict[str, Any]
    session_id: str


class HealthResponse(BaseModel):
    status: str
    version: str
    models_loaded: Dict[str, bool]
    uptime_seconds: float

In [None]:
## Cell 4: 模型初始化邏輯
# Mock implementations - replace with actual shared_utils imports
class MockLLMAdapter:
    def __init__(self):
        self.model_name = "Mock-7B-Instruct"

    def generate(self, messages, max_tokens=256, temperature=0.7):
        # Simulate processing time
        time.sleep(0.1)
        user_msg = messages[-1]["content"] if messages else ""
        return f"Mock response to: {user_msg[:50]}..."

    def count_tokens(self, text):
        return len(text.split()) * 1.3  # rough estimate


class MockRAGRetriever:
    def search(self, query, top_k=5):
        # Mock retrieval results
        return [
            {
                "text": f"Retrieved document {i+1} for query: {query[:30]}...",
                "meta": {"source_id": f"doc_{i+1}", "score": 0.9 - i * 0.1},
                "score": 0.9 - i * 0.1,
            }
            for i in range(min(top_k, 3))
        ]


class MockAgentOrchestrator:
    def execute_task(self, task, mode="research_write", max_iterations=3):
        # Mock agent execution
        execution_log = [
            {"step": 1, "role": "researcher", "action": f"Research: {task[:30]}..."},
            {"step": 2, "role": "writer", "action": "Generated response"},
        ]
        return f"Agent result for task: {task[:50]}...", execution_log


class MockGameEngine:
    def __init__(self):
        self.sessions = {}

    def process_action(self, action, session_id=None):
        if not session_id:
            session_id = f"game_{int(time.time())}"

        # Mock game state
        state = self.sessions.get(
            session_id, {"hp": 100, "level": 1, "location": "forest"}
        )

        narrative = f"You decided to {action}. The forest grows darker..."
        choices = ["Go north", "Rest here", "Check inventory"]

        self.sessions[session_id] = state
        return narrative, choices, state, session_id


# Initialize mock models (replace with real implementations)
def initialize_models():
    """Initialize all models and components"""
    logger.info("Initializing models...")

    app_state["llm_adapter"] = MockLLMAdapter()
    app_state["rag_retriever"] = MockRAGRetriever()
    app_state["agent_orchestrator"] = MockAgentOrchestrator()
    app_state["game_engine"] = MockGameEngine()

    logger.info("Models initialized successfully")

In [None]:
## Cell 5: FastAPI 應用設定
# FastAPI app with lifespan management
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    logger.info("Starting up FastAPI application...")
    start_time = time.time()

    initialize_models()

    app.state.start_time = start_time
    logger.info(f"Startup completed in {time.time() - start_time:.2f}s")

    yield

    # Shutdown
    logger.info("Shutting down FastAPI application...")


# Create FastAPI app
app = FastAPI(
    title="RAGent Text Lab API",
    description="Multi-modal text AI API with Chat/RAG/Agent/Game capabilities",
    version="1.0.0",
    lifespan=lifespan,
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Configure appropriately for production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# Request logging middleware
@app.middleware("http")
async def log_requests(request, call_next):
    start_time = time.time()
    response = await call_next(request)
    process_time = time.time() - start_time

    logger.info(
        f"{request.method} {request.url.path} - "
        f"Status: {response.status_code} - "
        f"Time: {process_time:.3f}s"
    )
    return response

In [None]:
## Cell 6: Health Check 端點
@app.get("/health", response_model=HealthResponse)
async def health_check():
    """健康檢查端點"""
    uptime = time.time() - app.state.start_time

    models_loaded = {
        "llm_adapter": app_state["llm_adapter"] is not None,
        "rag_retriever": app_state["rag_retriever"] is not None,
        "agent_orchestrator": app_state["agent_orchestrator"] is not None,
        "game_engine": app_state["game_engine"] is not None,
    }

    return HealthResponse(
        status="healthy" if all(models_loaded.values()) else "degraded",
        version="1.0.0",
        models_loaded=models_loaded,
        uptime_seconds=uptime,
    )


@app.get("/")
async def root():
    """根端點"""
    return {"message": "RAGent Text Lab API", "docs": "/docs"}

In [None]:
## Cell 7: Chat 端點實作
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
    """基本對話端點"""
    try:
        start_time = time.time()

        if not app_state["llm_adapter"]:
            raise HTTPException(status_code=503, detail="LLM adapter not initialized")

        # Prepare messages
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": request.message},
        ]

        # Generate response
        response_text = app_state["llm_adapter"].generate(
            messages=messages,
            max_tokens=request.max_tokens,
            temperature=request.temperature,
        )

        # Calculate metrics
        latency_ms = (time.time() - start_time) * 1000
        tokens_used = app_state["llm_adapter"].count_tokens(response_text)
        conversation_id = request.conversation_id or f"chat_{int(time.time())}"

        return ChatResponse(
            response=response_text,
            conversation_id=conversation_id,
            tokens_used=int(tokens_used),
            latency_ms=latency_ms,
        )

    except Exception as e:
        logger.error(f"Chat endpoint error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Chat processing failed: {str(e)}")

In [None]:
## Cell 8: RAG 端點實作
@app.post("/rag", response_model=RAGResponse)
async def rag_endpoint(request: RAGRequest):
    """RAG 檢索增強回答端點"""
    try:
        if not app_state["rag_retriever"] or not app_state["llm_adapter"]:
            raise HTTPException(
                status_code=503, detail="RAG components not initialized"
            )

        # Timing for different phases
        embed_start = time.time()

        # Retrieve relevant documents
        retrieval_start = time.time()
        retrieved_docs = app_state["rag_retriever"].search(
            query=request.query, top_k=request.top_k
        )
        retrieval_time = time.time() - retrieval_start

        # Build context from retrieved documents
        context = "\n\n".join(
            [f"[{i+1}] {doc['text']}" for i, doc in enumerate(retrieved_docs)]
        )

        # Generate answer with RAG context
        gen_start = time.time()
        messages = [
            {
                "role": "system",
                "content": "Answer based on the provided context. Use citations like [1], [2].",
            },
            {
                "role": "user",
                "content": f"Question: {request.query}\n\nContext:\n{context}",
            },
        ]

        answer = app_state["llm_adapter"].generate(
            messages=messages, max_tokens=512, temperature=0.3
        )
        generation_time = time.time() - gen_start

        # Prepare sources for response
        sources = [
            {
                "id": i + 1,
                "text": (
                    doc["text"][:200] + "..." if len(doc["text"]) > 200 else doc["text"]
                ),
                "metadata": doc["meta"],
                "score": doc["score"],
            }
            for i, doc in enumerate(retrieved_docs)
        ]

        return RAGResponse(
            answer=answer,
            sources=sources,
            query_embedding_time=retrieval_start - embed_start,
            retrieval_time=retrieval_time,
            generation_time=generation_time,
        )

    except Exception as e:
        logger.error(f"RAG endpoint error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"RAG processing failed: {str(e)}")

In [None]:
## Cell 9: Agent 端點實作
@app.post("/agent", response_model=AgentResponse)
async def agent_endpoint(request: AgentRequest):
    """多代理協作端點"""
    try:
        start_time = time.time()

        if not app_state["agent_orchestrator"]:
            raise HTTPException(
                status_code=503, detail="Agent orchestrator not initialized"
            )

        # Execute agent task
        result, execution_log = app_state["agent_orchestrator"].execute_task(
            task=request.task, mode=request.mode, max_iterations=request.max_iterations
        )

        total_time = time.time() - start_time

        return AgentResponse(
            result=result,
            execution_log=execution_log,
            total_time=total_time,
            iterations_used=len(execution_log),
        )

    except Exception as e:
        logger.error(f"Agent endpoint error: {str(e)}")
        raise HTTPException(
            status_code=500, detail=f"Agent processing failed: {str(e)}"
        )

In [None]:
## Cell 10: Game 端點實作
@app.post("/game", response_model=GameResponse)
async def game_endpoint(request: GameRequest):
    """文字冒險遊戲端點"""
    try:
        if not app_state["game_engine"]:
            raise HTTPException(status_code=503, detail="Game engine not initialized")

        # Process game action
        narrative, choices, game_state, session_id = app_state["game_engine"].process_action(
            action=request.action,
            session_id=request.session_id
        )

        return GameResponse(
            narrative=narrative,
            choices=choices,
            game_state=game_state,
            session_id=session_id
        )

    except Exception as e:
        logger.error(f"Game endpoint error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Game processing failed: {str(e)}")


In [None]:
## Cell 11: 錯誤處理

# Global exception handlers
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
    return JSONResponse(
        status_code=exc.status_code,
        content={
            "error": exc.detail,
            "path": str(request.url.path),
            "method": request.method,
            "timestamp": time.time(),
        },
    )


@app.exception_handler(Exception)
async def general_exception_handler(request, exc):
    logger.error(f"Unhandled exception: {str(exc)}")
    return JSONResponse(
        status_code=500,
        content={
            "error": "Internal server error",
            "path": str(request.url.path),
            "method": request.method,
            "timestamp": time.time(),
        },
    )

In [None]:
## Cell 12: Smoke Test - 本地啟動
# Smoke test - start server in background for testing
import threading
import requests
import time


def start_server():
    """Start FastAPI server in background thread"""
    uvicorn.run(app, host="127.0.0.1", port=8000, log_level="info")


# Start server in background
server_thread = threading.Thread(target=start_server, daemon=True)
server_thread.start()

# Wait for server startup
time.sleep(3)

print("🚀 FastAPI server started at http://127.0.0.1:8000")
print("📚 API Documentation: http://127.0.0.1:8000/docs")
print("🔍 Alternative docs: http://127.0.0.1:8000/redoc")

In [None]:
## Cell 13: API 測試
# Test the endpoints
base_url = "http://127.0.0.1:8000"

try:
    # Test health check
    health_response = requests.get(f"{base_url}/health")
    print(f"Health Check: {health_response.status_code}")
    print(f"Response: {health_response.json()}\n")

    # Test chat endpoint
    chat_payload = {
        "message": "Hello, how are you?",
        "max_tokens": 100,
        "temperature": 0.7,
    }
    chat_response = requests.post(f"{base_url}/chat", json=chat_payload)
    print(f"Chat: {chat_response.status_code}")
    print(f"Response: {chat_response.json()}\n")

    # Test RAG endpoint
    rag_payload = {"query": "What is machine learning?", "top_k": 3}
    rag_response = requests.post(f"{base_url}/rag", json=rag_payload)
    print(f"RAG: {rag_response.status_code}")
    print(f"Response: {rag_response.json()}\n")

    # Test agent endpoint
    agent_payload = {
        "task": "Research and write about renewable energy",
        "mode": "research_write",
    }
    agent_response = requests.post(f"{base_url}/agent", json=agent_payload)
    print(f"Agent: {agent_response.status_code}")
    print(f"Response: {agent_response.json()}\n")

    # Test game endpoint
    game_payload = {"action": "explore the forest"}
    game_response = requests.post(f"{base_url}/game", json=game_payload)
    print(f"Game: {game_response.status_code}")
    print(f"Response: {game_response.json()}\n")

    print("✅ All endpoints responding successfully!")

except requests.exceptions.ConnectionError:
    print("❌ Server not responding. Make sure it's started.")
except Exception as e:
    print(f"❌ Test failed: {str(e)}")

In [None]:
## Cell 14: 生產環境配置建議
# Production deployment considerations

production_config = {
    "server": {
        "host": "0.0.0.0",
        "port": 8000,
        "workers": 4,  # For uvicorn with --workers
        "log_level": "warning",
    },
    "security": {
        "cors_origins": ["https://yourdomain.com"],
        "rate_limit": "100/minute",
        "api_key_required": True,
    },
    "performance": {"max_request_size": "10MB", "timeout": 300, "keep_alive": 75},
    "monitoring": {
        "health_check_interval": 30,
        "metrics_enabled": True,
        "log_requests": True,
    },
}

print("🔧 Production Configuration:")
print(json.dumps(production_config, indent=2))

print("\n📋 Deployment Checklist:")
checklist = [
    "Set environment variables (MODEL_ID, API_KEYS)",
    "Configure CORS origins properly",
    "Set up rate limiting middleware",
    "Add API authentication",
    "Configure logging aggregation",
    "Set up health monitoring",
    "Use HTTPS in production",
    "Set appropriate resource limits",
]

for i, item in enumerate(checklist, 1):
    print(f"{i}. {item}")

In [None]:
## Cell 15: 停止伺服器 (清理)
# Note: In production, use proper process management
print("📝 Note: Server is running in background thread.")
print("To properly stop in production, use:")
print("- uvicorn main:app --reload (development)")
print("- gunicorn main:app -w 4 -k uvicorn.workers.UvicornWorker (production)")
print("- Docker containers with proper signal handling")
print("- Process managers like supervisor or systemd")
