diff --git a/.claude/settings.local.json b/.claude/settings.local.json index c3a7a075..9fdeaa6f 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -17,7 +17,12 @@ "WebFetch(domain:docs.anthropic.com)", "mcp__linear-server__list_projects", "mcp__linear-server__list_issues", - "mcp__linear-server__get_issue" + "mcp__linear-server__get_issue", + "Bash(make build:*)", + "Bash(git add:*)", + "Bash(git commit:*)", + "Bash(git push)", + "Bash(git -C:*)" ], "deny": [], "ask": [] diff --git a/.env.template b/.env.template index 99f5bbd1..0b4d3acf 100644 --- a/.env.template +++ b/.env.template @@ -36,6 +36,20 @@ # Optional: Custom cache directory for local file cache # GOMODEL_CACHE_DIR=.cache +# ============================================================================= +# Admin API & Dashboard Configuration +# ============================================================================= + +# Enable/disable admin REST API endpoints (default: true) +# When enabled, provides /admin/api/v1/* REST endpoints +# ADMIN_ENDPOINTS_ENABLED=true + +# Enable/disable admin dashboard UI (default: true) +# When enabled, provides /admin/dashboard UI +# Requires ADMIN_ENDPOINTS_ENABLED=true — if endpoints are disabled +# and UI is enabled, a warning is logged and UI is forced to disabled +# ADMIN_UI_ENABLED=true + # ============================================================================= # Storage Configuration (used by audit logging, usage tracking, future IAM, etc.) # ============================================================================= diff --git a/CLAUDE.md b/CLAUDE.md index 6f61efdf..5993a943 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -157,6 +157,23 @@ Full reference: `.env.template` and `config/config.yaml` - **Guardrails:** Configured via `config/config.yaml` only (except `GUARDRAILS_ENABLED` env var) - **Providers:** `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GEMINI_API_KEY`, `XAI_API_KEY`, `GROQ_API_KEY`, `OLLAMA_BASE_URL` +## Documentation Maintenance + +After completing any code change, routinely check whether documentation needs updating. This applies to all three documentation layers: + +1. **README files** (`README.md`, `helm/README.md`, `tests/contract/README.md`) — Update when adding/removing features, changing setup steps, modifying CLI flags, or altering configuration options. +2. **In-code documentation** (Go doc comments on exported types, functions, interfaces) — Update when changing public APIs, adding new exported symbols, or modifying function signatures/behavior. +3. **Mintlify / technical docs** (`docs/` directory) — Update `docs/advanced/*.mdx` pages when changing configuration options or guardrails behavior. Update `docs/adr/` when making significant architectural decisions. Update `docs/plans/` if implementation diverges from existing plans. Check `docs.json` if new pages need to be added to the navigation. + +**When to update:** +- Adding a new provider, endpoint, config option, or feature +- Changing existing behavior, defaults, or API contracts +- Renaming or removing configuration variables +- Adding or modifying middleware, guardrails, or storage backends +- Changing build/test commands or requirements + +**How to check:** After making changes, scan the affected documentation layers for stale or missing information. Do not add speculative documentation for unimplemented features — only document what exists. + ## Key Details 1. Providers are registered explicitly via `factory.Register()` in main.go — order matters, first registered wins for duplicate model names diff --git a/README.md b/README.md index 05afc417..f70cbb32 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,10 @@ docker run --rm -p 8080:8080 --env-file .env gomodel | `/v1/models` | GET | List available models | | `/health` | GET | Health check | | `/metrics` | GET | Prometheus metrics (when enabled) | +| `/admin/api/v1/usage/summary` | GET | Aggregate token usage statistics | +| `/admin/api/v1/usage/daily` | GET | Per-period token usage breakdown | +| `/admin/api/v1/models` | GET | List models with provider type | +| `/admin/dashboard` | GET | Admin dashboard UI | --- @@ -194,10 +198,10 @@ See [DEVELOPMENT.md](DEVELOPMENT.md) for testing, linting, and pre-commit setup. | Full-observability | 🚧 | 🚧 | | Budget management | 🚧 | 🚧 | | Many keys support | 🚧 | 🚧 | -| Administrative endpoints | 🚧 | 🚧 | -| Guardrails | 🚧 | 🚧 | +| Administrative endpoints | ✅ | 🚧 | +| Guardrails | ✅ | 🚧 | | SSO | 🚧 | 🚧 | -| System Prompt (GuardRails) | 🚧 | 🚧 | +| System Prompt (GuardRails) | ✅ | 🚧 | ## Integrations diff --git a/config/config.go b/config/config.go index 0b9269f7..e2773a14 100644 --- a/config/config.go +++ b/config/config.go @@ -34,10 +34,24 @@ type Config struct { Usage UsageConfig `yaml:"usage"` Metrics MetricsConfig `yaml:"metrics"` HTTP HTTPConfig `yaml:"http"` + Admin AdminConfig `yaml:"admin"` Guardrails GuardrailsConfig `yaml:"guardrails"` Providers map[string]ProviderConfig `yaml:"providers"` } +// AdminConfig holds configuration for the admin API and dashboard UI. +type AdminConfig struct { + // EndpointsEnabled controls whether the admin REST API is active + // Default: true + EndpointsEnabled bool `yaml:"endpoints_enabled" env:"ADMIN_ENDPOINTS_ENABLED"` + + // UIEnabled controls whether the admin dashboard UI is active + // Requires EndpointsEnabled — if endpoints are disabled and UI is enabled, + // a warning is logged and UI is forced to false. + // Default: true + UIEnabled bool `yaml:"ui_enabled" env:"ADMIN_UI_ENABLED"` +} + // GuardrailsConfig holds configuration for the request guardrails pipeline. type GuardrailsConfig struct { // Enabled controls whether guardrails are active @@ -215,7 +229,7 @@ type RedisConfig struct { // ServerConfig holds HTTP server configuration type ServerConfig struct { Port string `yaml:"port" env:"PORT"` - MasterKey string `yaml:"master_key" env:"GOMODEL_MASTER_KEY"` // Optional: Master key for authentication + MasterKey string `yaml:"master_key" env:"GOMODEL_MASTER_KEY"` // Optional: Master key for authentication BodySizeLimit string `yaml:"body_size_limit" env:"BODY_SIZE_LIMIT"` // Max request body size (e.g., "10M", "1024K") } @@ -284,8 +298,9 @@ func defaultConfig() Config { Timeout: 600, ResponseHeaderTimeout: 600, }, + Admin: AdminConfig{EndpointsEnabled: true, UIEnabled: true}, Guardrails: GuardrailsConfig{}, - Providers: make(map[string]ProviderConfig), + Providers: make(map[string]ProviderConfig), } } diff --git a/docs/advanced/admin-endpoints.mdx b/docs/advanced/admin-endpoints.mdx new file mode 100644 index 00000000..2ab1a14b --- /dev/null +++ b/docs/advanced/admin-endpoints.mdx @@ -0,0 +1,185 @@ +--- +title: "Admin Endpoints" +description: "Built-in REST API and dashboard for monitoring usage, models, and gateway health." +--- + +## Philosophy + +GOModel ships with admin endpoints **enabled by default**. The goal is simple: you should be able to deploy GOModel and immediately have visibility into what's happening — no extra services, no separate monitoring stack, no configuration. + +The admin layer is split into two independently controllable pieces: + +1. **Admin REST API** (`/admin/api/v1/*`) — machine-readable JSON endpoints for usage data and model inventory. Protected by `GOMODEL_MASTER_KEY` like all other API routes. +2. **Admin Dashboard UI** (`/admin/dashboard`) — a lightweight, embedded HTML dashboard that visualizes the same data. No external dependencies, no JavaScript frameworks to install — it's compiled into the binary. + +Both are on by default because observability shouldn't be opt-in. If you don't need them, turn them off with a single environment variable. + +## Configuration + +| Variable | Description | Default | +| ------------------------- | ------------------------------------ | ------- | +| `ADMIN_ENDPOINTS_ENABLED` | Enable the admin REST API | `true` | +| `ADMIN_UI_ENABLED` | Enable the admin dashboard UI | `true` | + +Or in YAML: + +```yaml +admin: + endpoints_enabled: true + ui_enabled: true +``` + + + The dashboard UI requires the REST API to be enabled. If you set + `ADMIN_ENDPOINTS_ENABLED=false` but leave `ADMIN_UI_ENABLED=true`, the UI + will be automatically disabled with a warning in the logs. + + +## Authentication + +The admin REST API endpoints (`/admin/api/v1/*`) are protected by the same `GOMODEL_MASTER_KEY` authentication as the main API routes. Include the key as a Bearer token: + +```bash +curl -H "Authorization: Bearer $GOMODEL_MASTER_KEY" \ + http://localhost:8080/admin/api/v1/usage/summary +``` + +The dashboard UI pages (`/admin/dashboard`) and static assets (`/admin/static/*`) **skip authentication** so the dashboard is accessible without configuring API keys in the browser. + + + If your GOModel instance is publicly accessible, be aware that the dashboard + UI is unauthenticated. Disable it with `ADMIN_UI_ENABLED=false` or restrict + access at the network level. + + +## REST API Endpoints + +All admin API endpoints are mounted under `/admin/api/v1`. + +### GET /admin/api/v1/usage/summary + +Returns aggregate token usage statistics over a configurable time window. + +**Query parameters:** + +| Parameter | Type | Description | Default | +| ------------ | ------ | -------------------------------------------------------- | -------------------- | +| `start_date` | string | Range start in `YYYY-MM-DD` format | 29 days before end | +| `end_date` | string | Range end in `YYYY-MM-DD` format | Today | +| `days` | int | Shorthand for look-back window (ignored if dates are set) | `30` | + +Use `start_date`/`end_date` for explicit ranges or `days` as a shorthand. When both are provided, `start_date`/`end_date` take priority. + +**Response:** + +```json +{ + "total_requests": 1542, + "total_input_tokens": 2450000, + "total_output_tokens": 890000, + "total_tokens": 3340000 +} +``` + +If usage tracking is disabled, returns zeroed values. + +### GET /admin/api/v1/usage/daily + +Returns per-period token usage breakdown over a configurable time window, grouped by the specified interval. + +**Query parameters:** + +| Parameter | Type | Description | Default | +| ------------ | ------ | -------------------------------------------------------- | -------------------- | +| `start_date` | string | Range start in `YYYY-MM-DD` format | 29 days before end | +| `end_date` | string | Range end in `YYYY-MM-DD` format | Today | +| `days` | int | Shorthand for look-back window (ignored if dates are set) | `30` | +| `interval` | string | Grouping: `daily`, `weekly`, `monthly`, `yearly` | `daily` | + +The `date` field in the response changes format based on the interval: `YYYY-MM-DD` (daily), `YYYY-Www` (weekly), `YYYY-MM` (monthly), or `YYYY` (yearly). + +**Response:** + +```json +[ + { + "date": "2026-02-17", + "requests": 84, + "input_tokens": 120000, + "output_tokens": 45000, + "total_tokens": 165000 + }, + { + "date": "2026-02-16", + "requests": 102, + "input_tokens": 155000, + "output_tokens": 58000, + "total_tokens": 213000 + } +] +``` + +Returns an empty array if usage tracking is disabled or no data exists for the period. + +### GET /admin/api/v1/models + +Returns all registered models with their provider type. + +**Response:** + +```json +[ + { + "model": { + "id": "gpt-4o", + "object": "model", + "created": 1715367049, + "owned_by": "system" + }, + "provider_type": "openai" + }, + { + "model": { + "id": "claude-sonnet-4-5-20250929", + "object": "model", + "created": 1727568000, + "owned_by": "system" + }, + "provider_type": "anthropic" + } +] +``` + +This differs from the standard `/v1/models` endpoint: the admin version includes `provider_type` for each model, making it useful for understanding which provider serves which model. + +## Admin Dashboard + +The dashboard is a server-rendered HTML page embedded in the GOModel binary. Access it at: + +``` +http://localhost:8080/admin/dashboard +``` + +It provides a visual overview of usage statistics and registered models using the same data as the REST API endpoints above. + +## Disabling Admin Features + +To disable all admin features: + +```bash +export ADMIN_ENDPOINTS_ENABLED=false +``` + +This disables both the REST API and the dashboard UI. To keep the API but hide the dashboard: + +```bash +export ADMIN_ENDPOINTS_ENABLED=true +export ADMIN_UI_ENABLED=false +``` + + + The admin layer is designed to degrade gracefully. If usage tracking is off, + the usage endpoints return empty results instead of errors. If the model + registry isn't ready, the models endpoint returns an empty array. The gateway + keeps running regardless. + diff --git a/docs/advanced/configuration.mdx b/docs/advanced/configuration.mdx index fcab4ba3..f96c729c 100644 --- a/docs/advanced/configuration.mdx +++ b/docs/advanced/configuration.mdx @@ -101,6 +101,13 @@ Storage is shared by audit logging, usage tracking, and future features like IAM | `METRICS_ENABLED` | Enable Prometheus metrics | `false` | | `METRICS_ENDPOINT` | HTTP path for metrics | `/metrics` | +#### Admin + +| Variable | Description | Default | +| ------------------------- | ----------------------------- | ------- | +| `ADMIN_ENDPOINTS_ENABLED` | Enable the admin REST API | `true` | +| `ADMIN_UI_ENABLED` | Enable the admin dashboard UI | `true` | + #### HTTP Client These control timeouts for upstream API requests to LLM providers. diff --git a/docs/docs.json b/docs/docs.json index 57c39cec..e6e6262d 100644 --- a/docs/docs.json +++ b/docs/docs.json @@ -9,7 +9,7 @@ "pages": [ { "group": "Advanced", - "pages": ["advanced/configuration", "advanced/guardrails"] + "pages": ["advanced/configuration", "advanced/guardrails", "advanced/admin-endpoints"] } ] } diff --git a/internal/admin/dashboard/dashboard.go b/internal/admin/dashboard/dashboard.go new file mode 100644 index 00000000..e08aac69 --- /dev/null +++ b/internal/admin/dashboard/dashboard.go @@ -0,0 +1,57 @@ +// Package dashboard provides the embedded admin dashboard UI for GOModel. +package dashboard + +import ( + "bytes" + "embed" + "html/template" + "io/fs" + "net/http" + + "github.com/labstack/echo/v4" +) + +//go:embed templates/*.html static/css/*.css static/js/*.js static/*.svg +var content embed.FS + +// Handler serves the admin dashboard UI. +type Handler struct { + indexTmpl *template.Template + staticFS http.Handler +} + +// New creates a new dashboard handler with parsed templates and static file server. +func New() (*Handler, error) { + tmpl, err := template.ParseFS(content, "templates/layout.html", "templates/index.html") + if err != nil { + return nil, err + } + + staticSub, err := fs.Sub(content, "static") + if err != nil { + return nil, err + } + + return &Handler{ + indexTmpl: tmpl, + staticFS: http.StripPrefix("/admin/static/", http.FileServer(http.FS(staticSub))), + }, nil +} + +// Index serves GET /admin/dashboard — the main dashboard page. +func (h *Handler) Index(c echo.Context) error { + var buf bytes.Buffer + if err := h.indexTmpl.ExecuteTemplate(&buf, "layout", nil); err != nil { + return err + } + c.Response().Header().Set("Content-Type", "text/html; charset=utf-8") + c.Response().WriteHeader(http.StatusOK) + _, err := buf.WriteTo(c.Response().Writer) + return err +} + +// Static serves GET /admin/static/* — embedded CSS/JS assets. +func (h *Handler) Static(c echo.Context) error { + h.staticFS.ServeHTTP(c.Response().Writer, c.Request()) + return nil +} diff --git a/internal/admin/dashboard/dashboard_test.go b/internal/admin/dashboard/dashboard_test.go new file mode 100644 index 00000000..869cf8c3 --- /dev/null +++ b/internal/admin/dashboard/dashboard_test.go @@ -0,0 +1,139 @@ +package dashboard + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/labstack/echo/v4" +) + +func TestNew(t *testing.T) { + h, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } + if h == nil { + t.Fatal("New() returned nil handler") + } +} + +func TestIndex_ReturnsHTML(t *testing.T) { + h, err := New() + if err != nil { + t.Fatalf("New() returned error: %v", err) + } + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + if err := h.Index(c); err != nil { + t.Fatalf("Index() returned error: %v", err) + } + + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + + contentType := rec.Header().Get("Content-Type") + if contentType != "text/html; charset=utf-8" { + t.Errorf("expected Content-Type text/html; charset=utf-8, got %s", contentType) + } + + body := strings.ToLower(rec.Body.String()) + if !strings.Contains(body, " + + + + diff --git a/internal/admin/dashboard/static/js/dashboard.js b/internal/admin/dashboard/static/js/dashboard.js new file mode 100644 index 00000000..fa07667f --- /dev/null +++ b/internal/admin/dashboard/static/js/dashboard.js @@ -0,0 +1,546 @@ +// GOModel Dashboard — Alpine.js + Chart.js logic + +function dashboard() { + return { + // State + page: 'overview', + days: '30', + loading: false, + authError: false, + needsAuth: false, + apiKey: '', + theme: 'system', // 'system', 'light', 'dark' + sidebarCollapsed: false, + + // Date picker + datePickerOpen: false, + selectedPreset: '30', + customStartDate: null, + customEndDate: null, + selectingDate: 'start', // 'start' or 'end' + calendarMonth: new Date(), + cursorHint: { show: false, x: 0, y: 0 }, + + // Interval + interval: 'daily', + + // Data + summary: { total_requests: 0, total_input_tokens: 0, total_output_tokens: 0, total_tokens: 0 }, + daily: [], + models: [], + + // Filters + modelFilter: '', + + // Chart + chart: null, + + init() { + this.apiKey = localStorage.getItem('gomodel_api_key') || ''; + this.theme = localStorage.getItem('gomodel_theme') || 'system'; + this.sidebarCollapsed = localStorage.getItem('gomodel_sidebar_collapsed') === 'true'; + this.applyTheme(); + + // Parse initial page from URL path + const path = window.location.pathname.replace(/\/$/, ''); + const slug = path.split('/').pop(); + this.page = (slug === 'models') ? 'models' : 'overview'; + + // Handle browser back/forward + window.addEventListener('popstate', () => { + const p = window.location.pathname.replace(/\/$/, ''); + const s = p.split('/').pop(); + this.page = (s === 'models') ? 'models' : 'overview'; + if (this.page === 'overview') this.renderChart(); + }); + + // Re-render chart when system theme changes (only matters in 'system' mode) + window.matchMedia('(prefers-color-scheme: dark)').addEventListener('change', () => { + if (this.theme === 'system') { + this.renderChart(); + } + }); + + this.fetchAll(); + }, + + toggleSidebar() { + this.sidebarCollapsed = !this.sidebarCollapsed; + localStorage.setItem('gomodel_sidebar_collapsed', this.sidebarCollapsed); + // Re-render chart after CSS transition so Chart.js picks up the new width + setTimeout(() => this.renderChart(), 220); + }, + + // Date picker methods + toggleDatePicker() { + this.datePickerOpen = !this.datePickerOpen; + if (this.datePickerOpen) { + this.calendarMonth = new Date(); + this.selectingDate = 'start'; + } + }, + + closeDatePicker() { + this.datePickerOpen = false; + this.cursorHint.show = false; + }, + + onCalendarMouseMove(e) { + this.cursorHint = { show: true, x: e.clientX, y: e.clientY }; + }, + + onCalendarMouseLeave() { + this.cursorHint.show = false; + }, + + selectPreset(days) { + this.selectedPreset = days; + this.customStartDate = null; + this.customEndDate = null; + this.selectingDate = 'start'; + this.days = days; + this.fetchUsage(); + this.closeDatePicker(); + }, + + selectionHint() { + return this.selectingDate === 'end' ? 'Select end date' : 'Select start date'; + }, + + dateRangeLabel() { + if (this.selectedPreset) return 'Last ' + this.selectedPreset + ' days'; + if (this.customStartDate && this.customEndDate) { + return this.formatDateShort(this.customStartDate) + ' \u2013 ' + this.formatDateShort(this.customEndDate); + } + if (this.customStartDate) { + return this.formatDateShort(this.customStartDate) + ' \u2013 ...'; + } + return 'Last 30 days'; + }, + + formatDateShort(date) { + const months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']; + return months[date.getMonth()] + ' ' + date.getDate() + ', ' + date.getFullYear(); + }, + + calendarTitle(offset) { + const d = new Date(this.calendarMonth.getFullYear(), this.calendarMonth.getMonth() + offset, 1); + const months = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']; + return months[d.getMonth()] + ' ' + d.getFullYear(); + }, + + calendarDays(offset) { + const year = this.calendarMonth.getFullYear(); + const month = this.calendarMonth.getMonth() + offset; + const first = new Date(year, month, 1); + const last = new Date(year, month + 1, 0); + // Monday = 0, Sunday = 6 + let startDay = (first.getDay() + 6) % 7; + const days = []; + + // Padding days from previous month + const prevLast = new Date(year, month, 0); + for (let i = startDay - 1; i >= 0; i--) { + const d = prevLast.getDate() - i; + days.push({ day: d, date: new Date(year, month - 1, d), current: false, key: 'p' + d }); + } + + // Current month days + for (let d = 1; d <= last.getDate(); d++) { + days.push({ day: d, date: new Date(year, month, d), current: true, key: 'c' + d }); + } + + // Padding days from next month + const remaining = 42 - days.length; + for (let d = 1; d <= remaining; d++) { + days.push({ day: d, date: new Date(year, month + 1, d), current: false, key: 'n' + d }); + } + + return days; + }, + + prevMonth() { + this.calendarMonth = new Date(this.calendarMonth.getFullYear(), this.calendarMonth.getMonth() - 1, 1); + }, + + nextMonth() { + const next = new Date(this.calendarMonth.getFullYear(), this.calendarMonth.getMonth() + 1, 1); + const today = new Date(); + // Don't navigate past current month + if (next.getFullYear() < today.getFullYear() || + (next.getFullYear() === today.getFullYear() && next.getMonth() <= today.getMonth())) { + this.calendarMonth = next; + } + }, + + isCurrentMonth() { + const today = new Date(); + return this.calendarMonth.getFullYear() === today.getFullYear() + && this.calendarMonth.getMonth() === today.getMonth(); + }, + + selectCalendarDay(day) { + if (!day.current || this.isFutureDay(day)) return; + const clicked = new Date(day.date); + clicked.setHours(0, 0, 0, 0); + this.selectedPreset = null; + + if (this.selectingDate === 'start') { + this.customStartDate = clicked; + // Keep existing end date; if it's now before start, move it to start + if (this.customEndDate && this.customEndDate < clicked) { + this.customEndDate = clicked; + } + // If no end date yet, default to today + if (!this.customEndDate) { + const today = new Date(); + today.setHours(0, 0, 0, 0); + this.customEndDate = today; + } + this.selectingDate = 'end'; + this.fetchUsage(); + } else { + // Selecting end date + if (clicked < this.customStartDate) { + // Swap: treat click as new start, old start becomes end + this.customEndDate = this.customStartDate; + this.customStartDate = clicked; + } else { + this.customEndDate = clicked; + } + this.selectingDate = 'start'; + this.fetchUsage(); + this.closeDatePicker(); + } + }, + + isToday(day) { + if (!day.current) return false; + const today = new Date(); + return day.date.getFullYear() === today.getFullYear() && + day.date.getMonth() === today.getMonth() && + day.date.getDate() === today.getDate(); + }, + + isFutureDay(day) { + const today = new Date(); + today.setHours(23, 59, 59, 999); + return day.date > today; + }, + + isRangeStart(day) { + if (!day.current) return false; + const start = this._rangeStart(); + if (!start) return false; + return day.date.getFullYear() === start.getFullYear() && + day.date.getMonth() === start.getMonth() && + day.date.getDate() === start.getDate(); + }, + + isRangeEnd(day) { + if (!day.current) return false; + const end = this._rangeEnd(); + if (!end) return false; + return day.date.getFullYear() === end.getFullYear() && + day.date.getMonth() === end.getMonth() && + day.date.getDate() === end.getDate(); + }, + + isInRange(day) { + if (!day.current) return false; + const start = this._rangeStart(); + const end = this._rangeEnd(); + if (!start || !end) return false; + const dayDate = new Date(day.date); + dayDate.setHours(0, 0, 0, 0); + return dayDate >= start && dayDate <= end; + }, + + _rangeStart() { + if (this.customStartDate) return this.customStartDate; + if (this.selectedPreset) { + const s = new Date(); + s.setHours(0, 0, 0, 0); + s.setDate(s.getDate() - (parseInt(this.selectedPreset, 10) - 1)); + return s; + } + return null; + }, + + _rangeEnd() { + if (this.customEndDate) return this.customEndDate; + if (this.customStartDate || this.selectedPreset) { + const t = new Date(); + t.setHours(0, 0, 0, 0); + return t; + } + return null; + }, + + // Interval methods + setInterval(val) { + this.interval = val; + this.fetchUsage(); + }, + + chartTitle() { + const titles = { daily: 'Daily', weekly: 'Weekly', monthly: 'Monthly', yearly: 'Yearly' }; + return (titles[this.interval] || 'Daily') + ' Token Usage'; + }, + + navigate(page) { + this.page = page; + history.pushState(null, '', '/admin/dashboard/' + page); + if (page === 'overview') this.renderChart(); + }, + + setTheme(t) { + this.theme = t; + localStorage.setItem('gomodel_theme', t); + this.applyTheme(); + this.renderChart(); + }, + + toggleTheme() { + const order = ['light', 'system', 'dark']; + this.setTheme(order[(order.indexOf(this.theme) + 1) % order.length]); + }, + + applyTheme() { + const root = document.documentElement; + if (this.theme === 'system') { + root.removeAttribute('data-theme'); + } else { + root.setAttribute('data-theme', this.theme); + } + }, + + cssVar(name) { + return getComputedStyle(document.documentElement).getPropertyValue(name).trim(); + }, + + chartColors() { + return { + grid: this.cssVar('--chart-grid'), + text: this.cssVar('--chart-text'), + tooltipBg: this.cssVar('--chart-tooltip-bg'), + tooltipBorder: this.cssVar('--chart-tooltip-border'), + tooltipText: this.cssVar('--chart-tooltip-text'), + }; + }, + + saveApiKey() { + if (this.apiKey) { + localStorage.setItem('gomodel_api_key', this.apiKey); + } else { + localStorage.removeItem('gomodel_api_key'); + } + }, + + headers() { + const h = { 'Content-Type': 'application/json' }; + if (this.apiKey) { + h['Authorization'] = 'Bearer ' + this.apiKey; + } + return h; + }, + + async fetchAll() { + this.loading = true; + this.authError = false; + this.needsAuth = false; + await Promise.all([this.fetchUsage(), this.fetchModels()]); + this.loading = false; + }, + + handleFetchResponse(res, label) { + if (res.status === 401) { + this.authError = true; + this.needsAuth = true; + return false; + } + if (!res.ok) { + console.error(`Failed to fetch ${label}: ${res.status} ${res.statusText}`); + return false; + } + return true; + }, + + _formatDate(date) { + return date.getFullYear() + '-' + + String(date.getMonth() + 1).padStart(2, '0') + '-' + + String(date.getDate()).padStart(2, '0'); + }, + + async fetchUsage() { + try { + var queryStr; + if (this.customStartDate && this.customEndDate) { + queryStr = 'start_date=' + this._formatDate(this.customStartDate) + + '&end_date=' + this._formatDate(this.customEndDate); + } else { + queryStr = 'days=' + this.days; + } + queryStr += '&interval=' + this.interval; + + const [summaryRes, dailyRes] = await Promise.all([ + fetch('/admin/api/v1/usage/summary?' + queryStr, { headers: this.headers() }), + fetch('/admin/api/v1/usage/daily?' + queryStr, { headers: this.headers() }) + ]); + + if (!this.handleFetchResponse(summaryRes, 'usage summary') || + !this.handleFetchResponse(dailyRes, 'usage daily')) { + return; + } + + this.summary = await summaryRes.json(); + this.daily = await dailyRes.json(); + this.renderChart(); + } catch (e) { + console.error('Failed to fetch usage:', e); + } + }, + + async fetchModels() { + try { + const res = await fetch('/admin/api/v1/models', { headers: this.headers() }); + if (!this.handleFetchResponse(res, 'models')) { + this.models = []; + return; + } + this.models = await res.json(); + } catch (e) { + console.error('Failed to fetch models:', e); + this.models = []; + } + }, + + fillMissingDays(daily) { + // For non-daily intervals, return data as-is (no gap filling) + if (this.interval !== 'daily') { + return daily; + } + + const byDate = {}; + daily.forEach(d => { byDate[d.date] = d; }); + const end = this.customEndDate ? new Date(this.customEndDate) : new Date(); + end.setHours(0, 0, 0, 0); + const start = this.customStartDate ? new Date(this.customStartDate) : new Date(end); + if (!this.customStartDate) { + start.setDate(start.getDate() - (parseInt(this.days, 10) - 1)); + } + const result = []; + for (let d = new Date(start); d <= end; d.setDate(d.getDate() + 1)) { + const key = d.getFullYear() + '-' + String(d.getMonth() + 1).padStart(2, '0') + '-' + String(d.getDate()).padStart(2, '0'); + result.push(byDate[key] || { date: key, input_tokens: 0, output_tokens: 0, total_tokens: 0, requests: 0 }); + } + return result; + }, + + renderChart() { + this.$nextTick(() => { + if (this.chart) { + this.chart.destroy(); + this.chart = null; + } + + if (this.daily.length === 0) return; + + const canvas = document.getElementById('usageChart'); + if (!canvas || canvas.offsetWidth === 0) return; + + const colors = this.chartColors(); + const filled = this.fillMissingDays(this.daily); + const labels = filled.map(d => d.date); + const inputData = filled.map(d => d.input_tokens); + const outputData = filled.map(d => d.output_tokens); + + this.chart = new Chart(canvas, { + type: 'line', + data: { + labels: labels, + datasets: [ + { + label: 'Input Tokens', + data: inputData, + borderColor: '#b8956e', + backgroundColor: 'rgba(184, 149, 110, 0.1)', + fill: true, + tension: 0.3, + pointRadius: 3, + pointHoverRadius: 5 + }, + { + label: 'Output Tokens', + data: outputData, + borderColor: '#34d399', + backgroundColor: 'rgba(52, 211, 153, 0.1)', + fill: true, + tension: 0.3, + pointRadius: 3, + pointHoverRadius: 5 + } + ] + }, + options: { + responsive: true, + maintainAspectRatio: false, + animation: { duration: 0 }, + interaction: { mode: 'index', intersect: false }, + plugins: { + legend: { + labels: { color: colors.text, font: { size: 12 } } + }, + tooltip: { + backgroundColor: colors.tooltipBg, + borderColor: colors.tooltipBorder, + borderWidth: 1, + titleColor: colors.tooltipText, + bodyColor: colors.tooltipText, + callbacks: { + label: function(c) { + return c.dataset.label + ': ' + c.parsed.y.toLocaleString(); + } + } + } + }, + scales: { + x: { + grid: { color: colors.grid }, + ticks: { color: colors.text, font: { size: 11 }, maxTicksLimit: 10 } + }, + y: { + beginAtZero: true, + grid: { color: colors.grid }, + ticks: { + color: colors.text, + font: { size: 11 }, + callback: function(value) { + if (value >= 1000000) return (value / 1000000).toFixed(1) + 'M'; + if (value >= 1000) return (value / 1000).toFixed(1) + 'K'; + return value; + } + } + } + } + } + }); + }); + }, + + get filteredModels() { + if (!this.modelFilter) return this.models; + const f = this.modelFilter.toLowerCase(); + return this.models.filter(m => + (m.model?.id ?? '').toLowerCase().includes(f) || + (m.provider_type ?? '').toLowerCase().includes(f) || + (m.model?.owned_by ?? '').toLowerCase().includes(f) + ); + }, + + formatNumber(n) { + if (n == null || n === undefined) return '-'; + return n.toLocaleString(); + } + }; +} diff --git a/internal/admin/dashboard/templates/index.html b/internal/admin/dashboard/templates/index.html new file mode 100644 index 00000000..ec9f8542 --- /dev/null +++ b/internal/admin/dashboard/templates/index.html @@ -0,0 +1,182 @@ +{{define "index"}} + +
+ + + +
+ Authentication required. Enter your API key in the sidebar to view data. +
+ + +
+
+
Total Requests
+
-
+
+
+
Input Tokens
+
-
+
+
+
Output Tokens
+
-
+
+
+
Total Tokens
+
-
+
+
+ + +
+

Daily Token Usage

+
+ +
+

No usage data yet.

+
+
+ + +
+ + + +
+ Authentication required. Enter your API key in the sidebar to view data. +
+ + +
+ +
+ + +
+ + + + + + + + + + + +
Model IDProviderOwned By
+
+

No models registered.

+
+{{end}} diff --git a/internal/admin/dashboard/templates/layout.html b/internal/admin/dashboard/templates/layout.html new file mode 100644 index 00000000..5db4346f --- /dev/null +++ b/internal/admin/dashboard/templates/layout.html @@ -0,0 +1,79 @@ +{{define "layout"}} + + + + + GOModel Dashboard + + + + + + + + + + +
+ + +
+ {{template "index" .}} +
+
+ + +{{end}} diff --git a/internal/admin/handler.go b/internal/admin/handler.go new file mode 100644 index 00000000..f33d7192 --- /dev/null +++ b/internal/admin/handler.go @@ -0,0 +1,169 @@ +// Package admin provides the admin REST API and dashboard for GOModel. +package admin + +import ( + "errors" + "net/http" + "strconv" + "time" + + "github.com/labstack/echo/v4" + + "gomodel/internal/core" + "gomodel/internal/providers" + "gomodel/internal/usage" +) + +// Handler serves admin API endpoints. +type Handler struct { + usageReader usage.UsageReader + registry *providers.ModelRegistry +} + +// NewHandler creates a new admin API handler. +// usageReader may be nil if usage tracking is not available. +func NewHandler(reader usage.UsageReader, registry *providers.ModelRegistry) *Handler { + return &Handler{ + usageReader: reader, + registry: registry, + } +} + +var validIntervals = map[string]bool{ + "daily": true, + "weekly": true, + "monthly": true, + "yearly": true, +} + +// parseUsageParams extracts UsageQueryParams from the request query string. +// Returns an error if date parameters are provided but malformed. +func parseUsageParams(c echo.Context) (usage.UsageQueryParams, error) { + var params usage.UsageQueryParams + + now := time.Now().UTC() + today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + + startStr := c.QueryParam("start_date") + endStr := c.QueryParam("end_date") + + var startParsed, endParsed bool + + if startStr != "" { + t, err := time.Parse("2006-01-02", startStr) + if err != nil { + return params, core.NewInvalidRequestError("invalid start_date format, expected YYYY-MM-DD", nil) + } + params.StartDate = t + startParsed = true + } + + if endStr != "" { + t, err := time.Parse("2006-01-02", endStr) + if err != nil { + return params, core.NewInvalidRequestError("invalid end_date format, expected YYYY-MM-DD", nil) + } + params.EndDate = t + endParsed = true + } + + if startParsed || endParsed { + // Fill in missing side + if !startParsed { + params.StartDate = params.EndDate.AddDate(0, 0, -29) + } + if !endParsed { + params.EndDate = today + } + } else { + // Fall back to days param + days := 30 + if d := c.QueryParam("days"); d != "" { + if parsed, err := strconv.Atoi(d); err == nil && parsed > 0 { + days = parsed + } + } + params.EndDate = today + params.StartDate = today.AddDate(0, 0, -(days - 1)) + } + + // Parse interval + params.Interval = c.QueryParam("interval") + if !validIntervals[params.Interval] { + params.Interval = "daily" + } + + return params, nil +} + +// handleError converts errors to appropriate HTTP responses, matching the +// format used by the main API handlers in the server package. +func handleError(c echo.Context, err error) error { + var gatewayErr *core.GatewayError + if errors.As(err, &gatewayErr) { + return c.JSON(gatewayErr.HTTPStatusCode(), gatewayErr.ToJSON()) + } + + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ + "error": map[string]interface{}{ + "type": "internal_error", + "message": "an unexpected error occurred", + }, + }) +} + +// UsageSummary handles GET /admin/api/v1/usage/summary +func (h *Handler) UsageSummary(c echo.Context) error { + if h.usageReader == nil { + return c.JSON(http.StatusOK, usage.UsageSummary{}) + } + + params, err := parseUsageParams(c) + if err != nil { + return handleError(c, err) + } + + summary, err := h.usageReader.GetSummary(c.Request().Context(), params) + if err != nil { + return handleError(c, err) + } + + return c.JSON(http.StatusOK, summary) +} + +// DailyUsage handles GET /admin/api/v1/usage/daily +func (h *Handler) DailyUsage(c echo.Context) error { + if h.usageReader == nil { + return c.JSON(http.StatusOK, []usage.DailyUsage{}) + } + + params, err := parseUsageParams(c) + if err != nil { + return handleError(c, err) + } + + daily, err := h.usageReader.GetDailyUsage(c.Request().Context(), params) + if err != nil { + return handleError(c, err) + } + + if daily == nil { + daily = []usage.DailyUsage{} + } + + return c.JSON(http.StatusOK, daily) +} + +// ListModels handles GET /admin/api/v1/models +func (h *Handler) ListModels(c echo.Context) error { + if h.registry == nil { + return c.JSON(http.StatusOK, []providers.ModelWithProvider{}) + } + + models := h.registry.ListModelsWithProvider() + if models == nil { + models = []providers.ModelWithProvider{} + } + + return c.JSON(http.StatusOK, models) +} diff --git a/internal/admin/handler_test.go b/internal/admin/handler_test.go new file mode 100644 index 00000000..4fba5b28 --- /dev/null +++ b/internal/admin/handler_test.go @@ -0,0 +1,608 @@ +package admin + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/labstack/echo/v4" + + "gomodel/internal/core" + "gomodel/internal/providers" + "gomodel/internal/usage" +) + +// mockUsageReader implements usage.UsageReader for testing. +type mockUsageReader struct { + summary *usage.UsageSummary + daily []usage.DailyUsage + summaryErr error + dailyErr error +} + +func (m *mockUsageReader) GetSummary(_ context.Context, _ usage.UsageQueryParams) (*usage.UsageSummary, error) { + if m.summaryErr != nil { + return nil, m.summaryErr + } + return m.summary, nil +} + +func (m *mockUsageReader) GetDailyUsage(_ context.Context, _ usage.UsageQueryParams) ([]usage.DailyUsage, error) { + if m.dailyErr != nil { + return nil, m.dailyErr + } + return m.daily, nil +} + +// handlerMockProvider implements core.Provider for ListModels registry testing. +type handlerMockProvider struct { + models *core.ModelsResponse +} + +func (m *handlerMockProvider) ChatCompletion(_ context.Context, _ *core.ChatRequest) (*core.ChatResponse, error) { + return nil, nil +} +func (m *handlerMockProvider) StreamChatCompletion(_ context.Context, _ *core.ChatRequest) (io.ReadCloser, error) { + return nil, nil +} +func (m *handlerMockProvider) ListModels(_ context.Context) (*core.ModelsResponse, error) { + return m.models, nil +} +func (m *handlerMockProvider) Responses(_ context.Context, _ *core.ResponsesRequest) (*core.ResponsesResponse, error) { + return nil, nil +} +func (m *handlerMockProvider) StreamResponses(_ context.Context, _ *core.ResponsesRequest) (io.ReadCloser, error) { + return nil, nil +} + +func newHandlerContext(path string) (echo.Context, *httptest.ResponseRecorder) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, path, nil) + rec := httptest.NewRecorder() + return e.NewContext(req, rec), rec +} + +// --- UsageSummary handler tests --- + +func TestUsageSummary_NilReader(t *testing.T) { + h := NewHandler(nil, nil) + c, rec := newHandlerContext("/admin/api/v1/usage/summary") + + if err := h.UsageSummary(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + var summary usage.UsageSummary + if err := json.Unmarshal(rec.Body.Bytes(), &summary); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if summary.TotalRequests != 0 || summary.TotalInput != 0 || summary.TotalOutput != 0 || summary.TotalTokens != 0 { + t.Errorf("expected zeroed summary, got %+v", summary) + } +} + +func TestUsageSummary_Success(t *testing.T) { + reader := &mockUsageReader{ + summary: &usage.UsageSummary{ + TotalRequests: 42, + TotalInput: 1000, + TotalOutput: 500, + TotalTokens: 1500, + }, + } + h := NewHandler(reader, nil) + c, rec := newHandlerContext("/admin/api/v1/usage/summary?days=30") + + if err := h.UsageSummary(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + var summary usage.UsageSummary + if err := json.Unmarshal(rec.Body.Bytes(), &summary); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if summary.TotalRequests != 42 { + t.Errorf("expected 42 requests, got %d", summary.TotalRequests) + } + if summary.TotalTokens != 1500 { + t.Errorf("expected 1500 total tokens, got %d", summary.TotalTokens) + } +} + +func TestUsageSummary_GatewayError(t *testing.T) { + reader := &mockUsageReader{ + summaryErr: core.NewProviderError("test", http.StatusBadGateway, "upstream failed", nil), + } + h := NewHandler(reader, nil) + c, rec := newHandlerContext("/admin/api/v1/usage/summary") + + if err := h.UsageSummary(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusBadGateway { + t.Errorf("expected 502, got %d", rec.Code) + } + body := rec.Body.String() + if !containsString(body, "provider_error") { + t.Errorf("expected provider_error in body, got: %s", body) + } +} + +func TestUsageSummary_GenericError(t *testing.T) { + reader := &mockUsageReader{ + summaryErr: errors.New("database connection lost"), + } + h := NewHandler(reader, nil) + c, rec := newHandlerContext("/admin/api/v1/usage/summary") + + if err := h.UsageSummary(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", rec.Code) + } + body := rec.Body.String() + if !containsString(body, "internal_error") { + t.Errorf("expected internal_error in body, got: %s", body) + } + if containsString(body, "database connection lost") { + t.Errorf("original error message should be hidden, got: %s", body) + } + if !containsString(body, "an unexpected error occurred") { + t.Errorf("expected generic message, got: %s", body) + } +} + +// --- DailyUsage handler tests --- + +func TestDailyUsage_NilReader(t *testing.T) { + h := NewHandler(nil, nil) + c, rec := newHandlerContext("/admin/api/v1/usage/daily") + + if err := h.DailyUsage(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + // Should be [] not null + if rec.Body.String() != "[]\n" { + t.Errorf("expected empty JSON array, got: %q", rec.Body.String()) + } +} + +func TestDailyUsage_Success(t *testing.T) { + reader := &mockUsageReader{ + daily: []usage.DailyUsage{ + {Date: "2026-02-01", Requests: 10, InputTokens: 100, OutputTokens: 50, TotalTokens: 150}, + {Date: "2026-02-02", Requests: 20, InputTokens: 200, OutputTokens: 100, TotalTokens: 300}, + }, + } + h := NewHandler(reader, nil) + c, rec := newHandlerContext("/admin/api/v1/usage/daily?days=7") + + if err := h.DailyUsage(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + var daily []usage.DailyUsage + if err := json.Unmarshal(rec.Body.Bytes(), &daily); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if len(daily) != 2 { + t.Errorf("expected 2 entries, got %d", len(daily)) + } +} + +func TestDailyUsage_NilResult(t *testing.T) { + reader := &mockUsageReader{ + daily: nil, // reader returns nil slice + } + h := NewHandler(reader, nil) + c, rec := newHandlerContext("/admin/api/v1/usage/daily") + + if err := h.DailyUsage(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + // Should be [] not null + if rec.Body.String() != "[]\n" { + t.Errorf("expected empty JSON array, got: %q", rec.Body.String()) + } +} + +func TestDailyUsage_Error(t *testing.T) { + reader := &mockUsageReader{ + dailyErr: core.NewRateLimitError("test", "too many requests"), + } + h := NewHandler(reader, nil) + c, rec := newHandlerContext("/admin/api/v1/usage/daily") + + if err := h.DailyUsage(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusTooManyRequests { + t.Errorf("expected 429, got %d", rec.Code) + } + body := rec.Body.String() + if !containsString(body, "rate_limit_error") { + t.Errorf("expected rate_limit_error in body, got: %s", body) + } +} + +// --- ListModels handler tests --- + +func TestListModels_NilRegistry(t *testing.T) { + h := NewHandler(nil, nil) + c, rec := newHandlerContext("/admin/api/v1/models") + + if err := h.ListModels(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "[]\n" { + t.Errorf("expected empty JSON array, got: %q", rec.Body.String()) + } +} + +func TestListModels_WithModels(t *testing.T) { + registry := providers.NewModelRegistry() + mock := &handlerMockProvider{ + models: &core.ModelsResponse{ + Object: "list", + Data: []core.Model{ + {ID: "gpt-4", Object: "model", OwnedBy: "openai"}, + {ID: "claude-3", Object: "model", OwnedBy: "anthropic"}, + }, + }, + } + registry.RegisterProviderWithType(mock, "test") + if err := registry.Initialize(context.Background()); err != nil { + t.Fatalf("failed to initialize registry: %v", err) + } + + h := NewHandler(nil, registry) + c, rec := newHandlerContext("/admin/api/v1/models") + + if err := h.ListModels(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + + var models []providers.ModelWithProvider + if err := json.Unmarshal(rec.Body.Bytes(), &models); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if len(models) != 2 { + t.Fatalf("expected 2 models, got %d", len(models)) + } + // Should be sorted by model ID + if models[0].Model.ID != "claude-3" { + t.Errorf("expected first model to be claude-3, got %s", models[0].Model.ID) + } + if models[1].Model.ID != "gpt-4" { + t.Errorf("expected second model to be gpt-4, got %s", models[1].Model.ID) + } + if models[0].ProviderType != "test" { + t.Errorf("expected provider type 'test', got %s", models[0].ProviderType) + } +} + +func TestListModels_EmptyRegistry(t *testing.T) { + // A registry with no providers initialized — ListModelsWithProvider returns nil + registry := providers.NewModelRegistry() + + h := NewHandler(nil, registry) + c, rec := newHandlerContext("/admin/api/v1/models") + + if err := h.ListModels(c); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + if rec.Body.String() != "[]\n" { + t.Errorf("expected empty JSON array, got: %q", rec.Body.String()) + } +} + +// --- handleError tests --- + +func TestHandleError_GatewayErrors(t *testing.T) { + tests := []struct { + name string + err error + expectedStatus int + expectedType string + }{ + { + name: "provider_error → 502", + err: core.NewProviderError("test", http.StatusBadGateway, "upstream error", nil), + expectedStatus: http.StatusBadGateway, + expectedType: "provider_error", + }, + { + name: "rate_limit_error → 429", + err: core.NewRateLimitError("test", "rate limited"), + expectedStatus: http.StatusTooManyRequests, + expectedType: "rate_limit_error", + }, + { + name: "invalid_request_error → 400", + err: core.NewInvalidRequestError("bad input", nil), + expectedStatus: http.StatusBadRequest, + expectedType: "invalid_request_error", + }, + { + name: "authentication_error → 401", + err: core.NewAuthenticationError("test", "invalid key"), + expectedStatus: http.StatusUnauthorized, + expectedType: "authentication_error", + }, + { + name: "not_found_error → 404", + err: core.NewNotFoundError("model not found"), + expectedStatus: http.StatusNotFound, + expectedType: "not_found_error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, rec := newHandlerContext("/test") + + if err := handleError(c, tt.err); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rec.Code) + } + body := rec.Body.String() + if !containsString(body, tt.expectedType) { + t.Errorf("expected %s in body, got: %s", tt.expectedType, body) + } + }) + } +} + +func TestHandleError_UnexpectedError(t *testing.T) { + c, rec := newHandlerContext("/test") + + if err := handleError(c, errors.New("something broke")); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if rec.Code != http.StatusInternalServerError { + t.Errorf("expected 500, got %d", rec.Code) + } + body := rec.Body.String() + if !containsString(body, "an unexpected error occurred") { + t.Errorf("expected generic message, got: %s", body) + } + if containsString(body, "something broke") { + t.Errorf("original error should be hidden, got: %s", body) + } +} + +// containsString is a small helper to check substring presence. +func containsString(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && stringContains(s, substr)) +} + +func stringContains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func newContext(query string) echo.Context { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/test?"+query, nil) + rec := httptest.NewRecorder() + return e.NewContext(req, rec) +} + +func TestParseUsageParams_DaysDefault(t *testing.T) { + c := newContext("") + params, err := parseUsageParams(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if params.Interval != "daily" { + t.Errorf("expected interval 'daily', got %q", params.Interval) + } + + today := time.Now().UTC() + expectedEnd := time.Date(today.Year(), today.Month(), today.Day(), 0, 0, 0, 0, time.UTC) + expectedStart := expectedEnd.AddDate(0, 0, -29) + + if !params.EndDate.Equal(expectedEnd) { + t.Errorf("expected end date %v, got %v", expectedEnd, params.EndDate) + } + if !params.StartDate.Equal(expectedStart) { + t.Errorf("expected start date %v, got %v", expectedStart, params.StartDate) + } +} + +func TestParseUsageParams_DaysExplicit(t *testing.T) { + c := newContext("days=7") + params, err := parseUsageParams(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + today := time.Now().UTC() + expectedEnd := time.Date(today.Year(), today.Month(), today.Day(), 0, 0, 0, 0, time.UTC) + expectedStart := expectedEnd.AddDate(0, 0, -6) + + if !params.StartDate.Equal(expectedStart) { + t.Errorf("expected start date %v, got %v", expectedStart, params.StartDate) + } + if !params.EndDate.Equal(expectedEnd) { + t.Errorf("expected end date %v, got %v", expectedEnd, params.EndDate) + } +} + +func TestParseUsageParams_StartAndEndDate(t *testing.T) { + c := newContext("start_date=2026-01-01&end_date=2026-01-31") + params, err := parseUsageParams(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedStart := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + expectedEnd := time.Date(2026, 1, 31, 0, 0, 0, 0, time.UTC) + + if !params.StartDate.Equal(expectedStart) { + t.Errorf("expected start date %v, got %v", expectedStart, params.StartDate) + } + if !params.EndDate.Equal(expectedEnd) { + t.Errorf("expected end date %v, got %v", expectedEnd, params.EndDate) + } +} + +func TestParseUsageParams_OnlyStartDate(t *testing.T) { + c := newContext("start_date=2026-01-15") + params, err := parseUsageParams(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedStart := time.Date(2026, 1, 15, 0, 0, 0, 0, time.UTC) + today := time.Now().UTC() + expectedEnd := time.Date(today.Year(), today.Month(), today.Day(), 0, 0, 0, 0, time.UTC) + + if !params.StartDate.Equal(expectedStart) { + t.Errorf("expected start date %v, got %v", expectedStart, params.StartDate) + } + if !params.EndDate.Equal(expectedEnd) { + t.Errorf("expected end date %v, got %v", expectedEnd, params.EndDate) + } +} + +func TestParseUsageParams_OnlyEndDate(t *testing.T) { + c := newContext("end_date=2026-02-10") + params, err := parseUsageParams(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedEnd := time.Date(2026, 2, 10, 0, 0, 0, 0, time.UTC) + expectedStart := expectedEnd.AddDate(0, 0, -29) + + if !params.StartDate.Equal(expectedStart) { + t.Errorf("expected start date %v, got %v", expectedStart, params.StartDate) + } + if !params.EndDate.Equal(expectedEnd) { + t.Errorf("expected end date %v, got %v", expectedEnd, params.EndDate) + } +} + +func TestParseUsageParams_InvalidStartDate(t *testing.T) { + c := newContext("start_date=invalid") + _, err := parseUsageParams(c) + if err == nil { + t.Fatal("expected error for invalid start_date, got nil") + } + + var gatewayErr *core.GatewayError + if !errors.As(err, &gatewayErr) { + t.Fatalf("expected GatewayError, got %T", err) + } + if gatewayErr.HTTPStatusCode() != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", gatewayErr.HTTPStatusCode()) + } +} + +func TestParseUsageParams_InvalidEndDate(t *testing.T) { + c := newContext("start_date=2026-01-01&end_date=also-invalid") + _, err := parseUsageParams(c) + if err == nil { + t.Fatal("expected error for invalid end_date, got nil") + } + + var gatewayErr *core.GatewayError + if !errors.As(err, &gatewayErr) { + t.Fatalf("expected GatewayError, got %T", err) + } + if gatewayErr.HTTPStatusCode() != http.StatusBadRequest { + t.Errorf("expected status 400, got %d", gatewayErr.HTTPStatusCode()) + } +} + +func TestParseUsageParams_IntervalWeekly(t *testing.T) { + c := newContext("interval=weekly") + params, err := parseUsageParams(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if params.Interval != "weekly" { + t.Errorf("expected interval 'weekly', got %q", params.Interval) + } +} + +func TestParseUsageParams_IntervalMonthly(t *testing.T) { + c := newContext("interval=monthly") + params, err := parseUsageParams(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if params.Interval != "monthly" { + t.Errorf("expected interval 'monthly', got %q", params.Interval) + } +} + +func TestParseUsageParams_IntervalInvalid(t *testing.T) { + c := newContext("interval=hourly") + params, err := parseUsageParams(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if params.Interval != "daily" { + t.Errorf("expected default interval 'daily', got %q", params.Interval) + } +} + +func TestParseUsageParams_IntervalEmpty(t *testing.T) { + c := newContext("") + params, err := parseUsageParams(c) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if params.Interval != "daily" { + t.Errorf("expected default interval 'daily', got %q", params.Interval) + } +} + +// Ensure usage.UsageQueryParams is the type used (compile check) +var _ = func() usage.UsageQueryParams { + return usage.UsageQueryParams{ + StartDate: time.Time{}, + EndDate: time.Time{}, + Interval: "daily", + } +} diff --git a/internal/app/app.go b/internal/app/app.go index 09e37240..4efb0a1e 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -12,11 +12,14 @@ import ( "time" "gomodel/config" + "gomodel/internal/admin" + "gomodel/internal/admin/dashboard" "gomodel/internal/auditlog" "gomodel/internal/core" "gomodel/internal/guardrails" "gomodel/internal/providers" "gomodel/internal/server" + "gomodel/internal/storage" "gomodel/internal/usage" ) @@ -134,6 +137,31 @@ func New(ctx context.Context, cfg Config) (*App, error) { UsageLogger: usageResult.Logger, LogOnlyModelInteractions: cfg.AppConfig.Logging.OnlyModelInteractions, } + + // Initialize admin API and dashboard (behind separate feature flags) + adminCfg := cfg.AppConfig.Admin + if !adminCfg.EndpointsEnabled && adminCfg.UIEnabled { + slog.Warn("ADMIN_UI_ENABLED=true requires ADMIN_ENDPOINTS_ENABLED=true — forcing UI to disabled") + adminCfg.UIEnabled = false + } + if adminCfg.EndpointsEnabled { + adminHandler, dashHandler, adminErr := initAdmin(auditResult.Storage, usageResult.Storage, providerResult.Registry, adminCfg.UIEnabled) + if adminErr != nil { + slog.Warn("failed to initialize admin", "error", adminErr) + } else { + serverCfg.AdminEndpointsEnabled = true + serverCfg.AdminHandler = adminHandler + slog.Info("admin API enabled", "api", "/admin/api/v1") + if adminCfg.UIEnabled { + serverCfg.AdminUIEnabled = true + serverCfg.DashboardHandler = dashHandler + slog.Info("admin UI enabled", "url", fmt.Sprintf("http://localhost:%s/admin/dashboard", cfg.AppConfig.Server.Port)) + } + } + } else { + slog.Info("admin API disabled") + } + app.server = server.New(provider, serverCfg) return app, nil @@ -284,6 +312,42 @@ func (a *App) logStartupInfo() { } else { slog.Info("usage tracking disabled") } + +} + +// initAdmin creates the admin API handler and optionally the dashboard handler. +// Returns nil dashboard handler if uiEnabled is false. +func initAdmin(auditStorage, usageStorage storage.Storage, registry *providers.ModelRegistry, uiEnabled bool) (*admin.Handler, *dashboard.Handler, error) { + // Find a storage connection for reading usage data + var store storage.Storage + if auditStorage != nil { + store = auditStorage + } else if usageStorage != nil { + store = usageStorage + } + + // Create usage reader (may be nil if no storage) + var reader usage.UsageReader + if store != nil { + var err error + reader, err = usage.NewReader(store) + if err != nil { + return nil, nil, fmt.Errorf("failed to create usage reader: %w", err) + } + } + + adminHandler := admin.NewHandler(reader, registry) + + var dashHandler *dashboard.Handler + if uiEnabled { + var err error + dashHandler, err = dashboard.New() + if err != nil { + return nil, nil, fmt.Errorf("failed to initialize dashboard: %w", err) + } + } + + return adminHandler, dashHandler, nil } // buildGuardrailsPipeline creates a guardrails pipeline from configuration. diff --git a/internal/providers/registry.go b/internal/providers/registry.go index 5557e63d..603f79a3 100644 --- a/internal/providers/registry.go +++ b/internal/providers/registry.go @@ -355,6 +355,32 @@ func (r *ModelRegistry) GetProviderType(model string) string { return r.providerTypes[info.Provider] } +// ModelWithProvider holds a model alongside its provider type string. +type ModelWithProvider struct { + Model core.Model `json:"model"` + ProviderType string `json:"provider_type"` +} + +// ListModelsWithProvider returns all models with their provider types, sorted by model ID. +func (r *ModelRegistry) ListModelsWithProvider() []ModelWithProvider { + r.mu.RLock() + defer r.mu.RUnlock() + + result := make([]ModelWithProvider, 0, len(r.models)) + for _, info := range r.models { + result = append(result, ModelWithProvider{ + Model: info.Model, + ProviderType: r.providerTypes[info.Provider], + }) + } + + sort.Slice(result, func(i, j int) bool { + return result[i].Model.ID < result[j].Model.ID + }) + + return result +} + // ProviderCount returns the number of registered providers func (r *ModelRegistry) ProviderCount() int { r.mu.RLock() diff --git a/internal/providers/registry_test.go b/internal/providers/registry_test.go index c3c729c4..7a229779 100644 --- a/internal/providers/registry_test.go +++ b/internal/providers/registry_test.go @@ -342,6 +342,94 @@ func TestModelRegistry(t *testing.T) { }) } +func TestListModelsWithProvider_Empty(t *testing.T) { + registry := NewModelRegistry() + models := registry.ListModelsWithProvider() + if len(models) != 0 { + t.Errorf("expected empty slice, got %d models", len(models)) + } +} + +func TestListModelsWithProvider_Sorted(t *testing.T) { + registry := NewModelRegistry() + + mock1 := ®istryMockProvider{ + name: "provider1", + modelsResponse: &core.ModelsResponse{ + Object: "list", + Data: []core.Model{ + {ID: "zebra-model", Object: "model", OwnedBy: "provider1"}, + {ID: "alpha-model", Object: "model", OwnedBy: "provider1"}, + }, + }, + } + mock2 := ®istryMockProvider{ + name: "provider2", + modelsResponse: &core.ModelsResponse{ + Object: "list", + Data: []core.Model{ + {ID: "middle-model", Object: "model", OwnedBy: "provider2"}, + }, + }, + } + registry.RegisterProviderWithType(mock1, "openai") + registry.RegisterProviderWithType(mock2, "anthropic") + _ = registry.Initialize(context.Background()) + + models := registry.ListModelsWithProvider() + if len(models) != 3 { + t.Fatalf("expected 3 models, got %d", len(models)) + } + if models[0].Model.ID != "alpha-model" { + t.Errorf("expected first model alpha-model, got %s", models[0].Model.ID) + } + if models[1].Model.ID != "middle-model" { + t.Errorf("expected second model middle-model, got %s", models[1].Model.ID) + } + if models[2].Model.ID != "zebra-model" { + t.Errorf("expected third model zebra-model, got %s", models[2].Model.ID) + } +} + +func TestListModelsWithProvider_IncludesProviderType(t *testing.T) { + registry := NewModelRegistry() + + mock1 := ®istryMockProvider{ + name: "provider1", + modelsResponse: &core.ModelsResponse{ + Object: "list", + Data: []core.Model{ + {ID: "gpt-4", Object: "model", OwnedBy: "openai"}, + }, + }, + } + mock2 := ®istryMockProvider{ + name: "provider2", + modelsResponse: &core.ModelsResponse{ + Object: "list", + Data: []core.Model{ + {ID: "claude-3", Object: "model", OwnedBy: "anthropic"}, + }, + }, + } + registry.RegisterProviderWithType(mock1, "openai") + registry.RegisterProviderWithType(mock2, "anthropic") + _ = registry.Initialize(context.Background()) + + models := registry.ListModelsWithProvider() + if len(models) != 2 { + t.Fatalf("expected 2 models, got %d", len(models)) + } + + // Models are sorted: claude-3 before gpt-4 + if models[0].ProviderType != "anthropic" { + t.Errorf("expected claude-3 provider type 'anthropic', got %q", models[0].ProviderType) + } + if models[1].ProviderType != "openai" { + t.Errorf("expected gpt-4 provider type 'openai', got %q", models[1].ProviderType) + } +} + // countingRegistryMockProvider wraps registryMockProvider and counts ListModels calls type countingRegistryMockProvider struct { *registryMockProvider diff --git a/internal/server/auth.go b/internal/server/auth.go index 3d83e4d8..d66c4e6b 100644 --- a/internal/server/auth.go +++ b/internal/server/auth.go @@ -19,10 +19,16 @@ func AuthMiddleware(masterKey string, skipPaths []string) echo.MiddlewareFunc { return next(c) } - // Check if path should skip authentication + // Check if path should skip authentication. + // Paths ending with "/*" are treated as prefix matches. requestPath := c.Request().URL.Path for _, skipPath := range skipPaths { - if requestPath == skipPath { + if strings.HasSuffix(skipPath, "/*") { + prefix := strings.TrimSuffix(skipPath, "*") + if strings.HasPrefix(requestPath, prefix) { + return next(c) + } + } else if requestPath == skipPath { return next(c) } } diff --git a/internal/server/auth_test.go b/internal/server/auth_test.go index 4d588621..0f24f4e0 100644 --- a/internal/server/auth_test.go +++ b/internal/server/auth_test.go @@ -210,6 +210,68 @@ func TestAuthMiddleware_SkipPaths(t *testing.T) { }) } +func TestAuthMiddleware_WildcardSkipPaths(t *testing.T) { + skipPaths := []string{"/admin/dashboard", "/admin/dashboard/*", "/admin/static/*"} + + tests := []struct { + name string + path string + wantSkip bool + }{ + { + name: "exact match /admin/dashboard", + path: "/admin/dashboard", + wantSkip: true, + }, + { + name: "wildcard match /admin/dashboard/overview", + path: "/admin/dashboard/overview", + wantSkip: true, + }, + { + name: "wildcard match /admin/dashboard/deep/nested", + path: "/admin/dashboard/deep/nested", + wantSkip: true, + }, + { + name: "wildcard match /admin/static/css/dashboard.css", + path: "/admin/static/css/dashboard.css", + wantSkip: true, + }, + { + name: "no match /admin/api/v1/models", + path: "/admin/api/v1/models", + wantSkip: false, + }, + { + name: "no match /admin/dashboardx (not prefix of /admin/dashboard/)", + path: "/admin/dashboardx", + wantSkip: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := echo.New() + e.Use(AuthMiddleware("secret-key", skipPaths)) + + e.GET(tt.path, func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + req := httptest.NewRequest(http.MethodGet, tt.path, nil) + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + if tt.wantSkip { + assert.Equal(t, http.StatusOK, rec.Code, "expected path %s to skip auth", tt.path) + } else { + assert.Equal(t, http.StatusUnauthorized, rec.Code, "expected path %s to require auth", tt.path) + } + }) + } +} + func TestAuthMiddleware_ConstantTimeComparison(t *testing.T) { t.Run("constant-time comparison prevents timing attacks", func(t *testing.T) { // Test that the constant-time comparison works correctly for various inputs diff --git a/internal/server/http.go b/internal/server/http.go index 5cdeb5c9..ace9c7a1 100644 --- a/internal/server/http.go +++ b/internal/server/http.go @@ -11,6 +11,8 @@ import ( "github.com/labstack/echo/v4/middleware" "github.com/prometheus/client_golang/prometheus/promhttp" + "gomodel/internal/admin" + "gomodel/internal/admin/dashboard" "gomodel/internal/auditlog" "gomodel/internal/core" "gomodel/internal/usage" @@ -31,6 +33,10 @@ type Config struct { AuditLogger auditlog.LoggerInterface // Optional: Audit logger for request/response logging UsageLogger usage.LoggerInterface // Optional: Usage logger for token tracking LogOnlyModelInteractions bool // Only log AI model endpoints (default: true) + AdminEndpointsEnabled bool // Whether admin API endpoints are enabled + AdminUIEnabled bool // Whether admin dashboard UI is enabled + AdminHandler *admin.Handler // Admin API handler (nil if disabled) + DashboardHandler *dashboard.Handler // Dashboard UI handler (nil if disabled) } // New creates a new HTTP server @@ -68,6 +74,11 @@ func New(provider core.RoutableProvider, cfg *Config) *Server { authSkipPaths = append(authSkipPaths, metricsPath) } + // Admin dashboard pages and static assets skip auth (/* enables prefix matching) + if cfg != nil && cfg.AdminUIEnabled && cfg.DashboardHandler != nil { + authSkipPaths = append(authSkipPaths, "/admin/dashboard", "/admin/dashboard/*", "/admin/static/*") + } + // Global middleware stack (order matters) // Request logger with optional filtering for model-only interactions if cfg != nil && cfg.LogOnlyModelInteractions { @@ -133,6 +144,21 @@ func New(provider core.RoutableProvider, cfg *Config) *Server { e.POST("/v1/chat/completions", handler.ChatCompletion) e.POST("/v1/responses", handler.Responses) + // Admin API routes (behind ADMIN_ENDPOINTS_ENABLED flag) + if cfg != nil && cfg.AdminEndpointsEnabled && cfg.AdminHandler != nil { + adminAPI := e.Group("/admin/api/v1") + adminAPI.GET("/usage/summary", cfg.AdminHandler.UsageSummary) + adminAPI.GET("/usage/daily", cfg.AdminHandler.DailyUsage) + adminAPI.GET("/models", cfg.AdminHandler.ListModels) + } + + // Admin dashboard UI routes (behind ADMIN_UI_ENABLED flag) + if cfg != nil && cfg.AdminUIEnabled && cfg.DashboardHandler != nil { + e.GET("/admin/dashboard", cfg.DashboardHandler.Index) + e.GET("/admin/dashboard/*", cfg.DashboardHandler.Index) + e.GET("/admin/static/*", cfg.DashboardHandler.Static) + } + return &Server{ echo: e, handler: handler, diff --git a/internal/server/http_test.go b/internal/server/http_test.go index b4ca1cc7..5a7142ea 100644 --- a/internal/server/http_test.go +++ b/internal/server/http_test.go @@ -5,6 +5,9 @@ import ( "net/http/httptest" "strings" "testing" + + "gomodel/internal/admin" + "gomodel/internal/admin/dashboard" ) func TestMetricsEndpoint(t *testing.T) { @@ -211,6 +214,151 @@ func TestServerWithMasterKeyAndMetrics(t *testing.T) { }) } +func newDashboardHandler(t *testing.T) *dashboard.Handler { + t.Helper() + h, err := dashboard.New() + if err != nil { + t.Fatalf("failed to create dashboard handler: %v", err) + } + return h +} + +func TestAdminEndpoints_Enabled(t *testing.T) { + mock := &mockProvider{} + adminHandler := admin.NewHandler(nil, nil) + srv := New(mock, &Config{ + AdminEndpointsEnabled: true, + AdminHandler: adminHandler, + }) + + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/models", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } +} + +func TestAdminEndpoints_Disabled(t *testing.T) { + mock := &mockProvider{} + srv := New(mock, &Config{ + AdminEndpointsEnabled: false, + }) + + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/models", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", rec.Code) + } +} + +func TestAdminUI_Enabled(t *testing.T) { + mock := &mockProvider{} + dashHandler := newDashboardHandler(t) + adminHandler := admin.NewHandler(nil, nil) + srv := New(mock, &Config{ + AdminEndpointsEnabled: true, + AdminUIEnabled: true, + AdminHandler: adminHandler, + DashboardHandler: dashHandler, + }) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rec.Code) + } + contentType := rec.Header().Get("Content-Type") + if !strings.Contains(contentType, "text/html") { + t.Errorf("expected text/html Content-Type, got %s", contentType) + } +} + +func TestAdminUI_Disabled(t *testing.T) { + mock := &mockProvider{} + srv := New(mock, &Config{ + AdminEndpointsEnabled: true, + AdminUIEnabled: false, + AdminHandler: admin.NewHandler(nil, nil), + }) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", rec.Code) + } +} + +func TestAdminDashboard_SkipsAuth(t *testing.T) { + mock := &mockProvider{} + dashHandler := newDashboardHandler(t) + adminHandler := admin.NewHandler(nil, nil) + srv := New(mock, &Config{ + MasterKey: "test-secret-key", + AdminEndpointsEnabled: true, + AdminUIEnabled: true, + AdminHandler: adminHandler, + DashboardHandler: dashHandler, + }) + + // Dashboard should be accessible without auth + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200 (no auth), got %d", rec.Code) + } +} + +func TestAdminAPI_RequiresAuth(t *testing.T) { + mock := &mockProvider{} + adminHandler := admin.NewHandler(nil, nil) + srv := New(mock, &Config{ + MasterKey: "test-secret-key", + AdminEndpointsEnabled: true, + AdminHandler: adminHandler, + }) + + // Admin API should require auth + req := httptest.NewRequest(http.MethodGet, "/admin/api/v1/models", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rec.Code) + } +} + +func TestAdminStaticAssets_SkipAuth(t *testing.T) { + mock := &mockProvider{} + dashHandler := newDashboardHandler(t) + adminHandler := admin.NewHandler(nil, nil) + srv := New(mock, &Config{ + MasterKey: "test-secret-key", + AdminEndpointsEnabled: true, + AdminUIEnabled: true, + AdminHandler: adminHandler, + DashboardHandler: dashHandler, + }) + + // Static assets should be accessible without auth + req := httptest.NewRequest(http.MethodGet, "/admin/static/css/dashboard.css", nil) + rec := httptest.NewRecorder() + srv.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200 for static asset without auth, got %d", rec.Code) + } +} + func TestHealthEndpointAlwaysAvailable(t *testing.T) { tests := []struct { name string diff --git a/internal/usage/factory.go b/internal/usage/factory.go index 5b28097a..deda7a39 100644 --- a/internal/usage/factory.go +++ b/internal/usage/factory.go @@ -110,6 +110,44 @@ func NewWithSharedStorage(ctx context.Context, cfg *config.Config, store storage }, nil } +// NewReader creates a UsageReader from a storage backend. +// Returns nil if the storage is nil (usage data not available). +func NewReader(store storage.Storage) (UsageReader, error) { + if store == nil { + return nil, nil + } + + switch store.Type() { + case storage.TypeSQLite: + return NewSQLiteReader(store.SQLiteDB()) + + case storage.TypePostgreSQL: + pool := store.PostgreSQLPool() + if pool == nil { + return nil, fmt.Errorf("PostgreSQL pool is nil") + } + pgxPool, ok := pool.(*pgxpool.Pool) + if !ok { + return nil, fmt.Errorf("invalid PostgreSQL pool type: %T", pool) + } + return NewPostgreSQLReader(pgxPool) + + case storage.TypeMongoDB: + db := store.MongoDatabase() + if db == nil { + return nil, fmt.Errorf("MongoDB database is nil") + } + mongoDB, ok := db.(*mongo.Database) + if !ok { + return nil, fmt.Errorf("invalid MongoDB database type: %T", db) + } + return NewMongoDBReader(mongoDB) + + default: + return nil, fmt.Errorf("unknown storage type: %s", store.Type()) + } +} + // buildStorageConfig creates a storage.Config from the application config. func buildStorageConfig(cfg *config.Config) storage.Config { storageCfg := storage.Config{ diff --git a/internal/usage/reader.go b/internal/usage/reader.go new file mode 100644 index 00000000..17e242f0 --- /dev/null +++ b/internal/usage/reader.go @@ -0,0 +1,43 @@ +package usage + +import ( + "context" + "time" +) + +// UsageQueryParams specifies the query parameters for usage data retrieval. +type UsageQueryParams struct { + StartDate time.Time // Inclusive start (day precision) + EndDate time.Time // Inclusive end (day precision) + Interval string // "daily", "weekly", "monthly", "yearly" +} + +// UsageSummary holds aggregated usage statistics over a time period. +type UsageSummary struct { + TotalRequests int `json:"total_requests"` + TotalInput int64 `json:"total_input_tokens"` + TotalOutput int64 `json:"total_output_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +// DailyUsage holds usage statistics for a single period. +// Date holds the period label: YYYY-MM-DD for daily, YYYY-Www for weekly, +// YYYY-MM for monthly, or YYYY for yearly intervals. +type DailyUsage struct { + Date string `json:"date"` + Requests int `json:"requests"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + TotalTokens int64 `json:"total_tokens"` +} + +// UsageReader provides read access to usage data for the admin API. +type UsageReader interface { + // GetSummary returns aggregated usage statistics for the given date range. + // If both StartDate and EndDate are zero, returns all-time statistics. + GetSummary(ctx context.Context, params UsageQueryParams) (*UsageSummary, error) + + // GetDailyUsage returns usage statistics grouped by the specified interval. + // If both StartDate and EndDate are zero, returns all available data. + GetDailyUsage(ctx context.Context, params UsageQueryParams) ([]DailyUsage, error) +} diff --git a/internal/usage/reader_mongodb.go b/internal/usage/reader_mongodb.go new file mode 100644 index 00000000..541e4a11 --- /dev/null +++ b/internal/usage/reader_mongodb.go @@ -0,0 +1,174 @@ +package usage + +import ( + "context" + "fmt" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/mongo" +) + +// MongoDBReader implements UsageReader for MongoDB. +type MongoDBReader struct { + collection *mongo.Collection +} + +// NewMongoDBReader creates a new MongoDB usage reader. +func NewMongoDBReader(database *mongo.Database) (*MongoDBReader, error) { + if database == nil { + return nil, fmt.Errorf("database is required") + } + return &MongoDBReader{collection: database.Collection("usage")}, nil +} + +func (r *MongoDBReader) GetSummary(ctx context.Context, params UsageQueryParams) (*UsageSummary, error) { + pipeline := bson.A{} + + startZero := params.StartDate.IsZero() + endZero := params.EndDate.IsZero() + + if !startZero && !endZero { + pipeline = append(pipeline, bson.D{{Key: "$match", Value: bson.D{ + {Key: "timestamp", Value: bson.D{ + {Key: "$gte", Value: params.StartDate.UTC()}, + {Key: "$lt", Value: params.EndDate.AddDate(0, 0, 1).UTC()}, + }}, + }}}) + } else if !startZero { + pipeline = append(pipeline, bson.D{{Key: "$match", Value: bson.D{ + {Key: "timestamp", Value: bson.D{{Key: "$gte", Value: params.StartDate.UTC()}}}, + }}}) + } else if !endZero { + pipeline = append(pipeline, bson.D{{Key: "$match", Value: bson.D{ + {Key: "timestamp", Value: bson.D{{Key: "$lt", Value: params.EndDate.AddDate(0, 0, 1).UTC()}}}, + }}}) + } + + pipeline = append(pipeline, bson.D{{Key: "$group", Value: bson.D{ + {Key: "_id", Value: nil}, + {Key: "total_requests", Value: bson.D{{Key: "$sum", Value: 1}}}, + {Key: "total_input", Value: bson.D{{Key: "$sum", Value: "$input_tokens"}}}, + {Key: "total_output", Value: bson.D{{Key: "$sum", Value: "$output_tokens"}}}, + {Key: "total_tokens", Value: bson.D{{Key: "$sum", Value: "$total_tokens"}}}, + }}}) + + cursor, err := r.collection.Aggregate(ctx, pipeline) + if err != nil { + return nil, fmt.Errorf("failed to aggregate usage summary: %w", err) + } + defer cursor.Close(ctx) + + summary := &UsageSummary{} + if cursor.Next(ctx) { + var result struct { + TotalRequests int `bson:"total_requests"` + TotalInput int64 `bson:"total_input"` + TotalOutput int64 `bson:"total_output"` + TotalTokens int64 `bson:"total_tokens"` + } + if err := cursor.Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode usage summary: %w", err) + } + summary.TotalRequests = result.TotalRequests + summary.TotalInput = result.TotalInput + summary.TotalOutput = result.TotalOutput + summary.TotalTokens = result.TotalTokens + } + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("error iterating usage summary cursor: %w", err) + } + + return summary, nil +} + +func mongoDateFormat(interval string) string { + switch interval { + case "weekly": + return "%G-W%V" + case "monthly": + return "%Y-%m" + case "yearly": + return "%Y" + default: + return "%Y-%m-%d" + } +} + +func (r *MongoDBReader) GetDailyUsage(ctx context.Context, params UsageQueryParams) ([]DailyUsage, error) { + interval := params.Interval + if interval == "" { + interval = "daily" + } + + pipeline := bson.A{} + + startZero := params.StartDate.IsZero() + endZero := params.EndDate.IsZero() + + if !startZero && !endZero { + pipeline = append(pipeline, bson.D{{Key: "$match", Value: bson.D{ + {Key: "timestamp", Value: bson.D{ + {Key: "$gte", Value: params.StartDate.UTC()}, + {Key: "$lt", Value: params.EndDate.AddDate(0, 0, 1).UTC()}, + }}, + }}}) + } else if !startZero { + pipeline = append(pipeline, bson.D{{Key: "$match", Value: bson.D{ + {Key: "timestamp", Value: bson.D{{Key: "$gte", Value: params.StartDate.UTC()}}}, + }}}) + } else if !endZero { + pipeline = append(pipeline, bson.D{{Key: "$match", Value: bson.D{ + {Key: "timestamp", Value: bson.D{{Key: "$lt", Value: params.EndDate.AddDate(0, 0, 1).UTC()}}}, + }}}) + } + + dateFormat := mongoDateFormat(interval) + + pipeline = append(pipeline, + bson.D{{Key: "$group", Value: bson.D{ + {Key: "_id", Value: bson.D{{Key: "$dateToString", Value: bson.D{ + {Key: "format", Value: dateFormat}, + {Key: "date", Value: "$timestamp"}, + }}}}, + {Key: "requests", Value: bson.D{{Key: "$sum", Value: 1}}}, + {Key: "input_tokens", Value: bson.D{{Key: "$sum", Value: "$input_tokens"}}}, + {Key: "output_tokens", Value: bson.D{{Key: "$sum", Value: "$output_tokens"}}}, + {Key: "total_tokens", Value: bson.D{{Key: "$sum", Value: "$total_tokens"}}}, + }}}, + bson.D{{Key: "$sort", Value: bson.D{{Key: "_id", Value: 1}}}}, + ) + + cursor, err := r.collection.Aggregate(ctx, pipeline) + if err != nil { + return nil, fmt.Errorf("failed to aggregate daily usage: %w", err) + } + defer cursor.Close(ctx) + + result := make([]DailyUsage, 0) + for cursor.Next(ctx) { + var row struct { + Date string `bson:"_id"` + Requests int `bson:"requests"` + InputTokens int64 `bson:"input_tokens"` + OutputTokens int64 `bson:"output_tokens"` + TotalTokens int64 `bson:"total_tokens"` + } + if err := cursor.Decode(&row); err != nil { + return nil, fmt.Errorf("failed to decode daily usage row: %w", err) + } + result = append(result, DailyUsage{ + Date: row.Date, + Requests: row.Requests, + InputTokens: row.InputTokens, + OutputTokens: row.OutputTokens, + TotalTokens: row.TotalTokens, + }) + } + + if err := cursor.Err(); err != nil { + return nil, fmt.Errorf("error iterating daily usage cursor: %w", err) + } + + return result, nil +} diff --git a/internal/usage/reader_postgresql.go b/internal/usage/reader_postgresql.go new file mode 100644 index 00000000..f82c4a28 --- /dev/null +++ b/internal/usage/reader_postgresql.go @@ -0,0 +1,111 @@ +package usage + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5/pgxpool" +) + +// PostgreSQLReader implements UsageReader for PostgreSQL databases. +type PostgreSQLReader struct { + pool *pgxpool.Pool +} + +// NewPostgreSQLReader creates a new PostgreSQL usage reader. +func NewPostgreSQLReader(pool *pgxpool.Pool) (*PostgreSQLReader, error) { + if pool == nil { + return nil, fmt.Errorf("connection pool is required") + } + return &PostgreSQLReader{pool: pool}, nil +} + +func (r *PostgreSQLReader) GetSummary(ctx context.Context, params UsageQueryParams) (*UsageSummary, error) { + var query string + var args []interface{} + + startZero := params.StartDate.IsZero() + endZero := params.EndDate.IsZero() + + if !startZero && !endZero { + query = `SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), COALESCE(SUM(total_tokens), 0) + FROM "usage" WHERE timestamp >= $1 AND timestamp < $2` + args = append(args, params.StartDate.UTC(), params.EndDate.AddDate(0, 0, 1).UTC()) + } else if !startZero { + query = `SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), COALESCE(SUM(total_tokens), 0) + FROM "usage" WHERE timestamp >= $1` + args = append(args, params.StartDate.UTC()) + } else { + query = `SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), COALESCE(SUM(total_tokens), 0) + FROM "usage"` + } + + summary := &UsageSummary{} + err := r.pool.QueryRow(ctx, query, args...).Scan( + &summary.TotalRequests, &summary.TotalInput, &summary.TotalOutput, &summary.TotalTokens, + ) + if err != nil { + return nil, fmt.Errorf("failed to query usage summary: %w", err) + } + + return summary, nil +} + +func pgGroupExpr(interval string) string { + switch interval { + case "weekly": + return `to_char(DATE_TRUNC('week', timestamp AT TIME ZONE 'UTC'), 'IYYY-"W"IW')` + case "monthly": + return `to_char(DATE_TRUNC('month', timestamp AT TIME ZONE 'UTC'), 'YYYY-MM')` + case "yearly": + return `to_char(DATE_TRUNC('year', timestamp AT TIME ZONE 'UTC'), 'YYYY')` + default: + return `to_char(DATE(timestamp AT TIME ZONE 'UTC'), 'YYYY-MM-DD')` + } +} + +func (r *PostgreSQLReader) GetDailyUsage(ctx context.Context, params UsageQueryParams) ([]DailyUsage, error) { + interval := params.Interval + if interval == "" { + interval = "daily" + } + groupExpr := pgGroupExpr(interval) + + var where string + var args []interface{} + + startZero := params.StartDate.IsZero() + endZero := params.EndDate.IsZero() + + if !startZero && !endZero { + where = ` WHERE timestamp >= $1 AND timestamp < $2` + args = append(args, params.StartDate.UTC(), params.EndDate.AddDate(0, 0, 1).UTC()) + } else if !startZero { + where = ` WHERE timestamp >= $1` + args = append(args, params.StartDate.UTC()) + } + + query := fmt.Sprintf(`SELECT %s as period, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), COALESCE(SUM(total_tokens), 0) + FROM "usage"%s GROUP BY %s ORDER BY period`, groupExpr, where, groupExpr) + + rows, err := r.pool.Query(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to query daily usage: %w", err) + } + defer rows.Close() + + result := make([]DailyUsage, 0) + for rows.Next() { + var d DailyUsage + if err := rows.Scan(&d.Date, &d.Requests, &d.InputTokens, &d.OutputTokens, &d.TotalTokens); err != nil { + return nil, fmt.Errorf("failed to scan daily usage row: %w", err) + } + result = append(result, d) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating daily usage rows: %w", err) + } + + return result, nil +} diff --git a/internal/usage/reader_sqlite.go b/internal/usage/reader_sqlite.go new file mode 100644 index 00000000..24591591 --- /dev/null +++ b/internal/usage/reader_sqlite.go @@ -0,0 +1,110 @@ +package usage + +import ( + "context" + "database/sql" + "fmt" +) + +// SQLiteReader implements UsageReader for SQLite databases. +type SQLiteReader struct { + db *sql.DB +} + +// NewSQLiteReader creates a new SQLite usage reader. +func NewSQLiteReader(db *sql.DB) (*SQLiteReader, error) { + if db == nil { + return nil, fmt.Errorf("database connection is required") + } + return &SQLiteReader{db: db}, nil +} + +func (r *SQLiteReader) GetSummary(ctx context.Context, params UsageQueryParams) (*UsageSummary, error) { + var query string + var args []interface{} + + startZero := params.StartDate.IsZero() + endZero := params.EndDate.IsZero() + + if !startZero && !endZero { + query = `SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), COALESCE(SUM(total_tokens), 0) + FROM usage WHERE timestamp >= ? AND timestamp < ?` + args = append(args, params.StartDate.UTC().Format("2006-01-02"), params.EndDate.AddDate(0, 0, 1).UTC().Format("2006-01-02")) + } else if !startZero { + query = `SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), COALESCE(SUM(total_tokens), 0) + FROM usage WHERE timestamp >= ?` + args = append(args, params.StartDate.UTC().Format("2006-01-02")) + } else { + query = `SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), COALESCE(SUM(total_tokens), 0) + FROM usage` + } + + summary := &UsageSummary{} + err := r.db.QueryRowContext(ctx, query, args...).Scan( + &summary.TotalRequests, &summary.TotalInput, &summary.TotalOutput, &summary.TotalTokens, + ) + if err != nil { + return nil, fmt.Errorf("failed to query usage summary: %w", err) + } + + return summary, nil +} + +func sqliteGroupExpr(interval string) string { + switch interval { + case "weekly": + return `strftime('%G-W%V', timestamp)` + case "monthly": + return `strftime('%Y-%m', timestamp)` + case "yearly": + return `strftime('%Y', timestamp)` + default: + return `DATE(timestamp)` + } +} + +func (r *SQLiteReader) GetDailyUsage(ctx context.Context, params UsageQueryParams) ([]DailyUsage, error) { + interval := params.Interval + if interval == "" { + interval = "daily" + } + groupExpr := sqliteGroupExpr(interval) + + var where string + var args []interface{} + + startZero := params.StartDate.IsZero() + endZero := params.EndDate.IsZero() + + if !startZero && !endZero { + where = ` WHERE timestamp >= ? AND timestamp < ?` + args = append(args, params.StartDate.UTC().Format("2006-01-02"), params.EndDate.AddDate(0, 0, 1).UTC().Format("2006-01-02")) + } else if !startZero { + where = ` WHERE timestamp >= ?` + args = append(args, params.StartDate.UTC().Format("2006-01-02")) + } + + query := fmt.Sprintf(`SELECT %s as period, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0), COALESCE(SUM(total_tokens), 0) + FROM usage%s GROUP BY %s ORDER BY period`, groupExpr, where, groupExpr) + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("failed to query daily usage: %w", err) + } + defer rows.Close() + + result := make([]DailyUsage, 0) + for rows.Next() { + var d DailyUsage + if err := rows.Scan(&d.Date, &d.Requests, &d.InputTokens, &d.OutputTokens, &d.TotalTokens); err != nil { + return nil, fmt.Errorf("failed to scan daily usage row: %w", err) + } + result = append(result, d) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating daily usage rows: %w", err) + } + + return result, nil +} diff --git a/tests/e2e/admin_test.go b/tests/e2e/admin_test.go new file mode 100644 index 00000000..68eafecf --- /dev/null +++ b/tests/e2e/admin_test.go @@ -0,0 +1,265 @@ +//go:build e2e + +package e2e + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gomodel/internal/admin" + "gomodel/internal/admin/dashboard" + "gomodel/internal/providers" + "gomodel/internal/server" + "gomodel/internal/usage" +) + +// setupAdminServer creates a new server instance with admin features configured. +func setupAdminServer(t *testing.T, masterKey string, endpointsEnabled, uiEnabled bool) *httptest.Server { + t.Helper() + + // Create test provider using the shared TestProvider + testProvider := NewTestProvider(mockLLMURL, "sk-test-key-12345") + + // Create registry and register mock provider with type + registry := providers.NewModelRegistry() + registry.RegisterProviderWithType(testProvider, "test") + + // Initialize registry synchronously for tests + if err := registry.Initialize(context.Background()); err != nil { + t.Fatalf("Failed to initialize registry: %v", err) + } + + // Create router + router, err := providers.NewRouter(registry) + if err != nil { + t.Fatalf("Failed to create router: %v", err) + } + + // Build server config + cfg := &server.Config{ + MasterKey: masterKey, + AdminEndpointsEnabled: endpointsEnabled, + } + + if endpointsEnabled { + cfg.AdminHandler = admin.NewHandler(nil, registry) + } + + if uiEnabled { + cfg.AdminUIEnabled = true + dashHandler, dashErr := dashboard.New() + if dashErr != nil { + t.Fatalf("Failed to create dashboard handler: %v", dashErr) + } + cfg.DashboardHandler = dashHandler + } + + srv := server.New(router, cfg) + return httptest.NewServer(srv) +} + +func TestAdminAPI_EndpointsEnabled_E2E(t *testing.T) { + ts := setupAdminServer(t, "", true, false) + defer ts.Close() + + endpoints := []string{ + "/admin/api/v1/usage/summary", + "/admin/api/v1/usage/daily", + "/admin/api/v1/models", + } + + for _, ep := range endpoints { + t.Run(ep, func(t *testing.T) { + resp, err := http.Get(ts.URL + ep) + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode, "endpoint %s should return 200", ep) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // Should be valid JSON + assert.True(t, json.Valid(body), "response should be valid JSON for %s, got: %s", ep, string(body)) + }) + } +} + +func TestAdminAPI_EndpointsDisabled_E2E(t *testing.T) { + ts := setupAdminServer(t, "", false, false) + defer ts.Close() + + endpoints := []string{ + "/admin/api/v1/usage/summary", + "/admin/api/v1/usage/daily", + "/admin/api/v1/models", + } + + for _, ep := range endpoints { + t.Run(ep, func(t *testing.T) { + resp, err := http.Get(ts.URL + ep) + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusNotFound, resp.StatusCode, "endpoint %s should return 404 when disabled", ep) + }) + } +} + +func TestAdminAPI_RequiresAuth_E2E(t *testing.T) { + ts := setupAdminServer(t, testMasterKey, true, false) + defer ts.Close() + + t.Run("without auth returns 401", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/admin/api/v1/models") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) + + t.Run("with valid auth returns 200", func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, ts.URL+"/admin/api/v1/models", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+testMasterKey) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) +} + +func TestAdminDashboard_Enabled_E2E(t *testing.T) { + ts := setupAdminServer(t, "", true, true) + defer ts.Close() + + t.Run("dashboard returns 200 HTML", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/admin/dashboard") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, resp.Header.Get("Content-Type"), "text/html") + }) + + t.Run("static CSS returns 200", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/admin/static/css/dashboard.css") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) +} + +func TestAdminDashboard_Disabled_E2E(t *testing.T) { + ts := setupAdminServer(t, "", true, false) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/admin/dashboard") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusNotFound, resp.StatusCode) +} + +func TestAdminDashboard_SkipsAuth_E2E(t *testing.T) { + ts := setupAdminServer(t, testMasterKey, true, true) + defer ts.Close() + + t.Run("dashboard is public (200 without auth)", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/admin/dashboard") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) + + t.Run("API is protected (401 without auth)", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/admin/api/v1/models") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode) + }) +} + +func TestAdminAPI_ModelsEndpoint_E2E(t *testing.T) { + ts := setupAdminServer(t, "", true, false) + defer ts.Close() + + resp, err := http.Get(ts.URL + "/admin/api/v1/models") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var models []providers.ModelWithProvider + require.NoError(t, json.Unmarshal(body, &models)) + + // TestProvider returns 3 models + assert.Len(t, models, 3) + + // Should be sorted by model ID + for i := 1; i < len(models); i++ { + assert.True(t, models[i-1].Model.ID < models[i].Model.ID, + "models should be sorted, but %s >= %s", models[i-1].Model.ID, models[i].Model.ID) + } + + // Each model should have provider_type + for _, m := range models { + assert.Equal(t, "test", m.ProviderType, "model %s should have provider_type 'test'", m.Model.ID) + } +} + +func TestAdminAPI_UsageEndpoints_E2E(t *testing.T) { + ts := setupAdminServer(t, "", true, false) + defer ts.Close() + + t.Run("summary returns zeroed object (nil reader)", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/admin/api/v1/usage/summary") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + var summary usage.UsageSummary + require.NoError(t, json.NewDecoder(resp.Body).Decode(&summary)) + assert.Equal(t, 0, summary.TotalRequests) + assert.Equal(t, int64(0), summary.TotalTokens) + }) + + t.Run("daily returns empty array (nil reader)", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/admin/api/v1/usage/daily") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var daily []usage.DailyUsage + require.NoError(t, json.Unmarshal(body, &daily)) + assert.Empty(t, daily) + }) + + t.Run("query params accepted", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/admin/api/v1/usage/daily?days=7&interval=weekly") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + }) +} diff --git a/tests/integration/admin_test.go b/tests/integration/admin_test.go new file mode 100644 index 00000000..f53f659c --- /dev/null +++ b/tests/integration/admin_test.go @@ -0,0 +1,233 @@ +//go:build integration + +package integration + +import ( + "encoding/json" + "io" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gomodel/internal/providers" + "gomodel/internal/usage" + "gomodel/tests/integration/dbassert" +) + +func TestAdminUsageSummary_PostgreSQL(t *testing.T) { + fixture := SetupTestServer(t, TestServerConfig{ + DBType: "postgresql", + UsageEnabled: true, + AdminEndpointsEnabled: true, + OnlyModelInteractions: false, + }) + + // Clear existing usage entries + dbassert.ClearUsage(t, fixture.PgPool) + + // Send 2 chat requests + for i := 0; i < 2; i++ { + payload := newChatRequest("gpt-4", "Hello!") + resp := sendChatRequest(t, fixture.ServerURL, payload) + require.Equal(t, 200, resp.StatusCode) + closeBody(resp) + } + + // Wait for usage buffer to flush (flush interval is 1s in tests) + time.Sleep(2 * time.Second) + + // Query admin API + resp, err := http.Get(fixture.ServerURL + "/admin/api/v1/usage/summary?days=30") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var summary usage.UsageSummary + require.NoError(t, json.Unmarshal(body, &summary)) + + assert.Equal(t, 2, summary.TotalRequests, "expected 2 total requests") + assert.Equal(t, int64(20), summary.TotalInput, "expected 20 input tokens (2 * 10)") + assert.Equal(t, int64(16), summary.TotalOutput, "expected 16 output tokens (2 * 8)") + assert.Equal(t, int64(36), summary.TotalTokens, "expected 36 total tokens (2 * 18)") + + fixture.FlushAndClose(t) +} + +func TestAdminDailyUsage_PostgreSQL(t *testing.T) { + fixture := SetupTestServer(t, TestServerConfig{ + DBType: "postgresql", + UsageEnabled: true, + AdminEndpointsEnabled: true, + OnlyModelInteractions: false, + }) + + // Clear existing usage entries + dbassert.ClearUsage(t, fixture.PgPool) + + // Send requests + for i := 0; i < 2; i++ { + payload := newChatRequest("gpt-4", "Hello!") + resp := sendChatRequest(t, fixture.ServerURL, payload) + require.Equal(t, 200, resp.StatusCode) + closeBody(resp) + } + + // Wait for usage buffer to flush + time.Sleep(2 * time.Second) + + // Query admin API + resp, err := http.Get(fixture.ServerURL + "/admin/api/v1/usage/daily?days=30") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var daily []usage.DailyUsage + require.NoError(t, json.Unmarshal(body, &daily)) + + require.NotEmpty(t, daily, "expected at least one daily entry") + + // Find today's entry + today := time.Now().UTC().Format("2006-01-02") + var todayEntry *usage.DailyUsage + for i := range daily { + if daily[i].Date == today { + todayEntry = &daily[i] + break + } + } + require.NotNil(t, todayEntry, "expected entry for today %s", today) + assert.Equal(t, 2, todayEntry.Requests, "expected 2 requests today") + assert.Equal(t, int64(20), todayEntry.InputTokens, "expected 20 input tokens") + assert.Equal(t, int64(16), todayEntry.OutputTokens, "expected 16 output tokens") + assert.Equal(t, int64(36), todayEntry.TotalTokens, "expected 36 total tokens") + + fixture.FlushAndClose(t) +} + +func TestAdminUsageSummary_MongoDB(t *testing.T) { + fixture := SetupTestServer(t, TestServerConfig{ + DBType: "mongodb", + UsageEnabled: true, + AdminEndpointsEnabled: true, + OnlyModelInteractions: false, + }) + + // Clear existing usage entries + dbassert.ClearUsageMongo(t, fixture.MongoDb) + + // Send 2 chat requests + for i := 0; i < 2; i++ { + payload := newChatRequest("gpt-4", "Hello!") + resp := sendChatRequest(t, fixture.ServerURL, payload) + require.Equal(t, 200, resp.StatusCode) + closeBody(resp) + } + + // Wait for usage buffer to flush + time.Sleep(2 * time.Second) + + // Query admin API + resp, err := http.Get(fixture.ServerURL + "/admin/api/v1/usage/summary?days=30") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var summary usage.UsageSummary + require.NoError(t, json.Unmarshal(body, &summary)) + + assert.Equal(t, 2, summary.TotalRequests, "expected 2 total requests") + assert.Equal(t, int64(20), summary.TotalInput, "expected 20 input tokens (2 * 10)") + assert.Equal(t, int64(16), summary.TotalOutput, "expected 16 output tokens (2 * 8)") + assert.Equal(t, int64(36), summary.TotalTokens, "expected 36 total tokens (2 * 18)") + + fixture.FlushAndClose(t) +} + +func TestAdminDailyUsage_WithInterval_PostgreSQL(t *testing.T) { + fixture := SetupTestServer(t, TestServerConfig{ + DBType: "postgresql", + UsageEnabled: true, + AdminEndpointsEnabled: true, + OnlyModelInteractions: false, + }) + + // Send a request so there's data + payload := newChatRequest("gpt-4", "Hello!") + resp := sendChatRequest(t, fixture.ServerURL, payload) + require.Equal(t, 200, resp.StatusCode) + closeBody(resp) + + // Wait for usage buffer to flush + time.Sleep(2 * time.Second) + + // Query with weekly interval + resp, err := http.Get(fixture.ServerURL + "/admin/api/v1/usage/daily?interval=weekly") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var daily []usage.DailyUsage + require.NoError(t, json.Unmarshal(body, &daily)) + + // Should return valid JSON array (may be empty or have entries) + assert.True(t, json.Valid(body), "response should be valid JSON") + + fixture.FlushAndClose(t) +} + +func TestAdminModels_PostgreSQL(t *testing.T) { + fixture := SetupTestServer(t, TestServerConfig{ + DBType: "postgresql", + UsageEnabled: false, + AdminEndpointsEnabled: true, + OnlyModelInteractions: false, + }) + + // Query admin models endpoint + resp, err := http.Get(fixture.ServerURL + "/admin/api/v1/models") + require.NoError(t, err) + defer closeBody(resp) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + var models []providers.ModelWithProvider + require.NoError(t, json.Unmarshal(body, &models)) + + require.NotEmpty(t, models, "expected at least one model") + + // Should be sorted by model ID + for i := 1; i < len(models); i++ { + assert.True(t, models[i-1].Model.ID < models[i].Model.ID, + "models should be sorted, but %s >= %s", models[i-1].Model.ID, models[i].Model.ID) + } + + // Each model should have model.id and provider_type + for _, m := range models { + assert.NotEmpty(t, m.Model.ID, "model.id should not be empty") + assert.NotEmpty(t, m.ProviderType, "provider_type should not be empty") + } + + fixture.FlushAndClose(t) +} diff --git a/tests/integration/dbassert/usage.go b/tests/integration/dbassert/usage.go index 944f9995..299ccd8c 100644 --- a/tests/integration/dbassert/usage.go +++ b/tests/integration/dbassert/usage.go @@ -117,6 +117,16 @@ func ClearUsage(t *testing.T, pool *pgxpool.Pool) { require.NoError(t, err, "failed to clear usage entries") } +// ClearUsageMongo deletes all usage entries from MongoDB. +func ClearUsageMongo(t *testing.T, db *mongo.Database) { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + _, err := db.Collection("usage").DeleteMany(ctx, bson.M{}) + require.NoError(t, err, "failed to clear usage entries from MongoDB") +} + // SumTokensByModel returns total token usage grouped by model from PostgreSQL. func SumTokensByModel(t *testing.T, pool *pgxpool.Pool) map[string]TokenSummary { t.Helper() diff --git a/tests/integration/setup_test.go b/tests/integration/setup_test.go index 9c2640cf..61fa48a8 100644 --- a/tests/integration/setup_test.go +++ b/tests/integration/setup_test.go @@ -44,6 +44,15 @@ type TestServerConfig struct { // OnlyModelInteractions limits logging to model endpoints only OnlyModelInteractions bool + + // AdminEndpointsEnabled enables admin API endpoints + AdminEndpointsEnabled bool + + // AdminUIEnabled enables admin dashboard UI + AdminUIEnabled bool + + // MasterKey sets the authentication master key (empty = unsafe mode) + MasterKey string } // TestServerFixture holds test server resources. @@ -172,8 +181,12 @@ func buildAppConfig(t *testing.T, cfg TestServerConfig, mockLLMURL string, port appCfg := &config.Config{ Server: config.ServerConfig{ - Port: fmt.Sprintf("%d", port), - // No master key for tests (unsafe mode) + Port: fmt.Sprintf("%d", port), + MasterKey: cfg.MasterKey, + }, + Admin: config.AdminConfig{ + EndpointsEnabled: cfg.AdminEndpointsEnabled, + UIEnabled: cfg.AdminUIEnabled, }, Cache: config.CacheConfig{ Type: "local",