In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm
from torchvision import datasets, transforms
from spikingjelly.activation_based import neuron, functional, surrogate, layer

# 定义卷积SNN网络
class MNIST(nn.Module):
    def __init__(self, T=10, use_cupy=False):
        super().__init__()
        self.T = T

        self.conv_fc = nn.Sequential(
            layer.Flatten(),
            layer.Linear(28 * 28, 10, bias=False),
            neuron.LIFNode(tau=2.0, surrogate_function=surrogate.ATan())
        )

        functional.set_step_mode(self, step_mode='m')

        if use_cupy:
            functional.set_backend(self, backend='cupy')

    def forward(self, x: torch.Tensor):
        # x.shape = [N, C, H, W]
        x_seq = x.unsqueeze(0).repeat(self.T, 1, 1, 1, 1)  # [N, C, H, W] -> [T, N, C, H, W]
        x_seq = self.conv_fc(x_seq)
        fr = x_seq.mean(0)
        return fr


In [10]:
# import torch
# # Simple module for demonstration
# class MyModule(torch.nn.Module):
#     def __init__(self) -> None:
#         super().__init__()
#         self.param = torch.nn.Parameter(torch.rand(3, 4))
#         self.linear = torch.nn.Linear(4, 5)

#     def forward(self, x):
#         return self.linear(x + self.param).clamp(min=0.0, max=1.0)

# module = MyModule()

# from torch.fx import symbolic_trace
# # Symbolic tracing frontend - captures the semantics of the module
# symbolic_traced : torch.fx.GraphModule = symbolic_trace(module)

# # High-level intermediate representation (IR) - Graph representation
# print(symbolic_traced.graph)
# print(symbolic_traced.code)

In [11]:
from nnviz import inspection, drawing

model = MNIST()
inspector = inspection.TorchFxInspector()
input = None
graph = inspector.inspect(model, inputs=input)

print(graph)

RuntimeError: You provided a model that cannot be traced with torch.fx and you may need to apply some changes to your source code. See the following link for more information: https://pytorch.org/docs/stable/fx.html#tracing

Also keep in mind that dynamic control flow is currently not supported by torch.fx (e.g. if statements, while loops, etc.) and that you may need to manually convert your model to a static graph before tracing it. The easiest way to trace dynamic models is using the torchscript compiler, which is currently not supported by nnviz. If you really need to plot a dynamic model and you have some good ideas on how to implement it, please open an issue on GitHub and we will discuss it. Be warned that this is not an easy task and that it may require a lot of work.