# 0. Libraries

In [1]:
import torch
from torchvision.models.vgg import vgg16
from torchvision.models.feature_extraction import create_feature_extractor

import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import numpy as np

# paper reference: "https://arxiv.org/abs/1906.01493"
# calculation reference: "https://www.baeldung.com/cs/pca"


  from .autonotebook import tqdm as notebook_tqdm


# 1. Functions

In [2]:
def compute_PCA(feature, threshold = 0.95, status_print = False):

    total_channel = feature.shape[1]

    activations = (feature.data).cpu().numpy()
    # print('shape of activations are:',activations.shape)
    a=activations.swapaxes(1,2).swapaxes(2,3)
    a_shape=a.shape
    # print('reshaped ativations are of shape',a.shape)
    # raw_input()

    pca = PCA() #number of components should be equal to the number of filters
    pca.fit(a.reshape(a_shape[0]*a_shape[1]*a_shape[2],a_shape[3]))
    a_trans=pca.transform(a.reshape(a_shape[0]*a_shape[1]*a_shape[2],a_shape[3]))
    # print('explained variance ratio is:',pca.explained_variance_ratio_)
    # raw_input()
    cumsum = np.cumsum(pca.explained_variance_ratio_)
    d = np.argmax(cumsum >= threshold)
    
    # print(cumsum.shape)
    # print("pca: ", pca.components_.shape)
    
    # importance_ratio = d/total_channel

    if status_print:
        print('need at least {} filter(s) out of {} components to exceed threshold'.format(d, total_channel))
    
    return d


# 2. define model & extraction point

In [3]:
model = vgg16(pretrained = True)
model.eval()

return_nodes = {
    "features.1": "layer1",
    "features.3": "layer2",
    "features.6": "layer3",
    "features.8": "layer4",
    "features.11": "layer5",
    "features.13": "layer6",
    "features.15": "layer7",
    "features.18": "layer8",
    "features.20": "layer9",
    "features.22": "layer10",
    "features.25": "layer11",
    "features.27": "layer12",
    "features.29": "layer13",
    "features.4": "layer14",
}

extractor_model = create_feature_extractor(model, return_nodes=return_nodes)



  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


# 3. evaluate PCA per layer

In [4]:
x = torch.rand(1, 3, 224, 224)


intermediate_outputs = extractor_model(x)

# more important layer carries smaller pca_score
#pca_score가 작을수록 더 중요한 레이어 입니다

for k in intermediate_outputs.keys():
    print(k)
    feature = intermediate_outputs[k]
    pca_score = compute_PCA(feature, 0.80, status_print = True)
    # print(pca_score)
    # break
    



layer1
need at least 5 filter(s) out of 64 components to exceed threshold
layer2
need at least 12 filter(s) out of 64 components to exceed threshold
layer14
need at least 14 filter(s) out of 64 components to exceed threshold
layer3
need at least 25 filter(s) out of 128 components to exceed threshold
layer4
need at least 32 filter(s) out of 128 components to exceed threshold
layer5
need at least 51 filter(s) out of 256 components to exceed threshold
layer6
need at least 47 filter(s) out of 256 components to exceed threshold
layer7
need at least 32 filter(s) out of 256 components to exceed threshold
layer8
need at least 33 filter(s) out of 512 components to exceed threshold
layer9
need at least 21 filter(s) out of 512 components to exceed threshold
layer10
need at least 14 filter(s) out of 512 components to exceed threshold
layer11
need at least 9 filter(s) out of 512 components to exceed threshold
layer12
need at least 9 filter(s) out of 512 components to exceed threshold
layer13
need a

# 4. manual model(VGG) formulation

In [5]:
import torch.nn as nn
import torch

VGG_types = {
    'VGG11' : [64, 'M', 128, 'M', 256, 256, 'M', 512,512, 'M',512,512,'M'],
    'VGG13' : [64,64, 'M', 128, 128, 'M', 256, 256, 'M', 512,512, 'M', 512,512,'M'],
    'VGG16' : [64,64, 'M', 128, 128, 'M', 256, 256,256, 'M', 512,512,512, 'M',512,512,512,'M'],
    'VGG19' : [64,64, 'M', 128, 128, 'M', 256, 256,256,256, 'M', 512,512,512,512, 'M',512,512,512,512,'M']
}

class VGGnet(torch.nn.Module):
    def __init__(self, model, in_channels=3, num_classes=10, init_weights=True):
        super(VGGnet,self).__init__()
        
        self.quant = torch.ao.quantization.QuantStub()	# 입력을 양자화 하는 QuantStub()
        
        
        self.in_channels = in_channels

        # create conv_layers corresponding to VGG type
        self.conv_layers = self.create_conv_layers(VGG_types[model])

        self.fcs = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        self.dequant = torch.ao.quantization.DeQuantStub() # 출력을 역양자화 하는 DeQuantStub()
        # weight initialization
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.quant(x)
        x = self.conv_layers(x)
        x = self.dequant(x)
        x = x.view(-1, 512 * 7 * 7)
        x = self.fcs(x)
        return x

    # defint weight initialization function
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    # define a function to create conv layer taken the key of VGG_type dict 
    def create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels # 3

        for x in architecture:
            if type(x) == int: # int means conv layer
                out_channels = x

                layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                     kernel_size=(3,3), stride=(1,1), padding=(1,1)),
                           nn.BatchNorm2d(x),
                           nn.ReLU()]
                in_channels = x
            elif x == 'M':
                layers += [nn.MaxPool2d(kernel_size=(2,2), stride=(2,2))]

        
        return nn.Sequential(*layers)

# define device
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(device)

# creat VGGnet object
model = VGGnet('VGG16', in_channels=3, num_classes=10, init_weights=True)
print(model)
model.eval()
# model.conv_layers[0].qconfig = torch.ao.quantization.get_default_qconfig('x86')
# model_fp32_fused = torch.ao.quantization.fuse_modules(model, [['conv_layers']])
# model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)
# m = nn.quantized.Conv2d(512, 512, (3, 3), stride=(1, 1), padding=(1, 1))
# model.conv_layers[27] = m

# model.fcs[0].quantization

# x = torch.rand(1, 3, 224, 224)
# model(x)

VGGnet(
  (quant): QuantStub()
  (conv_layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 25

VGGnet(
  (quant): QuantStub()
  (conv_layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 25

# 5. sample quantization

In [6]:
import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()

print(model_fp32)

model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')


model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])


model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)

input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)


model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

res = model_int8(input_fp32)

M(
  (quant): QuantStub()
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (relu): ReLU()
  (dequant): DeQuantStub()
)


In [7]:
model_fp32_prepared

M(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-2.575468063354492, max_val=1.9080582857131958)
  )
  (conv): ConvReLU2d(
    (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (activation_post_process): MinMaxObserver(min_val=0.0, max_val=1.9459654092788696)
  )
  (relu): Identity()
  (dequant): DeQuantStub()
)

In [8]:
model_fp32_prepared

M(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-2.575468063354492, max_val=1.9080582857131958)
  )
  (conv): ConvReLU2d(
    (0): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (activation_post_process): MinMaxObserver(min_val=0.0, max_val=1.9459654092788696)
  )
  (relu): Identity()
  (dequant): DeQuantStub()
)

In [9]:
model_fp32

M(
  (quant): QuantStub()
  (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  (relu): ReLU()
  (dequant): DeQuantStub()
)