In [1]:
import torch
import pytorch_cifar_models
import torch.utils.data as data
import torchvision.transforms as T
from torchvision.datasets import CIFAR10, CIFAR100
from torchmetrics.classification import Accuracy
from common import *
from collections import OrderedDict
# import onnx

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
delta = 0.005
i = 0.005
max_sparsity = 0.24
mode = 'inf'
model_name = 'cifar100_vgg19_bn'

In [3]:
def get_val_transforms(mean, std):
    return T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std)
    ])
    
# CIFAR-10
# val_set = CIFAR10('./data', train=False, download=False, transform=get_val_transforms(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]))
# train_set = CIFAR10('./data', train=True, download=False)

# CIFAR-100
val_set = CIFAR100('./data', train=False, download=False, transform=get_val_transforms(mean=[0.5070, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]))
train_set = CIFAR100('./data', train=True, download=False)

val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False)
val_loader

<torch.utils.data.dataloader.DataLoader at 0x7f41982c8250>

In [4]:
dt = Profile()
with dt:
    model = getattr(pytorch_cifar_models, model_name)(pretrained=True)
# model.eval()
dt.t * 1E3

1083.9486122131348

In [5]:
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

In [81]:
old_total = get_num_parameters(model)

In [82]:
dependencies = get_dependency_graph(model, model_name)
residual_dependencies = get_residual_dependency(model, model_name)
parameters_to_prune = get_parameters_to_prune(model, model_name)
name_to_module = get_name_to_module(model)

# for mod, param in parameters_to_prune:
#     prune.ln_structured(mod, param, 0.2, float('inf'), 0)
# for key in dependencies:
#     mod = name_to_module[key]
#     mod_dep = name_to_module[dependencies[key]]
#     prune_kernel2(mod, mod_dep)
# for key in residual_dependencies:
#     mod = name_to_module[key] 
#     mod_dep = name_to_module[residual_dependencies[key]]
#     prune_residual_filter(mod, mod_dep)

# global_smallest_filter(parameters_to_prune, 0.2, mode)
while i < max_sparsity:
    print(f'Sparsity={i}')
    global_smallest_filter(parameters_to_prune, i, mode)
    for key in dependencies:
        mod = name_to_module[key]
        mod_dep = name_to_module[dependencies[key]]
        prune_kernel2(mod, mod_dep)
    for key in residual_dependencies:
        mod = name_to_module[key] 
        mod_dep = name_to_module[residual_dependencies[key]]
        prune_residual_filter(mod, mod_dep)
    i += delta
print(get_num_pruned_parameters(parameters_to_prune))
print(get_num_parameters(model))
# print(validate(val_loader, model))
# (30.656659364700317, 3.065665936470032, tensor(0.6995), tensor(0.8420))

Sparsity=0.005
Sparsity=0.01
Sparsity=0.015
Sparsity=0.02
Sparsity=0.025
Sparsity=0.030000000000000002
Sparsity=0.035
Sparsity=0.04
Sparsity=0.045
Sparsity=0.049999999999999996
Sparsity=0.05499999999999999
Sparsity=0.05999999999999999
Sparsity=0.06499999999999999
Sparsity=0.06999999999999999
Sparsity=0.075
Sparsity=0.08
Sparsity=0.085
Sparsity=0.09000000000000001
Sparsity=0.09500000000000001
Sparsity=0.10000000000000002
Sparsity=0.10500000000000002
Sparsity=0.11000000000000003
Sparsity=0.11500000000000003
Sparsity=0.12000000000000004
Sparsity=0.12500000000000003
Sparsity=0.13000000000000003
Sparsity=0.13500000000000004
Sparsity=0.14000000000000004
Sparsity=0.14500000000000005
Sparsity=0.15000000000000005
Sparsity=0.15500000000000005
Sparsity=0.16000000000000006
Sparsity=0.16500000000000006
Sparsity=0.17000000000000007
Sparsity=0.17500000000000007
Sparsity=0.18000000000000008
Sparsity=0.18500000000000008
Sparsity=0.19000000000000009
Sparsity=0.1950000000000001
Sparsity=0.200000000000000

In [83]:
def fuse_conv_and_bn(conv, bn):
    # Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
    fusedconv = nn.Conv2d(conv.in_channels,
                          conv.out_channels,
                          kernel_size=conv.kernel_size,
                          stride=conv.stride,
                          padding=conv.padding,
                          dilation=conv.dilation,
                          groups=conv.groups,
                          bias=True).requires_grad_(False).to(conv.weight.device)

    # Prepare filters
    # print(f'before {fusedconv.weight.is_leaf}')
    w_conv = conv.weight.clone().view(conv.out_channels, -1)
    w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
    fusedconv.weight.data = (torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
    # print(f'after {fusedconv.weight.is_leaf}')

    # Prepare spatial bias
    b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
    b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
    fusedconv.bias.data = (torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
    
    # Prune fused layer
    prune.custom_from_mask(fusedconv, 'weight', mask=conv.weight_mask.data)
    bias_mask = getattr(conv, 'bias_mask', torch.ones_like(fusedconv.bias))
    prune.custom_from_mask(fusedconv, 'bias', mask=bias_mask)
    # prune.remove(fusedconv, 'weight')
    # prune.remove(fusedconv, 'bias')

    return fusedconv

In [84]:
for i in range(len(model.features)):
    if isinstance(model.features[i], nn.Conv2d) and isinstance(model.features[i + 1], nn.BatchNorm2d):
        model.features[i] = fuse_conv_and_bn(model.features[i], model.features[i + 1])
        model.features[i].requires_grad = False
        model.features[i + 1] = nn.Identity()

In [85]:
# model.eval()

In [86]:
get_num_pruned_parameters(parameters_to_prune) / get_num_parameters(model)

tensor(0.3488)

In [87]:
model.eval()
# validate(val_loader, model)
# (30.52274179458618, 3.0522741794586183, tensor(0.7004), tensor(0.8506))

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Identity()
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): Identity()
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): Identity()
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): Identity()
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): Identity()
    (16): ReLU(inplace=True)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): Identity()
    (19): ReLU(inplace=True)
    (20): Conv2d(256, 256, kernel_size=(3, 3), 

In [88]:
# Pruned architecture
pruned_architecture = {}

# Original architecture
original_architecture = {}

# Pruned architecture in tensor format
pruned_architecture_pt = torch.tensor([], dtype=torch.int32)

# Original architecture in tensor format
original_architecture_pt = torch.tensor([], dtype=torch.int32)

# Iterate over all conv layers
for name, module in model.named_modules():
    if not isinstance(module, nn.Conv2d):
        continue
    
    # Get the filter and kernel indices
    conv = module
    kernel_indices = get_kernel_indices(conv, pruned=False)
    filter_indices = get_filter_indices(conv, pruned=False)
    
    # Get the params
    weight = conv.weight.data
    bias = conv.bias.data
    stride = conv.stride
    padding = conv.padding
    kernel_size = conv.kernel_size
    
    # Create the new module
    new_mod = nn.Conv2d(kernel_indices.size()[0], filter_indices.size()[0], kernel_size, stride, padding)
    new_mod.weight.data = weight[filter_indices, :, :, :][:, kernel_indices, :, :]
    new_mod.bias.data = bias[filter_indices]
    
    # Store the original architecture of the layer
    layer = [
        f'model.{name}',
        module.in_channels,
        module.out_channels, 
        module.kernel_size[0],
        module.stride[0],
        module.padding[0],
        None,
        None,
    ]
    original_architecture[f'model.{name}'] = layer
    
    # Store the pruned architecture of the layer
    layer = [
        f'model.{name}',
        new_mod.in_channels,
        new_mod.out_channels,
        new_mod.kernel_size[0],
        new_mod.stride[0],
        new_mod.padding[0],
        filter_indices.tolist(),
        kernel_indices.tolist(),                
    ]
    pruned_architecture[f'model.{name}'] = layer 

In [89]:
# Initialize the names list
pruned_layer_names = []

# Convert to a torch friendly format
for i, p in enumerate(list(pruned_architecture.values())):
    # Insert the layer name into the names list
    pruned_layer_names.append(p[0])
    
    # Convert layer information into tensor
    args = torch.tensor([
        i,
        p[1], # in_channels
        p[2], # out_channels
        p[3], # kernel_size
        p[4], # stride
        p[5], # padding
        len(p[6]) if p[6] != None else 0, # Size of kernel indices list
        len(p[7]) if p[7]  != None else 0, # Size of filter indices list        
    ], dtype=torch.long)
    
    # Get the filter indices tensor
    kernel_indices = torch.tensor(p[6], dtype=torch.long) if p[6] != None else torch.tensor([], dtype=torch.long)
    
    # Get the kernel indices tensor
    filter_indices = torch.tensor(p[7], dtype=torch.long) if p[7] != None else torch.tensor([], dtype=torch.long)
    
    # Add the layer tensor to the architecture tensor
    p_pt = torch.hstack([args, kernel_indices, filter_indices])
    pruned_architecture_pt = torch.hstack([pruned_architecture_pt, p_pt])

# Initialize the names list
original_layer_names = []

# Convert to a torch friendly format
for i, p in enumerate(list(original_architecture.values())):
    # Insert the layer name into the names list
    original_layer_names.append(p[0])
    
    # Convert layer information into tensor
    args = torch.tensor([
        i,
        p[1], # in_channels
        p[2], # out_channels
        p[3], # kernel_size
        p[4], # stride
        p[5], # padding
        len(p[6]) if p[6] != None else 0, # Size of kernel indices list
        len(p[7]) if p[7]  != None else 0, # Size of filter indices list        
    ], dtype=torch.long)
    
    # Get the filter indices tensor
    kernel_indices = torch.tensor(p[6], dtype=torch.long) if p[6] != None else torch.tensor([], dtype=torch.long)
    
    # Get the kernel indices tensor
    filter_indices = torch.tensor(p[7], dtype=torch.long) if p[7] != None else torch.tensor([], dtype=torch.long)
    
    # Add the layer tensor to the architecture tensor
    p_pt = torch.hstack([args, kernel_indices, filter_indices])
    original_architecture_pt = torch.hstack([original_architecture_pt, p_pt])

In [90]:
torch.save({
    'data': pruned_architecture_pt,
    'names': pruned_layer_names
}, f'{model_name}_prun.pt')

torch.save({
    'data': original_architecture_pt,
    'names': original_layer_names
}, f'{model_name}_orig.pt')

In [13]:
i_s = []
for i in range(len(model.features)):
    if isinstance(model.features[i], nn.Conv2d):
        i_s.append(i)
        
profiler = Profile()
indexing_time = Profile()
with profiler:
    for i in i_s:
        conv = model.features[i]
        with indexing_time:
            kernel_indices = get_kernel_indices(conv, pruned=False)
            filter_indices = get_filter_indices(conv, pruned=False)
        
        weight = conv.weight.data
        bias = conv.bias.data
        stride = conv.stride
        padding = conv.padding
        kernel_size = conv.kernel_size
        
        model.features[i] = nn.Conv2d(kernel_indices.size()[0], filter_indices.size()[0], kernel_size, stride, padding)
        model.features[i].weight.data = weight[filter_indices, :, :, :][:, kernel_indices, :, :]
        model.features[i].bias.data = bias[filter_indices]
    
    # i = i_s[-1]
    # conv = model.features[i]
    # kernel_indices = get_kernel_indices(conv, pruned=False)
    # filter_indices = get_filter_indices(conv, pruned=False)

    # weight = conv.weight.data
    # bias = conv.bias.data
    # stride = conv.stride
    # padding = conv.padding
    # kernel_size = conv.kernel_size

    # model.features[i] = nn.Conv2d(kernel_indices.size()[0], conv.out_channels, kernel_size, stride, padding)
    # model.features[i].weight.data = weight[:, kernel_indices, :, :]
    # model.features[i].bias.data = bias
    
    fc = model.classifier[0]
    
    weight = fc.weight.data
    bias = fc.bias.data
    out_features = fc.out_features
    
    model.classifier[0] = nn.Linear(filter_indices.size()[0], fc.out_features, bias=True)
    model.classifier[0].weight.data = weight[:, filter_indices]
    model.classifier[0].bias.data = bias
(profiler.t - indexing_time.t) * 1E3    

61.02156639099121

In [14]:
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Identity()
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(51, 102, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): Identity()
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(102, 205, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): Identity()
    (10): ReLU(inplace=True)
    (11): Conv2d(205, 205, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): Identity()
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(205, 410, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (16): Identity()
    (17): ReLU(inplace=True)
    (18): Conv2d(410, 410, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): 

In [15]:
import copy

In [16]:
features = nn.Sequential()
for name, module in model.features.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.MaxPool2d) or isinstance(module, nn.ReLU):
        features.add_module(name, copy.deepcopy(module))
        
classifier = copy.deepcopy(model.classifier)

reduced_model = nn.Sequential(features, nn.Flatten(), classifier)

model_orig = getattr(pytorch_cifar_models, model_name)(pretrained=True)
fuse_batchnorms(model_orig, model_name, prune_=False)
features = nn.Sequential()
for name, module in model_orig.features.named_modules():
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.MaxPool2d) or isinstance(module, nn.ReLU):
        features.add_module(name, copy.deepcopy(module))
        
classifier = copy.deepcopy(model_orig.classifier)

new_model = nn.Sequential(features, nn.Flatten(), classifier)

In [17]:
new_model.eval()

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False

In [18]:
reduced_model.eval()

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 51, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(51, 102, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(102, 205, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU(inplace=True)
    (11): Conv2d(205, 205, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(205, 410, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace=True)
    (18): Conv2d(410, 410, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False

In [19]:
validate(val_loader, new_model)

100%|██████████| 40/40 [00:17<00:00,  2.26it/s]


(14.377016067504883, 1.4377016067504882, tensor(0.7078), tensor(0.8887))

In [20]:
validate(val_loader, reduced_model)

100%|██████████| 40/40 [00:16<00:00,  2.46it/s]


(12.876551628112793, 1.2876551628112793, tensor(0.3943), tensor(0.6476))

In [21]:
get_num_parameters(new_model), get_num_parameters(reduced_model)

(9797092, 6436351)

In [22]:
dummy_input = torch.rand((1, 3, 32, 32))
input_names = ['input']
output_names = ['output']
onnx_file = f'{model_name}_reduced.onnx'
torch.onnx.export(reduced_model, dummy_input, onnx_file, input_names=input_names, output_names=output_names,
                    dynamic_axes={
                        'input': {0: 'batch_size'}
                    })

dummy_input = torch.rand((1, 3, 32, 32))
input_names = ['input']
output_names = ['output']
onnx_file = f'{model_name}_new.onnx'
torch.onnx.export(new_model, dummy_input, onnx_file, input_names=input_names, output_names=output_names,
                    dynamic_axes={
                        'input': {0: 'batch_size'}
                    })

[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.


In [23]:
import onnxruntime as ort 

In [24]:
onnx_model_new = ort.InferenceSession(f'{model_name}_new.onnx', providers=['CPUExecutionProvider'])
validate_onnx(val_loader, onnx_model_new)

100%|██████████| 40/40 [00:22<00:00,  1.76it/s]


(19.49018168449402, 1.949018168449402, tensor(0.7078), tensor(0.8887))

In [25]:
onnx_model_reduced = ort.InferenceSession(f'{model_name}_reduced.onnx', providers=['CPUExecutionProvider'])
validate_onnx(val_loader, onnx_model_reduced)

100%|██████████| 40/40 [00:25<00:00,  1.56it/s]


(22.427286624908447, 2.2427286624908445, tensor(0.3943), tensor(0.6476))

In [26]:
validate(val_loader, model)

100%|██████████| 40/40 [00:16<00:00,  2.48it/s]


(12.925283193588257, 1.2925283193588257, tensor(0.3943), tensor(0.6476))

In [33]:
model_orig = getattr(pytorch_cifar_models, model_name)(pretrained=True)
model.eval()
validate(val_loader, model_orig)

100%|██████████| 40/40 [00:17<00:00,  2.23it/s]


(14.637847423553467, 1.4637847423553467, tensor(0.6564), tensor(0.8347))

In [28]:
# validate(val_loader, model)

In [29]:
new_total = get_num_parameters(model)

In [30]:
1.0 - (new_total / old_total)

0.34340342088973164

In [31]:
new_total

6436351

In [32]:
old_total

9802596