## SageMaker Training for DDA

### Pre-requisites

1. Note: This notebook contains elements which render correctly in Jupyter interface. Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.
1. Ensure that IAM role used has **AmazonSageMakerFullAccess**
1. Some hands-on experience using **Amazon SageMaker**.
1. To use this algorithm successfully, ensure that:
   
   A. Either your IAM role has these three permissions and you have authority to make AWS Marketplace subscriptions in the AWS account used:
   
        a. aws-marketplace:ViewSubscriptions
        b. aws-marketplace:Unsubscribe
        c. aws-marketplace:Subscribe
   
   B: or your AWS account has a subscription to:[Computer Vision Defect Detection Model](https://aws.amazon.com/marketplace/pp/prodview-j72hhmlt6avp6).

### Subscribe to the algorithm

To subscribe to the algorithm:

1. Open the algorithm listing page: [Computer Vision Defect Detection Model](https://aws.amazon.com/marketplace/pp/prodview-j72hhmlt6avp6).
1. On the AWS Marketplace listing, click on Continue to subscribe button.
1. On the Subscribe to this software page, review and click on "Accept Offer" if you agree with EULA, pricing, and support terms.
1. Once you click on Continue to configuration button and then choose a region, you will see a Product Arn. This is the algorithm ARN that you need to specify while training a custom ML model. Copy the algorithm name and specify the same in the following cell.

In [None]:
# TODO: change this to use subscribed SageMaker algorithm
algorithm_name = "<Customer to specify the algorithm name after subscribtion>"

### Set Up

In [None]:
import boto3
import sagemaker
import json

In [None]:
session = sagemaker.Session()
region = session.boto_region_name
bucket = session.default_bucket()
# Project name would be used as part of s3 output path
project = "LFV-public-test"

### Prepare data - No need to go thru this
Prepare mansifest file and download dataset

In [None]:
!aws s3 cp s3://lookoutvision-us-east-1-0e205be246/getting-started/manifests/train_class.manifest .

In [None]:
!aws s3 cp s3://lookoutvision-us-east-1-0e205be246/getting-started/manifests/train_segmentation.manifest .

Dataset is from LFV getting start - https://docs.aws.amazon.com/lookout-for-vision/latest/developer-guide/getting-started.html

In [None]:
!wget https://docs.aws.amazon.com/lookout-for-vision/latest/developer-guide/samples/getting-started.zip

### Create IAM role

In [None]:
iam_client = boto3.client('iam')
trust_policy = {
    "Version": "2012-10-17",
    "Statement": [
        {
            "Effect": "Allow",
            "Principal": {
                "Service": "sagemaker.amazonaws.com"
            },
            "Action": "sts:AssumeRole"
        }
    ]
}

# Create the IAM role
role_name = "SageMakerExecutionRole"

response = iam_client.create_role(
    RoleName=role_name,
    AssumeRolePolicyDocument=json.dumps(trust_policy),
    Description="IAM role with full S3 and SageMaker access"
)

sm_role_arn = response['Role']['Arn']
print(f"Role created with ARN: {sm_role_arn}")

# Attach policies for full S3 and SageMaker access
iam_client.attach_role_policy(
    RoleName=role_name,
    PolicyArn="arn:aws:iam::aws:policy/AmazonS3FullAccess"
)

iam_client.attach_role_policy(
    RoleName=role_name,
    PolicyArn="arn:aws:iam::aws:policy/AmazonSageMakerFullAccess"
)
print("Attached S3 full access and SageMaker full access")

## Classification Model
Start training job for classification model

In [None]:
import datetime
sagemaker = boto3.Session(region_name=region).client("sagemaker")
classification_training_job_name = 'LFV-classification-'+datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

To use robust model feature for classification model:
```
HyperParameters={
    'ModelType': 'classification-robust',
    'TestInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label',
    'TrainingInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label'
},
```

In [None]:
response = sagemaker.create_training_job(
    TrainingJobName=classification_training_job_name,
    HyperParameters={
        'ModelType': 'classification',
        'TestInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label',
        'TrainingInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label'
    },
    AlgorithmSpecification={
        'AlgorithmName': algorithm_name,
        'TrainingInputMode': 'File',
        'EnableSageMakerMetricsTimeSeries': False
    },
    RoleArn=sm_role_arn,
    InputDataConfig=[
        {
            'ChannelName': 'training',
            'DataSource': {
                'S3DataSource': {
                    'S3DataType': 'AugmentedManifestFile',
                    'S3Uri': 's3://lookoutvision-us-east-1-0e205be246/getting-started/manifests/train_class.manifest',
                    'S3DataDistributionType': 'ShardedByS3Key',
                    'AttributeNames': [
                        'source-ref',
                        'anomaly-label-metadata',
                        'anomaly-label'
                    ],
                }
            },
            'CompressionType': 'None',
            'RecordWrapperType': 'RecordIO',
            'InputMode': 'Pipe'
        },
    ],
    OutputDataConfig={'S3OutputPath': 's3://'+bucket+'/'+project+'/output'},
    ResourceConfig={
        'InstanceType': 'ml.g4dn.2xlarge',
        'InstanceCount': 1,
        'VolumeSizeInGB': 20
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 7200
    }
)

In [None]:
import time
while True:
    training_response = sagemaker.describe_training_job(
        TrainingJobName=classification_training_job_name
    )
    if training_response['TrainingJobStatus'] == 'InProgress':
        print(".", end='')
    elif training_response['TrainingJobStatus'] == 'Completed':
        print("Completed")
        break
    elif training_response['TrainingJobStatus'] == 'Failed':
        print("Failed")
        break
    else:
        print("?", end='')
    time.sleep(60)

******************

## Segmentation Model

Start traning job for segmentation model

In [None]:
segmentation_training_job_name = 'LFV-segmentation-'+datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

In [None]:
sagemaker = boto3.Session(region_name=region).client("sagemaker")
response = sagemaker.create_training_job(
    TrainingJobName=segmentation_training_job_name,
    HyperParameters={
        # To use robust model feature, change "ModelType" to "segmentation-robust"
        'ModelType': 'segmentation',
        'TestInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label,anomaly-mask-ref-metadata,anomaly-mask-ref',
        'TrainingInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label,anomaly-mask-ref-metadata,anomaly-mask-ref'
    },
    AlgorithmSpecification={
        'AlgorithmName': algorithm_name,
        'TrainingInputMode': 'File',
        'EnableSageMakerMetricsTimeSeries': False
    },
    RoleArn=sm_role_arn,
    InputDataConfig=[
        {
            'ChannelName': 'training',
            'DataSource': {
                'S3DataSource': {
                    'S3DataType': 'AugmentedManifestFile',
                    'S3Uri': 's3://lookoutvision-us-east-1-0e205be246/getting-started/manifests/train_segmentation.manifest',
                    'S3DataDistributionType': 'ShardedByS3Key',
                    'AttributeNames': [
                        'source-ref',
                        'anomaly-label-metadata',
                        'anomaly-label',
                        'anomaly-mask-ref-metadata',
                        'anomaly-mask-ref'
                    ],
                }
            },
            'CompressionType': 'None',
            'RecordWrapperType': 'RecordIO',
            'InputMode': 'Pipe'
        },
    ],
    OutputDataConfig={'S3OutputPath': 's3://'+bucket+'/'+project+'/output'},
    ResourceConfig={
        'InstanceType': 'ml.g4dn.2xlarge',
        'InstanceCount': 1,
        'VolumeSizeInGB': 20
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 7200
    }
)
print(response)

In [None]:
while True:
    training_response = sagemaker.describe_training_job(
        TrainingJobName=segmentation_training_job_name
    )
    if training_response['TrainingJobStatus'] == 'InProgress':
        print(".", end='')
    elif training_response['TrainingJobStatus'] == 'Completed':
        print("Completed")
        break
    elif training_response['TrainingJobStatus'] == 'Failed':
        print("Failed")
        break
    else:
        print("?", end='')
    time.sleep(60)

To use Segmentation head only, use hyper parameters like following:
```
HyperParameters={
    'ModelType': 'segmentation',
    'TestInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label,anomaly-mask-ref-metadata,anomaly-mask-ref',
    'TrainingInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label,anomaly-mask-ref-metadata,anomaly-mask-ref',
    'classification_logic': 'seg_head'
},
```

To enable robust model feature for segmentation model:
```
HyperParameters={
    'ModelType': 'segmentation-robust',
    'TestInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label,anomaly-mask-ref-metadata,anomaly-mask-ref',
    'TrainingInputDataAttributeNames': 'source-ref,anomaly-label-metadata,anomaly-label,anomaly-mask-ref-metadata,anomaly-mask-ref'
},
```

***********

## Compilation job - Classification

After training job is completed, we will create a sagemaker compilation job. During compilation job we will sepecify the target device we will run on along with DDA edge application.

Since SageMaker compilation job expects only one PyTorch model file, we could not use the training job output artifact directly. 

Prepare model for compilation:
1. download trained model
2. unzip and tar the mochi.pt file to mochi.tar.gz
3. upload to S3

In [None]:
res_class = sagemaker.describe_training_job(TrainingJobName=classification_training_job_name)
output_model_path = res_class['ModelArtifacts']['S3ModelArtifacts']
print(output_model_path)

In [None]:
from urllib.parse import urlparse

parsed_url = urlparse(output_model_path)
output_bucket = parsed_url.netloc
output_key = parsed_url.path.lstrip('/')
print(output_bucket)
print(output_key)

In [None]:
import tarfile
import os
import fnmatch
from pathlib import Path

s3_client = boto3.client('s3')
path = "./classification"
Path(path).mkdir(parents=True, exist_ok=True)

# Download the .tar.gz file from S3
input_tar_gz = os.path.join(path, 'model.tar.gz')
s3_client.download_file(output_bucket, output_key, input_tar_gz)

# Extract the contents of the .tar.gz file
extract_dir = os.path.join(path, 'extracted')
Path(extract_dir).mkdir(parents=True, exist_ok=True)
with tarfile.open(input_tar_gz, 'r:gz') as tar:
    tar.extractall(path=extract_dir)
print(f"Extracted {input_tar_gz} to {extract_dir}.")

# Find the file with "mochi.pt" in its name
model_file = os.path.join(extract_dir, 'mochi.pt')
if model_file is None:
    raise Exception("No mochi.pt file found.")

print(f"Found model file: {model_file}")

# Create a new .tar.gz file with the model.pt file
output_tar_gz = os.path.join(path, 'classification.tar.gz')
with tarfile.open(output_tar_gz, "w:gz") as tar:
    tar.add(model_file, arcname=os.path.basename(model_file))
print(f"Created tar.gz file {output_tar_gz} with {model_file}.")

# Upload the new .tar.gz file to S3
target_key = output_key.rsplit('/', 1)[0] + '/classification.tar.gz'
s3_client.upload_file(output_tar_gz, output_bucket, target_key)
print(f"Uploaded {output_tar_gz} to bucket {output_bucket} with key {target_key}.")

### Target Device: Jetson xavier Jetpack4

In [None]:
compilation_job_name = "class-xavier-gpu-"+datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

In [None]:
compressed_model_path = f"s3://{output_bucket}/{target_key}"
print(f"Compressed model path {compressed_model_path}")

In [None]:
create_response = sagemaker.create_compilation_job(
    CompilationJobName=compilation_job_name,
    RoleArn=sm_role_arn,
    InputConfig={
        'S3Uri': compressed_model_path,
        'DataInputConfig': '{"input_shape": [1,3,672,480]}',
        'Framework': 'PYTORCH',
        'FrameworkVersion': '1.8'
    },
    OutputConfig={
        'S3OutputLocation': 's3://'+bucket+'/'+project+'/compilation_output',
        'TargetPlatform': {
            'Os': 'LINUX',
            'Arch': 'ARM64',
            'Accelerator': 'NVIDIA'
        },
        'CompilerOptions': '{"cuda-ver": "10.2","gpu-code": "sm_72","trt-ver": "8.2.1"}'
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 3600
    },
)


In [None]:
while True:
    compile_response = sagemaker.describe_compilation_job(
        CompilationJobName=compilation_job_name
    )
    if compile_response['CompilationJobStatus'] == 'INPROGRESS':
        print(".", end='')
    elif compile_response['CompilationJobStatus'] == 'STARTING':
        print("*", end='')
    elif compile_response['CompilationJobStatus'] == 'COMPLETED':
        print("Completed")
        break
    elif compile_response['CompilationJobStatus'] == 'FAILED':
        print("Failed")
        print(compile_response['FailureReason'])
        break
    else:
        print("?", end='')
    time.sleep(60)

### Target Device: x86 cpu

In [None]:
compilation_job_name = "class-x86-cpu-"+datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

In [None]:
create_response = sagemaker.create_compilation_job(
    CompilationJobName=compilation_job_name,
    RoleArn=sm_role_arn,
    InputConfig={
        'S3Uri': compressed_model_path,
        'DataInputConfig': '{"input_shape": [1,3,672,480]}',
        'Framework': 'PYTORCH',
        'FrameworkVersion': '1.8'
    },
    OutputConfig={
        'S3OutputLocation': 's3://'+bucket+'/'+project+'/compilation_output',
        'TargetPlatform': {
            'Os': 'LINUX',
            'Arch': 'X86_64'
        }
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 3600
    },
)


In [None]:
import time
while True:
    compile_response = sagemaker.describe_compilation_job(
        CompilationJobName=compilation_job_name
    )
    if compile_response['CompilationJobStatus'] == 'INPROGRESS':
        print(".", end='')
    elif compile_response['CompilationJobStatus'] == 'STARTING':
        print("*", end='')
    elif compile_response['CompilationJobStatus'] == 'COMPLETED':
        print("Completed")
        break
    elif compile_response['CompilationJobStatus'] == 'FAILED':
        print("Failed")
        print(compile_response['FailureReason'])
        break
    else:
        print("?", end='')
    time.sleep(60)

In [None]:
compile_response

### Target Device: arm cpu

In [None]:
compilation_arm_cpu = "class-arm-cpu-"+datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

In [None]:
create_arm_response = sagemaker.create_compilation_job(
    CompilationJobName=compilation_arm_cpu,
    RoleArn=sm_role_arn,
    InputConfig={
        'S3Uri': compressed_model_path,
        'DataInputConfig': '{"input_shape": [1,3,672,480]}',
        'Framework': 'PYTORCH',
        'FrameworkVersion': '1.8'
    },
    OutputConfig={
        'S3OutputLocation': 's3://'+bucket+'/'+project+'/compilation_output',
        'TargetPlatform': {
            'Os': 'LINUX',
            'Arch': 'ARM64'
        }
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 3600
    },
)


In [None]:
import time
while True:
    create_arm_response = sagemaker.describe_compilation_job(
        CompilationJobName=compilation_arm_cpu
    )
    if create_arm_response['CompilationJobStatus'] == 'INPROGRESS':
        print(".", end='')
    elif create_arm_response['CompilationJobStatus'] == 'STARTING':
        print("*", end='')
    elif create_arm_response['CompilationJobStatus'] == 'COMPLETED':
        print("Completed")
        break
    elif create_arm_response['CompilationJobStatus'] == 'FAILED':
        print("Failed")
        print(create_arm_response['FailureReason'])
        break
    else:
        print("?", end='')
    time.sleep(60)

In [None]:
create_arm_response

## Compilation job - Segmentation

In [None]:
seg_training = segmentation_training_job_name

In [None]:
res_seg = sagemaker.describe_training_job(TrainingJobName=seg_training)
seg_output_model_path = res_seg['ModelArtifacts']['S3ModelArtifacts']
print(seg_output_model_path)

Prepare model for compilation:
1. download trained model
2. unzip and tar the mochi.pt file to mochi.tar.gz
3. upload to S3

In [None]:
from urllib.parse import urlparse

parsed_url = urlparse(seg_output_model_path)
output_bucket = parsed_url.netloc
output_key = parsed_url.path.lstrip('/')
print(output_bucket)
print(output_key)

In [None]:
import tarfile
import os
import fnmatch
from pathlib import Path

s3_client = boto3.client('s3')
path = "./segmentation"
Path(path).mkdir(parents=True, exist_ok=True)

# Download the .tar.gz file from S3
input_tar_gz = os.path.join(path, 'model.tar.gz')
s3_client.download_file(output_bucket, output_key, input_tar_gz)

# Extract the contents of the .tar.gz file
extract_dir = os.path.join(path, 'extracted')
Path(extract_dir).mkdir(parents=True, exist_ok=True)
with tarfile.open(input_tar_gz, 'r:gz') as tar:
    tar.extractall(path=extract_dir)
print(f"Extracted {input_tar_gz} to {extract_dir}.")

# Find the file with "mochi.pt" in its name
model_file = os.path.join(extract_dir, 'mochi.pt')
if model_file is None:
    raise Exception("No mochi.pt file found.")

print(f"Found model file: {model_file}")

# Create a new .tar.gz file with the model.pt file
output_tar_gz = os.path.join(path, 'segmentation.tar.gz')
with tarfile.open(output_tar_gz, "w:gz") as tar:
    tar.add(model_file, arcname=os.path.basename(model_file))
print(f"Created tar.gz file {output_tar_gz} with {model_file}.")

# Upload the new .tar.gz file to S3
target_key = output_key.rsplit('/', 1)[0] + '/segmentation.tar.gz'
s3_client.upload_file(output_tar_gz, output_bucket, target_key)
print(f"Uploaded {output_tar_gz} to bucket {output_bucket} with key {target_key}.")

In [None]:
compilation_job = "seg-x86-cpu-"+datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

In [None]:
model_path = f"s3://{output_bucket}/{target_key}"
print(f"Compressed model path {model_path}")

In [None]:
seg_x86_response = sagemaker.create_compilation_job(
    CompilationJobName=compilation_job,
    RoleArn=sm_role_arn,
    InputConfig={
        'S3Uri': model_path,
        'DataInputConfig': '{"input_shape": [1,3,768,576]}',
        'Framework': 'PYTORCH',
        'FrameworkVersion': '1.8'
    },
    OutputConfig={
        'S3OutputLocation': 's3://'+bucket+'/'+project+'/compilation_output',
        'TargetPlatform': {
            'Os': 'LINUX',
            'Arch': 'X86_64'
        }
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 3600
    },
)


In [None]:
while True:
    create_response = sagemaker.describe_compilation_job(
        CompilationJobName=compilation_job
    )
    if create_response['CompilationJobStatus'] == 'INPROGRESS':
        print(".", end='')
    elif create_response['CompilationJobStatus'] == 'STARTING':
        print("*", end='')
    elif create_response['CompilationJobStatus'] == 'COMPLETED':
        print("Completed")
        break
    elif create_response['CompilationJobStatus'] == 'FAILED':
        print("Failed")
        print(create_response['FailureReason'])
        break
    else:
        print("?", end='')
    time.sleep(60)

In [None]:
create_response