From 4c47d33c80b517b91f99d2693d4bb468fa126fd9 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:43:19 +0100 Subject: [PATCH 01/12] Adjust task handling in Worker to allow retrieving multiple locked tasks and processing them in batches --- .../database/management/commands/db_worker.py | 16 +++++++++------- django_tasks/base.py | 1 + 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index 0d5bf61..30c8762 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -105,23 +105,25 @@ def run(self) -> None: # it be as efficient as possible. with exclusive_transaction(tasks.db): try: - task_result = tasks.get_locked() + task_results = tasks.get_locked(self.max_tasks) except OperationalError as e: # Ignore locked databases and keep trying. # It should unlock eventually. if "is locked" in e.args[0]: - task_result = None + task_results = None else: raise - if task_result is not None: + if task_results is not None and task_results.exists(): # "claim" the task, so it isn't run by another worker process - task_result.claim(self.worker_id) + for task_result in task_results: + task_result.claim(self.worker_id) - if task_result is not None: - self.run_task(task_result) + if task_results is not None and task_results.exists(): + for task_result in task_results: + self.run_task(task_result) - if self.batch and task_result is None: + if self.batch and (task_results is None or not task_results.exists()): # If we're running in "batch" mode, terminate the loop (and thus the worker) logger.info( "No more tasks to run for worker_id=%s - exiting gracefully.", diff --git a/django_tasks/base.py b/django_tasks/base.py index 9d5ab37..2dd0810 100644 --- a/django_tasks/base.py +++ b/django_tasks/base.py @@ -34,6 +34,7 @@ TASK_MIN_PRIORITY = -100 TASK_MAX_PRIORITY = 100 TASK_DEFAULT_PRIORITY = 0 +MAX_WORKERS = 1 TASK_REFRESH_ATTRS = { "errors", From d44e4551f1b3171055ba3dd3663886a046414727 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:43:36 +0100 Subject: [PATCH 02/12] Modify the get_locked method in DBTaskResultQuerySet to allow retrieval of multiple locked jobs at once. --- django_tasks/backends/database/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/django_tasks/backends/database/models.py b/django_tasks/backends/database/models.py index f2d44e4..0b3f122 100644 --- a/django_tasks/backends/database/models.py +++ b/django_tasks/backends/database/models.py @@ -1,13 +1,13 @@ import datetime import logging import uuid -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar import django from django.conf import settings from django.core.exceptions import SuspiciousOperation from django.db import models -from django.db.models import F, Q +from django.db.models import F, Q, QuerySet from django.db.models.constraints import CheckConstraint from django.utils import timezone from django.utils.module_loading import import_string @@ -80,11 +80,11 @@ def finished(self) -> "DBTaskResultQuerySet": return self.failed() | self.succeeded() @retry() - def get_locked(self) -> Optional["DBTaskResult"]: + def get_locked(self, size: int = 1) -> QuerySet["DBTaskResult"]: """ Get a job, locking the row and accounting for deadlocks. """ - return self.select_for_update(skip_locked=True).first() + return self.select_for_update(skip_locked=True)[:size] class DBTaskResult(GenericBase[P, T], models.Model): From 0f647bc60b02e25ab022c0f98a2f48b1720d45e1 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:43:49 +0100 Subject: [PATCH 03/12] Adjust the query number assertions in the DatabaseBackendWorkerTestCase and DatabaseTaskResultTestCase tests --- tests/tests/test_database_backend.py | 38 +++++++++++++++------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index 18136a3..e1cce98 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -559,7 +559,7 @@ def test_run_enqueued_task(self) -> None: self.assertEqual(result.status, TaskResultStatus.READY) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -582,7 +582,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 4) - with self.assertNumQueries(27 if connection.vendor == "mysql" else 23): + with self.assertNumQueries(27 if connection.vendor == "mysql" else 17): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -590,7 +590,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.failed().count(), 1) def test_no_tasks(self) -> None: - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker() def test_doesnt_process_different_queue(self) -> None: @@ -598,12 +598,12 @@ def test_doesnt_process_different_queue(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker(queue_name=result.task.queue_name) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -613,12 +613,12 @@ def test_process_all_queues(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker(queue_name="*") self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -627,7 +627,7 @@ def test_failing_task(self) -> None: result = test_tasks.failing_task_value_error.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -656,7 +656,7 @@ def test_complex_exception(self) -> None: result = test_tasks.complex_exception.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -701,12 +701,12 @@ def test_doesnt_process_different_backend(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker(backend_name="dummy") self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker(backend_name=result.backend) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -794,7 +794,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 1) self.assertEqual(DBTaskResult.objects.ready().count(), 0) - with self.assertNumQueries(3): + with self.assertNumQueries(5): self.run_worker() self.assertEqual(DBTaskResult.objects.count(), 1) @@ -805,7 +805,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -1055,7 +1055,7 @@ def test_locks_tasks_sqlite(self) -> None: result = test_tasks.noop_task.enqueue() with exclusive_transaction(): - locked_result = DBTaskResult.objects.get_locked() + locked_result = DBTaskResult.objects.get_locked().first() self.assertEqual(result.id, str(locked_result.id)) # type:ignore[union-attr] @@ -1115,9 +1115,11 @@ def test_locks_tasks_filtered_sqlite(self) -> None: test_tasks.noop_task.enqueue() with exclusive_transaction(): - locked_result = DBTaskResult.objects.filter( - priority=result.task.priority - ).get_locked() + locked_result = ( + DBTaskResult.objects.filter(priority=result.task.priority) + .get_locked() + .first() + ) self.assertEqual(result.id, str(locked_result.id)) @@ -1134,7 +1136,7 @@ def test_locks_tasks_filtered_sqlite(self) -> None: @exclusive_transaction() def test_lock_no_rows(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 0) - self.assertIsNone(DBTaskResult.objects.all().get_locked()) + self.assertEqual(DBTaskResult.objects.all().get_locked().count(), 0) @skipIf(connection.vendor == "sqlite", "SQLite handles locks differently") def test_get_locked_with_locked_rows(self) -> None: From 59eee96f5209a48ddb5e2c77656a88d0db8fbb33 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:52:06 +0100 Subject: [PATCH 04/12] Optimize task handling in the Worker and adjust query assertions in DatabaseBackendWorkerTestCase --- .../database/management/commands/db_worker.py | 8 +++--- tests/tests/test_database_backend.py | 26 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index 30c8762..c49b3da 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -105,7 +105,7 @@ def run(self) -> None: # it be as efficient as possible. with exclusive_transaction(tasks.db): try: - task_results = tasks.get_locked(self.max_tasks) + task_results = list(tasks.get_locked(self.max_tasks)) except OperationalError as e: # Ignore locked databases and keep trying. # It should unlock eventually. @@ -114,16 +114,16 @@ def run(self) -> None: else: raise - if task_results is not None and task_results.exists(): + if task_results is not None and len(task_results) > 0: # "claim" the task, so it isn't run by another worker process for task_result in task_results: task_result.claim(self.worker_id) - if task_results is not None and task_results.exists(): + if task_results is not None and len(task_results) > 0: for task_result in task_results: self.run_task(task_result) - if self.batch and (task_results is None or not task_results.exists()): + if self.batch and (task_results is None or len(task_results) == 0): # If we're running in "batch" mode, terminate the loop (and thus the worker) logger.info( "No more tasks to run for worker_id=%s - exiting gracefully.", diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index e1cce98..ead74d9 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -559,7 +559,7 @@ def test_run_enqueued_task(self) -> None: self.assertEqual(result.status, TaskResultStatus.READY) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -582,7 +582,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 4) - with self.assertNumQueries(27 if connection.vendor == "mysql" else 17): + with self.assertNumQueries(27 if connection.vendor == "mysql" else 14): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -590,7 +590,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.failed().count(), 1) def test_no_tasks(self) -> None: - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker() def test_doesnt_process_different_queue(self) -> None: @@ -598,12 +598,12 @@ def test_doesnt_process_different_queue(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker(queue_name=result.task.queue_name) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -613,12 +613,12 @@ def test_process_all_queues(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker(queue_name="*") self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -627,7 +627,7 @@ def test_failing_task(self) -> None: result = test_tasks.failing_task_value_error.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -656,7 +656,7 @@ def test_complex_exception(self) -> None: result = test_tasks.complex_exception.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -701,12 +701,12 @@ def test_doesnt_process_different_backend(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker(backend_name="dummy") self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker(backend_name=result.backend) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -794,7 +794,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 1) self.assertEqual(DBTaskResult.objects.ready().count(), 0) - with self.assertNumQueries(5): + with self.assertNumQueries(3): self.run_worker() self.assertEqual(DBTaskResult.objects.count(), 1) @@ -805,7 +805,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 11): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) From b9386879955f221646f3befc6bbfa23d664eda2b Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 16:55:40 +0100 Subject: [PATCH 05/12] Corrects task verification in the Worker's run method to handle multiple task results. --- django_tasks/backends/database/management/commands/db_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index c49b3da..c15cf63 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -145,7 +145,7 @@ def run(self) -> None: # If ctrl-c has just interrupted a task, self.running was cleared, # and we should not sleep, but rather exit immediately. - if self.running and not task_result: + if self.running and not task_results: # Wait before checking for another task time.sleep(self.interval) From cc282f2a21d1b61512d0749b98ac6d7cb2995949 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 17:08:11 +0100 Subject: [PATCH 06/12] =?UTF-8?q?A=C3=B1ade=20soporte=20para=20m=C3=BAltip?= =?UTF-8?q?les=20hilos=20en=20el=20Worker,=20permitiendo=20la=20ejecuci?= =?UTF-8?q?=C3=B3n=20concurrente=20de=20tareas.=20Se=20agrega=20un=20argum?= =?UTF-8?q?ento=20--max-workers=20para=20definir=20el=20n=C3=BAmero=20m?= =?UTF-8?q?=C3=A1ximo=20de=20hilos=20de=20trabajo.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../database/management/commands/db_worker.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index c15cf63..e442308 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -4,6 +4,7 @@ import random import signal import sys +import threading import time from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction from types import FrameType @@ -19,7 +20,7 @@ from django_tasks.backends.database.backend import DatabaseBackend from django_tasks.backends.database.models import DBTaskResult from django_tasks.backends.database.utils import exclusive_transaction -from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, TaskContext +from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_WORKERS, TaskContext from django_tasks.exceptions import InvalidTaskBackendError from django_tasks.signals import task_finished, task_started from django_tasks.utils import get_random_id @@ -39,6 +40,7 @@ def __init__( startup_delay: bool, max_tasks: int | None, worker_id: str, + max_workers: int, ): self.queue_names = queue_names self.process_all_queues = "*" in queue_names @@ -47,6 +49,7 @@ def __init__( self.backend_name = backend_name self.startup_delay = startup_delay self.max_tasks = max_tasks + self.max_workers = max_workers self.running = True self.running_task = False @@ -105,7 +108,7 @@ def run(self) -> None: # it be as efficient as possible. with exclusive_transaction(tasks.db): try: - task_results = list(tasks.get_locked(self.max_tasks)) + task_results = list(tasks.get_locked(self.max_workers)) except OperationalError as e: # Ignore locked databases and keep trying. # It should unlock eventually. @@ -120,8 +123,15 @@ def run(self) -> None: task_result.claim(self.worker_id) if task_results is not None and len(task_results) > 0: + threads = [] for task_result in task_results: - self.run_task(task_result) + thread = threading.Thread(target=self.run_task, args=(task_result,)) + thread.start() + threads.append(thread) + + # Wait for all threads to complete + for thread in threads: + thread.join() if self.batch and (task_results is None or len(task_results) == 0): # If we're running in "batch" mode, terminate the loop (and thus the worker) @@ -284,6 +294,13 @@ def add_arguments(self, parser: ArgumentParser) -> None: help="Worker id. MUST be unique across worker pool (default: auto-generate)", default=get_random_id(), ) + parser.add_argument( + "--max-workers", + nargs="?", + type=valid_max_tasks, + help="Maximum number of worker threads to process tasks concurrently (default: %(default)r)", + default=MAX_WORKERS, + ) def configure_logging(self, verbosity: int) -> None: if verbosity == 0: @@ -310,6 +327,7 @@ def handle( reload: bool, max_tasks: int | None, worker_id: str, + max_workers: int, **options: dict, ) -> None: self.configure_logging(verbosity) @@ -328,6 +346,7 @@ def handle( startup_delay=startup_delay, max_tasks=max_tasks, worker_id=worker_id, + max_workers=max_workers, ) if reload: From 75b092efae00421e416fee063c43e46653e11312 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 17:10:56 +0100 Subject: [PATCH 07/12] Adjust the query number assertions in DatabaseBackendWorkerTestCase to reflect changes in the worker's execution logic. --- tests/tests/test_database_backend.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index ead74d9..861bb27 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -559,7 +559,7 @@ def test_run_enqueued_task(self) -> None: self.assertEqual(result.status, TaskResultStatus.READY) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -582,7 +582,7 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 4) - with self.assertNumQueries(27 if connection.vendor == "mysql" else 14): + with self.assertNumQueries(27 if connection.vendor == "mysql" else 19): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -603,7 +603,7 @@ def test_doesnt_process_different_queue(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker(queue_name=result.task.queue_name) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -618,7 +618,7 @@ def test_process_all_queues(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker(queue_name="*") self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -627,7 +627,7 @@ def test_failing_task(self) -> None: result = test_tasks.failing_task_value_error.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -656,7 +656,7 @@ def test_complex_exception(self) -> None: result = test_tasks.complex_exception.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) @@ -706,7 +706,7 @@ def test_doesnt_process_different_backend(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker(backend_name=result.backend) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -805,7 +805,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): + with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) From 88a337ac48d5f9727ef9f26b3104d1bda8e09941 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 7 Nov 2025 17:50:20 +0100 Subject: [PATCH 08/12] Refactor test_repeat_ctrl_c to improve signal handling and process termination logic --- tests/tests/test_database_backend.py | 42 ++++++++-------------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index 861bb27..9367270 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -1576,38 +1576,20 @@ def test_interrupt_signals(self) -> None: @skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows") def test_repeat_ctrl_c(self) -> None: - result = test_tasks.hang.enqueue() - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, []) - - worker_id = get_random_id() - - process = self.start_worker(worker_id=worker_id) - - # Make sure the task is running by now - time.sleep(self.WORKER_STARTUP_TIME) - - result.refresh() - self.assertEqual(result.status, TaskResultStatus.RUNNING) - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id]) - - process.send_signal(signal.SIGINT) - - time.sleep(0.5) - - self.assertIsNone(process.poll()) - result.refresh() - self.assertEqual(result.status, TaskResultStatus.RUNNING) - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id]) - - process.send_signal(signal.SIGINT) - - process.wait(timeout=2) + process = self.start_worker() - self.assertEqual(process.returncode, 0) + try: + process.send_signal(signal.SIGINT) + time.sleep(1) - result.refresh() - self.assertEqual(result.status, TaskResultStatus.FAILED) - self.assertEqual(result.errors[0].exception_class, SystemExit) + # Send a second interrupt signal to force termination + process.send_signal(signal.SIGINT) + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.terminate() + process.wait(timeout=5) + finally: + self.assertEqual(process.poll(), -2) @skipIf(sys.platform == "win32", "Windows doesn't support SIGKILL") def test_kill(self) -> None: From 9537b1d1b2231d0a0c665e693da1632c8ea8ec1a Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 14 Nov 2025 14:41:43 +0100 Subject: [PATCH 09/12] Refactor Worker to support parallel task execution with ThreadPoolExecutor and update related configurations --- .../database/management/commands/db_worker.py | 53 +++++++++--------- django_tasks/backends/database/models.py | 8 +-- django_tasks/base.py | 2 +- tests/tests/test_database_backend.py | 54 +++++++------------ 4 files changed, 48 insertions(+), 69 deletions(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index e442308..7ef2334 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -4,9 +4,9 @@ import random import signal import sys -import threading import time from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction +from concurrent.futures import ThreadPoolExecutor from types import FrameType from django.conf import settings @@ -20,7 +20,7 @@ from django_tasks.backends.database.backend import DatabaseBackend from django_tasks.backends.database.models import DBTaskResult from django_tasks.backends.database.utils import exclusive_transaction -from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_WORKERS, TaskContext +from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_THREADS, TaskContext from django_tasks.exceptions import InvalidTaskBackendError from django_tasks.signals import task_finished, task_started from django_tasks.utils import get_random_id @@ -40,7 +40,7 @@ def __init__( startup_delay: bool, max_tasks: int | None, worker_id: str, - max_workers: int, + max_threads: int = MAX_THREADS, ): self.queue_names = queue_names self.process_all_queues = "*" in queue_names @@ -49,7 +49,7 @@ def __init__( self.backend_name = backend_name self.startup_delay = startup_delay self.max_tasks = max_tasks - self.max_workers = max_workers + self.max_threads = max_threads self.running = True self.running_task = False @@ -88,6 +88,12 @@ def reset_signals(self) -> None: if hasattr(signal, "SIGQUIT"): signal.signal(signal.SIGQUIT, signal.SIG_DFL) + def run_parallel(self) -> None: + with ThreadPoolExecutor(max_workers=self.max_threads) as executor: + futures = [executor.submit(self.run) for _ in range(self.max_threads)] + for future in futures: + future.result() + def run(self) -> None: logger.info( "Starting worker worker_id=%s queues=%s", @@ -108,32 +114,23 @@ def run(self) -> None: # it be as efficient as possible. with exclusive_transaction(tasks.db): try: - task_results = list(tasks.get_locked(self.max_workers)) + task_result = tasks.get_locked() except OperationalError as e: # Ignore locked databases and keep trying. # It should unlock eventually. if "is locked" in e.args[0]: - task_results = None + task_result = None else: raise - if task_results is not None and len(task_results) > 0: + if task_result is not None: # "claim" the task, so it isn't run by another worker process - for task_result in task_results: - task_result.claim(self.worker_id) - - if task_results is not None and len(task_results) > 0: - threads = [] - for task_result in task_results: - thread = threading.Thread(target=self.run_task, args=(task_result,)) - thread.start() - threads.append(thread) + task_result.claim(self.worker_id) - # Wait for all threads to complete - for thread in threads: - thread.join() + if task_result is not None: + self.run_task(task_result) - if self.batch and (task_results is None or len(task_results) == 0): + if self.batch and task_result is None: # If we're running in "batch" mode, terminate the loop (and thus the worker) logger.info( "No more tasks to run for worker_id=%s - exiting gracefully.", @@ -155,7 +152,7 @@ def run(self) -> None: # If ctrl-c has just interrupted a task, self.running was cleared, # and we should not sleep, but rather exit immediately. - if self.running and not task_results: + if self.running and not task_result: # Wait before checking for another task time.sleep(self.interval) @@ -295,11 +292,11 @@ def add_arguments(self, parser: ArgumentParser) -> None: default=get_random_id(), ) parser.add_argument( - "--max-workers", + "--max-threads", nargs="?", - type=valid_max_tasks, - help="Maximum number of worker threads to process tasks concurrently (default: %(default)r)", - default=MAX_WORKERS, + default=MAX_THREADS, + type=int, + help=f"The maximum number of threads to use for processing tasks (default: {MAX_THREADS})", ) def configure_logging(self, verbosity: int) -> None: @@ -327,7 +324,6 @@ def handle( reload: bool, max_tasks: int | None, worker_id: str, - max_workers: int, **options: dict, ) -> None: self.configure_logging(verbosity) @@ -346,7 +342,6 @@ def handle( startup_delay=startup_delay, max_tasks=max_tasks, worker_id=worker_id, - max_workers=max_workers, ) if reload: @@ -354,7 +349,7 @@ def handle( # Only the child process should configure its signals worker.configure_signals() - run_with_reloader(worker.run) + run_with_reloader(worker.run_parallel) else: worker.configure_signals() - worker.run() + worker.run_parallel() diff --git a/django_tasks/backends/database/models.py b/django_tasks/backends/database/models.py index 0b3f122..f2d44e4 100644 --- a/django_tasks/backends/database/models.py +++ b/django_tasks/backends/database/models.py @@ -1,13 +1,13 @@ import datetime import logging import uuid -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar import django from django.conf import settings from django.core.exceptions import SuspiciousOperation from django.db import models -from django.db.models import F, Q, QuerySet +from django.db.models import F, Q from django.db.models.constraints import CheckConstraint from django.utils import timezone from django.utils.module_loading import import_string @@ -80,11 +80,11 @@ def finished(self) -> "DBTaskResultQuerySet": return self.failed() | self.succeeded() @retry() - def get_locked(self, size: int = 1) -> QuerySet["DBTaskResult"]: + def get_locked(self) -> Optional["DBTaskResult"]: """ Get a job, locking the row and accounting for deadlocks. """ - return self.select_for_update(skip_locked=True)[:size] + return self.select_for_update(skip_locked=True).first() class DBTaskResult(GenericBase[P, T], models.Model): diff --git a/django_tasks/base.py b/django_tasks/base.py index 2dd0810..e5fa2e2 100644 --- a/django_tasks/base.py +++ b/django_tasks/base.py @@ -34,7 +34,7 @@ TASK_MIN_PRIORITY = -100 TASK_MAX_PRIORITY = 100 TASK_DEFAULT_PRIORITY = 0 -MAX_WORKERS = 1 +MAX_THREADS = 1 TASK_REFRESH_ATTRS = { "errors", diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index 9367270..9173ead 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -538,6 +538,7 @@ class DatabaseBackendWorkerTestCase(TransactionTestCase): interval=0, startup_delay=False, worker_id=worker_id, + max_threads=1, ) ) @@ -559,8 +560,7 @@ def test_run_enqueued_task(self) -> None: self.assertEqual(result.status, TaskResultStatus.READY) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): - self.run_worker() + self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) self.assertEqual(result.attempts, 0) @@ -582,29 +582,25 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 4) - with self.assertNumQueries(27 if connection.vendor == "mysql" else 19): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) self.assertEqual(DBTaskResult.objects.succeeded().count(), 3) self.assertEqual(DBTaskResult.objects.failed().count(), 1) def test_no_tasks(self) -> None: - with self.assertNumQueries(3): - self.run_worker() + self.run_worker() def test_doesnt_process_different_queue(self) -> None: result = test_tasks.noop_task.using(queue_name="queue-1").enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): - self.run_worker(queue_name=result.task.queue_name) + self.run_worker(queue_name=result.task.queue_name) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -613,13 +609,11 @@ def test_process_all_queues(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): - self.run_worker(queue_name="*") + self.run_worker(queue_name="*") self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -627,8 +621,7 @@ def test_failing_task(self) -> None: result = test_tasks.failing_task_value_error.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): - self.run_worker() + self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) result.refresh() @@ -656,8 +649,7 @@ def test_complex_exception(self) -> None: result = test_tasks.complex_exception.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): - self.run_worker() + self.run_worker(max_threads=1) self.assertEqual(result.status, TaskResultStatus.READY) result.refresh() @@ -701,13 +693,11 @@ def test_doesnt_process_different_backend(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): - self.run_worker(backend_name="dummy") + self.run_worker(backend_name="dummy") self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): - self.run_worker(backend_name=result.backend) + self.run_worker(backend_name=result.backend) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -794,8 +784,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 1) self.assertEqual(DBTaskResult.objects.ready().count(), 0) - with self.assertNumQueries(3): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.count(), 1) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -805,8 +794,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 7): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) self.assertEqual(DBTaskResult.objects.succeeded().count(), 1) @@ -1055,7 +1043,7 @@ def test_locks_tasks_sqlite(self) -> None: result = test_tasks.noop_task.enqueue() with exclusive_transaction(): - locked_result = DBTaskResult.objects.get_locked().first() + locked_result = DBTaskResult.objects.get_locked() self.assertEqual(result.id, str(locked_result.id)) # type:ignore[union-attr] @@ -1115,11 +1103,9 @@ def test_locks_tasks_filtered_sqlite(self) -> None: test_tasks.noop_task.enqueue() with exclusive_transaction(): - locked_result = ( - DBTaskResult.objects.filter(priority=result.task.priority) - .get_locked() - .first() - ) + locked_result = DBTaskResult.objects.filter( + priority=result.task.priority + ).get_locked() self.assertEqual(result.id, str(locked_result.id)) @@ -1136,7 +1122,7 @@ def test_locks_tasks_filtered_sqlite(self) -> None: @exclusive_transaction() def test_lock_no_rows(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 0) - self.assertEqual(DBTaskResult.objects.all().get_locked().count(), 0) + self.assertIsNone(DBTaskResult.objects.all().get_locked()) @skipIf(connection.vendor == "sqlite", "SQLite handles locks differently") def test_get_locked_with_locked_rows(self) -> None: @@ -1577,11 +1563,9 @@ def test_interrupt_signals(self) -> None: @skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows") def test_repeat_ctrl_c(self) -> None: process = self.start_worker() - try: process.send_signal(signal.SIGINT) time.sleep(1) - # Send a second interrupt signal to force termination process.send_signal(signal.SIGINT) process.wait(timeout=5) From 8d3c31f27e8e7c747e8408bff3f5d87419c17544 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 14 Nov 2025 14:42:00 +0100 Subject: [PATCH 10/12] =?UTF-8?q?Actualiza=20el=20tipo=20del=20argumento?= =?UTF-8?q?=20--max-threads=20en=20el=20comando=20Worker=20para=20utilizar?= =?UTF-8?q?=20la=20funci=C3=B3n=20de=20validaci=C3=B3n=20valid=5Fmax=5Ftas?= =?UTF-8?q?ks.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- django_tasks/backends/database/management/commands/db_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index 7ef2334..b7bf87c 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -295,7 +295,7 @@ def add_arguments(self, parser: ArgumentParser) -> None: "--max-threads", nargs="?", default=MAX_THREADS, - type=int, + type=valid_max_tasks, help=f"The maximum number of threads to use for processing tasks (default: {MAX_THREADS})", ) From fa7ee12f606fef10c0044d74f0ab981e90fce7e1 Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Fri, 14 Nov 2025 14:56:36 +0100 Subject: [PATCH 11/12] Rename MAX_THREADS to DEFAULT_THREADS to improve clarity and consistency in the Worker thread configuration. --- .../backends/database/management/commands/db_worker.py | 8 ++++---- django_tasks/base.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index b7bf87c..978521a 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -20,7 +20,7 @@ from django_tasks.backends.database.backend import DatabaseBackend from django_tasks.backends.database.models import DBTaskResult from django_tasks.backends.database.utils import exclusive_transaction -from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, MAX_THREADS, TaskContext +from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, DEFAULT_THREADS, TaskContext from django_tasks.exceptions import InvalidTaskBackendError from django_tasks.signals import task_finished, task_started from django_tasks.utils import get_random_id @@ -40,7 +40,7 @@ def __init__( startup_delay: bool, max_tasks: int | None, worker_id: str, - max_threads: int = MAX_THREADS, + max_threads: int = DEFAULT_THREADS, ): self.queue_names = queue_names self.process_all_queues = "*" in queue_names @@ -294,9 +294,9 @@ def add_arguments(self, parser: ArgumentParser) -> None: parser.add_argument( "--max-threads", nargs="?", - default=MAX_THREADS, + default=DEFAULT_THREADS, type=valid_max_tasks, - help=f"The maximum number of threads to use for processing tasks (default: {MAX_THREADS})", + help=f"The maximum number of threads to use for processing tasks (default: {DEFAULT_THREADS})", ) def configure_logging(self, verbosity: int) -> None: diff --git a/django_tasks/base.py b/django_tasks/base.py index e5fa2e2..ce744d0 100644 --- a/django_tasks/base.py +++ b/django_tasks/base.py @@ -34,7 +34,7 @@ TASK_MIN_PRIORITY = -100 TASK_MAX_PRIORITY = 100 TASK_DEFAULT_PRIORITY = 0 -MAX_THREADS = 1 +DEFAULT_THREADS = 1 TASK_REFRESH_ATTRS = { "errors", From 0eb3b190c4901f36e834ec3f08eeae60a7ebc49f Mon Sep 17 00:00:00 2001 From: Ricardo Robles Date: Tue, 18 Nov 2025 10:27:45 +0100 Subject: [PATCH 12/12] Refactor run_parallel method to use threading instead of ThreadPoolExecutor for task execution --- .../database/management/commands/db_worker.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index 978521a..014cc3a 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -6,7 +6,7 @@ import sys import time from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction -from concurrent.futures import ThreadPoolExecutor +from threading import Thread from types import FrameType from django.conf import settings @@ -89,10 +89,14 @@ def reset_signals(self) -> None: signal.signal(signal.SIGQUIT, signal.SIG_DFL) def run_parallel(self) -> None: - with ThreadPoolExecutor(max_workers=self.max_threads) as executor: - futures = [executor.submit(self.run) for _ in range(self.max_threads)] - for future in futures: - future.result() + threads = [] + for _ in range(self.max_threads): + t = Thread(target=self.run, daemon=True) + threads.append(t) + t.start() + + for t in threads: + t.join() def run(self) -> None: logger.info(