In [20]:
import torch
import torch.nn as nn
from spikingjelly.activation_based import neuron, functional, monitor

# Monitor

`spikingjelly.activation_based.monitor` has defined some commonly used monitors, with which the users can record the data that they are interested in. Now let us try these monitors.

## Usage

In [22]:
net = nn.Sequential(
    nn.Linear(in_features=8, out_features=4, bias=True),
    neuron.IFNode(step_mode='m'),
    nn.Linear(in_features=4, out_features=2, bias=True),
    neuron.IFNode(step_mode='m')
)

In [23]:
spike_seq_monitor = monitor.OutputMonitor(net, neuron.IFNode)

In [24]:
T, N = 4, 1
x_seq = torch.rand([T, N, 8])
with torch.no_grad():
    net(x_seq)

tensor([[[0.0799, 0.3078, 0.5607, 0.1194, 0.7777, 0.1256, 0.3183, 0.0404]],

        [[0.4675, 0.3792, 0.3507, 0.4036, 0.2648, 0.8790, 0.2643, 0.8474]],

        [[0.9221, 0.5476, 0.3989, 0.9733, 0.6206, 0.7059, 0.1174, 0.2876]],

        [[0.0225, 0.9355, 0.6947, 0.9885, 0.1407, 0.7220, 0.8831, 0.0446]]])


In [25]:
print(f'spike_seq_monitor.records=\n{spike_seq_monitor.records}')


spike_seq_monitor.records=
[tensor([[[0., 0., 0., 0.]],

        [[0., 0., 0., 0.]],

        [[1., 0., 0., 0.]],

        [[0., 0., 0., 1.]]]), tensor([[[0., 0.]],

        [[0., 0.]],

        [[0., 0.]],

        [[0., 0.]]])]


In [28]:
print(f'spike_seq_monitor[0]={spike_seq_monitor[0]}')


spike_seq_monitor[0]=tensor([[[0., 0., 0., 0.]],

        [[0., 0., 0., 0.]],

        [[1., 0., 0., 0.]],

        [[0., 0., 0., 1.]]])


In [32]:
# Recorded data
print(f'net={net}')
print(f'spike_seq_monitor.monitored_layers={spike_seq_monitor.monitored_layers}')

net=Sequential(
  (0): Linear(in_features=8, out_features=4, bias=True)
  (1): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
    (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
  )
  (2): Linear(in_features=4, out_features=2, bias=True)
  (3): IFNode(
    v_threshold=1.0, v_reset=0.0, detach_reset=False, step_mode=m, backend=torch
    (surrogate_function): Sigmoid(alpha=4.0, spiking=True)
  )
)
spike_seq_monitor.monitored_layers=['1', '3']


In [33]:
## i-th value
print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")

spike_seq_monitor['1']=[]


In [34]:
## Delete recorded data
spike_seq_monitor.clear_recorded_data()
print(f'spike_seq_monitor.records={spike_seq_monitor.records}')
print(f"spike_seq_monitor['1']={spike_seq_monitor['1']}")

spike_seq_monitor.records=[]
spike_seq_monitor['1']=[]


## remove hooks


In [35]:
spike_seq_monitor.remove_hooks()

 If we want to record the firing rates, we can define the function of calculating the firing rates:

In [36]:
def cal_firing_rate(s_seq: torch.Tensor):
    # s_seq.shape = [T, N, *]
    return s_seq.flatten(1).mean(1)

Then, we can set this function as `function_on_output` to get a firing rates monitor:

In [37]:
fr_monitor = monitor.OutputMonitor(net, neuron.IFNode, cal_firing_rate)

`.disable()` can pause monitor, and `.enable()` can restart monitor:

In [38]:
with torch.no_grad():
    fr_monitor.disable()
    net(x_seq)
    functional.reset_net(net)
    print(f'after call fr_monitor.disable(), fr_monitor.records=\n{fr_monitor.records}')

    fr_monitor.enable()
    net(x_seq)
    print(f'after call fr_monitor.enable(), fr_monitor.records=\n{fr_monitor.records}')
    functional.reset_net(net)
    del fr_monitor

after call fr_monitor.disable(), fr_monitor.records=
[]
after call fr_monitor.enable(), fr_monitor.records=
[tensor([0.0000, 0.0000, 0.2500, 0.2500]), tensor([0., 0., 0., 0.])]


## Record attributes
To record the attributes of some modules, e.g., the membrane potential, we can use `spikingjelly.activation_based.monitor.AttributeMonitor`.

`store_v_seq: bool = False` is the default arg in `__init__` of spiking neurons, which means only `v` at the last time-step will be stored, and v_seq at each time-step will not be sotred. To record all $V[t]$, we set `store_v_seq = True`:

In [39]:
for m in net.modules():
    if isinstance(m, neuron.IFNode):
        m.store_v_seq = True

Then, we use `spikingjelly.activation_based.monitor.AttributeMonitor` to record:

In [40]:
v_seq_monitor = monitor.AttributeMonitor('v_seq', pre_forward=False, net=net, instance=neuron.IFNode)
with torch.no_grad():
    net(x_seq)
    print(f'v_seq_monitor.records=\n{v_seq_monitor.records}')
    functional.reset_net(net)
    del v_seq_monitor

v_seq_monitor.records=
[tensor([[[ 0.3305,  0.0134, -0.1621,  0.1869]],

        [[ 0.5239, -0.2477, -0.5901,  0.2290]],

        [[ 0.0000, -0.3150, -1.3329,  0.5524]],

        [[ 0.4087, -0.5824, -1.7561,  0.0000]]]), tensor([[[-0.0629,  0.3722]],

        [[-0.1257,  0.7443]],

        [[ 0.2858,  0.7879]],

        [[-0.0122,  0.8719]]])]


## Record inputs
To record inputs, we can use `spikingjelly.activation_based.monitor.InputMonitor`, which is similar to `spikingjelly.activation_based.monitor.OutputMonitor`:

In [44]:
input_monitor = monitor.InputMonitor(net, neuron.IFNode)
with torch.no_grad():
    net(x_seq)
    print(f'input_monitor.records=\n{input_monitor.records}')
    functional.reset_net(net)
    del input_monitor

input_monitor.records=
[tensor([[[-0.0884,  0.4653, -0.3975, -0.1535,  0.6720, -0.0904,  0.5490,
           1.0915]],

        [[-0.3488,  0.6251, -1.0215, -0.2039,  0.7428, -0.1111,  0.5008,
           1.1083]],

        [[-0.0718,  0.4112, -0.7085, -0.2981,  0.3660, -0.2687,  0.3207,
           1.2908]],

        [[ 0.1667,  0.2247, -0.3627,  0.0313,  0.3476, -0.0524,  0.3441,
           0.8014]]]), tensor([[[ 0.4228,  0.0509, -0.1642, -0.3007, -0.4748,  0.5256, -0.1688,
           0.2881]],

        [[-0.1293, -0.5687,  0.3038, -0.1025, -0.3947,  0.3971, -0.1093,
           1.0233]],

        [[ 0.4228,  0.0509, -0.1642, -0.3007, -0.4748,  0.5256, -0.1688,
           0.2881]],

        [[ 0.1237,  0.2823, -0.0546, -0.1617, -0.1641,  0.3154,  0.0259,
           0.0240]]]), tensor([[[ 0.3416, -0.2761,  0.1017,  0.0788,  0.1890,  0.2876, -0.1870,
          -0.0421]],

        [[ 0.2122, -0.5372,  0.0986, -0.2290,  0.1446,  0.4215, -0.1969,
          -0.3066]],

        [[ 0.2234, -0.16