# Lineage Tracking and Traversal Example

SageMaker Lineage makes it easy to track all the artifacts created in a machine learning workflow
 from start to finish.

The [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable) is an SDK to train and 
deploy Apache MXNet models. In this example, we train a simple neural network using the Apache MXNet [Module API](https://mxnet.apache.org/api/python/module/module.html) and the MNIST dataset. 


Make sure you selected `conda_mxnet_p36` kernel.

Make sure:
* your account has been whitelisted
* your execution role has the appropriate trusts

In [1]:
import sys
# Import Private Beta SDK.
!{sys.executable} -m pip install -q -U pip
!{sys.executable} -m pip install -q sagemaker-2.6.1.dev0.tar.gz

In [2]:
import IPython
#may need to restart the kernel after initial install of beta sdk
#IPython.Application.instance().kernel.do_shutdown(True)

In [3]:
from sagemaker import get_execution_role
from sagemaker.session import Session
from sagemaker.lineage import context, artifact, association, action
import boto3
from datetime import datetime
import logging
import os

In [4]:
# lineage beta only available in CMH
region = 'us-east-2'

# S3 bucket for saving code and model artifacts.
# Feel free to specify a different bucket here if you wish.
bucket = Session().default_bucket()
boto_session = boto3.Session(region_name=region)
sagemaker_client = boto_session.client("sagemaker")

In [5]:
# Bucket location where your custom code will be saved in the tar.gz format.
custom_code_upload_location = 's3://{}/mxnet-mnist-example/code'.format(bucket)
list_response = list(artifact.Artifact.list(source_uri=custom_code_upload_location, sagemaker_boto_client=sagemaker_client))

if len(list_response):
    code_artifact_arn = list_response[0].artifact_arn
else:
    code_artifact_arn = artifact.Artifact.create(
        artifact_name='SourceCodeLocation',
        source_uri=custom_code_upload_location,
        artifact_type='codelocation',
        sagemaker_boto_client=sagemaker_client
    ).artifact_arn

# Bucket location where results of model training are saved.
model_artifacts_location = 's3://{}/mxnet-mnist-example/artifacts'.format(bucket)
list_response = list(artifact.Artifact.list(source_uri=model_artifacts_location, sagemaker_boto_client=sagemaker_client))
if len(list_response):
    model_location_artifact_arn = list_response[0].artifact_arn
else:
    model_location_artifact_arn = artifact.Artifact.create(
        artifact_name='model-artifacts-location',
        source_uri=model_artifacts_location,
        artifact_type='model-artifacts-location',
        sagemaker_boto_client=sagemaker_client,
    ).artifact_arn

# IAM execution role that gives SageMaker access to resources in your AWS account.
# We can use the SageMaker Python SDK to get the role from our notebook environment. 
role = get_execution_role()

### The training script

The `mnist.py` script provides all the code we need for training and hosting a SageMaker model. The script also checkpoints the model at the end of every epoch and saves the model graph, params and optimizer state in the folder `/opt/ml/checkpoints`. If the folder path does not exist then it skips checkpointing. The script we use is adaptated from Apache MXNet [MNIST tutorial](https://mxnet.incubator.apache.org/tutorials/python/mnist.html).



### SageMaker's MXNet estimator class

In [11]:
from sagemaker.mxnet import MXNet

mnist_estimator = MXNet(entry_point='mnist.py',
                        role=role,
                        output_path=model_artifacts_location,
                        code_location=custom_code_upload_location,
                        instance_count=1,
                        instance_type='ml.m4.xlarge',
                        framework_version='1.4.1',
                        py_version='py3',
                        #distributions={'parameter_server': {'enabled': True}},
                        hyperparameters={'learning-rate': 0.1})

### Running the Training Job

After we've constructed our MXNet object, we can fit it using data stored in S3. Below we run SageMaker training on two input channels: **train** and **test**.

During training, SageMaker makes this data stored in S3 available in the local filesystem where the mnist script is running. The ```mnist.py``` script simply loads the train and test data from disk.

In [12]:
%%time

train_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/train'.format(region)
test_data_location = 's3://sagemaker-sample-data-{}/mxnet/mnist/test'.format(region)

mnist_estimator.fit({'train': train_data_location, 'test': test_data_location})

2020-09-18 19:11:17 Starting - Starting the training job...
2020-09-18 19:11:19 Starting - Launching requested ML instances......
2020-09-18 19:12:24 Starting - Preparing the instances for training......
2020-09-18 19:13:39 Downloading - Downloading input data
2020-09-18 19:13:39 Training - Downloading the training image...
2020-09-18 19:13:59 Training - Training image download completed. Training in progress.[34m2020-09-18 19:14:00,197 sagemaker-containers INFO     Imported framework sagemaker_mxnet_container.training[0m
[34m2020-09-18 19:14:00,201 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2020-09-18 19:14:00,215 sagemaker_mxnet_container.training INFO     MXNet training environment: {'SM_HOSTS': '["algo-1"]', 'SM_NETWORK_INTERFACE_NAME': 'eth0', 'SM_HPS': '{"learning-rate":0.1}', 'SM_USER_ENTRY_POINT': 'mnist.py', 'SM_FRAMEWORK_PARAMS': '{}', 'SM_RESOURCE_CONFIG': '{"current_host":"algo-1","hosts":["algo-1"],"network_interface_name":"eth

In [8]:
list_response = list(artifact.Artifact.list(source_uri=train_data_location, sagemaker_boto_client=sagemaker_client))
if len(list_response):
    train_data_location_artifact_arn = list_response[0].artifact_arn
else:
    train_data_location_artifact_arn = artifact.Artifact.create(
        artifact_name='train-data',
        artifact_type='TrainingData',
        source_uri=train_data_location,
        sagemaker_boto_client=sagemaker_client,
    ).artifact_arn

list_response = list(artifact.Artifact.list(source_uri=test_data_location, sagemaker_boto_client=sagemaker_client))
if len(list_response):
    test_data_location_artifact_arn = list_response[0].artifact_arn
else:
    test_data_location_artifact_arn = artifact.Artifact.create(
        artifact_name='test-data',
        artifact_type='TestData',
        source_uri=test_data_location,
        sagemaker_boto_client=sagemaker_client,
    ).artifact_arn

In [9]:
# associate the artifacts

training_job_name = mnist_estimator.latest_training_job.job_name

trial_component = sagemaker_client.describe_trial_component(TrialComponentName=training_job_name + '-aws-training-job')
trial_component_arn=trial_component['TrialComponentArn']

input_artifacts = [code_artifact_arn, train_data_location_artifact_arn, test_data_location_artifact_arn]
for artifact_arn in input_artifacts:
    try:
        association.Association.create(
            source_arn=artifact_arn,
            destination_arn=trial_component_arn,
            association_type='ContributedTo',
            sagemaker_boto_client=sagemaker_client,
        )
    except:
        logging.info('association between {} and {} already exists', artifact_arn, trial_component_arn)

output_artifacts = [model_location_artifact_arn]
for artifact_arn in output_artifacts:
    try:
         association.Association.create(
            source_arn=trial_component_arn,
            destination_arn=artifact_arn,
            association_type='Produced',
            sagemaker_boto_client=sagemaker_client,
        )
    except:
        logging.info('association between {} and {} already exists', artifact_arn, trial_component_arn)

### Creating an inference Endpoint

After training, we use the ``MXNet estimator`` object to build and deploy an ``MXNetPredictor``. This creates a Sagemaker **Endpoint** -- a hosted prediction service that we can use to perform inference. 

The arguments to the ``deploy`` function allow us to set the number and type of instances that will be used for the Endpoint. These do not need to be the same as the values we used for the training job. For example, you can train a model on a set of GPU-based instances, and then deploy the Endpoint to a fleet of CPU-based instances. Here we will deploy the model to a single ``ml.m4.xlarge`` instance.

In [14]:
%%time

predictor = mnist_estimator.deploy(initial_instance_count=1,
                                   instance_type='ml.m4.xlarge')

-------------!CPU times: user 368 ms, sys: 12.5 ms, total: 380 ms
Wall time: 6min 32s


In [15]:
from sagemaker.lineage import context

endpoint = sagemaker_client.describe_endpoint(EndpointName=predictor.endpoint_name)
endpoint_arn = endpoint['EndpointArn']

list_response = list(context.Context.list(source_uri=endpoint_arn, sagemaker_boto_client=sagemaker_client))
if len(list_response):
    endpoint_context_arn = list_response[0].context_arn
else:
    endpoint_context_arn = context.Context.create(
        context_name=predictor.endpoint_name,
        context_type='Endpoint',
        source_uri=endpoint_arn,
        sagemaker_boto_client=sagemaker_client, 
    ).context_arn

association.Association.create(
    source_arn=trial_component_arn,
    destination_arn=endpoint_context_arn,
    sagemaker_boto_client=sagemaker_client,
)

Association(sagemaker_boto_client=<botocore.client.SageMaker object at 0x7ff40a982d68>,source_arn='arn:aws:sagemaker:us-east-2:707662012936:experiment-trial-component/mxnet-training-2020-09-18-19-07-09-802-aws-training-job',destination_arn='arn:aws:sagemaker:us-east-2:707662012936:context/mxnet-training-2020-09-18-19-21-54-609',association_type=None,response_metadata={'RequestId': 'd6391fc9-edd2-43b9-8338-d996287f1140', 'HTTPStatusCode': 200, 'HTTPHeaders': {'x-amzn-requestid': 'd6391fc9-edd2-43b9-8338-d996287f1140', 'content-type': 'application/x-amz-json-1.1', 'content-length': '246', 'date': 'Fri, 18 Sep 2020 19:28:54 GMT'}, 'RetryAttempts': 0})

In [16]:
predictor.delete_endpoint()

In [None]:
%run lineage_visualizer.py

import lineage_visualizer

vis = LineageVisualizer(sagemaker_client)
vis.both(endpoint_context_arn)

In [None]:
file_name = vis.write_yaml()
f = open(file_name, "r")
print(f.read())