# Detector: ModelArchitecture

- author: gh
- last update: 2023-10-30


## Models

In [1]:
import torchvision

resnet = torchvision.models.get_model("resnet18").eval()
vit = torchvision.models.get_model("vit_b_16").eval()

  warn(


## ModelArchitecture

### Construct an instance

In [2]:
from torch.fx import symbolic_trace
from open_xai.detector import ModelArchitecture

ma_resnet = ModelArchitecture(model=resnet)
ma_vit = ModelArchitecture(model=vit)

# [DELETED] from_model method
# ma_resnet = ModelArchitecture.from_model(resnet)
# ma_vit = ModelArchitecture.from_model(vit)

### Attributions

**model** the input model

**traced_model** the model traced by `torch.fx.symbolic_trace`

In [5]:
ma_resnet.traced_model.graph.print_tabular()

opcode         name                   target                                                      args                                   kwargs
-------------  ---------------------  ----------------------------------------------------------  -------------------------------------  --------
placeholder    x                      x                                                           ()                                     {}
call_module    conv1                  conv1                                                       (x,)                                   {}
call_module    bn1                    bn1                                                         (conv1,)                               {}
call_module    relu                   relu                                                        (bn1,)                                 {}
call_module    maxpool                maxpool                                                     (relu,)                                {}
call_modul

### Methods

**list_nodes() -> List[NodeInfo]**

List all nodes consisting of the model architecture as `NodeInfo` class.

In [6]:
ma_resnet.list_nodes()[:5]

[NodeInfo(opcode='placeholder', name='x', target='x'),
 NodeInfo(opcode='call_module', name='conv1', target='conv1'),
 NodeInfo(opcode='call_module', name='bn1', target='bn1'),
 NodeInfo(opcode='call_module', name='relu', target='relu'),
 NodeInfo(opcode='call_module', name='maxpool', target='maxpool')]

**get_node(name: str) -> NodeInfo**

Get a node by name.

In [7]:
ma_resnet.get_node(name="conv1")

NodeInfo(opcode='call_module', name='conv1', target='conv1')

**find_node(filter_func: Callable, base_node: Optional[NodeInfo]=None, all: bool=False) -> NodeInfo or List[NodeInfo]**

Find node satisfying `filter_func` from the `base_node` in binary search manner.

- `filter_func(n: NodeInfo) -> bool`: a callable returning bool from a `node: NodeInfo`
- `base_node: NodeInfo`: a node where the searching will be started
- `all: bool`: if True, find all nodes satisfying condition and return a list of them

In [8]:
from torch import nn

# find the first conv module
conv_filter = lambda node: isinstance(node.operator, nn.Conv2d)
conv_node = ma_resnet.find_node(conv_filter)
print(conv_node)

NodeInfo(opcode='call_module', name='conv1', target='conv1')


In [9]:
# find all conv module from the model architecture
ma_resnet.find_node(conv_filter, all=True)[:5]

[NodeInfo(opcode='call_module', name='conv1', target='conv1'),
 NodeInfo(opcode='call_module', name='layer1_0_conv1', target='layer1.0.conv1'),
 NodeInfo(opcode='call_module', name='layer1_0_conv2', target='layer1.0.conv2'),
 NodeInfo(opcode='call_module', name='layer1_1_conv1', target='layer1.1.conv1'),
 NodeInfo(opcode='call_module', name='layer1_1_conv2', target='layer1.1.conv2')]

In [10]:
# custom filter_func
from _operator import add

# find all where addition occurs
add_filter = lambda n: n.operator is add
ma_resnet.find_node(add_filter, all=True)[:5]

[NodeInfo(opcode='call_function', name='add', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_1', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_2', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_3', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_4', target='operator.add')]

In [11]:
# find all where cloning occurs
multiuser_filter = lambda n: len(n.users) > 1
ma_resnet.find_node(multiuser_filter, all=True)[:5]

[NodeInfo(opcode='call_module', name='maxpool', target='maxpool'),
 NodeInfo(opcode='call_module', name='layer1_0_relu_1', target='layer1.0.relu'),
 NodeInfo(opcode='call_module', name='layer1_1_relu_1', target='layer1.1.relu'),
 NodeInfo(opcode='call_module', name='layer2_0_relu_1', target='layer2.0.relu'),
 NodeInfo(opcode='call_module', name='layer2_1_relu_1', target='layer2.1.relu')]

In [12]:
# find cam target layer using `find_node`

first_conv_node = ma_resnet.find_node(conv_filter)
print("first_conv_node:", first_conv_node)

pooling_filter = lambda node: (
    node.opcode == "call_module"
    and node.operator.__module__ == "torch.nn.modules.pooling"
)

if first_conv_node:
    last_pool_node = ma_resnet.find_node(
        pooling_filter, # find node satisyfing pooling filter
        root = first_conv_node, # searching starts from the first conv node
        all = True,
    )[-1] # final
    print("last_pool_node:", last_pool_node)

cam_target_node = last_pool_node.prev # a node right before the last pooling
print("cam_target_node:", cam_target_node)

first_conv_node: NodeInfo(opcode='call_module', name='conv1', target='conv1')
last_pool_node: NodeInfo(opcode='call_module', name='avgpool', target='avgpool')
cam_target_node: NodeInfo(opcode='call_module', name='layer4_1_relu_1', target='layer4.1.relu')


In [13]:
# check module stack
cam_target_node.meta

{'nn_module_stack': OrderedDict([('layer4',
               torch.nn.modules.container.Sequential),
              ('layer4.1', torchvision.models.resnet.BasicBlock),
              ('layer4.1.relu', torch.nn.modules.activation.ReLU)])}

**replace_node(node: NodeInfo, new_node: NodeInfo) -> self**

Replace a node to the new node. Current version supports

- replacement of function node (`node.opcode == "call_function`) to module node (`new_node.opcode == "call_module`)

In [14]:
# replace function to module for addition operation

# addition operations in resnet from torchvision is defined by function
add_nodes = ma_resnet.find_node(add_filter, all=True)
add_nodes[:5]

[NodeInfo(opcode='call_function', name='add', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_1', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_2', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_3', target='operator.add'),
 NodeInfo(opcode='call_function', name='add_4', target='operator.add')]

In [15]:
from torch import nn
from open_xai.detector._core import NodeInfo

# define module to replace
class Add(nn.Module):
    def forward(self, a, b):
        return a + b

# define node to replace

# replace all
for node in add_nodes:
    new_node = NodeInfo(
        opcode = "call_module",
        name = None,
        target = None,
        _operator = Add()
    )
    ma_resnet.replace_node(node, new_node)

In [16]:
# check call_function -> call_module for add nodes
ma_resnet.traced_model.graph.print_tabular()

opcode         name                   target                                                      args                                   kwargs
-------------  ---------------------  ----------------------------------------------------------  -------------------------------------  --------
placeholder    x                      x                                                           ()                                     {}
call_module    conv1                  conv1                                                       (x,)                                   {}
call_module    bn1                    bn1                                                         (conv1,)                               {}
call_module    relu                   relu                                                        (bn1,)                                 {}
call_module    maxpool                maxpool                                                     (relu,)                                {}
call_modul

## NodeInfo

`NodeInfo` is a dataclass referred to a `node: torch.fx.Node` in the graph. It is working like the referring node by cloning its main attributions. Also, several properties helping detection were added.

In [17]:
example_node = ma_resnet.get_node("conv1")
example_node

NodeInfo(opcode='call_module', name='conv1', target='conv1')

In [18]:
# referring (original) node
example_node._node

conv1

In [19]:
type(example_node._node)

torch.fx.node.Node

In [20]:
type(example_node)

open_xai.detector._core.NodeInfo

### Cloned attributions from `torch.fx.Node` 

Please see [docs](https://pytorch.org/docs/stable/fx.html#torch.fx.Node) for `torch.fx.Node`.

In [21]:
# clone of n.op
example_node.opcode

'call_module'

In [22]:
# clone of n.name
example_node.name

'conv1'

In [23]:
# clone of n.target (accessible name if n.opcode == "call_module")
example_node.target

'conv1'

In [24]:
# clone of n.meta
example_node.meta

{'nn_module_stack': OrderedDict([('conv1', torch.nn.modules.conv.Conv2d)])}

In [25]:
# clone of n.args
example_node.args

(NodeInfo(opcode='placeholder', name='x', target='x'),)

In [26]:
# clone of n.kwargs
example_node.kwargs

{}

In [27]:
# clone of n.users
example_node.users

(NodeInfo(opcode='call_module', name='bn1', target='bn1'),)

In [28]:
# clone of n.next
example_node.next

NodeInfo(opcode='call_module', name='bn1', target='bn1')

In [29]:
# clone of n.prev
example_node.prev

NodeInfo(opcode='placeholder', name='x', target='x')

### Properties helping detection

#### `operator`

In [30]:
# directly get the matched operator from the node
example_node.operator

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [31]:
# definitely same as one from model
operator_from_model = resnet
for t in example_node.target.split("."):
    operator_from_model = getattr(operator_from_model, t)

operator_from_model is example_node.operator

True

#### `owning_module -> Optional[Tuple[str, nn.Module]]`

Get a module and name owning the node

In [33]:
# If node.opcode == "call_module", its owning module is the operator itself.
example_node.owning_module

('conv1',
 Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))

### Converting data structure

In [36]:
# to dictionary
example_node.to_dict()

{'opcode': 'call_module', 'name': 'conv1', 'target': 'conv1'}

## Example: GradCAM

In [34]:
import torch
from captum.attr import LayerGradCam

inputs = torch.randn(1, 3, 224, 224)

target = resnet(inputs).argmax(1).item()
cam_tgt_nm, cam_tgt_mod = ma_resnet.find_cam_target_module()
explainer = LayerGradCam(resnet, layer=cam_tgt_mod) # here
attrs = explainer.attribute(inputs, target=target)
attrs

tensor([[[[0.0561, 0.0550, 0.0563, 0.0557, 0.0578, 0.0479, 0.0391],
          [0.0893, 0.0940, 0.0961, 0.1078, 0.1153, 0.0864, 0.0563],
          [0.0915, 0.1082, 0.1221, 0.1191, 0.1151, 0.0974, 0.0728],
          [0.0953, 0.1204, 0.1017, 0.0955, 0.0985, 0.0807, 0.0548],
          [0.0856, 0.1008, 0.1039, 0.1026, 0.1004, 0.0846, 0.0578],
          [0.0855, 0.0901, 0.1019, 0.1059, 0.0977, 0.0781, 0.0526],
          [0.0545, 0.0746, 0.0666, 0.0717, 0.0683, 0.0714, 0.0493]]]],
       grad_fn=<SumBackward1>)