diff --git a/src/climatevision/api/main.py b/src/climatevision/api/main.py index bd1d2df..a155ed4 100644 --- a/src/climatevision/api/main.py +++ b/src/climatevision/api/main.py @@ -13,15 +13,20 @@ import json import logging +import time from datetime import datetime, timezone from pathlib import Path from typing import Any, Optional, Literal -from fastapi import FastAPI, File, Form, HTTPException, UploadFile, Header, Query, Depends +from pydantic import field_validator + +from fastapi import FastAPI, File, Form, HTTPException, UploadFile, Header, Query, Depends, Request from fastapi.responses import RedirectResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles -from pydantic import BaseModel, Field, EmailStr +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.responses import Response +from pydantic import BaseModel, Field, EmailStr, model_validator from climatevision.db import ( get_connection, @@ -106,6 +111,36 @@ class PredictRequest(BaseModel): start_date: Optional[str] = None end_date: Optional[str] = None + @field_validator("bbox") + @classmethod + def validate_bbox(cls, v: Optional[list[float]]) -> Optional[list[float]]: + if v is None: + return v + if len(v) != 4: + raise ValueError("bbox must have exactly 4 values: [west, south, east, north]") + west, south, east, north = v + if not (-180 <= west <= 180 and -180 <= east <= 180): + raise ValueError("bbox longitude values must be between -180 and 180") + if not (-90 <= south <= 90 and -90 <= north <= 90): + raise ValueError("bbox latitude values must be between -90 and 90") + if west >= east: + raise ValueError("bbox west longitude must be less than east longitude") + if south >= north: + raise ValueError("bbox south latitude must be less than north latitude") + return v + + @model_validator(mode="after") + def validate_date_range(self) -> "PredictRequest": + if self.start_date and self.end_date: + try: + start = datetime.strptime(self.start_date, "%Y-%m-%d") + end = datetime.strptime(self.end_date, "%Y-%m-%d") + except ValueError: + raise ValueError("start_date and end_date must be in YYYY-MM-DD format") + if start >= end: + raise ValueError("start_date must be earlier than end_date") + return self + class RunRow(BaseModel): id: int @@ -283,6 +318,28 @@ async def get_current_organization( return None +# ===== Audit Logging Middleware ===== + +class AuditLogMiddleware(BaseHTTPMiddleware): + """Log every API request with method, path, status code, and duration.""" + + async def dispatch(self, request: Request, call_next: Any) -> Response: + start = time.perf_counter() + response: Response = await call_next(request) + duration_ms = round((time.perf_counter() - start) * 1000, 2) + + logger.info( + "API request | method=%s path=%s status=%s duration_ms=%s ip=%s", + request.method, + request.url.path, + response.status_code, + duration_ms, + request.client.host if request.client else "unknown", + ) + response.headers["X-Response-Time-Ms"] = str(duration_ms) + return response + + # ===== Application Factory ===== def create_app() -> FastAPI: @@ -305,6 +362,7 @@ def create_app() -> FastAPI: openapi_url="/openapi.json", ) + app.add_middleware(AuditLogMiddleware) app.add_middleware( CORSMiddleware, allow_origins=[ @@ -354,27 +412,79 @@ def get_analysis_type(analysis_type: str) -> dict[str, Any]: @app.get("/api/runs") def list_runs( limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), status: Optional[str] = None, analysis_type: Optional[str] = None, - ) -> list[RunRow]: - """List analysis runs with optional filtering.""" - query = "SELECT * FROM runs WHERE 1=1" + ) -> dict[str, Any]: + """List analysis runs with optional filtering and pagination metadata.""" + where_clauses = ["1=1"] params: list = [] - + if status: - query += " AND status = ?" + where_clauses.append("status = ?") params.append(status) if analysis_type: - query += " AND analysis_type = ?" + where_clauses.append("analysis_type = ?") params.append(analysis_type) - - query += " ORDER BY id DESC LIMIT ?" - params.append(int(limit)) - + + where = " AND ".join(where_clauses) + with get_connection() as conn: - rows = conn.execute(query, params).fetchall() - - return [RunRow(**dict(r)) for r in rows] + total: int = conn.execute( + f"SELECT COUNT(*) FROM runs WHERE {where}", params + ).fetchone()[0] + rows = conn.execute( + f"SELECT * FROM runs WHERE {where} ORDER BY id DESC LIMIT ? OFFSET ?", + params + [int(limit), int(offset)], + ).fetchall() + + return { + "total": total, + "limit": limit, + "offset": offset, + "runs": [RunRow(**dict(r)) for r in rows], + } + + @app.get("/api/runs/stats") + def get_run_stats() -> dict[str, Any]: + """Return aggregated run statistics for dashboard KPI cards.""" + with get_connection() as conn: + total = conn.execute("SELECT COUNT(*) FROM runs").fetchone()[0] + + by_status = { + row["status"]: row["count"] + for row in conn.execute( + "SELECT status, COUNT(*) as count FROM runs GROUP BY status" + ).fetchall() + } + + by_analysis_type = { + row["analysis_type"]: row["count"] + for row in conn.execute( + "SELECT analysis_type, COUNT(*) as count FROM runs GROUP BY analysis_type" + ).fetchall() + } + + recent_completed = conn.execute( + "SELECT COUNT(*) FROM runs WHERE status = 'completed' " + "AND created_at >= datetime('now', '-7 days')" + ).fetchone()[0] + + alerts_total = conn.execute("SELECT COUNT(*) FROM alerts").fetchone()[0] + alerts_unacknowledged = conn.execute( + "SELECT COUNT(*) FROM alerts WHERE acknowledged = 0" + ).fetchone()[0] + + return { + "total_runs": total, + "completed_last_7_days": recent_completed, + "by_status": by_status, + "by_analysis_type": by_analysis_type, + "alerts": { + "total": alerts_total, + "unacknowledged": alerts_unacknowledged, + }, + } @app.get("/api/runs/{run_id}") def get_run(run_id: int) -> dict[str, Any]: