# Amazon SageMaker administration and security workshop: Lab 3

This notebook contains hands-on exercises for the workshop **Amazon SageMaker administration and security** – Lab 3.

## Import packages and load variables

In [None]:
import time
import os
import json
import boto3
import numpy as np  
import pandas as pd 
import sagemaker
from sagemaker.network import NetworkConfig
from sagemaker.sklearn.processing import SKLearnProcessor
from sagemaker.processing import ProcessingInput, ProcessingOutput

sagemaker.__version__

In [None]:
%store -r 

%store

try:
    initialized
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] YOU HAVE TO RUN 01-lab-01 notebook         ")
    print("++++++++++++++++++++++++++++++++++++++++++")

In [None]:
# Get some variables you need to interact with SageMaker service
boto_session = boto3.Session()
region = boto_session.region_name
bucket_name = sagemaker.Session().default_bucket()
bucket_prefix = "sm-admin-workshop/xgboost"  
sm_session = sagemaker.Session()
sm_client = boto_session.client("sagemaker")
ssm = boto3.client("ssm")
sm_role = sagemaker.get_execution_role()

## Logging and monitoring
This section demostrates the foundational details about SageMaker log handling.

### SageMaker CloudWatch logs
Refer to the Developer Guide documentation [Log Amazon SageMaker Events with Amazon CloudWatch](https://docs.aws.amazon.com/sagemaker/latest/dg/logging-cloudwatch.html) for basic facts about logging and managing SageMaker events with Amazon CloudWatch.
The following code shows how to describe log groups, log streams, and retrieve log events using CloudWatch API and Python boto3.

In [None]:
logs = boto3.client("logs")

In [None]:
# Use SageMaker log group prefix
sagemaker_log_group_prefix = "/aws/sagemaker"
studio_log_group = f"{sagemaker_log_group_prefix}/studio"

In [None]:
# Retrieve log groups with SageMaker events 
logs.describe_log_groups(
    logGroupNamePrefix=sagemaker_log_group_prefix,
    limit=10,
)

In [None]:
# Get some log streams from the Studio logs
r = logs.describe_log_streams(
    logGroupName=studio_log_group,
    descending=True,
    limit=3,
)
r

In [None]:
# Get some events from the first log stream in the Studio log group
logs.get_log_events(
    logGroupName=studio_log_group,
    logStreamName=r["logStreams"][0]["logStreamName"],
    limit=10,
)

### Isolation of CloudWatch logs for multi-domain setup
This section demonstrates how to use `DOMAIN-ID` and identity-based IAM permission policies to control access to CloudWatch events for own SageMaker domain only.

In [None]:
# Get user profile name
NOTEBOOK_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json"

# Check what profile you're currently in
if os.path.exists(NOTEBOOK_METADATA_FILE):
    with open(NOTEBOOK_METADATA_FILE, "rb") as f:
        user_profile_name = json.loads(f.read())['UserProfileName']
        print(f"User profile: {user_profile_name}")

In [None]:
domain_id, user_profile_name

#### Access to log events without isolation
The user profile execution role doesn't  have any restriction on accessing SageMaker log groups, streams, and events in CloudWatch.
Run the following code cells and verify that you have access to all `/aws/sagemaker/` log groups and log streams within these groups.

In [None]:
# Access to log groups with SageMaker events 
logs.describe_log_groups(
    logGroupNamePrefix=sagemaker_log_group_prefix,
    limit=10,
)

In [None]:
# Access to log streams within Studio log group
r = logs.describe_log_streams(
    logGroupName=studio_log_group,
    descending=True,
    limit=3,
)
r

In [None]:
# Access to log events within any Studio log stream for any domain
logs.get_log_events(
    logGroupName=studio_log_group,
    logStreamName=r["logStreams"][0]["logStreamName"],
    limit=10,
)

#### Access to log events with isolation
Let's implement log access isolation based on the `DOMAIN-ID`. Navigate to the IAM console using the following link:

In [None]:
from IPython.core.display import display, HTML

# Execute this cell to show the execution role IAM console link
display(
    HTML(
        '<b>Add the log isolation policy to the execution role <a target="top" href="https://us-east-1.console.aws.amazon.com/iamv2/home#/roles/details/{}?section=permissions">AWS IAM console.</a></b>'.format(
            sagemaker.get_execution_role().split("/")[-1])
    )
)

In [None]:
account_id, domain_id

Add the folowing IAM inline permission policy to the user execution role:

```json
{
    "Version": "2012-10-17",
    "Statement": [
        {
            "Sid": "DenyNotownedDomainLogs",
            "Effect": "Deny",
            "Action": [
                "logs:GetLogEvents"
            ],
            "NotResource": [
                "arn:aws:logs:*:<ACCOUNT-ID>:log-group:/aws/sagemaker/*:log-stream:<DOMAIN-ID>*"
            ]
        }
    ]
}
```

Replace `<ACCOUNT-ID>` and `<DOMAIN-ID>` with their values.
Save changes.

In [None]:
# Now you don't have access to any log stream event from any other domain. The following call to `DescribeLogEvents` fails with AccessDenied exception
logs.get_log_events(
    logGroupName=studio_log_group,
    logStreamName=r["logStreams"][0]["logStreamName"],
    limit=10,
)

You can access only the log events in your domain-specific log stream.

In [None]:
studio_jupyter_server_log_stream = f"{domain_id}/{user_profile_name}/JupyterServer/default"

In [None]:
logs.describe_log_streams(
    logGroupName=studio_log_group,
    logStreamNamePrefix="d-dech5fdx5938/",
    descending=True,
    limit=3,
)

In [None]:
# Access only domain-specific log events and log streams
logs.get_log_events(
    logGroupName=studio_log_group,
    logStreamName=studio_jupyter_server_log_stream,
    limit=10,
)

### Logging with CloudTrail
Follow the instructions in the workshop lab 3 - Step 1. 
You can run the following step to generate `DescribeDomain` API access log entires in the CloudTrail event history.

In [None]:
sm_client.describe_domain(DomainId=domain_id)

## Security controls

### Preventive
In this section you experiment with IAM policies and condition keys. Follow the instructions in the workshop lab 3 - Step 2.

In [None]:
# Account id and region
account_id = boto3.client("sts").get_caller_identity()["Account"]
region = boto3.Session().region_name

account_id, region

In [None]:
security_group_ids = ssm.get_parameter(Name=f"sagemaker-admin-workshop-{region}-{account_id}-sagemaker-sg-ids")["Parameter"]["Value"]
private_subnet_ids = ssm.get_parameter(Name=f"sagemaker-admin-workshop-{region}-{account_id}-private-subnet-ids")["Parameter"]["Value"]
ebs_key_arn = ssm.get_parameter(Name=f"sagemaker-admin-workshop-{region}-{account_id}-kms-ebs-key-arn")["Parameter"]["Value"]

security_group_ids, private_subnet_ids, ebs_key_arn

In [None]:
# Construct the NetworkConfig with the values for your environment
network_config = NetworkConfig(
        enable_network_isolation=False, 
        security_group_ids=security_group_ids.split(','),
        subnets=private_subnet_ids.split(','),
        encrypt_inter_container_traffic=True)

In [None]:
framework_version = "0.23-1"
processing_instance_type = "ml.m5.large"
processing_instance_count = 1

In [None]:
# Define processing inputs and outputs
processing_inputs = [
        ProcessingInput(
            source=input_s3_url, 
            destination="/opt/ml/processing/input",
            s3_input_mode="File",
            s3_data_distribution_type="ShardedByS3Key"
        )
]

processing_outputs = [
        ProcessingOutput(
            output_name="train_data", 
            source="/opt/ml/processing/output/train",
            destination=train_s3_url,
        ),
        ProcessingOutput(
            output_name="validation_data", 
            source="/opt/ml/processing/output/validation", 
            destination=validation_s3_url
        ),
        ProcessingOutput(
            output_name="test_data", 
            source="/opt/ml/processing/output/test", 
            destination=test_s3_url
        ),
]

In [None]:
# Create a processor
sklearn_processor = SKLearnProcessor(
    framework_version=framework_version,
    role=sm_role,
    instance_type=processing_instance_type,
    instance_count=processing_instance_count, 
    base_job_name='sm-admin-workshop-processing',
    sagemaker_session=sm_session,
    network_config=network_config,
    volume_kms_key = ebs_key_arn
)

In [None]:
# Start the processing job - the call will be successful
sklearn_processor.run(
        inputs=processing_inputs,
        outputs=processing_outputs,
        code='preprocessing.py',
        wait=False,
)

#### Enforce the designated subnets
Let's implement the enforcement of the specific values in the network configuration. Update the preventive IAM policy attached to the user profile execution role as instructed in the workshop lab.

In [None]:
# Set the private_subnet_ids to some new value
private_subnet_ids="subnet-011e4fcfca10fffea"

In [None]:
# Create the NetworkConfig
network_config = NetworkConfig(
        enable_network_isolation=False, 
        security_group_ids=security_group_ids.split(','),
        subnets=private_subnet_ids.split(','),
        encrypt_inter_container_traffic=True)

In [None]:
# Create a processor
sklearn_processor = SKLearnProcessor(
    framework_version=framework_version,
    role=sm_role,
    instance_type=processing_instance_type,
    instance_count=processing_instance_count, 
    base_job_name='sm-admin-workshop-processing',
    sagemaker_session=sm_session,
    network_config=network_config,
    volume_kms_key = ebs_key_arn
)

In [None]:
# Start the processing job - this will raise an AccessDeniedException
sklearn_processor.run(
        inputs=processing_inputs,
        outputs=processing_outputs,
        code='preprocessing.py',
        wait=True,
)

## End of the lab 3

---

## Shutdown kernel

In [None]:
%%html

<p><b>Shutting down your kernel for this notebook to release resources.</b></p>
<button class="sm-command-button" data-commandlinker-command="kernelmenu:shutdown" style="display:none;">Shutdown Kernel</button>
        
<script>
try {
    els = document.getElementsByClassName("sm-command-button");
    els[0].click();
}
catch(err) {
    // NoOp
}    
</script>