In [6]:
import qai_hub as hub
import torch
import torchvision
import torchviz
import torchsummary
import requests
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [7]:
resNet18Torch = torchvision.models.resnet18(pretrained=True)
mobileNetV3Torch = torchvision.models.mobilenet_v3_small(pretrained=True)
mobileNetV2Torch = torchvision.models.mobilenet_v2(pretrained=True)




In [8]:
summary = torchsummary.summary(mobileNetV2Torch, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
             ReLU6-3         [-1, 32, 112, 112]               0
            Conv2d-4         [-1, 32, 112, 112]             288
       BatchNorm2d-5         [-1, 32, 112, 112]              64
             ReLU6-6         [-1, 32, 112, 112]               0
            Conv2d-7         [-1, 16, 112, 112]             512
       BatchNorm2d-8         [-1, 16, 112, 112]              32
  InvertedResidual-9         [-1, 16, 112, 112]               0
           Conv2d-10         [-1, 96, 112, 112]           1,536
      BatchNorm2d-11         [-1, 96, 112, 112]             192
            ReLU6-12         [-1, 96, 112, 112]               0
           Conv2d-13           [-1, 96, 56, 56]             864
      BatchNorm2d-14           [-1, 96,

In [9]:
summaryResNet = torchsummary.summary(resNet18Torch, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [10]:
mobileNetV3Torch.eval()
mobileNetV2Torch.eval()
resNet18Torch.eval()

workingModel = mobileNetV2Torch

# Step 1: Trace model
input_shape = (1, 3, 224, 224)
example_input = torch.rand(input_shape)
traced_torch_model = torch.jit.trace(workingModel, example_input)

torch.jit.save(traced_torch_model, 'workingModel.pt')


In [11]:

# Step 2: Compile model
compile_job = hub.submit_compile_job(
    model=traced_torch_model,
    device=hub.Device("Snapdragon 8 Elite QRD"),
    input_specs=dict(image=input_shape),
)


Uploading tmpdb5ix3ih.pt


100%|[34m██████████[0m| 14.0M/14.0M [00:02<00:00, 5.65MB/s]


Scheduled compile job (jgjdvqyxg) successfully. To see the status and results:
    https://app.aihub.qualcomm.com/jobs/jgjdvqyxg/



In [12]:

# Step 3: Profile on cloud-hosted device
target_model = compile_job.get_target_model()

print(type(target_model))


Waiting for compile job (jgjdvqyxg) completion. Type Ctrl+C to stop waiting at any time.
    ✅ SUCCESS                          
<class 'qai_hub.client.Model'>


In [13]:


profile_job = hub.submit_profile_job(
    model=target_model,
    device=hub.Device("Snapdragon 8 Elite QRD"),
)


Scheduled profile job (jpeodyx1g) successfully. To see the status and results:
    https://app.aihub.qualcomm.com/jobs/jpeodyx1g/



In [14]:

# Step 4: Run inference on cloud-hosted device
sample_image_url = (
    "https://qaihub-public-assets.s3.us-west-2.amazonaws.com/apidoc/input_image1.jpg"
)
response = requests.get(sample_image_url, stream=True)
response.raw.decode_content = True
image = Image.open(response.raw).resize((224, 224))
input_array = np.expand_dims(
    np.transpose(np.array(image, dtype=np.float32) / 255.0, (2, 0, 1)), axis=0
)


In [15]:

# Run inference using the on-device model on the input image
inference_job = hub.submit_inference_job(
    model=target_model,
    device=hub.Device("Snapdragon 8 Elite QRD"),
    inputs=dict(image=[input_array]),
)
on_device_output = inference_job.download_output_data()


Uploading dataset: 154kB [00:01, 112kB/s]                    <?, ?B/s]


Scheduled inference job (jgz23nykg) successfully. To see the status and results:
    https://app.aihub.qualcomm.com/jobs/jgz23nykg/

Waiting for inference job (jgz23nykg) completion. Type Ctrl+C to stop waiting at any time.
    ✅ SUCCESS                          


tmp2n_q64g5.h5: 100%|[34m██████████[0m| 14.4k/14.4k [00:00<?, ?B/s]


In [16]:

# Step 5: Post-processing the on-device output
output_name = list(on_device_output.keys())[0]
out = on_device_output[output_name][0]
on_device_probabilities = np.exp(out) / np.sum(np.exp(out), axis=1)


In [17]:

# Read the class labels for imagenet
sample_classes = "https://qaihub-public-assets.s3.us-west-2.amazonaws.com/apidoc/imagenet_classes.txt"
response = requests.get(sample_classes, stream=True)
response.raw.decode_content = True
categories = [str(s.strip()) for s in response.raw]

# Print top five predictions for the on-device model
print("Top-5 On-Device predictions:")
top5_classes = np.argsort(on_device_probabilities[0], axis=0)[-5:]
for c in reversed(top5_classes):
    print(f"{c} {categories[c]:20s} {on_device_probabilities[0][c]:>6.1%}")

# Step 6: Download model
target_model = compile_job.get_target_model()
target_model.download("mobilenet_v2.tflite")


Top-5 On-Device predictions:
968 b'cup'                76.4%
504 b'coffee mug'         13.9%
967 b'espresso'            4.1%
849 b'teapot'              1.1%
725 b'pitcher'             1.1%


mobilenet_v2.tflite: 100%|[34m██████████[0m| 13.3M/13.3M [00:01<00:00, 10.7MB/s]

Downloaded model to mobilenet_v2.tflite





'mobilenet_v2.tflite'