In [12]:
import numpy as np

# NumPy 호환성 패치 (NumPy 1.20+ 대응)
if not hasattr(np, 'int'):
    np.int = int
if not hasattr(np, 'float'):
    np.float = float
if not hasattr(np, 'bool'):
    np.bool = bool
if not hasattr(np, 'complex'):
    np.complex = complex
if not hasattr(np, 'object'):
    np.object = object
    
import spikingjelly
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from spikingjelly.activation_based import surrogate, neuron
import spikingjelly.activation_based.neuron as neuron

try:
    from spikingjelly.cuda_kernel.auto_cuda import cfunction
except BaseException as e:
    import logging
    logging.info(f"spikingjelly.cuda_kernel import failed: {e}")
    cfunction = None

try:
    from spikingjelly.triton_kernel.torch2triton import compile_triton_code_str
except BaseException as e:
    import logging

    logging.info(f"spikingjelly.activation_based.surrogate: {e}")
    compile_triton_code_str = None

tab4_str = "\t\t\t\t"  # used for aligning code
curly_bracket_l = "{"
curly_bracket_r = "}"

In [2]:
def check_manual_grad(primitive_function, spiking_function, *args, **kwargs):
    """
    :param primitive_function: 梯度替代函数的原函数
    :type primitive_function: callable
    :param spiking_function: 梯度替代函数
    :type spiking_function: callable

    梯度替代函数的反向传播一般是手写的，可以用此函数去检查手写梯度是否正确。

    此函数检查梯度替代函数spiking_function的反向传播，与原函数primitive_function的反向传播结果是否一致。“一致”被定义为，两者的误差不超过eps。

    示例代码：

    .. code-block:: python

        def s2nn_apply(x, alpha, beta):
            return surrogate.s2nn.apply(x, alpha, beta)


        surrogate.check_manual_grad(
            surrogate.S2NN.primitive_function, s2nn_apply, alpha=4.0, beta=1.0
        )
    """
    x = torch.arange(-2, 2, 32 / 8192)
    # x = torch.as_tensor([-1., 0., 1.])
    x.requires_grad_(True)
    primitive_function(x, *args, **kwargs).sum().backward()
    x_grad_auto = x.grad.clone()
    x.grad.zero_()
    spiking_function(x, *args, **kwargs).sum().backward()
    x_grad_manual = x.grad.clone()
    print("auto   grad", x_grad_auto)
    print("manual grad", x_grad_manual)
    abs_error = (x_grad_manual - x_grad_auto).abs()
    idx = abs_error.argmax()
    print("max error", abs_error[idx], "occurs at")
    print(f"x[{idx}] = {x[idx]}")
    print("auto   grad", x_grad_auto[idx])
    print("manual grad", x_grad_manual[idx])

In [3]:
def check_cuda_grad(neu, surrogate_function, device, *args, **kwargs):
    # check_cuda_grad(neuron.IFNode, surrogate.S2NN, device='cuda:1', alpha=4., beta=1.)
    for dtype in [torch.float, torch.half]:
        print(dtype)
        net = neu(surrogate_function=surrogate_function(*args, **kwargs), step_mode="m")
        net.to(device)
        x = torch.arange(-2, 2, 32 / 8192, device=device, dtype=dtype)
        x.requires_grad_(True)
        net.backend = "torch"
        net(x.unsqueeze(0)).sum().backward()
        x_grad_py = x.grad.clone()
        x.grad.zero_()
        net.reset()
        net.backend = "cupy"
        net(x.unsqueeze(0)).sum().backward()

        x_grad_cp = x.grad.clone()
        # print('python grad', x_grad_py)
        # print('cupy   grad', x_grad_cp)
        abs_error = (x_grad_cp - x_grad_py).abs()
        idx = abs_error.argmax()
        print("max error", abs_error[idx], "occurs at")
        print(f"x[{idx}] = {x[idx]}")
        print("python grad", x_grad_py[idx])
        print("cupy   grad", x_grad_cp[idx])

In [4]:
check_cuda_grad(neuron.IFNode, surrogate.S2NN, device='cuda:1', alpha=4., beta=1.)

torch.float32
max error tensor(nan, device='cuda:1') occurs at
x[512] = 0.0
python grad tensor(0.0707, device='cuda:1')
cupy   grad tensor(nan, device='cuda:1')
torch.float16
max error tensor(nan, device='cuda:1', dtype=torch.float16) occurs at
x[512] = 0.0
python grad tensor(0.0707, device='cuda:1', dtype=torch.float16)
cupy   grad tensor(nan, device='cuda:1', dtype=torch.float16)


In [5]:
# step function : x가 0보다 크면 spike
def heaviside(x: torch.Tensor):
    return (x>=0).to(x)

In [13]:
def plot_surrogate_function(surrogate_function):
    import matplotlib.pyplot as plt
    import scienceplots

    plt.style.use(["science", "muted", "grid"])
    fig = plt.figure(dpi=200)
    x = torch.arange(-2.5, 2.5, 0.001)
    plt.plot(x.data, heaviside(x), label="Heaviside", linestyle="-.")

    surrogate_function.set_spiking_mode(False)
    y = surrogate_function(x)
    plt.plot(x.data, y.data, label="Primitive")

    surrogate_function.set_spiking_mode(True)
    x.requires_grad_(True)
    y = surrogate_function(x)
    z = y.sum()
    z.backward()
    plt.plot(x.data, x.grad, label="Gradient")

    plt.xlim(-2, 2)
    plt.legend()
    plt.title(f"{surrogate_function.__class__.__name__} surrogate function")
    plt.xlabel("Input")
    plt.ylabel("Output")
    plt.grid(linestyle="--")
    plt.savefig(f"./{surrogate_function.__class__.__name__}.pdf", bbox_inches="tight")
    plt.show()

In [14]:
my_surrogate = surrogate.Sigmoid(alpha=4.0)
plot_surrogate_function(my_surrogate)

RuntimeError: Failed to process string with tex because latex could not be found

Error in callback <function _draw_all_if_interactive at 0x7f91c0fecfe0> (for post_execute), with arguments args (),kwargs {}:


RuntimeError: Failed to process string with tex because latex could not be found

RuntimeError: Failed to process string with tex because latex could not be found

<Figure size 700x525 with 1 Axes>

In [9]:
import matplotlib.pyplot as plt

In [10]:
surrogates = {
    'Sigmoid (alpha=4)': surrogate.Sigmoid(alpha=4.0),
    'ATan (alpha=2)': surrogate.ATan(alpha=2.0),
    'Piecewise Leaky ReLU': surrogate.PiecewiseLeakyReLU(w=1.0)
}

# 2. x값 준비 (전압이라고 생각하세요)
x = torch.arange(-2, 2, 0.01)
x.requires_grad = True # 미분 추적 켜기

plt.figure(figsize=(10, 5))

for name, func in surrogates.items():
    if x.grad is not None:
        x.grad.zero_() # 이전 기울기 초기화
    
    # 순전파: 사실 이건 다 똑같은 계단 함수(0 or 1)입니다.
    y = func(x)
    
    # 역전파: 여기서 각자 다른 '가짜 기울기'가 계산됩니다.
    y.sum().backward()
    
    # 그리기
    plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label=name)

plt.title("Surrogate Gradients Shapes")
plt.xlabel("Membrane Potential (x)")
plt.ylabel("Gradient (dx)")
plt.legend()
plt.grid()
plt.show()

RuntimeError: Failed to process string with tex because latex could not be found

<Figure size 1000x500 with 1 Axes>