In [2]:
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(torch.__version__)
# should be 1.8.0


model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item())

1.13.1+cu117


Downloading: "https://github.com/facebookresearch/deit/zipball/main" to /home/alexander/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /home/alexander/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth
100%|██████████| 330M/330M [00:16<00:00, 21.4MB/s] 


269


To use the model on mobile, we first need to script the model.

In [3]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("./models/fbdeit_scripted.pt")

Using cache found in /home/alexander/.cache/torch/hub/facebookresearch_deit_main


In [4]:
# Use 'fbgemm' for server inference and 'qnnpack' for mobile inference
backend = "fbgemm" # replaced with qnnpack causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save("./models/fbdeit_scripted_quantized.pt")



In [5]:
out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# The same output 269 should be printed

269


Optimization

In [6]:
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("./models/fbdeit_optimized_scripted_quantized.pt")

In [7]:
out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# Again, the same output 269 should be printed

269


In [9]:
optimized_scripted_quantized_model._save_for_lite_interpreter("./models/fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("./models/fbdeit_optimized_scripted_quantized_lite.ptl")

In [10]:
with torch.autograd.profiler.profile(use_cuda=False) as prof1:
    out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
    out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
    out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
    out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
    out = ptl(img)

print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))

STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerController.cpp:300] Completed Stage: Collection
STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-19 11:12:10 1148887:1148887 ActivityProfilerCon

original model: 235.94ms
scripted model: 189.81ms
scripted & quantized model: 70.28ms
scripted & quantized & optimized model: 80.79ms
lite model: 84.56ms
