# Notebook description

Decoders $d$ are found via optimization to minimize the mean-squared error between then desired function and the decode-weighted tuning curve as well as the L2 and L1 norms of the decode weights themselves.

$$\arg\min_d \|f-Ad\|_2^2 + \lambda_{L2}\|d\|_2 + \lambda_{L1}\|d\|_1$$

 - Generate realistic tuning curves from hardware
 - Specify an exemplary target function to decode (e.g. $f(x)=f_{max}(x+1)$)
 - Sweep the space of $\lambda_{L2}$ and $\lambda_{L1}$, finding decoder sets for each $\lambda_{L1}$ $\lambda_{L2}$ pair
 - Examine the decode validation error, input frequency (i.e., energy), nonzero decode weights, and SNR at 0 input for each $\lambda_{L1}$ $\lambda_{L2}$ pair

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from nengo_brainstorm.solvers import CVXSolver

import pystorm
from pystorm.hal import HAL
from pystorm.hal.net_builder import NetBuilder
from pystorm.hal.run_control import RunControl
from pystorm.hal.data_utils import lpf, bin_to_spk_times, bins_to_rates

from multiprocessing import Pool

In [None]:
# set parameters for network
X = 16
Y = 8
NNEURON = X*Y
DIM = 1
FMAX = 1000
DOWNSTREAM_NS = 10000
UPSTREAM_NS   = 100000
hal = pystorm.hal.HAL()

In [None]:
net_builder = NetBuilder(hal)

def build_taps(net_builder):
    bad_syn, _ = net_builder.determine_bad_syns()
    SX = X // 2
    SY = Y // 2
    bad_syn = bad_syn[:SY, :SX]
    tap_matrix_syn = net_builder.create_default_yx_taps(SY, SX, DIM, bad_syn)
    tap_matrix = net_builder.syn_taps_to_nrn_taps(tap_matrix_syn)
    np.savetxt("tap_matrix.txt", tap_matrix)
    return tap_matrix

tap_matrix = build_taps(net_builder)

In [None]:
def build_net(net_builder, tap_matrix, d_matrix):
    gain_divs = np.loadtxt("gain_divisors.txt", dtype=int)
    biases = np.loadtxt("biases.txt", dtype=int)
    d_matrix = np.eye(Y*X)
    net = net_builder.create_single_pool_net(
        Y, X, tap_matrix, biases=biases, gain_divs=gain_divs, decoders=d_matrix)
    return net

tuning_net = build_net(net_builder, tap_matrix, np.eye(Y*X))
run_control = RunControl(hal, tuning_net)
hal.map(tuning_net)

In [None]:
def collect_tuning_data(net, hal, fmax, run_control):
    bin_size = 1.0 # seconds
    bin_size_ns = int(bin_size*1E9)
    hal.set_time_resolution(DOWNSTREAM_NS, UPSTREAM_NS)

    total_points = 21

    input_rates = np.zeros((total_points+1, 1))
    input_rates[:total_points, 0] = fmax * np.linspace(-1, 1, total_points)
    input_rates[-1, 0] = input_rates[-2, 0]
    time_ns = np.arange(total_points+1)*bin_size_ns
    input_data = {net.input:(time_ns, input_rates)}
    output_data, _ = run_control.run_input_sweep(
        input_data, get_raw_spikes=False, get_outputs=True)
    outputs, output_times = output_data
    outputs = outputs[net.output][:, :-1] # last dimension reserved for decode
    spike_rates = bins_to_rates(outputs, output_times, time_ns, init_discard_frac=0.5)
    input_rates = input_rates[:-1]
    return input_rates, spike_rates

input_rates, spike_rates = collect_tuning_data(tuning_net, hal, FMAX, run_control)
# split into training and validation data sets
train_input_rates = input_rates[0::2]
valid_input_rates = input_rates[1::2]
train_spike_rates = spike_rates[0::2]
valid_spike_rates = spike_rates[1::2]


In [None]:
def plot_tuning(inputs, spike_rates, array_width, array_height):
    half_width = array_width//2
    plt.figure()
    for idx in range(array_height):
        start_l = idx*array_width
        start_r = start_l + half_width
        plt.plot(inputs, spike_rates[:, start_l:start_l+half_width], 'r')
        plt.plot(inputs, spike_rates[:, start_r:start_r+half_width], 'b')
plot_tuning(train_input_rates, train_spike_rates, X, Y)
plt.title("training")
plot_tuning(valid_input_rates, valid_spike_rates, X, Y)
plt.title("validation")

In [None]:
def fit_decoders(rates, target_function, l1, l2):
    solver = CVXSolver(reg=l2, reg_l1=l1)
    decoders, info = solver(rates, target_function)
    decoders = decoders.clip(-1, 1)
    return decoders, info

train_target_function = train_input_rates + FMAX
valid_target_function = valid_input_rates + FMAX

In [None]:
# Set up L1 L2 space to sweep
N_L1 = 2
N_L2 = 2
L1_vals = np.linspace(0, 0.1, N_L1)
L2_vals = np.linspace(0, 0.1, N_L2)

L1_grid, L2_grid = np.meshgrid(L1_vals, L2_vals)
L1L2_pts = np.zeros((N_L1*N_L2, 2))
L1L2_pts[:, 0] = L1_grid.flatten()
L1L2_pts[:, 1] = L2_grid.flatten()

snr = np.zeros((N_L1, N_L2))

In [None]:
# Compute decoders
decoders = []
rmse_train = []
for l1, l2 in L1L2_pts:
    dec, info = fit_decoders(train_spike_rates, train_target_function, l1, l2)
    decoders.append(dec)
    rmse_train.append(info['rmses'])

In [None]:
# Compute stats
nz_dw_threshold = 1/(8192*2)
print(nz_dw_threshold)

rmse_valid = []
f_in = []
nz_dw = []
for dweights in decoders:
    nz_idx = np.abs(dweights.flatten())>nz_dw_threshold
    nz_dw.append(np.sum(nz_idx))
    f_in.append(np.sum(np.mean(valid_spike_rates[:, nz_idx], axis=0)))
    decode = np.dot(valid_spike_rates, dweights)
    rmse_valid.append(np.sqrt(np.mean((valid_target_function-decode)**2)))
print(nz_dw)
print(f_in)
print(rmse_valid)

In [None]:
# Plot results
rmse_train = np.reshape(rmse_train, (N_L2, N_L1))
rmse_train = np.reshape(rmse_valid, (N_L2, N_L1))
f_in = np.reshape(f_in, (N_L2, N_L1))
nz_dw = np.reshape(nz_dw, (N_L2, N_L1))

fig_rmse, axs_rmse = plt.subplots(ncols=2, figsize=(12, 4))
ax_rmse_train, ax_rmse_valid = axs_rmse
ax_rmse_train.contour(L1_grid, L2_grid, rmse_train)
ax_rmse_valid.contour(L1_grid, L2_grid, rmse_train)

fig_energy, axs_energy = plt.subplots(ncols=2, figsize=(12, 4))
ax_fin, ax_nzdw = axs_energy
ax_fin.contour(L1_grid, L2_grid, f_in)
ax_nzdw.contour(L1_grid, L2_grid, nz_dw)

In [None]:
# Collect SNR data


In [None]:
# def plot_training_fit(train_input_rates, target_function, spike_rates, decoders):
#     train_decode = np.dot(spike_rates, decoders)
#     plt.figure()
#     plt.plot(train_input_rates, target_function, label="target function")
#     plt.plot(train_input_rates, train_decode, label="decoded function")
#     plt.legend(loc="best")

#     z_idx = np.searchsorted(train_input_rates[:, 0], 0) # input 0
#     rates_0 = spike_rates[z_idx] # spike rates at input 0
#     plt.figure()
#     plt.hist(decoders[rates_0>0], density=True, bins=40)

# plot_training_fit(tinfo.train_input_rates, tinfo.target_function, tinfo.spike_rates, tinfo.decoders.flatten())    