From 87f170bc88d1a41d42a2947da49956ebfbb78347 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Sun, 17 May 2026 07:00:17 +0530 Subject: [PATCH 1/2] feat(middleware): toggle middleware per task from dashboard MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TaskMiddleware now carries a stable ``name`` (defaulting to the fully- qualified class path so it survives restarts; overridable on a subclass). Per-task disable lists live under ``middleware:disabled:`` in the dashboard settings store, consulted by ``_get_middleware_chain`` at every invocation — turning a middleware off for one task takes effect on the next job, no worker restart needed. Queue gains ``list_middleware()``, ``disable_middleware_for_task``, ``enable_middleware_for_task``, ``clear_middleware_disables``. New endpoints: GET /api/middleware (discovery), GET/DELETE /api/tasks/{task}/middleware, PUT /api/tasks/{task}/middleware/{mw_name}. --- .../tasks/components/task-override-form.tsx | 97 +++++++- py_src/taskito/app.py | 2 + .../taskito/dashboard/handlers/middleware.py | 62 +++++ py_src/taskito/dashboard/middleware_store.py | 88 +++++++ py_src/taskito/dashboard/routes.py | 17 ++ py_src/taskito/dashboard/server.py | 12 + py_src/taskito/middleware.py | 8 + py_src/taskito/mixins/__init__.py | 2 + py_src/taskito/mixins/decorators.py | 14 +- py_src/taskito/mixins/middleware_admin.py | 70 ++++++ tests/dashboard/test_middleware_toggles.py | 234 ++++++++++++++++++ 11 files changed, 592 insertions(+), 14 deletions(-) create mode 100644 py_src/taskito/dashboard/handlers/middleware.py create mode 100644 py_src/taskito/dashboard/middleware_store.py create mode 100644 py_src/taskito/mixins/middleware_admin.py create mode 100644 tests/dashboard/test_middleware_toggles.py diff --git a/dashboard/src/features/tasks/components/task-override-form.tsx b/dashboard/src/features/tasks/components/task-override-form.tsx index 9767acb..5e962b5 100644 --- a/dashboard/src/features/tasks/components/task-override-form.tsx +++ b/dashboard/src/features/tasks/components/task-override-form.tsx @@ -1,8 +1,9 @@ import { Save, Trash2 } from "lucide-react"; import { type FormEvent, useState } from "react"; -import { Button, Input } from "@/components/ui"; +import { Button, Input, Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui"; import { useClearTaskOverride, useSetTaskOverride } from "../hooks"; import type { TaskEntry, TaskOverridePatch } from "../types"; +import { MiddlewareToggles } from "./middleware-toggles"; interface Props { task: TaskEntry; @@ -58,15 +59,89 @@ export function TaskOverrideForm({ task, onDone }: Props) { } return ( -
+

{task.name}

Queue · {task.queue}

-

- Overrides apply on the next worker restart; pausing takes effect immediately. -

+ + + Overrides + Middleware + + + clearOverride.mutate(task.name, { onSuccess: () => onDone?.() })} + /> + + + + + +
+ ); +} + +interface OverrideFormProps { + task: TaskEntry; + onSubmit: (e: FormEvent) => void; + rateLimit: string; + setRateLimit: (v: string) => void; + maxConcurrent: string; + setMaxConcurrent: (v: string) => void; + maxRetries: string; + setMaxRetries: (v: string) => void; + timeoutValue: string; + setTimeoutValue: (v: string) => void; + priority: string; + setPriority: (v: string) => void; + paused: boolean; + setPaused: (v: boolean) => void; + saving: boolean; + clearing: boolean; + onClear: () => void; +} +function OverrideForm({ + task, + onSubmit, + rateLimit, + setRateLimit, + maxConcurrent, + setMaxConcurrent, + maxRetries, + setMaxRetries, + timeoutValue, + setTimeoutValue, + priority, + setPriority, + paused, + setPaused, + saving, + clearing, + onClear, +}: OverrideFormProps) { + return ( + +

+ Overrides apply on the next worker restart; pausing takes effect immediately. +

- -
-
diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index 9fe1e0c..c8273c7 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -38,6 +38,7 @@ QueueInspectionMixin, QueueLifecycleMixin, QueueLockMixin, + QueueMiddlewareAdminMixin, QueueOperationsMixin, QueueOverridesMixin, QueuePredicateMixin, @@ -84,6 +85,7 @@ class Queue( QueueInspectionMixin, QueueOperationsMixin, QueueLockMixin, + QueueMiddlewareAdminMixin, QueueOverridesMixin, QueueSettingsMixin, QueueWorkflowMixin, diff --git a/py_src/taskito/dashboard/handlers/middleware.py b/py_src/taskito/dashboard/handlers/middleware.py new file mode 100644 index 0000000..e0fd85b --- /dev/null +++ b/py_src/taskito/dashboard/handlers/middleware.py @@ -0,0 +1,62 @@ +"""Middleware discovery + per-task enable/disable endpoints.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.errors import _BadRequest, _NotFound + +if TYPE_CHECKING: + from taskito.app import Queue + + +def handle_list_middleware(queue: Queue, _qs: dict) -> list[dict[str, Any]]: + """Return every registered middleware with its scopes.""" + return queue.list_middleware() + + +def handle_get_task_middleware(queue: Queue, _qs: dict, task_name: str) -> dict[str, Any]: + """Return the middleware chain that fires for ``task_name`` with each + entry's enabled/disabled state.""" + chain = queue._get_middleware_chain(task_name) + disabled = set(queue.get_disabled_middleware_for(task_name)) + # Build the full would-fire chain INCLUDING disabled entries so the UI + # can render every toggle. + base_chain = queue._global_middleware + queue._task_middleware.get(task_name, []) + entries: list[dict[str, Any]] = [] + chain_names = {getattr(mw, "name", "") for mw in chain} + for mw in base_chain: + name = getattr(mw, "name", "") or f"{type(mw).__module__}.{type(mw).__qualname__}" + entries.append( + { + "name": name, + "class_path": f"{type(mw).__module__}.{type(mw).__qualname__}", + "disabled": name in disabled, + "effective": name in chain_names, + } + ) + return {"task": task_name, "middleware": entries} + + +def handle_put_task_middleware(queue: Queue, body: dict, ids: tuple[str, str]) -> dict[str, Any]: + task_name, mw_name = ids + if not isinstance(body, dict) or "enabled" not in body: + raise _BadRequest('body must include {"enabled": bool}') + if not isinstance(body["enabled"], bool): + raise _BadRequest("'enabled' must be a boolean") + # Confirm the middleware exists in the relevant chain so a typo doesn't + # silently write a no-op disable entry. + base_chain = queue._global_middleware + queue._task_middleware.get(task_name, []) + names = {getattr(mw, "name", "") for mw in base_chain} + if mw_name not in names: + raise _NotFound(f"middleware '{mw_name}' is not registered on task '{task_name}'") + if body["enabled"]: + new = queue.enable_middleware_for_task(task_name, mw_name) + else: + new = queue.disable_middleware_for_task(task_name, mw_name) + return {"task": task_name, "disabled": new} + + +def handle_delete_task_middleware(queue: Queue, task_name: str) -> dict[str, bool]: + """Clear ALL disables for a task — every middleware fires again.""" + return {"cleared": queue.clear_middleware_disables(task_name)} diff --git a/py_src/taskito/dashboard/middleware_store.py b/py_src/taskito/dashboard/middleware_store.py new file mode 100644 index 0000000..0c2554b --- /dev/null +++ b/py_src/taskito/dashboard/middleware_store.py @@ -0,0 +1,88 @@ +"""Per-task middleware disable list. + +Operators turn individual middlewares off for individual tasks from the +dashboard. The disable list is persisted under +``middleware:disabled:`` as a JSON array of middleware names, +read by :meth:`~taskito.mixins.decorators.QueueDecoratorMixin._get_middleware_chain` +at every task invocation so changes take effect immediately on the next +job without a worker restart. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from taskito.app import Queue + + +DISABLE_PREFIX = "middleware:disabled:" + +logger = logging.getLogger("taskito.dashboard.middleware") + + +def _parse(raw: str | None) -> list[str]: + if not raw: + return [] + try: + data = json.loads(raw) + except json.JSONDecodeError: + logger.warning("middleware disable list is not valid JSON; treating as empty") + return [] + if not isinstance(data, list): + return [] + return [str(x) for x in data if isinstance(x, str)] + + +class MiddlewareDisableStore: + """List/set/clear per-task middleware disables.""" + + def __init__(self, queue: Queue) -> None: + self._queue = queue + + def _key(self, task_name: str) -> str: + return DISABLE_PREFIX + task_name + + def list_all(self) -> dict[str, list[str]]: + """Return ``{task_name: [disabled_mw_name, ...]}`` for every task that + has at least one disabled middleware.""" + out: dict[str, list[str]] = {} + for key, raw in self._queue.list_settings().items(): + if not key.startswith(DISABLE_PREFIX): + continue + task_name = key[len(DISABLE_PREFIX) :] + names = _parse(raw) + if names: + out[task_name] = names + return out + + def get_for(self, task_name: str) -> list[str]: + return _parse(self._queue.get_setting(self._key(task_name))) + + def is_disabled(self, task_name: str, mw_name: str) -> bool: + return mw_name in self.get_for(task_name) + + def set_disabled(self, task_name: str, mw_name: str, disabled: bool) -> list[str]: + """Flip a middleware on/off for a task and return the new disable list.""" + if not task_name: + raise ValueError("task_name must not be empty") + if not mw_name: + raise ValueError("mw_name must not be empty") + current = self.get_for(task_name) + if disabled: + if mw_name not in current: + current.append(mw_name) + else: + current = [n for n in current if n != mw_name] + if current: + self._queue.set_setting( + self._key(task_name), json.dumps(current, separators=(",", ":")) + ) + else: + self._queue.delete_setting(self._key(task_name)) + return current + + def clear_for(self, task_name: str) -> bool: + return self._queue.delete_setting(self._key(task_name)) diff --git a/py_src/taskito/dashboard/routes.py b/py_src/taskito/dashboard/routes.py index cb39515..a4ab793 100644 --- a/py_src/taskito/dashboard/routes.py +++ b/py_src/taskito/dashboard/routes.py @@ -38,6 +38,12 @@ ) from taskito.dashboard.handlers.logs import _handle_logs from taskito.dashboard.handlers.metrics import _handle_metrics, _handle_metrics_timeseries +from taskito.dashboard.handlers.middleware import ( + handle_delete_task_middleware, + handle_get_task_middleware, + handle_list_middleware, + handle_put_task_middleware, +) from taskito.dashboard.handlers.overrides import ( handle_delete_queue_override, handle_delete_task_override, @@ -117,6 +123,7 @@ "/api/event-types": handle_list_event_types, "/api/tasks": handle_list_tasks, "/api/queues": handle_list_queues, + "/api/middleware": handle_list_middleware, } # ── Parameterized GET routes: regex → handler(queue, qs, captured_id) ── @@ -138,6 +145,7 @@ (re.compile(r"^/api/webhooks/([^/]+)$"), handle_get_webhook), (re.compile(r"^/api/tasks/([^/]+)/override$"), handle_get_task_override), (re.compile(r"^/api/queues/([^/]+)/override$"), handle_get_queue_override), + (re.compile(r"^/api/tasks/([^/]+)/middleware$"), handle_get_task_middleware), ] # GET routes with 2 captured groups (handler signature: queue, qs, (g1, g2)) @@ -212,12 +220,21 @@ (re.compile(r"^/api/queues/([^/]+)/override$"), handle_put_queue_override), ] +# PUT routes with 2 captured groups (handler signature: queue, body, (g1, g2)) +PUT_PARAM2_ROUTES: list[tuple[re.Pattern, Any]] = [ + ( + re.compile(r"^/api/tasks/([^/]+)/middleware/([^/]+)$"), + handle_put_task_middleware, + ), +] + # ── Parameterized DELETE routes: regex → handler(queue, captured_id) ── DELETE_PARAM_ROUTES: list[tuple[re.Pattern, Any]] = [ (re.compile(r"^/api/settings/(.+)$"), _handle_delete_setting), (re.compile(r"^/api/webhooks/([^/]+)$"), handle_delete_webhook), (re.compile(r"^/api/tasks/([^/]+)/override$"), handle_delete_task_override), (re.compile(r"^/api/queues/([^/]+)/override$"), handle_delete_queue_override), + (re.compile(r"^/api/tasks/([^/]+)/middleware$"), handle_delete_task_middleware), ] diff --git a/py_src/taskito/dashboard/server.py b/py_src/taskito/dashboard/server.py index a4945f5..2aa1753 100644 --- a/py_src/taskito/dashboard/server.py +++ b/py_src/taskito/dashboard/server.py @@ -43,6 +43,7 @@ POST_PARAM_ROUTES, POST_ROUTES, PUBLIC_PATHS, + PUT_PARAM2_ROUTES, PUT_PARAM_ROUTES, is_csrf_exempt, is_state_changing_method, @@ -292,6 +293,17 @@ def _handle_put(self) -> None: param_handler, lambda h, m=m, body=body: h(queue, body, m.group(1)) ) return + for pattern, param_handler in PUT_PARAM2_ROUTES: + m = pattern.match(path) + if m: + body = self._read_json_body() + if body is None: + return + self._dispatch_with_handler( + param_handler, + lambda h, m=m, body=body: h(queue, body, (m.group(1), m.group(2))), + ) + return self._json_response({"error": "Not found"}, status=404) def _handle_delete(self) -> None: diff --git a/py_src/taskito/middleware.py b/py_src/taskito/middleware.py index 8650641..077ff33 100644 --- a/py_src/taskito/middleware.py +++ b/py_src/taskito/middleware.py @@ -55,12 +55,20 @@ def after(self, ctx, result, error): print(f"Finished {ctx.task_name}: {status}") """ + #: Stable identifier used to refer to this middleware from the dashboard + #: when toggling it on/off per task. Defaults to the class' fully-qualified + #: name so it survives restarts. Override on a subclass to pin a + #: shorter / more user-facing name. + name: str = "" + def __init__( self, *, predicate: Predicate | Callable[..., Any] | None = None, ) -> None: self._predicate = coerce_predicate(predicate) + if not type(self).name: + type(self).name = f"{type(self).__module__}.{type(self).__qualname__}" def _should_apply(self, ctx: JobContext | None, task_name: str = "") -> bool: """Decide whether this middleware's hooks should fire for ``ctx``. diff --git a/py_src/taskito/mixins/__init__.py b/py_src/taskito/mixins/__init__.py index c8dfe41..2d54c05 100644 --- a/py_src/taskito/mixins/__init__.py +++ b/py_src/taskito/mixins/__init__.py @@ -5,6 +5,7 @@ from taskito.mixins.inspection import QueueInspectionMixin from taskito.mixins.lifecycle import QueueLifecycleMixin from taskito.mixins.locks import QueueLockMixin +from taskito.mixins.middleware_admin import QueueMiddlewareAdminMixin from taskito.mixins.operations import QueueOperationsMixin from taskito.mixins.overrides import QueueOverridesMixin from taskito.mixins.predicates import QueuePredicateMixin @@ -18,6 +19,7 @@ "QueueInspectionMixin", "QueueLifecycleMixin", "QueueLockMixin", + "QueueMiddlewareAdminMixin", "QueueOperationsMixin", "QueueOverridesMixin", "QueuePredicateMixin", diff --git a/py_src/taskito/mixins/decorators.py b/py_src/taskito/mixins/decorators.py index 671e9a2..e33c940 100644 --- a/py_src/taskito/mixins/decorators.py +++ b/py_src/taskito/mixins/decorators.py @@ -16,6 +16,7 @@ from taskito._taskito import PyTaskConfig from taskito.async_support.helpers import run_maybe_async from taskito.context import _clear_context, current_job +from taskito.dashboard.middleware_store import MiddlewareDisableStore from taskito.events import EventType from taskito.exceptions import TaskCancelledError from taskito.inject import Inject, _InjectAlias @@ -111,9 +112,18 @@ class QueueDecoratorMixin: _apply_dispatch_predicate: Callable[..., None] def _get_middleware_chain(self, task_name: str) -> list[TaskMiddleware]: - """Get the combined global + per-task middleware list.""" + """Get the combined global + per-task middleware list, minus any + middleware the operator has disabled for this task from the dashboard.""" per_task = self._task_middleware.get(task_name, []) - return self._global_middleware + per_task + chain = self._global_middleware + per_task + try: + disabled = MiddlewareDisableStore(self).get_for(task_name) # type: ignore[arg-type] + except Exception: # pragma: no cover - storage read failure is non-fatal + disabled = [] + if not disabled: + return chain + disabled_set = set(disabled) + return [mw for mw in chain if getattr(mw, "name", "") not in disabled_set] def _wrap_task( self, fn: Callable, task_name: str, soft_timeout: float | None = None diff --git a/py_src/taskito/mixins/middleware_admin.py b/py_src/taskito/mixins/middleware_admin.py new file mode 100644 index 0000000..dd9af80 --- /dev/null +++ b/py_src/taskito/mixins/middleware_admin.py @@ -0,0 +1,70 @@ +"""Middleware discovery and per-task disable management on :class:`Queue`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from taskito.dashboard.middleware_store import MiddlewareDisableStore + +if TYPE_CHECKING: + from taskito.middleware import TaskMiddleware + + +class QueueMiddlewareAdminMixin: + """Discovery + per-task enable/disable for registered middlewares.""" + + _global_middleware: list[TaskMiddleware] + _task_middleware: dict[str, list[TaskMiddleware]] + + # ── Discovery ────────────────────────────────────────────────── + + def list_middleware(self) -> list[dict[str, Any]]: + """Return every registered middleware (global + per-task) with its + name, source ("global" or task name), and Python class path. The + ``name`` is the value the disable list keys on.""" + seen: dict[str, dict[str, Any]] = {} + for mw in self._global_middleware: + name = getattr(mw, "name", "") or f"{type(mw).__module__}.{type(mw).__qualname__}" + seen.setdefault( + name, + { + "name": name, + "class_path": f"{type(mw).__module__}.{type(mw).__qualname__}", + "scopes": [], + }, + )["scopes"].append({"kind": "global"}) + for task_name, mws in self._task_middleware.items(): + for mw in mws: + name = getattr(mw, "name", "") or f"{type(mw).__module__}.{type(mw).__qualname__}" + entry = seen.setdefault( + name, + { + "name": name, + "class_path": f"{type(mw).__module__}.{type(mw).__qualname__}", + "scopes": [], + }, + ) + entry["scopes"].append({"kind": "task", "task": task_name}) + return sorted(seen.values(), key=lambda x: x["name"]) + + # ── Disable management ───────────────────────────────────────── + + def list_middleware_disables(self) -> dict[str, list[str]]: + """Return every task that has at least one disabled middleware.""" + return MiddlewareDisableStore(self).list_all() # type: ignore[arg-type] + + def get_disabled_middleware_for(self, task_name: str) -> list[str]: + return MiddlewareDisableStore(self).get_for(task_name) # type: ignore[arg-type] + + def disable_middleware_for_task(self, task_name: str, mw_name: str) -> list[str]: + return MiddlewareDisableStore(self).set_disabled( # type: ignore[arg-type] + task_name, mw_name, disabled=True + ) + + def enable_middleware_for_task(self, task_name: str, mw_name: str) -> list[str]: + return MiddlewareDisableStore(self).set_disabled( # type: ignore[arg-type] + task_name, mw_name, disabled=False + ) + + def clear_middleware_disables(self, task_name: str) -> bool: + return MiddlewareDisableStore(self).clear_for(task_name) # type: ignore[arg-type] diff --git a/tests/dashboard/test_middleware_toggles.py b/tests/dashboard/test_middleware_toggles.py new file mode 100644 index 0000000..6713f09 --- /dev/null +++ b/tests/dashboard/test_middleware_toggles.py @@ -0,0 +1,234 @@ +"""Tests for per-task middleware enable/disable from the dashboard.""" + +from __future__ import annotations + +import threading +import urllib.error +from collections.abc import Generator +from http.server import ThreadingHTTPServer +from pathlib import Path +from typing import Any + +import pytest + +from taskito import Queue +from taskito.context import JobContext +from taskito.dashboard import _make_handler +from taskito.dashboard._testing import AuthedClient, seed_admin_and_session +from taskito.dashboard.middleware_store import MiddlewareDisableStore +from taskito.middleware import TaskMiddleware + + +class RecordingMiddleware(TaskMiddleware): + """Captures every ``before`` invocation so the test can assert which + tasks the middleware fired for.""" + + name = "test.recording" + + def __init__(self) -> None: + super().__init__() + self.invocations: list[str] = [] + + def before(self, ctx: JobContext) -> None: + self.invocations.append(ctx.task_name) + + +class OtherMiddleware(TaskMiddleware): + name = "test.other" + + def __init__(self) -> None: + super().__init__() + self.invocations: list[str] = [] + + def before(self, ctx: JobContext) -> None: + self.invocations.append(ctx.task_name) + + +@pytest.fixture +def middleware_pair() -> tuple[RecordingMiddleware, OtherMiddleware]: + return RecordingMiddleware(), OtherMiddleware() + + +@pytest.fixture +def queue(tmp_path: Path, middleware_pair: tuple[RecordingMiddleware, OtherMiddleware]) -> Queue: + rec, other = middleware_pair + q = Queue(db_path=str(tmp_path / "mw.db"), middleware=[rec, other]) + + @q.task() + def alpha() -> str: + return "a" + + @q.task() + def beta() -> str: + return "b" + + return q + + +@pytest.fixture +def dashboard(queue: Queue) -> Generator[tuple[AuthedClient, Queue]]: + handler = _make_handler(queue) + server = ThreadingHTTPServer(("127.0.0.1", 0), handler) + threading.Thread(target=server.serve_forever, daemon=True).start() + session = seed_admin_and_session(queue) + client = AuthedClient(base=f"http://127.0.0.1:{server.server_address[1]}", session=session) + try: + yield client, queue + finally: + server.shutdown() + + +# ── Store ────────────────────────────────────────────────────────────── + + +def test_store_starts_empty(queue: Queue) -> None: + store = MiddlewareDisableStore(queue) + assert store.list_all() == {} + assert store.get_for("alpha") == [] + + +def test_set_disabled_adds_and_removes(queue: Queue) -> None: + store = MiddlewareDisableStore(queue) + store.set_disabled("alpha", "test.other", True) + assert store.get_for("alpha") == ["test.other"] + # Idempotent — same disable twice still has just one entry. + store.set_disabled("alpha", "test.other", True) + assert store.get_for("alpha") == ["test.other"] + # Re-enable clears just that one. + store.set_disabled("alpha", "test.other", False) + assert store.get_for("alpha") == [] + + +def test_clear_for_drops_setting_key(queue: Queue) -> None: + store = MiddlewareDisableStore(queue) + store.set_disabled("alpha", "test.other", True) + assert store.clear_for("alpha") is True + assert store.clear_for("alpha") is False + assert store.get_for("alpha") == [] + + +# ── Wiring into the middleware chain ────────────────────────────────── + + +def test_chain_skips_disabled_middleware(queue: Queue) -> None: + """``_get_middleware_chain`` returns a chain that respects the disable + list at lookup time — no worker restart required.""" + full = queue._get_middleware_chain("alpha") + assert {mw.name for mw in full} == {"test.recording", "test.other"} + queue.disable_middleware_for_task("alpha", "test.other") + filtered = queue._get_middleware_chain("alpha") + assert {mw.name for mw in filtered} == {"test.recording"} + # Other tasks unaffected. + assert {mw.name for mw in queue._get_middleware_chain("beta")} == { + "test.recording", + "test.other", + } + + +def test_clear_re_enables_all(queue: Queue) -> None: + queue.disable_middleware_for_task("alpha", "test.other") + queue.disable_middleware_for_task("alpha", "test.recording") + assert queue._get_middleware_chain("alpha") == [] + queue.clear_middleware_disables("alpha") + assert len(queue._get_middleware_chain("alpha")) == 2 + + +# ── Discovery ───────────────────────────────────────────────────────── + + +def test_list_middleware_reports_globals(queue: Queue) -> None: + items = queue.list_middleware() + names = {item["name"] for item in items} + assert {"test.recording", "test.other"} <= names + for entry in items: + assert any(scope["kind"] == "global" for scope in entry["scopes"]) + + +# ── HTTP endpoints ──────────────────────────────────────────────────── + + +def test_list_middleware_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + items = client.get("/api/middleware") + names = {item["name"] for item in items} + assert {"test.recording", "test.other"} <= names + + +def test_get_task_middleware_endpoint(dashboard: tuple[AuthedClient, Queue]) -> None: + client, _ = dashboard + result = client.get("/api/tasks/alpha/middleware") + by_name = {entry["name"]: entry for entry in result["middleware"]} + assert by_name["test.recording"]["disabled"] is False + assert by_name["test.recording"]["effective"] is True + + +def test_put_task_middleware_disables(dashboard: tuple[AuthedClient, Queue]) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + result = client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": False}) + assert "test.other" in result["disabled"] + # Reflected in the chain. + chain_names = {mw.name for mw in queue._get_middleware_chain(name)} + assert "test.other" not in chain_names + # Re-enabling clears it. + client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": True}) + chain_names = {mw.name for mw in queue._get_middleware_chain(name)} + assert "test.other" in chain_names + + +def test_put_task_middleware_rejects_unknown_middleware( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put(f"/api/tasks/{name}/middleware/not.a.real.mw", {"enabled": False}) + assert exc_info.value.code == 404 + + +def test_put_task_middleware_rejects_bad_body( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + with pytest.raises(urllib.error.HTTPError) as exc_info: + client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": "yes"}) + assert exc_info.value.code == 400 + + +def test_delete_task_middleware_clears_all( + dashboard: tuple[AuthedClient, Queue], +) -> None: + client, queue = dashboard + name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + client.put(f"/api/tasks/{name}/middleware/test.other", {"enabled": False}) + client.put(f"/api/tasks/{name}/middleware/test.recording", {"enabled": False}) + assert queue._get_middleware_chain(name) == [] + result = client.delete(f"/api/tasks/{name}/middleware") + assert result == {"cleared": True} + assert len(queue._get_middleware_chain(name)) == 2 + + +# ── End-to-end: disabled middleware doesn't fire ───────────────────── + + +def test_disabled_middleware_does_not_fire( + queue: Queue, + middleware_pair: tuple[RecordingMiddleware, OtherMiddleware], + poll_until: Any, +) -> None: + rec, other = middleware_pair + alpha_name = next(c.name for c in queue._task_configs if c.name.endswith("alpha")) + queue.disable_middleware_for_task(alpha_name, "test.other") + + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + try: + queue.enqueue(alpha_name) + poll_until(lambda: alpha_name in rec.invocations, message="task didn't run") + finally: + queue._inner.request_shutdown() + thread.join(timeout=5) + + assert alpha_name in rec.invocations # global fired + assert alpha_name not in other.invocations # disabled for this task From ed0a670fc9e2bc7e389c973479b99d84921a57ce Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Sun, 17 May 2026 07:00:33 +0530 Subject: [PATCH 2/2] feat(dashboard): add Middleware tab to task editor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The task override side-sheet now has two tabs: Overrides (existing form) and Middleware. The Middleware tab lists every middleware that fires for the selected task with an enabled/disabled toggle each. Toggling a middleware off takes effect on the next job — no worker restart required. --- .../tasks/components/middleware-toggles.tsx | 99 +++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 dashboard/src/features/tasks/components/middleware-toggles.tsx diff --git a/dashboard/src/features/tasks/components/middleware-toggles.tsx b/dashboard/src/features/tasks/components/middleware-toggles.tsx new file mode 100644 index 0000000..c61241d --- /dev/null +++ b/dashboard/src/features/tasks/components/middleware-toggles.tsx @@ -0,0 +1,99 @@ +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { Power } from "lucide-react"; +import { toast } from "sonner"; +import { ErrorState, Skeleton } from "@/components/ui"; +import { api } from "@/lib/api-client"; + +interface TaskMiddlewareEntry { + name: string; + class_path: string; + disabled: boolean; + effective: boolean; +} + +interface TaskMiddlewareResponse { + task: string; + middleware: TaskMiddlewareEntry[]; +} + +interface Props { + taskName: string; +} + +const queryKey = (task: string) => ["tasks", task, "middleware"] as const; + +export function MiddlewareToggles({ taskName }: Props) { + const qc = useQueryClient(); + const query = useQuery({ + queryKey: queryKey(taskName), + queryFn: ({ signal }) => + api.get(`/api/tasks/${encodeURIComponent(taskName)}/middleware`, { + signal, + }), + }); + + const mutation = useMutation({ + mutationFn: ({ mwName, enabled }: { mwName: string; enabled: boolean }) => + api.put( + `/api/tasks/${encodeURIComponent(taskName)}/middleware/${encodeURIComponent(mwName)}`, + { enabled }, + ), + onSuccess: async () => { + await qc.invalidateQueries({ queryKey: queryKey(taskName) }); + }, + onError: () => toast.error("Failed to update middleware"), + }); + + if (query.isLoading) { + return ; + } + if (query.error) { + return ( + + ); + } + const entries = query.data?.middleware ?? []; + if (entries.length === 0) { + return ( +
+ No middleware registered for this task. +
+ ); + } + + return ( +
    + {entries.map((entry) => { + const enabled = !entry.disabled; + return ( +
  • +
    +
    {entry.name}
    +
    {entry.class_path}
    +
    + +
  • + ); + })} +
+ ); +}