# Exploring Dependency Groups
> 探索依赖关系组

In [2]:
import warnings
warnings.filterwarnings('ignore')
import sys, os
sys.path.append(os.path.abspath("../"))

import oneflow as torch
from flowvision.models import resnet18
import oneflow_pruning as tp

### Grouping

在这部分中，我们将深入研究“DependencyGraph”模块的细节，说明它在促进结构修剪方面的有效性。

首先，让我们从ResNet-18中获取一组。


In [3]:
# 0. prepare your model and example inputs
model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1,3,224,224)

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs = pruning_idxs=[2, 6, 9]
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

### Indexing

在Torch-Pruning中，依赖关系被组织为可迭代列表。在给定的组中，用户执行的初始操作被认为是根操作。

例如，如果我们试图修剪'model.conv1'操作，则组中的第一个依赖项将反映此操作，该操作修剪了conv1的输出通道。


In [4]:
print(group[0])

GroupItem(dep=prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9])


每个组中的依赖关系都包括对应于要修剪的通道的修剪索引。

在这里，我们旨在删除conv1的第2、第6和第9个通道。

In [5]:
print("Dep:", group[1][0])
print("Indices:", group[1][1])

Dep: prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))
Indices: [2, 6, 9]


Let's delve deeper into the concept of dependency in DepGraph. In DepGraph a dependency is represented as an edge that connects two nodes, indicating the presence of inter-dependency. Each dependency maintains two pruning functions: 1) a trigger function, which is a pruning operation that breaks the dependency when solely applied, and 2) a handler function, which can repair the broken dependency caused by triggers.  

For instance, consider the simple Conv-BN dependency between conv1 and bn1. If we remove an output channel of 'conv1', it becomes necessary to prune the corresponding channel of 'BN' as well. This dependency is clearly illustrated in the following example.


In Torch-Pruning, dependencies are organized as an iterable list. In a given group, the initial operation performed by the user is considered the root operation. For example, if we try to prune the 'model.conv1' operation, the first dependency in the group will reflect this operation, which prunes the output channel of conv1.

Each dependency in the group includes the pruning index corresponding to the channel to be pruned. Here, we aim to remove the 2nd, 6th, and 9th channels of conv1.

让我们深入了解DepGraph中的依赖关系概念。在DepGraph中，依赖关系表示为连接两个节点的边，表示存在相互依赖关系。每个依赖关系维护两个修剪函数：1）触发函数，它是一个修剪操作，当单独应用时会打破依赖关系，2）处理程序函数，它可以修复由触发器引起的破坏的依赖关系。

例如，考虑conv1和bn1之间的简单Conv-BN依赖关系。如果我们删除'conv1'的一个输出通道，则有必要同时修剪'BN'的相应通道。这种依赖关系在以下示例中清楚地说明了。

每个组中的每个依赖关系都包括对应于要修剪的通道的修剪索引。在这里，我们的目标是删除conv1的第2、第6和第9个通道。


In [6]:
print("Source Node:", group[1][0].source.module) # group[1][0].source.module # get the nn.Module
print("Target Node:", group[1][0].target.module) # group[1][0].target.module # get the nn.Module
print("Trigger Function:", group[1][0].trigger)
print("Handler Function:", group[1][0].handler)

Source Node: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Target Node: BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
Trigger Function: <bound method ConvPruner.prune_out_channels of <oneflow_pruning.pruner.function.ConvPruner object at 0x7f7410066ee0>>
Handler Function: <bound method BatchnormPruner.prune_out_channels of <oneflow_pruning.pruner.function.BatchnormPruner object at 0x7f7410066fd0>>


### Pruning with Dependency

在Torch-Pruning中，我们可以“执行”依赖项以应用修剪的处理程序函数。在这里，我们只修剪第一个conv1，而不修复依赖项。


In [7]:
idx = group[0][1]
dep = group[0][0]
print(f'{idx=}')
dep(idx)

idx=[2, 6, 9]


Conv2d(3, 61, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

然而，如果我们尝试像往常一样进行模型 forward，将会出现错误，显示“running_mean应该包含61个元素而不是64个”。


In [8]:
print(model(torch.randn(1,3,224,224)))

RuntimeError: Check failed: ((64,) == (61,)) 
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/core/functional/impl/nn_functor.cpp", line 2634, in operator()
    OpInterpUtil::Dispatch<one::Tensor>( *norm_eval_op_, {x, moving_mean_val, moving_variance_val, gamma_val, beta_val}, attrs)
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp", line 144, in Dispatch<oneflow::one::Tensor>
    Dispatch<TensorTuple>(op_expr, inputs, ctx)
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp", line 135, in Dispatch<oneflow::one::TensorTuple>
    Dispatch(op_expr, inputs, outputs.get(), ctx)
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/core/framework/op_interpreter/op_interpreter.cpp", line 103, in Apply
    internal_->Apply(op_expr, inputs, outputs, ctx)
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp", line 83, in NaiveInterpret
    [&]() -> Maybe<const LocalTensorInferResult> { LocalTensorMetaInferArgs ... mut_local_tensor_infer_cache()->GetOrInfer(infer_args)); }()
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/core/framework/op_interpreter/eager_local_op_interpreter.cpp", line 86, in operator()
    user_op_expr.mut_local_tensor_infer_cache()->GetOrInfer(infer_args)
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/core/framework/local_tensor_infer_cache.cpp", line 209, in GetOrInfer
    Infer(*user_op_expr, infer_args)
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/core/framework/local_tensor_infer_cache.cpp", line 178, in Infer
    user_op_expr.InferPhysicalTensorDesc( infer_args.attrs ... ) -> TensorMeta* { return &output_mut_metas.at(i); })
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/core/framework/op_expr.cpp", line 580, in InferPhysicalTensorDesc
    physical_tensor_desc_infer_fn_(&infer_ctx)
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/user/ops/normalization_op.cpp", line 159, in operator()
    CheckParamTensorDesc("moving_mean")
  File "/home/ci-user/runners/release/_work/oneflow/oneflow/oneflow/user/ops/normalization_op.cpp", line 32, in operator()
    CHECK_EQ_OR_RETURN(tensor_desc.shape(), shape)
Error Type: oneflow.ErrorProto.check_failed_error

为了解决这个问题，我们应该使用"group pruning"来从该模型中删除一组参数。

In [9]:
model = resnet18(pretrained=True).eval()
example_inputs = torch.randn(1,3,224,224)

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=example_inputs)

# 2. Select some channels to prune. Here we prune the channels indexed by [2, 6, 9].
pruning_idxs = pruning_idxs=[2, 6, 9]
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=pruning_idxs )

group.prune()

In [10]:
print(model(torch.randn(1,3,224,224)).shape)

oneflow.Size([1, 1000])
