# Init

In [27]:
import torch
import timm, tome
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

In [28]:
imagenet_data = datasets.CIFAR100('/tmp', train = False, download = True,
                                  transform=transforms.Compose([
                                      transforms.Resize(256),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                                  ]))
subset = torch.utils.data.Subset(imagenet_data, [i for i in range(1000)])

Files already downloaded and verified


In [29]:
data_loader = DataLoader(subset, batch_size=1, shuffle=False)

In [30]:
# Hàm đánh giá mô hình
def evaluate(model, data_loader, flag):
    total = 0
    correct = 0
    count = 0
    with torch.no_grad():
        for images, labels in tqdm(data_loader):
            images, labels = images.cuda(), labels.cuda()
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            count = count + 1
            if flag and count == 1:
                return 'debug'
    return 100 * correct / total

# Default Model

In [31]:
model = timm.create_model("deit_tiny_patch16_224.fb_in1k", pretrained=True, num_classes= 100)
model.load_state_dict(torch.load('checkpoints/deit_tiny_patch16_224.fb_in1k_cifar100.bin'))

<All keys matched successfully>

In [32]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

print('Done')

Done


In [33]:
evaluate(model, data_loader, 0)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:37<00:00, 26.44it/s]


74.3

# Tome Model

In [34]:
r = [4, 8, 12, 16, 20]
source_types = ['x_atnn', 'metric', 'x+x_attn', 'x_output']
concat_score = []
for source_type in source_types:
    for i in r:
        tome.patch.timm(model)
        model.r = i
        #model.method = ['pruned', 'pruned', 'pruned', 'pruned', 'pruned', 'pruned', 'tofu', 'tofu', 'tofu', 'tofu', 'tofu', 'tofu']
        model.source_type = source_type
        concat_score.append({'r' : i, 
                             'input_source' : source_type,
                             'score' : evaluate(model, data_loader, 0)})

100%|██████████| 1000/1000 [00:39<00:00, 25.05it/s]
100%|██████████| 1000/1000 [00:40<00:00, 24.64it/s]
100%|██████████| 1000/1000 [00:40<00:00, 24.66it/s]
100%|██████████| 1000/1000 [00:41<00:00, 24.14it/s]
100%|██████████| 1000/1000 [00:40<00:00, 24.90it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.18it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.33it/s]
100%|██████████| 1000/1000 [00:48<00:00, 20.44it/s]
100%|██████████| 1000/1000 [00:48<00:00, 20.47it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.40it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.02it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.04it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.08it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.29it/s]
100%|██████████| 1000/1000 [00:48<00:00, 20.63it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.35it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.35it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.40it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.29it/s]
100%|███████

In [35]:
import pandas as pd 
score = pd.DataFrame(concat_score)

In [36]:
score

Unnamed: 0,r,input_source,score
0,4,x_atnn,74.3
1,8,x_atnn,74.3
2,12,x_atnn,74.3
3,16,x_atnn,74.3
4,20,x_atnn,74.3
5,4,metric,74.0
6,8,metric,73.7
7,12,metric,73.4
8,16,metric,72.8
9,20,metric,69.0


In [39]:
model = timm.create_model("deit3_large_patch16_224", pretrained=True, num_classes= 100)

In [40]:
model

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): 