In [15]:
!git clone https://github.com/jeonsworld/ViT-pytorch.git
%cd ViT-pytorch/
!pip install ml-collections

Cloning into 'ViT-pytorch'...
remote: Enumerating objects: 170, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 170 (delta 32), reused 27 (delta 27), pack-reused 130[K
Receiving objects: 100% (170/170), 21.31 MiB | 18.92 MiB/s, done.
Resolving deltas: 100% (85/85), done.
/content/ViT-pytorch/ViT-pytorch/ViT-pytorch


In [2]:
#if running on a local divice, comment these lines
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
from __future__ import absolute_import, division, print_function

import logging
import argparse
import os
import random
import numpy as np

from datetime import timedelta

import torch
import torch.distributed as dist

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from utils.scheduler import WarmupLinearSchedule, WarmupCosineSchedule
from utils.data_utils import get_loader
from utils.dist_util import get_world_size

parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--name", default="cifar100",
                    help="Name of this run. Used for monitoring.")
parser.add_argument("--dataset", choices=["cifar10", "cifar100"], default="cifar100",
                    help="Which downstream task.")
parser.add_argument("--model_type", choices=["ViT-B_16", "ViT-B_32", "ViT-L_16",
                                              "ViT-L_32", "ViT-H_14", "R50-ViT-B_16"],
                    default="ViT-B_16",
                    help="Which variant to use.")
parser.add_argument("--pretrained_dir", type=str, default="checkpoint/ViT-B_16.npz",
                    help="Where to search for pretrained ViT models.")
parser.add_argument("--pretrained_model", type=str, default="cifar100_checkpoint.bin")
parser.add_argument("--output_dir", default="output", type=str,
                    help="The output directory where checkpoints will be written.")

parser.add_argument("--img_size", default=224, type=int,
                    help="Resolution size")
parser.add_argument("--train_batch_size", default=512, type=int,
                    help="Total batch size for training.")
parser.add_argument("--eval_batch_size", default=64, type=int,
                    help="Total batch size for eval.")
parser.add_argument("--eval_every", default=100, type=int,
                    help="Run prediction on validation set every so many steps."
                          "Will always run one evaluation at the end of training.")

parser.add_argument("--learning_rate", default=3e-2, type=float,
                    help="The initial learning rate for SGD.")
parser.add_argument("--weight_decay", default=0, type=float,
                    help="Weight deay if we apply some.")
parser.add_argument("--num_steps", default=10000, type=int,
                    help="Total number of training epochs to perform.")
parser.add_argument("--decay_type", choices=["cosine", "linear"], default="cosine",
                    help="How to decay the learning rate.")
parser.add_argument("--warmup_steps", default=500, type=int,
                    help="Step of training to perform learning rate warmup for.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
                    help="Max gradient norm.")

parser.add_argument("--local_rank", type=int, default=-1,
                    help="local_rank for distributed training on gpus")
parser.add_argument('--seed', type=int, default=42,
                    help="random seed for initialization")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                    help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument('--fp16', action='store_true',
                    help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O2',
                    help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                          "See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument('--loss_scale', type=float, default=0,
                    help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                          "0 (default value): dynamic loss scaling.\n"
                          "Positive power of 2: static loss scaling value.\n")
args = parser.parse_args("")

# Setup CUDA, GPU & distributed training
if args.local_rank == -1:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.n_gpu = torch.cuda.device_count()
else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend='nccl',
                                          timeout=timedelta(minutes=60))
    args.n_gpu = 1

args.name = 'cifar100'
args.device = device

In [4]:
import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from torch.quantization import QuantStub, DeQuantStub
from scipy import ndimage

import models.configs as configs
from models.modeling import *

# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import math

from os.path import join as pjoin

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage

import models.configs as configs

from models.modeling_resnet import ResNetV2


logger = logging.getLogger(__name__)


ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"


def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)


def swish(x):
    return x * torch.sigmoid(x)


ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}


class Attention(nn.Module):
    def __init__(self, config, vis, is_quant=False):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.is_quant = is_quant
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        if self.is_quant:
          query_layer = self.dequant(query_layer)
          key_layer = self.dequant(key_layer)
          value_layer = self.dequant(value_layer)
        
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)

        if self.is_quant:
          context_layer = self.quant(context_layer)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights


class Mlp(nn.Module):
    def __init__(self, config, is_quant=False):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()
        self.is_quant = is_quant
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        if self.is_quant:
          x = self.dequant(x)
        x = self.act_fn(x)
        if self.is_quant:
          x = self.quant(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3, is_quant=False):
        super(Embeddings, self).__init__()
        self.hybrid = None
        self.is_quant = is_quant
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,
                                         width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)

        if self.hybrid:
            x = self.hybrid_model(x)
        x = self.patch_embeddings(x)
        if self.is_quant:
          x = self.dequant(x)
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        x = torch.cat((cls_tokens, x), dim=1)

        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        if self.is_quant:
          embeddings = self.quant(embeddings)
        return embeddings


class Block(nn.Module):
    def __init__(self, config, vis, is_quant=False):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config, is_quant=is_quant)
        self.attn = Attention(config, vis, is_quant=is_quant)

        self.is_quant = is_quant
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        if self.is_quant:
          x = self.dequant(x)
          h = self.dequant(h)
        x = x + h
        if self.is_quant:
          x = self.quant(x)

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        if self.is_quant:
          x = self.dequant(x)
          h = self.dequant(h)
        x = x + h
        if self.is_quant:
          x = self.quant(x)
        return x, weights


class Encoder(nn.Module):
    def __init__(self, config, vis, is_quant=False):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis, is_quant=is_quant)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights


class Transformer(nn.Module):
    def __init__(self, config, img_size, vis, is_quant=False):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size, is_quant=is_quant)
        self.encoder = Encoder(config, vis, is_quant=is_quant)

    def forward(self, input_ids):
        embedding_output = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)
        return encoded, attn_weights


class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False, is_quant=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier

        self.transformer = Transformer(config, img_size, vis, is_quant=is_quant)
        self.head = Linear(config.hidden_size, num_classes)

        self.is_quant = is_quant
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x, labels=None):
        if self.is_quant:
          x = self.quant(x)
        x, attn_weights = self.transformer(x)
        logits = self.head(x[:, 0])

        if self.is_quant:
          logits = self.dequant(logits)
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
            return loss
        else:
            return logits, attn_weights


CONFIGS = {
    'ViT-B_16': configs.get_b16_config(),
    'ViT-B_32': configs.get_b32_config(),
    'ViT-L_16': configs.get_l16_config(),
    'ViT-L_32': configs.get_l32_config(),
    'ViT-H_14': configs.get_h14_config(),
    'R50-ViT-B_16': configs.get_r50_b16_config(),
    'testing': configs.get_testing(),
}


In [5]:
def simple_accuracy(preds, labels):
    return (preds == labels).mean()

def test(model, test_loader):
  model.eval()
  all_preds, all_label = [], []
  epoch_iterator = tqdm(test_loader,
                        desc="Validating... (loss=X.X)",
                        bar_format="{l_bar}{r_bar}",
                        dynamic_ncols=True,
                        disable=args.local_rank not in [-1, 0])
  loss_fct = torch.nn.CrossEntropyLoss()
  for step, batch in enumerate(epoch_iterator):
      batch = tuple(t.to(args.device) for t in batch)
      x, y = batch
      with torch.no_grad():
          logits = model(x)[0]

          eval_loss = loss_fct(logits, y)
          # eval_losses.update(eval_loss.item())

          preds = torch.argmax(logits, dim=-1)
      # assert False
      if len(all_preds) == 0:
          all_preds.append(preds.detach().cpu().numpy())
          all_label.append(y.detach().cpu().numpy())
      else:
          all_preds[0] = np.append(
              all_preds[0], preds.detach().cpu().numpy(), axis=0
          )
          all_label[0] = np.append(
              all_label[0], y.detach().cpu().numpy(), axis=0
          )
      # epoch_iterator.set_description("Validating... (loss=%2.5f)" % eval_losses.val)

  all_preds, all_label = all_preds[0], all_label[0]
  accuracy = simple_accuracy(all_preds, all_label)
  print("Valid Accuracy: %2.5f" % accuracy)

### Original Model

In [6]:
config = CONFIGS[args.model_type]
num_classes = 10 if args.dataset == "cifar10" else 100
model = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes)
checkpoint = torch.load(args.pretrained_model, map_location=args.device)
model.load_state_dict(checkpoint)
model.to(args.device)
train_loader, test_loader = get_loader(args)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified


In [21]:
test(model, test_loader)

Validating... (loss=X.X): 100%|| 157/157 [00:33<00:00,  4.65it/s]

Valid Accuracy: 0.92870





### Post Training Dynamic Quantization

In [7]:
import torch.quantization
quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)
test(quantized_model, test_loader)

Validating... (loss=X.X): 100%|| 157/157 [21:22<00:00,  8.17s/it]

Valid Accuracy: 0.90870





### Post Training Static Quantization

In [8]:
model_quant = VisionTransformer(config, args.img_size, zero_head=True, num_classes=num_classes, is_quant=True)
checkpoint = torch.load(args.pretrained_model, map_location=args.device)
model_quant.load_state_dict(checkpoint)
model_quant.to(args.device)

VisionTransformer(
  (transformer): Transformer(
    (embeddings): Embeddings(
      (patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      (dropout): Dropout(p=0.1, inplace=False)
      (quant): QuantStub()
      (dequant): DeQuantStub()
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0): Block(
          (attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (ffn): Mlp(
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (quant): QuantStub()
            (dequant): DeQuantStub()
          )
          (attn): Attention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Li

In [9]:
static_quantized_model = model_quant
static_quantized_model.qconfig = torch.quantization.default_qconfig

static_quantized_model_prepared = torch.quantization.prepare(static_quantized_model)
static_quantized_model_prepared.eval()
epoch_iterator = tqdm(test_loader,
                      bar_format="{l_bar}{r_bar}",
                      dynamic_ncols=True,
                      disable=args.local_rank not in [-1, 0])

for step, batch in enumerate(epoch_iterator):
    batch = tuple(t.to(args.device) for t in batch)
    x, y = batch
    static_quantized_model_prepared(x)


model_prepared_int8 = torch.quantization.convert(static_quantized_model_prepared)

  reduce_range will be deprecated in a future release of PyTorch."
100%|| 157/157 [32:12<00:00, 12.31s/it]


In [10]:
test(model_prepared_int8, test_loader)

Validating... (loss=X.X): 100%|| 157/157 [23:44<00:00,  9.08s/it]

Valid Accuracy: 0.00980





### Convert Model Param

In [13]:
new_checkpoint = {}
qkv_weight = None
qkv_bias = None

new_checkpoint['cls_token'] = checkpoint['transformer.embeddings.cls_token']
new_checkpoint['pos_embed'] = checkpoint['transformer.embeddings.position_embeddings']
new_checkpoint['patch_embed.proj.weight'] = checkpoint['transformer.embeddings.patch_embeddings.weight']
new_checkpoint['patch_embed.proj.bias'] = checkpoint['transformer.embeddings.patch_embeddings.bias']

for key in checkpoint.keys():
  if 'transformer.encoder.layer' in key:
    new_key = key.replace('transformer.encoder.layer', 'blocks')
    if 'attn.query' in new_key:
      if 'weight' in new_key:
        qkv_weight = checkpoint[key]
      else:
        qkv_bias = checkpoint[key]
    elif 'attn.key' in new_key:
      if 'weight' in new_key:
        qkv_weight = torch.cat((qkv_weight, checkpoint[key]), dim=0)
      else:
        qkv_bias = torch.cat((qkv_bias, checkpoint[key]), dim=0)
    elif 'attn.value' in new_key:
      if 'weight' in new_key:
        qkv_weight = torch.cat((qkv_weight, checkpoint[key]), dim=0)
        new_key = new_key.replace('attn.value', 'attn.qkv')
        new_checkpoint[new_key] = qkv_weight
        qkv_weight = None
      else:
        qkv_bias = torch.cat((qkv_bias, checkpoint[key]), dim=0)
        new_key = new_key.replace('attn.value', 'attn.qkv')
        new_checkpoint[new_key] = qkv_bias
        qkv_bias = None
    else:
      new_key = new_key.replace('attn.out.', 'attn.proj.')
      new_key = new_key.replace('attention_norm', 'norm1')
      new_key = new_key.replace('ffn_norm', 'norm2')
      new_key = new_key.replace('ffn', 'mlp')
      new_checkpoint[new_key] = checkpoint[key]

new_checkpoint['norm.weight'] = checkpoint['transformer.encoder.encoder_norm.weight']
new_checkpoint['norm.bias'] = checkpoint['transformer.encoder.encoder_norm.bias']
new_checkpoint['head.weight'] = checkpoint['head.weight']
new_checkpoint['head.bias'] = checkpoint['head.bias']

model_to_save = new_checkpoint
checkpoint_dir = os.path.join(args.output_dir, "new_checkpoint.bin")
torch.save(model_to_save, checkpoint_dir)
logger.info("Saved model checkpoint to [DIR: %s]", args.output_dir)
