This notebook is developed using the `Python 3 (Data Science)` kernel on an `ml.t3.medium` instance.

In [None]:
!pip install -q sagemaker-experiments

In [None]:
import sagemaker
import json
import boto3

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()
prefix = 'sagemaker-studio-book/chapter09'

In [None]:
%%writefile code/smmp_pytorch_mnist.py
import argparse
import math
import os
import random
import time

import smdistributed.modelparallel.torch as smp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast
from torch.optim.lr_scheduler import StepLR
from torchnet.dataset import SplitDataset
from torchvision import datasets, transforms
import numpy as np

# Make cudnn deterministic in order to get the same losses across runs.
# The following two lines can be removed if they cause a performance impact.
# For more details, see:
# https://pytorch.org/docs/stable/notes/randomness.html#cudnn
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Fix the randomness
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed(1)    

def parse_args():
    parser = argparse.ArgumentParser()

    # hyperparameters sent by the client are passed as command-line arguments to the script
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--learning_rate', type=float, default=4)
    parser.add_argument('--num_workers', type=float, default=1)
    # model directory /opt/ml/model default set by SageMaker
    parser.add_argument('--model_dir', type=str, default=os.environ.get('SM_MODEL_DIR'))

    return parser.parse_known_args()


class Net1(nn.Module):
    def __init__(self):
        super(Net1, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        return x


class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        output = F.log_softmax(x, 1)
        return output


class GroupedNet(nn.Module):
    def __init__(self):
        super(GroupedNet, self).__init__()
        self.net1 = Net1()
        self.net2 = Net2()

    def forward(self, x):
        x = self.net1(x)
        x = self.net2(x)
        return x

@smp.step
def train_step(model, data, target):
    output = model(data)
    loss = F.nll_loss(output, target, reduction='mean')
    model.backward(loss)
    
    return output, loss


def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        # Move input tensors to the GPU ID used by the current process,
        # based on the set_device call.
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        # loss_mb is a StepOutput object
        _, loss_mb = train_step(model, data, target)

        # Average the loss across microbatches.
        loss = loss_mb.reduce_mean()

        optimizer.step()

        if smp.rank() == 0 and batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        epoch,
                        batch_idx * len(data),
                        len(train_loader.dataset),
                        100.0 * batch_idx / len(train_loader),
                        loss.item(),
                    )
            )


# Define smp.step for model evaluation.
@smp.step
def test_step(model, data, target):
    output = model(data)
    loss = F.nll_loss(output, target, reduction='sum').item() 
    pred = output.argmax(dim=1, keepdim=True) 
    correct = pred.eq(target.view_as(pred)).sum().item()
    
    return loss, correct


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            # SM Distributed: Moves input tensors to the GPU ID used by the current process
            # based on the set_device call.
            data, target = data.to(device), target.to(device)

            # Since test_step returns scalars instead of tensors,
            # test_step decorated with smp.step will return lists instead of StepOutput objects.
            loss_batch, correct_batch = test_step(model, data, target)
            test_loss += sum(loss_batch)
            correct += sum(correct_batch)

    test_loss /= len(test_loader.dataset)
    if smp.mp_rank() == 0:
        print(
            '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                test_loss,
                correct,
                len(test_loader.dataset),
                100.0 * correct / len(test_loader.dataset),
            )
        )
    return test_loss


if __name__ == '__main__':
    args, _ = parse_args()
    
    if not torch.cuda.is_available():
        raise ValueError('The script requires CUDA support, but CUDA not available')

    smp.init()

    # Set the device to the device on the current process (smp.local_rank).
    torch.cuda.set_device(smp.local_rank())
    device = torch.device('cuda')

    transform = transforms.Compose(
        [transforms.ToTensor(), 
         transforms.Normalize((0.1307,), (0.3081,))]
    )

    # Download only on a the first process per instance.
    # When this is not present, the file is corrupted by multiple processes trying
    # to download and extract at the same time
    if smp.local_rank() == 0:
        dataset1 = datasets.MNIST('../data', train=True, 
                                  download=True, transform=transform)
        
    # Wait for all processes to be ready
    smp.barrier()
    dataset1 = datasets.MNIST('../data', train=True, 
                              download=False, transform=transform)

    # Download and create dataloaders for train and test dataset
    dataset2 = datasets.MNIST('../data', train=False, 
                              transform=transform)

    kwargs = {'batch_size': args.batch_size, 'num_workers': args.num_workers, 
              'pin_memory': True, 'shuffle': False, 'drop_last': True}
    train_loader = torch.utils.data.DataLoader(dataset1, **kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **kwargs)

    # Get the model
    model = GroupedNet()

    # SageMaker Model Parallel handles the transfer of model parameters to the right device
    # and the user doesn't need to call 'model.to(device)' explicitly.
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

    # Wrap the model with smp.DistributedModel and let smp handles the model parallelism
    # Similiarly for optimizer
    model = smp.DistributedModel(model)
    optimizer = smp.DistributedOptimizer(optimizer)

    scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
    for epoch in range(args.epochs):
        train(model, device, train_loader, optimizer, epoch)
        test_loss = test(model, device, test_loader)
        scheduler.step()

    # Waiting the save checkpoint to be finished before run another allgather_object
    smp.barrier()

    # To save a model, always save on dp_rank 0 to avoid data racing
    if smp.dp_rank() == 0:
        model_dict = model.local_state_dict()
        opt_dict = optimizer.local_state_dict()
        model = {'model_state_dict': model_dict, 'optimizer_state_dict': opt_dict}
        model_output_path = f'{args.model_dir}/pt_mnist_checkpoint.pt'
        smp.save(model, model_output_path, partial=True)
        # partial=True if you want to be able to load and further train a model that you save with smp.save().

In [None]:
from smexperiments.experiment import Experiment
from smexperiments.trial import Trial
from botocore.exceptions import ClientError

experiment_name = 'mnist-classification'

try:
    experiment = Experiment.create(
        experiment_name=experiment_name, 
        description='Training a classification model for mnist dataset.')
except ClientError as e:
    print(f'{experiment_name} experiment already exists! Reusing the existing experiment.')
    

In [None]:
from time import gmtime, strftime
from sagemaker.pytorch import PyTorch

exp_datetime = strftime('%Y-%m-%d-%H-%M-%S', gmtime())
jobname = f'mnist-smmp-pt-{exp_datetime}'

s3_output_location = f's3://{bucket}/{prefix}/{jobname}'
code_dir = f's3://{bucket}/{prefix}/{jobname}'

train_instance_type = 'ml.p3.8xlarge'

distribution = {'smdistributed': {
                    'modelparallel': {
                        'enabled': True,
                        'parameters': {
                            'partitions': 2,
                            'microbatches': 4,
                            'pipeline': 'interleaved',
                            'optimize': 'speed',
                            'ddp': False
                        }
                    }
                },
                'mpi': {
                    'enabled': True,
                    'processes_per_host': 2
                }
            }

estimator = PyTorch(source_dir='code',
                    entry_point='smmp_pytorch_mnist.py',
                    output_path=s3_output_location,
                    code_location=code_dir,
                    instance_type=train_instance_type,
                    instance_count=1,
                    enable_sagemaker_metrics=True,
                    sagemaker_session=sess,
                    role=role,
                    framework_version='1.6.0',
                    py_version='py36',
                    distribution=distribution)

Finally, you will use the estimator to launch the SageMaker training job.

In [None]:
# Creating a new trial for the experiment
exp_trial = Trial.create(experiment_name=experiment_name, 
                         trial_name=jobname)

experiment_config={'ExperimentName': experiment_name,
                   'TrialName': exp_trial.trial_name,
                   'TrialComponentDisplayName': 'Training'}

estimator.fit(job_name=jobname,
              experiment_config=experiment_config,
              wait=True)