In [1]:
import torch
import tarfile
import boto3
import sagemaker
import time
from sagemaker.utils import name_from_base


In [2]:
print(torch.__version__)

1.7.1


In [3]:
import urllib
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)

In [4]:
from torchvision import datasets

In [5]:
!pip install torchvision

You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_latest_p36/bin/python -m pip install --upgrade pip' command.[0m


In [6]:
!pip install pillow

You should consider upgrading via the '/home/ec2-user/anaconda3/envs/pytorch_latest_p36/bin/python -m pip install --upgrade pip' command.[0m


In [7]:

import torchvision


In [8]:
from torchvision import models
model = torchvision.models.resnet18(pretrained=True)

In [9]:
input_shape = [1, 3, 224, 224]
trace = torch.jit.trace(model.float().eval(), torch.zeros(input_shape).float())
trace.save("model.pth")
with tarfile.open("model.tar.gz", "w:gz") as f:
    f.add("model.pth")

In [10]:
role = sagemaker.get_execution_role()
sess = sagemaker.Session()
region = sess.boto_region_name
bucket = sess.default_bucket()

compilation_job_name = name_from_base("TorchVision-ResNet18-Neo")
prefix = compilation_job_name + "/model"
model_path = sess.upload_data(path="model.tar.gz", key_prefix=prefix)

data_shape = '{"input0":[1,3,224,224]}'
target_platform = {'Os': 'LINUX','Arch': 'X86_64'}
framework = "PYTORCH"
framework_version = "1.6"
compiled_model_path = "s3://{}/{}/output".format(bucket, compilation_job_name)


In [11]:
# Create a SageMaker client so you can submit a compilation job
sagemaker_client = boto3.client('sagemaker', region_name='us-east-1')

# Give your compilation job a name
print(f'Compilation job for {compilation_job_name} started')

response = sagemaker_client.create_compilation_job(
    CompilationJobName=compilation_job_name,
    RoleArn=role,
    InputConfig={
        'S3Uri': model_path,
        'DataInputConfig': data_shape,
        'Framework': framework.upper()
    },
    OutputConfig={
        'S3OutputLocation': compiled_model_path,
        'TargetPlatform': target_platform 
    },
    StoppingCondition={
        'MaxRuntimeInSeconds': 900
    }
)
while True:
    response = sagemaker_client.describe_compilation_job(CompilationJobName=compilation_job_name)
    if response['CompilationJobStatus'] == 'COMPLETED':
        break
    elif response['CompilationJobStatus'] == 'FAILED':
        raise RuntimeError('Compilation failed')
    print('Compiling ...')
    time.sleep(30)
print('Done!')

Compilation job for TorchVision-ResNet18-Neo-2021-08-05-01-02-01-874 started
Compiling ...
Compiling ...
Compiling ...
Compiling ...
Compiling ...
Compiling ...
Compiling ...
Compiling ...
Compiling ...
Compiling ...
Done!


In [12]:
s3_client = boto3.client('s3')
object_path = '{}/output/model-{}_{}.tar.gz'.format(compilation_job_name, target_platform['Os'], target_platform['Arch'])
print(object_path)

TorchVision-ResNet18-Neo-2021-08-05-01-02-01-874/output/model-LINUX_X86_64.tar.gz


In [13]:
neo_compiled_model = 'compiled-model.tar.gz'
s3_client.download_file(bucket, object_path, neo_compiled_model)

In [14]:
!mkdir model
!tar zfxv compiled-model.tar.gz -C model/
!zip compiled-model.zip model/*
s3_client.upload_file('compiled-model.zip', bucket, '{}/output/model-{}_{}.zip'.format(compilation_job_name, target_platform['Os'], target_platform['Arch']))



compiled.so
dlr.h
compiled_model.json
compiled.params
compiled.meta
libdlr.so
manifest
  adding: model/compiled.meta (deflated 66%)
  adding: model/compiled_model.json (deflated 93%)
  adding: model/compiled.params (deflated 7%)
  adding: model/compiled.so (deflated 78%)
  adding: model/dlr.h (deflated 83%)
  adding: model/libdlr.so (deflated 60%)
  adding: model/manifest (deflated 45%)
