Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(service): use slurm sinfo command to improve "cluster load" indicator #1664

Merged
merged 7 commits into from
Feb 2, 2024
2 changes: 2 additions & 0 deletions antarest/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class SlurmConfig:
default_time_limit: int = 0
default_json_db_name: str = ""
slurm_script_path: str = ""
partition: str = ""
max_cores: int = 64
antares_versions_on_remote_server: List[str] = field(default_factory=list)
enable_nb_cores_detection: bool = False
Expand Down Expand Up @@ -290,6 +291,7 @@ def from_dict(cls, data: JSON) -> "SlurmConfig":
default_time_limit=data.get("default_time_limit", defaults.default_time_limit),
default_json_db_name=data.get("default_json_db_name", defaults.default_json_db_name),
slurm_script_path=data.get("slurm_script_path", defaults.slurm_script_path),
partition=data.get("partition", defaults.partition),
antares_versions_on_remote_server=data.get(
"antares_versions_on_remote_server",
defaults.antares_versions_on_remote_server,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def _init_launcher_parameters(self, local_workspace: Optional[Path] = None) -> M
json_dir=local_workspace or self.slurm_config.local_workspace,
default_json_db_name=self.slurm_config.default_json_db_name,
slurm_script_path=self.slurm_config.slurm_script_path,
partition=self.slurm_config.partition,
antares_versions_on_remote_server=self.slurm_config.antares_versions_on_remote_server,
default_ssh_dict={
"username": self.slurm_config.username,
Expand Down
53 changes: 52 additions & 1 deletion antarest/launcher/model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import enum
import json
import typing as t
from datetime import datetime

from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, Sequence, String # type: ignore
from sqlalchemy.orm import relationship # type: ignore

from antarest.core.persistence import Base
from antarest.core.utils.string import to_camel_case
from antarest.login.model import Identity, UserInfo


Expand All @@ -32,6 +34,15 @@ class LauncherParametersDTO(BaseModel):
other_options: t.Optional[str] = None
# add extensions field here

@classmethod
def from_launcher_params(cls, params: t.Optional[str]) -> "LauncherParametersDTO":
"""
Convert the launcher parameters from a string to a `LauncherParametersDTO` object.
"""
if params is None:
return cls()
return cls.parse_obj(json.loads(params))


class LogType(str, enum.Enum):
STDOUT = "STDOUT"
Expand Down Expand Up @@ -214,3 +225,43 @@ class JobCreationDTO(BaseModel):

class LauncherEnginesDTO(BaseModel):
engines: t.List[str]


class LauncherLoadDTO(
BaseModel,
extra="forbid",
validate_assignment=True,
allow_population_by_field_name=True,
alias_generator=to_camel_case,
):
"""
DTO representing the load of the SLURM cluster or local machine.

Attributes:
allocated_cpu_rate: The rate of allocated CPU, in range (0, 100).
cluster_load_rate: The rate of cluster load, in range (0, 100).
nb_queued_jobs: The number of queued jobs.
launcher_status: The status of the launcher: "SUCCESS" or "FAILED".
"""

allocated_cpu_rate: float = Field(
description="The rate of allocated CPU, in range (0, 100)",
ge=0,
le=100,
title="Allocated CPU Rate",
)
cluster_load_rate: float = Field(
description="The rate of cluster load, in range (0, 100)",
ge=0,
le=100,
title="Cluster Load Rate",
)
nb_queued_jobs: int = Field(
description="The number of queued jobs",
ge=0,
title="Number of Queued Jobs",
)
launcher_status: str = Field(
description="The status of the launcher: 'SUCCESS' or 'FAILED'",
title="Launcher Status",
)
78 changes: 42 additions & 36 deletions antarest/launcher/service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import json
import logging
import os
import shutil
Expand Down Expand Up @@ -33,11 +32,14 @@
JobLogType,
JobResult,
JobStatus,
LauncherLoadDTO,
LauncherParametersDTO,
LogType,
XpansionParametersDTO,
)
from antarest.launcher.repository import JobResultRepository
from antarest.launcher.ssh_client import calculates_slurm_load
from antarest.launcher.ssh_config import SSHConfigDTO
from antarest.study.repository import StudyFilter
from antarest.study.service import StudyService
from antarest.study.storage.utils import assert_permission, extract_output_name, find_single_output_path
Expand Down Expand Up @@ -502,7 +504,7 @@ def _import_output(
launching_user = DEFAULT_ADMIN_USER

study_id = job_result.study_id
job_launch_params = LauncherParametersDTO.parse_raw(job_result.launcher_params or "{}")
job_launch_params = LauncherParametersDTO.from_launcher_params(job_result.launcher_params)

# this now can be a zip file instead of a directory !
output_true_path = find_single_output_path(output_path)
Expand Down Expand Up @@ -585,7 +587,7 @@ def _download_fallback_output(self, job_id: str, params: RequestParameters) -> F
export_path = Path(export_file_download.path)
export_id = export_file_download.id

def export_task(notifier: TaskUpdateNotifier) -> TaskResult:
def export_task(_: TaskUpdateNotifier) -> TaskResult:
try:
#
zip_dir(output_path, export_path)
Expand Down Expand Up @@ -622,43 +624,47 @@ def download_output(self, job_id: str, params: RequestParameters) -> FileDownloa
)
raise JobNotFound()

def get_load(self, from_cluster: bool = False) -> Dict[str, float]:
all_running_jobs = self.job_result_repository.get_running()
local_running_jobs = []
slurm_running_jobs = []
for job in all_running_jobs:
if job.launcher == "slurm":
slurm_running_jobs.append(job)
elif job.launcher == "local":
local_running_jobs.append(job)
def get_load(self) -> LauncherLoadDTO:
"""
Get the load of the SLURM cluster or the local machine.
"""
# SLURM load calculation
if self.config.launcher.default == "slurm":
if slurm_config := self.config.launcher.slurm:
ssh_config = SSHConfigDTO(
config_path=Path(),
username=slurm_config.username,
hostname=slurm_config.hostname,
port=slurm_config.port,
private_key_file=slurm_config.private_key_file,
key_password=slurm_config.key_password,
password=slurm_config.password,
)
partition = slurm_config.partition
allocated_cpus, cluster_load, queued_jobs = calculates_slurm_load(ssh_config, partition)
return LauncherLoadDTO(
allocated_cpu_rate=allocated_cpus,
cluster_load_rate=cluster_load,
nb_queued_jobs=queued_jobs,
launcher_status="SUCCESS",
)
else:
logger.warning(f"Unknown job launcher {job.launcher}")
raise KeyError("Default launcher is slurm but it is not registered in the config file")

load = {}
# local load calculation
local_used_cpus = sum(
LauncherParametersDTO.from_launcher_params(job.launcher_params).nb_cpu or 1
for job in self.job_result_repository.get_running()
)

slurm_config = self.config.launcher.slurm
if slurm_config is not None:
if from_cluster:
raise NotImplementedError("Cluster load not implemented yet")
default_cpu = slurm_config.nb_cores.default
slurm_used_cpus = 0
for job in slurm_running_jobs:
obj = json.loads(job.launcher_params) if job.launcher_params else {}
launch_params = LauncherParametersDTO(**obj)
slurm_used_cpus += launch_params.nb_cpu or default_cpu
load["slurm"] = slurm_used_cpus / slurm_config.max_cores
cluster_load_approx = min(1.0, local_used_cpus / (os.cpu_count() or 1))

local_config = self.config.launcher.local
if local_config is not None:
default_cpu = local_config.nb_cores.default
local_used_cpus = 0
for job in local_running_jobs:
obj = json.loads(job.launcher_params) if job.launcher_params else {}
launch_params = LauncherParametersDTO(**obj)
local_used_cpus += launch_params.nb_cpu or default_cpu
load["local"] = local_used_cpus / local_config.nb_cores.max

return load
return LauncherLoadDTO(
allocated_cpu_rate=cluster_load_approx,
cluster_load_rate=cluster_load_approx,
nb_queued_jobs=0,
launcher_status="SUCCESS",
)

def get_solver_versions(self, solver: str) -> List[str]:
"""
Expand Down
106 changes: 106 additions & 0 deletions antarest/launcher/ssh_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import contextlib
import socket
import shlex
from typing import Any, List, Tuple

import paramiko

from antarest.launcher.ssh_config import SSHConfigDTO


@contextlib.contextmanager # type: ignore
def ssh_client(ssh_config: SSHConfigDTO) -> paramiko.SSHClient: # type: ignore
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
client.connect(
hostname=ssh_config.hostname,
port=ssh_config.port,
username=ssh_config.username,
pkey=paramiko.RSAKey.from_private_key_file(filename=str(ssh_config.private_key_file)),
timeout=600,
allow_agent=False,
)
with contextlib.closing(client):
yield client


class SlurmError(Exception):
pass


def execute_command(ssh_config: SSHConfigDTO, args: List[str]) -> Any:
command = " ".join(args)
try:
with ssh_client(ssh_config) as client: # type: ignore
stdin, stdout, stderr = client.exec_command(command, timeout=10)
output = stdout.read().decode("utf-8").strip()
error = stderr.read().decode("utf-8").strip()
except (
paramiko.AuthenticationException,
paramiko.SSHException,
socket.timeout,
socket.error,
) as e:
raise SlurmError(f"Can't retrieve SLURM information: {e}") from e
if error:
raise SlurmError(f"Can't retrieve SLURM information: {error}")
return output


def parse_cpu_used(sinfo_output: str) -> float:
"""
Returns the percentage of used CPUs in the cluster, in range [0, 100].
"""
cpu_info_split = sinfo_output.split("/")
cpu_used_count = int(cpu_info_split[0])
cpu_inactive_count = int(cpu_info_split[1])
return 100 * cpu_used_count / (cpu_used_count + cpu_inactive_count)


def parse_cpu_load(sinfo_output: str) -> float:
"""
Returns the percentage of CPU load in the cluster, in range [0, 100].
"""
lines = sinfo_output.splitlines()
cpus_used = 0.0
cpus_available = 0.0
for line in lines:
values = line.split()
if "N/A" in values:
continue
cpus_used += float(values[0])
cpus_available += float(values[1])
ratio = cpus_used / max(cpus_available, 1)
return 100 * min(1.0, ratio)


def calculates_slurm_load(ssh_config: SSHConfigDTO, partition: str) -> Tuple[float, float, int]:
"""
Returns the used/oad of the SLURM cluster or local machine in percentage and the number of queued jobs.
"""
partition_arg = f"--partition={partition}" if partition else ""

# allocated cpus
arg_list = ["sinfo", partition_arg, "-O", "NodeAIOT", "--noheader"]
sinfo_cpus_used = execute_command(ssh_config, arg_list)
if not sinfo_cpus_used:
args = " ".join(map(shlex.quote, arg_list))
raise SlurmError(f"Can't retrieve SLURM information: [{args}] returned no result")
allocated_cpus = parse_cpu_used(sinfo_cpus_used)

# cluster load
arg_list = ["sinfo", partition_arg, "-N", "-O", "CPUsLoad,CPUs", "--noheader"]
sinfo_cpus_load = execute_command(ssh_config, arg_list)
if not sinfo_cpus_load:
args = " ".join(map(shlex.quote, arg_list))
raise SlurmError(f"Can't retrieve SLURM information: [{args}] returned no result")
cluster_load = parse_cpu_load(sinfo_cpus_load)

# queued jobs
arg_list = ["squeue", partition_arg, "--noheader", "-t", "pending", "|", "wc", "-l"]
queued_jobs = execute_command(ssh_config, arg_list)
if not queued_jobs:
args = " ".join(map(shlex.quote, arg_list))
raise SlurmError(f"Can't retrieve SLURM information: [{args}] returned no result")

return allocated_cpus, cluster_load, int(queued_jobs)
21 changes: 21 additions & 0 deletions antarest/launcher/ssh_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pathlib
from typing import Any, Dict, Optional

import paramiko
from pydantic import BaseModel, root_validator


class SSHConfigDTO(BaseModel):
config_path: pathlib.Path
username: str
hostname: str
port: int = 22
private_key_file: Optional[pathlib.Path] = None
key_password: Optional[str] = ""
password: Optional[str] = ""

@root_validator()
def validate_connection_information(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if "private_key_file" not in values and "password" not in values:
raise paramiko.AuthenticationException("SSH config needs at least a private key or a password")
return values