Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise error for inconsistent add_ml_model and add_script parameters #324

Merged
merged 17 commits into from
Jul 29, 2023
7 changes: 2 additions & 5 deletions smartsim/_core/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@ def db_is_active(hosts: t.List[str], ports: t.List[int], num_shards: int) -> boo

def set_ml_model(db_model: DBModel, client: Client) -> None:
logger.debug(f"Adding DBModel named {db_model.name}")
devices = db_model._enumerate_devices() # pylint: disable=protected-access

for device in devices:
for device in db_model.devices:
try:
if db_model.is_file:
client.set_model_from_file(
Expand Down Expand Up @@ -194,9 +193,7 @@ def set_ml_model(db_model: DBModel, client: Client) -> None:
def set_script(db_script: DBScript, client: Client) -> None:
logger.debug(f"Adding DBScript named {db_script.name}")

devices = db_script._enumerate_devices() # pylint: disable=protected-access

for device in devices:
for device in db_script.devices:
try:
if db_script.is_file:
client.set_script_from_file(
Expand Down
54 changes: 35 additions & 19 deletions smartsim/entity/dbobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from pathlib import Path
from .._core.utils import init_default
from ..error import SSUnsupportedError


__all__ = ["DBObject", "DBModel", "DBScript"]
Expand All @@ -43,7 +44,7 @@
name: str,
func: t.Optional[str],
file_path: t.Optional[str],
device: str,
device: t.Literal["CPU", "GPU"],
devices_per_node: int,
) -> None:
self.name = name
Expand All @@ -55,6 +56,11 @@
self.file = self._check_filepath(file_path)
self.device = self._check_device(device)
self.devices_per_node = devices_per_node
self._check_devices(device, devices_per_node)

@property
def devices(self) -> t.List[str]:
return self._enumerate_devices()

@property
def is_file(self) -> bool:
Expand Down Expand Up @@ -95,8 +101,8 @@
return file_path

@staticmethod
def _check_device(device: str) -> str:
device = device.upper()
def _check_device(device: t.Literal["CPU", "GPU"]) -> str:
device = t.cast(t.Literal["CPU", "GPU"], device.upper())
if not device.startswith("CPU") and not device.startswith("GPU"):
raise ValueError("Device argument must start with either CPU or GPU")
return device
Expand All @@ -109,21 +115,31 @@
:return: list of device names
:rtype: list[str]
"""
devices = []
if ":" in self.device and self.devices_per_node > 1:
msg = (
"Cannot set devices_per_node>1 if a device numeral is specified, "
f"the device was set to {self.device} and "
f"devices_per_node=={self.devices_per_node}"
)
raise ValueError(msg)
if self.device in ["CPU", "GPU"] and self.devices_per_node > 1:
for device_num in range(self.devices_per_node):
devices.append(f"{self.device}:{str(device_num)}")
else:
devices = [self.device]

return devices
if self.device == "GPU" and self.devices_per_node > 1:
return [

Check warning on line 120 in smartsim/entity/dbobject.py

View check run for this annotation

Codecov / codecov/patch

smartsim/entity/dbobject.py#L120

Added line #L120 was not covered by tests
f"{self.device}:{str(device_num)}"
for device_num in range(self.devices_per_node)
]

return [self.device]

@staticmethod
def _check_devices(
device: t.Literal["CPU", "GPU"], devices_per_node: int
) -> None:
if devices_per_node == 1:
return

if ":" in device:
msg = "Cannot set devices_per_node>1 if a device numeral is specified, "
msg += f"the device was set to {device} and \

Check warning on line 136 in smartsim/entity/dbobject.py

View check run for this annotation

Codecov / codecov/patch

smartsim/entity/dbobject.py#L135-L136

Added lines #L135 - L136 were not covered by tests
devices_per_node=={devices_per_node}"
raise ValueError(msg)

Check warning on line 138 in smartsim/entity/dbobject.py

View check run for this annotation

Codecov / codecov/patch

smartsim/entity/dbobject.py#L138

Added line #L138 was not covered by tests
if device == "CPU":
raise SSUnsupportedError(
"Cannot set devices_per_node>1 if CPU is specified under devices"
)
Copy link
Contributor

@ankona ankona Jul 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check device?

if not self._check_device(self.device):
  raise ValueError("invalid device...")

I think we have a hole since device is an attribute. Changes after the constructor won't be validated.

Consider adding properties that use _check_device on sets!



class DBScript(DBObject):
Expand All @@ -132,7 +148,7 @@
name: str,
script: t.Optional[str] = None,
script_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU", "GPU"] = "CPU",
devices_per_node: int = 1,
):
"""TorchScript code represenation
Expand Down Expand Up @@ -185,7 +201,7 @@
backend: str,
model: t.Optional[str] = None,
model_file: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU", "GPU"] = "CPU",
devices_per_node: int = 1,
batch_size: int = 0,
min_batch_size: int = 0,
Expand Down
6 changes: 3 additions & 3 deletions smartsim/entity/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def add_ml_model(
backend: str,
model: t.Optional[str] = None,
model_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
batch_size: int = 0,
min_batch_size: int = 0,
Expand Down Expand Up @@ -395,7 +395,7 @@ def add_script(
name: str,
script: t.Optional[str] = None,
script_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
) -> None:
"""TorchScript to launch with every entity belonging to this ensemble
Expand Down Expand Up @@ -439,7 +439,7 @@ def add_function(
self,
name: str,
function: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
) -> None:
"""TorchScript function to launch with every entity belonging to this ensemble
Expand Down
6 changes: 3 additions & 3 deletions smartsim/entity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def add_ml_model(
backend: str,
model: t.Optional[str] = None,
model_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
batch_size: int = 0,
min_batch_size: int = 0,
Expand Down Expand Up @@ -467,7 +467,7 @@ def add_script(
name: str,
script: t.Optional[str] = None,
script_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
) -> None:
"""TorchScript to launch with this Model instance
Expand Down Expand Up @@ -511,7 +511,7 @@ def add_function(
self,
name: str,
function: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
) -> None:
"""TorchScript function to launch with this Model instance
Expand Down
23 changes: 23 additions & 0 deletions tests/backends/test_dbmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from smartsim.error.errors import SSUnsupportedError
from smartsim.log import get_logger

from smartsim.entity.dbobject import DBModel

logger = get_logger(__name__)

should_run_tf = True
Expand Down Expand Up @@ -793,3 +795,24 @@ def test_colocated_db_model_errors(fileutils, wlmutils, mlutils):

with pytest.raises(SSUnsupportedError):
colo_ensemble.add_model(colo_model)

def test_inconsistent_params_db_model():
"""Test error when devices_per_node parameter>1 when devices is set to CPU in DBModel"""

# Create and save ML model to filesystem
model, inputs, outputs = create_tf_cnn()
with pytest.raises(SSUnsupportedError) as ex:
db_model = DBModel(
"cnn",
"TF",
model=model,
device="CPU",
devices_per_node=2,
tag="test",
inputs=inputs,
outputs=outputs,
)
assert (
ex.value.args[0]
== "Cannot set devices_per_node>1 if CPU is specified under devices"
)
18 changes: 18 additions & 0 deletions tests/backends/test_dbscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from smartsim.error.errors import SSUnsupportedError
from smartsim.log import get_logger

from smartsim.entity.dbobject import DBScript

logger = get_logger(__name__)

should_run = True
Expand Down Expand Up @@ -578,3 +580,19 @@ def test_db_script_errors(fileutils, wlmutils, mlutils):
# an in-memory script
with pytest.raises(SSUnsupportedError):
colo_ensemble.add_model(colo_model)

def test_inconsistent_params_db_script(fileutils):
"""Test error when devices_per_node>1 and when devices is set to CPU in DBScript constructor"""

torch_script = fileutils.get_test_conf_path("torchscript.py")
with pytest.raises(SSUnsupportedError) as ex:
db_script = DBScript(
name="test_script_db",
script_path = torch_script,
device="CPU",
devices_per_node=2,
)
assert (
ex.value.args[0]
== "Cannot set devices_per_node>1 if CPU is specified under devices"
)
Loading