In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from torch.onnx import TrainingMode
import torch.nn.utils.prune as prune

import onnx
import numpy as np

import torchslimkit as tsk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
weights = ResNet50_Weights.DEFAULT
preprocess = weights.transforms()

In [3]:
model = resnet50(weights=weights)

In [4]:
onnx_graph = tsk.generate_onnx_graph(model, preprocess=preprocess)
# onnx_graph

In [5]:
graph = tsk.generate_graph(onnx_graph=onnx_graph)

In [6]:
graph

{'conv1': node_type=Conv, parent=, child_nodes=bn1,
 'bn1': node_type=BatchNormalization, parent=conv1, child_nodes=Relu_2,
 'Relu_2': node_type=Relu, parent=bn1, child_nodes=MaxPool_3,
 'MaxPool_3': node_type=MaxPool, parent=Relu_2, child_nodes=layer1.0.conv1, layer1.0.downsample.0,
 'layer1.0.conv1': node_type=Conv, parent=MaxPool_3, child_nodes=layer1.0.bn1,
 'layer1.0.bn1': node_type=BatchNormalization, parent=layer1.0.conv1, child_nodes=Relu_6,
 'Relu_6': node_type=Relu, parent=layer1.0.bn1, child_nodes=layer1.0.conv2,
 'layer1.0.conv2': node_type=Conv, parent=Relu_6, child_nodes=layer1.0.bn2,
 'layer1.0.bn2': node_type=BatchNormalization, parent=layer1.0.conv2, child_nodes=Relu_9,
 'Relu_9': node_type=Relu, parent=layer1.0.bn2, child_nodes=layer1.0.conv3,
 'layer1.0.conv3': node_type=Conv, parent=Relu_9, child_nodes=layer1.0.bn3,
 'layer1.0.bn3': node_type=BatchNormalization, parent=layer1.0.conv3, child_nodes=onnx::Add_350,
 'layer1.0.downsample.0': node_type=Conv, parent=MaxPoo

In [7]:
for name, mod in model.named_modules():
    if isinstance(mod, nn.Conv2d):
        prune.ln_structured(mod, 'weight', 0.3, float('inf'), dim=0)
        prune.remove(mod, 'weight')
        print(tsk.get_filter_indices(mod))

tensor([ 3,  8, 12, 13, 19, 22, 23, 24, 25, 29, 32, 33, 39, 40, 41, 52, 54, 59,
        61])
tensor([ 1,  5,  9, 12, 14, 19, 20, 25, 27, 31, 32, 37, 41, 49, 54, 56, 60, 61,
        62])
tensor([ 0,  1,  7, 10, 11, 12, 16, 17, 22, 25, 26, 29, 33, 36, 41, 43, 46, 56,
        61])
tensor([  3,   9,  13,  20,  24,  27,  32,  34,  37,  38,  42,  45,  46,  47,
         49,  51,  52,  58,  60,  63,  70,  81,  82,  84,  92,  93,  94,  95,
         97, 100, 102, 109, 112, 114, 118, 119, 120, 121, 124, 126, 127, 128,
        131, 137, 143, 153, 155, 157, 159, 160, 164, 165, 167, 170, 173, 174,
        177, 179, 181, 183, 186, 187, 195, 196, 200, 203, 206, 218, 220, 222,
        226, 232, 236, 239, 240, 242, 251])
tensor([  2,  10,  12,  13,  22,  27,  32,  34,  38,  42,  44,  45,  46,  49,
         51,  55,  57,  58,  60,  72,  73,  74,  76,  81,  84,  90,  92,  93,
         94,  97,  98, 100, 108, 109, 119, 121, 127, 131, 135, 140, 144, 152,
        156, 157, 158, 164, 165, 170, 173, 177, 178, 

In [8]:
# Embed un-pruned filter indices into model
tsk.embed_filter_indices(model)

# Get and aggregate the sibling dependencies
sibling_groups = tsk.aggregate_sibling_groups(graph)

In [19]:
group = sibling_groups[0]

filter_idx_list = []
for mod in group:
    mod = eval(f'model{tsk.name2attr(mod)}')
    filter_idx_list.append(mod.filter_indices)
    
filter_indices = torch.cat(filter_idx_list)
filter_indices = torch.unique(filter_indices)
filter_indices, _ = torch.sort(filter_indices)

for mod in group:
    mod = eval(f'model{tsk.name2attr(mod)}')
    exec(f'mod.filter_indices = filter_indices')

torch.Size([1434])
torch.Size([1434])
torch.Size([1434])
torch.Size([1434])



torch.Size([1976])

In [8]:
node = 'layer1.1.conv1'
tsk.get_all_parents(node, graph)

['layer1.0.conv3', 'layer1.0.downsample.0']

In [11]:
node = 'layer1.2.conv1'
tsk.get_all_parents(node, graph)

['layer1.0.conv3', 'layer1.0.downsample.0', 'layer1.1.conv3']

In [13]:
node = 'input.308'
tsk.get_all_parents(node, graph)

['layer3.0.conv3',
 'layer3.0.downsample.0',
 'layer3.1.conv3',
 'layer3.2.conv3',
 'layer3.3.conv3',
 'layer3.4.conv3',
 'layer3.5.conv3']