In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import pandas as pd
from numba import jit
import models as md
from ipywidgets import interact, interactive, fixed, interact_manual, Button, HBox, VBox
from IPython.display import display
import ipywidgets as widgets

# Feedforward inhibition model

In [2]:
def plot_trial(exc_scale=9, inh_scale=9, r_m=10, tau_inh=10, rho_null=2, speed_idx=2):
    tau_m = 0.023
    e_l = -0.079
    r_m = r_m*1e6 # MOhm
    v_t = -0.061
    init_vm_std = 0.000
    vt_std = 0.000

    rho_null = rho_null/1000
    tau_inh = tau_inh/1000
    rho_scale = inh_scale*1e6

    dt = 0.0001
    total_time = 5
    init_period = 2

    noise_std_exc = 0*1e-3
    noise_std_inh = 0*1e-3
    n_timepoints = int((total_time+init_period)/dt)

    # generate looming stimulus
    LV_vals = np.array([0.19, 0.38, 0.56, 0.74, 0.93, 1.11])
    stim_size = 10
    speeds = 1/(LV_vals/stim_size)
    speed = speeds[speed_idx]
    cutoff_angle = 180
    print(speed)

    m = 5.0
    b = 0

    t, stims, tstims, dists, t_to_collision, transformed_stim_to_collision = md.transform_stim(stim_size, speed,
                                                                                               total_time, dt, m,
                                                                                               b, init_period,
                                                                                               cutoff_angle)

    stimulus = tstims*exc_scale*1e-11
    sigma_exc = noise_std_exc * np.sqrt(dt)
    sigma_inh = noise_std_inh * np.sqrt(dt)
    noise_exc = np.random.normal(loc=0.0, scale=sigma_exc, size=n_timepoints)
    noise_inh = np.random.normal(loc=0.0, scale=sigma_inh, size=n_timepoints)
    print(np.min(noise_exc), np.max(noise_exc))
    time, v_m, spks, spk_idc, rho_inh = md.jit_ffi_model(tau_m, e_l, r_m, stimulus, noise_exc, noise_inh, v_t, dt,
                                                      total_time, init_vm_std, vt_std, rho_null, tau_inh, rho_scale,
                                                      init_period)

    if not len(spks)==0:
        first_spike = spks[0]
        first_spike_idx = spk_idc[0]
    else:
        first_spike = 0
        first_spike_idx = 0

    fig, axes = plt.subplots(5, 1, figsize=(6,12))
    axes[0].plot(time, stims)
    axes[0].set_title(r'stimulus angle [$ \degree $]')
    #axes[0].set_ylim([0, 100])
    axes[1].plot(time, rho_inh)
    axes[1].set_title('inhibitory population activity')
    axes[2].plot(time, stimulus*r_m)
    axes[2].set_title('stimulus*r_m')
    axes[3].plot(time, stimulus*r_m - rho_inh)
    axes[3].hlines(0.018, time[0], time[-1], 'k')
    axes[3].set_title('effective input (stimulus + inhibition)')
    axes[4].plot(time, v_m)
    axes[4].set_title('membrane potential')
    if not len(spks) == 0:
        axes[4].plot(spks, np.ones(len(spks))*v_t, 'r*')
    plt.subplots_adjust(hspace=0.5)
    print('Response angle at first spike: ' + str(stims[first_spike_idx]) + ' degree')
    print('Distance at first spike: ' + str(dists[first_spike_idx]) + ' mm')

In [3]:
#parameter sliders
exc_scale_slider = widgets.FloatSlider(min=1, max=200, step=1, value=33, continuous_update=False)
inh_scale_slider = widgets.FloatSlider(min=1, max=20, step=0.2, value=9.6, continuous_update=False)
rm_slider = widgets.FloatSlider(min=1, max=20, step=0.2, value=10, continuous_update=False)
tau_inh_slider = widgets.IntSlider(min=1, max=50, step=2, value=10, continuous_update=False)
rho_null_slider = widgets.FloatSlider(min=0, max=50, step=1, value=10, continuous_update=False)
speed_idx_slider = widgets.IntSlider(min=0, max=5, step=1, value=2, continuous_update=False)
plotgroup = interactive(plot_trial, exc_scale=exc_scale_slider, inh_scale=inh_scale_slider, r_m=rm_slider,
                        tau_inh=tau_inh_slider, rho_null=rho_null_slider, speed_idx=speed_idx_slider)

# refresh button
button = widgets.Button(description='Refresh')
def on_button_clicked(b):
    plotgroup.update()
button.on_click(on_button_clicked)

# combine sliders and button
allgroups = HBox(children=[plotgroup, button])

In [4]:
allgroups

A Jupyter Widget

# Analyzing the effects of parameters of the inhibitory population

In [8]:
def plot_response_props(exc_scale, inh_scale, vt_std, rho_null, rho_null_std, tau_inh, cutoff_angle, exc_noise, m):
    params = {'tau_m': 0.023,
              'e_l': -0.079,
              'r_m': 10*1e6, # MOhm
              'v_t': -0.061,
              'init_vm_std': 0.0,
              'vt_std': vt_std/1000,
              'rho_null': rho_null,
              'rho_null_std': rho_null_std,
              'tau_inh': tau_inh/1000,
              'rho_scale': inh_scale*1e6,
              'exc_scale': exc_scale,
              'dt': 0.0005,
              'total_time': 10,
              'init_period': 2,
              'cutoff_angle': cutoff_angle,
              'noise_std_exc': exc_noise/1000,
              'noise_std_inh': 0*1e-3,
              'm': m,
              'b': 0,
              'lv_min': 0.02,
              'lv_max': 1.2,
              'l_min': 6,
              'l_max': 45,
              'init_distance': 160}
    nruns = 250
    data_cols = ['resp_angle', 'resp_dist', 'resp_time', 'lv', 'stim_size', 'speed', 'resp_time_coll']
    data_dict = dict([(col_name, []) for col_name in data_cols])

    for i in np.arange(nruns):
        resp_angle, resp_dist, resp_time, lv, stim_size, speed, resp_time_coll = md.calc_response_ffi(params)
        resp_angle = np.round(resp_angle, decimals=1)
        resp_dist = np.round(resp_dist, decimals=1)
        resp_time = np.round(resp_time, decimals=3)
        lv = np.round(lv, decimals=2)
        stim_size = np.round(stim_size, decimals=1)
        speed = np.round(speed, decimals=1)
        resp_time_coll = np.round(resp_time_coll, decimals=3)
        result_values = [resp_angle, resp_dist, resp_time, lv, stim_size, speed, resp_time_coll]
        for col, value in zip(data_cols, result_values):
            data_dict[col].append(value)

    df = pd.DataFrame(data_dict)
    
    sns.set('poster')
    
    fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(16,16))
    sns.regplot(x='resp_dist', y='lv', data=df, fit_reg=False, ax=axes[0,0])
    axes[0,0].set_ylim([0, 1.3])
    axes[0,0].set_xlim([0, 50])
    
    sns.regplot(x='resp_time', y='lv', data=df, fit_reg=False, ax=axes[0,1])
    # plot preuss2006 data
    lv_vals = np.array([0.075, 0.036, 0.02, 0.044, 0.055, 0.11, 0.03])
    latency_vals = np.array([0.1, 0.19, 0.22, 0.24, 0.3, 0.42, 0.7])
    axes[0,1].plot(latency_vals+params['init_period'], lv_vals, 'k.', ms=20)
    #axes[0,1].set_xlim([2, 3])
    
    sns.regplot(x='lv', y='resp_angle', data=df, fit_reg=False, ax=axes[1,0])
    # plot preuss2006 data
    theta_vals = np.array([28, 24, 14, 21, 19, 22, 16])
    axes[1,0].plot(lv_vals, theta_vals, 'k.', ms=20)
    axes[1,0].set_ylim([0, 180])
    axes[1,0].set_xlim([0, 1.3])
    
    sns.regplot(x='resp_time_coll', y='lv', data=df, fit_reg=False, ax=axes[1,1])
    axes[1,1].set_ylim([0, 1.3])
    axes[1,1].set_xlim([-5, 0])

In [9]:
#parameter sliders
exc_scale_slider = widgets.FloatSlider(min=1, max=200, step=1, value=33, continuous_update=False)
inh_scale_slider = widgets.FloatSlider(min=1, max=20, step=0.2, value=9.6, continuous_update=False)
vt_std_slider = widgets.FloatSlider(min=0, max=5, step=1, value=1, continuous_update=False)
rho_null_slider = widgets.FloatSlider(min=0, max=10, step=0.5, value=2, continuous_update=False)
rho_null_std_slider = widgets.FloatSlider(min=0, max=5, step=0.1, value=0.2, continuous_update=False)
tau_inh_slider = widgets.FloatSlider(min=0.05, max=25, step=0.5, value=1, continuous_update=False)
cutoff_slider = widgets.FloatSlider(min=120, max=180, step=10, value=150, continuous_update=False)
exc_noise_slider = widgets.FloatSlider(min=0, max=20, step=1, value=5, continuous_update=False)
m_slider = widgets.FloatSlider(min=1, max=6, step=0.5, value=3, continuous_update=False)

plotgroup = interactive(plot_response_props, exc_scale=exc_scale_slider, inh_scale=inh_scale_slider,
                        vt_std=vt_std_slider, rho_null=rho_null_slider, rho_null_std=rho_null_std_slider,
                        tau_inh=tau_inh_slider, cutoff_angle=cutoff_slider, exc_noise=exc_noise_slider, m=m_slider)

# refresh button
button = widgets.Button(description='Refresh')
def on_button_clicked(b):
    plotgroup.update()
button.on_click(on_button_clicked)

# combine sliders and button
allgroups = HBox(children=[plotgroup, button])

In [10]:
allgroups

A Jupyter Widget