# PyTorch SageMaker Training Guide

From [Sagemaker pytorch docs](https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html)

## 3 Steps for Training

1. Prepare a training script

2. Create a sagemaker.pytorch.PyTorch Estimator

3. Call the estimator’s fit method

Prepare your script in a separate source file than the notebook

* `SM_NUM_GPUS`: number of GPUs.

* `SM_MODEL_DIR`: path to the S3 directory to write model artifacts to.

* `SM_OUTPUT_DATA_DIR`: Write Output artifacts such as: checkpoints, graphs, and other files to save, not including model artifacts.

* `SM_CHANNEL_XXXX`: Either `SM_CHANNEL_TRAIN` and `SM_CHANNEL_TEST` input data.

## Example script
```python
import argparse
import os

if __name__ =='__main__':

    parser = argparse.ArgumentParser()

    # hyperparameters sent by the client are passed as command-line arguments to the script.
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch-size', type=int, default=64)
    parser.add_argument('--learning-rate', type=float, default=0.05)
    parser.add_argument('--use-cuda', type=bool, default=False)

    # Data, model, and output directories
    parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
    parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
    parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN'])
    parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST'])

    args, _ = parser.parse_known_args()

    # ... load from args.train and args.test, train a model, write model to args.model_dir.
```
For training, SageMaker simply exectues the provided script as main. However, during deploying if you're using the same script, you should put your training code in a main guard (`if __name__=='__main__':`), because the SageMaker imports your training script.

Note that SageMaker doesn't support argparse. Instead it relies on [Environment Variables](https://github.com/aws/sagemaker-training-toolkit/blob/master/ENVIRONMENT_VARIABLES.md)

If there are other packages you want to use with your script, you can include a `requirements.txt` file in the same directory as your training script to install other dependencies at runtime.

## Create Pytorch Model

```python
# https://github.com/awslabs/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/pytorch_mnist/mnist.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
#...

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        # ...
    
    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        # ...
        
```

## Create Data loaders

Create two private functions for loading training and testing data, e.g. `_get_train_data_loader` & `_get_test_data_loader`

```python
def _get_train_data_loader(batch_size, training_dir, is_distributed, **kwargs):
    logger.info("Get train data loader")
    dataset = datasets.MNIST(training_dir, train=True, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]))
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=train_sampler is None,
                                       sampler=train_sampler, **kwargs)
```

## Define train function
```python
def train(args):
    is_distributed = len(args.hosts) > 1 and args.backend is not None
    logger.debug("Distributed training - {}".format(is_distributed))
    use_cuda = args.num_gpus > 0
    device = torch.device("cuda" if use_cuda else "cpu")
    # ...
    
    train_loader = _get_train_data_loader(args.batch_size, args.data_dir, is_distributed, **kwargs)
    model = Net().to(device)
    # ...
    
    for epoch in range(1, args.epochs + 1):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader, 1):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.nll_loss(output, target)
            loss.backward()
            if is_distributed and not use_cuda:
                # average gradients manually for multi-machine cpu case only
                _average_gradients(model)
            optimizer.step()
            if batch_idx % args.log_interval == 0:
                logger.info('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.sampler),
                    100. * batch_idx / len(train_loader), loss.item()))
        test(model, test_loader, device)
    save_model(model, args.model_dir)
    
def test(model, test_loader, device):
    # ...
    
def model_fn(model_dir):
    '''
    Before a model can be served, it must be loaded.
    The SageMaker PyTorch model server loads your model by invoking a model_fn function
    '''
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.nn.DataParallel(Net())
    with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:
        model.load_state_dict(torch.load(f))
    return model.to(device)


def save_model(model, model_dir):
    logger.info("Saving the model.")
    path = os.path.join(model_dir, 'model.pth')
    # recommended way from http://pytorch.org/docs/master/notes/serialization.html
    torch.save(model.cpu().state_dict(), path)
```

## Predicting Functions (Deploying)

* `model_fn(model_dir)` - loads your model.
* `input_fn(serialized_input_data, content_type)` - deserializes predictions to predict_fn.
* `output_fn(prediction_output, accept)` - serializes predictions from predict_fn.
* `predict_fn(input_data, model)` - calls a model on data deserialized in input_fn