# Data Parallel on Amazon SageMaker Training with PyTorch Lightning
In this lab we'll write our own deep neural network using PyTorch Lightning, and train this on Amazon SageMaker. For more details [see our blog post here.](https://aws.amazon.com/blogs/machine-learning/run-pytorch-lightning-and-native-pytorch-ddp-on-amazon-sagemaker-training-featuring-amazon-search/)

### Step 0. Update the SageMaker Python SDK

In [None]:
%pip install --upgrade sagemaker
%pip install boto3 --upgrade

In [None]:
import sagemaker
# make sure this is at least 2.102.0
sagemaker.__version__

### Step 1. Write train script and requirements into a local directory, here named `scripts`

In [None]:
!mkdir scripts

In [None]:
%%writefile scripts/requirements.txt
pytorch-lightning == 1.6.3
lightning-bolts == 0.5.0

In [None]:
%%writefile scripts/mnist.py

import os
import torch
from torch.nn import functional as F

import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy

from pytorch_lightning.plugins.environments.lightning_environment import LightningEnvironment
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule

import argparse

class LitClassifier(pl.LightningModule):
    def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.0001):
        super().__init__()
        self.save_hyperparameters()

        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        probs = self(x)
        # we currently return the accuracy as the validation_step/test_step is run on the IPU devices.
        # Outputs from the step functions are sent to the host device, where we calculate the metrics in
        # validation_epoch_end and test_epoch_end for the test_step.
        acc = self.accuracy(probs, y)
        return acc

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = self.accuracy(logits, y)
        return acc

    def accuracy(self, logits, y):
        # currently IPU poptorch doesn't implicit convert bools to tensor
        # hence we use an explicit calculation for accuracy here. Once fixed in poptorch
        # we can use the accuracy metric.
        acc = torch.sum(torch.eq(torch.argmax(logits, -1), y).to(torch.float32)) / len(y)
        return acc

    def validation_epoch_end(self, outputs) -> None:
        # since the training step/validation step and test step are run on the IPU device
        # we must log the average loss outside the step functions.
        self.log("val_acc", torch.stack(outputs).mean(), prog_bar=True)

    def test_epoch_end(self, outputs) -> None:
        self.log("test_acc", torch.stack(outputs).mean())

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

    
def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--hosts", type=list, default=os.environ["SM_HOSTS"])
    parser.add_argument("--current-host", type=str, default=os.environ["SM_CURRENT_HOST"])
    parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--num-gpus", type=int, default=int(os.environ["SM_NUM_GPUS"]))

    parser.add_argument("--num_nodes", type=int, default = len(os.environ["SM_HOSTS"]))
           
    # num gpus is per node
    world_size = int(os.environ["SM_NUM_GPUS"]) * len(os.environ["SM_HOSTS"])
                 
    parser.add_argument("--world-size", type=int, default=world_size)
    
    args = parser.parse_args()
    
    return args
    
    
if __name__ == "__main__":
    
    args = parse_args()
    
    dm = MNISTDataModule(batch_size=32)
    
    model = LitClassifier()
    
    local_rank = os.environ["LOCAL_RANK"]
    torch.cuda.set_device(int(local_rank))
    
    num_nodes = args.num_nodes
    num_gpus = args.num_gpus
    
    env = LightningEnvironment()
    
    env.world_size = lambda: int(os.environ.get("WORLD_SIZE", 0))
    env.global_rank = lambda: int(os.environ.get("RANK", 0))
    
    ddp = DDPStrategy(cluster_environment=env, accelerator="gpu")
    
    trainer = pl.Trainer(max_epochs=200, strategy=ddp, devices=num_gpus, num_nodes=num_nodes, default_root_dir = args.model_dir)
    trainer.fit(model, datamodule=dm)
    trainer.test(model, datamodule=dm)
    

### Step 2. Configure the SageMaker Training api locally 
In this step you are using the [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html) as a wrapper around the core api, `create-training-job`, [as described here.](https://docs.aws.amazon.com/cli/latest/reference/sagemaker/create-training-job.html)

Read more about [training on SageMaker here,](https://docs.aws.amazon.com/sagemaker/latest/dg/how-it-works-training.html) with distributed training details [here.](https://docs.aws.amazon.com/sagemaker/latest/dg/distributed-training.html)


You can see [instance details for SageMaker here,](https://aws.amazon.com/sagemaker/pricing/) along with [instance specs from EC2 directly here.](https://aws.amazon.com/ec2/instance-types/)

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch
from sagemaker.local import LocalSession

sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()
region = sagemaker_session.boto_region_name

# hard code point to the DLC images
image_uri = '763104351884.dkr.ecr.{}.amazonaws.com/pytorch-training:1.12.0-gpu-py38-cu113-ubuntu20.04-sagemaker'.format(region)

estimator = PyTorch(
  entry_point="mnist.py",
  base_job_name="lightning-ddp-mnist",
  image_uri = image_uri,
  role=role,
  source_dir="scripts",
  # configures the SageMaker training resource, you can increase as you need
  instance_count=1,
  instance_type="ml.g4dn.12xlarge",
  py_version="py38",
  sagemaker_session=sagemaker_session,
  distribution={"pytorchddp":{"enabled": True}},
  debugger_hook_config=False)

In [None]:
# passing True will halt your kernel, passing False will not. both create a training job.
estimator.fit(wait=True)

### Step 3. Download a profiler report
Unless disabled, SageMaker will generate a report to analyze your overall job metrics and provide helpful insights that you can use to increase performance. Details on the report, including how to access it, [are provided here.](https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-profiling-report.html)

If you need to restart your kernel, you can re-establish the estimator variable here.
`estimator = PyTorch.attach("your_training_job_name")`

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

job_name = 'lightning-ddp-mnist-2022-08-23-02-11-37-450'
estimator = PyTorch.attach(job_name)

In [None]:
# point to an S3 path using your estimator variable
rule_output_path = estimator.output_path + estimator.latest_training_job.job_name + "/rule-output"

In [None]:
# use the AWS CLI to copy the report locally
! aws s3 cp {rule_output_path} ./ --recursive

### Step 4. Optimize!
That's the end of this lab. However, in the real world, this is just the begining. Here are a few extra things you can do to increase model accuracy (hopefully) and decrease job runtime (certainly).
1. **Increase throughput by adding extra nodes.** Increase the number of instances in `instance_count`, and as long as you're using some for of distribution in your training script, this will automatically copy your model over all available accelerators and handle averaging the results. This is also called ***horizontal scaling.***
2. **Increase throughput by upgrading your instances.** In addition to (or sometimes instead of) adding extra nodes, you can increase your instance to something larger. This usually means adding more accelerators, more CPU, more memory, and more bandwidth. 
3. **Increase accuracy with hyperparameter tuning**. Another critical step is picking the right hyperparameters. You can use [Syne Tune](https://github.com/awslabs/syne-tune/blob/hf_blog_post/hf_blog_post/example_syne_tune_for_hf.ipynb) for a multi-objective tuning metric as one example. Here's another example with [SageMaker Training Compiler to tune the batch size!](https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker-training-compiler/huggingface/pytorch_tune_batch_size/finding_max_batch_size_for_model_training.ipynb)
4. **Increase accuracy by adding more parameters to your model, and using a model parallel strategy.** This is the content of the next lab!