diff --git a/- b/- new file mode 100644 index 0000000..cec24cb --- /dev/null +++ b/- @@ -0,0 +1,87 @@ +import click + +from aiida.cmdline.utils import echo + +from .root import cmd_root +from .params import arguments + +@cmd_root.group("server") +def server_group(): + """Commands for interacting with the HQ server.""" + + +@server_group.command("start") +@arguments.COMPUTER() +def cmd_start(computer): + """Start the HyperQueue server.""" + + with computer.get_transport() as transport: + retval, _, _ = transport.exec_command_wait("hq server info") + + if retval == 0: + echo.echo_info("server is already running!") + return + + with computer.get_transport() as transport: + # FIXME: It requires to sleep a bit after the nohup + # see https://github.com/aiidateam/aiida-core/issues/6377 + # but the sleep solution is incorrect!!! Since the sleep will always return 0. + # this not rely on https://github.com/aiidateam/aiida-core/pull/6452 + retval, _, stderr = transport.exec_command_wait( + "nohup hq server start 1>$HOME/.hq-stdout 2>$HOME/.hq-stderr &", + timeout=0.1, + ) + + if retval != 0: + echo.echo_critical(f"unable to start the server: {stderr}") + + echo.echo_success("HQ server started!") + +@server_group.command("stop") +@arguments.COMPUTER() +def cmd_stop(computer): + """Start the HyperQueue server.""" + + with computer.get_transport() as transport: + retval, _, _ = transport.exec_command_wait("hq server info") + + if retval != 0: + echo.echo_info("server is not running!") + return + + echo.echo_info("Stop the hq server will close all allocs.") + + with computer.get_transport() as transport: + retval, _, stderr = transport.exec_command_wait( + "hq server stop" + ) + + if retval != 0: + echo.echo_critical(f"unable to stop the server: {stderr}") + + echo.echo_success("HQ server stopped!") + +@server_group.command("restart") +@arguments.COMPUTER() +@click.pass_context +def cmd_restart(ctx, computer): + """Restart the HyperQueue server by stop and start again""" + ctx.invoke(cmd_stop) + ctx.invoke(cmd_start) + + +@server_group.command("info") +@arguments.COMPUTER() +def cmd_info(computer): + """Get information on the HyperQueue server.""" + + with computer.get_transport() as transport: + retval, stdout, stderr = transport.exec_command_wait("hq server info") + + if retval != 0: + echo.echo_critical( + f"cannot obtain HyperQueue server information: {stderr}\n" + "Try starting the server with `aiida-qe server start`." + ) + + echo.echo(stdout) diff --git a/README.md b/README.md index ce3332b..8f5077a 100644 --- a/README.md +++ b/README.md @@ -13,3 +13,7 @@ AiiDA plugin for the [HyperQueue](https://github.com/It4innovations/hyperqueue) Allows task farming on Slurm machines through the submission of AiiDA calculations to the [HyperQueue](https://github.com/It4innovations/hyperqueue) metascheduler. See the [Documentation](http://aiida-hyperqueue.readthedocs.io/) for more information on how to install and use the plugin. + +## For developers + +To control the loglevel of command, since we use the `echo` module from aiida, the CLI loglever can be set through `logging.verdi_loglevel`. diff --git a/aiida_hyperqueue/cli/__init__.py b/aiida_hyperqueue/cli/__init__.py new file mode 100644 index 0000000..76dd0bd --- /dev/null +++ b/aiida_hyperqueue/cli/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: utf-8 -*- +from aiida.cmdline.params import options as core_options +from aiida.cmdline.params import types as core_types + +from .root import cmd_root +from .install import cmd_install +from .server import cmd_info, cmd_start, cmd_stop +from .alloc import cmd_list, cmd_add, cmd_remove diff --git a/aiida_hyperqueue/cli.py b/aiida_hyperqueue/cli/alloc.py similarity index 58% rename from aiida_hyperqueue/cli.py rename to aiida_hyperqueue/cli/alloc.py index dd911cb..607e655 100644 --- a/aiida_hyperqueue/cli.py +++ b/aiida_hyperqueue/cli/alloc.py @@ -1,65 +1,11 @@ -# -*- coding: utf-8 -*- -"""Command line interface (CLI) for aiida_hyperqueue.""" - import click -from aiida.cmdline.params import options, arguments -from aiida.cmdline.utils import decorators, echo -from aiida.cmdline.commands.cmd_data import verdi_data - - -@verdi_data.group("hyperqueue") -def data_cli(): - """Command line interface for aiida-hyperqueue""" +from aiida.cmdline.params import options, arguments +from aiida.cmdline.utils import echo -@data_cli.group("server") -def server_group(): - """Commands for interacting with the HQ server.""" - - -@server_group.command("start") -@arguments.COMPUTER() -@decorators.with_dbenv() -def start_cmd(computer): - """Start the HyperQueue server.""" - - with computer.get_transport() as transport: - retval, _, _ = transport.exec_command_wait("hq server info") - - if retval == 0: - echo.echo_info("server is already running!") - return - - with computer.get_transport() as transport: - retval, _, stderr = transport.exec_command_wait( - "nohup hq server start 1>$HOME/.hq-stdout 2>$HOME/.hq-stderr &" - ) - - if retval != 0: - echo.echo_critical(f"unable to start the server: {stderr}") - - echo.echo_success("HQ server started!") - - -@server_group.command("info") -@arguments.COMPUTER() -@decorators.with_dbenv() -def info_cmd(computer): - """Get information on the HyperQueue server.""" - - with computer.get_transport() as transport: - retval, stdout, stderr = transport.exec_command_wait("hq server info") - - if retval != 0: - echo.echo_critical( - f"cannot obtain HyperQueue server information: {stderr}\n" - "Try starting the server with `verdi data hyperqueue server start`." - ) - - echo.echo(stdout) - +from .root import cmd_root -@data_cli.group("alloc") +@cmd_root.group("alloc") def alloc_group(): """Commands to configure HQ allocations.""" @@ -102,13 +48,13 @@ def alloc_group(): default=1, help=("Option to allow pooled jobs to launch on multiple nodes."), ) -@decorators.with_dbenv() -def add_cmd( +def cmd_add( slurm_options, computer, time_limit, hyper_threading, backlog, workers_per_alloc ): """Add a new allocation to the HQ server.""" - hyper = "" if hyper_threading else "--cpus no-ht" + # from hq==0.13.0: ``--cpus=no-ht`` is now changed to a flag ``--no-hyper-threading`` + hyper = "" if hyper_threading else "--no-hyper-threading" with computer.get_transport() as transport: retval, _, stderr = transport.exec_command_wait( @@ -124,8 +70,7 @@ def add_cmd( @alloc_group.command("list") @arguments.COMPUTER() -@decorators.with_dbenv() -def list_cmd(computer): +def cmd_list(computer): """List the allocations on the HQ server.""" with computer.get_transport() as transport: @@ -140,8 +85,7 @@ def list_cmd(computer): @alloc_group.command("remove") @click.argument("alloc_id") @options.COMPUTER(required=True) -@decorators.with_dbenv() -def remove_cmd(alloc_id, computer): +def cmd_remove(alloc_id, computer): """Remove an allocation from the HQ server.""" with computer.get_transport() as transport: diff --git a/aiida_hyperqueue/cli/install.py b/aiida_hyperqueue/cli/install.py new file mode 100644 index 0000000..3f83066 --- /dev/null +++ b/aiida_hyperqueue/cli/install.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +import click +import tempfile +import requests +import tarfile +from pathlib import Path + +from aiida import orm +from aiida.cmdline.utils import echo + +from .params import arguments +from .root import cmd_root + + +@cmd_root.command("install") +@arguments.COMPUTER() +@click.option( + "-p", + "--remote-bin-dir", + type=click.Path(), + default=Path("$HOME/bin/"), + help="remote bin path hq will stored.", +) +@click.option( + "--hq-version", type=str, default="0.19.0", help="the hq version will be installed." +) +# TODO: separate the bashrc write and make it optional. +# TODO: should also support different arch binary?? +def cmd_install(computer: orm.Computer, remote_bin_dir: Path, hq_version: str): + """Install the hq binary to the computer through the transport""" + + # The minimal hq version we support is 0.13.0, check the minor version + try: + _, minor, _ = hq_version.split('.') + except ValueError as e: + echo.echo_critical(f"Cannot parse the version {hq_version}: {e}") + else: + if int(minor) < 13: + # `--no-hyper-threading` replace `--cpus=no-ht` from 0.13.0 + # If older version installed, try to not use `--no-hyper-threading` for `aiida-hq alloc add`. + echo.echo_warning( + f"You are installing hq version {hq_version}, please do not use `--no-hyper-threading` for `aiida-hq alloc add`." + " Or install version >= 0.13.0" + ) + + # Download the hq binary with specific version to local temp folder + # raise if the version not found + # Then upload to the remote using opened transport of computer + with tempfile.TemporaryDirectory() as temp_dir: + url = f"https://github.com/It4innovations/hyperqueue/releases/download/v{hq_version}/hq-v{hq_version}-linux-x64.tar.gz" + response = requests.get(url, stream=True) + rcode = response.status_code + + if rcode != 200: + echo.echo_error( + "Cannot download the hq, please check the version is exist." + ) + + temp_dir = Path(temp_dir) + tar_path = temp_dir / "hq.tar.gz" + + with open(tar_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + with tarfile.open(tar_path, "r") as tar: + tar.extractall(path=temp_dir) + + echo.echo_success(f"The hq version {hq_version} binary downloaded.") + + bin_path = temp_dir / "hq" + + # upload the binary to remote + # TODO: try not override if the binary exist, put has overwrite=True as default + with computer.get_transport() as transport: + # Get the abs path of remote bin dir + retval, stdout, stderr = transport.exec_command_wait(f"echo {str(remote_bin_dir)}") + if retval !=0: + echo.echo_critical(f"Not able to parse remote bin dir {remote_bin_dir}, exit_code={retval}") + else: + remote_bin_dir = Path(stdout.strip()) + + # first check if the hq exist in the target folder + if transport.isfile(str(remote_bin_dir / "hq")): + echo.echo_info( + f"hq exist in the {remote_bin_dir} on remote, will override it." + ) + + transport.makedirs(path=remote_bin_dir, ignore_existing=True) + transport.put( + localpath=str(bin_path.resolve()), remotepath=str(remote_bin_dir) + ) + + # XXX: should transport.put take care of this already?? + transport.exec_command_wait(f"chmod +x {str(remote_bin_dir / 'hq')}") + + # write to bashrc + identity_str = "by aiida-hq" + retval, _, stderr = transport.exec_command_wait( + f"grep -q '# {identity_str}' ~/.bashrc || echo '# {identity_str}\nexport PATH=$HOME/bin:$PATH' >> ~/.bashrc" + ) + + if retval != 0: + echo.echo_critical( + f"Not able to set set the path $HOME/bin to your remote bashrc, try to do it manually.\n" + f"Info: {stderr}" + ) + + echo.echo_success("The hq binary installed in remote") diff --git a/aiida_hyperqueue/cli/params/__init__.py b/aiida_hyperqueue/cli/params/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aiida_hyperqueue/cli/params/arguments.py b/aiida_hyperqueue/cli/params/arguments.py new file mode 100644 index 0000000..cf140e3 --- /dev/null +++ b/aiida_hyperqueue/cli/params/arguments.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +from aiida.cmdline.params import arguments as core_arguments + +COMPUTER = core_arguments.COMPUTER diff --git a/aiida_hyperqueue/cli/params/options.py b/aiida_hyperqueue/cli/params/options.py new file mode 100644 index 0000000..3f773be --- /dev/null +++ b/aiida_hyperqueue/cli/params/options.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +"""Reusable options for CLI commands.""" + +import functools + +import click +from aiida.cmdline.params import options as core_options +from aiida.cmdline.params import types as core_types + +__all__ = ( + "PROFILE", + "VERBOSITY", + "VERSION", +) + +PROFILE = functools.partial( + core_options.PROFILE, + type=core_types.ProfileParamType(load_profile=True), + expose_value=False, +) + +# Clone the ``VERBOSITY`` option from ``aiida-core`` so the ``-v`` short flag can be removed, since that overlaps with +# the flag of the ``VERSION`` option of this CLI. +VERBOSITY = core_options.VERBOSITY.clone() +VERBOSITY.args = ("--verbosity",) + +VERSION = core_options.OverridableOption( + "-v", + "--version", + type=click.STRING, + required=False, + help="Select the version of the installed configuration.", +) diff --git a/aiida_hyperqueue/cli/root.py b/aiida_hyperqueue/cli/root.py new file mode 100644 index 0000000..7341882 --- /dev/null +++ b/aiida_hyperqueue/cli/root.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +"""Command line interface `aiida-hq` for aiida-hyperqueue. +The CLI implementation prototype from `aiida-pseudo`. +""" + +import click + +from aiida.cmdline.groups.verdi import VerdiCommandGroup + +from .params import options + + +class CustomVerdiCommandGroup(VerdiCommandGroup): + """Subclass of :class:`aiida.cmdline.groups.verdi.VerdiCommandGroup` for the CLI. + + This subclass overrides the verbosity option to use a custom one that removes the ``-v`` short version of the option + since that is used by other options in this CLI and so would clash. + """ + + @staticmethod + def add_verbosity_option(cmd): + """Apply the ``verbosity`` option to the command, which is common to all subcommands.""" + if cmd is not None and "verbosity" not in [param.name for param in cmd.params]: + cmd = options.VERBOSITY()(cmd) + + return cmd + + +@click.group( + "aiida-hq", + cls=CustomVerdiCommandGroup, + context_settings={"help_option_names": ["-h", "--help"]}, +) +@options.VERBOSITY() +@options.PROFILE() +def cmd_root(): + """CLI for the ``aiida-hyperqueue`` plugin.""" diff --git a/aiida_hyperqueue/cli/server.py b/aiida_hyperqueue/cli/server.py new file mode 100644 index 0000000..37babee --- /dev/null +++ b/aiida_hyperqueue/cli/server.py @@ -0,0 +1,111 @@ +import click + +from aiida.cmdline.utils import echo + +from .root import cmd_root +from .params import arguments + +@cmd_root.group("server") +def server_group(): + """Commands for interacting with the HQ server.""" + + +@server_group.command("start") +@arguments.COMPUTER() +@click.option("-d", "--domain", required=False, type=click.STRING, help="domain that will attached to the `hostname` of remote.") +def cmd_start(computer, domain: str): + """Start the HyperQueue server.""" + + with computer.get_transport() as transport: + retval, _, _ = transport.exec_command_wait("hq server info") + + if retval == 0: + echo.echo_info("server is already running!") + return + + with computer.get_transport() as transport: + # Mostly the case needed by CSCS machines + # The hostname has not domain included, it requires with domain to connect login node from compute node + # We attach the domain name to the hostname manually and passed to the start command. + + # start command + start_command_lst = ["hq", "server", "start"] + + if domain is not None: + retval, stdout, stderr = transport.exec_command_wait( + "hostname" + ) + if retval != 0: + echo.echo_critical(f"unable to get the hostname: {stderr}") + else: + hostname = stdout.strip() + start_command_lst.extend(["--host", f"{hostname}.{domain}"]) + + start_command_lst.extend(["1>$HOME/.hq-stdout", "2>$HOME/.hq-stderr", "&",]) + start_command = " ".join(start_command_lst) + + # FIXME: It requires to sleep a bit after the nohup + # see https://github.com/aiidateam/aiida-core/issues/6377 + # but the sleep solution is incorrect!!! Since the sleep will always return 0. + # this not rely on https://github.com/aiidateam/aiida-core/pull/6452 + echo.echo_debug(f"Run start command {start_command} on the remote") + + retval, _, stderr = transport.exec_command_wait( + start_command, + timeout=0.1, + ) + + if retval != 0: + echo.echo_critical(f"unable to start the server: {stderr}") + + echo.echo_success("HQ server started!") + +@server_group.command("stop") +@arguments.COMPUTER() +def cmd_stop(computer): + """Start the HyperQueue server.""" + + with computer.get_transport() as transport: + retval, _, _ = transport.exec_command_wait("hq server info") + + if retval != 0: + echo.echo_info("server is not running!") + return + + echo.echo_info("Stop the hq server will close all allocs.") + + with computer.get_transport() as transport: + retval, _, stderr = transport.exec_command_wait( + "hq server stop" + ) + + if retval != 0: + echo.echo_critical(f"unable to stop the server: {stderr}") + + echo.echo_success("HQ server stopped!") + +@server_group.command("restart") +@arguments.COMPUTER() +# TODO: how to pass domain to restart??? +@click.pass_context +def cmd_restart(ctx, computer): + """Restart the HyperQueue server by stop and start again""" + ctx.forward(cmd_stop) + ctx.forward(cmd_start) + + +@server_group.command("info") +@arguments.COMPUTER() +def cmd_info(computer): + """Get information on the HyperQueue server.""" + + with computer.get_transport() as transport: + retval, stdout, stderr = transport.exec_command_wait("hq server info") + + if retval != 0: + echo.echo_critical( + f"cannot obtain HyperQueue server information: {stderr}\n" + "Try starting the server with `aiida-qe server start`." + ) + + echo.echo(stdout) diff --git a/pyproject.toml b/pyproject.toml index e599419..6bec36f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,8 +52,8 @@ pre-commit = [ [project.entry-points.'aiida.schedulers'] "hyperqueue" = "aiida_hyperqueue.scheduler:HyperQueueScheduler" -[project.entry-points.'aiida.cmdline.data'] -"hyperqueue" = "aiida_hyperqueue.cli:data_cli" +[project.scripts] +aiida-hq = 'aiida_hyperqueue.cli:cmd_root' [tool.pytest.ini_options] python_files = "test_*.py example_*.py" diff --git a/tests/conftest.py b/tests/conftest.py index 26ba2f0..7caa940 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import contextlib import json import os @@ -20,6 +21,7 @@ ROOT_DIR = PYTEST_DIR.parent BIN_DIR = ROOT_DIR / "target" + def pytest_sessionstart(session): # Download the hq binary before session start if (BIN_DIR / "hq").exists(): @@ -34,6 +36,7 @@ def pytest_sessionstart(session): if result.returncode != 0: raise Exception(f"Installation failed with return code {result.returncode}") + def get_hq_binary(debug): return BIN_DIR / "hq" @@ -53,7 +56,9 @@ def __init__(self, work_path): self.processes = [] self.work_path = work_path - def start_process(self, name, args, env=None, catch_io=True, cwd=None, final_check=True): + def start_process( + self, name, args, env=None, catch_io=True, cwd=None, final_check=True + ): cwd = str(cwd or self.work_path) logfile = (self.work_path / name).with_suffix(".out") print(f"Starting process {name} with logfile {logfile}") @@ -69,7 +74,9 @@ def start_process(self, name, args, env=None, catch_io=True, cwd=None, final_che ) else: p = subprocess.Popen(args, cwd=cwd, env=env) - self.processes.append(ManagedProcess(name=name, process=p, final_check=final_check)) + self.processes.append( + ManagedProcess(name=name, process=p, final_check=final_check) + ) return p def check_process_exited(self, process: subprocess.Popen, expected_code=0): @@ -83,7 +90,9 @@ def is_process_alive(): elif expected_code is not None: assert process.returncode == expected_code - self.processes = [p for p in self.processes if p.process is not process] + self.processes = [ + p for p in self.processes if p.process is not process + ] return False raise Exception(f"Process with pid {process.pid} not found") @@ -93,7 +102,11 @@ def check_running_processes(self): """Checks that everything is still running""" for p in self.processes: if p.final_check and p.process.poll() is not None: - raise Exception("Process {0} crashed (log in {1}/{0}.out)".format(p.name, self.work_path)) + raise Exception( + "Process {0} crashed (log in {1}/{0}.out)".format( + p.name, self.work_path + ) + ) def kill_all(self): self.sort_processes_for_kill() @@ -163,7 +176,9 @@ def server_args(server_dir="hq-server", debug=True): args += ["server", "start"] return args - def start_server(self, server_dir="hq-server", args=None, env=None) -> subprocess.Popen: + def start_server( + self, server_dir="hq-server", args=None, env=None + ) -> subprocess.Popen: self.server_dir = os.path.join(self.work_path, server_dir) environment = self.make_default_env() if env: @@ -222,7 +237,9 @@ def start_worker( worker_args += ["--cpus", str(cpus)] if args: worker_args += list(args) - r = self.start_process(hostname, worker_args, final_check=final_check, env=worker_env) + r = self.start_process( + hostname, worker_args, final_check=final_check, env=worker_env + ) if wait_for_start: print(wait_for_start) @@ -313,11 +330,15 @@ def command( if process.returncode != 0: if expect_fail: if expect_fail not in stdout: - raise Exception(f"Command should failed with message '{expect_fail}' but got:\n{stdout}") + raise Exception( + f"Command should failed with message '{expect_fail}' but got:\n{stdout}" + ) else: return print(f"Process output: {stdout}") - raise Exception(f"Process failed with exit-code {process.returncode}\n\n{stdout}") + raise Exception( + f"Process failed with exit-code {process.returncode}\n\n{stdout}" + ) if expect_fail is not None: raise Exception("Command should failed") if as_table: diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..b05b025 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,2 @@ +import pytest +from click.testing import CliRunner diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index e582553..6b9697e 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Tests for command line interface.""" + import pytest import datetime import logging @@ -16,30 +17,31 @@ from .conftest import HqEnv from .utils import wait_for_job_state + @pytest.fixture def valid_submit_script(): scheduler = HyperQueueScheduler() job_tmpl = JobTemplate() - job_tmpl.job_name = 'echo hello' - job_tmpl.shebang = '#!/bin/bash' + job_tmpl.job_name = "echo hello" + job_tmpl.shebang = "#!/bin/bash" job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.job_resource = scheduler.create_job_resource(num_cpus=1) job_tmpl.max_wallclock_seconds = 24 * 3600 tmpl_code_info = JobTemplateCodeInfo() - tmpl_code_info.cmdline_params = ['echo', 'Hello'] + tmpl_code_info.cmdline_params = ["echo", "Hello"] job_tmpl.codes_info = [tmpl_code_info] job_tmpl.codes_run_mode = CodeRunMode.SERIAL return scheduler.get_submit_script(job_tmpl) + def test_resource_validation(): """Tests to verify that resources are correctly validated.""" resource = HyperQueueJobResource(num_cpus=16, memory_mb=20) assert resource.num_cpus == 16 assert resource.memory_mb == 20 - # If memory_mb not set, the default value 0 will be assigned resource = HyperQueueJobResource(num_cpus=16) assert resource.num_cpus == 16 @@ -48,29 +50,31 @@ def test_resource_validation(): # raise if num_cpus is not set with pytest.raises( KeyError, - match='Must specify `num_cpus`', + match="Must specify `num_cpus`", ): HyperQueueJobResource() # raise if num_cpus is not integer with pytest.raises( ValueError, - match='`num_cpus` must be an integer', + match="`num_cpus` must be an integer", ): HyperQueueJobResource(num_cpus=1.2) # raise if memory_mb is not integer with pytest.raises( ValueError, - match='`memory_mb` must be an integer', + match="`memory_mb` must be an integer", ): HyperQueueJobResource(num_cpus=4, memory_mb=1.2) + def test_submit_command(): """Test submit command""" scheduler = HyperQueueScheduler() - assert scheduler._get_submit_command('job.sh') == "hq submit job.sh" + assert scheduler._get_submit_command("job.sh") == "hq submit job.sh" + def test_parse_submit_command_output(hq_env: HqEnv, valid_submit_script): """Test parsing the output of submit command""" @@ -78,7 +82,9 @@ def test_parse_submit_command_output(hq_env: HqEnv, valid_submit_script): hq_env.start_worker(cpus="2") Path("_aiidasubmit.sh").write_text(valid_submit_script) - process = hq_env.command(["submit", "_aiidasubmit.sh"], wait=False, ignore_stderr=True) + process = hq_env.command( + ["submit", "_aiidasubmit.sh"], wait=False, ignore_stderr=True + ) stdout = process.communicate()[0].decode() stderr = "" retval = process.returncode @@ -90,6 +96,7 @@ def test_parse_submit_command_output(hq_env: HqEnv, valid_submit_script): assert job_id == "1" + def test_submit_script(): """Test the creation of a simple submission script.""" VALID_SCRIPT_CONTENT = """#!/bin/bash @@ -104,32 +111,33 @@ def test_submit_script(): scheduler = HyperQueueScheduler() job_tmpl = JobTemplate() - job_tmpl.shebang = '#!/bin/bash' + job_tmpl.shebang = "#!/bin/bash" job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.job_resource = scheduler.create_job_resource(num_cpus=2, memory_mb=256) job_tmpl.max_wallclock_seconds = 24 * 3600 tmpl_code_info = JobTemplateCodeInfo() - tmpl_code_info.cmdline_params = ['mpirun', '-np', '4', 'pw.x', '-npool', '1'] - tmpl_code_info.stdin_name = 'aiida.in' + tmpl_code_info.cmdline_params = ["mpirun", "-np", "4", "pw.x", "-npool", "1"] + tmpl_code_info.stdin_name = "aiida.in" job_tmpl.codes_info = [tmpl_code_info] job_tmpl.codes_run_mode = CodeRunMode.SERIAL submit_script_text = scheduler.get_submit_script(job_tmpl) - + assert submit_script_text == VALID_SCRIPT_CONTENT + def test_submit_script_mem_not_specified(): """Test if memory_mb not pass to resource, it will not specified in job script""" scheduler = HyperQueueScheduler() job_tmpl = JobTemplate() - job_tmpl.shebang = '#!/bin/bash' + job_tmpl.shebang = "#!/bin/bash" job_tmpl.uuid = str(uuid.uuid4()) job_tmpl.job_resource = scheduler.create_job_resource(num_cpus=2) job_tmpl.max_wallclock_seconds = 24 * 3600 tmpl_code_info = JobTemplateCodeInfo() - tmpl_code_info.cmdline_params = ['mpirun', '-np', '4', 'pw.x', '-npool', '1'] - tmpl_code_info.stdin_name = 'aiida.in' + tmpl_code_info.cmdline_params = ["mpirun", "-np", "4", "pw.x", "-npool", "1"] + tmpl_code_info.stdin_name = "aiida.in" job_tmpl.codes_info = [tmpl_code_info] job_tmpl.codes_run_mode = CodeRunMode.SERIAL @@ -137,6 +145,7 @@ def test_submit_script_mem_not_specified(): assert "#HQ --resource mem" not in submit_script_text + def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): """The generated script can actually be run by hq""" hq_env.start_server() @@ -149,10 +158,11 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): assert "cpus: 1 compact" in table.get_row_value("Resources") wait_for_job_state(hq_env, 1, "FINISHED") - + assert table.get_row_value("State") == "FINISHED" -#class TestJoblistCommand: + +# class TestJoblistCommand: # """Tests of the issued squeue command.""" # # def test_joblist_single(self): @@ -171,7 +181,7 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): # assert '456,456' not in command # # -#def test_parse_out_of_memory(): +# def test_parse_out_of_memory(): # """Test that for job that failed due to OOM `parse_output` return the `ERROR_SCHEDULER_OUT_OF_MEMORY` code.""" # scheduler = SlurmScheduler() # stdout = '' @@ -186,7 +196,7 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): # assert exit_code == CalcJob.exit_codes.ERROR_SCHEDULER_OUT_OF_MEMORY # # -#def test_parse_node_failure(): +# def test_parse_node_failure(): # """Test that `ERROR_SCHEDULER_NODE_FAILURE` code is returned if `STATE == NODE_FAIL`.""" # scheduler = SlurmScheduler() # detailed_job_info = { @@ -199,7 +209,7 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): # assert exit_code == CalcJob.exit_codes.ERROR_SCHEDULER_NODE_FAILURE # # -#@pytest.mark.parametrize( +# @pytest.mark.parametrize( # 'detailed_job_info, expected', # [ # ('string', TypeError), # Not a dictionary @@ -211,8 +221,8 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): # ValueError, # ), # `stdout` second line contains too few elements separated by pipe # ], -#) -#def test_parse_output_invalid(detailed_job_info, expected): +# ) +# def test_parse_output_invalid(detailed_job_info, expected): # """Test `SlurmScheduler.parse_output` for various invalid arguments.""" # scheduler = SlurmScheduler() # @@ -220,21 +230,21 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): # scheduler.parse_output(detailed_job_info, '', '') # # -#def test_parse_output_valid(): +# def test_parse_output_valid(): # """Test `SlurmScheduler.parse_output` for valid arguments.""" # detailed_job_info = {'stdout': 'State|Account|\n||\n'} # scheduler = SlurmScheduler() # assert scheduler.parse_output(detailed_job_info, '', '') is None # # -#def test_parse_submit_output_invalid_account(): +# def test_parse_submit_output_invalid_account(): # """Test ``SlurmScheduler._parse_submit_output`` returns exit code if stderr contains error about invalid account.""" # scheduler = SlurmScheduler() # stderr = 'Batch job submission failed: Invalid account or account/partition combination specified' # result = scheduler._parse_submit_output(1, '', stderr) # assert result == CalcJob.exit_codes.ERROR_SCHEDULER_INVALID_ACCOUNT -#def test_parse_common_joblist_output(): +# def test_parse_common_joblist_output(): # """Test whether _parse_joblist_output can parse the squeue output""" # scheduler = SlurmScheduler() # @@ -290,7 +300,7 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): # # self.assertTrue( j.num_machines==num_machines ) # # self.assertTrue( j.num_mpiprocs==num_mpiprocs ) # -#def test_parse_failed_squeue_output(self): +# def test_parse_failed_squeue_output(self): # """Test that _parse_joblist_output reacts as expected to failures.""" # scheduler = SlurmScheduler() # @@ -303,7 +313,7 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): # scheduler._parse_joblist_output(0, TEXT_SQUEUE_TO_TEST, 'error message') # # -#@pytest.mark.parametrize( +# @pytest.mark.parametrize( # 'value,expected', # [ # ('2', 2 * 60), @@ -324,8 +334,8 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): # ('UNLIMITED', 2**31 - 1), # ('NOT_SET', None), # ], -#) -#def test_time_conversion(value, expected): +# ) +# def test_time_conversion(value, expected): # """Test conversion of (relative) times. # # From docs, acceptable time formats include @@ -336,7 +346,7 @@ def test_submit_script_is_hq_valid(hq_env: HqEnv, valid_submit_script): # assert scheduler._convert_time(value) == expected # # -#def test_time_conversion_errors(caplog): +# def test_time_conversion_errors(caplog): # """Test conversion of (relative) times for bad inputs.""" # scheduler = SlurmScheduler() # diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index fc9b2f3..a7c557a 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from .table import JOB_TABLE_ROWS, parse_table, parse_tables from .wait import wait_for_job_state, wait_for_worker_state diff --git a/tests/utils/cmd.py b/tests/utils/cmd.py index b1aad3e..7794397 100644 --- a/tests/utils/cmd.py +++ b/tests/utils/cmd.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from typing import List diff --git a/tests/utils/io.py b/tests/utils/io.py index 47161be..f5b0e9e 100644 --- a/tests/utils/io.py +++ b/tests/utils/io.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import socket from contextlib import closing diff --git a/tests/utils/job.py b/tests/utils/job.py index 5850319..d24fa40 100644 --- a/tests/utils/job.py +++ b/tests/utils/job.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import os from typing import List, Optional @@ -5,7 +6,9 @@ from .table import Table -def default_task_output(job_id=1, task_id=0, type="stdout", working_dir: Optional[str] = None) -> str: +def default_task_output( + job_id=1, task_id=0, type="stdout", working_dir: Optional[str] = None +) -> str: working_dir = working_dir if working_dir else os.getcwd() return f"{working_dir}/job-{job_id}/{task_id}.{type}" diff --git a/tests/utils/mock.py b/tests/utils/mock.py index 6e00806..4371cdf 100644 --- a/tests/utils/mock.py +++ b/tests/utils/mock.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import contextlib import os import sys diff --git a/tests/utils/table.py b/tests/utils/table.py index 7b57b76..e89526a 100644 --- a/tests/utils/table.py +++ b/tests/utils/table.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from typing import Dict, List, Optional JOB_TABLE_ROWS = 16 @@ -103,7 +104,9 @@ def parse_table(table_info): current_rows = [] divider_count += 1 # End early - if ((divider_count == 3) or (divider_count == 2 and header is None)) and (i + 1) < len(lines): + if ((divider_count == 3) or (divider_count == 2 and header is None)) and ( + i + 1 + ) < len(lines): return Table(rows, header=header), lines[(i + 1) :] continue items = [x.strip() for x in line.split("|")[1:-1]] diff --git a/tests/utils/wait.py b/tests/utils/wait.py index b3a2b88..400f645 100644 --- a/tests/utils/wait.py +++ b/tests/utils/wait.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import time from typing import List, Union @@ -47,7 +48,9 @@ def check(): table = env.command(commands, as_table=True) last_table = table items = [row for row in table if row[0] in ids] - return len(items) >= len(ids) and all(j[state_index].lower() in target_states for j in items) + return len(items) >= len(ids) and all( + j[state_index].lower() in target_states for j in items + ) try: wait_until(check, **kwargs) @@ -57,11 +60,15 @@ def check(): raise e -def wait_for_job_state(env, ids: Union[int, List[int]], target_states: Union[str, List[str]], **kwargs): +def wait_for_job_state( + env, ids: Union[int, List[int]], target_states: Union[str, List[str]], **kwargs +): wait_for_state(env, ids, target_states, ["job", "list", "--all"], 2, **kwargs) -def wait_for_worker_state(env, ids: Union[int, List[int]], target_states: Union[str, List[str]], **kwargs): +def wait_for_worker_state( + env, ids: Union[int, List[int]], target_states: Union[str, List[str]], **kwargs +): wait_for_state(env, ids, target_states, ["worker", "list", "--all"], 1, **kwargs) @@ -73,4 +80,6 @@ def wait_for_pid_exit(pid: int): def wait_for_job_list_count(env, count: int): - wait_until(lambda: len(env.command(["job", "list", "--all"], as_table=True)) == count) + wait_until( + lambda: len(env.command(["job", "list", "--all"], as_table=True)) == count + )