In [9]:
# %pip install --upgrade sagemaker

In [2]:
import sagemaker
sagemaker.__version__

'2.103.0'

In [3]:
!mkdir scripts

In [32]:
%%writefile scripts/requirements.txt
pytorch-lightning == 1.6.3
lightning-bolts == 0.5.0

Overwriting scripts/requirements.txt


In [5]:
%%writefile scripts/mnist.py

import os
import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy

from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule

import argparse

class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        probs = self(x)
        # we currently return the accuracy as the validation_step/test_step is run on the IPU devices.
        # Outputs from the step functions are sent to the host device, where we calculate the metrics in
        # validation_epoch_end and test_epoch_end for the test_step.
        acc = self.accuracy(probs, y)
        return acc

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = self.accuracy(logits, y)
        return acc

    def accuracy(self, logits, y):
        # currently IPU poptorch doesn't implicit convert bools to tensor
        # hence we use an explicit calculation for accuracy here. Once fixed in poptorch
        # we can use the accuracy metric.
        acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)
        return acc

    def validation_epoch_end(self, outputs) -> None:
        # since the training step/validation step and test step are run on the IPU device
        # we must log the average loss outside the step functions.
        self.log("val_acc", torch.stack(outputs).mean(), prog_bar=True)

    def test_epoch_end(self, outputs) -> None:
        self.log("test_acc", torch.stack(outputs).mean())

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--hosts", type=list, default=os.environ["SM_HOSTS"])
    parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
    parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--num-gpus", type=int, default=int(os.environ["SM_NUM_GPUS"]))

    parser.add_argument("--num_nodes", type=int, default = len(os.environ["SM_HOSTS"]))
           
    # need to double check if num_gpus is per node or in total
    world_size = int(os.environ["SM_NUM_GPUS"]) * len(os.environ["SM_HOSTS"])
                 
    parser.add_argument("--world-size", type=int, default=world_size)
    
    args = parser.parse_args()
    
    return args
    
    
if __name__ == "__main__":
    
    args = parse_args()
    
    dm = MNISTDataModule(batch_size=32)
    
    model = LitClassifier()
    
    local_rank = os.environ["LOCAL_RANK"]
    torch.cuda.set_device(int(local_rank))
    
    num_nodes = args.num_nodes
    num_gpus = args.num_gpus
    
    env = LightningEnvironment()
    
    env.world_size = lambda: int(os.environ.get("WORLD_SIZE", 0))
    env.global_rank = lambda: int(os.environ.get("RANK", 0))
    
    ddp = DDPStrategy(cluster_environment=env, accelerator="gpu")
    
    trainer = pl.Trainer(max_epochs=200, strategy=ddp, devices=num_gpus, num_nodes=num_nodes, default_root_dir = args.model_dir)
    trainer.fit(model, datamodule=dm)
    trainer.test(model, datamodule=dm)
    

Overwriting scripts/mnist.py


In [6]:
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.local import LocalSession

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

image_uri = '570106654206.dkr.ecr.us-east-1.amazonaws.com/pt-ddp-custom:1.12.0-gpu-py38-cu113-ubuntu20.04-sagemaker-2.6.0-numproc'

estimator = PyTorch(
  entry_point="mnist.py",
  max_run=1800,
  base_job_name="lightning-ddp-mnist",
  image_uri = image_uri,
  role=role,
  source_dir="scripts",
  instance_count=1,
  instance_type="ml.g4dn.12xlarge",
  py_version="py38",
  sagemaker_session=sagemaker_session,
  distribution={"pytorchddp":{"enabled": True}},
  debugger_hook_config=False)

estimator.fit(wait=True)


2022-08-12 20:50:33 Starting - Starting the training job...
2022-08-12 20:51:03 Starting - Preparing the instances for trainingProfilerReport-1660337433: InProgress
.........
2022-08-12 20:52:21 Downloading - Downloading input data
2022-08-12 20:52:21 Training - Downloading the training image...........................
2022-08-12 20:57:03 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2022-08-12 20:57:02,859 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2022-08-12 20:57:02,901 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2022-08-12 20:57:02,904 sagemaker_pytorch_container.training INFO     Pytorch_ddp_enabled is:[0m
[34m2022-08-12 20:57:02,904 sagemaker_pytorch_container.training INFO     True[0m
[34m2022-08-12 20:57:02,904 

[34m[1,mpirank:0,algo-1]<stdout>:algo-1:255:255 [0] NCCL INFO Bootstrap : Using eth0:10.2.220.138<0>[0m
[34m[1,mpirank:0,algo-1]<stdout>:algo-1:255:255 [0] NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v4 symbol.[0m
[34m[1,mpirank:0,algo-1]<stdout>:algo-1:255:255 [0] NCCL INFO NET/OFI Using aws-ofi-nccl 1.3.0aws[0m
[34m[1,mpirank:0,algo-1]<stdout>:algo-1:255:255 [0] NCCL INFO NET/OFI Setting FI_EFA_FORK_SAFE environment variable to 1[0m
[34m[1,mpirank:0,algo-1]<stdout>:algo-1:255:255 [0] NCCL INFO NET/OFI Forcing AWS OFI ndev 2[0m
[34m[1,mpirank:0,algo-1]<stdout>:algo-1:255:255 [0] NCCL INFO NET/OFI Selected Provider is efa[0m
[34m[1,mpirank:0,algo-1]<stdout>:algo-1:255:255 [0] NCCL INFO Using network AWS Libfabric[0m
[34m[1,mpirank:0,algo-1]<stdout>:NCCL version 2.10.3+cuda11.3[0m
[34m[1,mpirank:1,algo-1]<stdout>:algo-1:103:103 [1] NCCL INFO Bootstrap : Using eth0:10.2.220.138<0>[0m
[34m[1,mpirank:2,algo-1]<stdout>:algo-1:105:105 [2] NCCL INFO Bootstrap : U

[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 30:  12% 6/48 [00:00<00:00, 157.67it/s, loss=0.523, v_num=0, val_acc=0.844][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 30:  12% 6/48 [00:00<00:00, 156.82it/s, loss=0.541, v_num=0, val_acc=0.844][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 30:  15% 7/48 [00:00<00:00, 159.22it/s, loss=0.541, v_num=0, val_acc=0.844][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 30:  15% 7/48 [00:00<00:00, 158.49it/s, loss=0.538, v_num=0, val_acc=0.844][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 30:  17% 8/48 [00:00<00:00, 160.27it/s, loss=0.538, v_num=0, val_acc=0.844][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 30:  17% 8/48 [00:00<00:00, 159.61it/s, loss=0.574, v_num=0, val_acc=0.844][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 30:  19% 9/48 [00:00<00:00, 161.17it/s, loss=0.574, v_num=0, val_acc=0.844][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 30:  19% 9/48 [00:00<00:00, 160.58it/s, loss=0.571, v_num=0, val_acc=0.844][0m
[34m[1,

[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 61:  10% 5/48 [00:00<00:00, 156.17it/s, loss=0.509, v_num=0, val_acc=0.857][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 61:  12% 6/48 [00:00<00:00, 158.98it/s, loss=0.509, v_num=0, val_acc=0.857][1,mpirank:0,algo-1]<stdout>:#015Epoch 61:  12% 6/48 [00:00<00:00, 158.08it/s, loss=0.528, v_num=0, val_acc=0.857][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 61:  15% 7/48 [00:00<00:00, 160.42it/s, loss=0.528, v_num=0, val_acc=0.857][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 61:  15% 7/48 [00:00<00:00, 159.68it/s, loss=0.524, v_num=0, val_acc=0.857][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 61:  17% 8/48 [00:00<00:00, 161.47it/s, loss=0.524, v_num=0, val_acc=0.857][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 61:  17% 8/48 [00:00<00:00, 160.80it/s, loss=0.51, v_num=0, val_acc=0.857][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 61:  19% 9/48 [00:00<00:00, 162.06it/s, loss=0.51, v_num=0, val_acc=0.857][0m
[34m[1,mpirank:0,al

[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 86:   2% 1/48 [00:00<00:00, 118.80it/s, loss=0.459, v_num=0, val_acc=0.864][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 86:   2% 1/48 [00:00<00:00, 115.91it/s, loss=0.46, v_num=0, val_acc=0.864][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 86:   4% 2/48 [00:00<00:00, 138.33it/s, loss=0.46, v_num=0, val_acc=0.864][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 86:   4% 2/48 [00:00<00:00, 136.37it/s, loss=0.459, v_num=0, val_acc=0.864][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 86:   6% 3/48 [00:00<00:00, 146.97it/s, loss=0.459, v_num=0, val_acc=0.864][1,mpirank:0,algo-1]<stdout>:#015Epoch 86:   6% 3/48 [00:00<00:00, 145.53it/s, loss=0.469, v_num=0, val_acc=0.864][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 86:   8% 4/48 [00:00<00:00, 151.97it/s, loss=0.469, v_num=0, val_acc=0.864][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 86:   8% 4/48 [00:00<00:00, 150.78it/s, loss=0.46, v_num=0, val_acc=0.864][0m
[34m[1,mpirank:0,alg

[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 118:  58% 28/48 [00:00<00:00, 163.99it/s, loss=0.403, v_num=0, val_acc=0.869][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 118:  58% 28/48 [00:00<00:00, 163.51it/s, loss=0.394, v_num=0, val_acc=0.869][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 118:  60% 29/48 [00:00<00:00, 163.91it/s, loss=0.394, v_num=0, val_acc=0.869][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 118:  60% 29/48 [00:00<00:00, 163.54it/s, loss=0.394, v_num=0, val_acc=0.869][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 118:  62% 30/48 [00:00<00:00, 163.88it/s, loss=0.394, v_num=0, val_acc=0.869][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 118:  62% 30/48 [00:00<00:00, 163.51it/s, loss=0.39, v_num=0, val_acc=0.869][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 118:  65% 31/48 [00:00<00:00, 163.90it/s, loss=0.39, v_num=0, val_acc=0.869][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 118:  65% 31/48 [00:00<00:00, 163.49it/s, loss=0.379, v_num=0, val_acc=0.869

[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 123:   4% 2/48 [00:00<00:00, 138.69it/s, loss=0.474, v_num=0, val_acc=0.873][1,mpirank:0,algo-1]<stdout>:#015Epoch 123:   4% 2/48 [00:00<00:00, 136.70it/s, loss=0.481, v_num=0, val_acc=0.873][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 123:   6% 3/48 [00:00<00:00, 147.62it/s, loss=0.481, v_num=0, val_acc=0.873][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 123:   6% 3/48 [00:00<00:00, 146.13it/s, loss=0.479, v_num=0, val_acc=0.873][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 123:   8% 4/48 [00:00<00:00, 152.26it/s, loss=0.479, v_num=0, val_acc=0.873][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 123:   8% 4/48 [00:00<00:00, 151.10it/s, loss=0.494, v_num=0, val_acc=0.873][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 123:  10% 5/48 [00:00<00:00, 155.44it/s, loss=0.494, v_num=0, val_acc=0.873][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 123:  10% 5/48 [00:00<00:00, 154.45it/s, loss=0.493, v_num=0, val_acc=0.873][0m
[34m[1,mp

[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 155:  46% 22/48 [00:00<00:00, 168.79it/s, loss=0.397, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 155:  46% 22/48 [00:00<00:00, 168.53it/s, loss=0.393, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 155:  48% 23/48 [00:00<00:00, 168.97it/s, loss=0.393, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 155:  48% 23/48 [00:00<00:00, 168.71it/s, loss=0.396, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 155:  50% 24/48 [00:00<00:00, 169.10it/s, loss=0.396, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 155:  50% 24/48 [00:00<00:00, 168.86it/s, loss=0.411, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 155:  52% 25/48 [00:00<00:00, 169.22it/s, loss=0.411, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 155:  52% 25/48 [00:00<00:00, 168.94it/s, loss=0.416, v_num=0, val_acc=0.8

[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 160:  58% 28/48 [00:00<00:00, 168.81it/s, loss=0.364, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 160:  58% 28/48 [00:00<00:00, 168.61it/s, loss=0.364, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 160:  60% 29/48 [00:00<00:00, 168.89it/s, loss=0.364, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 160:  60% 29/48 [00:00<00:00, 168.69it/s, loss=0.36, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 160:  62% 30/48 [00:00<00:00, 168.99it/s, loss=0.36, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 160:  62% 30/48 [00:00<00:00, 168.80it/s, loss=0.386, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 160:  65% 31/48 [00:00<00:00, 169.11it/s, loss=0.386, v_num=0, val_acc=0.879][0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 160:  65% 31/48 [00:00<00:00, 168.89it/s, loss=0.396, v_num=0, val_acc=0.879

[34m[1,mpirank:0,algo-1]<stdout>:#015Validation DataLoader 0:  40% 4/10 [00:00<00:00, 352.64it/s]#033[A[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 192:  88% 42/48 [00:00<00:00, 173.75it/s, loss=0.381, v_num=0, val_acc=0.884][0m
[34m[1,mpirank:0,algo-1]<stdout>:[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Validation DataLoader 0:  50% 5/10 [00:00<00:00, 338.04it/s]#033[A[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 192:  90% 43/48 [00:00<00:00, 175.38it/s, loss=0.381, v_num=0, val_acc=0.884][0m
[34m[1,mpirank:0,algo-1]<stdout>:[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Validation DataLoader 0:  60% 6/10 [00:00<00:00, 327.98it/s]#033[A[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 192:  92% 44/48 [00:00<00:00, 176.94it/s, loss=0.381, v_num=0, val_acc=0.884][0m
[34m[1,mpirank:0,algo-1]<stdout>:[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Validation DataLoader 0:  70% 7/10 [00:00<00:00, 321.56it/s]#033[A[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 192:  94% 45/48 [00:00<00

[34m[1,mpirank:0,algo-1]<stdout>:[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Validation DataLoader 0:  30% 3/10 [00:00<00:00, 379.60it/s][1,mpirank:0,algo-1]<stdout>:#033[A[1,mpirank:0,algo-1]<stdout>:#015Epoch 197:  85% 41/48 [00:00<00:00, 170.37it/s, loss=0.426, v_num=0, val_acc=0.883][0m
[34m[1,mpirank:0,algo-1]<stdout>:[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Validation DataLoader 0:  40% 4/10 [00:00<00:00, 349.48it/s]#033[A[1,mpirank:0,algo-1]<stdout>:#015Epoch 197:  88% 42/48 [00:00<00:00, 171.98it/s, loss=0.426, v_num=0, val_acc=0.883][0m
[34m[1,mpirank:0,algo-1]<stdout>:[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Validation DataLoader 0:  50% 5/10 [00:00<00:00, 335.23it/s]#033[A[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Epoch 197:  90% 43/48 [00:00<00:00, 173.62it/s, loss=0.426, v_num=0, val_acc=0.883][0m
[34m[1,mpirank:0,algo-1]<stdout>:[0m
[34m[1,mpirank:0,algo-1]<stdout>:#015Validation DataLoader 0:  60% 6/10 [00:00<00:00, 326.71it/s]#033[A[0m
[34m[1,mpirank:0,alg


2022-08-12 20:58:43 Uploading - Uploading generated training model
2022-08-12 20:58:43 Completed - Training job completed
ProfilerReport-1660337433: NoIssuesFound
Training seconds: 373
Billable seconds: 373


In [1]:
# !pip install boto3 --upgrade

In [31]:
!aws s3 cp "PyTorch Lightning on SageMaker - CV.ipynb" s3://dist-train/pytorch-lightning/

Completed 7.9 KiB/7.9 KiB (49.3 KiB/s) with 1 file(s) remainingupload: ./PyTorch Lightning on SageMaker - CV.ipynb to s3://dist-train/pytorch-lightning/PyTorch Lightning on SageMaker - CV.ipynb
