# Advanced example

Some special module does not have the pre-defined Node in `UniP`. We could write a `CustomNode` for it.

Let's take the `DCN` for example (although it has been added to `unip/core/node.py`):

In [11]:
import torch.nn as nn
from unip.core.node import InOutNode, CustomNode
from unip.utils.prune_ops import *

class dcnNode(InOutNode, CustomNode):
    def __init__(self, name: str, module: nn.Module, grad) -> None:
        super().__init__(name, module, grad)
        self.in_channels = module.offset_conv.in_channels
        self.out_channels = module.regular_conv.out_channels
        self.param = [module.regular_conv.weight.data]
        if module.regular_conv.bias is not None:
            self.param.append(module.regular_conv.bias.data)

    def prune(self):
        self.saved_idx[IDX_IN] = get_saved_idx(
            self.prune_idx[IDX_IN], self.module.offset_conv.weight.shape[DIM_IN]
        )
        self.saved_idx[IDX_OUT] = get_saved_idx(
            self.prune_idx[IDX_OUT], self.module.regular_conv.weight.shape[DIM_OUT]
        )
        prune_conv(self.module.offset_conv, self.saved_idx[IDX_IN], DIM_IN)
        prune_conv(self.module.modulator_conv, self.saved_idx[IDX_IN], DIM_IN)
        prune_conv(self.module.regular_conv, self.saved_idx[IDX_IN], DIM_IN)
        prune_conv(self.module.regular_conv, self.saved_idx[IDX_OUT], DIM_OUT)

    def get_attr(self):
        return {
            "offset_conv.weight.data": self.module.offset_conv.weight.data,
            "offset_conv.bias.data": self.module.offset_conv.bias.data
            if self.module.offset_conv.bias is not None
            else None,
            "offset_conv.in_channels": self.module.offset_conv.in_channels,
            "modulator_conv.weight.data": self.module.modulator_conv.weight.data,
            "modulator_conv.bias.data": self.module.modulator_conv.bias.data
            if self.module.modulator_conv.bias is not None
            else None,
            "modulator_conv.in_channels": self.module.modulator_conv.in_channels,
            "regular_conv.weight.data": self.module.regular_conv.weight.data,
            "regular_conv.bias.data": self.module.regular_conv.bias.data
            if self.module.regular_conv.bias is not None
            else None,
            "regular_conv.in_channels": self.module.regular_conv.in_channels,
            "regular_conv.out_channels": self.module.regular_conv.out_channels,
        }

And then, add a `igtype2nodetype` dict for this module:

In [12]:
import sys
sys.path.append("../../tests/")
from model.backbone.conv_utils.dcn import DeformableConv2d
igtype2nodetype = {DeformableConv2d: dcnNode}

In [13]:
import torch
from model.radarnet import RCNet
from unip.core.pruner import BasePruner
from unip.utils.evaluation import cal_flops

# load model and example input
model = RCNet(in_channels=3)
example_input = torch.randn(1, 3, 320, 320, requires_grad=True)

# record the flops and params
print("original model:")
flops_ori, params_ori = cal_flops(model, example_input)

# define pruner
BP = BasePruner(
    model,
    example_input,
    "UniformRatio",
    algo_args={"score_fn": "weight_sum_l1_out"}, 
    igtype2nodetype=igtype2nodetype,
)
BP.prune(0.3)

# record the flops and params
print("pruned model:")
flops_pruned, params_pruned = cal_flops(model, example_input)

original model:
('184.413M', '40.887K')
pruned model:
('157.885M', '32.123K')
