In [54]:
import torch
from torch import nn
import numpy as np
from torchvision import transforms as T
import timm
import torchinfo
from PIL import Image
import json

In [2]:
import swin

In [3]:
swin_transformer = swin.SwinTransformer() # default config is swin_tiny_patch4_window7_224
state_dict = torch.load("swin_tiny_patch4_window7_224.pth", map_location="cpu")["model"]
swin_transformer.load_state_dict(state_dict)

<All keys matched successfully>

In [11]:
torchinfo.summary(swin_transformer, (64, 3, 224, 224))

Layer (type:depth-idx)                             Output Shape              Param #
SwinTransformer                                    --                        --
├─ModuleList: 1-1                                  --                        --
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-1                        --                        224,694
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-2                        --                        891,756
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-3                        --                        10,658,952
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-4                        --                        14,183,856
├─PatchEmbed: 1-2                                  [64, 3136, 96]            --
│    └─Co

In [12]:
torch.cuda.empty_cache()

In [17]:
import os
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [18]:
print_size_of_model(swin_transformer)

Size (MB): 114.334469


In [69]:
import torch.quantization
torch.backends.quantized.engine = 'qnnpack'
quantized_swin = torch.quantization.quantize_dynamic(
    swin_transformer, {torch.nn.Linear}, dtype=torch.qint8
)

In [70]:
print_size_of_model(quantized_swin)

Size (MB): 29.793839


In [59]:
labels = json.load(open("imagenet_labels.json", mode="r"))

In [22]:
transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(timm.data.IMAGENET_DEFAULT_MEAN, timm.data.IMAGENET_DEFAULT_STD)
])

In [31]:
def get_input(filepath):
    img = Image.open(filepath).convert("RGB")
    timg = transform(img)
    return timg.unsqueeze(0)

In [32]:
cheetah_img = get_input("../animals/dataset/acinonyx-jubatus/13.jpg")

In [63]:
%%time
with torch.no_grad():
    out = swin_transformer(cheetah_img)
    pred = out.argmax(dim=1).item()
    print(labels[pred])

cheetah
CPU times: user 234 ms, sys: 35.1 ms, total: 269 ms
Wall time: 99.9 ms


In [64]:
%%time
with torch.no_grad():
    quantized_out = quantized_swin(cheetah_img)
    pred = out.argmax(dim=1).item()
    print(labels[pred])

cheetah
CPU times: user 262 ms, sys: 27.6 ms, total: 289 ms
Wall time: 82.5 ms
