In [None]:
!pip install -q sagemaker-experiments sagemaker-studio-image-build

In [None]:
import uuid
import json
import os, sys
from time import gmtime, strftime
import boto3
import sagemaker
from sagemaker.session import Session
from sagemaker.feature_store.feature_group import FeatureGroup
from sagemaker.processing import ScriptProcessor
from sagemaker.processing import ProcessingInput, ProcessingOutput

role = sagemaker.get_execution_role()
sagemaker_session = sagemaker.Session()
region = sagemaker_session.boto_region_name
account_id = boto3.client('sts').get_caller_identity().get('Account')

default_bucket = sagemaker_session.default_bucket()
suffix=uuid.uuid1().hex[:8] # to be used in resource names

## Build a container image from the Dockerfile

Here we need to build a custom Docker container image to handle the CT images of DICOM format. We are going to use `sm-docker` utility described in [Using the Amazon SageMaker Studio Image Build CLI to build container images from your Studio notebooks](https://aws.amazon.com/blogs/machine-learning/using-the-amazon-sagemaker-studio-image-build-cli-to-build-container-images-from-your-studio-notebooks/). Note that you need to follow the prerequisite in the blog to add the IAM policies to you SageMaker execution role.

Be sure to use the image and tag name defined in `!sm-docker build` command. We will be replacing the placeholders in the Stepfunctions state machine definition json file with your bucket and image uri.

In [None]:
%%sh
cd src
sed -i "s|##REGION##|{region}|g" Dockerfile
cat Dockerfile
sm-docker build . --repository medical-image-processing-smstudio:1.3
cd ../

In [None]:
ecr_image_uri=f'{account_id}.dkr.ecr.{region}.amazonaws.com/medical-image-processing-smstudio:1.3'
print(ecr_image_uri)

We set up an experiment in SageMaker to hold information of the processing jobs. 

Execute the next four cells to launch the training jobs if this is the first time running the demo. There will be 162 processing jobs submitted in a for loop. We implemented a function `wait_for_instance_quota()` to check for the current job count and limit the total jobs in this experiment to `job_limit`. If the job count is at the limit, the function waits number of seconds specified in `wait` argument and check the job count again. This is to account for account level SageMaker quota that may cause error in the for loop. The default service quota for *Number of instances across processing jobs* and *number of ml.r5.large instances* are 4 as documented in [Service Quota page](https://docs.aws.amazon.com/general/latest/gr/sagemaker.html#limits_sagemaker). If your account has a higher limit, you may change the `job_limit` to a higher number to allow more simultaneous training jobs (therefore faster). You can also [request a quota increase](https://docs.aws.amazon.com/general/latest/gr/aws_service_limits.html).

In [None]:
from smexperiments.experiment import Experiment
from smexperiments.trial import Trial
from botocore.exceptions import ClientError
import time
from time import gmtime, strftime

dict_processor = {}

experiment_name = f'nsclc-lung-cancer-survival-prediction-{suffix}'
trial_name = 'multimodal-1'

try:
    experiment = Experiment.create(
        experiment_name=experiment_name, 
        description='Lung cancer survival prediction using multi-modal Non Small Cell Lung Cancer (NSCLC) Radiogenomic dataset')
except ClientError as e:
    experiment = Experiment.load(experiment_name)
    print(f'{experiment_name} experiment already exists! Reusing the existing experiment.')
    
# Creating a new trial for the experiment
exp_trial = Trial.create(experiment_name=experiment_name, 
                             trial_name=trial_name)

In [None]:
def launch_processing_job(subject, input_data_s3, output_data_s3, feature_store_name, offline_store_s3uri):
    exp_datetime = strftime('%Y-%m-%d-%H-%M-%S', gmtime())
    jobname = f'nsclc-{subject}-{exp_datetime}'

    experiment_config={'ExperimentName': experiment_name,
                       'TrialName': trial_name,
                       'TrialComponentDisplayName': f'ImageProcessing-{subject}'}

    inputs = [ProcessingInput(input_name = 'DICOM',
                              source=f'{input_data_s3}/{subject}',
                              destination='/opt/ml/processing/input')]

    outputs = [ProcessingOutput(output_name=i,
                                source='/opt/ml/processing/output/%s' % i,
                                destination=os.path.join(output_data_s3, i)) 
               for i in ['CT-Nifti', 'CT-SEG', 'PNG']]

    arguments = ['--subject', subject, 
                 '--feature_store_name', feature_store_name, 
                 '--offline_store_s3uri', offline_store_s3uri]
    
    script_processor = ScriptProcessor(command=['python3'],
                                       image_uri=ecr_image_uri,
                                       role=role,
                                       instance_count=1,
                                       instance_type='ml.m5.large',
                                       volume_size_in_gb=5,
                                       sagemaker_session=sagemaker_session)

    script_processor.run(code='./src/dcm2nifti_processing.py',
                         inputs=inputs,
                         outputs=outputs,
                         arguments=arguments,
                         job_name=jobname,
                         experiment_config=experiment_config,
                         wait=False,
                         logs=False)

    return script_processor

In [None]:
def wait_for_instance_quota(dict_processor, job_limit = 4, wait = 30):
    def query_jobs(dict_processor):
        counter=0
        for key, processor in dict_processor.items():
            status = processor.jobs[-1].describe()['ProcessingJobStatus']
            # print(status)
            time.sleep(2)
            if status == "InProgress":
                counter+=1
        return counter
    
    job_count = query_jobs(dict_processor)
    if job_count < job_limit:
        print(f'Current total running jobs {job_count} is below {job_limit}. Proceeding...')
        return 
    
    while job_count >= job_limit:
        print(f'Current total running jobs {job_count} is reaching the limit {job_limit}. Waiting {wait} seconds...')
        time.sleep(wait)
        job_count = query_jobs(dict_processor)

    print(f'Current total running jobs {job_count} is below {job_limit}. Proceeding...')

We will be running this workflow for all the `RO1` subjects.

In [None]:
subject_list = ['R01-%03d'%i for i in range(1,163)]

input_data_bucket='multimodal-image-data' # this is where you downloaded the DICOM files
input_data_prefix='nsclc_radiogenomics'

output_data_bucket=default_bucket
output_data_prefix='nsclc_radiogenomics'

input_dicom_dir = f's3://{input_data_bucket}/{input_data_prefix}'
output_nifti_dir = f's3://{output_data_bucket}/{output_data_prefix}'
    
feature_store_name = 'imaging-feature-group-%s' % suffix # Please append suffix if you want to create a unique feature group repetitively
offline_store_s3uri = '%s/multimodal-imaging-featurestore' % output_nifti_dir

for subject in subject_list:
    print(subject)
    wait_for_instance_quota(dict_processor, job_limit=4, wait=30)
    dict_processor[subject] = launch_processing_job(subject, input_dicom_dir, output_nifti_dir, feature_store_name, offline_store_s3uri)
    time.sleep(2)