In [3]:
%pip install --upgrade sagemaker
!pip install boto3 --upgrade

In [2]:
import sagemaker
sagemaker.__version__

'2.103.0'

In [3]:
!mkdir scripts

In [4]:
%%writefile scripts/requirements.txt
pytorch-lightning == 1.6.3
lightning-bolts == 0.5.0

Writing scripts/requirements.txt


In [5]:
%%writefile scripts/mnist.py

import os
import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy

from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule

import argparse

class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        probs = self(x)
        # we currently return the accuracy as the validation_step/test_step is run on the IPU devices.
        # Outputs from the step functions are sent to the host device, where we calculate the metrics in
        # validation_epoch_end and test_epoch_end for the test_step.
        acc = self.accuracy(probs, y)
        return acc

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = self.accuracy(logits, y)
        return acc

    def accuracy(self, logits, y):
        # currently IPU poptorch doesn't implicit convert bools to tensor
        # hence we use an explicit calculation for accuracy here. Once fixed in poptorch
        # we can use the accuracy metric.
        acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)
        return acc

    def validation_epoch_end(self, outputs) -> None:
        # since the training step/validation step and test step are run on the IPU device
        # we must log the average loss outside the step functions.
        self.log("val_acc", torch.stack(outputs).mean(), prog_bar=True)

    def test_epoch_end(self, outputs) -> None:
        self.log("test_acc", torch.stack(outputs).mean())

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--hosts", type=list, default=os.environ["SM_HOSTS"])
    parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
    parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--num-gpus", type=int, default=int(os.environ["SM_NUM_GPUS"]))

    parser.add_argument("--num_nodes", type=int, default = len(os.environ["SM_HOSTS"]))
           
    # num gpus is per node
    world_size = int(os.environ["SM_NUM_GPUS"]) * len(os.environ["SM_HOSTS"])
                 
    parser.add_argument("--world-size", type=int, default=world_size)
    
    args = parser.parse_args()
    
    return args
    
    
if __name__ == "__main__":
    
    args = parse_args()
    
    dm = MNISTDataModule(batch_size=32)
    
    model = LitClassifier()
    
    local_rank = os.environ["LOCAL_RANK"]
    torch.cuda.set_device(int(local_rank))
    
    num_nodes = args.num_nodes
    num_gpus = args.num_gpus
    
    env = LightningEnvironment()
    
    env.world_size = lambda: int(os.environ.get("WORLD_SIZE", 0))
    env.global_rank = lambda: int(os.environ.get("RANK", 0))
    
    ddp = DDPStrategy(cluster_environment=env, accelerator="gpu")
    
    trainer = pl.Trainer(max_epochs=200, strategy=ddp, devices=num_gpus, num_nodes=num_nodes, default_root_dir = args.model_dir)
    trainer.fit(model, datamodule=dm)
    trainer.test(model, datamodule=dm)
    

Writing scripts/mnist.py


In [8]:
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.local import LocalSession

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

# hard code point to the DLC images
image_uri = '< ... >.dkr.ecr.us-east-1.amazonaws.com/pt-ddp-custom:1.12.0-gpu-py38-cu113-ubuntu20.04-sagemaker-2.6.0-numproc'

estimator = PyTorch(
  entry_point="mnist.py",
  max_run=1800,
  base_job_name="lightning-ddp-mnist",
  image_uri = image_uri,
  role=role,
  source_dir="scripts",
  instance_count=1,
  instance_type="ml.g4dn.12xlarge",
  py_version="py38",
  sagemaker_session=sagemaker_session,
  distribution={"pytorchddp":{"enabled": True}},
  debugger_hook_config=False)

# estimator.fit(wait=True)
