In [None]:
import os
import boto3
import time
import numpy as np
import sagemaker
import torchvision
from matplotlib import pyplot as plt
from sagemaker.pytorch import PyTorch

from src.utils.utils import load_history, plot_history

sess = boto3.Session()
sm   = sess.client('sagemaker')
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
# role = 'coeuraj-ds'

bucket_name    = sagemaker_session.default_bucket()
jobs_folder    = 'jobs'
dataset_folder = 'datasets'

In [None]:
# Parameters
backend = 'smddp' # 'gloo'
instance_type = 'local_gpu' #'ml.p4d.24xlarge', # 'ml.p3.16xlarge', 'ml.p3dn.24xlarge', 'ml.p4d.24xlarge'
instance_count = 1

In [None]:
cifar10_dataset = torchvision.datasets.CIFAR10('cifar10-dataset',
                                               train=True,
                                               download=True)

In [None]:
datasets = sagemaker_session.upload_data(path='cifar10-dataset',
                                         key_prefix=f'{dataset_folder}/cifar10-dataset')

In [None]:
job_name   = f'pytorch-smddp-dist-{time.strftime("%Y-%m-%d-%H-%M-%S-%j", time.gmtime())}'
output_path = f's3://{bucket_name}/{jobs_folder}'

hyperparameters = {'seed'           : 32,
                   'optimizer'      : 'sgd',
                   'momentum'       : 0.9,
                   'lr'             : 0.001,
                   'criterion'      : 'cross_entropy',
                   'pred_function'  : 'softmax',
                   'metric'         : 'accuracy',
                   'custom_function': True, # If True update custom_pre_process_function in src/utils/functions.py
                   'backend'        : backend
                   }

In [None]:
distribution = { "smdistributed": {
                                    "dataparallel": { "enabled": True }
                                  }
                }

In [None]:
estimator = PyTorch(entry_point          = 'main.py',
                    source_dir           = '.',
                    output_path          = output_path + '/',
                    code_location        = output_path,
                    role                 = role,
                    instance_count       = instance_count,
                    instance_type        = instance_type,
                    framework_version    = '1.11.0',
                    py_version           = 'py38',
                    distribution         = distribution,
                    hyperparameters      = hyperparameters)

In [None]:
estimator.fit({'train': datasets},
              job_name=job_name,
              wait=True)

In [None]:
history = load_history(os.environ["SM_MODEL_DIR"])

In [None]:
plot_history(history)