From 5d4a79fd3d09942868682b3725c50b9619396a84 Mon Sep 17 00:00:00 2001 From: gopidesupavan Date: Sun, 29 Mar 2026 15:16:55 +0100 Subject: [PATCH] Fix worker CLI tests for optional pool argument --- .../unit/celery/cli/test_celery_command.py | 56 ++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/providers/celery/tests/unit/celery/cli/test_celery_command.py b/providers/celery/tests/unit/celery/cli/test_celery_command.py index cafd48f29a051..0c3f8069d6229 100644 --- a/providers/celery/tests/unit/celery/cli/test_celery_command.py +++ b/providers/celery/tests/unit/celery/cli/test_celery_command.py @@ -185,10 +185,64 @@ def test_worker_started_with_required_arguments(self, mock_celery_app, mock_pope autoscale, "--without-mingle", "--without-gossip", + ] + ) + + @mock.patch("airflow.providers.celery.cli.celery_command.maybe_patch_concurrency") + @mock.patch("airflow.providers.celery.cli.celery_command.conf") + @mock.patch("airflow.providers.celery.cli.celery_command.setup_locations") + @mock.patch("airflow.providers.celery.cli.celery_command.Process") + @mock.patch("airflow.providers.celery.executors.celery_executor.app") + def test_worker_started_with_pool_from_config( + self, + mock_celery_app, + mock_popen, + mock_locations, + mock_conf, + mock_maybe_patch_concurrency, + ): + pid_file = "pid_file" + mock_locations.return_value = (pid_file, None, None, None) + concurrency = "1" + queues = "queue" + pool = "prefork" + mock_conf.get.side_effect = lambda section, key, **kwargs: { + ("logging", "CELERY_LOGGING_LEVEL"): "INFO", + ("celery", "pool"): pool, + ("celery", "worker_umask"): "0o077", + }.get((section, key), kwargs.get("fallback")) + mock_conf.getint.return_value = 0 + mock_conf.has_option.side_effect = lambda section, key: (section, key) == ("celery", "pool") + + args = self.parser.parse_args( + [ + "celery", + "worker", + "--concurrency", + concurrency, + "--queues", + queues, + ] + ) + + celery_command.worker(args) + + mock_celery_app.worker_main.assert_called_once_with( + [ + "worker", + "-O", + "fair", + "--queues", + queues, + "--concurrency", + int(concurrency), + "--loglevel", + "INFO", "--pool", - "prefork", + pool, ] ) + mock_maybe_patch_concurrency.assert_called_once_with(["-P", pool]) @pytest.mark.backend("mysql", "postgres")