Skip to content

Commit

Permalink
use custom dlc
Browse files Browse the repository at this point in the history
  • Loading branch information
suzhoum committed Jun 8, 2024
1 parent 55c3169 commit 7d9fda5
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/autogluon/cloud/backend/ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,10 @@ def _get_image_uri(self, framework_version: str, instance_type: str, custom_imag
image_scope="training",
instance_type=instance_type,
)
if "g4dn" in instance_type:
image_uri = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:gpu-latest"
else:
image_uri = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:cpu-latest"
return image_uri

def _construct_ag_args(self, predictor_init_args, predictor_fit_args, **kwargs):
Expand Down
12 changes: 12 additions & 0 deletions src/autogluon/cloud/utils/ag_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def __init__(
image_scope="training",
instance_type=instance_type,
)
if "g4dn" in instance_type:
self.image_uri = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:gpu-latest"
else:
self.image_uri = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:cpu-latest"
super().__init__(
entry_point=entry_point,
source_dir=source_dir,
Expand Down Expand Up @@ -74,6 +78,10 @@ def create_model(
image_scope="inference",
instance_type=instance_type,
)
if instance_type == "gpu":
image_uri = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:gpu-inference-latest"
else:
image_uri = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:cpu-inference-latest"
if predictor_cls is None:

def predict_wrapper(endpoint, session):
Expand Down Expand Up @@ -141,6 +149,10 @@ def __init__(
image_scope="inference",
instance_type=instance_type,
)
if instance_type == "gpu":
image_uri = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:gpu-inference-latest"
else:
image_uri = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:cpu-inference-latest"
# setting PYTHONUNBUFFERED to disable output buffering for endpoints logging
if env is None:
env = {}
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@


class CloudTestHelper:
cpu_training_image = "369469875935.dkr.ecr.us-east-1.amazonaws.com/autogluon-nightly-training:cpu-latest"
gpu_training_image = "369469875935.dkr.ecr.us-east-1.amazonaws.com/autogluon-nightly-training:gpu-latest"
cpu_inference_image = "369469875935.dkr.ecr.us-east-1.amazonaws.com/autogluon-nightly-inference:cpu-latest"
gpu_inference_image = "369469875935.dkr.ecr.us-east-1.amazonaws.com/autogluon-nightly-inference:gpu-latest"
cpu_training_image = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:cpu-latest"
gpu_training_image = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:gpu-latest"
cpu_inference_image = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:cpu-inference-latest"
gpu_inference_image = "369469875935.dkr.ecr.us-east-1.amazonaws.com/weisu:gpu-inference-latest"

@staticmethod
def get_custom_image_uri(framework_version="source", type="training", gpu=False):
Expand Down

0 comments on commit 7d9fda5

Please sign in to comment.