diff --git a/changelogs/fragments/69054-collection-as-str.yaml b/changelogs/fragments/69054-collection-as-str.yaml new file mode 100644 index 00000000000000..5e9845f99b254a --- /dev/null +++ b/changelogs/fragments/69054-collection-as-str.yaml @@ -0,0 +1,2 @@ +bugfixes: +- Make sure if a collection is supplied as a string that we transform it into a list. diff --git a/lib/ansible/playbook/collectionsearch.py b/lib/ansible/playbook/collectionsearch.py index 994e2e13e4c70a..d80b6a1c6a3efc 100644 --- a/lib/ansible/playbook/collectionsearch.py +++ b/lib/ansible/playbook/collectionsearch.py @@ -16,15 +16,13 @@ def _ensure_default_collection(collection_list=None): default_collection = AnsibleCollectionLoader().default_collection + # Will be None when used as the default if collection_list is None: collection_list = [] - if default_collection: # FIXME: exclude role tasks? - if isinstance(collection_list, string_types): - collection_list = [collection_list] - - if default_collection not in collection_list: - collection_list.insert(0, default_collection) + # FIXME: exclude role tasks? + if default_collection and default_collection not in collection_list: + collection_list.insert(0, default_collection) # if there's something in the list, ensure that builtin or legacy is always there too if collection_list and 'ansible.builtin' not in collection_list and 'ansible.legacy' not in collection_list: @@ -40,6 +38,10 @@ class CollectionSearch: always_post_validate=True, static=True) def _load_collections(self, attr, ds): + # We are always a mixin with Base, so we can validate this untemplated + # field early on to guarantee we are dealing with a list. + ds = self.get_validated_value('collections', self._collections, ds, None) + # this will only be called if someone specified a value; call the shared value _ensure_default_collection(collection_list=ds) @@ -47,8 +49,9 @@ def _load_collections(self, attr, ds): return None # This duplicates static attr checking logic from post_validate() - # because if the user attempts to template a collection name, it will - # error before it ever gets to the post_validate() warning. + # because if the user attempts to template a collection name, it may + # error before it ever gets to the post_validate() warning (e.g. trying + # to import a role from the collection). env = Environment() for collection_name in ds: if is_template(collection_name, env): diff --git a/lib/ansible/playbook/task.py b/lib/ansible/playbook/task.py index 99e3b7a38870e5..a2fb0d86322856 100644 --- a/lib/ansible/playbook/task.py +++ b/lib/ansible/playbook/task.py @@ -184,14 +184,14 @@ def preprocess_data(self, ds): # since this affects the task action parsing, we have to resolve in preprocess instead of in typical validator default_collection = AnsibleCollectionLoader().default_collection - # use the parent value if our ds doesn't define it - collections_list = ds.get('collections', self.collections) - + collections_list = ds.get('collections') if collections_list is None: - collections_list = [] - - if isinstance(collections_list, string_types): - collections_list = [collections_list] + # use the parent value if our ds doesn't define it + collections_list = self.collections + else: + # Validate this untemplated field early on to guarantee we are dealing with a list. + # This is also done in CollectionSearch._load_collections() but this runs before that call. + collections_list = self.get_validated_value('collections', self._collections, collections_list, None) if default_collection and not self._role: # FIXME: and not a collections role if collections_list: diff --git a/test/integration/targets/collections/posix.yml b/test/integration/targets/collections/posix.yml index 0d7c7089c4b591..61f950f50bfa22 100644 --- a/test/integration/targets/collections/posix.yml +++ b/test/integration/targets/collections/posix.yml @@ -406,3 +406,10 @@ hosts: testhost roles: - testns.testcoll.call_standalone + +# Issue https://github.com/ansible/ansible/issues/69054 +- name: Test collection as string + hosts: testhost + collections: foo + tasks: + - debug: msg="Test" diff --git a/test/units/playbook/test_collectionsearch.py b/test/units/playbook/test_collectionsearch.py index fe480d3de1e3ed..be40d85e303dcb 100644 --- a/test/units/playbook/test_collectionsearch.py +++ b/test/units/playbook/test_collectionsearch.py @@ -18,6 +18,10 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +from ansible.errors import AnsibleParserError +from ansible.playbook.play import Play +from ansible.playbook.task import Task +from ansible.playbook.block import Block from ansible.playbook.collectionsearch import CollectionSearch import pytest @@ -28,11 +32,47 @@ def test_collection_static_warning(capsys): Also, make sure that users see the warning message for the referenced name. """ - - collection_name = 'foo.{{bar}}' - cs = CollectionSearch() - assert collection_name in cs._load_collections(None, [collection_name]) - + collection_name = "foo.{{bar}}" + p = Play.load(dict( + name="test play", + hosts=['foo'], + gather_facts=False, + connection='local', + collections=collection_name, + )) + assert collection_name in p.collections std_out, std_err = capsys.readouterr() assert '[WARNING]: "collections" is not templatable, but we found: %s' % collection_name in std_err assert '' == std_out + + +def test_collection_invalid_data_play(): + """Test that collection as a dict at the play level fails with parser error""" + collection_name = {'name': 'foo'} + with pytest.raises(AnsibleParserError): + Play.load(dict( + name="test play", + hosts=['foo'], + gather_facts=False, + connection='local', + collections=collection_name, + )) + + +def test_collection_invalid_data_task(): + """Test that collection as a dict at the task level fails with parser error""" + collection_name = {'name': 'foo'} + with pytest.raises(AnsibleParserError): + Task.load(dict( + name="test task", + collections=collection_name, + )) + + +def test_collection_invalid_data_block(): + """Test that collection as a dict at the block level fails with parser error""" + collection_name = {'name': 'foo'} + with pytest.raises(AnsibleParserError): + Block.load(dict( + block=[dict(name="test task", collections=collection_name)] + )) diff --git a/test/units/playbook/test_helpers.py b/test/units/playbook/test_helpers.py index 2dc67eeebdf180..a4ed6178d4ed69 100644 --- a/test/units/playbook/test_helpers.py +++ b/test/units/playbook/test_helpers.py @@ -316,7 +316,7 @@ def test_one_include_not_static(self): # print(res) def test_one_bogus_include_role(self): - ds = [{'include_role': {'name': 'bogus_role'}}] + ds = [{'include_role': {'name': 'bogus_role'}, 'collections': []}] res = helpers.load_list_of_tasks(ds, play=self.mock_play, block=self.mock_block, variable_manager=self.mock_variable_manager, loader=self.fake_role_loader) @@ -324,7 +324,7 @@ def test_one_bogus_include_role(self): self._assert_is_task_list_or_blocks(res) def test_one_bogus_include_role_use_handlers(self): - ds = [{'include_role': {'name': 'bogus_role'}}] + ds = [{'include_role': {'name': 'bogus_role'}, 'collections': []}] res = helpers.load_list_of_tasks(ds, play=self.mock_play, use_handlers=True, block=self.mock_block, variable_manager=self.mock_variable_manager, @@ -395,7 +395,7 @@ def test_empty_block(self): loader=None) def test_block_unknown_action(self): - ds = [{'action': 'foo'}] + ds = [{'action': 'foo', 'collections': []}] mock_play = MagicMock(name='MockPlay') res = helpers.load_list_of_blocks(ds, mock_play, parent_block=None, role=None, task_include=None, use_handlers=False, variable_manager=None, loader=None)