Skip to content

SageMaker Airflow Operators should handle sessions #1092

@andrewcstewart

Description

@andrewcstewart

Reference: MLFW-2710

I might be missing something about the Airflow operators, particularly as they are demonstrated in the examples at https://sagemaker.readthedocs.io/en/stable/using_workflow.html, but it seems like the sagemaker session handling should be done within the Operator objects.

In these examples, a training_config or transform_config object needs to be created prior to passing the object to the respective SageMakerTrainingOperator / SageMakerTransformOperator operators. These config objects in turn require the creation of a sagemaker.estimator.Estimator or sagemaker.transformer.Transformer object.

So just to check my understanding of the order of operations here, for the training operator we require the following objects in the following dependency order:

  1. sagemaker.estimator.Estimator
  2. training_config
  3. SageMakerTrainingOperator

The problem with this is that the sagemaker.estimator.Estimator object requires a sagemaker_session to be passed to it. In Airflow, normally any kind of connection/session information is abstracted away behind the Operator itself. Given the order of object creation listed above, a separate sagemaker session needs to be initialized outside of the Operators; for example, using `airflow.contrib.hooks.sagemaker_hook.SageMakerHook directly within the DAG.

For example:

import sagemaker
from airflow.contrib.hooks.sagemaker_hook import SageMakerHook
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator

aws_session = SageMakerHook('aws_default')
sagemaker_session = sagemaker.session.Session(aws_session.get_session())

estimator = sagemaker.session.Estimator(
  image_name = MY_IMAGE,
  role = MY_ROLE_ARN,
  sagemaker_session = sagemaker_session,
  train_instance_count = 1,
  train_instance_type = "ml.m5.large",
  output_path=MY_TRAINING_OUTPUT_PATH)

train_config = sagemaker.workflow.airflow.training_config(
  estimator=estimator,
  inputs={"train": MY_TRAINING_INPUT_PATH})

train_op = SageMakerTrainingOperator(
    task_id='training',
    config=train_config,
    aws_conn_id='aws_default',
    wait_for_completion=True,
    dag=dag)

I believe that the example in the documentation is assuming authentication to a boto3 session using the default AWS configuration chain outside of the context of Airflow's hooks/connections. This is probably not a good assumption to make considering that not all Airflow deployments are configured on top of IAM instance roles. Doing so also breaks Airflow's connection management pattern, which means that connections have to be managed in multiple places through separate configuration mechanisms.

IMHO, it would make sense to handle all of the session configuration within the airflow Operators.

  1. I believe one way to do that would be to modify the base objects like sagemaker.estimator.Estimator to allow lazy loading of the sagemaker session, which means that an Estimator object could be initialized without a session. The Estimator object would instead have a session set by SageMakerTrainingOperator once it is passed into the operator. This, however, seems like a fairly significant overhaul of the current code just to support what is essentially supplemental functionality (Airflow integration).

  2. Another approach could be to include lazy-loader wrappers around the base objects (under sagemaker.workflow.airflow) that collect all the arguments for Estimator, Transform, etc, except for the session. The wrapper objects could then have a .init(self, session=session) method that can be called within the Operators.

A DAG using this approach could look something like this:

import sagemaker
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator

estimator = sagemaker.session.Estimator(
  image_name = MY_IMAGE,
  role = MY_ROLE_ARN,
  train_instance_count = 1,
  train_instance_type = "ml.m5.large",
  output_path=MY_TRAINING_OUTPUT_PATH)

train_config = sagemaker.workflow.airflow.training_config(
  estimator=estimator,
  inputs={"train": MY_TRAINING_INPUT_PATH})

train_op = SageMakerTrainingOperator(
    task_id='training',
    config=train_config,
    aws_conn_id='aws_default',
    wait_for_completion=True,
    dag=dag)
  1. A quicker/dirtier approach could be to modify the Operators to accept a callable for the config argument that returns the appropriately initialized config object and adds the session within the Operator.

A DAG using this approach could look something like this:

import sagemaker
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator

# here the session argument is being passed into the callable by SageMakerTrainingOperator
def train_config(session, **context):
     estimator = sagemaker.workflow.airflow.estimator_config(
          sagemaker_session = session,
          image_name = MY_IMAGE,
          role = MY_ROLE_ARN,
          train_instance_count = 1,
          train_instance_type = "ml.m5.large",
          output_path=MY_TRAINING_OUTPUT_PATH)
     train_config = sagemaker.workflow.airflow.training_config(
          estimator=estimator,
          inputs={"train": MY_TRAINING_INPUT_PATH})
     return train_config

train_op = SageMakerTrainingOperator(
    task_id='training',
    config=train_config,
    aws_conn_id='aws_default',
    wait_for_completion=True,
    dag=dag)

That actually looks pretty clean.

Anyway, invoking the session using the hook technically works well enough, but it's really not best practice.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions