-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Description
Reference: MLFW-2710
I might be missing something about the Airflow operators, particularly as they are demonstrated in the examples at https://sagemaker.readthedocs.io/en/stable/using_workflow.html, but it seems like the sagemaker session handling should be done within the Operator objects.
In these examples, a training_config
or transform_config
object needs to be created prior to passing the object to the respective SageMakerTrainingOperator
/ SageMakerTransformOperator
operators. These config objects in turn require the creation of a sagemaker.estimator.Estimator
or sagemaker.transformer.Transformer
object.
So just to check my understanding of the order of operations here, for the training operator we require the following objects in the following dependency order:
sagemaker.estimator.Estimator
training_config
SageMakerTrainingOperator
The problem with this is that the sagemaker.estimator.Estimator
object requires a sagemaker_session
to be passed to it. In Airflow, normally any kind of connection/session information is abstracted away behind the Operator itself. Given the order of object creation listed above, a separate sagemaker session needs to be initialized outside of the Operators; for example, using `airflow.contrib.hooks.sagemaker_hook.SageMakerHook directly within the DAG.
For example:
import sagemaker
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator
aws_session = SageMakerHook('aws_default')
sagemaker_session = sagemaker.session.Session(aws_session.get_session())
estimator = sagemaker.session.Estimator(
image_name = MY_IMAGE,
role = MY_ROLE_ARN,
sagemaker_session = sagemaker_session,
train_instance_count = 1,
train_instance_type = "ml.m5.large",
output_path=MY_TRAINING_OUTPUT_PATH)
train_config = sagemaker.workflow.airflow.training_config(
estimator=estimator,
inputs={"train": MY_TRAINING_INPUT_PATH})
train_op = SageMakerTrainingOperator(
task_id='training',
config=train_config,
aws_conn_id='aws_default',
wait_for_completion=True,
dag=dag)
I believe that the example in the documentation is assuming authentication to a boto3 session using the default AWS configuration chain outside of the context of Airflow's hooks/connections. This is probably not a good assumption to make considering that not all Airflow deployments are configured on top of IAM instance roles. Doing so also breaks Airflow's connection management pattern, which means that connections have to be managed in multiple places through separate configuration mechanisms.
IMHO, it would make sense to handle all of the session configuration within the airflow Operators.
-
I believe one way to do that would be to modify the base objects like
sagemaker.estimator.Estimator
to allow lazy loading of the sagemaker session, which means that anEstimator
object could be initialized without a session. TheEstimator
object would instead have a session set bySageMakerTrainingOperator
once it is passed into the operator. This, however, seems like a fairly significant overhaul of the current code just to support what is essentially supplemental functionality (Airflow integration). -
Another approach could be to include lazy-loader wrappers around the base objects (under
sagemaker.workflow.airflow
) that collect all the arguments forEstimator
,Transform
, etc, except for the session. The wrapper objects could then have a.init(self, session=session)
method that can be called within the Operators.
A DAG using this approach could look something like this:
import sagemaker
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator
estimator = sagemaker.session.Estimator(
image_name = MY_IMAGE,
role = MY_ROLE_ARN,
train_instance_count = 1,
train_instance_type = "ml.m5.large",
output_path=MY_TRAINING_OUTPUT_PATH)
train_config = sagemaker.workflow.airflow.training_config(
estimator=estimator,
inputs={"train": MY_TRAINING_INPUT_PATH})
train_op = SageMakerTrainingOperator(
task_id='training',
config=train_config,
aws_conn_id='aws_default',
wait_for_completion=True,
dag=dag)
- A quicker/dirtier approach could be to modify the Operators to accept a callable for the
config
argument that returns the appropriately initialized config object and adds the session within the Operator.
A DAG using this approach could look something like this:
import sagemaker
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator
# here the session argument is being passed into the callable by SageMakerTrainingOperator
def train_config(session, **context):
estimator = sagemaker.workflow.airflow.estimator_config(
sagemaker_session = session,
image_name = MY_IMAGE,
role = MY_ROLE_ARN,
train_instance_count = 1,
train_instance_type = "ml.m5.large",
output_path=MY_TRAINING_OUTPUT_PATH)
train_config = sagemaker.workflow.airflow.training_config(
estimator=estimator,
inputs={"train": MY_TRAINING_INPUT_PATH})
return train_config
train_op = SageMakerTrainingOperator(
task_id='training',
config=train_config,
aws_conn_id='aws_default',
wait_for_completion=True,
dag=dag)
That actually looks pretty clean.
Anyway, invoking the session using the hook technically works well enough, but it's really not best practice.