@@ -301,3 +301,70 @@ def test_insert_wrong_step_args_into_tuning_step(inputs, pipeline_session):
301301 )
302302
303303 assert "The step_args of TuningStep must be obtained from tuner.fit()" in str (error .value )
304+
305+
306+ def test_tuning_step_with_extra_job_args (pipeline_session , entry_point ):
307+ pytorch_estimator = PyTorch (
308+ entry_point = entry_point ,
309+ role = ROLE ,
310+ framework_version = "1.5.0" ,
311+ py_version = "py3" ,
312+ instance_count = 1 ,
313+ instance_type = "ml.m5.xlarge" ,
314+ sagemaker_session = pipeline_session ,
315+ enable_sagemaker_metrics = True ,
316+ max_retry_attempts = 3 ,
317+ )
318+
319+ hyperparameter_ranges = {
320+ "batch-size" : IntegerParameter (64 , 128 ),
321+ }
322+
323+ tuner = HyperparameterTuner (
324+ estimator = pytorch_estimator ,
325+ objective_metric_name = "test:acc" ,
326+ objective_type = "Maximize" ,
327+ hyperparameter_ranges = hyperparameter_ranges ,
328+ metric_definitions = [{"Name" : "test:acc" , "Regex" : "Overall test accuracy: (.*?);" }],
329+ max_jobs = 2 ,
330+ max_parallel_jobs = 2 ,
331+ )
332+
333+ step_args = tuner .fit (inputs = TrainingInput (s3_data = "s3://my-bucket/my-training-input" ))
334+
335+ ignored_input = "s3://my-bucket/my-input-to-be-ignored"
336+ step = TuningStep (
337+ name = "MyTuningStep" ,
338+ step_args = step_args ,
339+ inputs = ignored_input ,
340+ )
341+
342+ pipeline = Pipeline (
343+ name = "MyPipeline" ,
344+ steps = [step ],
345+ sagemaker_session = pipeline_session ,
346+ )
347+
348+ step_args = get_step_args_helper (step_args , "HyperParameterTuning" )
349+ pipeline_def = pipeline .definition ()
350+ step_def = json .loads (pipeline_def )["Steps" ][0 ]
351+
352+ # delete sagemaker_job_name b/c of timestamp collision
353+ del step_args ["TrainingJobDefinition" ]["StaticHyperParameters" ]["sagemaker_job_name" ]
354+ del step_def ["Arguments" ]["TrainingJobDefinition" ]["StaticHyperParameters" ][
355+ "sagemaker_job_name"
356+ ]
357+
358+ # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled after
359+ # caching improvements phase 2.
360+ del step_args ["TrainingJobDefinition" ]["StaticHyperParameters" ]["sagemaker_submit_directory" ]
361+ del step_def ["Arguments" ]["TrainingJobDefinition" ]["StaticHyperParameters" ][
362+ "sagemaker_submit_directory"
363+ ]
364+
365+ assert step_def == {
366+ "Name" : "MyTuningStep" ,
367+ "Type" : "Tuning" ,
368+ "Arguments" : step_args ,
369+ }
370+ assert ignored_input not in pipeline_def
0 commit comments