In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple

class SRC(nn.Module):
    """
    Spike-based Recurrent Cell (SRC) layer.

    Defaults set to values:
      alpha_init = 0.99
      rho = 3.0
      r = 2.0
      rs = -7.0
      bh_init = -6.0 (clamped to <= bh_max = -4.0 during forward)
      z = 0.0
      zhyp_s = 0.9
      zdep_s = 0.0

    Input shapes:
      - seq-first: (seq_len, batch, input_size) (default)
      - if batch_first=True: (batch, seq_len, input_size)
    Output:
      - s_out_seq: (seq_len, batch, hidden_size)
      - final states: (h_T, hs_T, i_T) each (batch, hidden_size)
    """
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        batch_first: bool = False,
        # Table 1 defaults:
        rho: float = 3.0,
        r: float = 2.0,
        rs: float = -7.0,
        z: float = 0.0,
        zhyp_s: float = 0.9,
        zdep_s: float = 0.0,
        bh_init: float = -6.0,
        bh_max: float = -4.0,
        alpha_init: float = 0.99,
        learnable_alpha: bool = False,
        learnable_bh: bool = True,
        device=None,
        dtype=None
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.batch_first = batch_first

        # constants and buffers (per-neuron vectors for broadcasting)
        self.register_buffer("_r", torch.full((hidden_size,), float(r), **factory_kwargs))
        self.register_buffer("_rs", torch.full((hidden_size,), float(rs), **factory_kwargs))
        self.register_buffer("_z", torch.full((hidden_size,), float(z), **factory_kwargs))

        # zs constants (scalars)
        self._zhyp_s = float(zhyp_s)
        self._zdep_s = float(zdep_s)

        self.rho = float(rho)
        self.bh_max = float(bh_max)

        # input projection weights Ws: shape (hidden_size, input_size)
        self.Ws = nn.Parameter(torch.empty(hidden_size, input_size, **factory_kwargs))

        # bias bh per neuron (learnable optionally)
        if learnable_bh:
            self.bh = nn.Parameter(torch.full((hidden_size,), float(bh_init), **factory_kwargs))
        else:
            self.register_buffer("bh", torch.full((hidden_size,), float(bh_init), **factory_kwargs))

        # alpha: leaky integrator coefficient (learnable optional)
        self.learnable_alpha = bool(learnable_alpha)
        if self.learnable_alpha:
            # parameterize via logit so sigmoid(alpha_raw) in (0,1)
            alpha_raw_init = torch.logit(torch.tensor(alpha_init, **factory_kwargs).clamp(1e-6, 1 - 1e-6))
            self.alpha_raw = nn.Parameter(torch.full((hidden_size,), float(alpha_raw_init), **factory_kwargs))
        else:
            self.register_buffer("_alpha_scalar", torch.tensor(float(alpha_init), **factory_kwargs))

        self.reset_parameters()

    def reset_parameters(self):
        # initialize Ws with a Kaiming-like uniform; small random init is good
        nn.init.kaiming_uniform_(self.Ws, a=5**0.5)
        # other params/buffers already set through init

    @property
    def alpha(self) -> torch.Tensor:
        """Return alpha in (0,1) shape (hidden_size,)"""
        if self.learnable_alpha:
            return torch.sigmoid(self.alpha_raw)
        else:
            return self._alpha_scalar.expand(self.hidden_size)

    def _compute_zs(self, h: torch.Tensor) -> torch.Tensor:
        """
        zs[h] = zhyp_s + (zdep_s - zhyp_s) * H(h - 0.5)
        h shape: (batch, hidden_size)
        returns zs shape: (batch, hidden_size)
        """
        step = (h >= 0.5).to(h.dtype)
        zs_val = self._zhyp_s + (self._zdep_s - self._zhyp_s) * step
        return zs_val

    def forward(
        self,
        s_in: torch.Tensor,
        hx: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
        """
        s_in: input pulses, shape (seq_len, batch, input_size) or (batch, seq_len, input_size) if batch_first
        hx: optional tuple (h0, hs0, i0)
        """
        if self.batch_first:
            s_in = s_in.transpose(0, 1)  # -> (seq, batch, in)

        seq_len, batch_size, in_size = s_in.shape
        assert in_size == self.input_size, f"input_size mismatch: got {in_size}, expected {self.input_size}"

        device = s_in.device
        dtype = s_in.dtype

        # initial states
        if hx is None:
            h = torch.zeros(batch_size, self.hidden_size, device=device, dtype=dtype)
            hs = torch.zeros_like(h)
            i = torch.zeros_like(h)
        else:
            h, hs, i = hx
            if h is None:
                h = torch.zeros(batch_size, self.hidden_size, device=device, dtype=dtype)
            if hs is None:
                hs = torch.zeros_like(h)
            if i is None:
                i = torch.zeros_like(h)

        # constant vectors
        r = self._r.to(device=device, dtype=dtype)
        rs = self._rs.to(device=device, dtype=dtype)
        z = self._z.to(device=device, dtype=dtype)
        alpha = self.alpha.to(device=device, dtype=dtype)

        s_out_seq = []
        bh_clamped = torch.clamp(self.bh, max=self.bh_max).to(device=device, dtype=dtype)

        for t in range(seq_len):
            s_in_t = s_in[t]  # (batch, input_size)
            projected = F.linear(s_in_t, self.Ws)  # (batch, hidden_size)

            alpha_b = alpha.unsqueeze(0)  # (1, hidden_size)
            i = alpha_b * i + projected  # i[t] = alpha * i[t-1] + Ws * s_in[t]

            # x[t] = rho * tanh(i / rho)
            x = self.rho * torch.tanh(i / self.rho)

            # candidate hidden state
            h_cand = torch.tanh(x + h * r.unsqueeze(0) + hs * rs.unsqueeze(0) + bh_clamped.unsqueeze(0))

            # update h: h[t] = z * h_prev + (1-z) * h_cand (z default 0 -> h = h_cand)
            z_b = z.unsqueeze(0)
            h_new = z_b * h + (1.0 - z_b) * h_cand

            # compute zs from updated h (matches Eq.5c's use of h[t] in the step)
            zs = self._compute_zs(h_new)  # (batch, hidden_size)

            # update hs: hs[t] = zs * hs_prev + (1-zs) * h_prev (uses previous h)
            hs_new = zs * hs + (1.0 - zs) * h

            # output spikes: ReLU(h[t])
            s_out_t = F.relu(h_new)

            # save state
            h = h_new
            hs = hs_new

            s_out_seq.append(s_out_t.unsqueeze(0))

        s_out_seq = torch.cat(s_out_seq, dim=0)  # (seq_len, batch, hidden_size)
        return s_out_seq, (h, hs, i)

    def step(
        self,
        s_in_t: torch.Tensor,
        hx: Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = None
    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
        """Single-step convenience wrapper. s_in_t shape: (batch, input_size)"""
        seq_in = s_in_t.unsqueeze(0)
        out_seq, states = self.forward(seq_in, hx=hx)
        return out_seq[0], states


In [4]:
# smoke
batch = 2
seq = 20
inp = 5
hid = 16
cell = SRC(input_size=inp, hidden_size=hid, batch_first=False)
s_in = (torch.rand(seq, batch, inp) > 0.97).float()  # very sparse pulses
s_out, (h_T, hs_T, i_T) = cell(s_in)
print(s_out.shape)  # (seq, batch, hid)
print("spiking timesteps:", (s_out > 0).float().sum().item())


torch.Size([20, 2, 16])
spiking timesteps: 0.0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# ---- SRC cell (paste the previous SRC class here) ----
# Make sure the SRC class code is already defined in your environment.

# ---- 1. Generate synthetic spiking dataset ----
def generate_spike_data(seq_len=20, batch_size=64, input_size=5):
    """
    Generates random sparse spike inputs and a binary label based on a simple rule:
    label = 1 if sum of spikes in first 3 input neurons at last timestep > 1
    """
    X = (torch.rand(seq_len, batch_size, input_size) > 0.9).float()
    # simple rule: sum of first 3 inputs at last timestep > 1
    y = (X[-1, :, :3].sum(dim=1) > 1).float()
    return X, y

# ---- 2. Define a small network with SRC ----
class SRCNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.src = SRC(input_size=input_size, hidden_size=hidden_size)
        self.readout = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # x: (seq_len, batch, input_size)
        s_out, _ = self.src(x)  # s_out: (seq, batch, hidden)
        last_h = s_out[-1]      # use last timestep for readout
        out = self.readout(last_h)
        return out

# ---- 3. Surrogate gradient helper ----
# We'll use a simple straight-through approximation for ReLU spikes
class SurrogateSpike(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        # simple surrogate gradient: gradient of fast sigmoid
        return grad_output * torch.exp(-input.abs())  # decays for large inputs

# Replace F.relu with SurrogateSpike if you want gradients through spikes
# For simplicity, weâ€™ll continue with normal ReLU in this toy example

# ---- 4. Training loop ----
# Hyperparameters
input_size = 5
hidden_size = 16
output_size = 1  # binary classification
seq_len = 20
batch_size = 64
epochs = 30
lr = 1e-2

device = "cuda" if torch.cuda.is_available() else "cpu"

# model, loss, optimizer
model = SRCNetwork(input_size, hidden_size, output_size).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# training
for epoch in range(epochs):
    X, y = generate_spike_data(seq_len, batch_size, input_size)
    X, y = X.to(device), y.to(device)

    optimizer.zero_grad()
    logits = model(X)  # (batch, 1)
    loss = criterion(logits.squeeze(), y)
    loss.backward()
    optimizer.step()

    with torch.no_grad():
        pred = (torch.sigmoid(logits.squeeze()) > 0.5).float()
        acc = (pred == y).float().mean().item()

    print(f"Epoch {epoch+1:02d}, Loss: {loss.item():.4f}, Acc: {acc:.4f}")

