diff --git a/agent/skyhook-agent/src/skyhook_agent/controller.py b/agent/skyhook-agent/src/skyhook_agent/controller.py index b8cf6089..c8699bb4 100644 --- a/agent/skyhook-agent/src/skyhook_agent/controller.py +++ b/agent/skyhook-agent/src/skyhook_agent/controller.py @@ -339,14 +339,22 @@ def summarize_check_results(results: list[bool], step_data: dict[Mode, list[Step return False -def do_interrupt(interrupt_data: str, root_mount: str, copy_dir: str, on_host: bool, config_data: dict) -> bool: +def do_interrupt(interrupt_data: str, root_mount: str, copy_dir: str, on_host: bool) -> bool: """ Run an interrupt if there hasn't been an interrupt already for the skyhook ID. """ - SKYHOOK_RESOURCE_ID, _ = _get_env_config() + # Interrupts don't really have config data we can read from the Package as it is run standalone. + # So read it off of SKYHOOK_RESOURCE_ID instead + # customer-f5a1d42e-74e5-4606-8bbc-b504fbe0074d-1_tuning_2.0.2 + _, package, version = SKYHOOK_RESOURCE_ID.split("_") + config_data = { + "package_name": package, + "package_version": version, + } + interrupt = interrupts.inflate(interrupt_data) # Check if the interrupt has already been run for this particular skyhook resource @@ -392,6 +400,9 @@ def main(mode: Mode, root_mount: str, copy_dir: str, interrupt_data: None|str, a logger.warning(f"This version of the Agent doesn't support the {mode} mode. Options are: {','.join(map(str, Mode))}.") return False + if mode == Mode.INTERRUPT: + return do_interrupt(interrupt_data, root_mount, copy_dir, True) + _, SKYHOOK_DATA_DIR = _get_env_config() # Check to see if the directory has already been copied down. If it hasn't assume that we @@ -421,9 +432,6 @@ def main(mode: Mode, root_mount: str, copy_dir: str, interrupt_data: None|str, a return agent_main(mode, root_mount, copy_dir, config_data, interrupt_data, always_run_step) def agent_main(mode: Mode, root_mount: str, copy_dir: str, config_data: dict, interrupt_data: None|str, always_run_step=False): - - if mode == Mode.INTERRUPT: - return do_interrupt(interrupt_data, root_mount, copy_dir, True, config_data) # Pull out step_data so it matches with existing code step_data = config_data["modes"] @@ -506,7 +514,7 @@ def cli(sys_argv: list[str]=sys.argv): # new way with interrupt data mode, root_mount, copy_dir, interrupt_data = args - if os.getenv("COPY_RESOLVE", "true").lower() == "true": + if os.getenv("COPY_RESOLV", "true").lower() == "true": shutil.copyfile("/etc/resolv.conf", f"{root_mount}/etc/resolv.conf") always_run_step = os.getenv("OVERLAY_ALWAYS_RUN_STEP", "false").lower() == "true" diff --git a/agent/skyhook-agent/tests/test_controller.py b/agent/skyhook-agent/tests/test_controller.py index 5b75d00c..23016d27 100644 --- a/agent/skyhook-agent/tests/test_controller.py +++ b/agent/skyhook-agent/tests/test_controller.py @@ -1088,13 +1088,18 @@ def test_interrupt_applies_all_commands(self, run_mock, datetime_mock): ], } with self._setup_for_main(steps) as (container_root_dir, config_data, root_dir): - controller.main( - Mode.INTERRUPT, - root_dir, - "copy_dir", - interrupts.ServiceRestart(["containerd",]).make_controller_input() - ) + with set_env(SKYHOOK_RESOURCE_ID="scr-id-1_package_version"): + controller.main( + Mode.INTERRUPT, + root_dir, + "copy_dir", + interrupts.ServiceRestart(["containerd",]).make_controller_input() + ) + config_data = { + "package_name": "package", + "package_version": "version" + } run_mock.assert_has_calls([ mock.call(["systemctl", "daemon-reload"], controller.get_log_file(root_dir, "interrupts/service_restart_0", "copy_dir", config_data), on_host=True, root_mount=root_dir, write_cmds=True), mock.call(["systemctl", "restart", "containerd"], controller.get_log_file(root_dir, "interrupts/service_restart_1", "copy_dir", config_data), on_host=True, root_mount=root_dir, write_cmds=True) @@ -1103,25 +1108,25 @@ def test_interrupt_applies_all_commands(self, run_mock, datetime_mock): @mock.patch("skyhook_agent.controller._run") def test_interrupt_isnt_run_when_skyhook_resource_id_flag_is_there(self, run_mock): run_mock.return_value = 0 - SKYHOOK_RESOURCE_ID="foo" + SKYHOOK_RESOURCE_ID="scr-id-1_package_version" with (tempfile.TemporaryDirectory() as dir, set_env(SKYHOOK_RESOURCE_ID=SKYHOOK_RESOURCE_ID)): os.makedirs(f"{controller.get_skyhook_directory(dir)}/interrupts/flags/{SKYHOOK_RESOURCE_ID}", exist_ok=True) with open(f"{controller.get_skyhook_directory(dir)}/interrupts/flags/{SKYHOOK_RESOURCE_ID}/node_restart_0.complete", 'w') as f: f.write("") - controller.do_interrupt(interrupts.NodeRestart().make_controller_input(), dir, "copy_dir", on_host=False, config_data=self.config_data) + controller.do_interrupt(interrupts.NodeRestart().make_controller_input(), dir, "copy_dir", on_host=False) run_mock.assert_not_called() @mock.patch("skyhook_agent.controller._run") def test_interrupt_create_flags_per_cmd(self, run_mock): run_mock.return_value = 0 - SKYHOOK_RESOURCE_ID="foo" + SKYHOOK_RESOURCE_ID="scr-id-1_package_version" with (tempfile.TemporaryDirectory() as dir, set_env(SKYHOOK_RESOURCE_ID=SKYHOOK_RESOURCE_ID)): interrupt_dir = f"{controller.get_skyhook_directory(dir)}/interrupts/flags/{SKYHOOK_RESOURCE_ID}" interrupt = interrupts.ServiceRestart(["foo", "bar"]) - controller.do_interrupt(interrupt.make_controller_input(), dir, "copy_dir", on_host=False, config_data=self.config_data) + controller.do_interrupt(interrupt.make_controller_input(), dir, "copy_dir", on_host=False) for i in range(len(interrupt.interrupt_cmd)): self.assertTrue(os.path.exists(f"{interrupt_dir}/{interrupt._type()}_{i}.complete")) @@ -1129,12 +1134,12 @@ def test_interrupt_create_flags_per_cmd(self, run_mock): @mock.patch("skyhook_agent.controller._run") def test_interrupt_failures_remove_flag(self, run_mock): run_mock.side_effect = [0,1,0] - SKYHOOK_RESOURCE_ID="foo" + SKYHOOK_RESOURCE_ID="scr-id-1_package_version" with (tempfile.TemporaryDirectory() as dir, set_env(SKYHOOK_RESOURCE_ID=SKYHOOK_RESOURCE_ID)): interrupt_dir = f"{controller.get_skyhook_directory(dir)}/interrupts/flags/{SKYHOOK_RESOURCE_ID}" interrupt = interrupts.ServiceRestart(["foo", "bar"]) - controller.do_interrupt(interrupt.make_controller_input(), dir, "copy_dir", on_host=False, config_data=self.config_data) + controller.do_interrupt(interrupt.make_controller_input(), dir, "copy_dir", on_host=False) self.assertTrue(os.path.exists(f"{interrupt_dir}/{interrupt._type()}_0.complete")) self.assertFalse(os.path.exists(f"{interrupt_dir}/{interrupt._type()}_1.complete")) @@ -1156,50 +1161,85 @@ def test_interrupt_failure_fails_controller(self, run_mock, datetime_mock): ], } with self._setup_for_main(steps) as (container_root_dir, config_data, root_dir): - result = controller.main( - Mode.INTERRUPT, - root_dir, - "/tmp", - interrupts.ServiceRestart("containerd").make_controller_input() - ) - + with set_env(SKYHOOK_RESOURCE_ID="scr-id-1_package_version"): + result = controller.main( + Mode.INTERRUPT, + root_dir, + "/tmp", + interrupts.ServiceRestart("containerd").make_controller_input() + ) + config_data = { + "package_name": "package", + "package_version": "version" + } run_mock.assert_has_calls([ mock.call(["systemctl", "daemon-reload"], controller.get_log_file(root_dir, "interrupts/service_restart_0", "copy_dir", config_data), on_host=True, root_mount=root_dir, write_cmds=True) ]) self.assertEqual(result, True) + @mock.patch("skyhook_agent.controller.datetime") + @mock.patch("skyhook_agent.controller._run") + def test_interrupt_makes_config_from_skyhook_resource_id(self, run_mock, datetime_mock): + now_mock = mock.MagicMock() + datetime_mock.now.return_value = now_mock + now_mock.strftime.return_value = "12345" + run_mock.return_value = 0 + steps = { + Mode.APPLY: [ + Step("foo.sh", arguments=[]), + ], + Mode.APPLY_CHECK: [ + Step("foo_check.sh", arguments=[]), + ], + } + with self._setup_for_main(steps) as (container_root_dir, config_data, root_dir): + with set_env(SKYHOOK_RESOURCE_ID="scr-id-1_package_version"): + result = controller.main( + Mode.INTERRUPT, + root_dir, + "/tmp", + interrupts.ServiceRestart("containerd").make_controller_input() + ) + config_data = { + "package_name": "package", + "package_version": "version" + } + run_mock.assert_has_calls([ + mock.call(["systemctl", "daemon-reload"], controller.get_log_file(root_dir, "interrupts/service_restart_0", "copy_dir", config_data), on_host=True, root_mount=root_dir, write_cmds=True) + ]) + @mock.patch("skyhook_agent.controller.main") def test_interrupt_mode_reads_extra_argument(self, main_mock): argv = ["controller.py", str(Mode.INTERRUPT), "root_mount", "copy_dir", "interrupt_data"] - with set_env(COPY_RESOLVE="false"): + with set_env(COPY_RESOLV="false"): controller.cli(argv) main_mock.assert_called_once_with(str(Mode.INTERRUPT), "root_mount", "copy_dir", "interrupt_data", False) @mock.patch("skyhook_agent.controller.main") def test_cli_overlay_always_run_step_is_correct(self, main_mock): - with set_env(OVERLAY_ALWAYS_RUN_STEP="true", COPY_RESOLVE="false"): + with set_env(OVERLAY_ALWAYS_RUN_STEP="true", COPY_RESOLV="false"): controller.cli(["controller.py", str(Mode.APPLY), "root_mount", "copy_dir"]) main_mock.assert_called_once_with(str(Mode.APPLY), "root_mount", "copy_dir", None, True) main_mock.reset_mock() - with set_env(OVERLAY_ALWAYS_RUN_STEP="false", COPY_RESOLVE="false"): + with set_env(OVERLAY_ALWAYS_RUN_STEP="false", COPY_RESOLV="false"): controller.cli(["controller.py", str(Mode.APPLY), "root_mount", "copy_dir"]) main_mock.assert_called_once_with(str(Mode.APPLY), "root_mount", "copy_dir", None, False) @mock.patch("skyhook_agent.controller.main") @mock.patch("skyhook_agent.controller.shutil") - def test_cli_copy_resolve(self, shutil_mock, main_mock): + def test_cli_COPY_RESOLV(self, shutil_mock, main_mock): argv = ["controller.py", str(Mode.APPLY), "root_mount", "copy_dir"] - with set_env(COPY_RESOLVE="true"): + with set_env(COPY_RESOLV="true"): controller.cli(argv) shutil_mock.copyfile.assert_called_once() shutil_mock.copyfile.reset_mock() - with set_env(COPY_RESOLVE="false"): + with set_env(COPY_RESOLV="false"): controller.cli(argv) shutil_mock.copyfile.assert_not_called()