## Loihi Parameter Tuning Dashboard

This tool lets you change some of the main parameters on Loihi via an interactive dashboard. It is built on top of the Brian2Loihi Emulator (https://github.com/sagacitysite/brian2_loihi). Each time a parameter changes, a simulation is run and the plots are updated.

### Imports

In [1]:
import math
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from brian2 import *
prefs.codegen.target = 'numpy'  # use the Python fallback
from brian2_loihi import *

### Parameters

In [2]:
# INITIAL PARAMETERS
# The sliders are initially set to these values and can then be changed by the user
runtime_0 = 100               # number of timesteps to run simulation
ref_time_0 = 2                # refractory period
threshold_mantissa_0 = 400    # spike threshold mantissa (threshold = mantissa * 2^6)
decay_v_0 = 1024              # voltage decay
decay_I_0 = 1024              # synaptic input decay
weight_mantissa_0 = 128       # weight mantissa
weight_exponent_0 = 0         # weight exponent
n_bits_0 = 8                  # number of bits used to store weight mantissa

### Create design tool

In [3]:
def design_tool(runtime = widgets.IntSlider(min=1, max=200, step=1, value=runtime_0, continuous_update=False, style={'description_width': 'initial'}),
                decay_v = widgets.IntSlider(min=0, max=4096, step=1, value=decay_v_0, continuous_update=False, style={'description_width': 'initial'}),
                decay_I = widgets.IntSlider(min=0, max=4096, step=1, value=decay_I_0, continuous_update=False, style={'description_width': 'initial'}),
                weight_mantissa = widgets.IntSlider(min=0, max=255, step=1, value=weight_mantissa_0, continuous_update=False, style={'description_width': 'initial'}),
                weight_exponent = widgets.IntSlider(min=-8, max=7, step=1, value=weight_exponent_0, continuous_update=False, style={'description_width': 'initial'}),
                n_bits = widgets.Dropdown(options=[1,2,3,4,5,6,8], value=n_bits_0, description='numWeightBits', disabled=False, style={'description_width': 'initial'}),
                threshold_mantissa = widgets.IntSlider(min=0, max=131071, step=1, value=threshold_mantissa_0, continuous_update=False, style={'description_width': 'initial'}),
                ref_time = widgets.IntSlider(min=0, max=10, step=1, value=ref_time_0, continuous_update=False, style={'description_width': 'initial'}),
                scale = widgets.Dropdown(options=['linear', 'log'], value='linear', description='scale', disabled=False, style={'description_width': 'initial'})):


    # FUTURE ADDITIONS
    # sign mode
    is_mixed = 0 # for now only excitatory weight

    # GENERATE NETWORK FROM SLIDER VALUES -------------------------------

    # INPUT
    _input = LoihiSpikeGeneratorGroup(1,
                                      indices=[0],
                                      times=[1])

    # NEURON
    neuron = LoihiNeuronGroup(N=1,
                              refractory=ref_time,
                              threshold_v_mant=threshold_mantissa,
                              decay_v=decay_v,
                              decay_I=decay_I)
    # SYNAPSE
    synapse = LoihiSynapses(_input,
                            neuron,
                            w_exp=weight_exponent,
                            sign_mode=synapse_sign_mode.EXCITATORY,
                            num_weight_bits=n_bits)
    synapse.connect()
    synapse.w = weight_mantissa

    # MONITORS
    state_mon_v = LoihiStateMonitor(neuron, 'v')
    state_mon_I = LoihiStateMonitor(neuron, 'I')
    spike_mon = LoihiSpikeMonitor(neuron)

    # CREATE AND STORE NETWORK
    net = LoihiNetwork(neuron,
                       _input,
                       synapse,
                       state_mon_I,
                       state_mon_v,
                       spike_mon)

    # RUN SIMULATION ----------------------------------------------------
    net.run(runtime)

    # GET DATA ----------------------------------------------------------

    # synaptic input
    times = state_mon_I.t / ms
    synaptic_input = state_mon_I.I[0]
    max_synaptic_input = np.max(synaptic_input)
    total_current = np.sum(synaptic_input).round(2)

    # voltage
    voltage = state_mon_v.v[0]

    # threshold
    threshold = threshold_mantissa * 2**6

    # compute values for table
    if not list(spike_mon.t):
        max_voltage = np.max(state_mon_v.v).round(2)
        if max_voltage > 0:
            n_inputs_to_spike = (threshold/max_voltage).round(2)
        else:
            n_inputs_to_spike = "inf"
        perc_thresh = (max_voltage/threshold).round(2)
        n_spikes = 0
        firing_rate = 0
    else:
        max_voltage = threshold
        n_inputs_to_spike = 1
        perc_thresh = 1
        n_spikes = len(spike_mon.t)
        firing_rate = np.round(1000/runtime * n_spikes, 2)

    # compute possible weights
    numLsbBits = 8 - n_bits - is_mixed
    min_possible_weight_mant = 0
    max_possible_weight_mant = 255
    possible_weight_mantissas = np.arange(min_possible_weight_mant,
                                          max_possible_weight_mant + 1)

    # Shift weight mantissa
    w_shifted = np.floor((possible_weight_mantissas / 2**numLsbBits)).astype(int) * 2**numLsbBits
    # Scale weight with weight exponent
    w_scaled = w_shifted * 2 **(6.0+weight_exponent)
    # Shift scaled weight
    w_scaled_shifted = (w_scaled / 2**6).astype(int) * 2**6
    # Apply 21 bit limit
    w_values = np.clip(w_scaled_shifted, -2097088, 2097088)

    # PLOT ----------------------------------------------------------------
    fig, ax = plt.subplots(2, 2,
                           gridspec_kw={
                               'width_ratios': [2, 1],
                               'height_ratios': [1, 1]},
                           figsize=(15,7))

    # SYNAPTIC INPUT
    ax[0][0].set_yscale(scale)
    ax[0][0].plot(times, synaptic_input, color="black")
    ax[0][0].set_ylim(0.001, 1.1*max_synaptic_input)
    ax[0][0].set_ylabel('Synaptic input (I)')

    # VOLTAGE
    ax[1][0].set_yscale(scale)
    ax[1][0].plot(times, voltage, color="black")
    ax[1][0].set_ylim(0.001,1.1*threshold)
    ax[1][0].set_xlabel('Time step')
    ax[1][0].set_ylabel('Voltage (v)')
    h_thr = ax[1][0].axhline(y=threshold, linewidth=1, ls='--', color='gray')

    for t in spike_mon.t:
        h_spk = ax[1][0].axvline(t, ls='--', color='firebrick', lw=2)

    if not list(spike_mon.t):
        ax[1][0].legend([h_thr], ['Threshold'], loc='upper right', facecolor='white', framealpha=1)
    else:
        ax[1][0].legend([h_thr, h_spk], ['Threshold', 'Spike'], loc='upper right', facecolor='white', framealpha=1)


    # TABLE
    table_data = np.array([max_voltage, perc_thresh, n_inputs_to_spike, total_current, n_spikes, firing_rate]).reshape(6,1)
    rowlabel=("max v", "max v / v_th", "num. inputs to v_th", "total synaptic input", "num. spikes", 'num. spikes / sec')
    collabel=("Statistics", "")
    table_1 = ax[0][1].table(cellText=table_data,
                            colLabels=collabel,
                            rowLabels=rowlabel,
                            loc='center right')
    table_1.auto_set_font_size(False)
    table_1.set_fontsize(14)
    table_1.scale(0.5, 2)
    table_1.auto_set_column_width(col=list(range(len(collabel))))
    ax[0][1].axis('tight')
    ax[0][1].axis('off')

    # WEIGHTS
    ax[1][1].axvline(weight_mantissa,
                     ymin=0,
                     ymax= w_values[weight_mantissa]/np.max(w_values) if np.max(w_values) > 0 else 1,
                     linestyle='--', linewidth=0.5, color='gray')
    ax[1][1].axhline(w_values[weight_mantissa],
                     xmin=0,
                     xmax=weight_mantissa/max_possible_weight_mant,
                     linestyle='--', linewidth=0.5, color='gray')
    ax[1][1].plot(w_values, 'k', Marker='.', MarkerSize=0.01)
    h_weight = ax[1][1].scatter(weight_mantissa, w_values[weight_mantissa], color='k')
    ax[1][1].set_ylim(0 if np.max(w_values) > 0 else -0.1, np.max(w_values) if np.max(w_values) > 0 else 0.1)
    ax[1][1].set_xlim(min_possible_weight_mant, max_possible_weight_mant)
    ax[1][1].set_xlabel('Weight Mantissa ($w_{mant}$)')
    ax[1][1].set_ylabel('Actual Weight ($w_{mant} \cdot 2^{6 + w_{exp}}$)')
    ax[1][1].legend([h_weight], ['Current weight'], loc='lower right', facecolor='white', framealpha=1)


### Launch

In [4]:
wgt = widgets.interact(design_tool)

interactive(children=(IntSlider(value=100, continuous_update=False, description='runtime', max=200, min=1, sty…