In [None]:
import sys; sys.path.insert(0, '../../')

# Disable GPU to prevent TensorFlow/nengo_dl crashes when CUDA is unavailable
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
print("TensorFlow running on CPU only")

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
from models.snn.factory import build_snn, build_simple_snn
from models.snn.train import train_model, evaluate_model
from models.snn.tuning import tune_hyperparameters
from models.preprocessing import load_preprocessed

In [None]:
splits, class_names = load_preprocessed('../../data/processed/')
X_train, y_train = splits['train']
X_val, y_val = splits['val']
X_test, y_test = splits['test']

In [None]:
checkpoint_dir = Path('../../checkpoints/snn')
env_path = checkpoint_dir / "best_hparams.env"

In [None]:
# Load hyperparameters from env file if it exists, otherwise use defaults
if env_path.exists():
    print(f"Loading hyperparameters from {env_path}")
    hparams = {}
    with env_path.open("r") as f:
        for line in f:
            line = line.strip()
            if line and '=' in line:
                key, value = line.split('=', 1)
                hparams[key] = value
    
    best_hidden = int(hparams.get('N_NEURONS_HIDDEN', 128))
    best_lr = float(hparams.get('LEARNING_RATE', 0.001))
    best_batch_sz = int(hparams.get('BATCH_SIZE', 32))
    
    print("Loaded hyperparameters:")
else:
    print(f"Env file not found at {env_path}, using default hyperparameters")
    best_hidden = 128
    best_lr = 0.001
    best_batch_sz = 32
    
    print("Using default hyperparameters:")

print(f"  Hidden neurons: {best_hidden}")
print(f"  Learning rate: {best_lr:.2e}")
print(f"  Batch size: {best_batch_sz}")

In [None]:
print("Building SNN architecture...")
net, inp, p_out = build_simple_snn(
    n_features=X_train.shape[1],
    n_classes=len(class_names),
    n_neurons_hidden=best_hidden,  # Hidden layer neurons
    synapse=None,          # No synaptic filtering for single timestep
    spiking=False,         # Use RectifiedLinear neurons (trainable)
)
print(f"Network built with {X_train.shape[1]} input features and {len(class_names)} output classes")
print(f"Using rate-based neurons (RectifiedLinear) for gradient-based training")

In [None]:
# Train the SNN with best hyperparameters from tuning
print("Starting full training with best hyperparameters...")
print(f"Learning rate: {best_lr:.2e}, Batch size: {best_batch_sz}\n")

history, sim = train_model(
    net=net,
    inp=inp,
    p_out=p_out,
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    epochs=20,
    batch_size=best_batch_sz,
    learning_rate=best_lr,
    checkpoint_dir=checkpoint_dir,
    use_early_stopping=True,
    early_stopping_min_delta=0.001,
    early_stopping_patience=7,
)

In [None]:
# Get all unique metric names (without 'val_' prefix)
metrics = set()
for key in history.history.keys():
    metric_name = key.replace('val_', '')
    metrics.add(metric_name)

metrics = sorted(list(metrics))

# Create subplots for each metric
n_metrics = len(metrics)
n_cols = 2
n_rows = (n_metrics + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 4 * n_rows))
if n_metrics == 1:
    axes = [axes]
else:
    axes = axes.flatten()

# Plot each metric
for idx, metric in enumerate(metrics):
    ax = axes[idx]
    
    # Training metric
    if metric in history.history:
        ax.plot(history.history[metric], label=f'Training {metric}', linewidth=2)
    
    # Validation metric
    val_metric = f'val_{metric}'
    if val_metric in history.history:
        ax.plot(history.history[val_metric], label=f'Validation {metric}', linewidth=2)
    
    ax.set_xlabel('Epoch')
    ax.set_ylabel(metric.capitalize())
    ax.set_title(f'{metric.capitalize()} over Epochs')
    ax.legend()
    ax.grid(True, alpha=0.3)

# Hide unused subplots
for idx in range(len(metrics), len(axes)):
    axes[idx].set_visible(False)

plt.tight_layout()
plt.show()

# Print summary of all metrics
print("\nAvailable metrics in history:")
for key in sorted(history.history.keys()):
    values = history.history[key]
    print(f"  {key}: min={min(values):.4f}, max={max(values):.4f}, final={values[-1]:.4f}")

In [None]:
# Evaluate on test set + confusion matrix
import numpy as np
from scipy.special import softmax

print("\nEvaluating on test set...")
test_results = evaluate_model(sim, inp, p_out, X_test, y_test)

# Find the accuracy key (it will contain 'accuracy' in the name)
accuracy_key = [k for k in test_results.keys() if 'accuracy' in k][0]

print(f"Test Loss: {test_results['loss']:.4f}")
print(f"Test Accuracy: {test_results[accuracy_key]:.4f}")

print("\nComputing confusion matrix...")

# Predict class logits on the test set
X_test_reshaped = X_test[:, np.newaxis, :]
pred_out = sim.predict({inp: X_test_reshaped})
logits_time = pred_out[p_out]

# Handle both (batch, time, classes) and (batch, classes)
if logits_time.ndim == 3:
    logits = logits_time[:, -1, :]
elif logits_time.ndim == 2:
    logits = logits_time
else:
    raise ValueError(f"Unexpected model output shape for p_out: {logits_time.shape}")

# Apply softmax to get probabilities (network outputs raw logits now)
probs = softmax(logits, axis=1)
y_pred = np.argmax(probs, axis=1)

# y_test is usually one-hot encoded; support both one-hot and integer labels
if y_test.ndim == 2:
    y_true = np.argmax(y_test, axis=1)
else:
    y_true = y_test.astype(int)

n_classes = len(class_names)
labels = list(range(n_classes))

# Prefer sklearn if available (nicer plotting + report), else fallback to numpy
try:
    from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report

    cm = confusion_matrix(y_true, y_pred, labels=labels)

    fig, ax = plt.subplots(figsize=(6, 5))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    disp.plot(ax=ax, cmap="Blues", colorbar=True, values_format="d")
    ax.set_title("Confusion Matrix (Test)")
    plt.tight_layout()
    plt.show()

    print("\nClassification report (Test):")
    print(classification_report(y_true, y_pred, target_names=class_names, digits=4, zero_division=0))

except Exception as e:
    print(f"(sklearn unavailable or failed: {e})\nFalling back to numpy confusion matrix plot.")

    cm = np.zeros((n_classes, n_classes), dtype=int)
    for t, p in zip(y_true, y_pred):
        if 0 <= t < n_classes and 0 <= p < n_classes:
            cm[t, p] += 1

    fig, ax = plt.subplots(figsize=(6, 5))
    im = ax.imshow(cm, cmap="Blues")
    ax.set_title("Confusion Matrix (Test)")
    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_xticks(np.arange(n_classes))
    ax.set_yticks(np.arange(n_classes))
    ax.set_xticklabels(class_names, rotation=45, ha="right")
    ax.set_yticklabels(class_names)

    # annotate counts
    for i in range(n_classes):
        for j in range(n_classes):
            ax.text(j, i, str(cm[i, j]), ha="center", va="center", color="black")

    fig.colorbar(im, ax=ax)
    plt.tight_layout()
    plt.show()

sim.close()