In [1]:
import torch
import pytorch_cifar_models
import torch.utils.data as data
import torchvision.transforms as T
from torchvision.datasets import CIFAR10, CIFAR100
from torchmetrics.classification import Accuracy
from common import *
from collections import OrderedDict
# import onnx

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
delta = 0.005
i = 0.005
max_sparsity = 0.24
mode = 'inf'
model_name = 'cifar100_vgg19_bn'

In [3]:
def get_val_transforms(mean, std):
    return T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std)
    ])
    
# CIFAR-10
# val_set = CIFAR10('./data', train=False, download=False, transform=get_val_transforms(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]))
# train_set = CIFAR10('./data', train=True, download=False)

# CIFAR-100
val_set = CIFAR100('./data', train=False, download=False, transform=get_val_transforms(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]))
train_set = CIFAR100('./data', train=True, download=False)

val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False)
val_loader

<torch.utils.data.dataloader.DataLoader at 0x7fe87c139bb0>

In [4]:
dt = Profile()
with dt:
    model = getattr(pytorch_cifar_models, model_name)(pretrained=True)
# model.eval()
dt.t * 1E3

311.8481636047363

In [5]:
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

In [6]:
old_total = get_num_parameters(model)

In [7]:
dependencies = get_dependency_graph(model, model_name)
residual_dependencies = get_residual_dependency(model, model_name)
parameters_to_prune = get_parameters_to_prune(model, model_name)
name_to_module = get_name_to_module(model)

# global_smallest_filter(parameters_to_prune, 0.2, mode)
while i < max_sparsity:
    print(f'Sparsity={i}')
    global_smallest_filter(parameters_to_prune, i, mode)
    for key in dependencies:
        mod = name_to_module[key]
        mod_dep = name_to_module[dependencies[key]]
        prune_kernel2(mod, mod_dep)
    for key in residual_dependencies:
        mod = name_to_module[key] 
        mod_dep = name_to_module[residual_dependencies[key]]
        prune_residual_filter(mod, mod_dep)
    i += delta
print(get_num_pruned_parameters(parameters_to_prune))
print(get_num_parameters(model))
# print(validate(val_loader, model))
# (30.656659364700317, 3.065665936470032, tensor(0.6995), tensor(0.8420))

Sparsity=0.005
Sparsity=0.01


  torch.range(0, module.weight.shape[0] - 1).to(module.weight.device), # filter idx


Sparsity=0.015
Sparsity=0.02
Sparsity=0.025
Sparsity=0.030000000000000002
Sparsity=0.035
Sparsity=0.04
Sparsity=0.045
Sparsity=0.049999999999999996
Sparsity=0.05499999999999999
Sparsity=0.05999999999999999
Sparsity=0.06499999999999999
Sparsity=0.06999999999999999
Sparsity=0.075
Sparsity=0.08
Sparsity=0.085
Sparsity=0.09000000000000001
Sparsity=0.09500000000000001
Sparsity=0.10000000000000002
Sparsity=0.10500000000000002
Sparsity=0.11000000000000003
Sparsity=0.11500000000000003
Sparsity=0.12000000000000004
Sparsity=0.12500000000000003
Sparsity=0.13000000000000003
Sparsity=0.13500000000000004
Sparsity=0.14000000000000004
Sparsity=0.14500000000000005
Sparsity=0.15000000000000005
Sparsity=0.15500000000000005
Sparsity=0.16000000000000006
Sparsity=0.16500000000000006
Sparsity=0.17000000000000007
Sparsity=0.17500000000000007
Sparsity=0.18000000000000008
Sparsity=0.18500000000000008
Sparsity=0.19000000000000009
Sparsity=0.1950000000000001
Sparsity=0.2000000000000001
Sparsity=0.2050000000000001

In [8]:
def fuse_conv_and_bn(conv, bn):
    # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          dilation=conv.dilation,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

    # Prepare filters
    # print(f'before {fusedconv.weight.is_leaf}')
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    fusedconv.weight.data = (torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
    # print(f'after {fusedconv.weight.is_leaf}')

    # Prepare spatial bias
    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fusedconv.bias.data = (torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
    
    # Prune fused layer
    prune.custom_from_mask(fusedconv, 'weight', mask=conv.weight_mask.data)
    bias_mask = getattr(conv, 'bias_mask', torch.ones_like(fusedconv.bias))
    prune.custom_from_mask(fusedconv, 'bias', mask=bias_mask)
    # prune.remove(fusedconv, 'weight')
    # prune.remove(fusedconv, 'bias')

    return fusedconv

In [9]:
for i in range(len(model.features)):
    if isinstance(model.features[i], nn.Conv2d) and isinstance(model.features[i + 1], nn.BatchNorm2d):
        model.features[i] = fuse_conv_and_bn(model.features[i], model.features[i + 1])
        model.features[i].requires_grad = False
        model.features[i + 1] = nn.Identity()

In [10]:
# model.eval()

In [11]:
get_num_pruned_parameters(parameters_to_prune) / get_num_parameters(model)

tensor(0.3488)

In [12]:
model.eval()
# validate(val_loader, model)
# (30.52274179458618, 3.0522741794586183, tensor(0.7004), tensor(0.8506))

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Identity()
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): Identity()
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): Identity()
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): Identity()
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): Identity()
    (16): ReLU(inplace=True)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): Identity()
    (19): ReLU(inplace=True)
    (20): Conv2d(256, 256, kernel_size=(3, 3), 

In [13]:
i_s = []
for i in range(len(model.features)):
    if isinstance(model.features[i], nn.Conv2d):
        i_s.append(i)
        
profiler = Profile()
indexing_time = Profile()
with profiler:
    for i in i_s:
        conv = model.features[i]
        with indexing_time:
            kernel_indices = get_kernel_indices(conv, pruned=False)
            filter_indices = get_filter_indices(conv, pruned=False)
        
        weight = conv.weight.data
        bias = conv.bias.data
        stride = conv.stride
        padding = conv.padding
        kernel_size = conv.kernel_size
        
        model.features[i] = nn.Conv2d(kernel_indices.size()[0], filter_indices.size()[0], kernel_size, stride, padding)
        model.features[i].weight.data = weight[filter_indices, :, :, :][:, kernel_indices, :, :]
        model.features[i].bias.data = bias[filter_indices]
    
    # i = i_s[-1]
    # conv = model.features[i]
    # kernel_indices = get_kernel_indices(conv, pruned=False)
    # filter_indices = get_filter_indices(conv, pruned=False)

    # weight = conv.weight.data
    # bias = conv.bias.data
    # stride = conv.stride
    # padding = conv.padding
    # kernel_size = conv.kernel_size

    # model.features[i] = nn.Conv2d(kernel_indices.size()[0], conv.out_channels, kernel_size, stride, padding)
    # model.features[i].weight.data = weight[:, kernel_indices, :, :]
    # model.features[i].bias.data = bias
    
    fc = model.classifier[0]
    
    weight = fc.weight.data
    bias = fc.bias.data
    out_features = fc.out_features
    
    model.classifier[0] = nn.Linear(filter_indices.size()[0], fc.out_features, bias=True)
    model.classifier[0].weight.data = weight[:, filter_indices]
    model.classifier[0].bias.data = bias
(profiler.t - indexing_time.t) * 1E3    

171.5080738067627

In [14]:
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Identity()
    (2): ReLU(inplace=True)
    (3): Conv2d(62, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): Identity()
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): Identity()
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): Identity()
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): Identity()
    (16): ReLU(inplace=True)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): Identity()
    (19): ReLU(inplace=True)
    (20): Conv2d(256, 256, kernel_size=(3, 3), 

In [15]:
import copy

features = nn.Sequential()
for name, module in model.features.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.MaxPool2d) or isinstance(module, nn.ReLU):
        features.add_module(name, copy.deepcopy(module))
        
classifier = copy.deepcopy(model.classifier)

new_model = nn.Sequential(features, nn.Flatten(), classifier)

In [31]:
new_model

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 62, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU(inplace=True)
    (3): Conv2d(62, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (16): ReLU(inplace=True)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): ReLU(inplace=True)
    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), p

In [25]:
new_model.eval()
validate(val_loader, new_model)

100%|██████████| 40/40 [00:38<00:00,  1.04it/s]


(35.002033710479736, 3.5002033710479736, tensor(0.7092), tensor(0.8642))

In [27]:
dummy_input = torch.rand((1, 3, 32, 32))
input_names = ['input']
output_names = ['output']
onnx_file = f'{model_name}_reduced.onnx'
torch.onnx.export(model, dummy_input, onnx_file, input_names=input_names, output_names=output_names,
                    dynamic_axes={
                        'input': {0: 'batch_size'}
                    })

[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.


In [30]:
import onnxruntime as ort 
onnx_model = ort.InferenceSession(f'{model_name}_reduced.onnx', providers=['CPUExecutionProvider'])
validate_onnx(val_loader, onnx_model)

100%|██████████| 40/40 [00:45<00:00,  1.14s/it]


(42.48503255844116, 4.248503255844117, tensor(0.7092), tensor(0.8642))

In [29]:
onnx_model_orig = ort.InferenceSession(f'{model_name}.onnx', providers=['CPUExecutionProvider'])
validate_onnx(val_loader, onnx_model)

100%|██████████| 40/40 [00:47<00:00,  1.19s/it]


(44.22174143791199, 4.4221741437911986, tensor(0.7092), tensor(0.8642))

In [24]:
validate(val_loader, model)

100%|██████████| 40/40 [00:40<00:00,  1.01s/it]


(37.082294940948486, 3.708229494094849, tensor(0.7092), tensor(0.8642))

In [26]:
model_orig = getattr(pytorch_cifar_models, model_name)(pretrained=True)
validate(val_loader, model_orig)

100%|██████████| 40/40 [00:38<00:00,  1.03it/s]


(35.42549657821655, 3.542549657821655, tensor(0.7043), tensor(0.8642))

In [16]:
# validate(val_loader, model)

In [17]:
new_total = get_num_parameters(model)

In [18]:
1.0 - (new_total / old_total)

0.3512970403071919

In [19]:
new_total

13371068

In [20]:
old_total

20612004