In [None]:
import os
import time
from datetime import datetime
from IPython.display import Image
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.inputs import FileSystemInput
from sagemaker.debugger import ProfilerConfig, FrameworkProfile, DetailedProfilingConfig

In [None]:
CURR_SM_ROLE = 'arn:aws:iam::154108359553:role/service-role/AmazonSageMaker-ExecutionRole-20210203T120788'

In [None]:
MODEL = 'RESNET18'
#MODEL = 'RESNET50'

#BATCH_SIZE = 128
#BATCH_SIZE = 256
BATCH_SIZE = 64
#BATCH_SIZE = 32

LR = 0.001

NUM_EPOCHS = 1
#NUM_EPOCHS = 5

#INSTANCE_TYPE = 'ml.g4dn.12xlarge'
#INSTANCE_TYPE = 'ml.p3.8xlarge'
INSTANCE_TYPE = 'ml.p3.2xlarge'

AUGMENTATION = 'pytorch-cpu'
#AUGMENTATION = 'dali-cpu'
#AUGMENTATION = 'dali-gpu'

AUGMENTATION_LOAD = 5

SPOT_TRAINING = False
#SPOT_TRAINING = True

In [None]:
BUCKET = 'dali-test'

#Full size download of https://github.com/fastai/imagenette
#1.3GB — 13,395 images for 10 classes
train_data_s3 = 's3://{}/{}'.format(BUCKET, 'imagenette2')
#train_data_s3 = 's3://{}/{}'.format(BUCKET, 'imagenette2_subset')

In [None]:
model_ckpt_s3 = 's3://{}/{}'.format(BUCKET, 'training_jobs_checkpoints')
src_code_s3 = 's3://{}/{}'.format(BUCKET, 'training_jobs')
training_job_output_s3 = 's3://{}/{}'.format(BUCKET, 'training_jobs_output')

In [None]:
framework_profile_params = FrameworkProfile(local_path="/opt/ml/output/profiler/", 
                                    start_step = 1, 
                                    num_steps = NUM_EPOCHS,
                                    detailed_profiling_config=DetailedProfilingConfig(start_step = 1, 
                                                                                      num_steps = NUM_EPOCHS))
aug_metric_config = ProfilerConfig(
        system_monitor_interval_millis = 100,
        framework_profile_params = framework_profile_params
)

In [None]:
train_estimator = PyTorch(entry_point = 'sm_augmentation_train-script.py',
                          source_dir =  './src',
                          role = CURR_SM_ROLE,
                          framework_version = '1.8.1',
                          py_version = 'py3',
                          
                          profiler_config = aug_metric_config,
                          debugger_hook_config = False,
                          
                          instance_count = 1,
                          instance_type = INSTANCE_TYPE,
                          
                          output_path = training_job_output_s3,
                          code_location = src_code_s3,
                            
                          hyperparameters = {'epochs': NUM_EPOCHS, 
                                            'backend': 'nccl',
                                            'pretrained-model-type': MODEL,
                                            'lr': LR,
                                            'batch-size': BATCH_SIZE,
                                            'aug': AUGMENTATION,
                                            'aug-load': AUGMENTATION_LOAD},

                          use_spot_instances = SPOT_TRAINING,
                          checkpoint_s3_uri = model_ckpt_s3 if SPOT_TRAINING else None)

In [None]:
train_input = sagemaker.inputs.TrainingInput(
                                        s3_data_type = 'S3Prefix',
                                        s3_data = train_data_s3,
                                        content_type='image/jpeg',
                                        input_mode='File') 

val_input = sagemaker.inputs.TrainingInput(
                                        s3_data_type = 'S3Prefix',
                                        s3_data = train_data_s3,
                                        content_type='image/jpeg',
                                        input_mode='File') 

data_channels = {'train': train_input, 'val': val_input}

In [None]:
#train_job_id = 'aug-'+str(int(time.time()))
train_job_id = 'aug-' \
               + AUGMENTATION + '-' \
               + str(datetime.now().strftime("%H-%M-%S"))
print ('Launching Remote Training Job: ' + str(train_job_id))

start = time.time()
#tuner.fit(inputs = data_channels, job_name = train_job_id)
train_estimator.fit(inputs = data_channels, job_name = train_job_id)
end = time.time()

In [None]:
print ('Training Job: ' + str(train_job_id))
print('Total time {:.0f}m {:.0f}s'.format((end-start)//60, (end-start)%60))
print ('Model: ', MODEL)
print ('Batch Size: ', BATCH_SIZE)
print ('Learning Rate: ', LR)
print ('Epochs: ', NUM_EPOCHS)
print ('Instance: ', INSTANCE_TYPE)
print ('Augmentation: ', AUGMENTATION)

In [None]:
profile_report_s3 = 's3://' + BUCKET + '/training_jobs_output/' + train_job_id + '/rule-output'
print ('Downloading job-profile report from: '+ profile_report_s3)
!aws s3 cp $profile_report_s3 ./reports --recursive

In [None]:
'''
Check SageMaker Studio Experiement section for detailed CPU/GPU time-series metrics
'''