Skip to content

Passing PipelineVariable as hyperparameters for Framework Estimator fails #3349

@AndreiVoinovTR

Description

@AndreiVoinovTR

Describe the bug
Generating SageMaker Pipeline definition fails if hyperparameters defined as PipelineVariable are passed to a training step.

Expected behavior
Should PipelineVariable(s) be supported for hyperparametes?
In Estimator's constructor hyperparameters are defined as: Optional[Dict[str, Union[str, sagemaker.workflow.entities.PipelineVariable]]] = None.

Screenshots or logs

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-5-86301c565ea6> in <cell line: 1>()
----> 1 pipeline.build_and_deploy(experiment_config=experiment_config)

~/projects/src/ml-scale/labs-MLTools_CLI/src/trlabs_mltools_cli/pipeline.py in build_and_deploy(self, experiment_config, parameters_override)
     98             service responce : see https://sagemaker.readthedocs.io/en/stable/workflows/pipelines/sagemaker.workflow.pipelines.html#sagemaker.workflow.pipeline.Pipeline.upsert
     99         """  # noqa
--> 100         pipeline = self.build(experiment_config=experiment_config, parameters_override=parameters_override)
    101         logger.info(f"Deploying SageMaker Pipeline: {pipeline} ...")
    102         return pipeline.upsert(self.experiment.workspace.iam_role)

~/projects/src/ml-scale/labs-MLTools_CLI/src/trlabs_mltools_cli/pipeline.py in build(self, experiment_config, parameters_override)
     85             self.get_template(parameters_override), DeploymentContext.from_workspace(self.experiment.workspace)
     86         ).build_pipeline(self._base_ppl_name(), experiment_config=ppl_experiment_config)
---> 87         logger.debug(f"Built SageMaker Pipeline: {pipeline.definition()}")
     88         return pipeline
     89 

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/pipeline.py in definition(self)
    319     def definition(self) -> str:
    320         """Converts a request structure to string representation for workflow service calls."""
--> 321         request_dict = self.to_request()
    322         self._interpolate_step_collection_name_in_depends_on(request_dict["Steps"])
    323         request_dict["PipelineExperimentConfig"] = interpolate(

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/pipeline.py in to_request(self)
    103             if self.pipeline_experiment_config is not None
    104             else None,
--> 105             "Steps": list_to_request(self.steps),
    106         }
    107 

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/utilities.py in list_to_request(entities)
     51     for entity in entities:
     52         if isinstance(entity, Entity):
---> 53             request_dicts.append(entity.to_request())
     54         elif isinstance(entity, StepCollection):
     55             request_dicts.extend(entity.request_dicts())

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/steps.py in to_request(self)
    497     def to_request(self) -> RequestType:
    498         """Updates the request dictionary with cache configuration."""
--> 499         request_dict = super().to_request()
    500         if self.cache_config:
    501             request_dict.update(self.cache_config.config)

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/steps.py in to_request(self)
    349     def to_request(self) -> RequestType:
    350         """Gets the request structure for `ConfigurableRetryStep`."""
--> 351         step_dict = super().to_request()
    352         if self.retry_policies:
    353             step_dict["RetryPolicies"] = self._resolve_retry_policy(self.retry_policies)

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/steps.py in to_request(self)
    118             "Name": self.name,
    119             "Type": self.step_type.value,
--> 120             "Arguments": self.arguments,
    121         }
    122         if self.depends_on:

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/workflow/steps.py in arguments(self)
    476             request_dict = self.step_args
    477         else:
--> 478             self.estimator._prepare_for_training(self.job_name)
    479             train_args = _TrainingJob._get_train_args(
    480                 self.estimator, self.inputs, experiment_config=dict()

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/trlabs_mltools/sagemaker/factory.py in _prepare_for_training(self, job_name)
     91             )
     92 
---> 93         super()._prepare_for_training(job_name=job_name)
     94 
     95 

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/estimator.py in _prepare_for_training(self, job_name)
   2880                 constructor if applicable.
   2881         """
-> 2882         super(Framework, self)._prepare_for_training(job_name=job_name)
   2883 
   2884         self._validate_and_set_debugger_configs()

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/estimator.py in _prepare_for_training(self, job_name)
    705             # Modify hyperparameters in-place to point to the right code directory and
    706             # script URIs
--> 707             self._script_mode_hyperparam_update(code_dir, script)
    708 
    709         self._prepare_rules()

~/projects/src/ml-scale/labs-MLTools_CLI/.venv/lib/python3.8/site-packages/sagemaker/estimator.py in _script_mode_hyperparam_update(self, code_dir, script)
   2898         hyperparams[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
   2899 
-> 2900         self._hyperparameters.update(hyperparams)
   2901 
   2902     def _validate_and_set_debugger_configs(self):

AttributeError: 'ParameterString' object has no attribute 'update'

System information
A description of your system. Please provide:

  • SageMaker Python SDK version: 2.107.0:
  • PyTorch Estimator:

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions