# DeiT (Data-efficient Image Transformers)

Convolutional Neural Networks (CNNs) have been the main models for image classification since deep learning took off in 2012, but CNNs typically require hundreds of millions of images for training to achieve the SOTAresults. DeiT is a vision transformer model that requires a lot less data and computing resources for training to compete with the leading CNNs in performing image classification, which is made possible by two key components of of DeiT:

 - Data augmentation that simulates training on a much larger dataset
 - Native distillation that allows the transformer network to learn from a CNN’s output.

[**Repo**](https://github.com/facebookresearch/deit)

[**Paper**](https://arxiv.org/abs/2012.12877)

In [1]:
!pip install timm

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[?25l[K     |▊                               | 10 kB 25.7 MB/s eta 0:00:01[K     |█▌                              | 20 kB 33.6 MB/s eta 0:00:01[K     |██▎                             | 30 kB 16.3 MB/s eta 0:00:01[K     |███                             | 40 kB 11.6 MB/s eta 0:00:01[K     |███▉                            | 51 kB 7.5 MB/s eta 0:00:01[K     |████▋                           | 61 kB 7.5 MB/s eta 0:00:01[K     |█████▎                          | 71 kB 7.8 MB/s eta 0:00:01[K     |██████                          | 81 kB 8.8 MB/s eta 0:00:01[K     |██████▉                         | 92 kB 8.0 MB/s eta 0:00:01[K     |███████▋                        | 102 kB 7.0 MB/s eta 0:00:01[K     |████████▍                       | 112 kB 7.0 MB/s eta 0:00:01[K     |█████████▏                      | 122 kB 7.0 MB/s eta 0:00:01[K     |█████████▉                      | 133 kB 7.0 MB/s eta 0:00:01[K     

In [2]:
from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as T
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

In [5]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main
Downloading: "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth" to /root/.cache/torch/hub/checkpoints/deit_base_patch16_224-b5f2ef4d.pth


  0%|          | 0.00/330M [00:00<?, ?B/s]

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU()
        (drop1): Dropout(p=0.0, inplace=False)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (drop2): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((768,),

In [8]:
transform = T.Compose([
                        T.Resize(256, interpolation=3),
                        T.CenterCrop(224),
                        T.ToTensor(),
                        T.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [9]:
img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
img = transform(img)[None,]
out = model(img)
clsidx = torch.argmax(out)
print(clsidx.item()) # 269 is timber wolf, grey wolf, gray wolf, Canis lupus

269


In [10]:
# To use the model on mobile, we first need to script the model.
model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save('fbdeit_scripted.pt')

Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main


In [11]:
# Quantization -- reduces size significantly for small loss in accuracy
# Use 'fbgemm' for server inference and 'qnnpack' for mobile inference
backend = "fbgemm" # replaced with qnnpack causing much worse inference speed for quantized model on this notebook
model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

quantized_model = torch.quantization.quantize_dynamic(model, qconfig_spec={torch.nn.Linear}, dtype=torch.qint8)
scripted_quantized_model = torch.jit.script(quantized_model)
scripted_quantized_model.save('fbdeit_scripted_quantized.pt')

  reduce_range will be deprecated in a future release of PyTorch."


In [17]:
!ls -lh *.pt

-rw-r--r-- 1 root root 331M Feb  6 14:18 fbdeit_scripted.pt
-rw-r--r-- 1 root root  86M Feb  6 14:21 fbdeit_scripted_quantized.pt


In [18]:
out = scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# The same output 269 should be printed

269


In [19]:
# Optimizing DeiT
from torch.utils.mobile_optimizer import optimize_for_mobile
optimized_scripted_quantized_model = optimize_for_mobile(scripted_quantized_model)
optimized_scripted_quantized_model.save("fbdeit_optimized_scripted_quantized.pt")

In [20]:
out = optimized_scripted_quantized_model(img)
clsidx = torch.argmax(out)
print(clsidx.item())
# Again, the same output 269 should be printed

269


  return forward_call(*input, **kwargs)


In [21]:
!ls -lh *.pt

-rw-r--r-- 1 root root  86M Feb  6 14:24 fbdeit_optimized_scripted_quantized.pt
-rw-r--r-- 1 root root 331M Feb  6 14:18 fbdeit_scripted.pt
-rw-r--r-- 1 root root  86M Feb  6 14:21 fbdeit_scripted_quantized.pt


In [22]:
optimized_scripted_quantized_model._save_for_lite_interpreter("fbdeit_optimized_scripted_quantized_lite.ptl")
ptl = torch.jit.load("fbdeit_optimized_scripted_quantized_lite.ptl")

In [23]:
with torch.autograd.profiler.profile(use_cuda=False) as prof1:
    out = model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof2:
    out = scripted_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof3:
    out = scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof4:
    out = optimized_scripted_quantized_model(img)
with torch.autograd.profiler.profile(use_cuda=False) as prof5:
    out = ptl(img)

print("original model: {:.2f}ms".format(prof1.self_cpu_time_total/1000))
print("scripted model: {:.2f}ms".format(prof2.self_cpu_time_total/1000))
print("scripted & quantized model: {:.2f}ms".format(prof3.self_cpu_time_total/1000))
print("scripted & quantized & optimized model: {:.2f}ms".format(prof4.self_cpu_time_total/1000))
print("lite model: {:.2f}ms".format(prof5.self_cpu_time_total/1000))

original model: 713.52ms
scripted model: 731.18ms
scripted & quantized model: 460.61ms
scripted & quantized & optimized model: 426.97ms
lite model: 434.54ms


In [24]:
!ls -lh *.pt*

-rw-r--r-- 1 root root 167M Feb  6 14:25 fbdeit_optimized_scripted_quantized_lite.ptl
-rw-r--r-- 1 root root  86M Feb  6 14:24 fbdeit_optimized_scripted_quantized.pt
-rw-r--r-- 1 root root 331M Feb  6 14:18 fbdeit_scripted.pt
-rw-r--r-- 1 root root  86M Feb  6 14:21 fbdeit_scripted_quantized.pt
