# Signal Generator and IAF Encoding Demo

In this demo, the user will 
1. Generate an input signal $u(t)$
2. Encode this signal as spikes using an Integrate-and-Fire neuron

## Library imports
Python version check.
This notebook was tested in Anaconda running python 3.7.4

In [224]:
from platform import python_version
print(python_version())

3.7.4


In [323]:
# These libraries should be included with anaconda
import ipywidgets as widgets
import matplotlib.pyplot as plt
from ipywidgets import interact, interactive_output, interactive, Layout

import numpy as np
np.random.seed(20204020)

from numpy import sin, pi, cumsum, sinc, transpose, diag
from numpy.matlib import repmat

In [226]:
# We will use ipympl as our plotting interface
try:
    import ipympl
except ImportError:
    !pip install ipympl # run terminal command to install ipympl
    import ipympl

## Helper Classes
These class containers are just to make code reuse easier across demos.

### class: signal
The signal class is just to have containers for storing handy information about the input and output to an encoder, such as its:
1. time stamps $t$
2. time step $dt$
3. input amplitudes $u(t)$
4. integrated input $\int_{0}^{t} u(t) dt$
5. encoder output $v(t)$

In [227]:
class signal:
    def __init__(self, duration = 1, step = 1e-5, waveform = 'sin'):
        self.dt = step
        self.t = np.arange(0, duration, step)
        self.u = np.zeros_like(self.t)
        self.intu = np.zeros_like(self.t) # just a placeholder, no integration yet
        self.v = np.zeros_like(self.t)
        
        if waveform == 'sin':
            self.u = sin(2*pi*self.t) # create a sine wave by default

### class: encoder
Holds the parameters and outputs for a given encoder.

In [233]:
class encoder:
    def __init__(self, model='iaf'):
        if model == 'iaf':
            encoder.d = 1e-3 # threshold
            encoder.tk = [0] # placeholder for spike times
        # more encoders can be imported

## Global variables
Note that these variables can be modified by the signal generator and the encoder.
Once the signal is modified, please run the encoder again to update plots.

In [234]:
in1 = signal() # input signal
nm1 = encoder()

# Interactive: signal generator
The user can generate an input signal to be encoded. So far options include:
1. sine
2. step
3. random band-limited sinc

Directions: run the cell below and then change input parameters to visualize corresponding signal.

In [334]:
%matplotlib widget

fig, sg_ax = plt.subplots(figsize=(8, 4))
sg_ax.set_ylim([-2, 2])
line, = sg_ax.plot(in1.t, in1.u, lw=2)
sg_ax.set_xlabel('Time [s]')
sg_ax.set_ylabel('Amplitude')
sg_ax.grid(True)

def plotter(wave, ampl, freq, time, startTime):
    if time <= 0:
        time = in1.dt
        
    in1.t = np.arange(0, time, in1.dt)
    
    if wave == 'sin':
        in1.u =  ampl*sin(2*pi*freq*in1.t) # compute data in sig gen, not plotter 
        sg_ax.set_title('${}{}(2\pi.{}t)$'.format(ampl, wave, freq))
        
    elif wave == 'step':
        in1.u = np.zeros_like(in1.t)
        in1.u[in1.t>startTime] = ampl
        sg_ax.set_title(wave)
        
    elif wave == 'sincs':
        W = 2*pi*freq # omega
        Ts = pi/W # skip 2 for Nyquist criterion
        Ns = int((in1.t[-1]-in1.t[0])/Ts)          # find the number of stimulus samples
        s  = np.random.rand(Ns) - 0.5
        in1.u = np.zeros_like(in1.t)
        for k in range(Ns):
            in1.u += s[k]*sinc((W/pi)*(in1.t -k*Ts)) / (W/pi)  # the first sample is zero
        
        ''' # use repmat for efficiency
        sample_time = 
        in1.u = sum( diag(sample) @ sinc( \
                    (repmat(in1.t, len(samples), 1) - repmat(sample_time, 1, len(in1.t))) 
                                         *(W/pi))/(W/pi))
        '''
        in1.u /= max(abs(in1.u)) # normalize 
        in1.u *= ampl # and scale
        
        sg_ax.set_title('$\omega = {:.2f} rad/s [{} Hz]$'.format(W, freq))

    line.set_data(in1.t, in1.u) # could swap for plotly calls
    sg_ax.set_xlim([in1.t[0], in1.t[-1]])
    sg_ax.set_ylim([-ampl, ampl])
    
    # update sizes of other arrays
    in1.intu = cumsum(in1.u)
    in1.v = np.zeros_like(in1.t)
    
    return fig

timeSlider = widgets.FloatSlider(
    description = 'Duration (s)', min = 0, max = 2, value = .1)
startSlider = widgets.FloatSlider(
    description = 'Step @(s)', min = 0, max = 2, value = 0.1)
freqSlider = widgets.IntSlider(
    description = 'Freq (Hz)', min = 0, max = 100, value = 40)
amplSlider = widgets.FloatSlider(
    description = 'Amplitude (V)', min = 0, max = 1, value = 0.5)
waveSelector = widgets.Dropdown(
    options=['sin', 'step', 'sincs'],
    value='sincs',
    description='Waveform:',
    disabled=False,
)

ui = widgets.VBox([
    waveSelector, 
    timeSlider, 
    freqSlider, 
    amplSlider, 
    startSlider])

out = widgets.interactive_output(plotter, {
    'wave': waveSelector, 
    'ampl': amplSlider, 
    'freq': freqSlider,
    'time': timeSlider,
    'startTime': startSlider,
})

display(out, ui)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Output()

VBox(children=(Dropdown(description='Waveform:', index=2, options=('sin', 'step', 'sincs'), value='sincs'), Fl…

## Helper functions for plotting
Here we perform the IAF's encoding, and have another plotting function.

In [341]:
def encode_iaf(change): 
    nm1.d = change['new'] # updated threshold from slider
    nm1.tk = [] # empty array of new spike times
    
    for n in range(len(in1.t)):
        if n == 0:
            in1.v[n] = 0
            in1.intu[n] = 0
        else:
            step = in1.dt*in1.u[n]
            in1.v[n] = in1.v[n-1] + step
            in1.intu[n] = in1.intu[n-1] + step
            
        if in1.v[n] > nm1.d:
            in1.v[n] = 0 #-= nm1.d
            nm1.tk.append(in1.dt * n)

def plot_t_transform(time, threshold):
    n = int(time/in1.dt)
    
    line_v.set_data(in1.t[:n], in1.v[:n]) # could swap for plotly calls
    line_intu.set_data(in1.t[:n], in1.intu[:n])
    line_d.set_data(in1.t, nm1.d)
    
    tk_to_plot = np.array(nm1.tk) # cast to np array for vector comparison below
    tk_to_plot = tk_to_plot[tk_to_plot < n*in1.dt] # find spikes that have can be shown
    ''' # To Do: make a spike raster
    tk_to_plot = np.transpose(np.asmatrix(tk_to_plot))
    line_tk.set_data([tk_to_plot, tk_to_plot], 
                     [-np.ones_like(tk_to_plot), np.ones_like(tk_to_plot)])
    '''
    line_tk.set_data(tk_to_plot, np.zeros_like(tk_to_plot))
    
    ax[0].set_xlim([in1.t[0], in1.t[-1]])
    ax[0].set_ylim([min(in1.u), max(in1.u)])
    ax2.set_ylim([-max(abs(in1.intu)), max(abs(in1.intu))])
    ax[1].set_xlim([in1.t[0], in1.t[-1]])
    ax[1].set_ylim([-1, 1])
    return fig

# Interactive: sliding through the encoding
User can change the threshold of the encoder, and then slide through time to see the integration and resulting spikes.

Directions: re-run this cell once an updated signal is available.

In [345]:
%matplotlib widget

# Create placeholder plots
fig, ax = plt.subplots(2, 1, # rows and columns
                       gridspec_kw={'height_ratios': [3, 1]}, # comparative sizes
                       figsize=(6, 4) # fig size in inches
                      ) 

# Plot the user generated input signal u
line_u, = ax[0].plot(in1.t, in1.u, 
                   label='$u(t)$',
                   lw=2)
ax[0].legend(loc='upper left')
ax[0].spines['right'].set_color('#1f77b4')
ax[0].tick_params(axis='y', colors='#1f77b4')

# Plot integrated u, threshold, and thresolded output
ax2 = ax[0].twinx()
line_intu, = ax2.plot(in1.t, np.zeros_like(in1.intu), 
                      label='$\int_{0}^{t} u(t) dt$',
                      lw=1, linestyle='dashed', color='c')
line_d, = ax2.plot(in1.t, nm1.d * np.ones_like(in1.t), 
                   label='threshold',
                   color='c', lw=2)
line_v, = ax2.plot(in1.t, np.zeros_like(in1.v), 
                   label='$v(t)$',
                   color='orange', lw=1)
ax2.legend(loc='upper right')
ax2.spines['right'].set_color('c')
ax2.tick_params(axis='y', colors='c')

# Plot spike times
line_tk, = ax[1].plot(nm1.tk, np.ones_like(nm1.tk),
                      #'x-', # use for plot raster
                      'x',
                      label='$t_k$')
ax[1].legend(loc='upper right')

# Slider to set integraton
timeSlider = widgets.FloatSlider(
    description = 'Interval t [s]', min=in1.t[0], max=in1.t[-1], step=in1.dt, value=in1.t[-1],
    layout=Layout(width='60%'))

threshSlider = widgets.FloatSlider(
    description = 'Threshold', 
    min = 0, max=max(abs(in1.intu)), step=in1.dt, 
    value=max(in1.intu)/2,
    readout_format='.4f',
    continuous_update=False, layout=Layout(width='60%'))

# Make threshold changes by the user re-run encoding
threshSlider.observe(encode_iaf, names='value')

ui = widgets.VBox([
    timeSlider, 
    threshSlider])

out = widgets.interactive_output(plot_t_transform, {
    'time': timeSlider,
    'threshold': threshSlider,
})

display(out, ui)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Output()

VBox(children=(FloatSlider(value=0.09999000000000001, description='Interval t [s]', layout=Layout(width='60%')…