# Use Ground Truth Labeled Data to Train an Object Detection Model in Chest Xrays

In this notebook, we demonstrate how to build a machine learning model to detect the trachea of a patient in an x-ray image using Amazon SageMaker.
We will be using 1099 NIH Chest X-ray images sampled from [this](https://www.kaggle.com/nih-chest-xrays/data) repository. While the images are originally from that source, we leveraged SageMaker Ground Truth to create bounding boxes around the trachea of the patient. We will thus be using **both** the raw images and also the manifest file where labellers labeled the trachea of the patient. 
An example of a labeled image is:

![image.png](chest_image.png)

This process could potentially be used as a template for detecting other objects as well within xrays; however, we focus only on detecting the trachea of the patient, if it is present.
This notebook contains instructions to use the GroundTruth manifest file to understand the labeled data, train, build and deploy the model as an end point in SageMaker. This notebook is created on a "ml.t3.medium" instance.

### Learning Objectives:

This workshop covers a basic introduction to SageMaker Ground truth, understanding the labeled images, split the dataset for training and validation, building, training, deploying and testing an object detection model in SageMaker. Here are the steps:

1. Perform basic preprocessing of images using the ground truth manifest file
2. Visualize the labeled images for data analysis and understanding
3. Process SageMaker Ground Truth [manifest files](https://docs.aws.amazon.com/lookout-for-vision/latest/developer-guide/create-dataset-ground-truth.html).
4. Build, train, deploy and test the SageMaker built-in object detection model


## Leverage image dataset located in an S3 bucket

### Public Dataset Used:

Chest X-ray images are stored in publicly accessible S3 bucket. Below lines of code will download the image data from public S3 bucket to user's S3 bucket.  

In [None]:
import sagemaker
import json
import numpy as np
import boto3

sagemaker_session = sagemaker.session.Session()
default_bucket = sagemaker_session.default_bucket()

In [None]:
# Data source
DATA_SOURCE='s3://aws-hcls-ml/public_assets_support_materials/x_ray_object_detection_data/'

BUCKET = default_bucket #optional: Change to your bucket
PREFIX='x_ray_image_data' #optional: Change to your directory/prefix

IMAGE_DATA_S3=f's3://{BUCKET}/{PREFIX}/' #location of image data in s3

!aws s3 cp $DATA_SOURCE $IMAGE_DATA_S3 --recursive --quiet
!echo "Image data copied to "$IMAGE_DATA_S3

## Introduction to SageMaker Ground Truth

Amazon SageMaker enables you to identify raw data, such as images, text files, and videos; add informative labels; and generate labeled synthetic data to create high-quality training datasets for your machine learning (ML) models. SageMaker offers two options, Amazon SageMaker Ground Truth Plus and Amazon SageMaker Ground Truth, which provide you with the flexibility to use an expert workforce to create and manage data labeling workflows on your behalf or manage your own data labeling workflows.

If you want the flexibility to build and manage your own data labeling workflows and workforce, you can use SageMaker Ground Truth. SageMaker Ground Truth is a data labeling service that makes it easy to label data and gives you the option to use human annotators through Amazon Mechanical Turk, third-party vendors, or your own private workforce.

You can also generate labeled synthetic data without manually collecting or labeling real-world data. SageMaker Ground Truth can generate hundreds of thousands of automatically labeled synthetic images on your behalf.

## Manifest file

A manifest file contains information about the images and image labels that you can use to train and test a model. Each line in an input manifest file is an entry containing an object, or a reference to an object, to label. An entry can also contain labels from previous jobs and for some task types, additional information. In the example below, the pixel for the start of the box that contains the trachea is at pixel `[420,15]` and the image has a height of 108 and a width of 143.

Here is a sample record from manifest file:
```
{
  "source-ref": "s3://BUCKET/PREFIX/image_data/00000001_000.png",
  "xray-labeling-job-clone-clone-full-clone": {
    "annotations": [
      {
        "class_id": 0,
        "width": 143,
        "top": 15,
        "height": 108,
        "left": 420
      }
    ],
    "image_size": [
      {
        "width": 1024,
        "depth": 3,
        "height": 1024
      }
    ]
  },
  "xray-labeling-job-clone-clone-full-clone-metadata": {
    "class-map": {
      "0": "Trachea"
    },
    "objects": [
      {
        "confidence": 0
      }
    ],
    "job-name": "labeling-job/xray-labeling-job-clone-clone-full-clone",
    "human-annotated": "yes",
    "creation-date": "2020-07-22T15:04:38.513000",
    "type": "groundtruth/object-detection"
  }
}
```

In our workshop, we will be using a template manifest file from the Ground Truth labeling job that was used to label 1099 images from the source dataset. 

### Create a new manifest file from the template

The below lines of code will replace the bucket and prefix values and creates a new manifest file. This step replaces the BUCKET and PREFIX strings in the template file with the one you will actually used based on your account.

In [None]:
!mkdir -p data #make a data directory if it does not exist
with open('template.manifest', 'r') as template_file:
    output = [json.loads(line.strip().replace('BUCKET',BUCKET).replace('PREFIX',PREFIX)) for line in template_file.readlines()]

with open('data/output_manifest_clean.manifest','w') as output_file:
    for i in output:
        json.dump(i, output_file)
        output_file.write("\n")

### Preprocessing the labeled images

1. Add bounding box to the images that corresponds to throat/trachea labels
2. Create labeled and non-labeled image lists for model training and validation. 

There must be <b><u>851</u></b> labeled images and <b><u>248</u></b> non-labeled images.

In [None]:
from ground_truth_utils import extract_image_data, WorkerBoundingBox

output_images_with_bounding_box = extract_image_data(output)

# Iterate through the json files, creating bounding box objects.
output_with_answers = []  # only include images with the answers in them
output_images_with_answers = []
output_with_no_answers = []
output_images_with_no_answers = []

# Find the job name the manifest corresponds to.
keys = list(output[0].keys())
metakey = keys[np.where([("-metadata" in k) for k in keys])[0][0]]
jobname = metakey[:-9]

for i in range(0, len(output)):
    try:
        # images with class_id have answers in them
        x = output[i][jobname]["annotations"][0]["class_id"]
        output_with_answers.append(output[i])
        output_images_with_answers.append(output_images_with_bounding_box[i])
    except:
        output_with_no_answers.append(output[i])
        output_images_with_no_answers.append(output_images_with_bounding_box[i])
        pass

# add the box to the image
for i in range(0, len(output_with_answers)):
    the_output = output_with_answers[i]
    the_image = output_images_with_answers[i]
    answers = the_output[jobname]["annotations"]
    box = WorkerBoundingBox(image_id=i, boxdata=answers[0], worker_id="anon-worker")
    box.image = the_image
    the_image.worker_boxes.append(box)

print(
    f"Number of images with labeled trachea/throat: {len(output_images_with_answers)}"
)
print(f"Number of images without labeled trachea/throat: {len(output_with_no_answers)}")

## Inspect image labels

Download 5 random images from the labeled images list and store them in data directory. Loop through the sample images and plot them for visualization. The visualization shows the label with bounding box around the throat and trachea. 

In [None]:
N_SHOW = 5
image_subset = np.random.choice(output_images_with_answers, N_SHOW, replace=False)

# Download image data
for img in image_subset:
    target_fname = f"data/{img.uri.split('/')[-1]}"
    !aws s3 cp {img.uri} {target_fname} --quiet

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

# Find human and auto-labeled images in the subset.
human_labeled_subset = [img for img in image_subset if img.human]

# Show examples of each
for img in human_labeled_subset:
    fig, axes = plt.subplots(facecolor='white', dpi=100)
    fig.suptitle('Human-labeled examples', fontsize=14)
    img.download("data")
    img.plot_consolidated_bbs(axes)

## Creating Train, Validation and Test Datasets

Shuffle the records read from the manifest file. Split the data into training, validation and holdout sets. For each set create corresponding manifest file and upload them to S3. 

In [None]:
output=output_with_answers
EXP_NAME= 'sm-object-detection'

# Shuffle output in place.
np.random.shuffle(output)
    
dataset_size = len(output)
train_test_split_index = round(dataset_size*0.9)

train_data = output[:train_test_split_index]
test_data = output[train_test_split_index:]

train_test_split_index_2 = round(len(test_data)*0.5)
validation_data=test_data[:train_test_split_index_2]
hold_out=test_data[train_test_split_index_2:]
                                 

num_training_samples = 0
attribute_names = []
with open('data/train.manifest', 'w') as f:
    for line in train_data:
        f.write(json.dumps(line))
        f.write('\n')
        num_training_samples += 1
        attribute_names = [attrib for attrib in line.keys() if 'meta' not in attrib]
    
with open('data/validation.manifest', 'w') as f:
    for line in validation_data:
        f.write(json.dumps(line))
        f.write('\n')
        
with open('data/hold_out.manifest', 'w') as f:
    for line in hold_out:
        f.write(json.dumps(line))
        f.write('\n')
        
print(f'Training Data Set Size: {len(train_data)}')
print(f'Validatation Data Set Size: {len(validation_data)}')
print(f'Hold Out Data Set Size: {len(hold_out)}')

!aws s3 cp data/train.manifest s3://{BUCKET}/{PREFIX}/{EXP_NAME}/train.manifest --quiet
!aws s3 cp data/validation.manifest s3://{BUCKET}/{PREFIX}/{EXP_NAME}/validation.manifest --quiet
!aws s3 cp data/hold_out.manifest s3://{BUCKET}/{PREFIX}/{EXP_NAME}/hold_out.manifest --quiet

## SageMaker training job setup

In this section, the SageMaker built-in object detection algorithm is used to train the model with it's corresponding datasets (training and validation datasets) as input channels. 

In [None]:
from sagemaker import get_execution_role

role = get_execution_role()
sess = sagemaker.Session()
s3_resource = boto3.resource("s3")

# Using the builtin object detection algorithm in SageMaker
training_image = sagemaker.image_uris.retrieve(
    "object-detection", boto3.Session().region_name
)
augmented_manifest_filename_train = "train.manifest"
augmented_manifest_filename_validation = "validation.manifest"

# Defines paths for use in the training job request.
s3_output_path = f"s3://{BUCKET}/{PREFIX}/{EXP_NAME}/output"
s3_train_data_path = (
    f"s3://{BUCKET}/{PREFIX}/{EXP_NAME}/{augmented_manifest_filename_train}"
)
s3_validation_data_path = (
    f"s3://{BUCKET}/{PREFIX}/{EXP_NAME}/{augmented_manifest_filename_validation}"
)

In [None]:
print(f"Manifest location for training data: {s3_train_data_path}")
print(f"Manifest location for validation data: {s3_validation_data_path}")

## Specify SageMaker training job attributes

In this step, we will be set all the attributes required for model training. The attributes include a unique job name, role, output path, hyperparameters for the model, training and validation paths, number of training samples and instance details. 

Please note that this model can only be trained using an GPU instance. Refer: https://docs.aws.amazon.com/sagemaker/latest/dg/object-detection.html

For object detection, SageMaker currently supports following GPU instances for training: 
ml.p2.xlarge, ml.p2.8xlarge, ml.p2.16xlarge, ml.p3.2xlarge, ml.p3.8xlarge and ml.p3.16xlarge. 

We will be using ml.p2.xlarge to train this model. This instance takes about an hour to train the model.

In [None]:
import time
from time import gmtime, strftime

# Create unique job name
job_name_prefix = EXP_NAME
timestamp = time.strftime("-%Y-%m-%d-%H-%M-%S", time.gmtime())
model_job_name = job_name_prefix + timestamp

training_params = {
    "AlgorithmSpecification": {
        # NB. This is one of the named constants defined in the first cell.
        "TrainingImage": training_image,
        "TrainingInputMode": "Pipe",
    },
    "RoleArn": role,
    "OutputDataConfig": {"S3OutputPath": s3_output_path},
    "ResourceConfig": {
        "InstanceCount": 1,
        "InstanceType": "ml.p3.2xlarge",  # Use a GPU backed instance
        "VolumeSizeInGB": 50,
    },
    "TrainingJobName": model_job_name,
    "HyperParameters": {  # NB. These hyperparameters are at the user's discretion and are beyond the scope of this demo.
        "base_network": "resnet-50",
        "use_pretrained_model": "1",
        "num_classes": "1",
        "mini_batch_size": "10",
        "epochs": "30",
        "learning_rate": "0.001",
        "lr_scheduler_step": "",
        "lr_scheduler_factor": "0.1",
        "optimizer": "sgd",
        "momentum": "0.9",
        "weight_decay": "0.0005",
        "overlap_threshold": "0.5",
        "nms_threshold": "0.45",
        "image_shape": "300",
        "label_width": "350",
        "num_training_samples": str(num_training_samples),
    },
    "StoppingCondition": {"MaxRuntimeInSeconds": 86400},
    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "AugmentedManifestFile",  # NB. Augmented Manifest
                    "S3Uri": s3_train_data_path,
                    "S3DataDistributionType": "FullyReplicated",
                    # NB. This must correspond to the JSON field names in your augmented manifest.
                    "AttributeNames": attribute_names,
                }
            },
            "ContentType": "application/x-recordio",
            "RecordWrapperType": "RecordIO",
            "CompressionType": "None",
        },
        {
            "ChannelName": "validation",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "AugmentedManifestFile",  # NB. Augmented Manifest
                    "S3Uri": s3_validation_data_path,
                    "S3DataDistributionType": "FullyReplicated",
                    # NB. This must correspond to the JSON field names in your augmented manifest.
                    "AttributeNames": attribute_names,
                }
            },
            "ContentType": "application/x-recordio",
            "RecordWrapperType": "RecordIO",
            "CompressionType": "None",
        },
    ],
}

print("Training job name: {}".format(model_job_name))
print(
    "\nInput Data Location: {}".format(
        training_params["InputDataConfig"][0]["DataSource"]["S3DataSource"]
    )
)

## Kick off the training job. 

Model training will take approximately about 1 hour to complete. You may check the training job status in SageMaker console or using the code below.

In [None]:
sagemaker_client = boto3.client(service_name="sagemaker")
sagemaker_client.create_training_job(**training_params)

# Confirm that the training job has started
status = sagemaker_client.describe_training_job(TrainingJobName=model_job_name)[
    "TrainingJobStatus"
]
print(f"Training job name: {model_job_name}")
print("Training job current status: {}".format(status))

In [None]:
for i in range(0, 200):
    print(
        "Training job status: ",
        sagemaker_client.describe_training_job(TrainingJobName=model_job_name)[
            "TrainingJobStatus"
        ],
    )
    print(
        "Secondary status: ",
        sagemaker_client.describe_training_job(TrainingJobName=model_job_name)[
            "SecondaryStatus"
        ],
    )
    if (
        sagemaker_client.describe_training_job(TrainingJobName=model_job_name)[
            "TrainingJobStatus"
        ]
        == "InProgress"
    ):
        time.sleep(60)
    else:
        break

## Create a machine learning model

On successful model training, SageMaker will create a model artifact in the S3 output path. We will now use the model artifact (from training) to create a deployable model in SageMaker. 

In [None]:
info = sagemaker_client.describe_training_job(TrainingJobName=model_job_name)
model_data = info["ModelArtifacts"]["S3ModelArtifacts"]
print(model_data)  # Model artifact

primary_container = {
    "Image": training_image,
    "ModelDataUrl": model_data,
}

timestamp = time.strftime("-%Y-%m-%d-%H-%M-%S", time.gmtime())
model_name = "sm-object-detection-demo" + timestamp

# Create a model from training artifact
create_model_response = sagemaker_client.create_model(
    ModelName=model_name, ExecutionRoleArn=role, PrimaryContainer=primary_container
)

print(create_model_response["ModelArn"])

## Deploy the model as endpoint for real time predictions

Create endpoint configuration with name and instance type details. We will be using `ml.m4.xlarge` instance to host the model endpoint. 

In [None]:
timestamp = time.strftime("-%Y-%m-%d-%H-%M-%S", time.gmtime())
endpoint_config_name = job_name_prefix + "-epc" + timestamp
endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.m4.xlarge",
            "InitialInstanceCount": 1,
            "ModelName": model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint configuration name: {}".format(endpoint_config_name))
print(
    "Endpoint configuration arn:  {}".format(
        endpoint_config_response["EndpointConfigArn"]
    )
)

Create the endpoint and print the status. This step will take approximately 8mins.

In [None]:
job_name_prefix = "chest-xray-demo"
endpoint_name = job_name_prefix + "-ep" + timestamp
print("Endpoint name: {}".format(endpoint_name))

endpoint_params = {
    "EndpointName": endpoint_name,
    "EndpointConfigName": endpoint_config_name,
}
endpoint_response = sagemaker_client.create_endpoint(**endpoint_params)
print("EndpointArn = {}".format(endpoint_response["EndpointArn"]))

# get the status of the endpoint
response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
status = response["EndpointStatus"]
print("EndpointStatus = {}".format(status))

# wait until the status has changed
sagemaker_client.get_waiter("endpoint_in_service").wait(EndpointName=endpoint_name)

# print the status of the endpoint
endpoint_response = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)
status = endpoint_response["EndpointStatus"]
print("Endpoint creation ended with EndpointStatus = {}".format(status))

if status != "InService":
    raise Exception("Endpoint creation failed.")

## Model Predictions

In this section, we create 2 utility functions for visualizing the predictions. `make_predicted_image` is used to create a `BoxedImage` object from the ground truth utils class  and `perform_inference` is used to plot the image with actual predictions from SageMaker model. The `ground_truth_utils` library is from [here](https://github.com/aws/amazon-sagemaker-examples/blob/main/ground_truth_labeling_jobs/ground_truth_object_detection_tutorial/ground_truth_od.py).

For making a realtime prediction, the image is first transformed and stored into an appropriate format. The image is then submitted to the model endpoint in the form of bytearray payload. Once the response (prediction) is received from model endpoint, the result is plotted for visualization. 

In [None]:
from ground_truth_utils import BoundingBox, BoxedImage

def make_predicted_image(predictions, img_id, uri):
    ''' Creates a BoxedImage object with predicted bounding boxes. '''

    img = BoxedImage(id=img_id, uri=uri)
    img.download(f'./{local_dir}')
    imread_img = img.imread()
    imh, imw, *_ = imread_img.shape

    # Create boxes.
    for batch_data in predictions:
        class_id, confidence, xmin, ymin, xmax, ymax = batch_data
        boxdata = {'class_id': class_id,
                   'height': (ymax-ymin)*imh,
                   'width': (xmax-xmin)*imw,
                   'left': xmin*imw,
                   'top': ymin*imh}
        box = BoundingBox(boxdata=boxdata, image_id=img.id)
        img.consolidated_boxes.append(box)

    return img

def perform_inference(uri, local_dir):
    '''Perform inference on an image'''

    realtime_uri = uri
    !aws s3 cp --quiet $realtime_uri data/the_image.png
    test_image='data/the_image.png'
    with open(test_image, 'rb') as f:
        payload = f.read()
        payload = bytearray(payload)

    #manually set endpoint if job is interrupted
    #endpoint_name='chest-xray-demo-ep-2020-07-24-01-48-42'
    sm_runtime_client = boto3.client('sagemaker-runtime')
    response = sm_runtime_client.invoke_endpoint(EndpointName=endpoint_name, 
                                       ContentType='application/x-image', 
                                       Body=payload)

    result = response['Body'].read()
    result = json.loads(result)
    predictions = [prediction for prediction in result['prediction'] if prediction[1] > .2]

    realtime_img = make_predicted_image(predictions, 'RealtimeTest', realtime_uri)

    # Plot the realtime prediction.
    fig, ax = plt.subplots()
    realtime_img.download(f'./{local_dir}')
    realtime_img.plot_consolidated_bbs(ax)


### Predicting images from validation and hold out datasets

Pick 5 random images from the validation dataset and perform object detection

In [None]:
local_dir = "validation"
f_in_validation = open("data/validation.manifest", "r")
with f_in_validation as f:
    validation_uris = [json.loads(line.strip())["source-ref"] for line in f.readlines()]
validation_sample = np.random.choice(validation_uris, 3, replace=False)

for i in range(0, len(validation_sample)):
    perform_inference(validation_sample[i], local_dir)

Now let's pick 5 random images from the holdout set for prediction

In [None]:
local_dir = "holdout"

f_in_hold_out = open("data/hold_out.manifest", "r")
with f_in_hold_out as f:
    hold_out_uris = [json.loads(line.strip())["source-ref"] for line in f.readlines()]
hold_out_uris_sample = np.random.choice(hold_out_uris, 3, replace=False)

for i in range(0, len(hold_out_uris_sample)):
    perform_inference(hold_out_uris_sample[i], local_dir)

### Clean up 

Delete the endpoint

In [None]:
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)