# Detector: ModelArchitecture

- author: gh
- last update: 2023-10-30


## Models

In [1]:
import torchvision

resnet = torchvision.models.get_model("resnet18").eval()
vit = torchvision.models.get_model("vit_b_16").eval()

## ModelArchitecture

### Construct an instance

In [2]:
from torch.fx import symbolic_trace
from open_xai.detector import ModelArchitecture

traced_resnet = symbolic_trace(resnet)
ma_resnet = ModelArchitecture(graph=traced_resnet.graph)

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


**from_model** a classmethod directly constructing the instance from a model

In [3]:

ma_resnet = ModelArchitecture.from_model(resnet)
ma_vit = ModelArchitecture.from_model(vit)

### Methods

**list_nodes() -> List[NodeInfo]**

List all nodes consisting of the model architecture as `NodeInfo` class.

In [4]:
ma_resnet.list_nodes()[:5]

[NodeInfo(opcode='placeholder', name='x', target='x'),
 NodeInfo(opcode='call_module', name='conv1', target='conv1'),
 NodeInfo(opcode='call_module', name='bn1', target='bn1'),
 NodeInfo(opcode='call_module', name='relu', target='relu'),
 NodeInfo(opcode='call_module', name='maxpool', target='maxpool')]

**get_node(name: str) -> NodeInfo**

Get a node by name.

In [5]:
ma_resnet.get_node(name="conv1")

NodeInfo(opcode='call_module', name='conv1', target='conv1')

**find_node(filter_func: Callable, base_node: Optional[NodeInfo]=None, all: bool=False) -> NodeInfo or List[NodeInfo]**

Find node satisfying `filter_func` from the `base_node` in binary search manner.

- `filter_func(n: NodeInfo) -> bool`: a callable returning bool from a `node: NodeInfo`
- `base_node: NodeInfo`: a node where the searching will be started
- `all: bool`: if True, find all nodes satisfying condition and return a list of them

In [6]:
from open_xai.detector.filters import conv_filter

# find the first conv module
ma_resnet.find_node(conv_filter)

NodeInfo(opcode='call_module', name='conv1', target='conv1')

In [7]:
# find all conv module from the model architecture
ma_resnet.find_node(conv_filter, all=True)

[NodeInfo(opcode='call_module', name='conv1', target='conv1'),
 NodeInfo(opcode='call_module', name='layer1_0_conv1', target='layer1.0.conv1'),
 NodeInfo(opcode='call_module', name='layer1_0_conv2', target='layer1.0.conv2'),
 NodeInfo(opcode='call_module', name='layer1_1_conv1', target='layer1.1.conv1'),
 NodeInfo(opcode='call_module', name='layer1_1_conv2', target='layer1.1.conv2'),
 NodeInfo(opcode='call_module', name='layer2_0_conv1', target='layer2.0.conv1'),
 NodeInfo(opcode='call_module', name='layer2_0_conv2', target='layer2.0.conv2'),
 NodeInfo(opcode='call_module', name='layer2_0_downsample_0', target='layer2.0.downsample.0'),
 NodeInfo(opcode='call_module', name='layer2_1_conv1', target='layer2.1.conv1'),
 NodeInfo(opcode='call_module', name='layer2_1_conv2', target='layer2.1.conv2'),
 NodeInfo(opcode='call_module', name='layer3_0_conv1', target='layer3.0.conv1'),
 NodeInfo(opcode='call_module', name='layer3_0_conv2', target='layer3.0.conv2'),
 NodeInfo(opcode='call_module', 

In [8]:
# custom filter_func
from _operator import add

# find all where addition occurs
conn_filter = lambda n: n.operator is add
ma_resnet.find_node(conn_filter, all=True)

[NodeInfo(opcode='call_function', name='add', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_1', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_2', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_3', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_4', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_5', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_6', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_7', target='operator.add')]

In [9]:
# find all where cloning occurs
clone_filter = lambda n: len(n.users) == 2
ma_resnet.find_node(clone_filter, all=True)

[NodeInfo(opcode='call_module', name='maxpool', target='maxpool'),
 NodeInfo(opcode='call_module', name='layer1_0_relu_1', target='layer1.0.relu'),
 NodeInfo(opcode='call_module', name='layer1_1_relu_1', target='layer1.1.relu'),
 NodeInfo(opcode='call_module', name='layer2_0_relu_1', target='layer2.0.relu'),
 NodeInfo(opcode='call_module', name='layer2_1_relu_1', target='layer2.1.relu'),
 NodeInfo(opcode='call_module', name='layer3_0_relu_1', target='layer3.0.relu'),
 NodeInfo(opcode='call_module', name='layer3_1_relu_1', target='layer3.1.relu'),
 NodeInfo(opcode='call_module', name='layer4_0_relu_1', target='layer4.0.relu')]

**find_cam_target_module() -> Optional[Tuple[str, nn.Module]]**

Find CAM-target candidate name and module object. Internally, it uses `find_node` method with predefined `filter_func` to detect the candidates, and finally get `owning_module` of the  node.

In [10]:
# CAM-target node in resnet
ma_resnet.find_cam_target_module()

('layer4.1.relu', ReLU(inplace=True))

If there is no module satisfying the predefined conditions, output `None`.

In [11]:
# No CAM target in vit
ma_vit.find_cam_target_module()

## NodeInfo

`NodeInfo` is a dataclass referred to a `node: torch.fx.Node` in the graph. It is working like the referring node by cloning its main attributions. Also, several properties helping detection were added.

In [12]:
example_node = ma_resnet.get_node("conv1")
example_node

NodeInfo(opcode='call_module', name='conv1', target='conv1')

In [13]:
# referring (original) node
example_node._node

conv1

In [14]:
type(example_node._node)

torch.fx.node.Node

In [15]:
type(example_node)

open_xai.detector._core.NodeInfo

### Cloned attributions from `torch.fx.Node` 

Please see [docs](https://pytorch.org/docs/stable/fx.html#torch.fx.Node) for `torch.fx.Node`.

In [16]:
# clone of n.op
example_node.opcode

'call_module'

In [17]:
# clone of n.name
example_node.name

'conv1'

In [18]:
# clone of n.target (accessible name if n.opcode == "call_module")
example_node.target

'conv1'

In [19]:
# clone of n.meta
example_node.meta

{'nn_module_stack': OrderedDict([('conv1', torch.nn.modules.conv.Conv2d)])}

In [20]:
# clone of n.args
example_node.args

(NodeInfo(opcode='placeholder', name='x', target='x'),)

In [21]:
# clone of n.kwargs
example_node.kwargs

{}

In [22]:
# clone of n.users
example_node.users

(NodeInfo(opcode='call_module', name='bn1', target='bn1'),)

In [23]:
# clone of n.next
example_node.next

NodeInfo(opcode='call_module', name='bn1', target='bn1')

In [24]:
# clone of n.prev
example_node.prev

NodeInfo(opcode='placeholder', name='x', target='x')

### Properties helping detection

#### `operator`

In [25]:
# directly get the matched operator from the node
example_node.operator

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [26]:
# definitely same as one from model
operator_from_model = resnet
for t in example_node.target.split("."):
    operator_from_model = getattr(operator_from_model, t)

operator_from_model is example_node.operator

True

In [27]:
# detection example using node.operator

pool_filter = lambda n: (
    n.opcode == "call_module"
    and n.operator.__module__ == "torch.nn.modules.pooling"
)
ma_resnet.find_node(pool_filter, all=True)

[NodeInfo(opcode='call_module', name='maxpool', target='maxpool'),
 NodeInfo(opcode='call_module', name='avgpool', target='avgpool')]

#### `owning_module -> Optional[Tuple[str, nn.Module]]`

Get a module and name owning the node

In [28]:
# If node.opcode == "call_module", its owning module is the operator itself.
example_node.owning_module

('conv1',
 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))

In [29]:
print(example_node.target == example_node.owning_module[0])
print(example_node.operator == example_node.owning_module[1])

True
True


In [30]:
# If node.opcode == "call_function", it shows which module is containing the operator function.
func_node = ma_resnet.get_node("add")
func_node.owning_module

('layer1.0',
 Module(
   (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
   (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
 ))

In [31]:
# (node.target, node.operator) is totally different from `owning_module`
print(func_node.target)
print(func_node.operator)

operator.add
<built-in function add>


### Converting data structure

In [32]:
# to dictionary
example_node.to_dict()

{'opcode': 'call_module', 'name': 'conv1', 'target': 'conv1'}

In [33]:
func_node.to_dict()

{'opcode': 'call_function', 'name': 'add', 'target': 'operator.add'}

## Example: GradCAM

In [34]:
import torch
from captum.attr import LayerGradCam

inputs = torch.randn(1, 3, 224, 224)

target = resnet(inputs).argmax(1).item()
cam_tgt_nm, cam_tgt_mod = ma_resnet.find_cam_target_module()
explainer = LayerGradCam(resnet, layer=cam_tgt_mod) # here
attrs = explainer.attribute(inputs, target=target)
attrs

tensor([[[[0.0561, 0.0550, 0.0563, 0.0557, 0.0578, 0.0479, 0.0391],
          [0.0893, 0.0940, 0.0961, 0.1078, 0.1153, 0.0864, 0.0563],
          [0.0915, 0.1082, 0.1221, 0.1191, 0.1151, 0.0974, 0.0728],
          [0.0953, 0.1204, 0.1017, 0.0955, 0.0985, 0.0807, 0.0548],
          [0.0856, 0.1008, 0.1039, 0.1026, 0.1004, 0.0846, 0.0578],
          [0.0855, 0.0901, 0.1019, 0.1059, 0.0977, 0.0781, 0.0526],
          [0.0545, 0.0746, 0.0666, 0.0717, 0.0683, 0.0714, 0.0493]]]],
       grad_fn=<SumBackward1>)

## Example: LRP

In [35]:
from zennit.composites import EpsilonGammaBox
from zennit.canonizers import SequentialMergeBatchNorm
from zennit.attribution import Gradient

canonizers = [SequentialMergeBatchNorm()]
composite = EpsilonGammaBox(low=-3., high=3., canonizers=canonizers)

with Gradient(model=resnet, composite=composite) as attributor:
    _, attrs = attributor(inputs, torch.eye(1000)[[0]])

attrs

tensor([[[[1.0107e-03, 1.5427e-03, 1.2988e-03,  ..., 3.6572e-05,
           1.3705e-05, 8.6201e-06],
          [1.2498e-03, 1.6644e-03, 1.5210e-03,  ..., 2.9877e-05,
           1.0165e-05, 1.3933e-05],
          [1.3658e-03, 1.3569e-03, 1.6763e-03,  ..., 3.8471e-05,
           1.3477e-05, 1.1588e-05],
          ...,
          [5.0055e-05, 6.1396e-05, 6.7468e-05,  ..., 1.5492e-05,
           1.4921e-05, 1.2674e-05],
          [1.4413e-05, 2.1118e-05, 3.6859e-05,  ..., 1.3347e-05,
           5.0072e-06, 4.1874e-06],
          [1.4605e-05, 2.5482e-05, 2.8178e-05,  ..., 1.1906e-05,
           4.5209e-06, 3.5785e-06]],

         [[1.0090e-03, 1.4876e-03, 1.1762e-03,  ..., 1.8069e-05,
           1.0676e-05, 8.1778e-06],
          [1.3539e-03, 2.9663e-03, 2.0740e-03,  ..., 4.6878e-05,
           1.5746e-05, 1.9380e-05],
          [1.0943e-03, 1.6252e-03, 1.5074e-03,  ..., 4.3538e-05,
           1.0312e-05, 7.5514e-06],
          ...,
          [5.4181e-05, 7.2769e-05, 6.7374e-05,  ..., 1.8367

## Example: RAP

**1. hook is working with `node.operator` if `node.opcode=="call_module"`!**

In [36]:
def test_hook(m, x, y):
    # some rule assignment by register hook
    # if not hasattr(m, "rule"): m.rule = rule
    print(f"{m.__class__.__name__} is hooked!")
    
hook_handles = [example_node.operator.register_forward_hook(test_hook)]
_ = resnet(inputs) # forward pass

# remove hooks
for h in hook_handles:
    h.remove()

Conv2d is hooked!


**[EXPERIMENTAL] To control residual connection, convert `add` function to `Add` module and register hook.**

In [37]:
from _operator import add
from torch import nn
from torch.fx import GraphModule, symbolic_trace

# ./detector/converted_modules.py
class Add(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a, b):
        return torch.add(a, b)

# ./detector/utils.py or model architecture's method
def convert_add_func_to_mod(ma):
    for n in ma.graph.nodes:
        if n.target is add:
            if not hasattr(ma.graph.owning_module, "_add_converted"):
                # ma.graph.owning_module._add_converted = Add()
                setattr(ma.graph.owning_module, "_add_converted", Add())
            with ma.graph.inserting_after(n):
                new_node = ma.graph.call_module("_add_converted", n.args, n.kwargs)
                n.replace_all_uses_with(new_node)
            ma.graph.erase_node(n)
    ma.graph.lint()
    ma.graph.owning_module.recompile()
    return ma

# register hook
ma_resnet = convert_add_func_to_mod(ma_resnet)

def add_hook(m, x, y):
    print(f"{m} is hooked!")

hook_handles = [
    n.operator.register_forward_hook(add_hook) for n in ma_resnet.list_nodes()
    if n.name == "_add_converted"
]

hook_handles += [
    n.operator.register_forward_hook(test_hook) for n in ma_resnet.list_nodes()
    if n.opcode == "call_module" and n.operator.__module__ == "torch.nn.modules.pooling"
]

print("---hook for converted module works only in the traced model---")
_ = ma_resnet.graph.owning_module(inputs)

print("\n---hook for converted module is not registered in the original model---")
resnet(inputs)
for h in hook_handles:
    h.remove()
# resnet(inputs)

---hook for converted module works only in the traced model---
MaxPool2d is hooked!
Add() is hooked!
Add() is hooked!
Add() is hooked!
Add() is hooked!
Add() is hooked!
Add() is hooked!
Add() is hooked!
Add() is hooked!
AdaptiveAvgPool2d is hooked!

---hook for converted module is not registered in the original model---
MaxPool2d is hooked!
AdaptiveAvgPool2d is hooked!
