# Training Job with Encrypted Static Assets

In the [notebook about creating a training job in VPC mode](https://github.com/aws/amazon-sagemaker-examples/blob/master/sagemaker-fundamentals/create-training-job/create_training_job_vpc.ipynb) you learnt how to create a SageMaker training job with network isolation. Network isolation enables you to protect your data and model from being intercepted by cyber pirates. 

Another way you can protect your static assets is to encrypt them before moving them from location A to location B. In this notebook, you will walk through a few techniques on that with the help of AWS Key Management Service [(AWS KMS)](https://docs.aws.amazon.com/kms/latest/developerguide/overview.html).

The following materials are helpful to get you started if you are not familiar with cryptography:

* [AWS KMS Crytography Details](https://docs.aws.amazon.com/kms/latest/cryptographic-details/intro.html)
* [Wikipedia page](https://en.wikipedia.org/wiki/Encryption)
* [Chapter 2 of GNU Privacy Handbook](https://www.gnupg.org/gph/en/manual.html)

You are strongly encouraged to go through the [overview](https://docs.aws.amazon.com/kms/latest/developerguide/overview.html), [concepts](https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html) and [get started](https://docs.aws.amazon.com/kms/latest/developerguide/getting-started.html) sections from the KMS documentations before going through this notebook. This will help you getting familiar with some terminologies we will be using later. 

Encryption is a wildly used technology, in addition to the above introductory material, you can find many free lectures online. 

## Symmetric Ciphers
We will focus on symmetric ciphers in this notebook. Quote from the GNU Privacy Handbook

> A symmetric cipher is a cipher that uses the same key for both encryption and decryption. Two parties communicating using a symmetric cipher must agree on the key beforehand. Once they agree, the sender encrypts a message using the key, sends it to the receiver, and the receiver decrypts the message using the key. As an example, the German Enigma is a symmetric cipher, and daily keys were distributed as code books. Each day, a sending or receiving radio operator would consult his copy of the code book to find the day's key. Radio traffic for that day was then encrypted and decrypted using the day's key. Modern examples of symmetric ciphers include 3DES, Blowfish, and IDEA.

## Environment to run this notebook
You can run this notebook on your local machine or EC2 instance as an IAM user or you can run it on SageMaker Notebook Instance as a SageMaker service role. To avoid confusion, we will assume you are running it as an IAM user.

## Permissions
You will need to attach the following permissions to the IAM user

* IAMFullAccess 
* AWSKeyManagementServicePowerUser
* AmazonEC2ContainerRegistryFullAccess

## Outline of this notebook

* Generate a symmetric KMS key
* Allow your SageMaker service role to use the KMS key
* Generate a data key from the KMS key
* Encrypt some data with the data key and upload the encrypted data to S3
* Create a SageMaker service role
* Build a training image 
* Create a SageMaker training job using the encrypted data
* Verify that data retrieved from S3 is encrypted and SageMaker needs your data key to decrypt

The process of using a data key to encrypt your data instead of using the KMS key directly is called [**envelope encryption**](https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#enveloping).
You can directly use the KMS key to encrypt your data, but by using a data key, you reduced the risk of [man-in-the-middle-attack](https://en.wikipedia.org/wiki/Man-in-the-middle_attack). 
We will discuss the use of data key in detail later. 

![envelope-encryption](assets/key-hierarchy-cmk.png)

In [9]:
# set ups
import boto3
import datetime
import json
import pprint

pp = pprint.PrettyPrinter(indent=1)
kms = boto3.client('kms') 

In [41]:
# Some helper functions

def current_time():
    ct = datetime.datetime.now() 
    return str(ct.now()).replace(":", "-").replace(" ", "-")[:19]

def account_id():
    return boto3.client('sts').get_caller_identity()['Account']

### Generate a symmetric KMS key

You will use [kms:CreateKey](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kms.html#KMS.Client.create_key) API to generate a **symmetric key** used for **encryption** and **decryption**. You need to use a IAM policy to define who has access (and with what level of access) to the key. 
If you create the key from AWS console, then by following the default steps you will end up the following key policy:

In [None]:
root_arn = f"arn:aws:iam::{account_id()}:root"
user_arn = boto3.client('sts').get_caller_identity()['Arn']

key_policy = {
    "Id": "key-consolepolicy-3",
    "Version": "2012-10-17",
    "Statement": [
        {
            "Sid": "Enable IAM User Permissions",
            "Effect": "Allow",
            "Principal": {
                "AWS": root_arn # enable root user to perform all actions
            },
            "Action": "kms:*",
            "Resource": "*"
        },
        
        {
            "Sid": "Allow access for Key Administrators",
            "Effect": "Allow",
            "Principal": {
                "AWS": [user_arn]   # give myself admin permission to this key
                                     # you can add more admin users by appending this list
            },
            "Action": [
                "kms:Create*",
                "kms:Describe*",
                "kms:Enable*",
                "kms:List*",
                "kms:Put*",
                "kms:Update*",
                "kms:Revoke*",
                "kms:Disable*",
                "kms:Get*",
                "kms:Delete*",
                "kms:TagResource",
                "kms:UntagResource",
                "kms:ScheduleKeyDeletion",
                "kms:CancelKeyDeletion"
            ],
            "Resource": "*"
        },
        {
            "Sid": "Allow use of the key",
            "Effect": "Allow",
            "Principal": {
                "AWS": [user_arn]   # allow myself to use the key
                                     # you can add more users / roles to this list
                                     # for example you can add SageMaker service role 
                                     # here. But we will allow SageMaker service role
                                     # to use this key via grant (see below)

            },
            "Action": [
                "kms:Encrypt",
                "kms:Decrypt",
                "kms:ReEncrypt*",
                "kms:GenerateDataKey*",
                "kms:DescribeKey"
            ],
            "Resource": "*"
        },
        {
            "Sid": "Allow attachment of persistent resources",
            "Effect": "Allow",
            "Principal": {
                "AWS": [user_arn] # allow myself to create grant for this key
                                   # see ref below to understand the diff 
                                   # between user and grant
                                   # https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#grant
            },
            "Action": [
                "kms:CreateGrant",
                "kms:ListGrants",
                "kms:RevokeGrant"
            ],
            "Resource": "*",
            "Condition": {
                "Bool": {
                    "kms:GrantIsForAWSResource": "true"
                }
            }
        }
    ]
}

key_policy = json.dumps(key_policy)

You can either create a new key or use an existing key. If you want to use an existing key, set `create_new_key` variable to `False` and replace `None` in line 23 by your key id. Note that in order to run this notebook, the key policy of your existing key should grant you AT LEAST the same level of access as the above key policy. 

In [None]:
# create a key with the above key policy

create_new_key = False 

if create_new_key:
    ck_res = kms.create_key(
        Policy=key_policy,
        Description="a symmetric key to demonstrate KMS",
        KeyUsage="ENCRYPT_DECRYPT",                # use this key to encrypt and decrypt
        Origin='AWS_KMS',                          # created via AWS KMS
        CustomerMasterKeySpec='SYMMETRIC_DEFAULT'  # symmetric key
    )


    pp.pprint(ck_res)
    kms_key = ck_res['KeyMetadata']['KeyId']
    print("The id of the key: ")
    print(kms_key)
else:
    print("Supply an existing KMS key by setting kms_key variable to your key id")
    
    # replace None by your CMK key id
    kms_key = None
    
if kms_key is None:
    raise ValueError("Supply a valid KMS key id or create a new one")

In [None]:
kms_key = ck_res['KeyMetadata']['KeyId']
print("The id of the key: ")
print(kms_key)

You can use this KMS key to encrypt your data directly. It is not a good practice in production. But it is good to know what you can do.

In [None]:
my_secret_message = "1729 is the smallest number expressible \
as the sum of two cubes in two different ways".encode('utf-8')

# 1729 =  1^3 + 12^3 = 9^3 + 10^3 (Srinivasa Ramanujan)

# make the above secret a ciphertext
enc_res = kms.encrypt(
    KeyId=kms_key,
    Plaintext=my_secret_message)

pp.pprint(enc_res)

In [None]:
# decrypt your secret message
dec_res = kms.decrypt(
    KeyId=kms_key,
    CiphertextBlob=enc_res['CiphertextBlob']
)

print("Decrpyted message:")
print(dec_res['Plaintext'].decode())

One thing to notice is encryption and decryption should happen at **bytes** level. If you want to encrypt a python object (list, numpy array, pandas data frame, pytorch model or a string) then the first step is to serialize it into bytes. One easy way to do it is to use `pickle.dumps` method. 

## Client-side encryption with data key
Now let's pretend you are a data engineer and you need to move a chuck of data from location A to location B. Location A is the machine you are using now to run this notebook, location B is an S3 bucket that your data scientist buddy will be using later to create a training job. You want to ensure that while data is on its way from location A to location B, it is not intercepted and stolen by cyber-attacker in the middle. 

One solution is genereate a data key `DK` from the KMS key and use `DK` to encrypt your data at location A (client side) and save the encrypted to S3 bucket. 

You will get a different data key each time you request it from the KMS key and the plaintext data key is intended to be **short-lived** and you should only save the **encrypted** data key for later use. 

Use [kms:GenerateDataKey](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kms.html#KMS.Client.generate_data_key) to generate a data key. 

In [23]:
key_length = 32 # 32 bytes 

data_key_res = kms.generate_data_key(
    KeyId=kms_key,
    NumberOfBytes= key_length  # your data key is will be 32x8=256-bit long
                       # takes 2^256 number of guesses to crack your data key
    )

pp.pprint(data_key_res)

{'CiphertextBlob': b'\x01\x02\x01\x00x]\x12%\xfc,\xc3\t+d\x0fmm\xb3h!n6\xd6\t'
                   b'}\x95"\x84\xa2\xe9E\x05\x8b\xc28\xe3\xe1\x01b\x99'
                   b'\xde!\xfdl\x95C\x8c\xdf\x9c\x91\x82}\xe3\x07\x00\x00'
                   b'\x00~0|\x06\t*\x86H\x86\xf7\r\x01\x07\x06\xa0o0m\x02'
                   b'\x01\x000h\x06\t*\x86H\x86\xf7\r\x01\x07\x010\x1e\x06\t`'
                   b'\x86H\x01e\x03\x04\x01.0\x11\x04\x0c]\xe2?K$\xae\xc4\xd5'
                   b'\x8b\xaa\xfe\xc2\x02\x01\x10\x80;\x1c\x8b\xecXp\x92DwaL5'
                   b'W\x0b\x0eXEn8\xfdY\xf5S\xa8*~!\xf4\x99\xa3\xd9/'
                   b'?\xdf\xb5\xcc\xef\xbe\xf4Bx\xfc\x18\xcb\tw[\x11'
                   b'\xe7\x7f\x97\xc3\xd5\xa8\xfa\xe5O\x95\xf2K',
 'KeyId': 'arn:aws:kms:us-west-2:688520471316:key/0c8582da-46d4-4be6-af1c-5f3c1d41166b',
 'Plaintext': b'dX\x8cleY\xce\\b\xae\x8d{\x91\xe7\x99\x15\x1dJ\xb9\xc0'
              b'z\x82!\xb5\xa5\xe8\xf9\xe32L\xd1\t',
 'ResponseMetadata': {'HTTPHeaders': {'cac

In [24]:
plaintext, ciphertext = data_key_res['Plaintext'], data_key_res['CiphertextBlob']

In [25]:
assert len(plaintext) == key_length

The ciphertext above is the encrypted data key. Of course it is encrypted by the KMS key. And the ciphtertext is what you should keep for long term. There is nothing preventing you from encrypting your plaintext data key with a different KMS key. You just need to remember which KMS you used to encrypt it. 

In [26]:
assert kms.decrypt(KeyId=kms_key, CiphertextBlob=ciphertext)['Plaintext'] == plaintext

Note that the plaintext data key is a byte-like object. It is not a string and in fact it cannot be decoded to a python string.

In [None]:
try:
    plaintext.decode('utf-8')
except Exception as e:
    print(e)

There are multiple python libraries for cryptography. We will use [cryptography](https://pypi.org/project/cryptography/)

In [None]:
!pip install cryptography

In [27]:
import base64
from cryptography.fernet import Fernet

def encrypt(data, plaintext_key):
    """Encrypt a chunk of bytes on client-side
    data: a chunk of bytes
    plaintext_key: plaintext data key
    """
    ascii_str = base64.b64encode(plaintext_key)

    f = Fernet(key=ascii_str) # Fernet key must be 32 url-safe base64-encoded bytes
                              # That's why we generated a 32-byte long data key
    return f.encrypt(data)

def decrypt(data, ciphertext_key):
    """Decrypt a chunk of bytes on client-side
    data: encypted binary data
    ciphertext_key: ciphertext data key
    """
    # decrypt the ciphertext data key
    plaintext_key = kms.decrypt(
        KeyId=kms_key, 
        CiphertextBlob=ciphertext_key)['Plaintext']
    
    # to Fernet-friendly key
    ascii_str = base64.b64encode(plaintext_key)
    
    f = Fernet(key=ascii_str)
    return f.decrypt(data)

In [28]:
import pickle

# encrypt
data =[i for i in range(1729)]
encrypted_data = encrypt(
    pickle.dumps(data), # python object -> bytes 
    plaintext
)

Once you finished encryption, you should delete the plaintext data key as soon as possible. 

In [None]:
del plaintext

In [None]:
# decrypt
b = decrypt(encrypted_data, ciphertext)
data_ = pickle.loads(b) # bytes -> python object

for x, y in zip(data, data_):
    assert x == y

## Save encrypted objects on S3 
Now you understand how encryption at client-side works. It should be straightforward to you how to save encrypted data on an S3 bucket. 

In [None]:
# create a bucket to be shared by SageMaker later

def create_bucket():
    """Create an S3 bucket that is intended to be used for short term"""
    bucket = f"sagemaker-{current_time()}"
    
    region_name = boto3.Session().region_name
    create_bucket_config = {}
    if region_name != 'us-east-1': 
        # us-east-1 is the default region for S3 bucket
        # specify LocationConstraint if your VPC is not
        # in us-east-1
        create_bucket_config['LocationConstraint'] = region_name
    
    boto3.client('s3').create_bucket(
        Bucket=bucket,
        CreateBucketConfiguration=create_bucket_config
    )
    return bucket

bucket = create_bucket()

In [29]:
# put your encrypted data on the S3 bucket

s3 = boto3.client('s3')
input_prefix = "data" # will be used later as S3Prefix when calling CreateTrainingJob

put_obj_res = s3.put_object(
    Bucket=bucket, 
    Key=input_prefix +'/'+'a_chunk_of_secrets',
    Body=encrypted_data)

pp.pprint(put_obj_res)

{'ETag': '"28cc539ba7a441d436433f68c66924ff"',
 'ResponseMetadata': {'HTTPHeaders': {'content-length': '0',
                                      'date': 'Tue, 06 Apr 2021 19:14:33 GMT',
                                      'etag': '"28cc539ba7a441d436433f68c66924ff"',
                                      'server': 'AmazonS3',
                                      'x-amz-id-2': 'UiBfqVQTeqHjjEEkEdYpmsVdwJ3BvZzGVLPaGqIqxtN1Yv8sUvUlobWhrQuHMN8Ktcp8gRlb1SU=',
                                      'x-amz-request-id': '1XV2YFFFXQZ3AG82'},
                      'HTTPStatusCode': 200,
                      'HostId': 'UiBfqVQTeqHjjEEkEdYpmsVdwJ3BvZzGVLPaGqIqxtN1Yv8sUvUlobWhrQuHMN8Ktcp8gRlb1SU=',
                      'RequestId': '1XV2YFFFXQZ3AG82',
                      'RetryAttempts': 0}}


## Create a SageMaker training job with encrypted data
Now you understand how to move your data from location $A$ to location $B$ encrypted. Let's see how this workflow can be merged into a SageMaker training job. What you want to achieve is, the static assets (model and data) need to be encrypted before you traffic them in the Internet. 

Let $M$ denote the customer KMS key hosted on KMS, $D$ the plaintext data key and $C$ the ciphertext data key. 

Suppose your training data is in an S3 bucket encrypted by the data key $D$. In order to use the training data, the SageMaker training job needs to be able to decrypt it. Of course you **would not** want to move $D$ (plaintext) around in the Internet and hand it to a SageMaker training job. Instead you will hand the encrypted data key (ciphertext) $C$ to the SageMaker training job. 

The SageMaker training job will do the following things with $C$
- Decrypt it to plaintext using the KMS key $M$ and get $D$
- Download the encrypted data from the S3 bucket and decrypt the data with $D$
- Train the model and encrypt the model with $D$ 
- Send the encrypted model to an S3 bucket

Of course, you could use a different data key to encrypt the model.

### How SageMaker uses your KMS key $M$

Remember a managed service like SageMaker *assumes* an IAM role (service role) in your account and it procures the resources in your AWS account based on the permission of the service role. 

When you created $M$, key policy said that the IAM user (you) and the root user of your account are the only entities entitled to use $M$. So does SageMaker use $M$ then? 

There are two ways to achieve this:
Suppose your SageMaker service role is called `example-role`.

1. Update the key policy to allow `example-role` to use $M$
2. Allow `example-role` to use $M$ via a **grant**

Quote from the [KMS docs](https://docs.aws.amazon.com/kms/latest/developerguide/concepts.html#grant)

>A grant is a policy instrument that allows AWS principals to use AWS KMS key in cryptographic operations. It also can let them view a CMK (DescribeKey) and create and manage grants. When authorizing access to a CMK, grants are considered along with key policies and IAM policies. Grants are often used for temporary permissions because you can create one, use its permissions, and delete it without changing your key policies or IAM policies. Because grants can be very specific, and are easy to create and revoke, they are often used to provide temporary permissions or more granular permissions.

We will the grant approach this tutorial as it involves less activities on your key policy. In a prodcution environment, you should think of an activity on your key policy as *a big deal*. 

First, get some helper functions for creating a SageMaker service role. 

In [None]:
%%bash
file=$(ls . | grep iam_helpers.py)

if [ -f "$file" ]
then
    rm $file
fi

wget https://raw.githubusercontent.com/aws/amazon-sagemaker-examples/sagemaker-fundamentals/execution-role/iam_helpers.py

In [None]:
# set up service role for SageMaker
from iam_helpers import create_execution_role

iam = boto3.client('iam')

role_name = 'example-role'
role_arn = create_execution_role(role_name=role_name)['Role']['Arn']

iam.attach_role_policy(
    RoleName=role_name,
    PolicyArn='arn:aws:iam::aws:policy/AmazonSageMakerFullAccess'
)

Now, you will verify that `example-role` cannot use your master key $M$ at this point. The cell below is expected to raise an exception. 

In [None]:
# create a boto3 session with example-role
import time

def create_session(role_arn):
    """Create a boto3 session with an IAM role"""
    now = str(time.time()).split('.')[0]
    obj = boto3.client('sts').assume_role(
        RoleArn=role_arn,
        RoleSessionName=now
    )

    cred=obj['Credentials']
    sess = boto3.session.Session(
        aws_access_key_id=cred['AccessKeyId'],
        aws_secret_access_key=cred['SecretAccessKey'],
        aws_session_token=cred['SessionToken']
        )
    return sess

sess = create_session(role_arn)

try:
    sess.client('kms').encrypt(
        KeyId=kms_key,
        Plaintext='it will not go through'.encode('utf-8')
    )
except Exception as e:
    print(e)

In [None]:
del sess

In [30]:
grant_res = kms.create_grant(
    KeyId=kms_key,
    GranteePrincipal=role_arn, 
    Operations=['Decrypt', 'Encrypt'] # allow example-role to use M to encrypt and decrypt
)

pp.pprint(grant_res)

NameError: name 'role_arn' is not defined

In [None]:
# Verify example-role has can use M
sess = create_session(role_arn)
enc_res = sess.client('kms').encrypt(
        KeyId=kms_key,
        Plaintext='it will go through this time'.encode('utf-8')
    )

pp.pprint(enc_res)

In [None]:
del sess

In [31]:
# put C to the bucket
s3.put_object(
    Bucket=bucket,
    Key='dont_look',
    Body=ciphertext
)

{'ResponseMetadata': {'RequestId': 'PF578V1B7N0WMPRG',
  'HostId': 'vqyTUQq1qqpMFX5U8/VojN2EF7F2ZZauEG8SGYPuobiU+/7090ULz62BGUBinUSjAqCCSd+seok=',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amz-id-2': 'vqyTUQq1qqpMFX5U8/VojN2EF7F2ZZauEG8SGYPuobiU+/7090ULz62BGUBinUSjAqCCSd+seok=',
   'x-amz-request-id': 'PF578V1B7N0WMPRG',
   'date': 'Tue, 06 Apr 2021 19:15:06 GMT',
   'etag': '"3428f69a48df1ce99d45ef13f11fb41e"',
   'content-length': '0',
   'server': 'AmazonS3'},
  'RetryAttempts': 0},
 'ETag': '"3428f69a48df1ce99d45ef13f11fb41e"'}

### Build a training container
You will build a training image here like in [the notebook on basics of `CreateTrainingJob`](https://github.com/hsl89/amazon-sagemaker-examples/blob/sagemaker-fundamentals/sagemaker-fundamentals/create-training-job/create_training_job.ipynb)

In [None]:
# View the Dockerfile
!cat container_kms/Dockerfile

In [None]:
# View the entrypoint script
!pygmentize container_kms/train.py

You will need to build your AWS credentials into the container, because you will need to decrypt your ciphertext data key within the container.

In [14]:
cred = boto3.Session().get_credentials()
access_key, secret_key = cred.access_key, cred.secret_key
region_name = boto3.Session().region_name

In [35]:
%%bash -s "$access_key" "$secret_key" "$region_name"

# build the image
cd container_kms/

# tag it as example-image:latest
docker build -t example-image:latest . --build-arg ACCESS_KEY=$1 \
    --build-arg SECRET_KEY=$2 --build-arg REGION_NAME=$3

Sending build context to Docker daemon  35.84kB
Step 1/12 : FROM continuumio/miniconda3:latest
 ---> 52daacd3dd5d
Step 2/12 : ARG ACCESS_KEY
 ---> Using cache
 ---> c87cbd2d2385
Step 3/12 : ARG SECRET_KEY
 ---> Using cache
 ---> 02dd377a1a3a
Step 4/12 : ARG REGION_NAME
 ---> Using cache
 ---> 7cd6c49d9440
Step 5/12 : RUN mkdir -p /opt/ml
 ---> Using cache
 ---> b66a62e92adf
Step 6/12 : ENV ACCESS_KEY=$ACCESS_KEY
 ---> Using cache
 ---> 44380a0bbf28
Step 7/12 : ENV SECRET_KEY=$SECRET_KEY
 ---> Using cache
 ---> c03aaa930cfd
Step 8/12 : ENV REGION_NAME=$REGION_NAME
 ---> Using cache
 ---> 1dc9bd190d33
Step 9/12 : RUN pip install cryptography
 ---> Using cache
 ---> 60654dfca0aa
Step 10/12 : RUN pip install boto3
 ---> Using cache
 ---> 4af00ca6e2d2
Step 11/12 : COPY train.py /usr/bin/train
 ---> 0d9b2cdc2b81
Step 12/12 : RUN chmod +x /usr/bin/train
 ---> Running in fef21e229ed6
Removing intermediate container fef21e229ed6
 ---> 272fc7002d3a
Successfully built 272fc7002d3a
Successfully ta

## Test your container locally
You programmed the entrypoint `container_kms/train.py` so that it gets to know the id of the master key as well as the S3 object key for the data key ciphertext via hyperparameters in `/opt/ml/input/config/hyperparameters.json`. That means you will need to [inject those hyperparamters to the container](https://github.com/aws/amazon-sagemaker-examples/blob/master/sagemaker-fundamentals/create-training-job/create_training_job_hyperparameter_injection.ipynb)

You can checkout the [notebook on basics of create a training job](https://github.com/aws/amazon-sagemaker-examples/blob/master/sagemaker-fundamentals/create-training-job/create_training_job.ipynb)(section Test your container) for more details. 

To recap, you will mount `container_kms/local_test/ml` (OS) to `/opt/ml`(container) as a docker volume and exchange training information with the container there. 

Look at what hyperparameters we used in `container_kms/train.py`. The hyperparameters are:

In [17]:
hyperparameters = {
    "ciphertext_s3_key": "dont_look",
    "kms_key_id": kms_key,
    "train_channel": "train",
    "train_file": "a_chunk_of_secrets",
    "key_bucket": bucket
}

pp.pprint(hyperparameters)

{'ciphertext_s3_key': 'dont_look',
 'key_bucket': 'sagemaker-2021-04-01-00-17-41',
 'kms_key_id': '0c8582da-46d4-4be6-af1c-5f3c1d41166b',
 'train_channel': 'train',
 'train_file': 'a_chunk_of_secrets'}


The hyperparameters are made available to the training container at `/opt/ml/input/config/hyperparameter.json`, so you will write the hyperparameters to `container_kms/local_test/ml/input/config/hyperparameters.json` for local testing.

In [18]:
import json

with open("container_kms/local_test/ml/input/config/hyperparameters.json", "w") as f:
    json.dump(hyperparameters, f)

In [None]:
!cat container_kms/local_test/ml/input/config/hyperparameters.json

Also, you need to have `container_kms/local_test/ml/input/train/a_chunk_of_secrets` available

In [32]:
import os
with open(
    os.path.join('container_kms', 'local_test', 'ml', 'input', 'data', 'train', 'a_chunk_of_secrets'), 'wb') as f:
    f.write(encrypted_data)

In [None]:
!ls -R container_kms/local_test/ml

In [None]:
!python container_kms/local_test/test_container.py

In [33]:
# create a repo in ECR called example-image
ecr = boto3.client('ecr')

try:
    # The repository might already exist
    # in your ECR
    cr_res = ecr.create_repository(
        repositoryName='example-image')
    pp.pprint(cr_res)
except Exception as e:
    print(e)

An error occurred (RepositoryAlreadyExistsException) when calling the CreateRepository operation: The repository with name 'example-image' already exists in the registry with id '688520471316'


In [37]:
%%bash
account=$(aws sts get-caller-identity --query Account | sed -e 's/^"//' -e 's/"$//')
region=$(aws configure get region)
ecr_account=${account}.dkr.ecr.${region}.amazonaws.com

# Give docker your ECR login password
aws ecr get-login-password --region $region | docker login --username AWS --password-stdin $ecr_account

# Fullname of the repo
fullname=$ecr_account/example-image:latest

#echo $fullname
# Tag the image with the fullname
docker tag example-image:latest $fullname

# Push to ECR
docker push $fullname

Login Succeeded
The push refers to repository [688520471316.dkr.ecr.us-west-2.amazonaws.com/example-image]
de7086155d0f: Preparing
449fac52d6a6: Preparing
6d038293e0e8: Preparing
19c0b14789a7: Preparing
e9d8037ebe1e: Preparing
dfef8986f350: Preparing
0553ab4c463e: Preparing
f5600c6330da: Preparing
f5600c6330da: Waiting
0553ab4c463e: Waiting
dfef8986f350: Waiting
19c0b14789a7: Layer already exists
e9d8037ebe1e: Layer already exists
6d038293e0e8: Layer already exists
dfef8986f350: Layer already exists
0553ab4c463e: Layer already exists
f5600c6330da: Layer already exists
449fac52d6a6: Pushed
de7086155d0f: Pushed
latest: digest: sha256:4db8346101a2265a7b492b08ad20f31acf85a9d00b5b08acacaef8e8eb98f13f size: 1997


https://docs.docker.com/engine/reference/commandline/login/#credentials-store



Now you have all the ingredients for a SageMaker training job.

In [42]:
# configure a training job

sm_cli = boto3.client('sagemaker')

# input
data_path = "s3://" + bucket + '/' + input_prefix

# location that SageMaker saves the model artifacts
output_prefix = 'output'
output_path = "s3://" + bucket + '/' + output_prefix

# ECR URI of your image
region = boto3.Session().region_name
account = account_id()
image_uri = "{}.dkr.ecr.{}.amazonaws.com/example-image:latest".format(account, region)

algorithm_specification = {
    'TrainingImage': image_uri,
    'TrainingInputMode': 'File',
}


input_data_config = [
    {
        'ChannelName': 'train',
            'DataSource':{
                'S3DataSource':{
                    'S3DataType': 'S3Prefix',
                    'S3Uri': data_path,
                    'S3DataDistributionType': 'FullyReplicated',
                }
        }
        
    }
]


output_data_config = {
    'S3OutputPath': output_path
}

resource_config = {
    'InstanceType': 'ml.m5.large',
    'InstanceCount':1,
    'VolumeSizeInGB':10
}

stopping_condition={
    'MaxRuntimeInSeconds':120,
}


enable_network_isolation=False

In [43]:
# some helper functions to monitor the training job

import time

def monitor_training_job_status(training_job_name, log_freq=30):
    """Print out training job status every $log_freq seconds"""
    stopped = False
    while not stopped:
        tj_state = sm_cli.describe_training_job(
            TrainingJobName=training_job_name)
        if tj_state['TrainingJobStatus'] in ['Completed', 'Stopped', 'Failed']:
            stopped=True
        else:
            print("Training in progress")
            time.sleep(log_freq)

    if tj_state['TrainingJobStatus'] == 'Failed':
        print("Training job failed ")
        print("Failed Reason: {}".format(tj_state['FailureReason']))
    else:
        print("Training job completed")
    return


def print_logs(training_job_name):
    """Print out stdout in the container from CloudWatch"""
    logs = boto3.client('logs')

    log_res= logs.describe_log_streams(
        logGroupName='/aws/sagemaker/TrainingJobs',
        logStreamNamePrefix=training_job_name)

    for log_stream in log_res['logStreams']:
        # get one log event
        log_event = logs.get_log_events(
            logGroupName='/aws/sagemaker/TrainingJobs',
            logStreamName=log_stream['logStreamName'])

        # print out messages from the log event
        for ev in log_event['events']:
            for k, v in ev.items():
                if k == 'message':
                    print(v)
                    
    return

In [44]:
# name training job
training_job_name = 'example-training-job-{}'.format(current_time())

ct_res = sm_cli.create_training_job(
    TrainingJobName=training_job_name,
    AlgorithmSpecification=algorithm_specification,
    RoleArn=role_arn,
    HyperParameters=hyperparameters, # use the same hyperparameters for local testing
    InputDataConfig=input_data_config,
    OutputDataConfig=output_data_config,
    ResourceConfig=resource_config,
    StoppingCondition=stopping_condition,
    EnableNetworkIsolation=enable_network_isolation,
    EnableManagedSpotTraining=False,
)

In [45]:
monitor_training_job_status(training_job_name)

Training in progress
Training in progress
Training in progress
Training in progress
Training in progress
Training in progress
Training job completed


In [None]:
print_logs(training_job_name)

## Review && Discussion

In this notebook, you went through a typical workflow for creating a SageMaker training job with client-side encryption. You have
* Generated a data key from a KMS key
* Used the data key to encrypt your dataset before putting it to a SageMaker-accessible S3 bucket
* Created a SageMaker training job and passed the encrypted data key to the training job
* Decrypted the data key within the training container and used the decrypted data key to decrypted training data from the S3 bucket
* Encrypted your trained model at the end of the training job


A few things to keep in mind:
* Encryption and decryption should happen at bytes level. There are numerous way to serialize a python objects, you used pickle module in this tutorial. But pickle might not be your best option in your production env (thanks to @seebees). Checkout the following content:
    * https://nedbatchelder.com/blog/202006/pickles_nine_flaws.html
    * https://lwn.net/Articles/595352/
    * https://intoli.com/blog/dangerous-pickles/
    
* There are many technical details to take care of when doing client-side encryption. The purpose of this tutorial is to show you the basic concepts involved in client-side encryption. [AWS Encryption SDK](https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/introduction.html) is a library for client-side encryption and it can take care much of nuts and bolts for you. You are highly encouraged to explore it. 

## Clean up
You cannot delete the KMS key with one stroke. In this tutorial we created key for sake of running this tutorial, but in a production environment, deleting a key is a BIG deal, because once you deleled the KMS key, all data encryted under that key becomes unavailable to you. That's why you need to exercise extreme caution when deleting a key. Checkout the [section on deleting a KMS key](https://docs.aws.amazon.com/kms/latest/developerguide/deleting-keys.html) for more detail. 

You can schedule a key deletion in $X$ days by calling [kms:ScheduleKeyDeletion](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kms.html#KMS.Client.schedule_key_deletion).
Once this API is called, you key status will be **Pending Deletion**. This is to remind you that you need to sort out all your data encrypted under this key in $X$ days and encrypted them using a different key. 

During those $X$ days, if you changed your mind and decided not to delete the key, you can call [kms:CancelKeyDeletion](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/kms.html#KMS.Client.cancel_key_deletion) API. 

In [None]:
def schedule_key_deletion(key_id, waiting_period):
    """Delete a key in $waiting_period days
    Args:
        key_id: id of the key to be deleted
        waiting_period: number of days to wait before key deletion
    """
    dk_res = kms.schedule_key_deletion(
        KeyId=key_id,
        PendingWindowInDays=waiting_period
    )
    
    pp.pprint(dk_res)
    return

# call schedule_key_deletion if you want to delete the key
# schedule_key_deletion(key_id=kms_key, waiting_period=7)

In [None]:
def delete_force(bucket_name):
    """Helper function to delete a bucket"""
    objs = s3.list_objects_v2(Bucket=bucket_name)['Contents']
    for obj in objs:
        s3.delete_object(
            Bucket=bucket_name,
            Key=obj['Key'])
    
    return s3.delete_bucket(Bucket=bucket_name)


def delete_ecr_repo(repo_name):
    """Helper function to delete an ECR repo"""
    ecr.delete_repository(
        repositoryName=repo_name,
        force=True
    )
    return