# MNIST SNN Classification Accuracy

This notebook demonstrates the end-to-end flow of the Spiking Neural Network (SNN):
1. **Rate Encoding**: Converting images to Poisson spike trains.
2. **Crossbar MVM**: Computing currents using the ReRAM conductance matrix.
3. **LIF Neurons**: Integration and spiking at the output layer.
4. **Training Results**: Evaluating the performance of trained weights vs random weights.

In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

# Ensure project root is available
sys.path.insert(0, str(Path().resolve().parent))

from src.crossbar import IdealCrossbar
from src.snn import PoissonEncoder, SNNNetwork, SNNTrainer
from src.utils.mnist_loader import get_mnist_path, load_mnist_from_path
from src.utils.metrics import compute_accuracy
from src.utils.weight_io import load_weights
from src.utils.visualization import plot_firing_raster

np.random.seed(42)
WEIGHTS_PATH = Path().resolve().parent / "experiments" / "trained_weights.npy"

# 1. Load Data
print("Loading MNIST samples...")
mnist_path = get_mnist_path()
X_train, y_train, X_test, y_test, p = load_mnist_from_path(mnist_path, max_test=100)

# 2. Setup Inference
N_IN, N_OUT = 784, 10
TIMESTEPS = 50

cb = IdealCrossbar(N_IN, N_OUT)

if WEIGHTS_PATH.exists():
    print("Loading trained weights...")
    W = load_weights(str(WEIGHTS_PATH))
else:
    print("Trained weights not found. Using random...")
    W = np.maximum(np.random.randn(N_IN, N_OUT) * 0.01, 0)

cb.set_conductance(W)
snn = SNNNetwork(N_IN, N_OUT, cb.run, timesteps=TIMESTEPS)

# 3. Run Inference on Test Set
logits = np.zeros((len(X_test), N_OUT))
for i in range(len(X_test)):
    spikes, total = snn.forward(X_test[i], seed=42+i)
    logits[i] = total

acc = compute_accuracy(logits, y_test)
print(f"\nTest Accuracy: {acc*100:.2f}%")

# 4. Visualization (Sample Digit)
idx = 0
true_label = y_test[idx]
pred_label = np.argmax(logits[idx])

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(X_test[idx].reshape(28, 28), cmap='gray')
plt.title(f"True: {true_label} | Pred: {pred_label}")
plt.axis('off')

sample_spikes, _ = snn.forward(X_test[idx])
fig = plot_firing_raster(sample_spikes, title="Output Layer Firing Raster", ylabel="Neuron (Class)")
plt.subplot(1, 2, 2)
plt.show()