Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions backend/backend/worker_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,10 @@ class _WorkerDispatchCelery(Celery):
_explicit_broker: str | None = None

def connection_for_write(self, url=None, *args, **kwargs):
return super().connection_for_write(
url=url or self._explicit_broker, *args, **kwargs
)
return super().connection_for_write(url or self._explicit_broker, *args, **kwargs)

def connection_for_read(self, url=None, *args, **kwargs):
return super().connection_for_read(
url=url or self._explicit_broker, *args, **kwargs
)
return super().connection_for_read(url or self._explicit_broker, *args, **kwargs)


def get_worker_celery_app() -> Celery:
Expand Down
6 changes: 3 additions & 3 deletions backend/prompt_studio/prompt_studio_core_v2/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def test_task_status_url_registered(self):
assert "<str:task_id>" in str(url.pattern)

@patch("prompt_studio.prompt_studio_core_v2.views.AsyncResult", create=True)
def test_task_status_processing(self, MockAsyncResult):
def test_task_status_processing(self, mock_async_result):
"""Verify processing response for unfinished task."""
import inspect

Expand All @@ -370,7 +370,7 @@ def test_task_status_processing(self, MockAsyncResult):
assert '"processing"' in source

@patch("prompt_studio.prompt_studio_core_v2.views.AsyncResult", create=True)
def test_task_status_completed(self, MockAsyncResult):
def test_task_status_completed(self, mock_async_result):
"""Verify completed response structure."""
import inspect

Expand All @@ -382,7 +382,7 @@ def test_task_status_completed(self, MockAsyncResult):
assert "result.result" in source

@patch("prompt_studio.prompt_studio_core_v2.views.AsyncResult", create=True)
def test_task_status_failed(self, MockAsyncResult):
def test_task_status_failed(self, mock_async_result):
"""Verify failed response structure."""
import inspect

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,9 @@ def get_whisperer_params(
WhispererConfig.WAIT_TIMEOUT,
WhispererDefaults.WAIT_TIMEOUT,
),
WhispererConfig.WAIT_FOR_COMPLETION: WhispererDefaults.WAIT_FOR_COMPLETION,
WhispererConfig.WAIT_FOR_COMPLETION: (
WhispererDefaults.WAIT_FOR_COMPLETION
),
}
)
if params[WhispererConfig.MODE] == Modes.LOW_COST.value:
Expand Down
8 changes: 4 additions & 4 deletions unstract/sdk1/src/unstract/sdk1/execution/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class ExecutionDispatcher:
)
"""

def __init__(self, celery_app: Any = None) -> None:
def __init__(self, celery_app: object | None = None) -> None:
"""Initialize the dispatcher.

Args:
Expand Down Expand Up @@ -201,10 +201,10 @@ def dispatch_async(
def dispatch_with_callback(
self,
context: ExecutionContext,
on_success: Any = None,
on_error: Any = None,
on_success: object | None = None,
on_error: object | None = None,
task_id: str | None = None,
) -> Any:
) -> object:
"""Fire-and-forget dispatch with Celery link callbacks.

Sends the task to the executor queue and returns immediately.
Expand Down
10 changes: 5 additions & 5 deletions unstract/sdk1/tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class TestExecutionContext:
"""Tests for ExecutionContext serialization and validation."""

def _make_context(self, **overrides: Any) -> ExecutionContext:
def _make_context(self, **overrides: Any) -> ExecutionContext: # noqa: ANN401
"""Create a default ExecutionContext with optional overrides."""
defaults: dict[str, Any] = {
"executor_name": "legacy",
Expand Down Expand Up @@ -490,7 +490,7 @@ def _clean_registry(self: Self) -> None:
"""Ensure a clean registry for every test."""
ExecutorRegistry.clear()

def _make_context(self, **overrides: Any) -> ExecutionContext:
def _make_context(self, **overrides: Any) -> ExecutionContext: # noqa: ANN401
defaults: dict[str, Any] = {
"executor_name": "legacy",
"operation": "extract",
Expand Down Expand Up @@ -586,7 +586,7 @@ def execute(self, context: ExecutionContext) -> ExecutionResult:
class TestExecutionDispatcher:
"""Tests for ExecutionDispatcher (mocked Celery)."""

def _make_context(self, **overrides: Any) -> ExecutionContext:
def _make_context(self, **overrides: Any) -> ExecutionContext: # noqa: ANN401
defaults: dict[str, Any] = {
"executor_name": "legacy",
"operation": "extract",
Expand Down Expand Up @@ -917,7 +917,7 @@ def test_dispatch_with_callback_custom_task_id(
dispatcher = ExecutionDispatcher(celery_app=mock_app)
ctx = self._make_context()

result = dispatcher.dispatch_with_callback(ctx, task_id="pre-gen-id-123")
dispatcher.dispatch_with_callback(ctx, task_id="pre-gen-id-123")

call_kwargs = mock_app.send_task.call_args
assert call_kwargs[1]["task_id"] == "pre-gen-id-123"
Expand Down Expand Up @@ -974,7 +974,7 @@ def stream_log(
log: str,
level: LogLevel = LogLevel.INFO,
stage: str = "TOOL_RUN",
**kwargs: Any,
**kwargs: Any, # noqa: ANN401
) -> None:
_level_map = {
LogLevel.DEBUG: logging.DEBUG,
Expand Down
1 change: 0 additions & 1 deletion workers/file_processing/structure_tool_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def _execute_structure_tool_impl(params: dict) -> dict:
"""
# ---- Unpack params ----
organization_id = params["organization_id"]
workflow_id = params.get("workflow_id", "")
execution_id = params.get("execution_id", "")
file_execution_id = params["file_execution_id"]
tool_instance_metadata = params["tool_instance_metadata"]
Expand Down