<a href="https://colab.research.google.com/github/JoshuaNalla/SNN-with-IFA/blob/main/DFA_SNN_version3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# ============================================================================
# CELL 1: IMPORTS AND POISSON ENCODING
# ============================================================================

# Import core libraries
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# for evaluation purposes, importing scikit.metric
from sklearn.metrics import confusion_matrix

def poisson_encoding(x, time_steps, max_rate=0.8):
    """
    Convert static inputs to temporal spike trains via Poisson process.

    Args:
        x: [batch, features] - static input (pixel intensities in [0,1])
        time_steps: number of time steps
        max_rate: maximum firing rate (probability per time step)

    Returns:
        x_time: [batch, time_steps, features] - temporal spike trains
    """
    batch_size = x.shape[0]
    features = x.shape[1]
    x_time = torch.zeros(batch_size, time_steps, features, device=x.device)

    for t in range(time_steps):
        # Generate spikes with probability proportional to input intensity
        spikes = (torch.rand_like(x) < (x * max_rate)).float()
        x_time[:, t, :] = spikes

    return x_time

In [2]:
# ============================================================================
# CELL 2: UPDATED SURROGATE GRADIENTS
# ============================================================================

def surrogate_gradient_exact(a, h_th=0.1, t_ref=1.0, tau=20.0):
    """
    Exact surrogate gradient from paper (Appendix D, Equation A.9).
    PyTorch version.
    """
    eps = 1e-8

    # Compute ratio a / (a - h_th)
    ratio = a / (a - h_th + eps)

    # Numerator: h_th * t_ref * tau / [a * (a - h_th)]
    numerator = h_th * t_ref * tau / (a * (a - h_th) + eps)

    # Denominator: [t_ref + tau * log(ratio)]^2
    log_term = torch.log(ratio + eps)
    denominator = (t_ref + tau * log_term) ** 2 + eps

    grad = numerator / denominator

    # Only non-zero where a > threshold
    grad = torch.where(a > h_th, grad, torch.zeros_like(grad))

    return grad


def surrogate_gradient_fast_sigmoid(a, threshold=0.1, alpha=10):
    """
    Fast sigmoid surrogate gradient (PyTorch version).
    Increased alpha from 5 to 10 for steeper gradients.
    """
    shifted = a - threshold
    grad = 1.0 / (1.0 + torch.abs(alpha * shifted)) ** 2
    return grad



In [3]:
# ============================================================================
# CELL 3: DFA LIF LAYER (unchanged but included for completeness)
# ============================================================================

class DFA_LIFLayer(nn.Module):
    """
    LIF layer that outputs spikes over time.
    Has fixed random feedback matrix B for DFA training.
    """
    def __init__(self, in_features, units, output_size,
                 tau=20.0, dt=0.25, threshold=0.1, t_ref=1.0,
                 use_dfa=True, gamma=0.0338):
        super().__init__()
        self.in_features = in_features
        self.units = units
        self.output_size = output_size

        self.tau = tau
        self.dt = dt
        self.threshold = threshold
        self.t_ref = t_ref
        self.use_dfa = use_dfa
        self.gamma = gamma

        # Trainable weights
        self.w = nn.Parameter(torch.empty(in_features, units))
        self.b = nn.Parameter(torch.zeros(units))
        nn.init.constant_(self.b, 0.2)

        # Init weights
        nn.init.xavier_uniform_(self.w)

        # Refractory steps
        self.ref_steps = int(round(self.t_ref / self.dt))

        # Fixed random feedback matrix B
        if use_dfa:
            B_init = torch.randn(units, output_size) * self.gamma
            self.register_buffer("B", B_init)

    def forward(self, inputs):
        """
        inputs: [batch, time, in_features]
        returns spikes: [batch, time, units]
        """
        batch_size, time_steps, _ = inputs.shape

        v = torch.zeros(batch_size, self.units, device=inputs.device)
        ref_count = torch.zeros(batch_size, self.units, device=inputs.device)

        spikes_out = []

        alpha = 1.0 - (self.dt / self.tau)
        beta  = self.dt / self.tau

        for t in range(time_steps):
            x_t = inputs[:, t, :]

            I_t = (x_t @ self.w) + self.b
            v = alpha * v + beta * I_t

            # Spike generation
            spike = (v >= self.threshold).float()

            # Apply refractory period
            in_ref = (ref_count > 0).float()
            spike = spike * (1.0 - in_ref)

            # Reset voltage for spiking neurons
            v = v * (1.0 - spike)

            # Update refractory counter
            ref_count = torch.where(spike > 0.5,
                                   torch.full_like(ref_count, float(self.ref_steps)),
                                   ref_count)
            ref_count = torch.clamp(ref_count - 1.0, min=0.0)

            spikes_out.append(spike)

        spikes_out = torch.stack(spikes_out, dim=1)
        return spikes_out

In [4]:
# ============================================================================
# CELL 4: UPDATED DFA-SNN MODEL WITH BETTER THRESHOLDS
# ============================================================================

class DFASNN(nn.Module):
    def __init__(self, time_steps, input_size, hidden_size1, hidden_size2, hidden_size3, output_size):
        super().__init__()
        self.time_steps = time_steps

        # 1st hidden layer (DFA) - consistent threshold
        self.hidden_layer1 = DFA_LIFLayer(
            in_features=input_size,
            units=hidden_size1,
            output_size=output_size,
            tau=20.0,
            dt=0.25,
            threshold=0.05,  # Consistent threshold
            t_ref=1.0,
            use_dfa=True,
            gamma=0.0338
        )

        # 2nd hidden layer (DFA)
        self.hidden_layer2 = DFA_LIFLayer(
            in_features=hidden_size1,
            units=hidden_size2,
            output_size=output_size,
            tau=20.0,
            dt=0.25,
            threshold=0.02,  # Consistent threshold
            t_ref=1.0,
            use_dfa=True,
            gamma=0.0338
        )

        # 3rd hidden layer (DFA)
        self.hidden_layer3 = DFA_LIFLayer(
            in_features=hidden_size2,
            units=hidden_size3,
            output_size=output_size,
            tau=20.0,
            dt=0.25,
            threshold=0.01,  # Consistent threshold
            t_ref=1.0,
            use_dfa=True,
            gamma=0.0338
        )

        # Output layer (no DFA)
        self.output_layer = DFA_LIFLayer(
            in_features=hidden_size3,
            units=output_size,
            output_size=output_size,
            tau=20.0,
            dt=0.25,
            threshold=0.15,  # Slightly higher for output
            t_ref=1.0,
            use_dfa=False
        )

    def forward(self, x_time):
        """
        x_time: [batch, time, 784]
        returns:
          h1_spikes: [batch, time, hidden_size1]
          h2_spikes: [batch, time, hidden_size2]
          h3_spikes: [batch, time, hidden_size3]
          y_spikes:  [batch, time, output_size]
        """
        h1_spikes = self.hidden_layer1(x_time)
        h2_spikes = self.hidden_layer2(h1_spikes)
        h3_spikes = self.hidden_layer3(h2_spikes)
        y_spikes = self.output_layer(h3_spikes)
        return h1_spikes, h2_spikes, h3_spikes, y_spikes



In [5]:
# ============================================================================
# CELL 5: UPDATED DFA TRAINER WITH LEARNING RATE DECAY
# ============================================================================

class DFATrainerTorch:
    def __init__(self, model, learning_rate=0.1, use_exact_gradient=False,
                 lr_decay=0.95, decay_every=10):
        self.model = model
        self.lr = learning_rate
        self.initial_lr = learning_rate
        self.lr_decay = lr_decay
        self.decay_every = decay_every
        self.use_exact_gradient = use_exact_gradient
        self.epoch_counter = 0

    @staticmethod
    def compute_spike_rate(spikes):
        return spikes.sum(dim=1)  # [batch, units]

    def step_lr(self):
        """Decay learning rate periodically"""
        self.epoch_counter += 1
        if self.epoch_counter % self.decay_every == 0:
            self.lr *= self.lr_decay
            print(f"  → Learning rate decayed to {self.lr:.6f}")

    def train_step(self, x_time, y_onehot):
        """
        x_time:   [b,t,784]
        y_onehot: [b,10]
        returns: (loss_float, acc_float)
        """
        self.model.train()
        x_time = x_time.to(device)
        y_onehot = y_onehot.to(device)

        # Forward (3 hidden layers)
        h1_spikes, h2_spikes, h3_spikes, y_spikes = self.model(x_time)
        out_rates = self.compute_spike_rate(y_spikes)

        # Temperature scaling for better gradients
        out_rates_norm = out_rates / (out_rates.sum(dim=1, keepdim=True) + 1e-9)
        out_probs = F.softmax(out_rates_norm / 0.5, dim=1)

        eps = 1e-9
        loss = -(y_onehot * torch.log(out_probs + eps)).sum(dim=1).mean()

        # Global error (broadcast to time)
        e_global = out_probs - y_onehot
        e_time = e_global.unsqueeze(1).repeat(1, y_spikes.size(1), 1)

        bsz, T, _ = y_spikes.shape
        denom = float(bsz * T)

        with torch.no_grad():
            # =========================================================
            # 1) OUTPUT LAYER UPDATE (uses real output error)
            # =========================================================
            out_layer = self.model.output_layer
            out_input = h3_spikes

            a_out = (out_input @ out_layer.w) + out_layer.b

            if self.use_exact_gradient:
                fprime_out = surrogate_gradient_exact(
                    a_out, h_th=out_layer.threshold, t_ref=out_layer.t_ref, tau=out_layer.tau
                )
            else:
                fprime_out = surrogate_gradient_fast_sigmoid(
                    a_out, threshold=out_layer.threshold, alpha=10.0
                )

            e_out = e_time * fprime_out

            grad_w_out = torch.einsum("bti,btj->ij", out_input, e_out) / denom
            grad_b_out = e_out.sum(dim=(0, 1)) / denom

            out_layer.w -= self.lr * grad_w_out
            out_layer.b -= self.lr * grad_b_out

            # =========================================================
            # 2) HIDDEN LAYER 3 DFA UPDATE
            # =========================================================
            hid3_layer = self.model.hidden_layer3

            e_proj3 = torch.einsum("bto,ho->bth", e_time, hid3_layer.B)
            a_hid3 = (h2_spikes @ hid3_layer.w) + hid3_layer.b

            if self.use_exact_gradient:
                fprime_hid3 = surrogate_gradient_exact(
                    a_hid3, h_th=hid3_layer.threshold, t_ref=hid3_layer.t_ref, tau=hid3_layer.tau
                )
            else:
                fprime_hid3 = surrogate_gradient_fast_sigmoid(
                    a_hid3, threshold=hid3_layer.threshold, alpha=10.0
                )

            e_hid3 = e_proj3 * fprime_hid3

            grad_w_hid3 = torch.einsum("bti,btj->ij", h2_spikes, e_hid3) / denom
            grad_b_hid3 = e_hid3.sum(dim=(0, 1)) / denom

            hid3_layer.w -= self.lr * grad_w_hid3
            hid3_layer.b -= self.lr * grad_b_hid3

            # =========================================================
            # 3) HIDDEN LAYER 2 DFA UPDATE
            # =========================================================
            hid2_layer = self.model.hidden_layer2

            e_proj2 = torch.einsum("bto,ho->bth", e_time, hid2_layer.B)
            a_hid2 = (h1_spikes @ hid2_layer.w) + hid2_layer.b

            if self.use_exact_gradient:
                fprime_hid2 = surrogate_gradient_exact(
                    a_hid2, h_th=hid2_layer.threshold, t_ref=hid2_layer.t_ref, tau=hid2_layer.tau
                )
            else:
                fprime_hid2 = surrogate_gradient_fast_sigmoid(
                    a_hid2, threshold=hid2_layer.threshold, alpha=10.0
                )

            e_hid2 = e_proj2 * fprime_hid2

            grad_w_hid2 = torch.einsum("bti,btj->ij", h1_spikes, e_hid2) / denom
            grad_b_hid2 = e_hid2.sum(dim=(0, 1)) / denom

            hid2_layer.w -= self.lr * grad_w_hid2
            hid2_layer.b -= self.lr * grad_b_hid2

            # =========================================================
            # 4) HIDDEN LAYER 1 DFA UPDATE
            # =========================================================
            hid1_layer = self.model.hidden_layer1

            e_proj1 = torch.einsum("bto,ho->bth", e_time, hid1_layer.B)
            a_hid1 = (x_time @ hid1_layer.w) + hid1_layer.b

            if self.use_exact_gradient:
                fprime_hid1 = surrogate_gradient_exact(
                    a_hid1, h_th=hid1_layer.threshold, t_ref=hid1_layer.t_ref, tau=hid1_layer.tau
                )
            else:
                fprime_hid1 = surrogate_gradient_fast_sigmoid(
                    a_hid1, threshold=hid1_layer.threshold, alpha=10.0
                )

            e_hid1 = e_proj1 * fprime_hid1

            grad_w_hid1 = torch.einsum("bti,btj->ij", x_time, e_hid1) / denom
            grad_b_hid1 = e_hid1.sum(dim=(0, 1)) / denom

            hid1_layer.w -= self.lr * grad_w_hid1
            hid1_layer.b -= self.lr * grad_b_hid1

        # Accuracy
        preds = torch.argmax(out_probs, dim=1)
        true = torch.argmax(y_onehot, dim=1)
        acc = (preds == true).float().mean().item()

        return loss.item(), acc



In [None]:
# ============================================================================
# CELL 6: TRAINING SETUP WITH ALL FIXES
# ============================================================================

import numpy as np
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# ----------------------------
# Load MNIST data (PyTorch)
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))
])

train_ds_full = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds_full  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

# Convert to tensors
x_train = torch.stack([train_ds_full[i][0] for i in range(len(train_ds_full))]).float()
y_train_int = torch.tensor([train_ds_full[i][1] for i in range(len(train_ds_full))], dtype=torch.long)

x_test  = torch.stack([test_ds_full[i][0] for i in range(len(test_ds_full))]).float()
y_test_int  = torch.tensor([test_ds_full[i][1] for i in range(len(test_ds_full))], dtype=torch.long)

# One-hot encode labels
y_train = F.one_hot(y_train_int, num_classes=10).float()
y_test  = F.one_hot(y_test_int,  num_classes=10).float()

print(f"Training samples: {len(x_train)}")
print(f"Test samples: {len(x_test)}")

# ----------------------------
# Hyperparams
# ----------------------------
TIME_STEPS = 25
HIDDEN_SIZE1 = 1000
HIDDEN_SIZE2 = 250
HIDDEN_SIZE3 = 100
OUTPUT_SIZE = 10
INPUT_SIZE = 784

TRAIN_SAMPLES = 60000
VAL_FRAC = 0.10

EPOCHS = 50
BATCH_SIZE = 128
EVAL_BATCH = 256

# ----------------------------
# CRITICAL FIX: Use Poisson encoding instead of static repetition
# ----------------------------
print("\n" + "="*60)
print("ENCODING INPUTS WITH POISSON PROCESS")
print("="*60)

x_train_small = x_train[:TRAIN_SAMPLES]
y_train_small = y_train[:TRAIN_SAMPLES]

# Poisson encoding with max_rate=0.8
print("Encoding training data...")
x_train_spikes = poisson_encoding(x_train_small, TIME_STEPS, max_rate=0.8)

print(f"Input shape with time: {tuple(x_train_spikes.shape)}")
print(f"  [batch={x_train_spikes.shape[0]}, time={x_train_spikes.shape[1]}, features={x_train_spikes.shape[2]}]")

# ----------------------------
# Build DFA-SNN model
# ----------------------------
print("\n" + "="*60)
print("BUILDING DFA-SNN MODEL")
print("="*60)

model = DFASNN(
    time_steps=TIME_STEPS,
    input_size=INPUT_SIZE,
    hidden_size1=HIDDEN_SIZE1,
    hidden_size2=HIDDEN_SIZE2,
    hidden_size3=HIDDEN_SIZE3,
    output_size=OUTPUT_SIZE
).to(device)

# CRITICAL FIX: Consistent bias initialization
with torch.no_grad():
    model.hidden_layer1.b.fill_(0.5)  # Consistent
    model.hidden_layer2.b.fill_(0.5)  # Consistent
    model.hidden_layer3.b.fill_(0.5)  # Consistent
    model.output_layer.b.fill_(1.0)   # Slightly higher for output

print(model)

# ----------------------------
# Create DFA trainer with LR decay
# ----------------------------
print("\n" + "="*60)
print("INITIALIZING DFA TRAINER")
print("="*60)

trainer = DFATrainerTorch(
    model=model,
    learning_rate=0.1,
    use_exact_gradient=False,  # Use fast sigmoid
    lr_decay=0.95,
    decay_every=10
)

print(f"\nInitial learning rate: {trainer.lr:.6f}")
print(f"LR decay factor: {trainer.lr_decay}")
print(f"Decay every: {trainer.decay_every} epochs")
print(f"Using fast sigmoid surrogate (alpha=10)")

# ----------------------------
# Train/Val split
# ----------------------------
N = x_train_spikes.shape[0]
perm = torch.randperm(N)

val_n = int(N * VAL_FRAC)
val_idx = perm[:val_n]
train_idx = perm[val_n:]

x_tr = x_train_spikes[train_idx]
y_tr = y_train_small[train_idx]
x_val = x_train_spikes[val_idx]
y_val = y_train_small[val_idx]

print(f"\nTrain samples: {x_tr.shape[0]}")
print(f"Val samples:   {x_val.shape[0]}")

# ----------------------------
# Prepare test set with Poisson encoding
# ----------------------------
print("Encoding test data...")
x_test_spikes = poisson_encoding(x_test, TIME_STEPS, max_rate=0.8)
print(f"Test spikes shape: {tuple(x_test_spikes.shape)}")

# ----------------------------
# Evaluation helper
# ----------------------------
@torch.no_grad()
def evaluate(model, x_time, y_onehot, batch_size=256):
    model.eval()
    n = x_time.shape[0]
    total_loss = 0.0
    total_acc = 0.0
    nb = 0
    eps = 1e-9

    for i in range(0, n, batch_size):
        xb = x_time[i:i+batch_size].to(device)
        yb = y_onehot[i:i+batch_size].to(device)

        h1_spikes, h2_spikes, h3_spikes, y_spikes = model(xb)

        out_rates = trainer.compute_spike_rate(y_spikes)
        out_rates_norm = out_rates / (out_rates.sum(dim=1, keepdim=True) + eps)
        out_probs = F.softmax(out_rates_norm / 0.5, dim=1)

        loss = -(yb * torch.log(out_probs + eps)).sum(dim=1).mean()

        preds = torch.argmax(out_probs, dim=1)
        true  = torch.argmax(yb, dim=1)
        acc = (preds == true).float().mean().item()

        total_loss += loss.item()
        total_acc  += acc
        nb += 1

    return total_loss / nb, total_acc / nb

# ----------------------------
# Training loop with val/test each epoch
# ----------------------------
print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)

history = {
    "train_loss": [], "train_acc": [],
    "val_loss": [], "val_acc": [],
    "test_loss": [], "test_acc": []
}

best_test_acc = 0.0

for ep in range(EPOCHS):
    # --- train one epoch ---
    model.train()
    perm = torch.randperm(x_tr.shape[0])

    ep_loss = 0.0
    ep_acc = 0.0
    nb = 0

    for i in range(0, x_tr.shape[0], BATCH_SIZE):
        idx = perm[i:i+BATCH_SIZE]
        loss, acc = trainer.train_step(x_tr[idx], y_tr[idx])
        ep_loss += loss
        ep_acc += acc
        nb += 1

    train_loss = ep_loss / nb
    train_acc  = ep_acc / nb

    # --- eval ---
    val_loss, val_acc = evaluate(model, x_val, y_val, batch_size=EVAL_BATCH)
    test_loss, test_acc = evaluate(model, x_test_spikes, y_test, batch_size=EVAL_BATCH)

    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)
    history["test_loss"].append(test_loss)
    history["test_acc"].append(test_acc)

    # Track best
    if test_acc > best_test_acc:
        best_test_acc = test_acc

    print(f"Epoch {ep+1:3d}/{EPOCHS} | "
          f"train loss {train_loss:.4f} acc {train_acc:.4f} | "
          f"val loss {val_loss:.4f} acc {val_acc:.4f} | "
          f"test loss {test_loss:.4f} acc {test_acc:.4f}")

    # Decay learning rate
    trainer.step_lr()

print(f"\n{'='*60}")
print(f"TRAINING COMPLETE - Best test accuracy: {best_test_acc:.4f} ({best_test_acc*100:.2f}%)")
print(f"{'='*60}")

# ----------------------------
# Plots: Accuracy + Loss
# ----------------------------
epochs = np.arange(1, EPOCHS + 1)

plt.figure(figsize=(14, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs, history["train_acc"], linewidth=2, label="train acc", alpha=0.7)
plt.plot(epochs, history["val_acc"], linewidth=2, label="val acc", alpha=0.7)
plt.plot(epochs, history["test_acc"], linewidth=2, label="test acc", alpha=0.9)
plt.axhline(y=best_test_acc, color='r', linestyle='--', alpha=0.5, label=f'best test: {best_test_acc:.4f}')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
plt.title("Accuracy vs Epochs (DFA-SNN with Poisson Encoding)", fontsize=13)
plt.grid(True, alpha=0.3)
plt.ylim(0, 1.0)
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, history["train_loss"], linewidth=2, label="train loss", alpha=0.7)
plt.plot(epochs, history["val_loss"], linewidth=2, label="val loss", alpha=0.7)
plt.plot(epochs, history["test_loss"], linewidth=2, label="test loss", alpha=0.9)
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.title("Loss vs Epochs (DFA-SNN with Poisson Encoding)", fontsize=13)
plt.grid(True, alpha=0.3)
plt.legend()

plt.tight_layout()
plt.show()


Using device: cuda


100%|██████████| 9.91M/9.91M [00:00<00:00, 45.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.13MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 10.5MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.99MB/s]


Training samples: 60000
Test samples: 10000

ENCODING INPUTS WITH POISSON PROCESS
Encoding training data...
Input shape with time: (60000, 25, 784)
  [batch=60000, time=25, features=784]

BUILDING DFA-SNN MODEL
DFASNN(
  (hidden_layer1): DFA_LIFLayer()
  (hidden_layer2): DFA_LIFLayer()
  (hidden_layer3): DFA_LIFLayer()
  (output_layer): DFA_LIFLayer()
)

INITIALIZING DFA TRAINER

Initial learning rate: 0.100000
LR decay factor: 0.95
Decay every: 10 epochs
Using fast sigmoid surrogate (alpha=10)

Train samples: 54000
Val samples:   6000
Encoding test data...
Test spikes shape: (10000, 25, 784)

STARTING TRAINING
Epoch   1/50 | train loss 2.3024 acc 0.0992 | val loss 2.3034 acc 0.0938 | test loss 2.3030 acc 0.0972
Epoch   2/50 | train loss 2.3025 acc 0.1049 | val loss 2.3019 acc 0.1168 | test loss 2.3019 acc 0.1138
Epoch   3/50 | train loss 2.3034 acc 0.1103 | val loss 2.3024 acc 0.1168 | test loss 2.3024 acc 0.1138
Epoch   4/50 | train loss 2.3034 acc 0.1106 | val loss 2.3024 acc 0.1168