In [1]:
# @title
# Import core libraries
import tensorflow as tf  # Main deep learning framework
from tensorflow import keras  # High-level neural network API
import numpy as np  # Numerical computing library
import matplotlib.pyplot as plt  # Plotting library for visualization

# for evaluation purposes, importing scikit.metric
from sklearn.metrics import confusion_matrix

In [2]:
# new cell 2 for LIF neuron
class LIFNeuron:
    """
    Leaky Integrate-and-Fire neuron implementation.

    This class demonstrates the LIF dynamics before incorporating into layers.
    Paper reference: Appendix A, Equations A.1-A.4
    """

    def __init__(self, tau=20.0, dt=0.25, threshold=0.4, t_ref=1.0):
        """
        Args:
            tau: Membrane time constant (ms) - controls leak rate
            dt: Time step (ms) - discretization step
            threshold: Spike threshold (h_th)
            t_ref: Refractory period (ms) - time after spike when neuron can't fire

        Paper settings (Section B):
            - tau = 20ms
            - dt = 0.25ms
            - threshold = 0.4
            - t_ref = 1ms
        """
        self.tau = tau # the leak rate
        self.dt = dt # time step
        self.threshold = threshold # voltage threshold
        self.t_ref = t_ref # time after the neuron cant fire

        # Calculate leak factor (1 - dt/tau)
        self.alpha = 1.0 - (dt / tau)
        # Calculate input scaling (dt/tau)
        self.beta = dt / tau

        # State variables
        self.v = 0.0  # Membrane potential
        self.ref_count = 0  # Refractory counter (in time steps)

    def step(self, input_current):
        """
        Single time step of LIF dynamics.

        Args:
            input_current: Weighted sum of inputs at this time step

        Returns:
            spike: 1 if neuron fired, 0 otherwise
        """
        spike = 0

        # Check if in refractory period
        if self.ref_count > 0:
            self.ref_count -= 1
            self.v = 0.0  # Keep voltage at 0 during refractory
            return spike

        # Update membrane potential (Equation A.2)
        self.v = self.alpha * self.v + self.beta * input_current

        # Check for spike (Equation A.3)
        if self.v >= self.threshold:
            spike = 1
            self.v = 0.0  # Reset (Equation A.4)
            # Enter refractory period
            self.ref_count = int(self.t_ref / self.dt)
            # we want to wait 1 ms, so we divide the amount of time by the time steps
            # in this particular instance of 1 ms, we would have to wait 4 time steps
            # before this neuron can fire again

        return spike

In [None]:
#initializing fixed random feedback matrices

In [3]:
# @title
def compute_weight_stats(num_neurons, v_mean=8.0, v_second_moment=164.0, alpha=0.066):
    """
    Compute weight initialization statistics.
    Paper reference: Appendix C, Equations A.7 and A.8

    Args:
        num_neurons: Number of neurons in layer (N)
        v_mean: Mean input value (v̄)
        v_second_moment: Second moment of input (v̄̄)
        alpha: Constant (0.066)

    Returns:
        w_mean: Mean weight value (W̄_n)
        w_std: Standard deviation (σ_{W_n})
    """
    # Equation A.7: W̄_n = (v̄ - 0.8) / (α · N · v̄)
    w_mean = (v_mean - 0.8) / (alpha * num_neurons * v_mean)

    # Equation A.8: W̄̄_n for computing standard deviation
    numerator = (v_second_moment +
                 alpha**2 * (num_neurons - num_neurons**2) * w_mean**2 * v_mean**2 -
                 1.6 * alpha * num_neurons * v_mean * w_mean -
                 0.64)
    denominator = alpha**2 * num_neurons * v_second_moment

    w_second_moment = numerator / denominator

    # Calculate standard deviation from second moment
    w_std = np.sqrt(w_second_moment - w_mean**2)

    return w_mean, w_std


def initialize_feedback_matrix(output_shape, input_shape,
                                w_mean_next, w_std_next,
                                gamma=0.0338,
                                num_downstream_layers=1):
    """
    Initialize fixed random feedback matrix B.
    Paper reference: Appendix B, Equation A.5

    Args:
        output_shape: Number of neurons this layer projects to
        input_shape: Number of neurons in this layer
        w_mean_next: Mean of weights in next layer
        w_std_next: Std of weights in next layer
        gamma: Scale factor (0.0338 in paper)
        num_downstream_layers: Number of layers between this and output (D)

    Returns:
        B: Fixed random feedback matrix [input_shape, output_shape]
    """
    # Generate random matrix with paper's specific distribution
    # B_n = γ · [W̄_{n+1} + 2√3 · σ_{W_{n+1}} · (rand - 0.5)]

    rand_values = np.random.uniform(0, 1, size=(input_shape, output_shape))

    # Apply paper's formula
    B = w_mean_next + 2 * np.sqrt(3) * w_std_next * (rand_values - 0.5)

    # Apply scaling factor
    B = gamma * B

    # Note: Paper mentions product over downstream layers (∏)
    # For simplicity with 2-layer network, we use single multiplication
    # For deeper networks, you'd multiply B matrices from all downstream layers

    return B.astype(np.float32)




In [6]:
class DFA_LIFLayer(keras.layers.Layer):
    """LIF layer with Direct Feedback Alignment (DFA) training."""

    def __init__(self, units, output_size,
                 tau=20.0, dt=0.25, threshold=0.4, t_ref=1.0,
                 use_dfa=True, gamma=0.0338, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.output_size = output_size
        self.tau = tau
        self.dt = dt
        self.threshold = threshold
        self.t_ref = t_ref
        self.use_dfa = use_dfa
        self.gamma = gamma

        self.alpha = 1.0 - (dt / tau)
        self.beta = dt / tau
        self.ref_steps = int(t_ref / dt)

    def build(self, input_shape):
        input_dim = input_shape[-1]

        w_mean, w_std = compute_weight_stats(self.units)
        w_init = keras.initializers.RandomUniform(
            minval=w_mean - np.sqrt(3) * w_std,
            maxval=w_mean + np.sqrt(3) * w_std
        )

        self.w = self.add_weight(
            name='weights',
            shape=(input_dim, self.units),
            initializer=w_init,
            trainable=True
        )

        self.b = self.add_weight(
            name='bias',
            shape=(self.units,),
            initializer=keras.initializers.Constant(0.8),
            trainable=True
        )

        if self.use_dfa:
            w_mean_next, w_std_next = compute_weight_stats(self.output_size)
            B_init = initialize_feedback_matrix(
                output_shape=self.output_size,
                input_shape=self.units,
                w_mean_next=w_mean_next,
                w_std_next=w_std_next,
                gamma=self.gamma
            )

            self.B = self.add_weight(
                name='feedback_matrix',
                shape=(self.units, self.output_size),
                initializer=keras.initializers.Constant(B_init),
                trainable=False
            )

    def call(self, inputs): #<- removed the training parameter
        """Forward pass: LIF dynamics over time"""
        batch_size = tf.shape(inputs)[0]
        time_steps_int = inputs.shape[1]

        #fix for potential issue perhaps
        if time_steps_int is None:
            # Handle dynamic shape
            time_steps_tensor = tf.shape(inputs)[1]
            try:
                time_steps_int = int(time_steps_tensor.numpy())
            except (AttributeError, TypeError):
                time_steps_int = 100  # Default fallback

        v = tf.zeros((batch_size, self.units))
        ref_count = tf.zeros((batch_size, self.units))

        spikes_array = tf.TensorArray(
            dtype=tf.float32,
            size=time_steps_int,
            element_shape=(None, self.units)
        )

        for t in range(time_steps_int):
            x_t = inputs[:, t, :]
            i_in = tf.matmul(x_t, self.w) + self.b

            not_refractory = tf.cast(ref_count <= 0, tf.float32)
            v = v * self.alpha + self.beta * i_in
            v = v * not_refractory

            spikes = tf.cast(v >= self.threshold, tf.float32)
            v = v * (1.0 - spikes)

            ref_count = ref_count - 1.0
            ref_count = tf.where(
                spikes > 0,
                tf.ones_like(ref_count) * self.ref_steps,
                ref_count
            )
            ref_count = tf.maximum(ref_count, 0.0)
            spikes_array = spikes_array.write(t, spikes)

        output_spikes = tf.transpose(spikes_array.stack(), [1, 0, 2])
        return output_spikes


In [7]:
# @title
# math function
def surrogate_gradient_exact(a, h_th=0.4, t_ref=1.0, tau=20.0):
    """
    Exact surrogate gradient from paper (Appendix D, Equation A.9).

    This is the smooth approximation of the Dirac delta function
    that represents the derivative of the LIF spike function.

    Args:
        a: Pre-activation values (membrane potential)
        h_th: Spike threshold
        t_ref: Refractory period (ms)
        tau: Membrane time constant (ms)

    Returns:
        Surrogate gradient f'(a)
    """
    # Only compute gradient where a > threshold
    # (derivative is 0 when a ≤ threshold)

    # Compute ratio a / (a - h_th)
    # Add small epsilon to avoid division by zero
    eps = 1e-8
    ratio = a / (a - h_th + eps)

    # Compute numerator: h_th · t_ref · τ / [a · (a - h_th)]
    numerator = h_th * t_ref * tau / (a * (a - h_th) + eps)

    # Compute denominator: [t_ref + τ · log(ratio)]²
    log_term = tf.math.log(ratio + eps)
    denominator = (t_ref + tau * log_term) ** 2 + eps

    # Compute gradient
    grad = numerator / denominator

    # Only non-zero where a > threshold
    grad = tf.where(a > h_th, grad, tf.zeros_like(grad))

    return grad

#math function
def surrogate_gradient_fast_sigmoid(a, threshold=0.4, alpha=10.0):
    """
    Alternative: Fast sigmoid surrogate gradient.

    This is simpler and commonly used in SNN training.
    Formula: 1 / (1 + |alpha * (a - threshold)|)²

    Args:
        a: Pre-activation values
        threshold: Spike threshold
        alpha: Steepness parameter

    Returns:
        Surrogate gradient
    """
    shifted = a - threshold
    grad = 1.0 / (1.0 + tf.abs(alpha * shifted)) ** 2
    return grad


# Visualize both surrogate gradients
a_values = np.linspace(-2, 2, 1000)
threshold = 0.4

# Compute gradients
grad_exact = surrogate_gradient_exact(
    tf.constant(a_values, dtype=tf.float32),
    h_th=threshold
).numpy()

grad_sigmoid = surrogate_gradient_fast_sigmoid(
    tf.constant(a_values, dtype=tf.float32),
    threshold=threshold
).numpy()


In [8]:
import tensorflow as tf

class DFATrainer:
    """Custom trainer implementing Direct Feedback Alignment."""

    def __init__(self, model, learning_rate=0.001, use_exact_gradient=True):
        self.model = model
        self.base_lr = learning_rate
        self.use_exact_gradient = use_exact_gradient

        self.dfa_layers = [
            layer for layer in model.layers if hasattr(layer, 'B')
        ]

        self.loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
        self.train_loss = None
        self.train_acc = None
        self._init_metrics()

    def _init_metrics(self):
        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_acc = tf.keras.metrics.CategoricalAccuracy(name='train_acc')

    def _reset_metrics(self):
        self._init_metrics()

    def compute_spike_rate(self, spikes):
        return tf.reduce_sum(spikes, axis=1)

    def get_layer_learning_rate(self, layer):
        input_dim = layer.w.shape[0]
        return self.base_lr

    def train_step(self, inputs, targets):
      """Single DFA training step"""
      """Single DFA training step"""
      batch_size = tf.shape(inputs)[0]
      time_steps = tf.shape(inputs)[1]
      time_steps_float = tf.cast(time_steps, tf.float32)
      batch_size_float = tf.cast(batch_size, tf.float32)

      # =============================================
      # FORWARD PASS
      # =============================================
      activations = {}
      current_input = inputs

      for layer in self.model.layers:
        if isinstance(layer, keras.layers.InputLayer):
            continue
        activations[layer.name] = current_input
        current_input = layer(current_input)

      output_spikes = current_input
      output_rates = self.compute_spike_rate(output_spikes)
      output_probs = tf.nn.softmax(output_rates)
      loss = self.loss_fn(targets, output_probs)

      # =============================================
      # COMPUTE GLOBAL ERROR (same for all DFA)
      # =============================================
      e_global = output_probs - targets  # [batch, 10]
      e_global_time = tf.expand_dims(e_global, axis=1)
      e_global_time = tf.tile(e_global_time, [1, time_steps, 1])  # [batch, time, 10]

      # =============================================
      # TRAIN OUTPUT LAYER (Direct error - this IS part of DFA!)
      # =============================================
      output_layer = self.model.get_layer('output_layer')
      output_input = activations['output_layer']  # Hidden layer's spikes

      print(f"\n--- DEBUG ---")
      print(f"Output layer spike rate: {tf.reduce_mean(output_spikes).numpy():.6f}")
      print(f"Output rates (sum over time): {output_rates[0].numpy()}")  # First sample
      print(f"Output probs: {output_probs[0].numpy()}")  # First sample
      print(f"Target: {targets[0]}")
      print(f"Loss: {loss.numpy():.4f}")

      # Surrogate gradient at output layer
      a_out = tf.matmul(output_input, output_layer.w) + output_layer.b
      f_prime_out = 1.0 / (1.0 + tf.abs(10.0 * (a_out - output_layer.threshold))) ** 2
      e_out = e_global_time * f_prime_out

      # Vectorized gradient computation
      grad_w_out = tf.einsum('bti,btj->ij', output_input, e_out) / (time_steps_float * batch_size_float)
      grad_b_out = tf.reduce_sum(tf.reduce_sum(e_out, axis=1), axis=0) / (time_steps_float * batch_size_float)

      output_layer.w.assign_sub(self.base_lr * grad_w_out)
      output_layer.b.assign_sub(self.base_lr * grad_b_out)

      # =============================================
      # TRAIN HIDDEN LAYERS (DFA - error through fixed B matrix)
      # This is where DFA differs from backprop!
      # =============================================
      for layer in self.dfa_layers:
          B = layer.B  # Fixed random feedback matrix (NOT learned)

          # PROJECT error through B (this is the DFA step!)
          # Backprop would use: e_projected = e @ W_output^T
          # DFA uses:           e_projected = e @ B^T
          e_projected = tf.matmul(e_global_time, B, transpose_b=True)

          layer_input = activations[layer.name]

          # Surrogate gradient
          a = tf.matmul(layer_input, layer.w) + layer.b
          f_prime = 1.0 / (1.0 + tf.abs(10.0 * (a - layer.threshold))) ** 2
          e_layer = e_projected * f_prime

          # Vectorized gradient computation
          grad_w = tf.einsum('bti,btj->ij', layer_input, e_layer) / (time_steps_float * batch_size_float)
          grad_b = tf.reduce_sum(tf.reduce_sum(e_layer, axis=1), axis=0) / (time_steps_float * batch_size_float)

          layer.w.assign_sub(self.base_lr * grad_w)
          layer.b.assign_sub(self.base_lr * grad_b)

      self.train_loss.update_state(loss)
      self.train_acc.update_state(targets, output_probs)

      return {
          'loss': self.train_loss.result(),
          'accuracy': self.train_acc.result()
      }

    def fit(self, train_data, epochs, verbose=1):
        """Train using DFA"""

        history = {'loss': [], 'accuracy': []}

        for epoch in range(epochs):
            self._reset_metrics()

            if isinstance(train_data, tuple):
                x_train, y_train = train_data
                batch_size = 32
                num_batches = len(x_train) // batch_size
                for batch_idx in range(num_batches):
                    start_idx = batch_idx * batch_size
                    end_idx = start_idx + batch_size


                    x_batch = x_train[start_idx:end_idx]
                    y_batch = y_train[start_idx:end_idx]



                    metrics = self.train_step(x_batch, y_batch)

            loss = float(metrics['loss'])
            acc = float(metrics['accuracy'])
            history['loss'].append(loss)
            history['accuracy'].append(acc)


            if verbose:
                print(f"Epoch {epoch+1}/{epochs} - loss: {loss:.4f} - accuracy: {acc:.4f}")


        return history

print("✓ FIXED DFATrainer loaded")

✓ FIXED DFATrainer loaded


In [None]:
# @title
# Load MNIST data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Flatten
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)

# One-hot encode labels
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

print(f"Training samples: {len(x_train)}")
print(f"Test samples: {len(x_test)}")
print(f"Input shape: {x_train.shape[1]}")
print(f"Output classes: {y_train.shape[1]}")

# Paper's configuration
TIME_STEPS = 25  # Reduced from 100 for faster training
HIDDEN_SIZE = 1000  # Paper uses 1000 hidden neurons
OUTPUT_SIZE = 10
INPUT_SIZE = 784

# INCREASED: Use larger dataset
TRAIN_SAMPLES = 30000  # Increased from 1000 to 10000
x_train_small = x_train[:TRAIN_SAMPLES]
y_train_small = y_train[:TRAIN_SAMPLES]

# Expand to time dimension (paper uses "direct mapping")
# Just repeat the input for each time step
x_train_spikes = np.tile(
    x_train_small[:, np.newaxis, :],  # Add time dim
    (1, TIME_STEPS, 1)  # Repeat TIME_STEPS times
)

print(f"\nInput shape with time: {x_train_spikes.shape}")
print(f"  [batch={x_train_spikes.shape[0]}, "
      f"time={x_train_spikes.shape[1]}, "
      f"features={x_train_spikes.shape[2]}]")

# Build DFA-SNN model
print("\n" + "="*60)
print("BUILDING DFA-SNN MODEL")
print("="*60)

inputs = keras.Input(shape=(TIME_STEPS, INPUT_SIZE))

# Hidden layer with DFA
hidden = DFA_LIFLayer(
    units=HIDDEN_SIZE,
    output_size=OUTPUT_SIZE,  # For B matrix initialization
    tau=20.0,
    dt=0.25,
    threshold=0.4,
    t_ref=1.0, #< making this smaller
    use_dfa=True,
    gamma=0.0338,
    name='hidden_layer'
)(inputs)

# Output layer (also DFA, but projects to itself)
outputs = DFA_LIFLayer(
    units=OUTPUT_SIZE,
    output_size=OUTPUT_SIZE,
    tau=20.0,
    dt=0.25,
    threshold=0.4,
    t_ref=1.0,
    use_dfa=False,  # Output layer doesn't need B
    name='output_layer'
)(hidden)

model = keras.Model(inputs=inputs, outputs=outputs)

model.summary()

# Create DFA trainer
print("\n" + "="*60)
print("INITIALIZING DFA TRAINER")
print("="*60)

trainer = DFATrainer(
    model=model,
    learning_rate=0.1,  # INCREASED: from 0.001 to 0.5 and since it divides by 784
    use_exact_gradient=True  # Use paper's exact f'
)

print(f"\nDFA layers found: {len(trainer.dfa_layers)}")
for layer in trainer.dfa_layers:
    lr = trainer.get_layer_learning_rate(layer)
    print(f"  {layer.name}:")
    print(f"    Units: {layer.units}")
    print(f"    B shape: {layer.B.shape}")
    print(f"    Learning rate: {lr:.6f}")

# Train with DFA
print("\n" + "="*60)
print("TRAINING WITH DFA")
print("="*60)
print("\nNote: This is using DIRECT FEEDBACK ALIGNMENT")
print("  - Errors bypass layer-by-layer propagation")
print("  - Each layer receives error directly from output")
print("  - Feedback matrices B are FIXED (not trained)\n")

history = trainer.fit(
    train_data=(x_train_spikes, y_train_small),
    epochs=20,
    verbose=1
)

# Plot results
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history['loss'], 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss (DFA-SNN)')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(history['accuracy'], 'g-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training Accuracy (DFA-SNN)')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n✓ DFA-SNN training complete!")
print(f"\nFinal metrics:")
print(f"  Loss: {history['loss'][-1]:.4f}")
print(f"  Accuracy: {history['accuracy'][-1]:.4f}")

print(f"\nPaper's reported performance (MNIST):")
print(f"  DFA-SNNs: ~96.75% (best), ~92.09% (average)")
print(f"  aDFA-SNNs: ~98.01% (best), ~97.91% (average)")

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 0us/step
Training samples: 60000
Test samples: 10000
Input shape: 784
Output classes: 10

Input shape with time: (30000, 25, 784)
  [batch=30000, time=25, features=784]

BUILDING DFA-SNN MODEL


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
--- DEBUG ---
Output layer spike rate: 0.165250
Output rates (sum over time): [4. 4. 4. 4. 4. 5. 4. 4. 4. 4.]
Output probs: [0.08533675 0.08533675 0.08533675 0.08533675 0.08533675 0.23196931
 0.08533675 0.08533675 0.08533675 0.08533675]
Target: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
Loss: 2.3882

--- DEBUG ---
Output layer spike rate: 0.163125
Output rates (sum over time): [4. 4. 4. 4. 4. 5. 4. 4. 4. 4.]
Output probs: [0.08533675 0.08533675 0.08533675 0.08533675 0.08533675 0.23196931
 0.08533675 0.08533675 0.08533675 0.08533675]
Target: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
Loss: 2.3872

--- DEBUG ---
Output layer spike rate: 0.162500
Output rates (sum over time): [4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
Output probs: [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
Target: [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
Loss: 2.2722

--- DEBUG ---
Output layer spike rate: 0.161875
Output rates (sum over time): [4. 4. 4. 4. 4. 4. 4. 4. 4. 4.]
Output probs: [0.1 0.1 

In [None]:
# @title
# Load MNIST data
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Flatten
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)

# One-hot encode labels
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

print(f"Training samples: {len(x_train)}")
print(f"Test samples: {len(x_test)}")
print(f"Input shape: {x_train.shape[1]}")
print(f"Output classes: {y_train.shape[1]}")

# Paper's configuration
TIME_STEPS = 100  # Paper uses 100ms total (simplified from 400 steps)
HIDDEN_SIZE = 1000  # Paper uses 1000 hidden neurons
OUTPUT_SIZE = 10
INPUT_SIZE = 784

# For demo, use smaller dataset
TRAIN_SAMPLES = 1000  # Use 1000 samples for quick demo
x_train_small = x_train[:TRAIN_SAMPLES]
y_train_small = y_train[:TRAIN_SAMPLES]

# Expand to time dimension (paper uses "direct mapping")
# Just repeat the input for each time step
x_train_spikes = np.tile(
    x_train_small[:, np.newaxis, :],  # Add time dim
    (1, TIME_STEPS, 1)  # Repeat TIME_STEPS times
)

print(f"\nInput shape with time: {x_train_spikes.shape}")
print(f"  [batch={x_train_spikes.shape[0]}, "
      f"time={x_train_spikes.shape[1]}, "
      f"features={x_train_spikes.shape[2]}]")

# Build DFA-SNN model
print("\n" + "="*60)
print("BUILDING DFA-SNN MODEL")
print("="*60)

inputs = keras.Input(shape=(TIME_STEPS, INPUT_SIZE))

# Hidden layer with DFA
hidden = DFA_LIFLayer(
    units=HIDDEN_SIZE,
    output_size=OUTPUT_SIZE,  # For B matrix initialization
    tau=20.0,
    dt=0.25,
    threshold=0.4,
    t_ref=1.0,
    use_dfa=True,
    gamma=0.0338,
    name='hidden_layer'
)(inputs)

# Output layer (also DFA, but projects to itself)
outputs = DFA_LIFLayer(
    units=OUTPUT_SIZE,
    output_size=OUTPUT_SIZE,
    tau=20.0,
    dt=0.25,
    threshold=0.4,
    t_ref=1.0,
    use_dfa=False,  # Output layer doesn't need B
    name='output_layer'
)(hidden)

model = keras.Model(inputs=inputs, outputs=outputs)

model.summary()

# Create DFA trainer
print("\n" + "="*60)
print("INITIALIZING DFA TRAINER")
print("="*60)

trainer = DFATrainer(
    model=model,
    learning_rate=0.001,  # Base LR (will be adjusted per layer)
    use_exact_gradient=True  # Use paper's exact f'
)

print(f"\nDFA layers found: {len(trainer.dfa_layers)}")
for layer in trainer.dfa_layers:
    lr = trainer.get_layer_learning_rate(layer)
    print(f"  {layer.name}:")
    print(f"    Units: {layer.units}")
    print(f"    B shape: {layer.B.shape}")
    print(f"    Learning rate: {lr:.6f}")

# Train with DFA
print("\n" + "="*60)
print("TRAINING WITH DFA")
print("="*60)
print("\nNote: This is using DIRECT FEEDBACK ALIGNMENT")
print("  - Errors bypass layer-by-layer propagation")
print("  - Each layer receives error directly from output")
print("  - Feedback matrices B are FIXED (not trained)\n")

history = trainer.fit(
    train_data=(x_train_spikes, y_train_small),
    epochs=20,
    verbose=1
)

# Plot results
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history['loss'], 'b-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss (DFA-SNN)')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(history['accuracy'], 'g-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training Accuracy (DFA-SNN)')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n✓ DFA-SNN training complete!")
print(f"\nFinal metrics:")
print(f"  Loss: {history['loss'][-1]:.4f}")
print(f"  Accuracy: {history['accuracy'][-1]:.4f}")



print(f"\nPaper's reported performance (MNIST):")
print(f"  DFA-SNNs: ~96.75% (best), ~92.09% (average)")
print(f"  aDFA-SNNs: ~98.01% (best), ~97.91% (average)")