# Use SageMaker Batch Transform for PyTorch Batch Inference
In this notebook, we examine how to do a Batch Transform task with PyTorch in Amazon SageMaker. 

First, an image classification model is built on the MNIST dataset. Then, we demonstrate batch transform by using the SageMaker Python SDK PyTorch framework with different configurations:
- `data_type=S3Prefix`: uses all objects that match the specified S3 prefix for batch inference.
- `data_type=ManifestFile`: a manifest file contains a list of object keys to use in batch inference.
- `instance_count>1`: distributes the batch inference dataset to multiple inference instances.

For batch transform in TensorFlow in Amazon SageMaker, you can follow other Jupyter notebooks in the [sagemaker_batch_transform](https://github.com/awslabs/amazon-sagemaker-examples/tree/master/sagemaker_batch_transform) directory.

### Runtime

This notebook takes approximately 15 minutes to run.

### Contents

1. [Setup](#Setup)
1. [Model training](#Model-training)
1. [Prepare batch inference data](#Prepare-batch-inference-data)
1. [Create model transformer](#Create-model-transformer)
1. [Batch inference](#Batch-inference)
1. [Look at all transform jobs](#Look-at-all-transform-jobs)
1. [Conclusion](#Conclusion)

## Setup
We'll begin with some necessary installs and imports, and get an Amazon SageMaker session to help perform certain tasks, as well as an IAM role with the necessary permissions.

In [2]:
!pip install nvidia-ml-py3
!yes | pip uninstall torchvision
!pip install torchvision

Found existing installation: torchvision 0.5.0+cpu
Uninstalling torchvision-0.5.0+cpu:
  Would remove:
    /opt/conda/lib/python3.6/site-packages/torchvision-0.5.0+cpu.dist-info/*
    /opt/conda/lib/python3.6/site-packages/torchvision/*
Proceed (y/n)?   Successfully uninstalled torchvision-0.5.0+cpu
yes: standard output: Broken pipe
Collecting torchvision
  Downloading torchvision-0.11.2-cp36-cp36m-manylinux1_x86_64.whl (23.3 MB)
[K     |████████████████████████████████| 23.3 MB 5.9 MB/s eta 0:00:01
[?25hCollecting torch==1.10.1
  Downloading torch-1.10.1-cp36-cp36m-manylinux1_x86_64.whl (881.9 MB)
[K     |████████████████████████████████| 881.9 MB 4.4 kB/s s eta 0:00:01   |█▊                              | 46.2 MB 47.4 MB/s eta 0:00:18
Installing collected packages: torch, torchvision
  Attempting uninstall: torch
    Found existing installation: torch 1.4.0
    Uninstalling torch-1.4.0:
      Successfully uninstalled torch-1.4.0
Successfully installed torch-1.10.1 torchvision-0.11

In [3]:
%matplotlib inline
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import os
from os import listdir
from os.path import isfile, join
from shutil import copyfile
import sagemaker
from sagemaker.pytorch import PyTorchModel
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()
role = get_execution_role()

bucket = sagemaker_session.default_bucket()
prefix = "sagemaker/DEMO-pytorch-batch-inference-script"
print("Bucket: {}".format(bucket))

Bucket: sagemaker-us-west-2-521695447989


## Model training

Since the main purpose of this notebook is to demonstrate SageMaker PyTorch batch transform, we reuse a SageMaker Python SDK [PyTorch MNIST example](https://github.com/awslabs/amazon-sagemaker-examples/tree/master/sagemaker-python-sdk/pytorch_mnist) to train a PyTorch model. It takes around 7 minutes to finish the training.

In [4]:
from torchvision.datasets import MNIST
from torchvision import transforms

local_dir = "data"
MNIST.mirrors = ["https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/"]
MNIST(
    local_dir,
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    ),
)


inputs = sagemaker_session.upload_data(path=local_dir, bucket=bucket, key_prefix=prefix)
print("input spec (in this case, just an S3 path): {}".format(inputs))

from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point="model-script/mnist.py",
    role=role,
    framework_version="1.8.0",
    py_version="py3",
    instance_count=3,
    instance_type="ml.c5.2xlarge",
    hyperparameters={
        "epochs": 1,
        "backend": "gloo",
    },  # set epochs to a more realistic number for real training
)

estimator.fit({"training": inputs})

Downloading https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/train-images-idx3-ubyte.gz
Downloading https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/train-labels-idx1-ubyte.gz
Downloading https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/t10k-images-idx3-ubyte.gz
Downloading https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/t10k-labels-idx1-ubyte.gz
Downloading https://sagemaker-sample-files.s3.amazonaws.com/datasets/image/MNIST/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

input spec (in this case, just an S3 path): s3://sagemaker-us-west-2-000000000000/sagemaker/DEMO-pytorch-batch-inference-script
2022-04-18 00:17:13 Starting - Starting the training job...
2022-04-18 00:17:33 Starting - Preparing the instances for trainingProfilerReport-1650241033: InProgress
.........
2022-04-18 00:19:13 Downloading - Downloading input data...
...
[35mINFO:__main__:Test set: Average loss: 0.4341, Accuracy: 8826/10000 (88%)[0m
[35m2022-04-18 00:20:14,348 sagemaker-training-toolkit INFO     Reporting training SUCCESS[0m

2022-04-18 00:20:39 Uploading - Uploading generated training model
2022-04-18 00:20:39 Completed - Training job completed
Training seconds: 216
Billable seconds: 216


## Prepare batch inference data

Convert the test data into PNG image format.

In [5]:
!ls data/MNIST/raw

t10k-images-idx3-ubyte	   train-images-idx3-ubyte
t10k-images-idx3-ubyte.gz  train-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte	   train-labels-idx1-ubyte
t10k-labels-idx1-ubyte.gz  train-labels-idx1-ubyte.gz


In [6]:
# untar gz => png

import gzip
import numpy as np
import os

with gzip.open(os.path.join(local_dir, "MNIST/raw", "t10k-images-idx3-ubyte.gz"), "rb") as f:
    images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)

In [7]:
print(len(images), "test images")

10000 test images


Randomly sample 100 test images and upload them to S3.

In [8]:
import random
from PIL import Image as im

ids = random.sample(range(len(images)), 100)
ids = np.array(ids, dtype=np.int)
selected_images = images[ids]

image_dir = "data/images"

if not os.path.exists(image_dir):
    os.makedirs(image_dir)

for i, img in enumerate(selected_images):
    pngimg = im.fromarray(img)
    pngimg.save(os.path.join(image_dir, f"{i}.png"))

In [9]:
inference_prefix = "batch_transform"
inference_inputs = sagemaker_session.upload_data(
    path=image_dir, bucket=bucket, key_prefix=inference_prefix
)
print("Input S3 path: {}".format(inference_inputs))

Input S3 path: s3://sagemaker-us-west-2-000000000000/batch_transform


## Create model transformer
Now, we create a transformer object for creating and interacting with Amazon SageMaker transform jobs. We can create the transformer in two ways:
1. Use a fitted estimator directly.
1. First create a PyTorchModel from a saved model artifact, and then create a transformer from the PyTorchModel object.


Here, we implement the `model_fn`, `input_fn`, `predict_fn` and `output_fn` function to override the default [PyTorch inference handler](https://github.com/aws/sagemaker-pytorch-inference-toolkit/blob/master/src/sagemaker_pytorch_serving_container/default_inference_handler.py). 

In the `input_fn()` function, the inferenced images are encoded as a Python ByteArray. That's why we use the `load_from_bytearray()` function to load images from `io.BytesIO` and then use `PIL.image` to read the images.

```python
def model_fn(model_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.nn.DataParallel(Net())
    with open(os.path.join(model_dir, "model.pth"), "rb") as f:
        model.load_state_dict(torch.load(f))
    return model.to(device)

    
def load_from_bytearray(request_body):
    image_as_bytes = io.BytesIO(request_body)
    image = Image.open(image_as_bytes)
    image_tensor = ToTensor()(image).unsqueeze(0)    
    return image_tensor


def input_fn(request_body, request_content_type):
    # if set content_type as "image/jpg" or "application/x-npy", 
    # the input is also a python bytearray
    if request_content_type == "application/x-image": 
        image_tensor = load_from_bytearray(request_body)
    else:
        print("not support this type yet")
        raise ValueError("not support this type yet")
    return image_tensor


# Perform prediction on the deserialized object, with the loaded model
def predict_fn(input_object, model):
    output = model.forward(input_object)
    pred = output.max(1, keepdim=True)[1]

    return {"predictions": pred.item()}


# Serialize the prediction result into the desired response content type
def output_fn(predictions, response_content_type):
    return json.dumps(predictions)
```

In [10]:
# Use fitted estimator directly
transformer = estimator.transformer(instance_count=1, instance_type="ml.c5.xlarge")

In [11]:
# You can also create a Transformer object from saved model artifact

# Get model artifact location by estimator.model_data, or give an S3 key directly
model_artifact_s3_location = estimator.model_data  # "s3://<BUCKET>/<PREFIX>/model.tar.gz"

# Create PyTorchModel from saved model artifact
pytorch_model = PyTorchModel(
    model_data=model_artifact_s3_location,
    role=role,
    framework_version="1.8.0",
    py_version="py3",
    source_dir="model-script/",
    entry_point="mnist.py",
)

# Create transformer from PyTorchModel object
transformer = pytorch_model.transformer(instance_count=1, instance_type="ml.c5.xlarge")

## Batch inference
Next, we perform inference on the sampled 100 MNIST images in a batch manner. 

### Input images directly from S3 location
We set `S3DataType=S3Prefix` to use all objects that match the specified S3 prefix for batch inference.

In [12]:
transformer.transform(
    data=inference_inputs,
    data_type="S3Prefix",
    content_type="application/x-image",
    wait=True,
)

......................[34m2022-04-18 00:25:31,379 [INFO ] main org.pytorch.serve.ModelServer - [0m
[34mTorchserve version: 0.3.0[0m
[34mTS Home: /opt/conda/lib/python3.6/site-packages[0m
[34mCurrent directory: /[0m
[34mTemp directory: /home/model-server/tmp[0m
[34mNumber of GPUs: 0[0m
[34mNumber of CPUs: 4[0m
[34mMax heap size: 948 M[0m
[34mPython executable: /opt/conda/bin/python3.6[0m
[34mConfig file: /etc/sagemaker-ts.properties[0m
[34mInference address: http://0.0.0.0:8080[0m
[34mManagement address: http://0.0.0.0:8080[0m
[34mMetrics address: http://127.0.0.1:8082[0m
[34mModel Store: /.sagemaker/ts/models[0m
[34mInitial Models: model.mar[0m
[34mLog dir: /logs[0m
[34mMetrics dir: /logs[0m
[34mNetty threads: 0[0m
[34mNetty client threads: 0[0m
[34mDefault workers per model: 4[0m
[34mBlacklist Regex: N/A[0m
[34mMaximum Response Size: 6553500[0m
[34mMaximum Request Size: 6553500[0m
[34mPrefer direct buffer: false[0m
[34mAllowed Urls: [fi

### Input images by manifest file
First, we generate a manifest file. Then we use the manifest file containing a list of object keys as inputs to batch inference. Some key points:
- `content_type = "application/x-image"` (here the `content_type` is for the actual object for inference, not for the manifest file)
- `data_type = "ManifestFile"`
- Manifest file format must follow the format as [S3DataSource](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#SageMaker-Type-S3DataSource-S3DataType) points out. We create the manifest file by using the jsonlines package.
``` json
[
    {"prefix": "s3://customer_bucket/some/prefix/"},
    "relative/path/to/custdata-1",
    "relative/path/custdata-2",
    ...
    "relative/path/custdata-N"
]
```

In [13]:
!pip install -q jsonlines

In [14]:
import jsonlines

# Build image list
manifest_prefix = f"s3://{bucket}/{prefix}/images/"

path = image_dir
img_files = [f for f in listdir(path) if isfile(join(path, f))]

print("img_files\n", img_files)

manifest_content = [{"prefix": manifest_prefix}]
manifest_content.extend(img_files)

print("manifest_content\n", manifest_content)

# Write jsonl file
manifest_file = "manifest.json"
with jsonlines.open(manifest_file, mode="w") as writer:
    writer.write(manifest_content)

# Upload to S3
manifest_obj = sagemaker_session.upload_data(path=manifest_file, key_prefix=prefix)

print("manifest_obj\n", manifest_obj)

img_files
 ['2.png', '43.png', '46.png', '45.png', '16.png', '4.png', '92.png', '12.png', '18.png', '75.png', '87.png', '63.png', '81.png', '78.png', '13.png', '34.png', '42.png', '19.png', '73.png', '14.png', '83.png', '3.png', '55.png', '5.png', '38.png', '86.png', '27.png', '7.png', '15.png', '80.png', '89.png', '72.png', '88.png', '99.png', '82.png', '32.png', '70.png', '36.png', '67.png', '26.png', '94.png', '41.png', '30.png', '64.png', '28.png', '59.png', '52.png', '90.png', '69.png', '31.png', '37.png', '47.png', '40.png', '56.png', '58.png', '1.png', '93.png', '61.png', '29.png', '76.png', '23.png', '50.png', '97.png', '79.png', '85.png', '6.png', '57.png', '35.png', '11.png', '22.png', '62.png', '21.png', '95.png', '71.png', '60.png', '17.png', '77.png', '25.png', '49.png', '96.png', '53.png', '54.png', '48.png', '51.png', '74.png', '20.png', '0.png', '91.png', '33.png', '24.png', '65.png', '66.png', '10.png', '68.png', '39.png', '9.png', '98.png', '84.png', '8.png', '44.png'

In [15]:
# Batch transform with manifest file
transform_job = transformer.transform(
    data=manifest_obj,
    data_type="ManifestFile",
    content_type="application/x-image",
    wait=False,
)

In [16]:
print("Latest transform job:", transformer.latest_transform_job.name)

Latest transform job: pytorch-inference-2022-04-18-00-26-08-985


In [17]:
# look at the status of the transform job
import pprint as pp

sm_cli = sagemaker_session.sagemaker_client

job_info = sm_cli.describe_transform_job(TransformJobName=transformer.latest_transform_job.name)

pp.pprint(job_info)

{'CreationTime': datetime.datetime(2022, 4, 18, 0, 26, 9, 32000, tzinfo=tzlocal()),
 'DataProcessing': {'InputFilter': '$',
                    'JoinSource': 'None',
                    'OutputFilter': '$'},
 'ModelName': 'pytorch-inference-2022-04-18-00-21-14-564',
 'ResponseMetadata': {'HTTPHeaders': {'content-length': '870',
                                      'content-type': 'application/x-amz-json-1.1',
                                      'date': 'Mon, 18 Apr 2022 00:26:09 GMT',
                                      'x-amzn-requestid': '394fa75d-5814-41e3-a604-2a471c52c745'},
                      'HTTPStatusCode': 200,
                      'RequestId': '394fa75d-5814-41e3-a604-2a471c52c745',
                      'RetryAttempts': 0},
 'TransformInput': {'CompressionType': 'None',
                    'ContentType': 'application/x-image',
                    'DataSource': {'S3DataSource': {'S3DataType': 'ManifestFile',
                                                    'S3Uri

###  Multiple instance
We use `instance_count > 1` to create multiple inference instances. When a batch transform job starts, Amazon SageMaker initializes compute instances and distributes the inference or preprocessing workload between them. Batch Transform partitions the Amazon S3 objects in the input by key and maps Amazon S3 objects to instances. Given multiple files, one instance might process input1.csv, and another instance might process input2.csv. Read more at [Use Batch Transform](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html).

In [18]:
dist_transformer = estimator.transformer(instance_count=2, instance_type="ml.c4.xlarge")

dist_transformer.transform(
    data=inference_inputs,
    data_type="S3Prefix",
    content_type="application/x-image",
    wait=True,
)

................................[35m2022-04-18 00:31:23,943 [INFO ] main org.pytorch.serve.ModelServer - [0m
[35mTorchserve version: 0.3.0[0m
[35mTS Home: /opt/conda/lib/python3.6/site-packages[0m
[35mCurrent directory: /[0m
[35mTemp directory: /home/model-server/tmp[0m
[35mNumber of GPUs: 0[0m
[35mNumber of CPUs: 4[0m
[35mMax heap size: 910 M[0m
[35mPython executable: /opt/conda/bin/python3.6[0m
[35mConfig file: /etc/sagemaker-ts.properties[0m
[35mInference address: http://0.0.0.0:8080[0m
[35mManagement address: http://0.0.0.0:8080[0m
[35mMetrics address: http://127.0.0.1:8082[0m
[35mModel Store: /.sagemaker/ts/models[0m
[35mInitial Models: model.mar[0m
[35mLog dir: /logs[0m
[35mMetrics dir: /logs[0m
[35mNetty threads: 0[0m
[35mNetty client threads: 0[0m
[35mDefault workers per model: 4[0m
[35mBlacklist Regex: N/A[0m
[35mMaximum Response Size: 6553500[0m
[35mMaximum Request Size: 6553500[0m
[35mPrefer direct buffer: false[0m
[35mAllowed


[34m2022-04-18 00:31:35,112 [INFO ] pool-1-thread-5 ACCESS_LOG - /169.254.255.130:44472 "GET /ping HTTP/1.1" 200 18[0m
[34m2022-04-18 00:31:35,112 [INFO ] pool-1-thread-5 TS_METRICS - Requests2XX.Count:1|#Level:Host|#hostname:aeaba6005455,timestamp:null[0m
[34m2022-04-18 00:31:35,140 [INFO ] epollEventLoopGroup-3-2 ACCESS_LOG - /169.254.255.130:44484 "GET /execution-parameters HTTP/1.1" 404 1[0m
[34m2022-04-18 00:31:35,141 [INFO ] epollEventLoopGroup-3-2 TS_METRICS - Requests4XX.Count:1|#Level:Host|#hostname:aeaba6005455,timestamp:null[0m
[34m2022-04-18 00:31:35,329 [INFO ] W-9002-model_1 org.pytorch.serve.wlm.WorkerThread - Backend response time: 16[0m
[32m2022-04-18 00:31:35,112 [INFO ] pool-1-thread-5 ACCESS_LOG - /169.254.255.130:44472 "GET /ping HTTP/1.1" 200 18[0m
[32m2022-04-18 00:31:35,112 [INFO ] pool-1-thread-5 TS_METRICS - Requests2XX.Count:1|#Level:Host|#hostname:aeaba6005455,timestamp:null[0m
[32m2022-04-18 00:31:35,140 [INFO ] epollEventLoopGroup-3-2 ACCES

[34m2022-04-18T00:31:36.242:[sagemaker logs]: MaxConcurrentTransforms=1, MaxPayloadInMB=6, BatchStrategy=MULTI_RECORD[0m
[33m2022-04-18 00:31:40,200 [INFO ] W-9002-model_1 org.pytorch.serve.wlm.WorkerThread - Backend response time: 3[0m
[33m2022-04-18 00:31:40,200 [INFO ] W-9002-model_1-stdout MODEL_METRICS - PredictionTime.Milliseconds:2.02|#ModelName:model,Level:Model|#hostname:6925eb526ceb,requestID:6738de45-b281-4d36-b934-b536cdfd6094,timestamp:1650241900[0m
[33m2022-04-18 00:31:40,200 [INFO ] W-9002-model_1 ACCESS_LOG - /169.254.255.130:33038 "POST /invocations HTTP/1.1" 200 4[0m
[33m2022-04-18 00:31:40,200 [INFO ] W-9002-model_1 TS_METRICS - Requests2XX.Count:1|#Level:Host|#hostname:6925eb526ceb,timestamp:null[0m
[33m2022-04-18 00:31:40,200 [INFO ] W-9002-model_1 TS_METRICS - QueueTime.ms:0|#Level:Host|#hostname:6925eb526ceb,timestamp:null[0m
[33m2022-04-18 00:31:40,201 [INFO ] W-9002-model_1 TS_METRICS - WorkerThreadTime.ms:1|#Level:Host|#hostname:6925eb526ceb,times

## Look at all transform jobs

We list and describe the transform jobs to retrieve information about them.

In [19]:
transform_jobs = sm_cli.list_transform_jobs()["TransformJobSummaries"]
for job in transform_jobs:
    pp.pprint(job)

{'CreationTime': datetime.datetime(2022, 4, 18, 0, 32, 3, 162000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.datetime(2022, 4, 18, 0, 32, 3, 162000, tzinfo=tzlocal()),
 'TransformJobArn': 'arn:aws:sagemaker:us-west-2:000000000000:transform-job/automl-churn-sdk-18-00-20-19-dpp0-rpb-1-ab06a7aa54904c088293a6d',
 'TransformJobName': 'automl-churn-sdk-18-00-20-19-dpp0-rpb-1-ab06a7aa54904c088293a6d',
 'TransformJobStatus': 'InProgress'}
{'CreationTime': datetime.datetime(2022, 4, 18, 0, 32, 0, 946000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.datetime(2022, 4, 18, 0, 32, 1, 607000, tzinfo=tzlocal()),
 'TransformJobArn': 'arn:aws:sagemaker:us-west-2:000000000000:transform-job/automl-churn-sdk-18-00-20-19-dpp8-rpb-1-9abdff16bbd34c83872e871',
 'TransformJobName': 'automl-churn-sdk-18-00-20-19-dpp8-rpb-1-9abdff16bbd34c83872e871',
 'TransformJobStatus': 'InProgress'}
{'CreationTime': datetime.datetime(2022, 4, 18, 0, 31, 58, 909000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.da

In [20]:
job_info = sm_cli.describe_transform_job(
    TransformJobName=dist_transformer.latest_transform_job.name
)

pp.pprint(job_info)

{'CreationTime': datetime.datetime(2022, 4, 18, 0, 26, 12, 635000, tzinfo=tzlocal()),
 'DataProcessing': {'InputFilter': '$',
                    'JoinSource': 'None',
                    'OutputFilter': '$'},
 'Environment': {},
 'ModelName': 'pytorch-training-2022-04-18-00-26-11-963',
 'ResponseMetadata': {'HTTPHeaders': {'content-length': '907',
                                      'content-type': 'application/x-amz-json-1.1',
                                      'date': 'Mon, 18 Apr 2022 00:32:03 GMT',
                                      'x-amzn-requestid': 'c75d7d27-2982-45b7-b619-f04bffe72fff'},
                      'HTTPStatusCode': 200,
                      'RequestId': 'c75d7d27-2982-45b7-b619-f04bffe72fff',
                      'RetryAttempts': 0},
 'TransformEndTime': datetime.datetime(2022, 4, 18, 0, 31, 42, 720000, tzinfo=tzlocal()),
 'TransformInput': {'CompressionType': 'None',
                    'ContentType': 'application/x-image',
                    'DataSour

In [21]:
import re


def get_bucket_and_prefix(s3_output_path):
    trim = re.sub("s3://", "", s3_output_path)
    bucket, prefix = trim.split("/")
    return bucket, prefix


local_path = "output"  # Where to save the output locally

bucket, output_prefix = get_bucket_and_prefix(job_info["TransformOutput"]["S3OutputPath"])
print(bucket, output_prefix)

sagemaker_session.download_data(path=local_path, bucket=bucket, key_prefix=output_prefix)

sagemaker-us-west-2-521695447989 pytorch-training-2022-04-18-00-26-12-621


In [22]:
!ls {local_path}

0.png.out   24.png.out	4.png.out   55.png.out	70.png.out  86.png.out
1.png.out   25.png.out	40.png.out  56.png.out	71.png.out  87.png.out
10.png.out  26.png.out	41.png.out  57.png.out	72.png.out  88.png.out
11.png.out  27.png.out	42.png.out  58.png.out	73.png.out  89.png.out
12.png.out  28.png.out	43.png.out  59.png.out	74.png.out  9.png.out
13.png.out  29.png.out	44.png.out  6.png.out	75.png.out  90.png.out
14.png.out  3.png.out	45.png.out  60.png.out	76.png.out  91.png.out
15.png.out  30.png.out	46.png.out  61.png.out	77.png.out  92.png.out
16.png.out  31.png.out	47.png.out  62.png.out	78.png.out  93.png.out
17.png.out  32.png.out	48.png.out  63.png.out	79.png.out  94.png.out
18.png.out  33.png.out	49.png.out  64.png.out	8.png.out   95.png.out
19.png.out  34.png.out	5.png.out   65.png.out	80.png.out  96.png.out
2.png.out   35.png.out	50.png.out  66.png.out	81.png.out  97.png.out
20.png.out  36.png.out	51.png.out  67.png.out	82.png.out  98.png.out
21.png.out  37.png.out	

In [23]:
# Inspect the output

import json

for f in os.listdir(local_path):
    path = os.path.join(local_path, f)
    with open(path, "r") as f:
        pred = json.load(f)
        print(pred)

{'predictions': 9}
{'predictions': 9}
{'predictions': 7}
{'predictions': 3}
{'predictions': 4}
{'predictions': 5}
{'predictions': 7}
{'predictions': 3}
{'predictions': 0}
{'predictions': 2}
{'predictions': 2}
{'predictions': 5}
{'predictions': 8}
{'predictions': 4}
{'predictions': 3}
{'predictions': 8}
{'predictions': 6}
{'predictions': 0}
{'predictions': 0}
{'predictions': 1}
{'predictions': 2}
{'predictions': 4}
{'predictions': 4}
{'predictions': 0}
{'predictions': 7}
{'predictions': 8}
{'predictions': 4}
{'predictions': 7}
{'predictions': 5}
{'predictions': 4}
{'predictions': 6}
{'predictions': 5}
{'predictions': 2}
{'predictions': 1}
{'predictions': 5}
{'predictions': 6}
{'predictions': 4}
{'predictions': 8}
{'predictions': 0}
{'predictions': 2}
{'predictions': 8}
{'predictions': 4}
{'predictions': 7}
{'predictions': 9}
{'predictions': 2}
{'predictions': 0}
{'predictions': 9}
{'predictions': 3}
{'predictions': 5}
{'predictions': 4}
{'predictions': 6}
{'predictions': 7}
{'prediction

## Conclusion

In this notebook, we trained a PyTorch model, created a transformer from it, and then performed batch inference using S3 inputs, manifest files, and on multiple instances. This shows a variety of options that are available when running SageMaker Batch Transform jobs for batch inference.