In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

In [301]:
from pruning.tracer import GroupTracer
from torchvision.models import resnet18, ResNet18_Weights
from pruning.utils import fuse_batchnorm, run_model
from pruning.feature_merging.importance import GroupNormImportance, FPGMImportance
import torch
import copy
import torch_pruning

model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1).eval()
fuse_batchnorm(model)
original_model = copy.deepcopy(model)
input_example = torch.rand((1, 3, 224, 224))
output_initial = run_model(model, input_example)

In [296]:
(run_model(model, input_example).argmax() - output_initial.argmax()).abs()

tensor(0)

In [258]:
ignored_dims = {
    'fc.weight': [0],
    'fc.bias': [0]
}

In [259]:
groups = GroupTracer(model, ignored_dims).build_groups(input_example)

In [260]:
imp_fn = FPGMImportance()

In [261]:
imp = imp_fn(groups[0])

conv1.weight
conv1.bias
layer1.0.conv1.weight
layer1.0.conv2.weight
layer1.0.conv2.bias
layer1.1.conv1.weight
layer1.1.conv2.weight
layer1.1.conv2.bias
layer2.0.conv1.weight
layer2.0.downsample.0.weight


In [11]:
from torch_pruning.pruner.importance import GroupNormImportance as GroupNormImportance_manual, FPGMImportance as FPGMImportance_manual

imp_fn_manual = FPGMImportance_manual()
imp_fn_manual(gr)

tensor([0.8322, 0.9324, 0.7254, 1.1597, 0.8198, 1.3088, 1.1292, 1.2364, 1.1167,
        0.9771, 0.7395, 1.4061, 1.7708, 0.8550, 0.8431, 0.8932, 0.8896, 1.2863,
        1.0927, 0.7480, 1.0713, 0.8311, 0.8710, 0.8811, 0.8360, 0.7671, 0.7554,
        0.8392, 1.4318, 0.7714, 0.8671, 0.8371, 1.1592, 0.9691, 0.7886, 0.7745,
        1.1983, 0.8831, 1.2548, 0.9161, 1.5906, 0.7302, 0.7808, 0.9070, 0.7817,
        0.9813, 1.0839, 0.7902, 0.9429, 1.7601, 1.2450, 0.8054, 1.6630, 0.7980,
        1.0298, 0.7596, 0.7984, 0.7913, 0.9340, 0.8354, 0.7648, 1.0479, 1.0205,
        1.6925])

In [10]:
DG = torch_pruning.DependencyGraph()
DG.build_dependency(model, example_inputs=input_example)
ignored_layers = ['fc']

groups_tp = DG.get_all_groups(ignored_layers=[model.get_submodule(l) for l in ignored_layers])

idx_ = 3
for i, gr in enumerate(groups_tp):
    print(i)
    print(gr)
    print(i, idx_)
    if i == idx_:
        break

    # if i == idx:
    #     print(c)
    #     break

0

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on layer4.1.conv2 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) => prune_out_channels on layer4.1.conv2 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), #idxs=512
[1] prune_out_channels on layer4.1.conv2 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) => prune_out_channels on _ElementWiseOp_4(AddBackward0), #idxs=512
[2] prune_out_channels on _ElementWiseOp_4(AddBackward0) => prune_out_channels on _ElementWiseOp_5(ReluBackward0), #idxs=512
[3] prune_out_channels on _ElementWiseOp_4(AddBackward0) => prune_out_channels on _ElementWiseOp_3(ReluBackward0), #idxs=512
[4] prune_out_channels on _ElementWiseOp_3(ReluBackward0) => prune_out_channels on _ElementWiseOp_2(MeanBackward1), #idxs=512
[5] prune_out_channels on _ElementWiseOp_2(MeanBackward1) => prune_out_channels on _Reshape_0(), #idxs=512
[6] prune_out_

In [262]:
example_weights = torch.zeros((32, 64))
important_indices = imp.argsort(descending=True)


In [None]:
example_weights[:, ]

In [268]:
isinstance(dummy_important_indices, torch.Tensor)

False

In [273]:
import torch
import torch.nn as nn

class FeatureMergingMLP(nn.Module):
    def __init__(self, input_channels, output_channels, important_indices=None):
        super(FeatureMergingMLP, self).__init__()
        self.output_channels = output_channels
        self.fc = nn.Linear(input_channels, self.output_channels, bias=False)
        self.important_indices = important_indices
        if isinstance(self.important_indices, (torch.Tensor, list)):
            self.important_indices = self.important_indices[:output_channels]
        # TODO: CHECK HOW IT WORKS IN TP ()
        self._initialize_weights()


    def _initialize_weights(self):
        if isinstance(self.important_indices, (torch.Tensor, list)):
            weight_init = torch.zeros_like(self.fc.weight)
            for idx, preserve_idx in enumerate(self.important_indices):
                weight_init[idx, preserve_idx] = 1.0
        else:
            weight_init = torch.ones_like(self.fc.weight)
        
        self.fc.weight.data.copy_(weight_init)

    def forward(self, x):
        return self.fc(x)

class FeatureMerging(nn.Module):
    def __init__(self, size, new_size, dim, important_indices=None):
        super(FeatureMerging, self).__init__()
        self.size = size
        self.new_size = new_size
        self.dim = dim
        self.mlp = FeatureMergingMLP(size, new_size, important_indices)

    def forward(self, x):
        reshaped_x = (
            x.transpose(self.dim, 0)  # transpose to make pruning dim the first dimension
            .reshape(self.size, -1)  # flatten all dimensions except the pruning idm
            .permute(1, 0)           # switch the first and second dimensions to fit MLP
        )

        processed_weights = self.mlp(reshaped_x)
        final_shape = list(x.shape)
        final_shape[self.dim] = self.new_size
        return processed_weights.permute(1, 0).reshape(final_shape)


size = 64
new_size = 32
dim = 0

old_tensor = torch.randint(10, (size, 512, 3, 3))
dummy_important_indices = [0, 2, 1]
feature_merging = FeatureMerging(size, new_size, dim, dummy_important_indices)

assert torch.allclose(
    old_tensor[dummy_important_indices].float(), 
    feature_merging(old_tensor.float())[:len(dummy_important_indices)]
)

In [299]:
for group in groups:
    size = group.size
    print(group)
    break

conv1.weight: 0
conv1.bias: 0
layer1.0.conv1.weight: 1
layer1.0.conv2.weight: 0
layer1.0.conv2.bias: 0
layer1.1.conv1.weight: 1
layer1.1.conv2.weight: 0
layer1.1.conv2.bias: 0
layer2.0.conv1.weight: 1
layer2.0.downsample.0.weight: 1



In [303]:
import torch.nn.utils.parametrize as parametrize
import copy

model = copy.deepcopy(original_model)

new_size = 64
for param in group.params:
    module_name, wb = param['name'].rsplit('.', 1)
    submodule = model.get_submodule(module_name)
    important_indices = important_indices[:new_size]
    if wb == 'weight':
        parametrize.register_parametrization(
            submodule, 
            "weight", 
            FeatureMerging(group.size, new_size, param['dim']), 
            unsafe=True
        )
    elif wb == 'bias':
        submodule.bias.data = submodule.bias.data[important_indices]

In [304]:
(run_model(model, input_example).argmax() - output_initial.argmax()).abs()

tensor(428)

In [240]:
check = FeatureMerging(group.size, new_size, 0, important_indices)

In [281]:
original_weights = model.conv1.parametrizations.weight.original

In [280]:
important_indices

tensor([12, 49, 63, 52, 40, 28, 11,  5, 17, 38, 50,  7, 36,  3, 32,  6,  8, 18,
        46, 20, 61, 54, 62, 45,  9, 33, 48, 58,  1, 39, 43, 15, 16, 37, 23, 22,
        30, 13, 14, 27, 31, 24, 59,  0, 21,  4, 51, 56, 53, 57, 47, 34, 44, 42,
        35, 29, 25, 60, 55, 26, 19, 10, 41,  2])

In [283]:
assert torch.allclose(
    original_weights[important_indices[:new_size]].float(), 
    model.conv1.weight
)

In [242]:
original_weights.shape

torch.Size([64, 3, 7, 7])

In [243]:
new_w = check(original_weights)
new_w.shape

torch.Size([147, 64])
torch.Size([64, 3, 7, 7]) [32, 3, 7, 7]
torch.Size([147, 32])


torch.Size([32, 3, 7, 7])

In [251]:
assert torch.allclose(
    original_weights[important_indices],
    new_w[:len(important_indices)]
)

In [252]:
assert torch.allclose(
    original_weights[0],
    new_w[0]
)

In [249]:
new_w[:len(important_indices)].shape

torch.Size([3, 3, 7, 7])

In [183]:
important_indices

[0, 2, 5]

In [173]:
for name, p in model.conv1.named_parameters():
    print(name)

bias
parametrizations.weight.original
parametrizations.weight.0.mlp.fc.weight
parametrizations.weight.1.mlp.fc.weight
parametrizations.weight.2.mlp.fc.weight


In [25]:
ch1, ch2, k1, k2 = model.conv1.weight.shape

In [13]:
model.conv1.weight.shape

torch.Size([64, 3, 7, 7])

In [24]:
check_linear = torch.nn.Linear(64, 32)
reshaped_weight = model.conv1.weight.data.view(64, -1).T

In [14]:
example_pruning = check_linear(reshaped_weight)

In [26]:
check_linear.weight.shape

torch.Size([32, 64])

conv1.weight: 0
conv1.bias: 0
layer1.0.conv1.weight: 1
layer1.0.conv2.weight: 0
layer1.0.conv2.bias: 0
layer1.1.conv1.weight: 1
layer1.1.conv2.weight: 0
layer1.1.conv2.bias: 0
layer2.0.conv1.weight: 1
layer2.0.downsample.0.weight: 1



In [5]:
import os
import torchvision

root = os.path.expanduser("~") + "/datasets/"
batch_size = 64
augmentation = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ]
)

preprocess = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
    ]
)

train_dataset = torchvision.datasets.CIFAR10(
    root=root, train=True, download=True, transform=augmentation
)

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)

val_dataset = torchvision.datasets.CIFAR10(
    root=root, train=False, download=True, transform=preprocess
)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size, shuffle=False)
loaders = {"train": train_dataloader, "valid": val_dataloader}


Files already downloaded and verified
Files already downloaded and verified


In [6]:
import torch_integral

continuous_dims = torch_integral.standard_continuous_dims(model)
# 'fc.weight': [0, 1],
# 'fc.bias': [0]
# 'conv1.weight': [0, 1]
continuous_dims['conv1.weight'] = [0]
continuous_dims['fc.weight'] = [1]
continuous_dims.pop('fc.bias', None)

[0]

In [7]:
integral_model = torch_integral.IntegralWrapper(
    init_from_discrete=True, verbose=True
)(model, input_example, continuous_dims)

Rearranging of group 0
variation before permutation: 87.49198150634766
variation after permutation: 50.8090705871582
Rearranging of group 1
variation before permutation: 32.05695343017578
variation after permutation: 8.883305549621582
Rearranging of group 2
variation before permutation: 41.3999137878418
variation after permutation: 10.537375450134277
Rearranging of group 3
variation before permutation: 34.390079498291016
variation after permutation: 10.334465980529785
Rearranging of group 4
variation before permutation: 141.0234375
variation after permutation: 72.79061126708984
Rearranging of group 5
variation before permutation: 42.25988006591797
variation after permutation: 12.976730346679688
Rearranging of group 6
variation before permutation: 73.67745208740234
variation after permutation: 16.51487159729004
Rearranging of group 7
variation before permutation: 170.888916015625
variation after permutation: 84.26046752929688
Rearranging of group 8
variation before permutation: 86.69864

ModuleNotFoundError: No module named 'utils'

In [8]:
from catalyst import dl
cross_entropy = torch.nn.CrossEntropyLoss()
log_dir = "./logs/cifar"
runner = dl.SupervisedRunner(
    input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)
callbacks = [
    dl.AccuracyCallback(
        input_key="logits", target_key="targets", topk=(1,), num_classes=10
    ),
    dl.SchedulerCallback(mode="batch", loader_key="train", metric_key="loss"),
]
loggers = []
epochs = 10

In [9]:
sparsity = 0.5
for group in integral_model.groups:
    if "operator" not in group.operations:
        initial_size = group.size
        new_size = int(initial_size * sparsity)
        group.reset_grid(
            torch_integral.TrainableGrid1D(new_size)
        )

print("compression: ", integral_model.eval().calculate_compression())

compression:  0.4701610103869056


In [11]:
with torch_integral.grid_tuning(integral_model, False, True):
    opt = torch.optim.Adam(integral_model.parameters(), lr=1e-3, weight_decay=0)
    epoch_len = len(train_dataloader)
    sched = torch.optim.lr_scheduler.MultiStepLR(
        opt, [epoch_len * 2, epoch_len * 5, epoch_len * 6, epoch_len * 8], gamma=0.33
    )
    runner.train(
        model=integral_model,
        criterion=cross_entropy,
        optimizer=opt,
        scheduler=sched,
        loaders=loaders,
        num_epochs=epochs,
        callbacks=callbacks,
        loggers=loggers,
        logdir=log_dir,
        valid_loader="valid",
        valid_metric="loss",
        verbose=True,
    )

1/10 * Epoch (train):   0%|          | 0/782 [00:00<?, ?it/s]

Keyboard Interrupt


KeyboardInterrupt: 

: 

In [52]:
# for group in groups:
#     for param in group.params:
#         name = param['name'].replace('.weight', '').replace('.bias', '')
#         submodule = model.get_submodule(name)
#         if 'weight' in param['name']:
#             weight = submodule.weight.data
#         elif 'bias' in param['name']:
#             weight = submodule.bias.data
#         print(name, param['dim'], weight.shape == param['value'].shape)

In [17]:
hasattr(torch.nn, target_layer.__class__.__name__)

True

In [22]:
for gr in groups:
    break

In [37]:
import numpy as np
np.arange(0, 512)

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
       104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
       117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
       130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142,
       143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
       156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168,
       169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 18

In [35]:
gr._group

[(prune_out_channels on layer4.1.conv2 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))) => prune_out_channels on layer4.1.conv2 (Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177,

In [27]:
for group in :
    # print(group)
    for i, (dep, idxs) in enumerate(group):
        trigger = dep.trigger
        handler = dep.handler
        source_layer = dep.source.module
        target_layer = dep.target.module
        if hasattr(torch.nn, target_layer.__class__.__name__):
            print(dep.target.name.split()[0], handler.__name__)
        
        if hasattr(torch.nn, source_layer.__class__.__name__):
            print(dep.source.name.split()[0], handler.__name__)
            
    print('-' * 80)

layer4.1.conv2 prune_out_channels
layer4.1.conv2 prune_out_channels
layer4.1.conv2 prune_out_channels
fc prune_in_channels
fc prune_out_channels
layer4.1.conv1 prune_in_channels
layer4.0.conv2 prune_out_channels
layer4.0.downsample.0 prune_out_channels
--------------------------------------------------------------------------------
layer3.1.conv2 prune_out_channels
layer3.1.conv2 prune_out_channels
layer3.1.conv2 prune_out_channels
layer4.0.downsample.0 prune_in_channels
layer4.0.conv1 prune_in_channels
layer3.1.conv1 prune_in_channels
layer3.0.conv2 prune_out_channels
layer3.0.downsample.0 prune_out_channels
--------------------------------------------------------------------------------
layer2.1.conv2 prune_out_channels
layer2.1.conv2 prune_out_channels
layer2.1.conv2 prune_out_channels
layer3.0.downsample.0 prune_in_channels
layer3.0.conv1 prune_in_channels
layer2.1.conv1 prune_in_channels
layer2.0.conv2 prune_out_channels
layer2.0.downsample.0 prune_out_channels
-------------------

In [71]:
for i, (dep, idxs) in enumerate(group):
    trigger = dep.trigger
    handler = dep.handler
    source_layer = dep.source.module
    target_layer = dep.target.module
    print(dep)
    print("For Dep: ", dep)
    print("Handler: ", handler.__name__)
    print(" > Source Layer: ", source_layer)
    print(" > Target Layer: ", target_layer)
    print("")

prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False))
For Dep:  prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False))
Handler:  prune_out_channels
 > Source Layer:  Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
 > Target Layer:  Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)

prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)) => prune_out_channels on layer4.0.downsample.1 (BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
For Dep:  prune_out_channels on layer4.0.downsample.0 (Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bia

In [None]:
torch.fx.Interpret

In [14]:
for group in groups:
    print(group)
    print(len(group.params))

conv1.weight: 0
layer1.0.conv1.weight: 1
layer1.0.conv2.weight: 0
layer1.1.conv1.weight: 1
layer1.1.conv2.weight: 0
layer2.0.conv1.weight: 1
layer2.0.downsample.0.weight: 1

7
conv1.weight: 2

1
conv1.weight: 3

1
bn1.weight: 0

1
bn1.bias: 0

1
layer1.0.conv1.weight: 0
layer1.0.conv2.weight: 1

2
layer1.0.conv1.weight: 2

1
layer1.0.conv1.weight: 3

1
layer1.0.bn1.weight: 0

1
layer1.0.bn1.bias: 0

1
layer1.0.conv2.weight: 2

1
layer1.0.conv2.weight: 3

1
layer1.0.bn2.weight: 0

1
layer1.0.bn2.bias: 0

1
layer1.1.conv1.weight: 0
layer1.1.conv2.weight: 1

2
layer1.1.conv1.weight: 2

1
layer1.1.conv1.weight: 3

1
layer1.1.bn1.weight: 0

1
layer1.1.bn1.bias: 0

1
layer1.1.conv2.weight: 2

1
layer1.1.conv2.weight: 3

1
layer1.1.bn2.weight: 0

1
layer1.1.bn2.bias: 0

1
layer2.0.conv1.weight: 0
layer2.0.conv2.weight: 1

2
layer2.0.conv1.weight: 2

1
layer2.0.conv1.weight: 3

1
layer2.0.bn1.weight: 0

1
layer2.0.bn1.bias: 0

1
layer2.0.conv2.weight: 2

1
layer2.0.conv2.weight: 3

1
layer2.0.