# Lineage

This notebook shows how to use [SageMaker Lineage](https://sagemaker.readthedocs.io/en/stable/workflows/lineage/index.html) to retrieve the links between the different steps of the ML workflow.

***
This notebook has been written for `Data Science 3.0 (Python 3) kernel`
***

In [None]:
%pip install -q pyvis

In [None]:
import pprint

import sagemaker
from sagemaker.lineage.action import Action
from sagemaker.lineage.artifact import Artifact, DatasetArtifact, ModelArtifact
from sagemaker.lineage.association import Association
from sagemaker.lineage.context import Context, EndpointContext
from sagemaker.lineage.query import (
    LineageEntityEnum,
    LineageFilter,
    LineageQuery,
    LineageQueryDirectionEnum,
    LineageSourceEnum,
)
from visualizer import Visualizer

In [None]:
%store -r project_name

In [None]:
sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name
sm_client = sagemaker_session.sagemaker_client
default_bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()
project_id = sm_client.describe_project(ProjectName=project_name)["ProjectId"]

# Helper function to print query outputs
pp = pprint.PrettyPrinter()

print(f"Project: {project_name} ({project_id})")

In [None]:
endpoint_name = f"{project_name}-staging"
endpoint_arn = sm_client.describe_endpoint(
    EndpointName=endpoint_name,
).get("EndpointArn")

In [None]:
contexts = Context.list(source_uri=endpoint_arn)
context_name = list(contexts)[0].context_name
endpoint_context = EndpointContext.load(context_name=context_name)

## Registered model package from Endpoint

In [None]:
query_filter = LineageFilter(
    entities=[LineageEntityEnum.ARTIFACT],
    sources=[LineageSourceEnum.MODEL],
    properties={"ApprovalStatus": "Approved"},
)

# Providing this `LineageFilter` to the `LineageQuery` constructs a query that traverses through the given context `endpoint_context`
# and find all datasets.

query_result = LineageQuery(sagemaker_session).query(
    start_arns=[endpoint_context.context_arn],
    query_filter=query_filter,
    direction=LineageQueryDirectionEnum.ASCENDANTS,
    include_edges=False,
)

# Parse through the query results to get the lineage objects corresponding to the model
model_artifacts = []
for vertex in query_result.vertices:
    model_artifacts.append(vertex.to_lineage_object().source.source_uri)

# The results of the `LineageQuery` API call return the ARN of the model deployed to the endpoint
# along with the S3 URI to the model.tar.gz file associated with the model
pp.pprint(model_artifacts)

## Pipeline and pipeline execution from Endpoint

In [None]:
endpoint_context.pipeline_execution_arn()

In [None]:
pipeline_arn = endpoint_context.pipeline_execution_arn().split("/execution")[0]
pipeline_arn

## Training job from Endpoint

In [None]:
# Define the LineageFilter to look for entities of type `TRIAL_COMPONENT` and the source of type `TRAINING_JOB`.

query_filter = LineageFilter(
    entities=[LineageEntityEnum.TRIAL_COMPONENT],
    sources=[LineageSourceEnum.TRAINING_JOB],
)

# Providing this `LineageFilter` to the `LineageQuery` constructs a query that traverses through the given context `endpoint_context`
# and find all datasets.

query_result = LineageQuery(sagemaker_session).query(
    start_arns=[endpoint_context.context_arn],
    query_filter=query_filter,
    direction=LineageQueryDirectionEnum.ASCENDANTS,
    include_edges=False,
)

# Parse through the query results to get the ARNs of the training jobs associated with this Endpoint
trial_components = []
for vertex in query_result.vertices:
    trial_components.append(vertex.arn)

pp.pprint(trial_components)

In [None]:
training_jobs = [training_job for training_job in endpoint_context.training_job_arns()]
print("Training Jobs :\n", pp.pformat(training_jobs))

## Datasets from Endpoint

In [None]:
print("Datasets:")
for dataset in endpoint_context.dataset_artifacts():
    print(dataset.source.source_uri)

## Visualizing the lineage

In [None]:
viz = Visualizer()

In [None]:
query_response = sm_client.query_lineage(
    StartArns=[endpoint_context.context_arn],
    Direction="Ascendants",
    IncludeEdges=True,
)

viz.render(query_response, "Endpoint", height=500, width=1200)

In [None]:
query_response = sm_client.query_lineage(
    StartArns=[endpoint_context.context_arn],
    Direction="Both",
    IncludeEdges=True,
)
viz.render(query_response, "Endpoint", height=700, width=1200)