In [1]:
!git clone https://github.com/franklinnwren/ViT-pytorch.git
%cd ViT-pytorch/
!pip install ml-collections

Cloning into 'ViT-pytorch'...
remote: Enumerating objects: 187, done.[K
remote: Counting objects: 100% (57/57), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 187 (delta 42), reused 27 (delta 27), pack-reused 130[K
Receiving objects: 100% (187/187), 21.31 MiB | 16.26 MiB/s, done.
Resolving deltas: 100% (95/95), done.
/content/ViT-pytorch
Collecting ml-collections
  Downloading ml_collections-0.1.1.tar.gz (77 kB)
[K     |████████████████████████████████| 77 kB 3.9 MB/s 
Building wheels for collected packages: ml-collections
  Building wheel for ml-collections (setup.py) ... [?25l[?25hdone
  Created wheel for ml-collections: filename=ml_collections-0.1.1-py3-none-any.whl size=94524 sha256=7a474a1b174c7b6366fad57f71a271bf02c93cfc8817d06d7ca2fef9228ffc11
  Stored in directory: /root/.cache/pip/wheels/b7/da/64/33c926a1b10ff19791081b705879561b715a8341a856a3bbd2
Successfully built ml-collections
Installing collected packages: ml-collections
Successfully instal

In [5]:
!pip install einops

Collecting einops
  Downloading einops-0.4.0-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.0


In [6]:
#if running on a local divice, comment these lines
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
from __future__ import absolute_import, division, print_function

import logging
import argparse
import os
import random
import numpy as np

from datetime import timedelta

import torch
import torch.distributed as dist

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from models.modeling import VisionTransformer, CONFIGS
from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size

parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--name", default="cifar10",
                    help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["cifar10", "cifar100"], default="cifar10",
                    help="Which downstream task.")
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16",
                                              "ViT-L_32", "ViT-H_14", "R50-ViT-B_16"],
                    default="ViT-B_16",
                    help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="ViT-B_16.npz",
                    help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default="../drive/MyDrive/cifar10-100_500_checkpoint_prune.bin")
parser.add_argument("--output_dir", default="output", type=str,
                    help="The output directory where checkpoints will be written.")

parser.add_argument("--img_size", default=224, type=int,
                    help="Resolution size")
parser.add_argument("--train_batch_size", default=512, type=int,
                    help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=64, type=int,
                    help="Total batch size for eval.")
parser.add_argument("--eval_every", default=100, type=int,
                    help="Run prediction on validation set every so many steps."
                          "Will always run one evaluation at the end of training.")

parser.add_argument("--learning_rate", default=3e-2, type=float,
                    help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0, type=float,
                    help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=10000, type=int,
                    help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
                    help="How to decay the learning rate.")
parser.add_argument("--warmup_steps", default=500, type=int,
                    help="Step of training to perform learning rate warmup for.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
                    help="Max gradient norm.")

parser.add_argument("--local_rank", type=int, default=-1,
                    help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
                    help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
                    help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O2',
                    help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                          "See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--loss_scale', type=float, default=0,
                    help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                          "0 (default value): dynamic loss scaling.\n"
                          "Positive power of 2: static loss scaling value.\n")
args = parser.parse_args("")

# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()
else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend='nccl',
                                          timeout=timedelta(minutes=60))
    args.n_gpu = 1

args.name = 'cifar10'
args.device = device

config = CONFIGS[args.model_type]
num_classes = 10 if args.dataset == "cifar10" else 100
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
model.load_state_dict(torch.load(args.pretrained_model))
model.to(args.device)
train_loader, test_loader = get_loader(args)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [9]:
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn

total = 0
for name, m in model.named_modules():
    if 'select' in name:
        print(m)
        total += m.indexes.data.shape[0]

print("pause")

bn = torch.zeros(total)
index = 0
for name, m in model.named_modules():
    if 'select' in name:
        print(m)
        size = m.indexes.data.shape[0]
        bn[index:(index+size)] = m.indexes.data.abs().clone()
        index += size

channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
pause
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()
channel_selection()


In [10]:
percent = 0.1
y, i = torch.sort(bn)
thre_index = int(total * percent)
thre = y[thre_index]

# print(thre)

pruned = 0
cfg = []
cfg_mask = []
for k, (name, m) in enumerate(model.named_modules()):
    if 'select' in name:
        # print(k)
        # print(m)
        if k not in [14,31,48,65,82,99,116,133,150,167,184,201]:
            weight_copy = m.indexes.data.abs().clone()
            mask = weight_copy.gt(thre).float().cuda()
            thre_ = thre.clone()
            while (torch.sum(mask) % 8 != 0): # heads
                thre_ = thre_ - 0.001
                mask = weight_copy.gt(thre_).float().cuda()
        else:
            weight_copy = m.indexes.data.abs().clone()
            mask = weight_copy.gt(thre).float().cuda()
        pruned = pruned + mask.shape[0] - torch.sum(mask)
        m.indexes.data.mul_(mask)
        cfg.append(int(torch.sum(mask)))
        cfg_mask.append(mask.clone())
        print('layer index: {:d} \t total channel: {:d} \t remaining channel: {:d}'.
            format(k, mask.shape[0], int(torch.sum(mask))))

layer index: 14 	 total channel: 3072 	 remaining channel: 2699
layer index: 19 	 total channel: 768 	 remaining channel: 448
layer index: 31 	 total channel: 3072 	 remaining channel: 2807
layer index: 36 	 total channel: 768 	 remaining channel: 664
layer index: 48 	 total channel: 3072 	 remaining channel: 2872
layer index: 53 	 total channel: 768 	 remaining channel: 432
layer index: 65 	 total channel: 3072 	 remaining channel: 2925
layer index: 70 	 total channel: 768 	 remaining channel: 488
layer index: 82 	 total channel: 3072 	 remaining channel: 2937
layer index: 87 	 total channel: 768 	 remaining channel: 568
layer index: 99 	 total channel: 3072 	 remaining channel: 2966
layer index: 104 	 total channel: 768 	 remaining channel: 528
layer index: 116 	 total channel: 3072 	 remaining channel: 2991
layer index: 121 	 total channel: 768 	 remaining channel: 672
layer index: 133 	 total channel: 3072 	 remaining channel: 3017
layer index: 138 	 total channel: 768 	 remaining 

In [11]:
def simple_accuracy(preds, labels):
    return (preds == labels).mean()

def test(model, test_loader):
  model.eval()
  all_preds, all_label = [], []
  epoch_iterator = tqdm(test_loader,
                        desc="Validating... (loss=X.X)",
                        bar_format="{l_bar}{r_bar}",
                        dynamic_ncols=True,
                        disable=args.local_rank not in [-1, 0])
  loss_fct = torch.nn.CrossEntropyLoss()
  for step, batch in enumerate(epoch_iterator):
      batch = tuple(t.to(args.device) for t in batch)
      x, y = batch
      with torch.no_grad():
          logits = model(x)[0]

          eval_loss = loss_fct(logits, y)
          # eval_losses.update(eval_loss.item())

          preds = torch.argmax(logits, dim=-1)

      if len(all_preds) == 0:
          all_preds.append(preds.detach().cpu().numpy())
          all_label.append(y.detach().cpu().numpy())
      else:
          all_preds[0] = np.append(
              all_preds[0], preds.detach().cpu().numpy(), axis=0
          )
          all_label[0] = np.append(
              all_label[0], y.detach().cpu().numpy(), axis=0
          )
      # epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)

  all_preds, all_label = all_preds[0], all_label[0]
  accuracy = simple_accuracy(all_preds, all_label)
  print("Valid Accuracy: %2.5f" % accuracy)

In [12]:
test(model, test_loader)

Validating... (loss=X.X): 100%|| 157/157 [00:35<00:00,  4.44it/s]

Valid Accuracy: 0.88360



