|
23 | 23 | import time
|
24 | 24 |
|
25 | 25 | from abc import ABC
|
| 26 | +from typing import Union, Optional |
26 | 27 |
|
27 | 28 | import attr
|
28 | 29 |
|
29 | 30 | import smdebug_rulesconfig as rule_configs
|
30 | 31 |
|
31 | 32 | from sagemaker import image_uris
|
32 | 33 | from sagemaker.utils import build_dict
|
| 34 | +from sagemaker.workflow.entities import PipelineVariable |
33 | 35 |
|
34 | 36 | framework_name = "debugger"
|
35 | 37 | DEBUGGER_FLAG = "USE_SMDEBUG"
|
@@ -311,10 +313,10 @@ def sagemaker(
|
311 | 313 | @classmethod
|
312 | 314 | def custom(
|
313 | 315 | cls,
|
314 |
| - name, |
315 |
| - image_uri, |
316 |
| - instance_type, |
317 |
| - volume_size_in_gb, |
| 316 | + name: str, |
| 317 | + image_uri: Union[str, PipelineVariable], |
| 318 | + instance_type: Union[str, PipelineVariable], |
| 319 | + volume_size_in_gb: Union[int, PipelineVariable], |
318 | 320 | source=None,
|
319 | 321 | rule_to_invoke=None,
|
320 | 322 | container_local_output_path=None,
|
@@ -610,7 +612,7 @@ class DebuggerHookConfig(object):
|
610 | 612 |
|
611 | 613 | def __init__(
|
612 | 614 | self,
|
613 |
| - s3_output_path=None, |
| 615 | + s3_output_path: Optional[Union[str, PipelineVariable]] = None, |
614 | 616 | container_local_output_path=None,
|
615 | 617 | hook_parameters=None,
|
616 | 618 | collection_configs=None,
|
@@ -679,7 +681,9 @@ def _to_request_dict(self):
|
679 | 681 | class TensorBoardOutputConfig(object):
|
680 | 682 | """Create a tensor ouput configuration object for debugging visualizations on TensorBoard."""
|
681 | 683 |
|
682 |
| - def __init__(self, s3_output_path, container_local_output_path=None): |
| 684 | + def __init__( |
| 685 | + self, s3_output_path: Union[str, PipelineVariable], container_local_output_path=None |
| 686 | + ): |
683 | 687 | """Initialize the TensorBoardOutputConfig instance.
|
684 | 688 |
|
685 | 689 | Args:
|
|
0 commit comments