In [1]:
import torch
import timm
import sys
sys.path.append('../')
from timm.models.vision_transformer import VisionTransformer
from fvcore.nn import FlopCountAnalysis, flop_count_table, flop_count_str
#from beit import BeitTeacher
#from tokenrank_vit import TokenRankVisionTransformer
#from tokenrank_beit import TokenRankBeit
from PIL import Image
from transformers import ViTFeatureExtractor
import requests
%load_ext autoreload
%autoreload 2

from vit import VisionTransformerDiffPruning

In [2]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
feature_extractor = ViTFeatureExtractor()
inputs = feature_extractor(images=image, return_tensors="pt")['pixel_values']

In [5]:
#Testing DEIT-tiny
base_rate = 0.9
model_path = "logs/dynamic-vit_deit-tiny-0.9/checkpoint_best.pth"

PRUNING_LOC = [3,6,9]
KEEP_RATE = [base_rate, base_rate ** 2, base_rate ** 3]
model = VisionTransformerDiffPruning(
            patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 
            pruning_loc=PRUNING_LOC, token_ratio=KEEP_RATE
            )

checkpoint = torch.load(model_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])

model = model.eval()
#device = 'cuda:0'
#model = model.to(device)
#inputs = inputs.to(device)
# op = model(inputs)
flop = FlopCountAnalysis(model, inputs)
print(flop_count_table(flop, max_depth=4))
print(flop_count_str(flop))
print(flop.total())

## diff vit pruning method
| module                           | #parameters or shape   | #flops      |
|:---------------------------------|:-----------------------|:------------|
| model                            | 5.9M                   | 1.096G      |
|  cls_token                       |  (1, 1, 192)           |             |
|  pos_embed                       |  (1, 197, 192)         |             |
|  patch_embed.proj                |  0.148M                |  28.901M    |
|   patch_embed.proj.weight        |   (192, 3, 16, 16)     |             |
|   patch_embed.proj.bias          |   (192,)               |             |
|  blocks                          |  5.338M                |  1.034G     |
|   blocks.0                       |   0.445M               |   0.102G    |
|    blocks.0.norm1                |    0.384K              |    0.189M   |
|     blocks.0.norm1.weight        |     (192,)             |             |
|     blocks.0.norm1.bias          |     (192,)             |

In [6]:
#Testing DEIT-256
base_rate = 0.7
model_path = "models/dynamic-vit_256_r0.7.pth"

PRUNING_LOC = [3,6,9]
KEEP_RATE = [base_rate, base_rate ** 2, base_rate ** 3]
model = VisionTransformerDiffPruning(
            patch_size=16, embed_dim=256, depth=12, num_heads=4, mlp_ratio=4, qkv_bias=True, 
            pruning_loc=PRUNING_LOC, token_ratio=KEEP_RATE
            )

checkpoint = torch.load(model_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])

model = model.eval()
#device = 'cuda:0'
#model = model.to(device)
#inputs = inputs.to(device)
# op = model(inputs)
flop = FlopCountAnalysis(model, inputs)
print(flop_count_table(flop, max_depth=4))
print(flop_count_str(flop))
print(flop.total())

## diff vit pruning method
| module                           | #parameters or shape   | #flops      |
|:---------------------------------|:-----------------------|:------------|
| model                            | 10.305M                | 1.379G      |
|  cls_token                       |  (1, 1, 256)           |             |
|  pos_embed                       |  (1, 197, 256)         |             |
|  patch_embed.proj                |  0.197M                |  38.535M    |
|   patch_embed.proj.weight        |   (256, 3, 16, 16)     |             |
|   patch_embed.proj.bias          |   (256,)               |             |
|  blocks                          |  9.477M                |  1.294G     |
|   blocks.0                       |   0.79M                |   0.175G    |
|    blocks.0.norm1                |    0.512K              |    0.252M   |
|     blocks.0.norm1.weight        |     (256,)             |             |
|     blocks.0.norm1.bias          |     (256,)             |

In [7]:
#Testing DEIT-small
base_rate = 0.7
model_path = "models/dynamic-vit_384_r0.7.pth"

PRUNING_LOC = [3,6,9]
KEEP_RATE = [base_rate, base_rate ** 2, base_rate ** 3]
model = VisionTransformerDiffPruning(
            patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 
            pruning_loc=PRUNING_LOC, token_ratio=KEEP_RATE
            )

checkpoint = torch.load(model_path, map_location="cpu")
model.load_state_dict(checkpoint["model"])

model = model.eval()
#device = 'cuda:0'
#model = model.to(device)
#inputs = inputs.to(device)
# op = model(inputs)
flop = FlopCountAnalysis(model, inputs)
print(flop_count_table(flop, max_depth=4))
print(flop_count_str(flop))
print(flop.total())

## diff vit pruning method
| module                           | #parameters or shape   | #flops      |
|:---------------------------------|:-----------------------|:------------|
| model                            | 22.774M                | 2.988G      |
|  cls_token                       |  (1, 1, 384)           |             |
|  pos_embed                       |  (1, 197, 384)         |             |
|  patch_embed.proj                |  0.295M                |  57.803M    |
|   patch_embed.proj.weight        |   (384, 3, 16, 16)     |             |
|   patch_embed.proj.bias          |   (384,)               |             |
|  blocks                          |  21.294M               |  2.826G     |
|   blocks.0                       |   1.774M               |   0.379G    |
|    blocks.0.norm1                |    0.768K              |    0.378M   |
|     blocks.0.norm1.weight        |     (384,)             |             |
|     blocks.0.norm1.bias          |     (384,)             |