In [None]:
import sys; sys.path.insert(0, '../../')

In [None]:
from pathlib import Path
from models.snn.factory import build_snn
from models.snn.train import train_model, evaluate_model
from models.snn.tuning import tune_hyperparameters
from models.preprocessing import load_preprocessed

In [None]:
splits, class_names = load_preprocessed('../../data/processed/')
X_train, y_train = splits['train']
X_val, y_val = splits['val']
X_test, y_test = splits['test']

In [None]:
checkpoint_dir = Path('../../checkpoints/snn')

In [None]:
best_config, analysis = tune_hyperparameters(
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    n_features=X_train.shape[1],
    n_classes=len(class_names),
    max_epochs=5,
    num_samples=20,
    project_name="snn_ray_tune",
 )

print("\nBest hyperparameters found by Ray Tune:\n")
print(f"  Hidden neurons: {best_config['n_neurons_hidden']}")
print(f"  Synapse time constant: {best_config['synapse']}")
print(f"  Learning rate: {best_config['learning_rate']:.2e}")
print(f"  Batch size: {best_config['batch_size']}")

best_hidden = int(best_config['n_neurons_hidden'])
best_syn = float(best_config['synapse'])
best_lr = float(best_config['learning_rate'])
best_batch_sz = int(best_config['batch_size'])

# Persist best hyperparameters to an env-style file for later use
env_path = checkpoint_dir / "best_hparams.env"

print(f"\nSaving best hyperparameters to {env_path} ...")
with env_path.open("w") as f:
    f.write(f"N_NEURONS_HIDDEN={best_hidden}\n")
    f.write(f"SYNAPSE={best_syn}\n")
    f.write(f"LEARNING_RATE={best_lr}\n")
    f.write(f"BATCH_SIZE={best_batch_sz}\n")
print("Best hyperparameters saved.\n")

In [None]:
# Build the SNN architecture with best hyperparameters
print("Building SNN architecture with best hyperparameters...")
net, inp, p_out = build_snn(
    n_features=X_train.shape[1],
    n_classes=len(class_names),
    n_neurons_hidden=best_hidden,
    synapse=best_syn,
)
print(f"Network built with {X_train.shape[1]} input features and {len(class_names)} output classes")
print(f"Hidden layer: {best_hidden} neurons, Synapse: {best_syn}")

In [None]:
# Train the SNN with best hyperparameters from tuning
print("Starting full training with best hyperparameters...")
print(f"Learning rate: {best_lr:.2e}, Batch size: {best_batch_sz}\n")

history, sim = train_model(
    net=net,
    inp=inp,
    p_out=p_out,
    X_train=X_train,
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    epochs=20,
    batch_size=best_batch_sz,
    learning_rate=best_lr,
    checkpoint_dir=checkpoint_dir,
    use_early_stopping=True,
    early_stopping_min_delta=0.01,
    early_stopping_patience=5,
)

In [None]:
import matplotlib.pyplot as plt

# Get all unique metric names (without 'val_' prefix)
metrics = set()
for key in history.history.keys():
    metric_name = key.replace('val_', '')
    metrics.add(metric_name)

metrics = sorted(list(metrics))

# Create subplots for each metric
n_metrics = len(metrics)
n_cols = 2
n_rows = (n_metrics + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 4 * n_rows))
if n_metrics == 1:
    axes = [axes]
else:
    axes = axes.flatten()

# Plot each metric
for idx, metric in enumerate(metrics):
    ax = axes[idx]
    
    # Training metric
    if metric in history.history:
        ax.plot(history.history[metric], label=f'Training {metric}', linewidth=2)
    
    # Validation metric
    val_metric = f'val_{metric}'
    if val_metric in history.history:
        ax.plot(history.history[val_metric], label=f'Validation {metric}', linewidth=2)
    
    ax.set_xlabel('Epoch')
    ax.set_ylabel(metric.capitalize())
    ax.set_title(f'{metric.capitalize()} over Epochs')
    ax.legend()
    ax.grid(True, alpha=0.3)

# Hide unused subplots
for idx in range(len(metrics), len(axes)):
    axes[idx].set_visible(False)

plt.tight_layout()
plt.show()

# Print summary of all metrics
print("\nAvailable metrics in history:")
for key in sorted(history.history.keys()):
    values = history.history[key]
    print(f"  {key}: min={min(values):.4f}, max={max(values):.4f}, final={values[-1]:.4f}")

In [None]:
# Evaluate on test set
print("\nEvaluating on test set...")
test_loss = evaluate_model(sim, inp, p_out, X_test, y_test)
print(f"Test Loss: {test_loss:.4f}")

sim.close()