diff --git a/dashboard/src/components/layout/sidebar.tsx b/dashboard/src/components/layout/sidebar.tsx index ee2c9c1..5f9e209 100644 --- a/dashboard/src/components/layout/sidebar.tsx +++ b/dashboard/src/components/layout/sidebar.tsx @@ -59,6 +59,7 @@ const NAV: NavGroup[] = [ { title: "Configuration", items: [ + { to: "/tasks", label: "Tasks", icon: ListTree }, { to: "/webhooks", label: "Webhooks", icon: WebhookIcon }, { to: "/settings", label: "Settings", icon: Cog }, ], diff --git a/dashboard/src/features/tasks/api.ts b/dashboard/src/features/tasks/api.ts new file mode 100644 index 0000000..e1232fd --- /dev/null +++ b/dashboard/src/features/tasks/api.ts @@ -0,0 +1,26 @@ +import { api } from "@/lib/api-client"; +import type { QueueEntry, QueueOverridePatch, TaskEntry, TaskOverridePatch } from "./types"; + +export function listTasks(signal?: AbortSignal): Promise { + return api.get("/api/tasks", { signal }); +} + +export function listQueues(signal?: AbortSignal): Promise { + return api.get("/api/queues", { signal }); +} + +export function putTaskOverride(name: string, patch: TaskOverridePatch): Promise { + return api.put(`/api/tasks/${encodeURIComponent(name)}/override`, patch); +} + +export function clearTaskOverride(name: string): Promise<{ cleared: boolean }> { + return api.delete<{ cleared: boolean }>(`/api/tasks/${encodeURIComponent(name)}/override`); +} + +export function putQueueOverride(name: string, patch: QueueOverridePatch): Promise { + return api.put(`/api/queues/${encodeURIComponent(name)}/override`, patch); +} + +export function clearQueueOverride(name: string): Promise<{ cleared: boolean }> { + return api.delete<{ cleared: boolean }>(`/api/queues/${encodeURIComponent(name)}/override`); +} diff --git a/dashboard/src/features/tasks/components/middleware-toggles.tsx b/dashboard/src/features/tasks/components/middleware-toggles.tsx new file mode 100644 index 0000000..c61241d --- /dev/null +++ b/dashboard/src/features/tasks/components/middleware-toggles.tsx @@ -0,0 +1,99 @@ +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { Power } from "lucide-react"; +import { toast } from "sonner"; +import { ErrorState, Skeleton } from "@/components/ui"; +import { api } from "@/lib/api-client"; + +interface TaskMiddlewareEntry { + name: string; + class_path: string; + disabled: boolean; + effective: boolean; +} + +interface TaskMiddlewareResponse { + task: string; + middleware: TaskMiddlewareEntry[]; +} + +interface Props { + taskName: string; +} + +const queryKey = (task: string) => ["tasks", task, "middleware"] as const; + +export function MiddlewareToggles({ taskName }: Props) { + const qc = useQueryClient(); + const query = useQuery({ + queryKey: queryKey(taskName), + queryFn: ({ signal }) => + api.get(`/api/tasks/${encodeURIComponent(taskName)}/middleware`, { + signal, + }), + }); + + const mutation = useMutation({ + mutationFn: ({ mwName, enabled }: { mwName: string; enabled: boolean }) => + api.put( + `/api/tasks/${encodeURIComponent(taskName)}/middleware/${encodeURIComponent(mwName)}`, + { enabled }, + ), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: queryKey(taskName) }); + }, + onError: () => toast.error("Failed to update middleware"), + }); + + if (query.isLoading) { + return ; + } + if (query.error) { + return ( + + ); + } + const entries = query.data?.middleware ?? []; + if (entries.length === 0) { + return ( +
+ No middleware registered for this task. +
+ ); + } + + return ( +
    + {entries.map((entry) => { + const enabled = !entry.disabled; + return ( +
  • +
    +
    {entry.name}
    +
    {entry.class_path}
    +
    + +
  • + ); + })} +
+ ); +} diff --git a/dashboard/src/features/tasks/components/task-list-table.tsx b/dashboard/src/features/tasks/components/task-list-table.tsx new file mode 100644 index 0000000..4c69e32 --- /dev/null +++ b/dashboard/src/features/tasks/components/task-list-table.tsx @@ -0,0 +1,132 @@ +import { ListTree } from "lucide-react"; +import { useState } from "react"; +import { + Badge, + Button, + EmptyState, + Sheet, + SheetContent, + Table, + TableBody, + TableCell, + TableHead, + TableHeader, + TableRow, +} from "@/components/ui"; +import type { TaskEntry } from "../types"; +import { TaskOverrideForm } from "./task-override-form"; + +interface Props { + tasks: TaskEntry[]; +} + +export function TaskListTable({ tasks }: Props) { + const [editing, setEditing] = useState(null); + + if (tasks.length === 0) { + return ( + + ); + } + + return ( + <> +
+ + + + Task + Queue + Rate limit + Concurrency + Retries + Timeout + Override + + + + + {tasks.map((task) => ( + + {task.name} + + {task.queue} + + + (v == null ? "—" : String(v))} + /> + + + (v == null ? "—" : String(v))} + /> + + + String(v)} + /> + + + `${v}s`} + /> + + + {task.paused ? ( + Paused + ) : task.override ? ( + Override + ) : ( + Default + )} + + + + + + ))} + +
+
+ + !open && setEditing(null)}> + + {editing ? setEditing(null)} /> : null} + + + + ); +} + +interface CellProps { + effective: T; + decoratorDefault: T; + formatter: (v: T) => string; +} + +function EffectiveCell({ effective, decoratorDefault, formatter }: CellProps) { + const overridden = effective !== decoratorDefault; + return ( + + {formatter(effective)} + + ); +} diff --git a/dashboard/src/features/tasks/components/task-override-form.tsx b/dashboard/src/features/tasks/components/task-override-form.tsx new file mode 100644 index 0000000..5e962b5 --- /dev/null +++ b/dashboard/src/features/tasks/components/task-override-form.tsx @@ -0,0 +1,237 @@ +import { Save, Trash2 } from "lucide-react"; +import { type FormEvent, useState } from "react"; +import { Button, Input, Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui"; +import { useClearTaskOverride, useSetTaskOverride } from "../hooks"; +import type { TaskEntry, TaskOverridePatch } from "../types"; +import { MiddlewareToggles } from "./middleware-toggles"; + +interface Props { + task: TaskEntry; + onDone?: () => void; +} + +/** + * Side-panel form for editing a task's overrides. Empty inputs mean + * "inherit the decorator default" (the override field is omitted / + * cleared); a non-empty value overrides the default. Submit applies the + * change; ``Clear`` removes the override entirely. + */ +export function TaskOverrideForm({ task, onDone }: Props) { + const setOverride = useSetTaskOverride(); + const clearOverride = useClearTaskOverride(); + + const o = task.override ?? {}; + const [rateLimit, setRateLimit] = useState(o.rate_limit ?? ""); + const [maxConcurrent, setMaxConcurrent] = useState( + o.max_concurrent != null ? String(o.max_concurrent) : "", + ); + const [maxRetries, setMaxRetries] = useState(o.max_retries != null ? String(o.max_retries) : ""); + const [timeout, setTimeoutValue] = useState(o.timeout != null ? String(o.timeout) : ""); + const [priority, setPriority] = useState(o.priority != null ? String(o.priority) : ""); + const [paused, setPaused] = useState(o.paused ?? false); + + function buildPatch(): TaskOverridePatch | null { + const patch: TaskOverridePatch = {}; + const numOr = (raw: string, name: keyof TaskOverridePatch) => { + if (raw === "") { + patch[name] = null as never; + } else { + const v = Number(raw); + if (!Number.isFinite(v)) return false; + (patch as Record)[name] = v; + } + return true; + }; + patch.rate_limit = rateLimit ? rateLimit : null; + if (!numOr(maxConcurrent, "max_concurrent")) return null; + if (!numOr(maxRetries, "max_retries")) return null; + if (!numOr(timeout, "timeout")) return null; + if (!numOr(priority, "priority")) return null; + patch.paused = paused; + return patch; + } + + function onSubmit(event: FormEvent): void { + event.preventDefault(); + const patch = buildPatch(); + if (!patch) return; + setOverride.mutate({ name: task.name, patch }, { onSuccess: () => onDone?.() }); + } + + return ( +
+
+

{task.name}

+

Queue · {task.queue}

+
+ + + Overrides + Middleware + + + clearOverride.mutate(task.name, { onSuccess: () => onDone?.() })} + /> + + + + + +
+ ); +} + +interface OverrideFormProps { + task: TaskEntry; + onSubmit: (e: FormEvent) => void; + rateLimit: string; + setRateLimit: (v: string) => void; + maxConcurrent: string; + setMaxConcurrent: (v: string) => void; + maxRetries: string; + setMaxRetries: (v: string) => void; + timeoutValue: string; + setTimeoutValue: (v: string) => void; + priority: string; + setPriority: (v: string) => void; + paused: boolean; + setPaused: (v: boolean) => void; + saving: boolean; + clearing: boolean; + onClear: () => void; +} + +function OverrideForm({ + task, + onSubmit, + rateLimit, + setRateLimit, + maxConcurrent, + setMaxConcurrent, + maxRetries, + setMaxRetries, + timeoutValue, + setTimeoutValue, + priority, + setPriority, + paused, + setPaused, + saving, + clearing, + onClear, +}: OverrideFormProps) { + return ( +
+

+ Overrides apply on the next worker restart; pausing takes effect immediately. +

+ + + + + + +
+ + +
+ + ); +} + +interface FieldProps { + id: string; + label: string; + value: string; + onChange: (v: string) => void; + defaultValue: string; + type: "text" | "number"; + placeholder?: string; +} + +function NumberField({ id, label, value, onChange, defaultValue, type, placeholder }: FieldProps) { + return ( + + ); +} diff --git a/dashboard/src/features/tasks/hooks.ts b/dashboard/src/features/tasks/hooks.ts new file mode 100644 index 0000000..2e91188 --- /dev/null +++ b/dashboard/src/features/tasks/hooks.ts @@ -0,0 +1,100 @@ +import { queryOptions, useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { toast } from "sonner"; +import { ApiError } from "@/lib/api-client"; +import { + clearQueueOverride, + clearTaskOverride, + listQueues, + listTasks, + putQueueOverride, + putTaskOverride, +} from "./api"; +import type { QueueOverridePatch, TaskOverridePatch } from "./types"; + +const TASKS_KEY = ["tasks"] as const; +const QUEUES_KEY = ["queues-overrides"] as const; + +function describeError(error: unknown): string | undefined { + if (error instanceof ApiError && error.status >= 400 && error.status < 500) { + return error.message; + } + return undefined; +} + +export function tasksQuery() { + return queryOptions({ + queryKey: TASKS_KEY, + queryFn: ({ signal }) => listTasks(signal), + }); +} + +export function queuesQuery() { + return queryOptions({ + queryKey: QUEUES_KEY, + queryFn: ({ signal }) => listQueues(signal), + }); +} + +export function useTasks() { + return useQuery(tasksQuery()); +} + +export function useQueues() { + return useQuery(queuesQuery()); +} + +export function useSetTaskOverride() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ name, patch }: { name: string; patch: TaskOverridePatch }) => + putTaskOverride(name, patch), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: TASKS_KEY }); + toast.success("Override saved", { + description: "Applied on next worker restart.", + }); + }, + onError: (error) => + toast.error("Failed to save override", { description: describeError(error) }), + }); +} + +export function useClearTaskOverride() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (name: string) => clearTaskOverride(name), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: TASKS_KEY }); + toast.success("Override cleared"); + }, + onError: (error) => + toast.error("Failed to clear override", { description: describeError(error) }), + }); +} + +export function useSetQueueOverride() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: ({ name, patch }: { name: string; patch: QueueOverridePatch }) => + putQueueOverride(name, patch), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: QUEUES_KEY }); + toast.success("Queue override saved"); + }, + onError: (error) => + toast.error("Failed to save queue override", { description: describeError(error) }), + }); +} + +export function useClearQueueOverride() { + const qc = useQueryClient(); + return useMutation({ + mutationFn: (name: string) => clearQueueOverride(name), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: QUEUES_KEY }); + toast.success("Queue override cleared"); + }, + onError: (error) => + toast.error("Failed to clear queue override", { description: describeError(error) }), + }); +} diff --git a/dashboard/src/features/tasks/index.ts b/dashboard/src/features/tasks/index.ts new file mode 100644 index 0000000..e9eae07 --- /dev/null +++ b/dashboard/src/features/tasks/index.ts @@ -0,0 +1,19 @@ +export { TaskListTable } from "./components/task-list-table"; +export { TaskOverrideForm } from "./components/task-override-form"; +export { + queuesQuery, + tasksQuery, + useClearQueueOverride, + useClearTaskOverride, + useQueues, + useSetQueueOverride, + useSetTaskOverride, + useTasks, +} from "./hooks"; +export type { + QueueEntry, + QueueOverridePatch, + TaskDefaults, + TaskEntry, + TaskOverridePatch, +} from "./types"; diff --git a/dashboard/src/features/tasks/types.ts b/dashboard/src/features/tasks/types.ts new file mode 100644 index 0000000..01b46cb --- /dev/null +++ b/dashboard/src/features/tasks/types.ts @@ -0,0 +1,41 @@ +export interface TaskDefaults { + max_retries: number; + retry_backoff: number; + timeout: number; + priority: number; + rate_limit: string | null; + max_concurrent: number | null; +} + +export interface TaskOverridePatch { + rate_limit?: string | null; + max_concurrent?: number | null; + max_retries?: number | null; + retry_backoff?: number | null; + timeout?: number | null; + priority?: number | null; + paused?: boolean; +} + +export interface TaskEntry { + name: string; + queue: string; + defaults: TaskDefaults; + override: TaskOverridePatch | null; + effective: TaskDefaults; + paused: boolean; +} + +export interface QueueOverridePatch { + rate_limit?: string | null; + max_concurrent?: number | null; + paused?: boolean; +} + +export interface QueueEntry { + name: string; + defaults: Record; + override: QueueOverridePatch | null; + effective: Record; + paused: boolean; +} diff --git a/dashboard/src/routes/tasks.tsx b/dashboard/src/routes/tasks.tsx new file mode 100644 index 0000000..1465ba9 --- /dev/null +++ b/dashboard/src/routes/tasks.tsx @@ -0,0 +1,31 @@ +import { createFileRoute } from "@tanstack/react-router"; +import { PageHeader } from "@/components/layout/page-header"; +import { ErrorState, Skeleton } from "@/components/ui"; +import { TaskListTable, useTasks } from "@/features/tasks"; + +export const Route = createFileRoute("/tasks")({ + component: TasksPage, +}); + +function TasksPage() { + const { data, isLoading, error } = useTasks(); + + return ( +
+ + {isLoading ? ( + + ) : error ? ( + + ) : ( + + )} +
+ ); +} diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index d42b5c4..c8273c7 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -38,7 +38,9 @@ QueueInspectionMixin, QueueLifecycleMixin, QueueLockMixin, + QueueMiddlewareAdminMixin, QueueOperationsMixin, + QueueOverridesMixin, QueuePredicateMixin, QueueResourceMixin, QueueRuntimeConfigMixin, @@ -83,6 +85,8 @@ class Queue( QueueInspectionMixin, QueueOperationsMixin, QueueLockMixin, + QueueMiddlewareAdminMixin, + QueueOverridesMixin, QueueSettingsMixin, QueueWorkflowMixin, AsyncQueueMixin, diff --git a/py_src/taskito/dashboard/handlers/middleware.py b/py_src/taskito/dashboard/handlers/middleware.py new file mode 100644 index 0000000..e0fd85b --- /dev/null +++ b/py_src/taskito/dashboard/handlers/middleware.py @@ -0,0 +1,62 @@ +"""Middleware discovery + per-task enable/disable endpoints.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.errors import _BadRequest, _NotFound + +if TYPE_CHECKING: + from taskito.app import Queue + + +def handle_list_middleware(queue: Queue, _qs: dict) -> list[dict[str, Any]]: + """Return every registered middleware with its scopes.""" + return queue.list_middleware() + + +def handle_get_task_middleware(queue: Queue, _qs: dict, task_name: str) -> dict[str, Any]: + """Return the middleware chain that fires for ``task_name`` with each + entry's enabled/disabled state.""" + chain = queue._get_middleware_chain(task_name) + disabled = set(queue.get_disabled_middleware_for(task_name)) + # Build the full would-fire chain INCLUDING disabled entries so the UI + # can render every toggle. + base_chain = queue._global_middleware + queue._task_middleware.get(task_name, []) + entries: list[dict[str, Any]] = [] + chain_names = {getattr(mw, "name", "") for mw in chain} + for mw in base_chain: + name = getattr(mw, "name", "") or f"{type(mw).__module__}.{type(mw).__qualname__}" + entries.append( + { + "name": name, + "class_path": f"{type(mw).__module__}.{type(mw).__qualname__}", + "disabled": name in disabled, + "effective": name in chain_names, + } + ) + return {"task": task_name, "middleware": entries} + + +def handle_put_task_middleware(queue: Queue, body: dict, ids: tuple[str, str]) -> dict[str, Any]: + task_name, mw_name = ids + if not isinstance(body, dict) or "enabled" not in body: + raise _BadRequest('body must include {"enabled": bool}') + if not isinstance(body["enabled"], bool): + raise _BadRequest("'enabled' must be a boolean") + # Confirm the middleware exists in the relevant chain so a typo doesn't + # silently write a no-op disable entry. + base_chain = queue._global_middleware + queue._task_middleware.get(task_name, []) + names = {getattr(mw, "name", "") for mw in base_chain} + if mw_name not in names: + raise _NotFound(f"middleware '{mw_name}' is not registered on task '{task_name}'") + if body["enabled"]: + new = queue.enable_middleware_for_task(task_name, mw_name) + else: + new = queue.disable_middleware_for_task(task_name, mw_name) + return {"task": task_name, "disabled": new} + + +def handle_delete_task_middleware(queue: Queue, task_name: str) -> dict[str, bool]: + """Clear ALL disables for a task — every middleware fires again.""" + return {"cleared": queue.clear_middleware_disables(task_name)} diff --git a/py_src/taskito/dashboard/handlers/overrides.py b/py_src/taskito/dashboard/handlers/overrides.py new file mode 100644 index 0000000..c125441 --- /dev/null +++ b/py_src/taskito/dashboard/handlers/overrides.py @@ -0,0 +1,95 @@ +"""Task & queue override endpoints.""" + +from __future__ import annotations + +from dataclasses import asdict +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.errors import _BadRequest, _NotFound +from taskito.dashboard.overrides_store import ( + QUEUE_OVERRIDE_FIELDS, + TASK_OVERRIDE_FIELDS, + OverridesStore, +) + +if TYPE_CHECKING: + from taskito.app import Queue + + +def handle_list_tasks(queue: Queue, _qs: dict) -> list[dict[str, Any]]: + """Return every registered task with decorator defaults + active override.""" + return queue.registered_tasks() + + +def handle_list_queues(queue: Queue, _qs: dict) -> list[dict[str, Any]]: + return queue.registered_queues() + + +def _coerce_override_body(body: Any, allowed: frozenset[str]) -> dict[str, Any]: + if not isinstance(body, dict): + raise _BadRequest("body must be a JSON object") + unknown = set(body) - allowed + if unknown: + raise _BadRequest( + f"unknown override fields: {sorted(unknown)}; allowed: {sorted(allowed)}" + ) + return body + + +# ── Task override endpoints ─────────────────────────────────────────── + + +def handle_get_task_override(queue: Queue, _qs: dict, task_name: str) -> dict[str, Any]: + override = OverridesStore(queue).get_task(task_name) + if override is None: + raise _NotFound(f"no override set for task '{task_name}'") + return asdict(override) + + +def handle_put_task_override(queue: Queue, body: dict, task_name: str) -> dict[str, Any]: + fields = _coerce_override_body(body, TASK_OVERRIDE_FIELDS) + try: + override = OverridesStore(queue).set_task(task_name, fields) + except ValueError as e: + raise _BadRequest(str(e)) from None + return asdict(override) + + +def handle_delete_task_override(queue: Queue, task_name: str) -> dict[str, bool]: + removed = OverridesStore(queue).clear_task(task_name) + return {"cleared": removed} + + +# ── Queue override endpoints ────────────────────────────────────────── + + +def handle_get_queue_override(queue: Queue, _qs: dict, queue_name: str) -> dict[str, Any]: + override = OverridesStore(queue).get_queue(queue_name) + if override is None: + raise _NotFound(f"no override set for queue '{queue_name}'") + return asdict(override) + + +def handle_put_queue_override(queue: Queue, body: dict, queue_name: str) -> dict[str, Any]: + fields = _coerce_override_body(body, QUEUE_OVERRIDE_FIELDS) + try: + override = OverridesStore(queue).set_queue(queue_name, fields) + except ValueError as e: + raise _BadRequest(str(e)) from None + # Reflect "paused" immediately by touching the paused_queues store + # (this state DOES propagate to a running worker — independent of the + # static override consumed at worker startup). + if "paused" in fields: + try: + if fields["paused"]: + queue.pause(queue_name) + else: + queue.resume(queue_name) + except Exception: # pragma: no cover - safety net only + pass + return asdict(override) + + +def handle_delete_queue_override(queue: Queue, queue_name: str) -> dict[str, bool]: + removed = OverridesStore(queue).clear_queue(queue_name) + return {"cleared": removed} diff --git a/py_src/taskito/dashboard/middleware_store.py b/py_src/taskito/dashboard/middleware_store.py new file mode 100644 index 0000000..0c2554b --- /dev/null +++ b/py_src/taskito/dashboard/middleware_store.py @@ -0,0 +1,88 @@ +"""Per-task middleware disable list. + +Operators turn individual middlewares off for individual tasks from the +dashboard. The disable list is persisted under +``middleware:disabled:`` as a JSON array of middleware names, +read by :meth:`~taskito.mixins.decorators.QueueDecoratorMixin._get_middleware_chain` +at every task invocation so changes take effect immediately on the next +job without a worker restart. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from taskito.app import Queue + + +DISABLE_PREFIX = "middleware:disabled:" + +logger = logging.getLogger("taskito.dashboard.middleware") + + +def _parse(raw: str | None) -> list[str]: + if not raw: + return [] + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("middleware disable list is not valid JSON; treating as empty") + return [] + if not isinstance(data, list): + return [] + return [str(x) for x in data if isinstance(x, str)] + + +class MiddlewareDisableStore: + """List/set/clear per-task middleware disables.""" + + def __init__(self, queue: Queue) -> None: + self._queue = queue + + def _key(self, task_name: str) -> str: + return DISABLE_PREFIX + task_name + + def list_all(self) -> dict[str, list[str]]: + """Return ``{task_name: [disabled_mw_name, ...]}`` for every task that + has at least one disabled middleware.""" + out: dict[str, list[str]] = {} + for key, raw in self._queue.list_settings().items(): + if not key.startswith(DISABLE_PREFIX): + continue + task_name = key[len(DISABLE_PREFIX) :] + names = _parse(raw) + if names: + out[task_name] = names + return out + + def get_for(self, task_name: str) -> list[str]: + return _parse(self._queue.get_setting(self._key(task_name))) + + def is_disabled(self, task_name: str, mw_name: str) -> bool: + return mw_name in self.get_for(task_name) + + def set_disabled(self, task_name: str, mw_name: str, disabled: bool) -> list[str]: + """Flip a middleware on/off for a task and return the new disable list.""" + if not task_name: + raise ValueError("task_name must not be empty") + if not mw_name: + raise ValueError("mw_name must not be empty") + current = self.get_for(task_name) + if disabled: + if mw_name not in current: + current.append(mw_name) + else: + current = [n for n in current if n != mw_name] + if current: + self._queue.set_setting( + self._key(task_name), json.dumps(current, separators=(",", ":")) + ) + else: + self._queue.delete_setting(self._key(task_name)) + return current + + def clear_for(self, task_name: str) -> bool: + return self._queue.delete_setting(self._key(task_name)) diff --git a/py_src/taskito/dashboard/overrides_store.py b/py_src/taskito/dashboard/overrides_store.py new file mode 100644 index 0000000..d5d70f1 --- /dev/null +++ b/py_src/taskito/dashboard/overrides_store.py @@ -0,0 +1,341 @@ +"""Persistent task & queue runtime overrides. + +Operators tune individual task or queue behaviour (rate limits, concurrency +caps, retry policy, timeouts, priority, paused state) at runtime via the +dashboard. The decorator-declared values become the *defaults* — any override +recorded here wins. + +Storage layout in ``dashboard_settings``: + +- ``overrides:task:`` — JSON of overridden fields for that task +- ``overrides:queue:`` — JSON of overridden fields for that queue + +Overrides are applied at worker startup (see +:meth:`taskito.mixins.lifecycle.QueueLifecycleMixin.start_worker`). +Changes to the store DO NOT take effect on a running worker until it is +restarted — the dashboard surfaces this so operators aren't surprised. + +The contract is intentionally minimal: only the fields below can be +overridden. The store rejects anything else so a typo can't write garbage +through the dashboard. +""" + +from __future__ import annotations + +import json +import logging +import time +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from taskito.app import Queue + + +TASK_PREFIX = "overrides:task:" +QUEUE_PREFIX = "overrides:queue:" + +logger = logging.getLogger("taskito.dashboard.overrides") + + +# ── Allowed override fields ──────────────────────────────────────────── + + +TASK_OVERRIDE_FIELDS: frozenset[str] = frozenset( + { + "rate_limit", + "max_concurrent", + "max_retries", + "retry_backoff", + "timeout", + "priority", + "paused", + } +) + +QUEUE_OVERRIDE_FIELDS: frozenset[str] = frozenset( + { + "rate_limit", + "max_concurrent", + "paused", + } +) + + +# ── Data classes ─────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class TaskOverride: + """An operator-set override for a registered task.""" + + task_name: str + rate_limit: str | None = None + max_concurrent: int | None = None + max_retries: int | None = None + retry_backoff: float | None = None + timeout: int | None = None + priority: int | None = None + paused: bool = False + updated_at: int = 0 + + def as_patch(self) -> dict[str, Any]: + """Return a dict of only the non-default fields (those the operator + actually set). The empty/default values are NOT patched onto the + underlying ``PyTaskConfig`` — they continue to use the decorator + value.""" + patch: dict[str, Any] = {} + for field in TASK_OVERRIDE_FIELDS: + if field == "paused": + continue # handled separately; not a PyTaskConfig field + value = getattr(self, field) + if value is not None: + patch[field] = value + return patch + + +@dataclass(frozen=True) +class QueueOverride: + """An operator-set override for a queue.""" + + queue_name: str + rate_limit: str | None = None + max_concurrent: int | None = None + paused: bool = False + updated_at: int = 0 + + +# ── Validation ───────────────────────────────────────────────────────── + + +def _validate_task_fields(fields: dict[str, Any]) -> None: + unknown = set(fields) - TASK_OVERRIDE_FIELDS + if unknown: + raise ValueError(f"unknown task override fields: {sorted(unknown)}") + _validate_rate_limit(fields.get("rate_limit")) + _validate_max_concurrent(fields.get("max_concurrent")) + _validate_int_field(fields, "max_retries", minimum=0) + _validate_float_field(fields, "retry_backoff", minimum=0) + _validate_int_field(fields, "timeout", minimum=1) + _validate_int_field(fields, "priority") + _validate_bool_field(fields, "paused") + + +def _validate_queue_fields(fields: dict[str, Any]) -> None: + unknown = set(fields) - QUEUE_OVERRIDE_FIELDS + if unknown: + raise ValueError(f"unknown queue override fields: {sorted(unknown)}") + _validate_rate_limit(fields.get("rate_limit")) + _validate_max_concurrent(fields.get("max_concurrent")) + _validate_bool_field(fields, "paused") + + +def _validate_rate_limit(value: Any) -> None: + if value is None: + return + if not isinstance(value, str) or not value: + raise ValueError("rate_limit must be a non-empty string like '100/m'") + # Cheap shape check; rate-limit parsing happens in Rust. + if "/" not in value: + raise ValueError("rate_limit must contain a unit, e.g. '10/s', '100/m', '3600/h'") + + +def _validate_max_concurrent(value: Any) -> None: + if value is None: + return + if not isinstance(value, int) or isinstance(value, bool) or value < 0: + raise ValueError("max_concurrent must be a non-negative integer") + + +def _validate_int_field(fields: dict[str, Any], name: str, *, minimum: int | None = None) -> None: + value = fields.get(name) + if value is None: + return + if not isinstance(value, int) or isinstance(value, bool): + raise ValueError(f"{name} must be an integer") + if minimum is not None and value < minimum: + raise ValueError(f"{name} must be >= {minimum}") + + +def _validate_float_field( + fields: dict[str, Any], name: str, *, minimum: float | None = None +) -> None: + value = fields.get(name) + if value is None: + return + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise ValueError(f"{name} must be a number") + if minimum is not None and value < minimum: + raise ValueError(f"{name} must be >= {minimum}") + + +def _validate_bool_field(fields: dict[str, Any], name: str) -> None: + value = fields.get(name) + if value is not None and not isinstance(value, bool): + raise ValueError(f"{name} must be a boolean") + + +# ── Store ────────────────────────────────────────────────────────────── + + +def _now() -> int: + return int(time.time()) + + +def _parse_json(raw: str | None) -> dict[str, Any]: + if not raw: + return {} + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("overrides entry is not valid JSON; treating as empty") + return {} + return data if isinstance(data, dict) else {} + + +class OverridesStore: + """CRUD for per-task and per-queue runtime overrides.""" + + def __init__(self, queue: Queue) -> None: + self._queue = queue + + # ── Tasks ────────────────────────────────────────────────── + + def list_tasks(self) -> dict[str, TaskOverride]: + """Return ``{task_name: TaskOverride}`` for every task with an override.""" + out: dict[str, TaskOverride] = {} + for key, raw in self._queue.list_settings().items(): + if not key.startswith(TASK_PREFIX): + continue + task_name = key[len(TASK_PREFIX) :] + out[task_name] = self._row_to_task(task_name, _parse_json(raw)) + return out + + def get_task(self, task_name: str) -> TaskOverride | None: + raw = self._queue.get_setting(TASK_PREFIX + task_name) + if not raw: + return None + return self._row_to_task(task_name, _parse_json(raw)) + + def set_task(self, task_name: str, fields: dict[str, Any]) -> TaskOverride: + _validate_task_fields(fields) + if not task_name: + raise ValueError("task_name must not be empty") + existing = self.get_task(task_name) + merged: dict[str, Any] = {} + if existing is not None: + merged.update({k: v for k, v in asdict(existing).items() if v is not None}) + merged.pop("task_name", None) + merged.pop("updated_at", None) + for k, v in fields.items(): + if v is None: + merged.pop(k, None) + else: + merged[k] = v + merged["updated_at"] = _now() + self._queue.set_setting(TASK_PREFIX + task_name, json.dumps(merged, separators=(",", ":"))) + return self._row_to_task(task_name, merged) + + def clear_task(self, task_name: str) -> bool: + return self._queue.delete_setting(TASK_PREFIX + task_name) + + @staticmethod + def _row_to_task(task_name: str, row: dict[str, Any]) -> TaskOverride: + return TaskOverride( + task_name=task_name, + rate_limit=row.get("rate_limit"), + max_concurrent=row.get("max_concurrent"), + max_retries=row.get("max_retries"), + retry_backoff=row.get("retry_backoff"), + timeout=row.get("timeout"), + priority=row.get("priority"), + paused=bool(row.get("paused", False)), + updated_at=int(row.get("updated_at", 0)), + ) + + # ── Queues ───────────────────────────────────────────────── + + def list_queues(self) -> dict[str, QueueOverride]: + out: dict[str, QueueOverride] = {} + for key, raw in self._queue.list_settings().items(): + if not key.startswith(QUEUE_PREFIX): + continue + queue_name = key[len(QUEUE_PREFIX) :] + out[queue_name] = self._row_to_queue(queue_name, _parse_json(raw)) + return out + + def get_queue(self, queue_name: str) -> QueueOverride | None: + raw = self._queue.get_setting(QUEUE_PREFIX + queue_name) + if not raw: + return None + return self._row_to_queue(queue_name, _parse_json(raw)) + + def set_queue(self, queue_name: str, fields: dict[str, Any]) -> QueueOverride: + _validate_queue_fields(fields) + if not queue_name: + raise ValueError("queue_name must not be empty") + existing = self.get_queue(queue_name) + merged: dict[str, Any] = {} + if existing is not None: + merged.update({k: v for k, v in asdict(existing).items() if v is not None}) + merged.pop("queue_name", None) + merged.pop("updated_at", None) + for k, v in fields.items(): + if v is None: + merged.pop(k, None) + else: + merged[k] = v + merged["updated_at"] = _now() + self._queue.set_setting( + QUEUE_PREFIX + queue_name, json.dumps(merged, separators=(",", ":")) + ) + return self._row_to_queue(queue_name, merged) + + def clear_queue(self, queue_name: str) -> bool: + return self._queue.delete_setting(QUEUE_PREFIX + queue_name) + + @staticmethod + def _row_to_queue(queue_name: str, row: dict[str, Any]) -> QueueOverride: + return QueueOverride( + queue_name=queue_name, + rate_limit=row.get("rate_limit"), + max_concurrent=row.get("max_concurrent"), + paused=bool(row.get("paused", False)), + updated_at=int(row.get("updated_at", 0)), + ) + + # ── Apply (used at worker startup) ───────────────────────── + + def apply_task_overrides(self, configs: list[Any]) -> list[str]: + """Mutate each :class:`PyTaskConfig` in ``configs`` with any matching + task override. Returns a list of task names that are paused (so the + caller can skip enqueuing them). + """ + overrides = self.list_tasks() + paused: list[str] = [] + for config in configs: + override = overrides.get(config.name) + if override is None: + continue + for field, value in override.as_patch().items(): + if hasattr(config, field): + setattr(config, field, value) + if override.paused: + paused.append(config.name) + return paused + + def apply_queue_overrides( + self, queue_configs: dict[str, dict[str, Any]] + ) -> dict[str, dict[str, Any]]: + """Merge queue overrides into ``queue_configs``. Returns the merged + dict (a copy).""" + merged: dict[str, dict[str, Any]] = {k: dict(v) for k, v in queue_configs.items()} + for queue_name, override in self.list_queues().items(): + slot = merged.setdefault(queue_name, {}) + if override.rate_limit is not None: + slot["rate_limit"] = override.rate_limit + if override.max_concurrent is not None: + slot["max_concurrent"] = override.max_concurrent + if override.paused: + slot["paused"] = True + return merged diff --git a/py_src/taskito/dashboard/routes.py b/py_src/taskito/dashboard/routes.py index aeecc11..a4ab793 100644 --- a/py_src/taskito/dashboard/routes.py +++ b/py_src/taskito/dashboard/routes.py @@ -38,6 +38,22 @@ ) from taskito.dashboard.handlers.logs import _handle_logs from taskito.dashboard.handlers.metrics import _handle_metrics, _handle_metrics_timeseries +from taskito.dashboard.handlers.middleware import ( + handle_delete_task_middleware, + handle_get_task_middleware, + handle_list_middleware, + handle_put_task_middleware, +) +from taskito.dashboard.handlers.overrides import ( + handle_delete_queue_override, + handle_delete_task_override, + handle_get_queue_override, + handle_get_task_override, + handle_list_queues, + handle_list_tasks, + handle_put_queue_override, + handle_put_task_override, +) from taskito.dashboard.handlers.queues import _handle_stats_queues from taskito.dashboard.handlers.scaler import build_scaler_response from taskito.dashboard.handlers.settings import ( @@ -105,6 +121,9 @@ "/api/auth/status": handle_auth_status, "/api/webhooks": handle_list_webhooks, "/api/event-types": handle_list_event_types, + "/api/tasks": handle_list_tasks, + "/api/queues": handle_list_queues, + "/api/middleware": handle_list_middleware, } # ── Parameterized GET routes: regex → handler(queue, qs, captured_id) ── @@ -124,6 +143,9 @@ handle_list_deliveries, ), (re.compile(r"^/api/webhooks/([^/]+)$"), handle_get_webhook), + (re.compile(r"^/api/tasks/([^/]+)/override$"), handle_get_task_override), + (re.compile(r"^/api/queues/([^/]+)/override$"), handle_get_queue_override), + (re.compile(r"^/api/tasks/([^/]+)/middleware$"), handle_get_task_middleware), ] # GET routes with 2 captured groups (handler signature: queue, qs, (g1, g2)) @@ -194,12 +216,25 @@ PUT_PARAM_ROUTES: list[tuple[re.Pattern, Any]] = [ (re.compile(r"^/api/settings/(.+)$"), _handle_set_setting), (re.compile(r"^/api/webhooks/([^/]+)$"), handle_update_webhook), + (re.compile(r"^/api/tasks/([^/]+)/override$"), handle_put_task_override), + (re.compile(r"^/api/queues/([^/]+)/override$"), handle_put_queue_override), +] + +# PUT routes with 2 captured groups (handler signature: queue, body, (g1, g2)) +PUT_PARAM2_ROUTES: list[tuple[re.Pattern, Any]] = [ + ( + re.compile(r"^/api/tasks/([^/]+)/middleware/([^/]+)$"), + handle_put_task_middleware, + ), ] # ── Parameterized DELETE routes: regex → handler(queue, captured_id) ── DELETE_PARAM_ROUTES: list[tuple[re.Pattern, Any]] = [ (re.compile(r"^/api/settings/(.+)$"), _handle_delete_setting), (re.compile(r"^/api/webhooks/([^/]+)$"), handle_delete_webhook), + (re.compile(r"^/api/tasks/([^/]+)/override$"), handle_delete_task_override), + (re.compile(r"^/api/queues/([^/]+)/override$"), handle_delete_queue_override), + (re.compile(r"^/api/tasks/([^/]+)/middleware$"), handle_delete_task_middleware), ] diff --git a/py_src/taskito/dashboard/server.py b/py_src/taskito/dashboard/server.py index a4945f5..2aa1753 100644 --- a/py_src/taskito/dashboard/server.py +++ b/py_src/taskito/dashboard/server.py @@ -43,6 +43,7 @@ POST_PARAM_ROUTES, POST_ROUTES, PUBLIC_PATHS, + PUT_PARAM2_ROUTES, PUT_PARAM_ROUTES, is_csrf_exempt, is_state_changing_method, @@ -292,6 +293,17 @@ def _handle_put(self) -> None: param_handler, lambda h, m=m, body=body: h(queue, body, m.group(1)) ) return + for pattern, param_handler in PUT_PARAM2_ROUTES: + m = pattern.match(path) + if m: + body = self._read_json_body() + if body is None: + return + self._dispatch_with_handler( + param_handler, + lambda h, m=m, body=body: h(queue, body, (m.group(1), m.group(2))), + ) + return self._json_response({"error": "Not found"}, status=404) def _handle_delete(self) -> None: diff --git a/py_src/taskito/middleware.py b/py_src/taskito/middleware.py index 8650641..077ff33 100644 --- a/py_src/taskito/middleware.py +++ b/py_src/taskito/middleware.py @@ -55,12 +55,20 @@ def after(self, ctx, result, error): print(f"Finished {ctx.task_name}: {status}") """ + #: Stable identifier used to refer to this middleware from the dashboard + #: when toggling it on/off per task. Defaults to the class' fully-qualified + #: name so it survives restarts. Override on a subclass to pin a + #: shorter / more user-facing name. + name: str = "" + def __init__( self, *, predicate: Predicate | Callable[..., Any] | None = None, ) -> None: self._predicate = coerce_predicate(predicate) + if not type(self).name: + type(self).name = f"{type(self).__module__}.{type(self).__qualname__}" def _should_apply(self, ctx: JobContext | None, task_name: str = "") -> bool: """Decide whether this middleware's hooks should fire for ``ctx``. diff --git a/py_src/taskito/mixins/__init__.py b/py_src/taskito/mixins/__init__.py index f9b07ba..2d54c05 100644 --- a/py_src/taskito/mixins/__init__.py +++ b/py_src/taskito/mixins/__init__.py @@ -5,7 +5,9 @@ from taskito.mixins.inspection import QueueInspectionMixin from taskito.mixins.lifecycle import QueueLifecycleMixin from taskito.mixins.locks import QueueLockMixin +from taskito.mixins.middleware_admin import QueueMiddlewareAdminMixin from taskito.mixins.operations import QueueOperationsMixin +from taskito.mixins.overrides import QueueOverridesMixin from taskito.mixins.predicates import QueuePredicateMixin from taskito.mixins.resources import QueueResourceMixin from taskito.mixins.runtime_config import QueueRuntimeConfigMixin @@ -17,7 +19,9 @@ "QueueInspectionMixin", "QueueLifecycleMixin", "QueueLockMixin", + "QueueMiddlewareAdminMixin", "QueueOperationsMixin", + "QueueOverridesMixin", "QueuePredicateMixin", "QueueResourceMixin", "QueueRuntimeConfigMixin", diff --git a/py_src/taskito/mixins/decorators.py b/py_src/taskito/mixins/decorators.py index 671e9a2..e33c940 100644 --- a/py_src/taskito/mixins/decorators.py +++ b/py_src/taskito/mixins/decorators.py @@ -16,6 +16,7 @@ from taskito._taskito import PyTaskConfig from taskito.async_support.helpers import run_maybe_async from taskito.context import _clear_context, current_job +from taskito.dashboard.middleware_store import MiddlewareDisableStore from taskito.events import EventType from taskito.exceptions import TaskCancelledError from taskito.inject import Inject, _InjectAlias @@ -111,9 +112,18 @@ class QueueDecoratorMixin: _apply_dispatch_predicate: Callable[..., None] def _get_middleware_chain(self, task_name: str) -> list[TaskMiddleware]: - """Get the combined global + per-task middleware list.""" + """Get the combined global + per-task middleware list, minus any + middleware the operator has disabled for this task from the dashboard.""" per_task = self._task_middleware.get(task_name, []) - return self._global_middleware + per_task + chain = self._global_middleware + per_task + try: + disabled = MiddlewareDisableStore(self).get_for(task_name) # type: ignore[arg-type] + except Exception: # pragma: no cover - storage read failure is non-fatal + disabled = [] + if not disabled: + return chain + disabled_set = set(disabled) + return [mw for mw in chain if getattr(mw, "name", "") not in disabled_set] def _wrap_task( self, fn: Callable, task_name: str, soft_timeout: float | None = None diff --git a/py_src/taskito/mixins/lifecycle.py b/py_src/taskito/mixins/lifecycle.py index 874d553..9b912e7 100644 --- a/py_src/taskito/mixins/lifecycle.py +++ b/py_src/taskito/mixins/lifecycle.py @@ -16,6 +16,7 @@ import taskito from taskito._taskito import PyQueue, PyTaskConfig from taskito.context import _set_queue_ref +from taskito.dashboard.overrides_store import OverridesStore from taskito.events import EventType from taskito.log_config import configure as configure_logging from taskito.log_config import restore_asyncio_pipe_noise, silence_asyncio_pipe_noise @@ -231,7 +232,24 @@ def sighup_handler(signum: int, frame: Any) -> None: ) try: - queue_configs_json = json.dumps(self._queue_configs) if self._queue_configs else None + overrides = OverridesStore(self) # type: ignore[arg-type] + # Mutate the in-memory PyTaskConfig list so the Rust scheduler + # sees the override values; merge queue-level overrides into + # the JSON blob passed to run_worker. Paused tasks/queues get + # their pause state propagated to the existing paused_queues + # mechanism for tasks-by-queue, but per-task pause is left to + # the application-level guard in enqueue (out of scope here). + paused_tasks = overrides.apply_task_overrides(self._task_configs) + if paused_tasks: + logger.info("Paused task overrides in effect: %s", paused_tasks) + merged_queue_configs = overrides.apply_queue_overrides(self._queue_configs) + for queue_name, slot in merged_queue_configs.items(): + if slot.get("paused"): + try: + self.pause(queue_name) # type: ignore[attr-defined] + except Exception: + logger.exception("Failed to apply paused state for queue %s", queue_name) + queue_configs_json = json.dumps(merged_queue_configs) if merged_queue_configs else None self._inner.run_worker( task_registry=self._task_registry, task_configs=self._task_configs, diff --git a/py_src/taskito/mixins/middleware_admin.py b/py_src/taskito/mixins/middleware_admin.py new file mode 100644 index 0000000..dd9af80 --- /dev/null +++ b/py_src/taskito/mixins/middleware_admin.py @@ -0,0 +1,70 @@ +"""Middleware discovery and per-task disable management on :class:`Queue`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.middleware_store import MiddlewareDisableStore + +if TYPE_CHECKING: + from taskito.middleware import TaskMiddleware + + +class QueueMiddlewareAdminMixin: + """Discovery + per-task enable/disable for registered middlewares.""" + + _global_middleware: list[TaskMiddleware] + _task_middleware: dict[str, list[TaskMiddleware]] + + # ── Discovery ────────────────────────────────────────────────── + + def list_middleware(self) -> list[dict[str, Any]]: + """Return every registered middleware (global + per-task) with its + name, source ("global" or task name), and Python class path. The + ``name`` is the value the disable list keys on.""" + seen: dict[str, dict[str, Any]] = {} + for mw in self._global_middleware: + name = getattr(mw, "name", "") or f"{type(mw).__module__}.{type(mw).__qualname__}" + seen.setdefault( + name, + { + "name": name, + "class_path": f"{type(mw).__module__}.{type(mw).__qualname__}", + "scopes": [], + }, + )["scopes"].append({"kind": "global"}) + for task_name, mws in self._task_middleware.items(): + for mw in mws: + name = getattr(mw, "name", "") or f"{type(mw).__module__}.{type(mw).__qualname__}" + entry = seen.setdefault( + name, + { + "name": name, + "class_path": f"{type(mw).__module__}.{type(mw).__qualname__}", + "scopes": [], + }, + ) + entry["scopes"].append({"kind": "task", "task": task_name}) + return sorted(seen.values(), key=lambda x: x["name"]) + + # ── Disable management ───────────────────────────────────────── + + def list_middleware_disables(self) -> dict[str, list[str]]: + """Return every task that has at least one disabled middleware.""" + return MiddlewareDisableStore(self).list_all() # type: ignore[arg-type] + + def get_disabled_middleware_for(self, task_name: str) -> list[str]: + return MiddlewareDisableStore(self).get_for(task_name) # type: ignore[arg-type] + + def disable_middleware_for_task(self, task_name: str, mw_name: str) -> list[str]: + return MiddlewareDisableStore(self).set_disabled( # type: ignore[arg-type] + task_name, mw_name, disabled=True + ) + + def enable_middleware_for_task(self, task_name: str, mw_name: str) -> list[str]: + return MiddlewareDisableStore(self).set_disabled( # type: ignore[arg-type] + task_name, mw_name, disabled=False + ) + + def clear_middleware_disables(self, task_name: str) -> bool: + return MiddlewareDisableStore(self).clear_for(task_name) # type: ignore[arg-type] diff --git a/py_src/taskito/mixins/overrides.py b/py_src/taskito/mixins/overrides.py new file mode 100644 index 0000000..aae9ace --- /dev/null +++ b/py_src/taskito/mixins/overrides.py @@ -0,0 +1,151 @@ +"""Task & queue runtime override management on :class:`taskito.app.Queue`. + +These knobs let operators tune retry policy, concurrency caps, rate +limits, timeouts, priority, and pause/resume state without touching +code. Overrides land in the dashboard settings store and apply on the +next worker startup. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.overrides_store import ( + OverridesStore, + QueueOverride, + TaskOverride, +) + +if TYPE_CHECKING: + from taskito._taskito import PyTaskConfig + + +class QueueOverridesMixin: + """CRUD for task + queue overrides, plus a task-discovery API for the UI.""" + + _task_configs: list[PyTaskConfig] + _queue_configs: dict[str, dict[str, Any]] + + # ── Task overrides ───────────────────────────────────────────── + + def list_task_overrides(self) -> dict[str, TaskOverride]: + """Return every persisted task override keyed by task name.""" + return OverridesStore(self).list_tasks() # type: ignore[arg-type] + + def get_task_override(self, task_name: str) -> TaskOverride | None: + return OverridesStore(self).get_task(task_name) # type: ignore[arg-type] + + def set_task_override(self, task_name: str, **fields: Any) -> TaskOverride: + """Set or update an override. Pass ``None`` for a field to clear it. + + Allowed fields: ``rate_limit``, ``max_concurrent``, ``max_retries``, + ``retry_backoff``, ``timeout``, ``priority``, ``paused``. + """ + return OverridesStore(self).set_task(task_name, fields) # type: ignore[arg-type] + + def clear_task_override(self, task_name: str) -> bool: + return OverridesStore(self).clear_task(task_name) # type: ignore[arg-type] + + # ── Queue overrides ──────────────────────────────────────────── + + def list_queue_overrides(self) -> dict[str, QueueOverride]: + return OverridesStore(self).list_queues() # type: ignore[arg-type] + + def get_queue_override(self, queue_name: str) -> QueueOverride | None: + return OverridesStore(self).get_queue(queue_name) # type: ignore[arg-type] + + def set_queue_override(self, queue_name: str, **fields: Any) -> QueueOverride: + """Set or update a queue override. Allowed fields: ``rate_limit``, + ``max_concurrent``, ``paused``.""" + return OverridesStore(self).set_queue(queue_name, fields) # type: ignore[arg-type] + + def clear_queue_override(self, queue_name: str) -> bool: + return OverridesStore(self).clear_queue(queue_name) # type: ignore[arg-type] + + # ── Task discovery (for the dashboard) ───────────────────────── + + def registered_tasks(self) -> list[dict[str, Any]]: + """Return every registered task with its decorator defaults and any + active override. Each entry contains: + + - ``name``, ``queue``, ``priority`` + - ``defaults``: the decorator-declared values + - ``override``: the override fields (or ``None`` if no override exists) + - ``effective``: the values that will be used on the next worker start + """ + overrides = self.list_task_overrides() + out: list[dict[str, Any]] = [] + for config in self._task_configs: + defaults = { + "max_retries": config.max_retries, + "retry_backoff": config.retry_backoff, + "timeout": config.timeout, + "priority": config.priority, + "rate_limit": config.rate_limit, + "max_concurrent": config.max_concurrent, + } + override = overrides.get(config.name) + override_dict: dict[str, Any] | None + if override is None: + override_dict = None + effective = dict(defaults) + paused = False + else: + patch = override.as_patch() + override_dict = dict(patch) + if override.paused: + override_dict["paused"] = True + effective = {**defaults, **patch} + paused = override.paused + out.append( + { + "name": config.name, + "queue": config.queue, + "defaults": defaults, + "override": override_dict, + "effective": effective, + "paused": paused, + } + ) + return out + + def registered_queues(self) -> list[dict[str, Any]]: + """Return every queue mentioned by a task config plus any + configured-from-Python queue, with its current overrides + paused + state.""" + queue_names: set[str] = set() + queue_names.update(self._queue_configs.keys()) + for config in self._task_configs: + queue_names.add(config.queue) + overrides = self.list_queue_overrides() + paused_set = set( + self.paused_queues() # type: ignore[attr-defined] + ) + out: list[dict[str, Any]] = [] + for name in sorted(queue_names): + base = dict(self._queue_configs.get(name, {})) + override = overrides.get(name) + override_dict: dict[str, Any] | None + if override is None: + override_dict = None + effective = dict(base) + else: + patch: dict[str, Any] = {} + if override.rate_limit is not None: + patch["rate_limit"] = override.rate_limit + if override.max_concurrent is not None: + patch["max_concurrent"] = override.max_concurrent + override_dict = dict(patch) + if override.paused: + override_dict["paused"] = True + effective = {**base, **patch} + out.append( + { + "name": name, + "defaults": base, + "override": override_dict, + "effective": effective, + "paused": name in paused_set or (override.paused if override else False), + } + ) + return out diff --git a/tests/dashboard/test_middleware_toggles.py b/tests/dashboard/test_middleware_toggles.py new file mode 100644 index 0000000..6713f09 --- /dev/null +++ b/tests/dashboard/test_middleware_toggles.py @@ -0,0 +1,234 @@ +"""Tests for per-task middleware enable/disable from the dashboard.""" + +from __future__ import annotations + +import threading +import urllib.error +from collections.abc import Generator +from http.server import ThreadingHTTPServer +from pathlib import Path +from typing import Any + +import pytest + +from taskito import Queue +from taskito.context import JobContext +from taskito.dashboard import _make_handler +from taskito.dashboard._testing import AuthedClient, seed_admin_and_session +from taskito.dashboard.middleware_store import MiddlewareDisableStore +from taskito.middleware import TaskMiddleware + + +class RecordingMiddleware(TaskMiddleware): + """Captures every ``before`` invocation so the test can assert which + tasks the middleware fired for.""" + + name = "test.recording" + + def __init__(self) -> None: + super().__init__() + self.invocations: list[str] = [] + + def before(self, ctx: JobContext) -> None: + self.invocations.append(ctx.task_name) + + +class OtherMiddleware(TaskMiddleware): + name = "test.other" + + def __init__(self) -> None: + super().__init__() + self.invocations: list[str] = [] + + def before(self, ctx: JobContext) -> None: + self.invocations.append(ctx.task_name) + + +@pytest.fixture +def middleware_pair() -> tuple[RecordingMiddleware, OtherMiddleware]: + return RecordingMiddleware(), OtherMiddleware() + + +@pytest.fixture +def queue(tmp_path: Path, middleware_pair: tuple[RecordingMiddleware, OtherMiddleware]) -> Queue: + rec, other = middleware_pair + q = Queue(db_path=str(tmp_path / "mw.db"), middleware=[rec, other]) + + @q.task() + def alpha() -> str: + return "a" + + @q.task() + def beta() -> str: + return "b" + + return q + + +@pytest.fixture +def dashboard(queue: Queue) -> Generator[tuple[AuthedClient, Queue]]: + handler = _make_handler(queue) + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + threading.Thread(target=server.serve_forever, daemon=True).start() + session = seed_admin_and_session(queue) + client = AuthedClient(base=f"http://127.0.0.1:{server.server_address[1]}", session=session) + try: + yield client, queue + finally: + server.shutdown() + + +# ── Store ────────────────────────────────────────────────────────────── + + +def test_store_starts_empty(queue: Queue) -> None: + store = MiddlewareDisableStore(queue) + assert store.list_all() == {} + assert store.get_for("alpha") == [] + + +def test_set_disabled_adds_and_removes(queue: Queue) -> None: + store = MiddlewareDisableStore(queue) + store.set_disabled("alpha", "test.other", True) + assert store.get_for("alpha") == ["test.other"] + # Idempotent — same disable twice still has just one entry. + store.set_disabled("alpha", "test.other", True) + assert store.get_for("alpha") == ["test.other"] + # Re-enable clears just that one. + store.set_disabled("alpha", "test.other", False) + assert store.get_for("alpha") == [] + + +def test_clear_for_drops_setting_key(queue: Queue) -> None: + store = MiddlewareDisableStore(queue) + store.set_disabled("alpha", "test.other", True) + assert store.clear_for("alpha") is True + assert store.clear_for("alpha") is False + assert store.get_for("alpha") == [] + + +# ── Wiring into the middleware chain ────────────────────────────────── + + +def test_chain_skips_disabled_middleware(queue: Queue) -> None: + """``_get_middleware_chain`` returns a chain that respects the disable + list at lookup time — no worker restart required.""" + full = queue._get_middleware_chain("alpha") + assert {mw.name for mw in full} == {"test.recording", "test.other"} + queue.disable_middleware_for_task("alpha", "test.other") + filtered = queue._get_middleware_chain("alpha") + assert {mw.name for mw in filtered} == {"test.recording"} + # Other tasks unaffected. + assert {mw.name for mw in queue._get_middleware_chain("beta")} == { + "test.recording", + "test.other", + } + + +def test_clear_re_enables_all(queue: Queue) -> None: + queue.disable_middleware_for_task("alpha", "test.other") + queue.disable_middleware_for_task("alpha", "test.recording") + assert queue._get_middleware_chain("alpha") == [] + queue.clear_middleware_disables("alpha") + assert len(queue._get_middleware_chain("alpha")) == 2 + + +# ── Discovery ───────────────────────────────────────────────────────── + + +def test_list_middleware_reports_globals(queue: Queue) -> None: + items = queue.list_middleware() + names = {item["name"] for item in items} + assert {"test.recording", "test.other"} <= names + for entry in items: + assert any(scope["kind"] == "global" for scope in entry["scopes"]) + + +# ── HTTP endpoints ──────────────────────────────────────────────────── + + +def test_list_middleware_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + items = client.get("/api/middleware") + names = {item["name"] for item in items} + assert {"test.recording", "test.other"} <= names + + +def test_get_task_middleware_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + result = client.get("/api/tasks/alpha/middleware") + by_name = {entry["name"]: entry for entry in result["middleware"]} + assert by_name["test.recording"]["disabled"] is False + assert by_name["test.recording"]["effective"] is True + + +def test_put_task_middleware_disables(dashboard: tuple[AuthedClient, Queue]) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + result = client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": False}) + assert "test.other" in result["disabled"] + # Reflected in the chain. + chain_names = {mw.name for mw in queue._get_middleware_chain(name)} + assert "test.other" not in chain_names + # Re-enabling clears it. + client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": True}) + chain_names = {mw.name for mw in queue._get_middleware_chain(name)} + assert "test.other" in chain_names + + +def test_put_task_middleware_rejects_unknown_middleware( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put(f"/api/tasks/{name}/middleware/not.a.real.mw", {"enabled": False}) + assert exc_info.value.code == 404 + + +def test_put_task_middleware_rejects_bad_body( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": "yes"}) + assert exc_info.value.code == 400 + + +def test_delete_task_middleware_clears_all( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": False}) + client.put(f"/api/tasks/{name}/middleware/test.recording", {"enabled": False}) + assert queue._get_middleware_chain(name) == [] + result = client.delete(f"/api/tasks/{name}/middleware") + assert result == {"cleared": True} + assert len(queue._get_middleware_chain(name)) == 2 + + +# ── End-to-end: disabled middleware doesn't fire ───────────────────── + + +def test_disabled_middleware_does_not_fire( + queue: Queue, + middleware_pair: tuple[RecordingMiddleware, OtherMiddleware], + poll_until: Any, +) -> None: + rec, other = middleware_pair + alpha_name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + queue.disable_middleware_for_task(alpha_name, "test.other") + + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + try: + queue.enqueue(alpha_name) + poll_until(lambda: alpha_name in rec.invocations, message="task didn't run") + finally: + queue._inner.request_shutdown() + thread.join(timeout=5) + + assert alpha_name in rec.invocations # global fired + assert alpha_name not in other.invocations # disabled for this task diff --git a/tests/dashboard/test_task_overrides.py b/tests/dashboard/test_task_overrides.py new file mode 100644 index 0000000..ba01e2f --- /dev/null +++ b/tests/dashboard/test_task_overrides.py @@ -0,0 +1,234 @@ +"""Tests for task & queue runtime overrides.""" + +from __future__ import annotations + +import threading +import urllib.error +from collections.abc import Generator +from http.server import ThreadingHTTPServer +from pathlib import Path + +import pytest + +from taskito import Queue +from taskito.dashboard import _make_handler +from taskito.dashboard._testing import AuthedClient, seed_admin_and_session +from taskito.dashboard.overrides_store import OverridesStore + + +@pytest.fixture +def queue(tmp_path: Path) -> Queue: + q = Queue(db_path=str(tmp_path / "overrides.db")) + + @q.task(queue="default", max_retries=3, timeout=300) + def send_email(to: str) -> str: + return to + + @q.task(queue="email", max_retries=5, rate_limit="100/m", max_concurrent=10) + def deliver(message: str) -> str: + return message + + return q + + +@pytest.fixture +def dashboard(queue: Queue) -> Generator[tuple[AuthedClient, Queue]]: + handler = _make_handler(queue) + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + threading.Thread(target=server.serve_forever, daemon=True).start() + session = seed_admin_and_session(queue) + client = AuthedClient(base=f"http://127.0.0.1:{server.server_address[1]}", session=session) + try: + yield client, queue + finally: + server.shutdown() + + +# ── Store ────────────────────────────────────────────────────────────── + + +def test_overrides_store_starts_empty(queue: Queue) -> None: + store = OverridesStore(queue) + assert store.list_tasks() == {} + assert store.list_queues() == {} + + +def test_set_task_override_persists(queue: Queue) -> None: + store = OverridesStore(queue) + override = store.set_task("foo", {"max_retries": 7, "rate_limit": "50/s"}) + assert override.max_retries == 7 + assert override.rate_limit == "50/s" + fetched = store.get_task("foo") + assert fetched is not None and fetched.max_retries == 7 + + +def test_set_task_override_validates(queue: Queue) -> None: + store = OverridesStore(queue) + with pytest.raises(ValueError, match="rate_limit"): + store.set_task("foo", {"rate_limit": "no-slash"}) + with pytest.raises(ValueError, match="max_concurrent"): + store.set_task("foo", {"max_concurrent": -1}) + with pytest.raises(ValueError, match="unknown task override"): + store.set_task("foo", {"not_a_field": 1}) + + +def test_set_task_override_merges_with_existing(queue: Queue) -> None: + store = OverridesStore(queue) + store.set_task("foo", {"max_retries": 7}) + store.set_task("foo", {"rate_limit": "50/s"}) + merged = store.get_task("foo") + assert merged is not None + assert merged.max_retries == 7 + assert merged.rate_limit == "50/s" + + +def test_set_task_override_clears_field_with_none(queue: Queue) -> None: + store = OverridesStore(queue) + store.set_task("foo", {"max_retries": 7, "rate_limit": "50/s"}) + store.set_task("foo", {"max_retries": None}) + fetched = store.get_task("foo") + assert fetched is not None + assert fetched.max_retries is None + assert fetched.rate_limit == "50/s" + + +def test_clear_task_override(queue: Queue) -> None: + store = OverridesStore(queue) + store.set_task("foo", {"max_retries": 7}) + assert store.clear_task("foo") is True + assert store.clear_task("foo") is False + assert store.get_task("foo") is None + + +def test_queue_override_basics(queue: Queue) -> None: + store = OverridesStore(queue) + store.set_queue("default", {"max_concurrent": 5, "paused": True}) + fetched = store.get_queue("default") + assert fetched is not None + assert fetched.max_concurrent == 5 + assert fetched.paused is True + + +def test_apply_task_overrides_mutates_configs(queue: Queue) -> None: + """Mutating the in-memory PyTaskConfig is what makes overrides reach the + Rust scheduler at worker start.""" + store = OverridesStore(queue) + send_email = next(c for c in queue._task_configs if "send_email" in c.name) + store.set_task(send_email.name, {"max_retries": 99, "rate_limit": "1/s"}) + store.apply_task_overrides(queue._task_configs) + assert send_email.max_retries == 99 + assert send_email.rate_limit == "1/s" + + +def test_apply_task_overrides_reports_paused(queue: Queue) -> None: + store = OverridesStore(queue) + send_email = next(c for c in queue._task_configs if "send_email" in c.name) + store.set_task(send_email.name, {"paused": True}) + paused = store.apply_task_overrides(queue._task_configs) + assert send_email.name in paused + + +def test_apply_queue_overrides_merges(queue: Queue) -> None: + store = OverridesStore(queue) + queue.set_queue_concurrency("email", 10) # configured-from-Python + store.set_queue("email", {"rate_limit": "200/m"}) + merged = store.apply_queue_overrides(queue._queue_configs) + assert merged["email"]["max_concurrent"] == 10 # decorator-set survives + assert merged["email"]["rate_limit"] == "200/m" # override wins + + +# ── Queue.registered_tasks() ────────────────────────────────────────── + + +def test_registered_tasks_lists_defaults_and_overrides(queue: Queue) -> None: + tasks = queue.registered_tasks() + assert len(tasks) == 2 + by_name = {t["name"]: t for t in tasks} + deliver = next(t for n, t in by_name.items() if "deliver" in n) + assert deliver["defaults"]["rate_limit"] == "100/m" + assert deliver["defaults"]["max_retries"] == 5 + assert deliver["override"] is None + assert deliver["effective"]["rate_limit"] == "100/m" + + +def test_registered_tasks_reflects_override(queue: Queue) -> None: + send_email = next(t for t in queue.registered_tasks() if "send_email" in t["name"]) + queue.set_task_override(send_email["name"], max_retries=99) + fresh = next(t for t in queue.registered_tasks() if t["name"] == send_email["name"]) + assert fresh["override"] == {"max_retries": 99} + assert fresh["effective"]["max_retries"] == 99 + assert fresh["defaults"]["max_retries"] == 3 # original decorator value + + +# ── HTTP endpoints ──────────────────────────────────────────────────── + + +def test_list_tasks_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + tasks = client.get("/api/tasks") + assert len(tasks) == 2 + for entry in tasks: + assert "name" in entry and "defaults" in entry and "effective" in entry + + +def test_put_task_override(dashboard: tuple[AuthedClient, Queue]) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if "send_email" in c.name) + result = client.put( + f"/api/tasks/{name}/override", + {"max_retries": 7, "rate_limit": "50/s"}, + ) + assert result["max_retries"] == 7 + assert result["rate_limit"] == "50/s" + + fetched = client.get(f"/api/tasks/{name}/override") + assert fetched["max_retries"] == 7 + + +def test_put_task_override_rejects_unknown_field( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if "send_email" in c.name) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put(f"/api/tasks/{name}/override", {"made_up": 1}) + assert exc_info.value.code == 400 + + +def test_delete_task_override(dashboard: tuple[AuthedClient, Queue]) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if "send_email" in c.name) + client.put(f"/api/tasks/{name}/override", {"max_retries": 7}) + assert client.delete(f"/api/tasks/{name}/override") == {"cleared": True} + assert client.delete(f"/api/tasks/{name}/override") == {"cleared": False} + + +def test_get_task_override_404_when_none(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.get("/api/tasks/nonexistent/override") + assert exc_info.value.code == 404 + + +def test_put_queue_override_pauses_queue(dashboard: tuple[AuthedClient, Queue]) -> None: + """Pausing via queue override must also update the live paused_queues + state so a running worker stops dequeueing immediately.""" + client, queue = dashboard + client.put("/api/queues/email/override", {"paused": True}) + assert "email" in queue.paused_queues() + client.put("/api/queues/email/override", {"paused": False}) + assert "email" not in queue.paused_queues() + + +def test_list_queues_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + queues = client.get("/api/queues") + names = {q["name"] for q in queues} + assert {"default", "email"} <= names + + +def test_put_queue_override_validates(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put("/api/queues/default/override", {"max_concurrent": -1}) + assert exc_info.value.code == 400