In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from nengo.solvers import LstsqL2
from nengo_extras.plot_spikes import plot_spikes

import pystorm
from pystorm.hal import HAL
from pystorm.hal.net_builder import NetBuilder
from pystorm.hal.run_control import RunControl

In [None]:
# set parameters for network
#   number of neurons
# create the network
#   set tap points - this can be encapsulated and improved later

# what should I set the gain and bias bits to?
# 
X = 16
Y = 16
NNEURON = X*Y
DIM = 1
FMAX = 1000

DOWNSTREAM_NS = 10000

hal = pystorm.hal.HAL()

In [None]:
net_builder = NetBuilder(hal)

bad_syn = hal.get_calibration("synapse", "high_bias_magnitude")
SX = X // 2
SY = Y // 2
tap_matrix_syn = net_builder.create_default_yx_taps(SY, SX, DIM, bad_syn)
tap_matrix = net_builder.syn_taps_to_nrn_taps(tap_matrix_syn)
np.savetxt("tap_matrix.txt", tap_matrix)

In [None]:
gain_divs = np.loadtxt("gain_divisors.txt", dtype=int)
biases = np.loadtxt("biases.txt", dtype=int)
net = net_builder.create_single_pool_net(
    Y, X, tap_matrix, biases=biases, gain_divs=gain_divs, decoders=np.zeros((1, Y*X)))
run_control = RunControl(hal, net)

hal.map(net)

Collect data for training

In [None]:
bin_size = 0.5 # seconds
bin_size_ns = int(bin_size*1E9)
hal.set_time_resolution(DOWNSTREAM_NS, bin_size_ns)

total_train_points = 11
offset_time = 1.
offset_time_ns = int(offset_time*1E9)

train_rates = np.zeros((total_train_points+1, 1))
train_rates[:total_train_points,0] = FMAX * np.linspace(-1, 1, total_train_points)
train_rates[-1, 0] = train_rates[-2, 0]

train_time_ns = np.arange(total_train_points+1)*bin_size_ns+offset_time_ns
train_time_ns += hal.get_time()

input_vals = {net.input:(train_time_ns, train_rates)}

outputs_data, spike_data = run_control.run_input_sweep(
    input_vals, get_raw_spikes=True, get_outputs=False)
spikes, bin_times = spike_data

spikes = spikes[net.pool]
rates = spikes/bin_size
train_rates = train_rates[:-1]

In [None]:
def plot_tuning(inputs, rates):
    n_bins, n_neurons = rates.shape
    nsq = int(np.sqrt(n_neurons))
    half_sq = nsq//2
    for idx in range(nsq):
        start_l = 2*idx*half_sq
        start_r = start_l + half_sq
        plt.plot(inputs, rates[:, start_l:start_l+half_sq], 'r')
        plt.plot(inputs, rates[:, start_r:start_r+half_sq], 'b')
plot_tuning(train_rates, rates)

In [None]:
# fit decoders

target_function = train_rates + FMAX
solver = LstsqL2(reg=0.1)
print(rates.shape)
print(target_function.shape)
decoders, info = solver(rates, target_function)
rmse = info['rmses']
print(rmse)
train_decode = np.dot(rates, decoders)

In [None]:
z_idx = np.searchsorted(train_rates[:, 0], 0)
rates_0 = rates[z_idx]

plt.figure()
plt.plot(train_rates, target_function)
plt.plot(train_rates, train_decode)

plt.figure()
plt.hist(decoders[rates_0>0])
#TODO: collect testing data

In [None]:
net.decoder_conn.reassign_weights(decoders.T)
hal.remap_weights()

In [None]:
# run tests
#  deliver an input of 0
time_resolution_ns = 100000
hal.set_time_resolution(DOWNSTREAM_NS, time_resolution_ns)

offset_time = 0.3
offset_time_ns = int(offset_time*1E9)

test_time = 1
test_time_ns = int(test_time*1E9)

test_rates = np.zeros((2, 1))
test_times = np.arange(2)*test_time_ns + offset_time_ns
now_ns = hal.get_time()
test_times += now_ns

input_vals = {net.input:(test_times, test_rates)}

output_data, spike_data = run_control.run_input_sweep(
    input_vals, get_raw_spikes=True, get_outputs=True)

outputs, output_bin_times = output_data
spikes, spike_bin_times = spike_data
outputs = outputs[net.output][:, 0]
spikes = spikes[net.pool]

In [None]:
# output_bin_times = np.arange(10)
# outputs = np.arange(10*2).reshape((10, 2))
# spike_bin_times = np.arange(10)[1:-2]
# spikes = np.arange(10*2).reshape((10, 2))[1:-2]
# print(output_bin_times)
# print(spike_bin_times)
# print(outputs)
# print(spikes)

# clip arrays to same time region
# start of array
if output_bin_times[0] < spike_bin_times[0]:
    idx = np.searchsorted(output_bin_times, spike_bin_times[0])
    print("clipping {:d} elements from output data start to align with spike data start".format(idx))
    output_bin_times = output_bin_times[idx:]
    outputs = outputs[idx:]
elif spike_bin_times[0] < output_bin_times[0]:
    idx = np.searchsorted(spike_bin_times, output_bin_times[0])
    print("clipping {:d} elements from spike data start to align with output data start".format(idx))
    spike_bin_times = spike_bin_times[idx:]
    spikes = spikes[idx:]
# end of array
if output_bin_times[-1] > spike_bin_times[-1]:
    # clip output_bin_times
    idx = np.searchsorted(output_bin_times, spike_bin_times[-1], 'right')
    print("clipping {:d} elements from output data end to align with spike data end".format(len(output_bin_times)-idx))
    output_bin_times = output_bin_times[:idx]
    outputs = outputs[:idx]
elif spike_bin_times[-1] > output_bin_times[-1]:
    # clip spike_bin_times
    idx = np.searchsorted(spike_bin_times, output_bin_times[-1], 'right')
    print("clipping {:d} elements from spike data end to align with output data start".format(len(spike_bin_times)-idx))
    spike_bin_times = spike_bin_times[:idx]
    spikes = spikes[:idx]
# print(output_bin_times)
# print(spike_bin_times)
# print(outputs)
# print(spikes)

In [None]:
# process output data
output_times = (output_bin_times - now_ns - offset_time_ns) / 1E9

bins_gt0 = np.sum(outputs>0)
total_outputs = np.sum(outputs)
bins_1 = np.sum(outputs==1)
bins_2 = np.sum(outputs==2)
bins_gt2 = np.sum(outputs>2)
bin_vals_gt2 = np.unique(outputs[outputs>2])
print("Collected {:d} non-zero output bins. Sum(outputs) {:d}".format(bins_gt0, total_outputs))
print("Bin stats:1-spike bins: {:d}, 2-spike bins: {:d}, >2-spike bins: {:d} (bin values {})".format(
    bins_1, bins_2, bins_gt2, bin_vals_gt2))

In [None]:
# process spike data
bins_gt0 = np.sum(spikes>0)
total_spikes = np.sum(spikes)
bins_1 = np.sum(spikes==1)
bins_2 = np.sum(spikes==2)
bins_3 = np.sum(spikes==3)
bins_gt3 = np.sum(spikes>3)
bin_vals_gt3 = np.unique(spikes[spikes>3])
print("Collected {:d} non-zero spike bins. Sum(spikes) {:d}".format(bins_gt0, total_spikes))
print("Bin stats:1-spike bins: {:d}, 2-spike bins: {:d}, 3-spike bins: {:d}, >3-spike bins: {:d} (bin values {})".format(
    bins_1, bins_2, bins_3, bins_gt3, bin_vals_gt3))

In [None]:
# make a raster of spikes and outputs
to_raster = np.zeros((spikes.shape[0], spikes.shape[1]+1), dtype=int)
to_raster[:, 1:] = spikes
to_raster[:, 0] = outputs
to_raster[to_raster>1] = 1
plt.subplots(figsize=(16, 12))
plot_spikes(output_times, to_raster)

In [None]:
def lpf(signal, tau, dt):
    """Low pass filters a 1D timeseries"""
    ret = np.zeros(signal.shape)
    decay = np.expm1(-dt/tau)+1
    increment = -np.expm1(-dt/tau)/dt
    ret += increment*signal
    for idx in range(1, len(signal)):
        ret[idx] += ret[idx-1]*decay
    return ret

sample_rate = 10000
dt = float(time_resolution_ns)*1E-9
tau = 0.01

valid_outputs = outputs.copy()
valid_outputs[outputs>10] = 0
filtered_outputs = lpf(valid_outputs, tau, dt)

decoded_spikes = spikes*decoders.flatten()
filtered_decoded_spikes = lpf(decoded_spikes, tau, dt)
decode = np.sum(filtered_decoded_spikes, axis=1) 

mean = np.mean(filtered_outputs[output_times>5*tau])
var = np.var(filtered_outputs[output_times>5*tau])
print(mean/np.sqrt(var))

plt.subplots(figsize=(10,4))
plt.plot(output_times, filtered_outputs, label="filtered accumulator output")
plt.plot(output_times, decode, label="filtered, decode-weighted, raw spikes")
plt.legend(loc="best")
plt.xlabel("time")

In [None]:
scratch_outputs = np.zeros((len(output_times), 3))
scratch_outputs[outputs[:,0]==1, 0] = 1
scratch_outputs[outputs[:,0]==2, 1] = 1
scratch_outputs[outputs[:,0]>2, 2] = 1
# plot_spikes(output_times, scratch_outputs)
plt.subplots(figsize=(8,6))
plt.plot(output_times, scratch_outputs)
plt.xlim(0, 1.3)

### open ideas

- compare to all weights positive and equal
- check for poissonness of superposed spike trains
- sweep decoder magnitude