In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from nengo_brainstorm.solvers import CVXSolver
from nengo_extras.plot_spikes import plot_spikes

import pystorm
from pystorm.hal import HAL
from pystorm.hal.net_builder import NetBuilder
from pystorm.hal.run_control import RunControl
from pystorm.hal.data_utils import lpf, bin_to_spk_times

In [None]:
def get_snr_gamma(lamtau_out, k):
    """SNR of the synaptically filtered gamma process

    In terms of the output lambda * tau
    """
    if isinstance(lamtau_out, float):
        lamtau_out = np.array([lamtau_out])
    lamtau_in = lamtau_out * k
    snr = np.zeros_like(lamtau_out)
    idx = lamtau_out > 0
    x = lamtau_in
    a = np.sqrt(2*x[idx])
    b_num = (1+x[idx])**k+x[idx]**k
    b_den = (1+x[idx])**k-x[idx]**k
    b = b_num/b_den
    c = 2*x[idx]/k
    snr[idx] = a / np.sqrt(k*(b-c))
    return snr

def get_snr_periodic(lamtau):
    snr = np.zeros_like(lamtau)
    idx = lamtau > 0
    snr[idx] = 1./np.sqrt(
        1./(2.*lamtau[idx])*(1+np.exp(-1/(lamtau[idx])))/(1-np.exp(-1/(lamtau[idx])))-1)
    return snr

In [None]:
# set parameters for network
X = 16
Y = 8
NNEURON = X*Y
DIM = 1
FMAX = 1000
DOWNSTREAM_NS = 10000

hal = pystorm.hal.HAL()

In [None]:
net_builder = NetBuilder(hal)

def build_taps(net_builder):
    bad_syn, _ = net_builder.determine_bad_syns()
    SX = X // 2
    SY = Y // 2
    bad_syn = bad_syn[:SY, :SX]
    tap_matrix_syn = net_builder.create_default_yx_taps(SY, SX, DIM, bad_syn)
    tap_matrix = net_builder.syn_taps_to_nrn_taps(tap_matrix_syn)
    np.savetxt("tap_matrix.txt", tap_matrix)
    return tap_matrix

tap_matrix = build_taps(net_builder)

In [None]:
def build_net(net_builder, tap_matrix):
    gain_divs = np.loadtxt("gain_divisors.txt", dtype=int)
    biases = np.loadtxt("biases.txt", dtype=int)

    d_matrix = np.zeros((Y*X+1, Y*X))
    d_matrix[:-1] = np.eye(Y*X)
    net = net_builder.create_single_pool_net(
        Y, X, tap_matrix, biases=biases, gain_divs=gain_divs, decoders=d_matrix)
    return net

def assign_decoders(net, decoders, hal):
    net.decoder_conn.weights[-1, :] = decoders
    net.decoder_conn.reassign_weights(net.decoder_conn.weights)
    hal.remap_weights()

net = build_net(net_builder, tap_matrix)
run_control = RunControl(hal, net)
hal.map(net)

# Train Decoders

- Check that tuning curves look reasonable
- Check that accumulator spikes match raw spikes via identity decode matrix

In [None]:
class TrainInfo:
    def __init__(self, train_input_rates, spike_rates):
        self.train_input_rates = train_input_rates
        self.spike_rates = spike_rates

def collect_train_data(net, hal, fmax, run_control):
    bin_size = 0.5 # seconds
    bin_size_ns = int(bin_size*1E9)
    hal.set_time_resolution(DOWNSTREAM_NS, bin_size_ns)

    total_train_points = 11

    train_input_rates = np.zeros((total_train_points+2, 1))
    train_input_rates[1:total_train_points+1,0] = fmax * np.linspace(-1, 1, total_train_points)
    train_input_rates[0, 0] = train_input_rates[1, 0] # add padding to account for slop in trials
    train_input_rates[-1, 0] = train_input_rates[-2, 0]
    train_time_ns = np.arange(total_train_points+2)*bin_size_ns
    input_data = {net.input:(train_time_ns, train_input_rates)}

    output_data, _ = run_control.run_input_sweep(
        input_data, get_raw_spikes=False, get_outputs=True)
    outputs, output_times = output_data
    outputs = outputs[net.output][1:, :-1]
    spike_rates = outputs/bin_size
    train_input_rates = train_input_rates[1:-1]
    
    tinfo = TrainInfo(train_input_rates, spike_rates)
    return tinfo
tinfo = collect_train_data(net, hal, FMAX, run_control)

In [None]:
def plot_tuning(inputs, spike_rates, array_width, array_height):
    half_width = array_width//2
    for idx in range(array_height):
        start_l = idx*array_width
        start_r = start_l + half_width
        plt.plot(inputs, spike_rates[:, start_l:start_l+half_width], 'r')
        plt.plot(inputs, spike_rates[:, start_r:start_r+half_width], 'b')
plot_tuning(tinfo.train_input_rates, tinfo.spike_rates, X, Y)

In [None]:
def fit_decoders(rates, target_function):
    solver = CVXSolver(reg=0.1, reg_l1=0.1)
    decoders, info = solver(rates, target_function)
    decoders = decoders.clip(-1, 1)
    rmse = info['rmses']
    print(rmse)
    return decoders

tinfo.target_function = tinfo.train_input_rates + FMAX
decoders = fit_decoders(tinfo.spike_rates, tinfo.target_function)
tinfo.decoders = decoders

In [None]:
def plot_training_fit(train_input_rates, target_function, spike_rates, decoders):
    train_decode = np.dot(spike_rates, decoders)
    plt.figure()
    plt.plot(train_input_rates, target_function, label="target function")
    plt.plot(train_input_rates, train_decode, label="decoded function")
    plt.legend(loc="best")

    z_idx = np.searchsorted(train_input_rates[:, 0], 0) # input 0
    rates_0 = spike_rates[z_idx] # spike rates at input 0
    plt.figure()
    plt.hist(decoders[rates_0>0], density=True, bins=40)

plot_training_fit(tinfo.train_input_rates, tinfo.target_function, tinfo.spike_rates, tinfo.decoders.flatten())    
print(np.sum(np.abs(tinfo.decoders)<1E-2))
assign_decoders(net, tinfo.decoders.flatten(), hal)

# Test Decoders

In [None]:
#  deliver an input of 0
def run_test(hal, run_control, bin_size=0.0001):
    """Run a single input test trial"""
    bin_size_ns = int(bin_size*1E9)
    hal.set_time_resolution(DOWNSTREAM_NS, bin_size_ns)
    
    test_time = 1
    test_time_ns = int(test_time*1E9)

    input_rates = np.zeros((2, 1))
    input_times = np.arange(2)*test_time_ns

    input_vals = {net.input:(input_times, input_rates)}
    output_data, _ = run_control.run_input_sweep(
        input_vals, get_raw_spikes=False, get_outputs=True)

    outputs, bin_times_ns = output_data
    outputs = outputs[net.output]
    decode = outputs[:, -1]
    spikes = outputs[:, :-1]
    
    bin_times = bin_times_ns * 1E-9
    bin_times -= bin_times[0]
    return decode, spikes, bin_times
decode, spikes, bin_times = run_test(hal, run_control)
if np.sum(spikes[0]) > 2*spikes.shape[1]: # zero-out spikes that accumulated between traffic activation and exp
    spikes[0] = 0

In [None]:
# process output data
def check_decode(decode):
    bins_gt0 = np.sum(decode>0)
    total_outputs = np.sum(decode)
    bins_1 = np.sum(decode==1)
    bins_2 = np.sum(decode==2)
    bins_gt2 = np.sum(decode>2)
    bin_vals_gt2 = np.unique(decode[decode>2])
    print("Collected {:d} non-zero output bins. Sum(outputs) = {:d}".format(bins_gt0, total_outputs))
    print("Bin stats:1-spike bins: {:d}, 2-spike bins: {:d}, >2-spike bins: {:d} (bin values {})".format(
        bins_1, bins_2, bins_gt2, bin_vals_gt2))
check_decode(decode)

In [None]:
# make a raster of spikes and outputs
to_raster = np.zeros((decode.shape[0], 1+spikes.shape[1]))
to_raster[:, :-1] = spikes
to_raster[:, -1] = decode
to_raster[to_raster>1] = 1
plt.subplots(figsize=(16, 12))
plot_spikes(bin_times, to_raster)

In [None]:
dt = bin_times[1] - bin_times[0]
tau = 0.01

valid_decode = decode.copy()
valid_decode[decode>10] = 0
filtered_decode = lpf(valid_decode, tau, dt)

filtered_spikes = lpf(spikes, tau, dt)
decoded_spikes = spikes*decoders[:, 0]
filtered_decoded_spikes = lpf(decoded_spikes, tau, dt)
decode = np.sum(filtered_decoded_spikes, axis=1)

In [None]:
mean = np.mean(filtered_decode[bin_times>5*tau])
var = np.var(filtered_decode[bin_times>5*tau])
print(mean/np.sqrt(var))

plt.subplots(figsize=(10,4))
plt.plot(bin_times, filtered_decode, label="filtered accumulator output")
plt.plot(bin_times, decode, label="filtered, decode-weighted, raw spikes via identity matrix")
plt.legend(loc="best")
plt.xlabel("time")

plt.subplots(figsize=(10,4))
for idx in range(X*Y):
    plt.plot(bin_times, filtered_spikes[:, idx])

# Fixed, All Positive Decode Weights

- compare to all weights positive and equal
- check for poissonness of superposed spike trains
- sweep decoder magnitude

In [None]:
class DData:
    def __init__(self, decode, spikes, bin_times, dweights):
        self.decode = decode
        self.spikes = spikes
        self.bin_times = bin_times
        self.dweights = dweights
    
def dw_sweep_collect_data(dweights, net, run_control, hal, labels=None):
    """Collect data from experiments that sweep across dweights"""
    dw_data = {}
    if labels is not None:
        assert len(dweights) == len(labels)
        diter = zip(dweights, labels)
    else:
        diter = zip(dweights, dweights)
    for dw, label in diter:
        assign_decoders(net, dw, hal)
        decode, spikes, bin_times = run_test(hal, run_control, bin_size=0.0001)
        
        check_decode(decode)
        sticky_decode_idx = decode > 100000
        if sticky_decode_idx.any():
            print("zeroing out {} sticky bitted decode bins".format(np.sum(sticky_decode_idx)))
        decode[sticky_decode_idx] = 0
    
        sticky_spk_idx = spikes > 100000
        if sticky_spk_idx.any():
            print("zeroing out {} sticky bitted spike bins".format(np.sum(sticky_spk_idx)))
            spikes[sticky_spk_idx] = 0

        dw_data[label] = DData(decode=decode, spikes=spikes, bin_times=bin_times, dweights=dw)
    return dw_data

In [None]:
hal.set_time_resolution(downstream_ns=10000, upstream_ns=100000)
dweights = [1, 1/2, 1/4, 1/8, 1/16, 1/32, 1/64]
# dweights = [1, 1/2, 1/4]
ddata = dw_sweep_collect_data(dweights, net, run_control, hal)
for dw in ddata: # clean up sticky bitted spikes if present
    if (ddata[dw].spikes > 100000).any():
        ddata[dw].spikes[ddata[dw].spikes > 100000] = 0

In [None]:
TAU = 0.01
def dw_sweep_analyze_data(ddata):
    for dw in ddata:
        decode = ddata[dw].decode
        spikes = ddata[dw].spikes
        bin_times = ddata[dw].bin_times
        dt = bin_times[1] - bin_times[0]
        filtered_decode = lpf(decode, TAU, dt)
        idx = bin_times>5*TAU
        fin = np.sum(spikes[idx]) / (bin_times[idx][-1] - bin_times[idx][0])
        fout = np.sum(decode[idx]) / (bin_times[idx][-1] - bin_times[idx][0])
        mean = np.mean(filtered_decode[idx])
        var = np.var(filtered_decode[idx])
        snr = mean/np.sqrt(var)
        print("dw {} fin {:.0f} fout {:.0f} mean {:.0f} var {:.0f} snr {:.2f}".format(
            dw, fin, fout, mean, var, snr))
        dspk_times = bin_to_spk_times(decode, bin_times)
        isi = np.diff(dspk_times)
        isi_cv = np.sqrt(np.var(isi)) / np.mean(isi)
        
        ddata[dw].filtered_decode = filtered_decode
        ddata[dw].fin, ddata[dw].fout = (fin, fout)
        ddata[dw].mean, ddata[dw].snr = (mean, snr)
        ddata[dw].dspk_times = dspk_times
        ddata[dw].isi = isi
        ddata[dw].isi_cv = isi_cv
    return ddata
ddata = dw_sweep_analyze_data(ddata)

In [None]:
def dw_sweep_plot_data(ddata):
    fig_stats, axs_stats = plt.subplots(ncols=2, nrows=2, figsize=(16, 12))
    ax_snr, ax_cv = axs_stats[0]
    ax_snr_check = axs_stats[1, 0]
    fig_exp, axs_exp = plt.subplots(ncols=2, figsize=(14,4))
    ax_filt, ax_f = axs_exp
    fig_hist, axs_hist = plt.subplots(nrows=len(ddata), figsize=(14, 2*len(ddata)), sharex=True)
    
    fout = np.zeros(len(ddata))
    fin = np.zeros(len(ddata))
    snr = np.zeros(len(ddata))
    snr_th = np.zeros(len(ddata))
    isi_cv = np.zeros(len(ddata))
    for idx, dw in enumerate(ddata):
        color = ax_filt.plot(ddata[dw].bin_times, ddata[dw].filtered_decode)[0].get_color()
        ax_filt.axhline(ddata[dw].mean, color="k", alpha=0.5)
        fin[idx] = ddata[dw].fin
        fout[idx] = ddata[dw].fout
        snr[idx] = ddata[dw].snr
        snr_th[idx] = get_snr_gamma(ddata[dw].fout*TAU, 1/dw)
        isi_cv[idx] = ddata[dw].isi_cv
        axs_hist[idx].hist(ddata[dw].isi, bins=50, cumulative=False, density=True, histtype="step", color=color)

    ax_filt.set_title('do means look reasonable?')
    fends = np.array([fout.min(), fout.max()])
    snr_poi = np.sqrt(2*fends*TAU)
    snr_per_high_lt_appx = np.sqrt(12)*fends*TAU
    snr_per = get_snr_periodic(fout*TAU)

    ax_snr.loglog(fout, snr, 'o', label="observed snr")
    ax_snr.loglog(fout, snr_th, '-o', label="theoretical gamma snr")
    ax_snr.loglog(fends, snr_poi, label="theoretical poisson snr")
    per_color = ax_snr.loglog(fout, snr_per, label="theoretical periodic snr")[0].get_color()
    ax_snr.loglog(fends, snr_per_high_lt_appx, color=per_color, alpha=0.2)
    ax_snr.legend(loc="best")
    ax_snr.set_xlabel("f_out")
    ax_snr.set_ylabel("SNR")
    ax_snr.grid(which="both")
    
    rel_snr = np.abs(snr-snr_th)/snr_th
    ax_snr_check.semilogx(fout, rel_snr, '-o')
    ax_snr_check.grid()
    ax_snr_check.set_xlabel("f_out")
    ax_snr_check.set_ylabel("|SNR - SNR_theory| / SNR_theory")
    
    ax_cv.semilogx(fout, isi_cv, '-o')
    ax_cv.set_ylim([0, 1])
    ax_cv.set_xlabel("f_out")
    ax_cv.set_ylabel("CV")
    ax_cv.grid(which="both")

    ax_f.plot(fin, '-o', label="f_in")
    ax_f.plot(fout, '-o', label="f_out")
    ax_f.legend(loc="best")
    ax_f.set_title("is f_in constant across trials?")    

dw_sweep_plot_data(ddata)

# Trained Decoders

- What should f_in be?
- What should the decode weight be?

In [None]:
# build decoder sets to try
fmaxes = np.linspace(5000, 100, 5)
decoder_sets = []
for idx, fmax in enumerate(fmaxes):
    target_function = tinfo.train_input_rates + fmax
    decoder_sets.append(fit_decoders(tinfo.spike_rates, target_function).flatten())
#     print(decoder_sets[-1])
    print(np.sum(decoder_sets[-1]==0))

In [None]:
ddata = dw_sweep_collect_data(decoder_sets, net, run_control, hal, labels=fmaxes)

In [None]:
TAU = 0.001
def fmax_sweep_analyze_data(ddata):
    for fmax in ddata:
        decode = ddata[fmax].decode
        spikes = ddata[fmax].spikes
        bin_times = ddata[fmax].bin_times
        dt = bin_times[1] - bin_times[0]
        filtered_decode = lpf(decode, TAU, dt)
        idx = bin_times>5*TAU
        fin = np.sum(spikes[idx]) / (bin_times[idx][-1] - bin_times[idx][0])
        fout = np.sum(decode[idx]) / (bin_times[idx][-1] - bin_times[idx][0])
        mean = np.mean(filtered_decode[idx])
        var = np.var(filtered_decode[idx])
        snr = mean/np.sqrt(var)
        print("fmax {} fin {:.0f} fout {:.0f} mean {:.0f} var {:.0f} snr {:.2f}".format(
            fmax, fin, fout, mean, var, snr))
        dspk_times = bin_to_spk_times(decode, bin_times)
        isi = np.diff(dspk_times)
        isi_cv = np.sqrt(np.var(isi)) / np.mean(isi)
        
        ddata[fmax].filtered_decode = filtered_decode
        ddata[fmax].fin, ddata[fmax].fout = (fin, fout)
        ddata[fmax].mean, ddata[fmax].snr = (mean, snr)
        ddata[fmax].dspk_times = dspk_times
        ddata[fmax].isi = isi
        ddata[fmax].isi_cv = isi_cv
    return ddata

ddata = fmax_sweep_analyze_data(ddata)

In [None]:
def fmax_sweep_plot_data(ddata):
    fig_stats, axs_stats = plt.subplots(ncols=2, nrows=2, figsize=(16, 12))
    ax_snr, ax_cv = axs_stats[0]
    ax_snr_check = axs_stats[1, 0]
    fig_exp, axs_exp = plt.subplots(ncols=2, figsize=(14,4))
    ax_filt, ax_f = axs_exp
    fig_hist, axs_hist = plt.subplots(nrows=len(ddata), figsize=(14, 2*len(ddata)), sharex=True)
    fig_dw, axs_dw = plt.subplots(nrows=len(ddata), figsize=(14, 2*len(ddata)), sharex=True)
    
    fout = np.zeros(len(ddata))
    fin = np.zeros(len(ddata))
    snr = np.zeros(len(ddata))
    snr_th = np.zeros(len(ddata))
    isi_cv = np.zeros(len(ddata))
    fmax_tgt = np.zeros(len(ddata))
    for idx, fmax in enumerate(ddata):
        color = ax_filt.plot(ddata[fmax].bin_times, ddata[fmax].filtered_decode)[0].get_color()
        ax_filt.axhline(ddata[fmax].mean, color="k", alpha=0.5)
        fin[idx] = ddata[fmax].fin
        fout[idx] = ddata[fmax].fout
        snr[idx] = ddata[fmax].snr
        fmax_tgt[idx] = fmax
        effective_dw = ddata[fmax].fout / ddata[fmax].fin
        snr_th[idx] = get_snr_gamma(ddata[fmax].fout*TAU, 1/effective_dw)
        isi_cv[idx] = ddata[fmax].isi_cv
        axs_hist[idx].hist(
            ddata[fmax].isi, bins=50, cumulative=False, density=True, histtype="step", color=color,
            label="target fmax={}".format(fmax))
        axs_hist[idx].legend(loc="upper right")
        axs_dw[idx].hist(
            ddata[fmax].dweights, bins=50, cumulative=False, density=True, histtype="step", color=color,
            label="target fmax={}".format(fmax))
        axs_dw[idx].legend(loc="upper right")
    ax_filt.set_title('do means look reasonable?')
    ax_filt.set_xlabel("time")
    ax_filt.set_ylabel("filtered decode")
    axs_hist[0].set_title("ISI Distributions")
    axs_hist[-1].set_xlabel("ISIs")
    axs_dw[0].set_title("Decode Weight Distributions")
    axs_dw[-1].set_xlabel("Decode Weights")
    axs_dw[-1].set_xlim([-1.1, 1.1])
    fends = np.array([fout.min(), fout.max()])
    snr_poi = np.sqrt(2*fends*TAU)
    snr_per_high_lt_appx = np.sqrt(12)*fends*TAU
    snr_per = get_snr_periodic(fout*TAU)

    ax_snr.loglog(fout, snr, 'o', label="observed snr")
    ax_snr.loglog(fout, snr_th, '-o', label="theoretical gamma snr")
    ax_snr.loglog(fends, snr_poi, label="theoretical poisson snr")
    per_color = ax_snr.loglog(fout, snr_per, label="theoretical periodic snr")[0].get_color()
    ax_snr.loglog(fends, snr_per_high_lt_appx, color=per_color, alpha=0.2)
    ax_snr.legend(loc="best")
    ax_snr.set_xlabel("f_out")
    ax_snr.set_ylabel("SNR")
    ax_snr.grid(which="both")
    
    rel_snr = np.abs(snr-snr_th)/snr_th
    ax_snr_check.semilogx(fout, rel_snr, '-o')
    ax_snr_check.grid()
    ax_snr_check.set_xlabel("f_out")
    ax_snr_check.set_ylabel("|SNR - SNR_theory| / SNR_theory")
    
    ax_cv.semilogx(fout, isi_cv, '-o')
    ax_cv.set_ylim([0, 1])
    ax_cv.set_xlabel("f_out")
    ax_cv.set_ylabel("CV")
    ax_cv.grid(which="both")

    ax_f.plot(fmax_tgt, fin, '-o', label="f_in")
    ax_f.plot(fmax_tgt, fout, '-o', label="f_out")
    ax_f.legend(loc="best")
    ax_f.set_title("is f_in constant across trials?")    
    ax_f.set_xlabel("target fmax")

fmax_sweep_plot_data(ddata)

# Scratchspace

In [None]:
def th_snr_sweep():
    fins = np.linspace(1000, 10000, 5)
    dweights = np.linspace(0.015, 1, 100)
    fouts = np.zeros((len(dweights), len(fins)))
    routs = np.zeros_like(fouts)
    plt.figure()
    for fidx, fin in enumerate(fins):
        for didx, dw in enumerate(dweights):
            fouts[didx, fidx] = fin*dw
            routs[didx, fidx] = get_snr_gamma(fouts[didx, fidx]*TAU, 1/dw)
        plt.loglog(fouts[:, fidx], routs[:, fidx])
    fends = np.array([fouts.min(), fouts.max()])
    r_uni = np.sqrt(12)*fends*TAU
    r_poi = np.sqrt(2*fends*TAU)
    plt.loglog(fends, r_uni)
    plt.loglog(fends, r_poi)
    plt.ylim([plt.ylim()[0], 1.1*routs.max()])
th_snr_sweep()