Skip to content

Commit

Permalink
fix: correct airflow workflow for BYO estimators. (#1052)
Browse files Browse the repository at this point in the history
The airflow workflow for BYO estimators has been broken since a fix
made on 09/13/2019. This resolves that issue.
Tests to ensure this does not happen again will be added as part of
the next airflow commit that prepares training jobs as part of
defining training configuration.
  • Loading branch information
knakad committed Sep 20, 2019
1 parent 0c6c6ff commit 4b0f5af
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def model_data(self):
)["ModelArtifacts"]["S3ModelArtifacts"]
else:
logging.warning(
"No finished training job found associated with this estimator. Please make sure"
"No finished training job found associated with this estimator. Please make sure "
"this estimator is only used for building workflow config"
)
model_uri = os.path.join(
Expand Down Expand Up @@ -718,21 +718,20 @@ def transformer(

if self.latest_training_job is None:
logging.warning(
"No finished training job found associated with this estimator. Please make sure"
"No finished training job found associated with this estimator. Please make sure "
"this estimator is only used for building workflow config"
)
model_name = self._current_job_name
else:
model_name = self.latest_training_job.name
model = self.create_model(vpc_config_override=vpc_config_override)

model = self.create_model(vpc_config_override=vpc_config_override)
# not all create_model() implementations have the same kwargs
model.name = model_name
if role is not None:
model.role = role

# not all create_model() implementations have the same kwargs
model.name = model_name
if role is not None:
model.role = role

model._create_sagemaker_model(instance_type, tags=tags)
model._create_sagemaker_model(instance_type, tags=tags)

return Transformer(
model_name,
Expand Down Expand Up @@ -1716,7 +1715,7 @@ def transformer(
transform_env.update(env)
else:
logging.warning(
"No finished training job found associated with this estimator. Please make sure"
"No finished training job found associated with this estimator. Please make sure "
"this estimator is only used for building workflow config"
)
model_name = self._current_job_name
Expand Down

0 comments on commit 4b0f5af

Please sign in to comment.