# MNIST Training using PyTorch

MNIST is a widely used dataset for handwritten digit classification. 

- It consists of 70,000 labeled 28x28 pixel grayscale images of hand-written digits.
- The dataset is split into 60,000 training images and 10,000 test images. 
- There are 10 classes (one for each of the 10 digits). 
- This tutorial will show how to train and test an MNIST model on SageMaker using PyTorch.




## Setup

Let's start by creating a SageMaker session and specifying:

- The S3 bucket and prefix that you want to use for training and model data. This should be within the same region as the Notebook Instance, training, and hosting.
- The IAM role arn used to give training and hosting access to your data. See the documentation for how to create these. Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the sagemaker.get_execution_role() with a the appropriate full IAM role arn string(s).

In [1]:
import sagemaker
from sagemaker.local import LocalSession

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/Demo-pytorch-mnist"
role = sagemaker.get_execution_role()

In [2]:
role

'arn:aws:iam::964564632268:role/service-role/AmazonSageMaker-ExecutionRole-20220129T105728'

In [3]:
pip install torchvision==0.5.0 --no-cache-dir

Collecting torchvision==0.5.0
  Downloading torchvision-0.5.0-cp36-cp36m-manylinux1_x86_64.whl (4.0 MB)
     |████████████████████████████████| 4.0 MB 21.0 MB/s            
[?25hCollecting torch==1.4.0
  Downloading torch-1.4.0-cp36-cp36m-manylinux1_x86_64.whl (753.4 MB)
     |████████████████████████████████| 753.4 MB 91.3 MB/s            
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.7.1
    Uninstalling torch-1.7.1:
      Successfully uninstalled torch-1.7.1
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.8.2
    Uninstalling torchvision-0.8.2:
      Successfully uninstalled torchvision-0.8.2
Successfully installed torch-1.4.0 torchvision-0.5.0
Note: you may need to restart the kernel to use updated packages.


## Data

In [7]:
from torchvision import datasets, transforms

datasets.MNIST(
    "data",
    download=True,
    transform=transforms.Compose(
        [
            transforms.ToTensor(), 
            transforms.Normalize((0.1307,), (0.3081))
        ]
    )
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
Processing...
Done!


Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.1307,), std=0.3081)
           )

## Uploading the data to S3


In [9]:
inputs = sagemaker_session.upload_data(path="data", bucket=bucket, key_prefix=prefix)
print("input spec (in this case just an s3 path): {}".format(inputs))

input spec (in this case just an s3 path): s3://sagemaker-ap-southeast-1-964564632268/sagemaker/Demo-pytorch-mnist


## Train

The mnist.py script provides all the code we need for training and hosting a SageMaker model (model_fn function to load a model). The training script is very similar to a training script you might run outside of SageMaker, but you can access useful properties about the training environment through various environment variables

In [11]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point = "mnist.py",
    role = role,
    framework_version = "1.4.0",
    py_version = "py3",
    instance_count = 2,
    instance_type="ml.c4.xlarge",
    hyperparameters={
        "epochs": 6,
        "backend": "gloo"
    },
)

In [12]:
estimator.fit({"training": inputs})

2022-01-29 07:04:18 Starting - Starting the training job...
2022-01-29 07:04:45 Starting - Launching requested ML instancesProfilerReport-1643439857: InProgress
......
2022-01-29 07:05:46 Starting - Preparing the instances for training............
2022-01-29 07:07:46 Downloading - Downloading input data
2022-01-29 07:07:46 Training - Downloading the training image..[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2022-01-29 07:08:03,500 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2022-01-29 07:08:03,503 sagemaker-containers INFO     No GPUs detected (normal if no gpus installed)[0m
[34m2022-01-29 07:08:03,515 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m

2022-01-29 07:08:08 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappro

In [13]:
estimator.model_data

's3://sagemaker-ap-southeast-1-964564632268/pytorch-training-2022-01-29-07-04-17-606/output/model.tar.gz'

## Host/Deploy

In [14]:
predictor = estimator.deploy(
    initial_instance_count=1,
    instance_type="ml.t2.medium"
)

--------!

## Evaluate

In [15]:
!ls data/MNIST/raw

t10k-images-idx3-ubyte	   train-images-idx3-ubyte
t10k-images-idx3-ubyte.gz  train-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte	   train-labels-idx1-ubyte
t10k-labels-idx1-ubyte.gz  train-labels-idx1-ubyte.gz


In [17]:
import gzip
import numpy as np
import random
import os

data_dir = "data/MNIST/raw"
with gzip.open(os.path.join(data_dir, "t10k-images-idx3-ubyte.gz"), "rb") as f: ###### complete
    images = np.frombuffer(f.read(), np.int8, offset=16).reshape(-1, 28, 28).astype(np.float32)
    
mask = random.sample(range(len(images)), 16) # randomly select some of the test images
mask = np.array(mask, dtype=np.int)
data = images[mask]


In [18]:
response = predictor.predict(np.expand_dims(data, axis=1))
print("Raw prediction result: ", response)
print()

labeled_predictions = list(zip(range(10), response[0]))
print("Labeled_predictions: ", labeled_predictions, "\n")

labeled_predictions.sort(key=lambda label_and_prod: 1.0 - label_and_prod[1]) ###### complete
print("MOST LIKELY ANSWER: {}".format(labeled_predictions[0]))

Raw prediction result:  [[-1.08076992e+01 -1.62589340e+01 -6.55671453e+00 -4.61717653e+00
  -2.85745335e+01 -2.29228884e-02 -9.13732338e+00 -3.43735771e+01
  -4.48889112e+00 -2.68527851e+01]
 [-4.49020233e+01 -7.20633087e+01 -6.75494385e+01 -3.75043221e+01
  -5.59003677e+01 -2.50339190e-06 -3.48041763e+01 -8.15338516e+01
  -1.29081831e+01 -4.37730446e+01]
 [-3.78774300e+01 -6.53409348e+01 -1.22470369e+01 -3.42104607e+01
  -6.46161957e+01 -3.98908348e+01 -4.06050339e+01 -4.87020264e+01
  -4.76836021e-06 -4.89162750e+01]
 [-2.22264385e+01 -1.16533709e+01 -8.00051028e-04 -7.14700031e+00
  -2.21742916e+01 -1.35864429e+01 -1.28857899e+01 -2.48312817e+01
  -1.77316551e+01 -3.32350693e+01]
 [-7.86819763e+01 -8.57515640e+01 -1.66698657e-02 -4.10247755e+00
  -1.86761063e+02 -8.99968033e+01 -1.36304382e+02 -6.85258865e+01
  -5.02308617e+01 -1.28020050e+02]
 [-5.58570976e+01 -8.46266556e+01  0.00000000e+00 -5.65192451e+01
  -5.62005386e+01 -8.66798401e+01 -6.22382812e+01 -6.45247574e+01
  -2.6474

In [20]:
from IPython.display import HTML
HTML(open("input.html").read())

In [25]:
import numpy as np
image = np.array([data], dtype=np.float32)
response = predictor.predict(image)
prediction = response.argmax(axis=1)[0]
print(prediction)

1


## Cleanup

In [26]:
predictor.delete_endpoint() # to save cost by deleting endpoint

## REFERENCES

1. GitHub Repository: <https://github.com/aws/amazon-sagemaker-examples/tree/master/sagemaker-python-sdk/pytorch_mnist>
2. Notebook Link: https://github.com/aws/amazon-sagemaker-examples/blob/master/sagemaker-python-sdk/pytorch_mnist/pytorch_mnist.ipynb