# 0. Optuna

In [None]:
# Load Optuna Results and Report it.

import optuna
save_name = "sqlite:///dat/sim_optuna_results_2025-05-15.db"
study     = optuna.load_study(study_name="sim_optuna", storage=save_name)
best      = study.best_params

# Print the best hyperparameters
print("Best hyperparameters:", best)
print("\nBest accuracy:", study.best_value)

print(f"\nfstep:                {best["fstep"]}")
print(f"sigma:                {best["sigma"]}")
print(f"sparsity:             {best["sparsity"]}")
print(f"spectral_radius:      {best["spectral_radius"]}")
print(f"base_geometric_ratio: {best["base_geometric_ratio"]}")

---
# 1. Simulated neural rhythms¶

## (Figure 1A): Example traces of simulated neural events

In [None]:
import pickle
from ReservoirNetwork_utils import plot_example_simulated_traces

load_name      = "dat/SIM_accuracy_results_reservoir_2025-05-15"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
X_train = results['X_train']
y_train = results['y_train']

f = plot_example_simulated_traces(X_train, y_train)
#f.savefig("./PDFs/Figure-1A.pdf", bbox_inches='tight')

## (Figure 1B): Example spectra of reservoir nodes

In [None]:
from ReservoirNetwork import ReservoirNetwork

res_net         = ReservoirNetwork(Fs=1000)
history_weights, frange = res_net.generate_history_weights()
w_t_minus_1     = history_weights["w_t_minus_1"]
w_t_minus_2     = history_weights["w_t_minus_2"]

# Use the history weights to plot the spectrum of each node.
from ReservoirNetwork_utils import plot_analytic_spectrum
f = plot_analytic_spectrum(w_t_minus_1, w_t_minus_2, node_step=25)
#f.savefig("./PDFs/Figure-1B.pdf", bbox_inches='tight')

## (Figure 1C): Example traces of noise-driven reservoir

In [None]:
import numpy as np
from   ReservoirNetwork import ReservoirNetwork

# Get the state dynamics with zero input.
res_net                    = ReservoirNetwork(Fs=1000)
history_weights, frange    = res_net.generate_history_weights()
input_time_series          = np.zeros(1000)
states, amplitudes, phases = res_net.collect_states(input_time_series)

from ReservoirNetwork_utils import plot_state_dynamics
f = plot_state_dynamics(states, node_step=25, plot_spacing=0.2, frange=frange)
#f.savefig("./PDFs/Figure-1C.pdf", bbox_inches='tight')

## (Figure 1D): Average confusion matrix 

In [None]:
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
import pickle

load_name      = "dat/SIM_accuracy_results_reservoir_2025-05-15"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
accuracy       = results['accuracy']
confuse_matrix = results['confuse_matrix']
X_test         = results['X_test']
K              = X_test.shape[0]/4

# Print accuracy metrics
print("Accuracy")
print(f"Mean: {np.mean(accuracy):.3f}")
print(f"STD : {np.std(accuracy):.3f}")

# Compute the average confusion matrix over all k iterations
avg_confuse_matrix = np.mean(confuse_matrix, axis=0)/K

# Define class labels (adjust if needed)
classes = ["Spike Ripple", "Spike", "Ripple", "Background"]

f = plt.figure(figsize=(5, 5))
plt.imshow(avg_confuse_matrix, interpolation='nearest', cmap=plt.cm.Blues)
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45, ha='right')
plt.yticks(tick_marks, classes)

# Add text annotations to each cell in the matrix
thresh = avg_confuse_matrix.max() / 2.
for i in range(avg_confuse_matrix.shape[0]):
    for j in range(avg_confuse_matrix.shape[1]):
        plt.text(j, i, f"{avg_confuse_matrix[i, j]:.2f}", horizontalalignment="center",
                 color="white" if avg_confuse_matrix[i, j] > thresh else "black")

plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()
#f.savefig("./PDFs/Figure-1D.pdf", bbox_inches='tight')

## (Figure 1E): Results for different step sizes

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle

# Construct list of file names.
load_names = ["dat/SIM_accuracy_results_reservoir_2025-05-15"] + [
    f"dat/SIM_accuracy_results_fstep_{i}_date_2025-05-16" for i in range(2, 21)
]

num_files = len(load_names)
mean_accuracy = np.zeros(num_files)
std_accuracy  = np.zeros(num_files)
n_nodes       = np.zeros(num_files)
f_step        = np.zeros(num_files)

# Loop over each file, load the results, and compute accuracy stats.
for k, load_name in enumerate(load_names):
    with open(load_name + '.pkl', 'rb') as f:
        results = pickle.load(f)
    accuracy = results['accuracy']
    mean_accuracy[k] = np.mean(accuracy)
    std_accuracy[k]  = np.std(accuracy)
    n_nodes[k]       = np.shape(results['X_train_features'])[1]
    f_step[k]        = results['res_net'].fstep

# Define x positions
x_positions = np.arange(1, num_files + 1)

# Create plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.errorbar(
    x_positions, 
    mean_accuracy, 
    yerr=std_accuracy, 
    fmt='o', 
    capsize=5, 
    markersize=8, 
    color='blue', 
    ecolor='black', 
    linestyle='None'
)

# Bottom x-axis labels (frequency)
bottom_labels = [f"{int(f)} Hz" for f in f_step]
ax.set_xticks(x_positions)
ax.set_xticklabels(bottom_labels, rotation=45)

# Add top x-axis with node count labels
ax_top = ax.twiny()
ax_top.set_xticks(x_positions)
top_labels = [f"n={int(n)}" for n in n_nodes]
ax_top.set_xticklabels(top_labels, rotation=45)
ax_top.set_xlim(ax.get_xlim())

# Label and formatting
ax.set_ylabel('Accuracy')
ax.grid(True)
plt.tight_layout()
plt.show()
#fig.savefig("./PDFs/Figure-1E.pdf", bbox_inches='tight')

np.set_printoptions(precision=3, suppress=True)
print(np.transpose([x_positions, mean_accuracy, std_accuracy]))

## (Print out) Compare to alternative classification via power spectrum

In [None]:
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import pickle

load_name      = "dat/SIM_accuracy_results_reservoir_2025-05-15"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
accuracy_RRN       = results['accuracy']
correct_counts_RRN = results['correct_counts']
X_test             = results['X_test']
K                  = X_test.shape[0]/4

load_name      = "dat/SIM_accuracy_results_fft_power_2025-05-15"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
accuracy_pow       = results['accuracy']
correct_counts_pow = results['correct_counts']


# Print accuracy metrics
print("Accuracy RRN")
print(f"Mean: {np.mean(accuracy_RRN):.3f}")
print(f"STD : {np.std(accuracy_RRN):.3f}")
print("Count :", np.size(accuracy_RRN))

print("\nAccuracy Power")
print(f"Mean: {np.mean(accuracy_pow):.3f}")
print(f"STD : {np.std(accuracy_pow):.3f}")
print("Count :", np.size(accuracy_pow))

# Two-sample t-test
t_statistic, p_value = stats.ttest_ind(accuracy_RRN, accuracy_pow)
print("\nAccuracy RRN vs Power")
print(f"T-statistic:, {t_statistic:.1f}")
print("P-value:", p_value)

## (Figure 1F): Plot example average responses

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
from ReservoirNetwork       import ReservoirNetwork
from ReservoirNetwork_utils import plot_scaled_reservoir_responses_by_label

# Get the state dynamics with zero input.
res_net = ReservoirNetwork(Fs=1000)

load_name      = "dat/SIM_accuracy_results_reservoir_2025-05-15"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
X_test_features = results['X_test_features']
y_test = results['y_test']
scaler = results['scaler']
res_net = results['res_net']

fig, ax = plot_scaled_reservoir_responses_by_label(X_test_features, y_test, res_net, scaler)
#fig.savefig("./PDFs/Figure-1F.pdf", bbox_inches='tight')

---

# 2. In vivo neural rhythms

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle

load_name      = "dat/INVIVO_accuracy_results_2025-05-17"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
accuracy    = results['accuracy']
sensitivity = results['sensitivity']
specificity = results['specificity']
PPV         = results['PPV']
NPV         = results['NPV']

# Print classification metrics

print("\nSensitivity")
print(f"Mean: {np.mean(sensitivity):.3f}")
print(f"STD : {np.std(sensitivity):.3f}")

print("\nSpecificity")
print(f"Mean: {np.mean(specificity):.3f}")
print(f"STD : {np.std(specificity):.3f}")

print("\nPPV")
print(f"Mean: {np.mean(PPV):.3f}")
print(f"STD : {np.std(PPV):.3f}")

print("\nNPV")
print(f"Mean: {np.mean(NPV):.3f}")
print(f"STD : {np.std(NPV):.3f}")

print("\nAccuracy")
print(f"Mean: {np.mean(accuracy):.3f}")
print(f"STD : {np.std(accuracy):.3f}")

---
# 3. MNIST

## (Figure 2A): Counts and plot of example scanline traces

In [None]:
from   keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

# Load the MNIST data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784).astype(np.float32)  # Flatten images from (60000, 28, 28) to (60000, 784)
X_train /= 255.0                                          # Normalize pixel values to range [0, 1]

# Count the labels for training and testing subsets
train_counts = dict(Counter(y_train))
test_counts = dict(Counter(y_test))

# Print the table header
print("{:<10}{:<20}{:<20}".format("Digit", "Training Count", "Testing Count"))
print("-" * 50)

# Print counts for each digit from 0 to 9
for digit in range(10):
    train_count = train_counts.get(digit, 0)
    test_count = test_counts.get(digit, 0)
    print("{:<10},{:<20},{:<20}".format(digit, train_count, test_count))

# Compute totals for training and testing
total_train = sum(train_counts.values())
total_test = sum(test_counts.values())

print("{:<10},{:<20},{:<20}".format("Total", total_train, total_test))

# Indices to plot.
indices = [1, 3, 5]

f, axes = plt.subplots(nrows=3, ncols=2, figsize=(10, 5))
for i, k in enumerate(indices):

    # Plot the image
    axes[i, 0].imshow(X_train[k].reshape(28,28), cmap='gray_r')
    axes[i, 0].axis('off')  # Remove x and y axes from the image.
    
    # Plot the line
    axes[i, 1].plot(X_train[k], 'k')
    axes[i, 1].set_xlim([0, 784])
    axes[i, 1].set_ylim([0, 1])
    axes[i, 1].spines['top'].set_visible(False)
    axes[i, 1].spines['right'].set_visible(False)

plt.show()
#f.savefig("./PDFs/Figure-2A.pdf", bbox_inches='tight')

## (Figure 2B): Average confusion matrix

In [None]:
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
import pickle

load_name      = "dat/MNIST_accuracy_results_2025-05-20"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
accuracy         = results['accuracy']
confuse_matrix   = results['confuse_matrix']
X_train_features = results['X_train_features']

# Number of features in classification
print(f"Number of features: {np.shape(X_train_features)[1]}")

# Print accuracy metrics with three decimal places
print("Accuracy")
print(f"Mean: {np.mean(accuracy):.3f}")
print(f"STD : {np.std(accuracy):.3f}")

# Compute the sum for each row in each matrix.
row_sums = confuse_matrix.sum(axis=2, keepdims=True)

# Normalize each row by dividing by its row sum.
normalized_confuse_matrix = confuse_matrix / row_sums

# Compute the average confusion matrix over all iterations
avg_confuse_matrix = np.mean(normalized_confuse_matrix, axis=0)

# Define class labels (adjust if needed)
classes = [0,1,2,3,4,5,6,7,8,9]

f = plt.figure(figsize=(8, 6))
plt.imshow(avg_confuse_matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Average Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)

# Add text annotations to each cell in the matrix
thresh = avg_confuse_matrix.max() / 2.
for i in range(avg_confuse_matrix.shape[0]):
    for j in range(avg_confuse_matrix.shape[1]):
        plt.text(j, i, f"{avg_confuse_matrix[i, j]:.2f}", horizontalalignment="center",
                 color="white" if avg_confuse_matrix[i, j] > thresh else "black")

plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()
#f.savefig("./PDFs/Figure-2B.pdf", bbox_inches='tight')

## (Table 3): Accuracy versus step size

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle

# Construct list of file names.
load_names = ["dat/MNIST_accuracy_results_2025-05-20"] + [
    f"dat/MNIST_accuracy_results_fstep_{i}_date_2025-05-20" for i in range(2, 11)
]

num_files = len(load_names)
mean_accuracy = np.zeros(num_files)
std_accuracy  = np.zeros(num_files)
n_nodes       = np.zeros(num_files)
f_step        = np.zeros(num_files)

# Loop over each file, load the results, and compute accuracy stats.
for k, load_name in enumerate(load_names):
    with open(load_name + '.pkl', 'rb') as f:
        results = pickle.load(f)
    accuracy = results['accuracy']
    mean_accuracy[k] = np.mean(accuracy)
    std_accuracy[k]  = np.std(accuracy)
    n_nodes[k]       = np.shape(results['X_train_features'])[1]
    f_step[k]        = results['res_net'].fstep

# Define x positions
x_positions = np.arange(1, num_files + 1)

# Create plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.errorbar(
    x_positions, 
    mean_accuracy, 
    yerr=std_accuracy, 
    fmt='o', 
    capsize=5, 
    markersize=8, 
    color='blue', 
    ecolor='black', 
    linestyle='None'
)

# Bottom x-axis labels (frequency)
bottom_labels = [f"{int(f)} Hz" for f in f_step]
ax.set_xticks(x_positions)
ax.set_xticklabels(bottom_labels, rotation=45)

# Add top x-axis with node count labels
ax_top = ax.twiny()
ax_top.set_xticks(x_positions)
top_labels = [f"n={int(n)}" for n in n_nodes]
ax_top.set_xticklabels(top_labels, rotation=45)
ax_top.set_xlim(ax.get_xlim())

# Label and formatting
ax.set_ylabel('Accuracy')
ax.grid(True)
plt.tight_layout()
plt.show()
#fig.savefig("./PDFs/Figure-1E.pdf", bbox_inches='tight')

# Print out results.
np.set_printoptions(precision=3, suppress=True)
print(np.transpose([x_positions, mean_accuracy, std_accuracy]))

## (Figure 2C): Example average responses

In [None]:
# How does it respond to the test set?
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
import pickle
from collections import Counter

load_name      = "dat/MNIST_accuracy_results_2025-05-20"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
X_test_features  = results['X_test_features']
y_test           = results['y_test']
res_net          = results['res_net']
clf              = results['clf']
scaler           = clf.named_steps['scaler']

# Gather counts of each digit for printout
counts = []
print(f"{'Digit':<10}{'Count':<20}")
for digit in range(10):
    cnt = np.sum(y_test == digit)
    counts.append(cnt)
    print(f"{digit:<10}{cnt:<20}")

# Compute mean and std over digits
mean_count = np.mean(counts)
std_count  = np.std(counts)

print("-"*30)
print(f"{'Mean count ':<10}{mean_count:.3f}")
print(f"{'Std dev ':<10}{std_count:.3f}")

from ReservoirNetwork_utils import plot_scaled_reservoir_responses_by_label
fig, ax = plot_scaled_reservoir_responses_by_label(X_test_features, y_test, res_net, scaler)
#fig.savefig("./PDFs/Figure-2C.pdf", bbox_inches='tight')

---
# 4. Speech Commands Dataset (SCD)

## (Figure 3A): Average confusion matrix

In [None]:
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
import pickle
import os
import glob
import librosa

load_name      = "dat/SDDS_accuracy_results_2025-05-29"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
accuracy       = results['accuracy']
confuse_matrix = results['confuse_matrix']

# Print accuracy metrics with three decimal places
print("Accuracy")
print(f"Mean: {np.mean(accuracy):.3f}")
print(f"STD : {np.std(accuracy):.3f}")

# Compute the sum for each row in each matrix.
row_sums = confuse_matrix.sum(axis=2, keepdims=True)

# Normalize each row by dividing by its row sum.
normalized_confuse_matrix = confuse_matrix / row_sums

# Compute the average confusion matrix over all iterations
avg_confuse_matrix = np.mean(normalized_confuse_matrix, axis=0)

# Define class labels
classes = [0,1,2,3,4,5,6,7,8,9]

f = plt.figure(figsize=(8, 6))
plt.imshow(avg_confuse_matrix, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Average Confusion Matrix")
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)

# Add text annotations to each cell in the matrix
thresh = avg_confuse_matrix.max() / 2.
for i in range(avg_confuse_matrix.shape[0]):
    for j in range(avg_confuse_matrix.shape[1]):
        plt.text(j, i, f"{avg_confuse_matrix[i, j]:.2f}", horizontalalignment="center",
                 color="white" if avg_confuse_matrix[i, j] > thresh else "black")

plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()
plt.show()
#f.savefig("./PDFs/Figure-3A.pdf", bbox_inches='tight')

## (Figure 3B): Example average response

In [None]:
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
import pickle

load_name      = "dat/SDDS_accuracy_results_2025-05-29"
with open(load_name + '.pkl', 'rb') as f:
    results = pickle.load(f)
X_test_features  = results['X_test_features']
y_test           = results['y_test']
res_net          = results['res_net']
scaler           = results['scaler']

# Gather counts in a list so we can compute stats
counts = []
print(f"{'Digit':<10}{'Count':<20}")
for digit in range(10):
    cnt = np.sum(y_test == digit)
    counts.append(cnt)
    print(f"{digit:<10}{cnt:<20}")

print(f"{'Mean '}{np.mean(counts):.1f}{', STD '}{np.std(counts):.1f}")

from ReservoirNetwork_utils import plot_scaled_reservoir_responses_by_label
fig, ax = plot_scaled_reservoir_responses_by_label(X_test_features, y_test, res_net, scaler)
#fig.savefig("./PDFs/Figure-3B.pdf", bbox_inches='tight')