Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 125 additions & 15 deletions src/climatevision/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@

import json
import logging
import time
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Optional, Literal

from fastapi import FastAPI, File, Form, HTTPException, UploadFile, Header, Query, Depends
from pydantic import field_validator

from fastapi import FastAPI, File, Form, HTTPException, UploadFile, Header, Query, Depends, Request
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field, EmailStr
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response
from pydantic import BaseModel, Field, EmailStr, model_validator

from climatevision.db import (
get_connection,
Expand Down Expand Up @@ -106,6 +111,36 @@ class PredictRequest(BaseModel):
start_date: Optional[str] = None
end_date: Optional[str] = None

@field_validator("bbox")
@classmethod
def validate_bbox(cls, v: Optional[list[float]]) -> Optional[list[float]]:
if v is None:
return v
if len(v) != 4:
raise ValueError("bbox must have exactly 4 values: [west, south, east, north]")
west, south, east, north = v
if not (-180 <= west <= 180 and -180 <= east <= 180):
raise ValueError("bbox longitude values must be between -180 and 180")
if not (-90 <= south <= 90 and -90 <= north <= 90):
raise ValueError("bbox latitude values must be between -90 and 90")
if west >= east:
raise ValueError("bbox west longitude must be less than east longitude")
if south >= north:
raise ValueError("bbox south latitude must be less than north latitude")
return v

@model_validator(mode="after")
def validate_date_range(self) -> "PredictRequest":
if self.start_date and self.end_date:
try:
start = datetime.strptime(self.start_date, "%Y-%m-%d")
end = datetime.strptime(self.end_date, "%Y-%m-%d")
except ValueError:
raise ValueError("start_date and end_date must be in YYYY-MM-DD format")
if start >= end:
raise ValueError("start_date must be earlier than end_date")
return self


class RunRow(BaseModel):
id: int
Expand Down Expand Up @@ -283,6 +318,28 @@ async def get_current_organization(
return None


# ===== Audit Logging Middleware =====

class AuditLogMiddleware(BaseHTTPMiddleware):
"""Log every API request with method, path, status code, and duration."""

async def dispatch(self, request: Request, call_next: Any) -> Response:
start = time.perf_counter()
response: Response = await call_next(request)
duration_ms = round((time.perf_counter() - start) * 1000, 2)

logger.info(
"API request | method=%s path=%s status=%s duration_ms=%s ip=%s",
request.method,
request.url.path,
response.status_code,
duration_ms,
request.client.host if request.client else "unknown",
)
response.headers["X-Response-Time-Ms"] = str(duration_ms)
return response


# ===== Application Factory =====

def create_app() -> FastAPI:
Expand All @@ -305,6 +362,7 @@ def create_app() -> FastAPI:
openapi_url="/openapi.json",
)

app.add_middleware(AuditLogMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=[
Expand Down Expand Up @@ -354,27 +412,79 @@ def get_analysis_type(analysis_type: str) -> dict[str, Any]:
@app.get("/api/runs")
def list_runs(
limit: int = Query(default=50, le=200),
offset: int = Query(default=0, ge=0),
status: Optional[str] = None,
analysis_type: Optional[str] = None,
) -> list[RunRow]:
"""List analysis runs with optional filtering."""
query = "SELECT * FROM runs WHERE 1=1"
) -> dict[str, Any]:
"""List analysis runs with optional filtering and pagination metadata."""
where_clauses = ["1=1"]
params: list = []

if status:
query += " AND status = ?"
where_clauses.append("status = ?")
params.append(status)
if analysis_type:
query += " AND analysis_type = ?"
where_clauses.append("analysis_type = ?")
params.append(analysis_type)

query += " ORDER BY id DESC LIMIT ?"
params.append(int(limit))


where = " AND ".join(where_clauses)

with get_connection() as conn:
rows = conn.execute(query, params).fetchall()

return [RunRow(**dict(r)) for r in rows]
total: int = conn.execute(
f"SELECT COUNT(*) FROM runs WHERE {where}", params
).fetchone()[0]
rows = conn.execute(
f"SELECT * FROM runs WHERE {where} ORDER BY id DESC LIMIT ? OFFSET ?",
params + [int(limit), int(offset)],
).fetchall()

return {
"total": total,
"limit": limit,
"offset": offset,
"runs": [RunRow(**dict(r)) for r in rows],
}

@app.get("/api/runs/stats")
def get_run_stats() -> dict[str, Any]:
"""Return aggregated run statistics for dashboard KPI cards."""
with get_connection() as conn:
total = conn.execute("SELECT COUNT(*) FROM runs").fetchone()[0]

by_status = {
row["status"]: row["count"]
for row in conn.execute(
"SELECT status, COUNT(*) as count FROM runs GROUP BY status"
).fetchall()
}

by_analysis_type = {
row["analysis_type"]: row["count"]
for row in conn.execute(
"SELECT analysis_type, COUNT(*) as count FROM runs GROUP BY analysis_type"
).fetchall()
}

recent_completed = conn.execute(
"SELECT COUNT(*) FROM runs WHERE status = 'completed' "
"AND created_at >= datetime('now', '-7 days')"
).fetchone()[0]

alerts_total = conn.execute("SELECT COUNT(*) FROM alerts").fetchone()[0]
alerts_unacknowledged = conn.execute(
"SELECT COUNT(*) FROM alerts WHERE acknowledged = 0"
).fetchone()[0]

return {
"total_runs": total,
"completed_last_7_days": recent_completed,
"by_status": by_status,
"by_analysis_type": by_analysis_type,
"alerts": {
"total": alerts_total,
"unacknowledged": alerts_unacknowledged,
},
}

@app.get("/api/runs/{run_id}")
def get_run(run_id: int) -> dict[str, Any]:
Expand Down