Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions agent/skyhook-agent/src/skyhook_agent/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,22 @@ def summarize_check_results(results: list[bool], step_data: dict[Mode, list[Step

return False

def do_interrupt(interrupt_data: str, root_mount: str, copy_dir: str, on_host: bool, config_data: dict) -> bool:
def do_interrupt(interrupt_data: str, root_mount: str, copy_dir: str, on_host: bool) -> bool:
"""
Run an interrupt if there hasn't been an interrupt already for the skyhook ID.
"""


SKYHOOK_RESOURCE_ID, _ = _get_env_config()

# Interrupts don't really have config data we can read from the Package as it is run standalone.
# So read it off of SKYHOOK_RESOURCE_ID instead
# customer-f5a1d42e-74e5-4606-8bbc-b504fbe0074d-1_tuning_2.0.2
_, package, version = SKYHOOK_RESOURCE_ID.split("_")
config_data = {
"package_name": package,
"package_version": version,
}

interrupt = interrupts.inflate(interrupt_data)

# Check if the interrupt has already been run for this particular skyhook resource
Expand Down Expand Up @@ -392,6 +400,9 @@ def main(mode: Mode, root_mount: str, copy_dir: str, interrupt_data: None|str, a
logger.warning(f"This version of the Agent doesn't support the {mode} mode. Options are: {','.join(map(str, Mode))}.")
return False

if mode == Mode.INTERRUPT:
return do_interrupt(interrupt_data, root_mount, copy_dir, True)

_, SKYHOOK_DATA_DIR = _get_env_config()

# Check to see if the directory has already been copied down. If it hasn't assume that we
Expand Down Expand Up @@ -421,9 +432,6 @@ def main(mode: Mode, root_mount: str, copy_dir: str, interrupt_data: None|str, a
return agent_main(mode, root_mount, copy_dir, config_data, interrupt_data, always_run_step)

def agent_main(mode: Mode, root_mount: str, copy_dir: str, config_data: dict, interrupt_data: None|str, always_run_step=False):

if mode == Mode.INTERRUPT:
return do_interrupt(interrupt_data, root_mount, copy_dir, True, config_data)

# Pull out step_data so it matches with existing code
step_data = config_data["modes"]
Expand Down Expand Up @@ -506,7 +514,7 @@ def cli(sys_argv: list[str]=sys.argv):
# new way with interrupt data
mode, root_mount, copy_dir, interrupt_data = args

if os.getenv("COPY_RESOLVE", "true").lower() == "true":
if os.getenv("COPY_RESOLV", "true").lower() == "true":
shutil.copyfile("/etc/resolv.conf", f"{root_mount}/etc/resolv.conf")

always_run_step = os.getenv("OVERLAY_ALWAYS_RUN_STEP", "false").lower() == "true"
Expand Down
90 changes: 65 additions & 25 deletions agent/skyhook-agent/tests/test_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,13 +1088,18 @@ def test_interrupt_applies_all_commands(self, run_mock, datetime_mock):
],
}
with self._setup_for_main(steps) as (container_root_dir, config_data, root_dir):
controller.main(
Mode.INTERRUPT,
root_dir,
"copy_dir",
interrupts.ServiceRestart(["containerd",]).make_controller_input()
)
with set_env(SKYHOOK_RESOURCE_ID="scr-id-1_package_version"):
controller.main(
Mode.INTERRUPT,
root_dir,
"copy_dir",
interrupts.ServiceRestart(["containerd",]).make_controller_input()
)

config_data = {
"package_name": "package",
"package_version": "version"
}
run_mock.assert_has_calls([
mock.call(["systemctl", "daemon-reload"], controller.get_log_file(root_dir, "interrupts/service_restart_0", "copy_dir", config_data), on_host=True, root_mount=root_dir, write_cmds=True),
mock.call(["systemctl", "restart", "containerd"], controller.get_log_file(root_dir, "interrupts/service_restart_1", "copy_dir", config_data), on_host=True, root_mount=root_dir, write_cmds=True)
Expand All @@ -1103,38 +1108,38 @@ def test_interrupt_applies_all_commands(self, run_mock, datetime_mock):
@mock.patch("skyhook_agent.controller._run")
def test_interrupt_isnt_run_when_skyhook_resource_id_flag_is_there(self, run_mock):
run_mock.return_value = 0
SKYHOOK_RESOURCE_ID="foo"
SKYHOOK_RESOURCE_ID="scr-id-1_package_version"
with (tempfile.TemporaryDirectory() as dir,
set_env(SKYHOOK_RESOURCE_ID=SKYHOOK_RESOURCE_ID)):
os.makedirs(f"{controller.get_skyhook_directory(dir)}/interrupts/flags/{SKYHOOK_RESOURCE_ID}", exist_ok=True)
with open(f"{controller.get_skyhook_directory(dir)}/interrupts/flags/{SKYHOOK_RESOURCE_ID}/node_restart_0.complete", 'w') as f:
f.write("")
controller.do_interrupt(interrupts.NodeRestart().make_controller_input(), dir, "copy_dir", on_host=False, config_data=self.config_data)
controller.do_interrupt(interrupts.NodeRestart().make_controller_input(), dir, "copy_dir", on_host=False)

run_mock.assert_not_called()

@mock.patch("skyhook_agent.controller._run")
def test_interrupt_create_flags_per_cmd(self, run_mock):
run_mock.return_value = 0
SKYHOOK_RESOURCE_ID="foo"
SKYHOOK_RESOURCE_ID="scr-id-1_package_version"
with (tempfile.TemporaryDirectory() as dir,
set_env(SKYHOOK_RESOURCE_ID=SKYHOOK_RESOURCE_ID)):
interrupt_dir = f"{controller.get_skyhook_directory(dir)}/interrupts/flags/{SKYHOOK_RESOURCE_ID}"
interrupt = interrupts.ServiceRestart(["foo", "bar"])
controller.do_interrupt(interrupt.make_controller_input(), dir, "copy_dir", on_host=False, config_data=self.config_data)
controller.do_interrupt(interrupt.make_controller_input(), dir, "copy_dir", on_host=False)

for i in range(len(interrupt.interrupt_cmd)):
self.assertTrue(os.path.exists(f"{interrupt_dir}/{interrupt._type()}_{i}.complete"))

@mock.patch("skyhook_agent.controller._run")
def test_interrupt_failures_remove_flag(self, run_mock):
run_mock.side_effect = [0,1,0]
SKYHOOK_RESOURCE_ID="foo"
SKYHOOK_RESOURCE_ID="scr-id-1_package_version"
with (tempfile.TemporaryDirectory() as dir,
set_env(SKYHOOK_RESOURCE_ID=SKYHOOK_RESOURCE_ID)):
interrupt_dir = f"{controller.get_skyhook_directory(dir)}/interrupts/flags/{SKYHOOK_RESOURCE_ID}"
interrupt = interrupts.ServiceRestart(["foo", "bar"])
controller.do_interrupt(interrupt.make_controller_input(), dir, "copy_dir", on_host=False, config_data=self.config_data)
controller.do_interrupt(interrupt.make_controller_input(), dir, "copy_dir", on_host=False)

self.assertTrue(os.path.exists(f"{interrupt_dir}/{interrupt._type()}_0.complete"))
self.assertFalse(os.path.exists(f"{interrupt_dir}/{interrupt._type()}_1.complete"))
Expand All @@ -1156,50 +1161,85 @@ def test_interrupt_failure_fails_controller(self, run_mock, datetime_mock):
],
}
with self._setup_for_main(steps) as (container_root_dir, config_data, root_dir):
result = controller.main(
Mode.INTERRUPT,
root_dir,
"/tmp",
interrupts.ServiceRestart("containerd").make_controller_input()
)

with set_env(SKYHOOK_RESOURCE_ID="scr-id-1_package_version"):
result = controller.main(
Mode.INTERRUPT,
root_dir,
"/tmp",
interrupts.ServiceRestart("containerd").make_controller_input()
)
config_data = {
"package_name": "package",
"package_version": "version"
}
run_mock.assert_has_calls([
mock.call(["systemctl", "daemon-reload"], controller.get_log_file(root_dir, "interrupts/service_restart_0", "copy_dir", config_data), on_host=True, root_mount=root_dir, write_cmds=True)
])

self.assertEqual(result, True)

@mock.patch("skyhook_agent.controller.datetime")
@mock.patch("skyhook_agent.controller._run")
def test_interrupt_makes_config_from_skyhook_resource_id(self, run_mock, datetime_mock):
now_mock = mock.MagicMock()
datetime_mock.now.return_value = now_mock
now_mock.strftime.return_value = "12345"
run_mock.return_value = 0
steps = {
Mode.APPLY: [
Step("foo.sh", arguments=[]),
],
Mode.APPLY_CHECK: [
Step("foo_check.sh", arguments=[]),
],
}
with self._setup_for_main(steps) as (container_root_dir, config_data, root_dir):
with set_env(SKYHOOK_RESOURCE_ID="scr-id-1_package_version"):
result = controller.main(
Mode.INTERRUPT,
root_dir,
"/tmp",
interrupts.ServiceRestart("containerd").make_controller_input()
)
config_data = {
"package_name": "package",
"package_version": "version"
}
run_mock.assert_has_calls([
mock.call(["systemctl", "daemon-reload"], controller.get_log_file(root_dir, "interrupts/service_restart_0", "copy_dir", config_data), on_host=True, root_mount=root_dir, write_cmds=True)
])

@mock.patch("skyhook_agent.controller.main")
def test_interrupt_mode_reads_extra_argument(self, main_mock):
argv = ["controller.py", str(Mode.INTERRUPT), "root_mount", "copy_dir", "interrupt_data"]
with set_env(COPY_RESOLVE="false"):
with set_env(COPY_RESOLV="false"):
controller.cli(argv)

main_mock.assert_called_once_with(str(Mode.INTERRUPT), "root_mount", "copy_dir", "interrupt_data", False)

@mock.patch("skyhook_agent.controller.main")
def test_cli_overlay_always_run_step_is_correct(self, main_mock):
with set_env(OVERLAY_ALWAYS_RUN_STEP="true", COPY_RESOLVE="false"):
with set_env(OVERLAY_ALWAYS_RUN_STEP="true", COPY_RESOLV="false"):
controller.cli(["controller.py", str(Mode.APPLY), "root_mount", "copy_dir"])

main_mock.assert_called_once_with(str(Mode.APPLY), "root_mount", "copy_dir", None, True)
main_mock.reset_mock()

with set_env(OVERLAY_ALWAYS_RUN_STEP="false", COPY_RESOLVE="false"):
with set_env(OVERLAY_ALWAYS_RUN_STEP="false", COPY_RESOLV="false"):
controller.cli(["controller.py", str(Mode.APPLY), "root_mount", "copy_dir"])
main_mock.assert_called_once_with(str(Mode.APPLY), "root_mount", "copy_dir", None, False)

@mock.patch("skyhook_agent.controller.main")
@mock.patch("skyhook_agent.controller.shutil")
def test_cli_copy_resolve(self, shutil_mock, main_mock):
def test_cli_COPY_RESOLV(self, shutil_mock, main_mock):
argv = ["controller.py", str(Mode.APPLY), "root_mount", "copy_dir"]
with set_env(COPY_RESOLVE="true"):
with set_env(COPY_RESOLV="true"):
controller.cli(argv)

shutil_mock.copyfile.assert_called_once()
shutil_mock.copyfile.reset_mock()

with set_env(COPY_RESOLVE="false"):
with set_env(COPY_RESOLV="false"):
controller.cli(argv)

shutil_mock.copyfile.assert_not_called()
Expand Down
Loading