# Detector: ModelArchitecture

## Models

In [1]:
import torchvision

resnet = torchvision.models.get_model("resnet18").eval()
vit = torchvision.models.get_model("vit_b_16").eval()

## ModelArchitecture

### Construct an instance

In [2]:
from torch.fx import symbolic_trace
from open_xai.detector import ModelArchitecture

traced_resnet = symbolic_trace(resnet)
ma_resnet = ModelArchitecture(graph=traced_resnet.graph)

  torch.has_cuda,
  torch.has_cudnn,
  torch.has_mps,
  torch.has_mkldnn,


**from_model** a classmethod directly constructing the instance from a model

In [3]:

ma_resnet = ModelArchitecture.from_model(resnet)
ma_vit = ModelArchitecture.from_model(vit)

### Methods

**list_nodes() -> List[NodeInfo]**

List all nodes consisting of the model architecture as `NodeInfo` class.

In [4]:
ma_resnet.list_nodes()[:5]

[NodeInfo(opcode='placeholder', name='x', target='x'),
 NodeInfo(opcode='call_module', name='conv1', target='conv1'),
 NodeInfo(opcode='call_module', name='bn1', target='bn1'),
 NodeInfo(opcode='call_module', name='relu', target='relu'),
 NodeInfo(opcode='call_module', name='maxpool', target='maxpool')]

**get_node(name: str) -> NodeInfo**

Get a node by name.

In [5]:
ma_resnet.get_node(name="conv1")

NodeInfo(opcode='call_module', name='conv1', target='conv1')

**find_node(filter_func: Callable, base_node: Optional[NodeInfo]=None, all: bool=False) -> NodeInfo or List[NodeInfo]**

Find node satisfying `filter_func` from the `base_node` in binary search manner.

- `filter_func(n: NodeInfo) -> bool`: a callable returning bool from a `node: NodeInfo`
- `base_node: NodeInfo`: a node where the searching will be started
- `all: bool`: if True, find all nodes satisfying condition and return a list of them

In [6]:
from open_xai.detector.filters import conv_filter

# find the first conv module
ma_resnet.find_node(conv_filter)

NodeInfo(opcode='call_module', name='conv1', target='conv1')

In [7]:
# find all conv module from the model architecture
ma_resnet.find_node(conv_filter, all=True)

[NodeInfo(opcode='call_module', name='conv1', target='conv1'),
 NodeInfo(opcode='call_module', name='layer1_0_conv1', target='layer1.0.conv1'),
 NodeInfo(opcode='call_module', name='layer1_0_conv2', target='layer1.0.conv2'),
 NodeInfo(opcode='call_module', name='layer1_1_conv1', target='layer1.1.conv1'),
 NodeInfo(opcode='call_module', name='layer1_1_conv2', target='layer1.1.conv2'),
 NodeInfo(opcode='call_module', name='layer2_0_conv1', target='layer2.0.conv1'),
 NodeInfo(opcode='call_module', name='layer2_0_conv2', target='layer2.0.conv2'),
 NodeInfo(opcode='call_module', name='layer2_0_downsample_0', target='layer2.0.downsample.0'),
 NodeInfo(opcode='call_module', name='layer2_1_conv1', target='layer2.1.conv1'),
 NodeInfo(opcode='call_module', name='layer2_1_conv2', target='layer2.1.conv2'),
 NodeInfo(opcode='call_module', name='layer3_0_conv1', target='layer3.0.conv1'),
 NodeInfo(opcode='call_module', name='layer3_0_conv2', target='layer3.0.conv2'),
 NodeInfo(opcode='call_module', 

In [30]:
# custom filter_func
from _operator import add

# find all where addition occurs
conn_filter = lambda n: n.operator is add
ma_resnet.find_node(conn_filter, all=True)

[NodeInfo(opcode='call_function', name='add', target=<built-in function add>),
 NodeInfo(opcode='call_function', name='add_1', target=<built-in function add>),
 NodeInfo(opcode='call_function', name='add_2', target=<built-in function add>),
 NodeInfo(opcode='call_function', name='add_3', target=<built-in function add>),
 NodeInfo(opcode='call_function', name='add_4', target=<built-in function add>),
 NodeInfo(opcode='call_function', name='add_5', target=<built-in function add>),
 NodeInfo(opcode='call_function', name='add_6', target=<built-in function add>),
 NodeInfo(opcode='call_function', name='add_7', target=<built-in function add>)]

In [9]:
# find all where cloning occurs
clone_filter = lambda n: len(n.users) == 2
ma_resnet.find_node(clone_filter, all=True)

[NodeInfo(opcode='call_module', name='maxpool', target='maxpool'),
 NodeInfo(opcode='call_module', name='layer1_0_relu_1', target='layer1.0.relu'),
 NodeInfo(opcode='call_module', name='layer1_1_relu_1', target='layer1.1.relu'),
 NodeInfo(opcode='call_module', name='layer2_0_relu_1', target='layer2.0.relu'),
 NodeInfo(opcode='call_module', name='layer2_1_relu_1', target='layer2.1.relu'),
 NodeInfo(opcode='call_module', name='layer3_0_relu_1', target='layer3.0.relu'),
 NodeInfo(opcode='call_module', name='layer3_1_relu_1', target='layer3.1.relu'),
 NodeInfo(opcode='call_module', name='layer4_0_relu_1', target='layer4.0.relu')]

**find_cam_target_node() -> NodeInfo**

Find CAM-targetable node. Internally, it uses `find_node` method with predefined `filter_func` to detect CAM-targetables.

In [10]:
# CAM-target node in resnet
node = ma_resnet.find_cam_target_node()
node

NodeInfo(opcode='call_module', name='layer4_1_relu_1', target='layer4.1.relu')

In [11]:
# No CAM target in vit
ma_vit.find_cam_target_node()

## NodeInfo

`NodeInfo` is a dataclass referred to a `node: torch.fx.Node` in the graph. It is working like the referring node by cloning its main attributions. Also, several properties helping detection were added.

In [12]:
node

NodeInfo(opcode='call_module', name='layer4_1_relu_1', target='layer4.1.relu')

In [13]:
# referring (original) node
node._node

layer4_1_relu_1

In [14]:
type(node._node)

torch.fx.node.Node

In [15]:
type(node)

open_xai.detector._core.NodeInfo

### Cloned attributions from `torch.fx.Node` 

Please see [docs](https://pytorch.org/docs/stable/fx.html#torch.fx.Node) for `torch.fx.Node`.

In [16]:
# clone of n.op
node.opcode

'call_module'

In [17]:
# clone of n.name
node.name

'layer4_1_relu_1'

In [18]:
# clone of n.target (accessible name if n.opcode == "call_module")
node.target

'layer4.1.relu'

In [19]:
# clone of n.args
node.args

(NodeInfo(opcode='call_function', name='add_7', target=<built-in function add>),)

In [20]:
# clone of n.kwargs
node.kwargs

{}

In [21]:
# clone of n.users
node.users

(NodeInfo(opcode='call_module', name='avgpool', target='avgpool'),)

In [22]:
# clone of n.next
node.next

NodeInfo(opcode='call_module', name='avgpool', target='avgpool')

In [23]:
# clone of n.prev
node.prev

NodeInfo(opcode='call_function', name='add_7', target=<built-in function add>)

### Properties helping detection

In [24]:
# directly get the matched operator from the node
node.operator

ReLU(inplace=True)

In [25]:
# definitely same as one from model
operator_from_model = resnet
for t in node.target.split("."):
    operator_from_model = getattr(operator_from_model, t)

operator_from_model is node.operator

True

In [26]:
# detection example using node.operator

pool_filter = lambda n: (
    n.opcode == "call_module"
    and n.operator.__module__ == "torch.nn.modules.pooling"
)
ma_resnet.find_node(pool_filter, all=True)

[NodeInfo(opcode='call_module', name='maxpool', target='maxpool'),
 NodeInfo(opcode='call_module', name='avgpool', target='avgpool')]

### converting structure

In [27]:
# to dictionary
node.to_dict()

{'opcode': 'call_module', 'name': 'layer4_1_relu_1', 'target': 'layer4.1.relu'}

In [28]:
# [TODO] to json
node.to_json()

## Example: GradCAM

In [29]:
import torch
from captum.attr import LayerGradCam

inputs = torch.randn(1, 3, 224, 224)

target = resnet(inputs).argmax(1).item()
explainer = LayerGradCam(resnet, layer=node.operator) # here
attrs = explainer.attribute(inputs, target=target)
attrs

tensor([[[[0.0545, 0.0553, 0.0552, 0.0601, 0.0620, 0.0594, 0.0419],
          [0.0821, 0.0789, 0.0940, 0.1166, 0.1117, 0.1091, 0.0835],
          [0.0818, 0.0944, 0.0925, 0.1173, 0.1222, 0.1291, 0.0880],
          [0.0750, 0.0921, 0.0998, 0.1225, 0.1283, 0.1159, 0.0807],
          [0.0835, 0.1015, 0.1190, 0.1267, 0.1299, 0.1091, 0.0859],
          [0.0726, 0.1072, 0.0970, 0.1198, 0.1027, 0.1030, 0.0893],
          [0.0446, 0.0804, 0.0749, 0.0937, 0.0806, 0.0742, 0.0696]]]],
       grad_fn=<SumBackward1>)