# Exploring Dependency Groups

In [1]:
import warnings
warnings.filterwarnings('ignore')
import sys, os
sys.path.append(os.path.abspath("../"))

import torch
from torchvision.models import resnet18
import torch_pruning as tp

### Grouping

In this part, we will delve into the details in the ``DependencyGraph`` module, illustrating its effectiveness in facilitating structural pruning. First, let's fetch a group from a ResNet-18.

In [2]:
# 0. prepare your model and example inputs
model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1,3,224,224)

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs = pruning_idxs=[2, 6, 9]
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

### Indexing

In Torch-Pruning, dependencies are organized as a iteratable list. In a given group, the initial operation performed by users is considered the root operation. For instance, if we are attempting to prune the 'model.conv1' operation, the first dependency in the group will reflect this action, which prunes the output channels of conv1.

In [3]:
print(group[0])

([DEP] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), [2, 6, 9])


Each dependency in the group will include pruning indices that correspond to the channels to be pruned. Here we aims to remove the 2nd, 6th and 9th channels of conv1.

In [4]:
print("Dep:", group[1][0])
print("Indices:", group[1][1])

Dep: [DEP] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
Indices: [2, 6, 9]


Let's delve deeper into the concept of dependency in DepGraph. In DepGraph a dependency is represented as an edge that connects two nodes, indicating the presence of inter-dependency. Each dependency maintains two pruning functions: 1) a trigger function, which is a pruning operation that breaks the dependency when solely applied, and 2) a handler function, which can repair the broken dependency caused by triggers.  

For instance, consider the simple Conv-BN dependency between conv1 and bn1. If we remove an output channel of 'conv1', it becomes necessary to prune the corresponding channel of 'BN' as well. This dependency is clearly illustrated in the following example.

In [5]:
print("Source Node:", group[1][0].source.module) # group[1][0].source.module # get the nn.Module
print("Target Node:", group[1][0].target.module) # group[1][0].target.module # get the nn.Module
print("Trigger Function:", group[1][0].trigger)
print("Handler Function:", group[1][0].handler)

Source Node: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Target Node: BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Trigger Function: <bound method ConvPruner.prune_out_channels of <torch_pruning.pruner.function.ConvPruner object at 0x7f5cbe611b20>>
Handler Function: <bound method BatchnormPruner.prune_out_channels of <torch_pruning.pruner.function.BatchnormPruner object at 0x7f5cbe611bb0>>


### Pruning with Dependency

In Torch-Pruning, we can "execute" a dependency to apply the handler function for pruning. Here we only prune the first conv1 without fixing the dependency.

In [6]:
idx = group[0][1]
dep = group[0][0]
dep(idx)

Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

However, if we try to forward this model as usual, a error will occur, showing that ``running_mean should contain 61 elements not 64.``

In [7]:
print(model(torch.randn(1,3,224,224)))

RuntimeError: running_mean should contain 61 elements not 64

To address this issue, we should use "group pruning" to remove a group of parameters from this model.

In [8]:
model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1,3,224,224)

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs = pruning_idxs=[2, 6, 9]
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

group.prune()

In [9]:
print(model(torch.randn(1,3,224,224)).shape)

torch.Size([1, 1000])
