In [1]:
import torch
import torchvision
from simplify import simplify
from collections import OrderedDict
import torch.nn.utils.prune as prune
import copy
import numpy as np

In [2]:
model = torchvision.models.vgg11()
in_features = model.classifier[6].in_features
model.classifier[6] = torch.nn.Linear(in_features, out_features=10, bias=True)

starting_dict = copy.deepcopy(model.state_dict())

In [3]:
parameters_to_prune = (
            ("features.0", 'weight'),
            ("features.3", 'weight'),
            ("features.6", 'weight'),
            ("features.8", 'weight'),
            ("features.11", 'weight'),
            ("features.13", 'weight'),
            ("features.16", 'weight'),
            ("features.18", 'weight'),
            ("classifier.0", 'weight'),
            ("classifier.3", 'weight'))

mask_dict = OrderedDict()
for name, attr in parameters_to_prune:
    layer_name, n = name.split(".")
    prune.ln_structured(getattr(model, layer_name)[int(n)], name=attr,
                        amount=0.5, n=2, dim=0) #pp_prune[f"{layer_name}.{n}.{attr}"]
    mask_dict[name] = model.state_dict()[name+".weight_mask"]
    prune.remove(getattr(model, layer_name)[int(n)], name=attr)

dummy_input = torch.zeros(1, 3, 32, 32)
simplify(model, dummy_input)

simplified_dict = copy.deepcopy(model.state_dict())

print()





In [4]:
starting_dict.keys()

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.11.weight', 'features.11.bias', 'features.13.weight', 'features.13.bias', 'features.16.weight', 'features.16.bias', 'features.18.weight', 'features.18.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias'])

In [38]:
mask_dict.keys()

odict_keys(['features.0', 'features.3', 'features.6', 'features.8', 'features.11', 'features.13', 'features.16', 'features.18', 'classifier.0', 'classifier.3'])

In [5]:
present_rows = [[0,1,2]]

for k in mask_dict.keys():
    tmp = []
    for idx, e in enumerate(mask_dict[k]):
        if e.sum() != 0:
            tmp.append(idx)
    present_rows.append(tmp)

In [43]:
original_keys = list(starting_dict.keys())
original_keys = [k.replace(".bias",'') for k in original_keys]
original_keys = [k.replace(".weight",'') for k in original_keys]
original_keys = list(dict.fromkeys(original_keys))
original_keys

['features.0',
 'features.3',
 'features.6',
 'features.8',
 'features.11',
 'features.13',
 'features.16',
 'features.18',
 'classifier.0',
 'classifier.3',
 'classifier.6']

In [6]:
len(present_rows)

11

In [7]:
reconstruced_dict = OrderedDict()

for idx_k, k in enumerate(mask_dict.keys()): 
    dim = mask_dict[k].shape
    reconstructed_w = torch.zeros(dim)
    reconstructed_b = torch.zeros(dim[0])

    if "features" in k:

        
        for idx_r,r in enumerate(present_rows[idx_k+1]):
            reconstructed_b[r] = model.state_dict()[k+".bias"][idx_r]
            
            for idx_c, c in enumerate(present_rows[idx_k]):

                reconstructed_w[r,c] = model.state_dict()[k+".weight"][idx_r,idx_c]
        reconstruced_dict[k+".weight"] = reconstructed_w
        reconstruced_dict[k+".bias"] = reconstructed_b
    elif "classifier" in k:

        a = torch.load(f"nonzero_indices/input/{k}.pt")
        b = torch.load(f"nonzero_indices/output/{k}.pt")

        mask = torch.zeros(dim)

        mask[b,:] += torch.ones((int(dim[0]/2), dim[1]))
        mask[:,a] +=  torch.ones((dim[0], int(dim[1]/2)))
        mask = mask == 2

        reconstructed_b[b] = model.state_dict()[k+".bias"]
        reconstruced_dict[k+".bias"] = reconstructed_b

        reconstructed_w[mask] = torch.reshape(model.state_dict()[k+".weight"],(-1,))
        reconstruced_dict[k+".weight"] = reconstructed_w

k = "classifier.6"

reconstruced_dict[k+".bias"] = model.state_dict()["classifier.6.bias"]

dim = starting_dict["classifier.6.weight"].shape
reconstructed_w = torch.zeros(dim)

a = torch.load(f"nonzero_indices/input/{k}.pt")

mask = torch.zeros(dim)

mask[:,a] +=  torch.ones((dim[0], int(dim[1]/2)))
mask = mask == 1


reconstructed_w[mask] = torch.reshape(model.state_dict()[k+".weight"],(-1,))
reconstruced_dict[k+".weight"] = reconstructed_w


In [8]:
model = torchvision.models.vgg11()
in_features = model.classifier[6].in_features
model.classifier[6] = torch.nn.Linear(in_features, out_features=10, bias=True)
model.load_state_dict(reconstruced_dict)


<All keys matched successfully>

In [9]:
for name, attr in parameters_to_prune:
    layer_name, n = name.split(".")
    prune.ln_structured(getattr(model, layer_name)[int(n)], name=attr,
                        amount=0.5, n=2, dim=0) #pp_prune[f"{layer_name}.{n}.{attr}"]
    prune.remove(getattr(model, layer_name)[int(n)], name=attr)

dummy_input = torch.zeros(1, 3, 32, 32)
simplify(model, dummy_input)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   

In [10]:
for k,v in simplified_dict.items():
    print(torch.equal(v, model.state_dict()[k]))

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [11]:
nonzero_input = {k:torch.load(f"nonzero_indices/input/{k.replace('.weight','')}.pt") for k in starting_dict.keys() if "weight" in k}
nonzero_output = {k: torch.load(f"nonzero_indices/output/{k.replace('.weight','')}.pt") for k in starting_dict.keys() if "weight" in k}


In [12]:
starting_dict.keys()

odict_keys(['features.0.weight', 'features.0.bias', 'features.3.weight', 'features.3.bias', 'features.6.weight', 'features.6.bias', 'features.8.weight', 'features.8.bias', 'features.11.weight', 'features.11.bias', 'features.13.weight', 'features.13.bias', 'features.16.weight', 'features.16.bias', 'features.18.weight', 'features.18.bias', 'classifier.0.weight', 'classifier.0.bias', 'classifier.3.weight', 'classifier.3.bias', 'classifier.6.weight', 'classifier.6.bias'])

In [13]:
for k,v in nonzero_input.items():
    print(k)
    print(v.count_nonzero())


features.0.weight
tensor(3)
features.3.weight
tensor(32)
features.6.weight
tensor(64)
features.8.weight
tensor(128)
features.11.weight
tensor(128)
features.13.weight
tensor(256)
features.16.weight
tensor(256)
features.18.weight
tensor(256)
classifier.0.weight
tensor(12544)
classifier.3.weight
tensor(2048)
classifier.6.weight
tensor(2048)


In [14]:
for k,v in nonzero_output.items():
    #(k)
    print(v.count_nonzero())

tensor(32)
tensor(64)
tensor(128)
tensor(128)
tensor(256)
tensor(256)
tensor(256)
tensor(256)
tensor(2048)
tensor(2048)
tensor(10)


In [None]:
nonzero_output = {k:torch.load(f"nonzero_indices/output/{k.replace('.weight','')}.pt") for k in starting_dict.keys() if "weight" in k}


pp = [int(v.count_nonzero()) for v in nonzero_output.values()]
pp.insert(0, 3)
pp = pp[:len(p)]
pp

In [16]:
len(p)

11

In [17]:
present_rows

[[0, 1, 2],
 [0,
  1,
  2,
  3,
  4,
  6,
  8,
  9,
  10,
  17,
  18,
  20,
  21,
  22,
  23,
  25,
  27,
  30,
  32,
  33,
  36,
  37,
  40,
  41,
  42,
  43,
  44,
  47,
  50,
  51,
  56,
  62],
 [2,
  3,
  4,
  7,
  10,
  11,
  12,
  13,
  14,
  16,
  18,
  21,
  22,
  24,
  27,
  29,
  30,
  31,
  35,
  37,
  38,
  43,
  44,
  45,
  47,
  53,
  55,
  56,
  58,
  59,
  60,
  62,
  65,
  68,
  70,
  75,
  76,
  77,
  78,
  79,
  80,
  82,
  83,
  84,
  87,
  88,
  90,
  92,
  94,
  95,
  100,
  101,
  102,
  105,
  107,
  108,
  111,
  114,
  115,
  117,
  118,
  119,
  123,
  127],
 [2,
  3,
  8,
  9,
  11,
  14,
  15,
  16,
  19,
  22,
  23,
  25,
  26,
  27,
  29,
  31,
  32,
  34,
  35,
  39,
  43,
  46,
  48,
  50,
  53,
  54,
  56,
  57,
  58,
  59,
  61,
  62,
  63,
  66,
  67,
  68,
  69,
  72,
  74,
  75,
  76,
  77,
  80,
  81,
  82,
  83,
  87,
  89,
  92,
  93,
  95,
  97,
  99,
  100,
  101,
  102,
  104,
  105,
  106,
  107,
  109,
  110,
  111,
  112,
  114,
  115,
  1

In [35]:
for i in range(11):
    print(present_rows[i] == pp[i])

True
True
True
True
True
True
True
True
True
True
True


In [34]:
pp = [v.nonzero(as_tuple=True)[0].tolist() for v in nonzero_output.values()]
pp.insert(0, [0,1,2])