# Distributed training of tissue slide images using SageMaker and Horovod

## Visualizing input SVS image

We need slideio for visualizing SVS images

In [None]:
!pip install slideio===0.5.225
!mkdir -p images

### Import libraries needed

In [None]:
import os

import boto3
import slideio
import matplotlib.pyplot as plt
import sagemaker
import numpy as np
import tensorflow as tf

from sagemaker.processing import Processor, ProcessingInput, ProcessingOutput
from sagemaker import get_execution_role
from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow.model import TensorFlowModel
from sagemaker.session import s3_input

role = get_execution_role()
sagemaker_session = sagemaker.Session()
region = boto3.Session().region_name

### Configurations

In [None]:
bucket = 'tcga-data' # Please specify the bucket where the tissues SVS images are downloaded

### TCGA SVS files

For downloading TCGA images, please refer to README file. Create a bucket in S3 and a folder `tcga-svs` within the bucket. This folder will contain all the SVS files.

Replace the bucket name below with the name of the bucket you created.

In [None]:
# Download sample svs file from S3
s3 = boto3.resource('s3', region_name=region)

image_file = 'TCGA-55-8514-01A-01-TS1.0e0f5cf3-96e9-4a35-aaed-4340df78d389.svs'
key = f'tcga-svs/0000b231-7c05-4e2e-8c9e-6d0675bfbb34/{image_file}'

s3.Bucket(bucket).download_file(key, f'./images/{image_file}')

# Read svs image
slide = slideio.open_slide(path=f"./images/{image_file}", driver="SVS")
scene = slide.get_scene(0)
block = scene.read_block()

# Display image
plt.imshow(block,aspect="auto")
plt.show()

## Build Docker container for preprocessing SVS files into TFRecord

### Dockerfile

In [None]:
!pygmentize Dockerfile

### Python script for preprocessing

In [None]:
!pygmentize src/script.py

### Build container and upload it to ECR

In [None]:
from docker_utils import build_and_push_docker_image

repository_short_name = 'tcga-tissue-slides-preprocess'
image_name = build_and_push_docker_image(repository_short_name)

## Launch SageMaker Processing Job


In [None]:
processor = Processor(image_uri=image_name,
                      role=get_execution_role(),
                      instance_count=16,               # run the job on 16 instances
                      base_job_name='processing-base', # should be unique name
                      instance_type='ml.m5.4xlarge', 
                      volume_size_in_gb=1000)

processor.run(inputs=[ProcessingInput(
    source=f's3://{bucket}/tcga-svs', # s3 input prefix
    s3_data_type='S3Prefix',
    s3_input_mode='File',
    s3_data_distribution_type='ShardedByS3Key', # Split the data across instances
    destination='/opt/ml/processing/input')], # local path on the container
    outputs=[ProcessingOutput(
        source='/opt/ml/processing/output', # local output path on the container
        destination=f's3://{bucket}/tcga-svs-tfrecords/' # output s3 location
    )],
    arguments=['10000'], # number of tiled images per TF record for training dataset
    wait=True,
    logs=True)

### Visualize tiled images within TF records

In [None]:
%matplotlib inline

print(tf.__version__)
print(tf.executing_eagerly())

HEIGHT=512
WIDTH=512
DEPTH=3
NUM_CLASSES=3

def dataset_parser(value):
    image_feature_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'image_raw': tf.io.FixedLenFeature([], tf.string),
        'slide_string': tf.io.FixedLenFeature([], tf.string)
    }
    record = tf.io.parse_single_example(value, image_feature_description)
    image = tf.io.decode_raw(record['image_raw'], tf.float32)
    image = tf.cast(image, tf.float32)
    image.set_shape([DEPTH * HEIGHT * WIDTH])
    image = tf.cast(tf.reshape(image, [HEIGHT, WIDTH, DEPTH]), tf.float32)
    label = tf.cast(record['label'], tf.int32)
    slide = record['slide_string']
    
    return image, label, slide

# List first 10 tiled images

key = 'tcga-svs-tfrecords/test'

file = [f for f in s3.Bucket(bucket).objects.filter(Prefix=key).limit(1)][0]
local_file = file.key.split('/')[-1]
s3.Bucket(bucket).download_file(file.key, f'./images/{local_file}')

raw_image_dataset = tf.data.TFRecordDataset(f'./images/{local_file}')
parsed_image_dataset = raw_image_dataset.map(dataset_parser)

c = 0
for image_features in parsed_image_dataset:
    image_raw = image_features[0].numpy()
    label = image_features[1].numpy()
    
    plt.figure()
    plt.imshow(image_raw/255) 
    plt.title(f'Full image: {image_features[2].numpy().decode()}, Label: {label}')

    c += 1
    if c == 10:
        break

## Distributed training with Horovod and Pipe Mode input
Ditributed training with Horovod can also utilize SageMaker Pipe Mode.

SageMaker Pipe Mode is a mechanism for providing S3 data to a training job via Linux fifos. Training programs can read from the fifo and get high-throughput data transfer from S3, without managing the S3 access in the program itself.
Pipe Mode is covered in more detail in the SageMaker [documentation](https://sagemaker.readthedocs.io/en/stable/frameworks/tensorflow/using_tf.html#training-with-pipe-mode-using-pipemodedataset)

In [None]:
train_instance_type='ml.p3.8xlarge'
train_instance_count = 4
gpus_per_host = 4
num_of_shards = gpus_per_host * train_instance_count

distributions = {'mpi': {
    'enabled': True,
    'processes_per_host': gpus_per_host
    }
}

### Sharding the tiles 

In [None]:
# Sharding
client = boto3.client('s3')
result = client.list_objects(Bucket=bucket, Prefix='tcga-svs-tfrecords/train/', Delimiter='/')

j = -1
for i in range(num_of_shards):
    copy_source = {
        'Bucket': bucket,
        'Key': result['Contents'][i]['Key']
     }
    print(result['Contents'][i]['Key'])
    if i%4 == 0:
        j += 1
    dest = 'tcga-svs-tfrecords/train_sharded/' + str(j) +'/' + result['Contents'][i]['Key'].split('/')[2]
    print(dest)
    s3.meta.client.copy(copy_source, bucket, dest)

In [None]:
svs_tf_sharded = f's3://{bucket}/tcga-svs-tfrecords'
shuffle_config = sagemaker.session.ShuffleConfig(234)
train_s3_uri_prefix = svs_tf_sharded
remote_inputs = {}

for idx in range(4):
    train_s3_uri = f'{train_s3_uri_prefix}/train_sharded/{idx}/'
    train_s3_input = s3_input(train_s3_uri, distribution ='ShardedByS3Key', shuffle_config=shuffle_config)
    remote_inputs[f'train_{idx}'] = train_s3_input
    remote_inputs['valid_{}'.format(idx)] = '{}/valid'.format(svs_tf_sharded)
remote_inputs['test'] = '{}/test'.format(svs_tf_sharded)
remote_inputs

### Training script

In [None]:
!pygmentize src/train.py

In [None]:
local_hyperparameters = {'epochs': 5, 'batch-size' : 16, 'num-train':160000, 'num-val':8192, 'num-test':8192}

estimator_dist = TensorFlow(base_job_name='svs-horovod-cloud-pipe',
                            entry_point='src/train.py',
                            role=role,
                            framework_version='2.1.0',
                            py_version='py3',
                            distribution=distributions,
                            volume_size=1024,
                            hyperparameters=local_hyperparameters,
                            output_path=f's3://{bucket}/output/',
                            instance_count=4, 
                            instance_type=train_instance_type,
                            input_mode='Pipe')

estimator_dist.fit(remote_inputs, wait=True)

## Deploy the trained model

The deploy() method creates an endpoint that serves prediction requests in real-time.
The model saves keras artifacts, to use TensorFlow serving for deployment, you'll need to save the artifacts in SavedModel format.

In [None]:
%matplotlib inline

plt.style.use('bmh')

In [None]:
# Create predictor from S3 instead

model_data = f's3://{bucket}/output/{estimator_dist.latest_training_job.name}/output/model.tar.gz'

model = TensorFlowModel(model_data=model_data, 
                        role=role, framework_version='2.1.0')

predictor = model.deploy(initial_instance_count=1, instance_type='ml.c5.xlarge')

## Make some predictions
To verify the that the endpoint functions properly, we generate random data in the correct shape and get a prediction.

In [None]:
HEIGHT=512
WIDTH=512
DEPTH=3
NUM_CLASSES=3

def _dataset_parser_with_slide(value):
    image_feature_description = {
        'label': tf.io.FixedLenFeature([], tf.int64),
        'image_raw': tf.io.FixedLenFeature([], tf.string),
        'slide_string': tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(value, image_feature_description)
    image = tf.io.decode_raw(example['image_raw'], tf.float32)
    image = tf.cast(image, tf.float32)
    image.set_shape([DEPTH * HEIGHT * WIDTH])
    image = tf.cast(tf.reshape(image, [HEIGHT, WIDTH, DEPTH]), tf.float32)
    label = tf.cast(example['label'], tf.int32)
    slide = example['slide_string']
    
    return image, label, slide

### Tile-level prediction

In [None]:
local_file = [each for each in os.listdir('./images') if each.endswith('.tfrecords')][0]

raw_image_dataset = tf.data.TFRecordDataset(f'./images/{local_file}') ## read a TFrecord
parsed_image_dataset = raw_image_dataset.map(_dataset_parser_with_slide) ## Parse TFrecord to JPEGs

pred_scores_list = []
for i, element in enumerate(parsed_image_dataset):
    if i > 10:
        break
    image = element[0].numpy()
    label = element[1].numpy()
    slide = element[2].numpy().decode()
    if i == 0:
        print(f'Making tile-level predictions for slide: {slide}...')

    print(f'Querying endpoint for a prediction for tile {i+1}...')
    pred_scores = predictor.predict(np.expand_dims(image, axis=0))['predictions'][0]
    print(pred_scores)
    pred_class = np.argmax(pred_scores) 
    print(pred_class)
       
    if i > 0 and i % 10 == 0:
        plt.figure()
        plt.title(f'Tile {i} prediction: {pred_class}')  
        plt.imshow(image / 255)
         
    pred_scores_list.append(pred_scores)
print('Done.')

### Slide-level prediction (average score over all tiles)

In [None]:
mean_pred_scores = np.mean(np.vstack(pred_scores_list), axis=0)
mean_pred_class = np.argmax(mean_pred_scores)

print(f"Slide-level prediction for {slide}:", mean_pred_class)