In [None]:
import numpy as np
import h5py
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import colormaps
from matplotlib.colors import ListedColormap
import math
from scipy.optimize import curve_fit
from scipy.special import lambertw

In [None]:
filename_plot_data = 'data/plot_data.h5'
plots_folder_name = 'plots/'

# Set flag to save figures to plots_folder_name
save_fig = True

# Load data

In [None]:
n_RF_reference = 'n_RF_tar_scan' # Can be 'n_RF_tar_scan', 'n_ff_syn_mostSignificant_scan' or 'n_ff_syn_mostSignificant_scan_old'


""" Loading the data from the HDF5 file for plotting and analysis. """


with h5py.File(filename_plot_data, 'r') as f:
    # Load 1D vanilla data
    g = f['1D_excitatory']
    tresp_data_1D_vanilla = g['tresp_net'][:]
    nff_data_1D_vanilla = g['nRF'][:]

    # Load 1D lagged data
    g = f['1D_balanced']
    tresp_data_1D_lagged_values = g['tresp_full'][:]
    nff_data_1D_lagged = g['nRF'][:]
    dt_1D_lagged = g.attrs['dt']
    n_lag_1D_lagged = g.attrs['n_lag']

    # Load 2D linear mixed vanilla data
    g = f['2D_linear_mixed_excitatory']
    tresp_data_2D_linear_vanilla = g['tresp_net'][:]
    nff_data_2D_linear_vanilla = g['nRF'][:]

    # Load 2D linear mixed lagged data
    g = f['2D_linear_mixed_balanced']
    tresp_data_2D_linear_lagged_values = g['tresp_full'][:]
    nff_data_2D_linear_lagged = g['nRF'][:]
    x_target_scan = g['x_target_scan'][:]
    n_lag_2D_linear_lagged = g.attrs['n_lag']

    # Load 2D nonlinear mixed vanilla data
    g = f['2D_excitatory']
    tresp_data_2D_nonlinear_vanilla = g['tresp_net'][:]
    nff_data_2D_nonlinear_vanilla = g['nRF'][:]

    # Load 2D nonlinear mixed lagged data
    g = f['2D_balanced']
    tresp_data_2D_nonlinear_lagged_values = g['tresp_net'][:]
    nff_data_2D_nonlinear_lagged = g['nRF'][:]
    x_target_scan_2D = g['x_target_scan'][:]
    n_lag_2D_nonlinear_lagged = g.attrs['n_lag']

    # Load activity propagation data
    g = f['activity_propagation']
    states_actProp_vanilla = g['activity'][:]
    loss_curve_vanilla_actProp = g['loss_curve'][:]
    loss_curve_lagged_actProp = g['loss_curve_balanced'][:]
    n_store_every_tsteps = g.attrs['n_store_every_tsteps']
    tau_actProp = g.attrs['tau']
    dt_actProp = g.attrs['dt']
    t_resp_net_actProp = g.attrs['t_resp_net']
    t_resp_full_actProp = g.attrs['t_resp_full']
    t_resp_full_experimental_actProp = g.attrs['t_resp_bal_experimental']

    # Load SFA data
    g = f['SFA']
    t_respBest_data_SFAtf = g['t_resp_SFA_best'][:]
    n_RF_vals_SFAtf = g['nRF'][:]
    tau_SFA = g.attrs['tau_SFA']

    # Load loss evolution data around critical EI-balance
    g = f['loss_evolution_around_critical']

    loss_curve_critical = g['loss_curve_critical'][:]
    loss_curve_overdamped = g['loss_curve_overdamped'][:]
    loss_curve_underdamped = g['loss_curve_underdamped'][:]
    dt_around_critical = g.attrs['dt']
    tau_around_critical = g.attrs['tau']

In [None]:
# Analytic expressions and analysis functions

def get_tresp_nff(d, tau=1):
    n_ff = 2*d+1
    gamma = np.exp(-1/d)
    w_rec_sum = 2 * 1/(gamma+gamma**(-1))
    t_resp = tau / (1-w_rec_sum)
    return t_resp, n_ff

def get_tresp_appr(d, tau=1):
    gamma = np.exp(-1/d)
    t_resp = tau * (1 + 2*d**2)
    return t_resp

def analyze_state_dynamics(states, thr_lines_at, dt, tau, x_speed_ref=None, n_steps_use=None):
    state_data = states.copy()
    n_steps, d_state = state_data.shape
    n_steps_use = n_steps if n_steps_use is None else n_steps_use
    x_speed_ref = state_data[-1] if x_speed_ref is None else x_speed_ref
    state_data = state_data[:n_steps_use]
    times_crossing_list = []
    for thr in thr_lines_at:
        times_crossing = np.argmax(state_data>thr*x_speed_ref[np.newaxis,:], axis=0)*dt/tau # In units of tau
        times_crossing_list.append(times_crossing)

    return np.array(times_crossing_list)

In [None]:
# 1D theory
d_vals_1D_vanilla = (nff_data_1D_vanilla-1)/2
tresp_data_1D_vanilla_theory, nff_data_1D_vanilla_theory = get_tresp_nff(d_vals_1D_vanilla) 
nff_data_1D_lagged_theory = nff_data_1D_vanilla_theory
tresp_appr_data_1D_vanilla_theory = get_tresp_appr(d_vals_1D_vanilla) 
tresp_appr_data_1D_lagged_theory = np.sqrt(tresp_appr_data_1D_vanilla_theory*n_lag_1D_lagged*dt_1D_lagged/2)

# Manual plots

In [None]:
# ===== Design =====

SMALL_SIZE = 12 # 8 # 6
MEDIUM_SIZE = 18 # 8 # 6
BIGGER_SIZE = 20 # 12 #12

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE + 4) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

cmap_base = colormaps['RdBu_r']

colors = {
    'blue'   : cmap_base(0.1),   # '#1f77b4ff', #"blue",
    'red'    : cmap_base(0.9),   #"red",
    'orange' : '#ff7f0eff', #"orange",
    'brown'  : '#c54700ff', #'brown',
    'black'  : '#000000ff', #'black',
    'gray'   : '#7f7f7fff', #'gray',
}

cmap_name = 'OrRd'
cmap_activity_ = colormaps[cmap_name]
v_max = 1.
v_min = 0.025
n_samples = int(256/(v_max-v_min))
print(f"n_samples = {n_samples}")
cmap_activity_vals = cmap_activity_.resampled(n_samples)
cmap_activity = ListedColormap(cmap_activity_vals(np.linspace(v_min, v_max, 256)))
cmap_activity_discrete = colormaps[cmap_name].resampled(20)

# Add further colors to colors dict
# colors['x_in_inactive'] = cmap_activity(0)      # Inactive Input cells
colors['x_in_inactive'] = '#d3d3d3'      # Inactive Input cells
colors['FF_conns']      = colors['blue'] # '#2a2bffff' # Feedforward connections
colors['REC_conns']     = colors['black'] # '#2a2bffff' # Recurrent connections
colors['REC_input']     = colors['black']   # '#2a2bffff' # Recurrent inputs
colors['net']           = colors['blue']    # Net interactions only
colors['full']          = colors['red']     # Full interactions
colors['SFA']           = colors['orange']  # Spike frequency adaptation
colors['expl_inh']      = colors['orange']  # colors['brown']   # Exploratory inhibition

In [None]:
def get_tresp_and_wrisc(wrns, tau, tau_lag, verbose=True):
    t_resp_net = tau/(1-wrns)
    w_rec_I_sum = lambertw(-np.exp(-1-tau_lag/t_resp_net), k=0)
    t_resp_full = 1 / (1/t_resp_net + w_rec_I_sum/tau_lag + 1/tau_lag)
    t_dec_I = tau/(1+w_rec_I_sum)
    if verbose:
        print(f"wris_crit for wrns={wrns}, tau={tau} and tau_lag={tau_lag} ")
        print(f"is {w_rec_I_sum}")
    return np.real(t_resp_full), (tau/tau_lag)*np.real(w_rec_I_sum), np.real(t_dec_I)


## Fig. 3) response (diff. times) and 1D activity propagation 

In [None]:
x_ticks = [4, 10, 40, 100]
y_ticks_1 = [1, 10, 100, 1_000]
y_ticks_2 = [1, 10]
x_labelpad = -5
y_labelpad = [-2, 0]
hide_spines = True

use_relative_activity = True
neuron_range = range(50, 155)

fig_3, axes = plt.subplots(1, 2, figsize=(10, 4))
fig_3.subplots_adjust(left=0.08, bottom=0.15, right=0.99, top=0.99, wspace=0.4, hspace=0.4)

""" axes[0] """
ax = axes[0] # Response at different times
times = np.array([1., 10, 50, 200])[::-1]   # In units of tau
times *= tau_actProp/dt_actProp      # In units of dt (time steps)
# Consider that, for states_actProp_vanilla, only every 
# n_store_every_tsteps-th time step was stored to reduce file size
times = times / n_store_every_tsteps
times = times.astype(int)

# Plot target
ax.plot(neuron_range, states_actProp_vanilla[-1, neuron_range] / np.max(states_actProp_vanilla[-1,:]), c='black', ls='solid', label='Target')
for i_time, time in enumerate(times):
    state = states_actProp_vanilla[time, neuron_range] / np.max(states_actProp_vanilla[-1,:])
    d_state = len(state)
    ax.plot(neuron_range, state, color=colormaps['Greys'](0.2+0.8*state[d_state//2]), label=f't={time*dt_actProp/tau_actProp:.0f}'+r'$\tau$', zorder=-i_time)
ax.set_xlabel('Neuron index')
ax.set_ylabel('Response')
ax.set_xlim([neuron_range[0], neuron_range[-1]])
ax.legend(loc='upper right', frameon=True, fontsize=SMALL_SIZE)
# Remove frame (top and right)
if hide_spines:
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)


""" axes[1] """
ax = axes[1] # Activity propagation
plot_limit_vanilla = int(3*(t_resp_net_actProp/dt_actProp)/n_store_every_tsteps)
n_steps = plot_limit_vanilla
states = states_actProp_vanilla[:n_steps, neuron_range]

d_state = len(states[0])
x_relative_to_ = states_actProp_vanilla[-1, neuron_range]/100 if use_relative_activity else np.ones(d_state)
# 1D vanilla
t_grid = np.arange(n_steps)*(dt_actProp/tau_actProp)*n_store_every_tsteps

# Use imshow instead of pcolormesh
pc = ax.imshow(states.T/x_relative_to_[:,np.newaxis], cmap=cmap_activity_discrete, vmin=0, vmax=100, aspect='auto', extent=[0, t_grid[-1], neuron_range[-1], neuron_range[0]])

ax.set_xlabel(r'Time ($\tau$)')
ax.set_ylabel('Neuron index')
label_ = r'$x\ /\ x^\mathrm{steady}$' if use_relative_activity else 'Activity'
label_ = r'Response (%)' if use_relative_activity else 'Activity'
cbar = plt.colorbar(pc, ax=ax)
cbar.set_label(label_)
# Plot half-activation lines
crossing_times = analyze_state_dynamics(states_actProp_vanilla[:,neuron_range], [0.5,], dt_actProp, tau_actProp)
i_fastest = np.argmin(crossing_times)
times_crossing_up = crossing_times[0,i_fastest:]*n_store_every_tsteps
times_crossing_down = crossing_times[0,:i_fastest+1]*n_store_every_tsteps
ax.plot(times_crossing_up, np.arange(d_state-i_fastest)+i_fastest+neuron_range[0], c='red', ls='dashed')
ax.plot(times_crossing_down, np.arange(i_fastest+1)+neuron_range[0], c='red', ls='dashed')
ax.set_xlim([0, t_grid[-1]])

if save_fig:
    file_store_to = plots_folder_name + "activityPropagation.pdf"
    print("\nSave fig to {}".format(file_store_to))
    fig_3.savefig(file_store_to)

plt.show()

## Fig. 4) Loss curve & $\tau_\mathrm{resp}$ vs $n_\mathrm{RF}$

In [None]:
x_ticks = [4, 10, 40, 100]
y_ticks_0 = [1, np.exp(-1), np.exp(-2), np.exp(-3)]
y_ticks_labels_0 = ['1', r'$e^{-1}$', r'$e^{-2}$', r'$e^{-3}$']
y_ticks_1 = [1, 10, 100, 1_000]
y_ticks_2 = [1, 10]
x_labelpad = -5
y_labelpad = [-2, 0]
hide_spines = True

use_relative_activity = True
neuron_range = range(50, 155)

fig_4, axes = plt.subplots(1, 2, figsize=(10, 4))
fig_4.subplots_adjust(left=0.08, bottom=0.15, right=0.99, top=0.99, wspace=0.4, hspace=0.4)


""" axes[0] """
ax = axes[0] # loss curve
# 1D vanilla (net)
plot_limit_vanilla = int(3*t_resp_net_actProp/dt_actProp)
# plot_limit_lagged = int(3*t_resp_full_actProp/dt_actProp)
n_steps = plot_limit_vanilla
t_grid = np.arange(n_steps)*dt_actProp/tau_actProp
ax.plot(t_grid, loss_curve_vanilla_actProp[:n_steps]/loss_curve_vanilla_actProp[0], c=colors['net'])
ax.axvline(t_resp_net_actProp, c='gray', ls='solid', alpha=0.5)
ax.axhline(np.exp(-1), c='gray', ls='solid', alpha=0.5)
ax.scatter([t_resp_net_actProp,], [np.exp(-1)], facecolors='none', edgecolors=colors['net'])
ax.set_xlabel(r'Time ($\tau$)')
ax.set_ylabel('Loss')
ax.set_yscale('log')
ax.set_yticks(y_ticks_0, labels=y_ticks_labels_0)
ax.minorticks_off()
ax.set_xlim([0, 210])
ax.set_ylim([0.9*np.exp(-2), 1.1])




""" axes[1] """
n_RF_name = r'$n_{\mathrm{RF}}$' if n_RF_reference=='n_RF_tar_scan' else r'$\hat{n}_{\mathrm{RF}}$'
pows_show = [0]
ax = axes[1] # t_resp vs. n_ff
ax.set_title(r"$\sim n_\mathrm{RF}^2$")
# 1D vanilla
ax.plot(nff_data_1D_vanilla, tresp_data_1D_vanilla, ls='none', marker='o', c=colors['net'], markerfacecolor='none', label=r'$\tau_{\mathrm{resp}}^{\mathrm{1D},\mathrm{net}}$')
ax.plot(nff_data_1D_vanilla_theory, tresp_appr_data_1D_vanilla_theory, ls=':', c=colors['net'])
# # 1D SFA
ax.set_xlabel(n_RF_name)
ax.set_ylabel(r'$\tau_{\mathrm{resp}}$') 
ax.xaxis.labelpad = x_labelpad
ax.yaxis.labelpad = y_labelpad[0]
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xticks(x_ticks, labels=x_ticks)
ax.set_yticks(y_ticks_1, labels=y_ticks_1)
if hide_spines:
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)


left, bottom, width, height = [0.785, 0.25, 0.2, 0.3]
ax_inset = fig_4.add_axes([left, bottom, width, height])
# 1D vanilla
ax_inset.plot(nff_data_1D_vanilla, tresp_data_1D_vanilla/nff_data_1D_vanilla, ls='none', marker='o', c=colors['net'], markerfacecolor='none', label=r'$\tau_{\mathrm{resp}}^{\mathrm{1D},\mathrm{net}}$')
ax_inset.plot(nff_data_1D_vanilla_theory, tresp_appr_data_1D_vanilla_theory/nff_data_1D_vanilla_theory, ls=':', c=colors['net'])

ax_inset.set_ylabel(r'$\tau_{\mathrm{resp}}/$'+n_RF_name)
ax_inset.xaxis.labelpad = x_labelpad
ax_inset.yaxis.labelpad = y_labelpad[1]
ax_inset.set_xscale('log')
ax_inset.set_yscale('log')
ax_inset.set_xticks(x_ticks[:-1], labels=x_ticks[:-1])
ax_inset.set_yticks(y_ticks_2, labels=y_ticks_2)
if hide_spines:
    for ax in [axes[0], axes[1], ax_inset]:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)



if save_fig:
    file_store_to = plots_folder_name + "tresp_1D_net.pdf"
    print("\nSave fig to {}".format(file_store_to))
    fig_4.savefig(file_store_to)

plt.show()

## Fig. 5) Loss evolution - solutions

In [None]:
tau = 1.
tau_lag = 0.1
t_dec_net = 100
n_vals = 1110
wris_vals = np.linspace(-0.9, -1.01, n_vals)
t_dec_I_vals = 1. / (1+wris_vals)

wrns = 1 - 1/t_dec_net
t_resp_net = tau/(1-wrns)
c1 = tau_lag/t_resp_net

def get_lt_ot_LambertW(wrns, wris, wris_crit, tau, tau_lag):
    """ Calculate lt = lambda*tau_lag and ot = omega*tau_lag
        using the solution given by the Lambert W function, using the
        k=0 and k=-1 branches. """
    t_resp_net = tau/(1-wrns)
    z = wris*np.exp(tau_lag/t_resp_net + wris)
    ot_complex_k0  = tau_lag/t_resp_net + wris - lambertw(z, k=0, tol=1e-8)
    ot_complex_km1 = tau_lag/t_resp_net + wris - lambertw(z, k=-1, tol=1e-8)
    lt_slow = np.real(ot_complex_k0)
    ot_slow = np.imag(ot_complex_k0)
    lt_fast = np.real(ot_complex_km1)
    ot_fast = np.imag(ot_complex_km1)
    assert np.isclose(np.abs(ot_slow), np.abs(ot_fast), rtol=1e-05, atol=1e-08), f"np.abs(ot_slow) != np.abs(ot_fast) ({np.abs(ot_slow)} != {np.abs(ot_fast)}; diff = {np.abs(ot_slow)-np.abs(ot_fast)} = {np.abs(np.abs(ot_slow)-np.abs(ot_fast))/np.abs(ot_slow)*100:.3f}%)"

    return lt_slow, lt_fast, np.abs(ot_slow)

def get_tresp_and_wrisc(wrns, tau, tau_lag, verbose=True):
    t_resp_net = tau/(1-wrns)
    w_rec_I_sum = lambertw(-np.exp(-1-tau_lag/t_resp_net), k=0)
    t_resp_full = 1 / (1/t_resp_net + w_rec_I_sum/tau_lag + 1/tau_lag)
    t_dec_I = tau/(1+w_rec_I_sum)
    if verbose:
        print(f"wris_crit for wrns={wrns}, tau={tau} and tau_lag={tau_lag} ")
        print(f"is {w_rec_I_sum}")
    return np.real(t_resp_full), (tau/tau_lag)*np.real(w_rec_I_sum), np.real(t_dec_I)

# calculate critical wris_crit
t_resp_full_crit, wris_crit, t_dec_I_crit = get_tresp_and_wrisc(wrns, tau, tau_lag)

wris_data = np.zeros(n_vals)
t_dec_I_data = np.zeros(n_vals)
lt_slow_data_old = np.zeros(n_vals)
lt_fast_data_old = np.zeros(n_vals)
ot_data_old = np.zeros(n_vals)
lt_slow_data = np.zeros(n_vals)
lt_fast_data = np.zeros(n_vals)
ot_data = np.zeros(n_vals)
for i, (wris, t_dec_I_val) in enumerate(zip(wris_vals, t_dec_I_vals)):
    wris_data[i] = wris 
    t_dec_I_data[i] = t_dec_I_val
    c2 = -wris

    # Recalculate lt and ot from new function
    lt_slow, lt_fast, ot = get_lt_ot_LambertW(wrns, wris, wris_crit, tau, tau_lag)
    lt_slow_data[i] = lt_slow
    lt_fast_data[i] = lt_fast
    ot_data[i] = ot

t_resp_slow_data = tau_lag / lt_slow_data
t_resp_fast_data = tau_lag / lt_fast_data
t_osci_data = 2*np.pi * tau_lag / ot_data    

In [None]:
tau = 1.
tau_lag = 0.1
dt = 0.0001
n_lag = math.ceil(tau_lag/dt)
t_dec_net = 100

n_steps = int(150*tau/dt)
initialization = 'constant' # Can be 'constant' or 'expDecay'

wrns = 1 - 1/t_dec_net
t_resp_net = tau/(1-wrns)

L_data_scan = []

eps = 0.02 * (tau/tau_lag) #5e-1
w_I_sum_vals = [wris_crit+eps, wris_crit, wris_crit-eps]
for i, w_I_sum in enumerate(w_I_sum_vals):
    print(f'w_net_sum = {wrns}, w_I_sum = {w_I_sum}')
    lt_slow, lt_fast, ot = get_lt_ot_LambertW(wrns, w_I_sum, wris_crit, tau, tau_lag)
    t_resp = tau_lag/lt_slow # t_resp_net * (1 + w_I_sum)
    c_ = - tau_lag/t_resp

    if initialization=='expDecay':
        L_data = 1 * np.exp(-(np.arange(2*n_lag+1+n_steps)-n_lag)*(dt/tau)/t_resp)
        DL_data = c_ * L_data
        DL_data_test = L_data - np.roll(L_data, n_lag)
        L_data = np.copy(L_data[n_lag:])
        DL_data = np.copy(DL_data[n_lag:])
        DL_data_test = np.copy(DL_data_test[n_lag:])
    elif initialization=='constant':
        L_data = 1 * np.ones(n_lag+1+n_steps)
        DL_data = np.zeros(n_lag+1+n_steps)
        DL_data_test = np.zeros(n_lag+1+n_steps)

    t_data = (np.arange(n_lag+1+n_steps)-n_lag)*dt

    DL_data = np.copy(DL_data_test)
    for t in range(n_steps):
        t = t+n_lag+1
        DL_data[t-1] = L_data[t-1] - L_data[t-1-n_lag]
        L_deriv = - (1-wrns)*L_data[t-1] - w_I_sum*DL_data[t-1]
        L_data[t] = L_data[t-1] + (dt/tau)*L_deriv


    L_data_scan.append(L_data)

In [None]:
""" 
Set minima in L_data for oscillating solution manually to 1e-30 
Loop through data with fixed window size and set data entry in
the middle of the window to 1e-15 if it is the minimum of that 
window
"""

L_data_osci = np.copy(L_data_scan[2])
window_size = 10
di = window_size//2
for t in range(len(L_data_osci)-window_size):
    if di==np.argmin(np.abs(L_data_osci[t:t+window_size])):
        L_data_osci[t+di] = 1e-30


plt.plot(np.abs(L_data_scan[2]), label="before")
plt.plot(np.abs(L_data_osci),  ls='dotted', label="after")
plt.yscale('log')
plt.legend()

L_data_scan[2] = L_data_osci

In [None]:
tau = 1.
tau_lag = 0.1
dt = 0.0001
n_lag = math.ceil(tau_lag/dt)
t_dec_net = 100


def brighten(color_name, amount=20, outtype='hex'):
    """ Returns color in hex format, brightened by amount (0-256),
        i.e. increasing each RGB value by this amount """
    if type(color_name)==str:
        color_hex = mcolors.cnames[color_name]
    elif type(color_name)==tuple:
        color_hex = mcolors.to_hex(color_name)
    else:
        print(f'ERROR: color_name is neither str nor tuple, but {type(color_name)}')
    color_rgb = np.array(mcolors.to_rgb(color_hex))
    color_rgb_bright = np.clip(color_rgb + amount/256, 0, 1)
    color_hex_bright = mcolors.to_hex(color_rgb_bright)
    if outtype=='hex':
        return color_hex_bright
    elif outtype=='tuple':
        return tuple(color_rgb_bright)
    
def darken(color_name, amount=20):
    return brighten(color_name, -amount)

colors_ = {
    'overdamped' : cmap_activity(0.4), #'blue',
    'critical' : brighten('blue', 0, outtype='tuple'), #colors['full'],
    'underdamped' : 'teal', #cmap_activity(1.), #'green',
    'real' : 'black',
    'imag' : 'red' ,
    'approx_ratio' : 'red',
    'exact' :  'black',
    'approx' : 'gray'
}

print(f'{colors_["critical"]=}, {type(colors_["critical"])}')
print(f'{colors_["overdamped"]=}, {type(colors_["overdamped"])}')

wris_vals = np.linspace(-0.9, -1.01, n_vals)
t_dec_I_vals = 1. / (1+wris_vals)

wrns = 1 - 1/t_dec_net
t_resp_net = tau/(1-wrns)
c1 = tau_lag/t_resp_net

t_resp_full, wris_crit, t_dec_I = get_tresp_and_wrisc(wrns, tau, tau_lag)

# ===== Design =====

SMALL_SIZE = 12 # 8 # 6
MEDIUM_SIZE = 18 #18 # 8 # 6
BIGGER_SIZE = 20 # 12 #12

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE + 4) # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title



fig, axes = plt.subplots(1, 2, figsize=(10, 4))
fig.subplots_adjust(left=0.07, bottom=0.16, right=0.99, top=0.99, wspace=0.35, hspace=0.4)

run_colors = [colors_['overdamped'], colors_['critical'], colors_['underdamped']]
for i, w_I_sum in enumerate(w_I_sum_vals):
    

    L_data = L_data_scan[i]
    L_data /= L_data[0]

    """ Plot loss curve from network simulation """
    ax = axes[0]
    # ax.axhline(0, c='gray', ls='solid', alpha=0.5)
    if w_I_sum==wris_crit:
        key = 'loss_curve_critical'
        L_data_network = loss_curve_critical
    elif w_I_sum>wris_crit:
        key = 'loss_curve_overdamped'
        L_data_network = loss_curve_overdamped
    elif w_I_sum<wris_crit:
        key = 'loss_curve_underdamped'
        L_data_network = np.copy(loss_curve_underdamped)
    L_data_network /= L_data_network[0]
    t_data_network = np.arange(len(L_data_network))*dt_around_critical/tau_around_critical
    ax.plot(t_data_network, np.abs(L_data_network), c=brighten(run_colors[i], amount=150 if i==1 else 20), ls='solid', label=r'$w_\mathrm{rec}^{I,c}$'+("{:+.2f}".format(w_I_sum-wris_crit)+r" $\cdot(\tau_{{lag}}/\tau$)" if w_I_sum!=wris_crit else ""))

    print(f'{np.shape(t_data)=}, {np.shape(L_data)=}')


    """ Plot loss curve from simulated analytical model """
    ax.plot(t_data, np.abs(L_data), c=darken(run_colors[i]), ls='dashed', label=r'$w_\mathrm{rec}^{I,c}$'+("{:+.2f}".format(w_I_sum-wris_crit)+r" $\cdot(\tau_{{lag}}/\tau$)" if w_I_sum!=wris_crit else ""))
    ax.legend(loc=1)
    ax.set_yscale('log')
    ax.set_xlabel(r'$t (\tau)$')
    ax.set_ylabel(r'Loss')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    # Remove minor y ticks
    ax.yaxis.set_minor_locator(plt.NullLocator())



    ax = axes[1]
    ax.axvline(w_I_sum, c=darken(run_colors[i]), ls='dashed') #, alpha=0.5)

axes[0].set_xlim(0, 75)
axes[0].set_ylim(1e-9, 1.1)
axes[0].set_yticks([1e-8, 1e-6, 1e-4, 1e-2, 1])



ax = axes[1]
ax.set_xlabel(r'$\frac{\tau_\mathrm{lag}}{\tau} w_\mathrm{sum}^\mathrm{rec,I}$')
ax.set_ylabel(r'$\lambda\tau$')
ax.axhline(0., c='gray', alpha=0.5)
ax.plot(wris_data, (tau/tau_lag)*lt_fast_data, c='gray')
ax.plot(wris_data, (tau/tau_lag)*lt_slow_data, c=colors_['real'], label=r'$\lambda\tau$')
ax.plot([],[], c='red', label=r'$\omega\tau$') 
t_resp_c_th = np.sqrt(t_resp_net*tau_lag/2)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['left'].set_color(colors_['real'])
ax.tick_params(axis='y', colors=colors_['real'])
ax.invert_xaxis()
ax.set_xlim((-0.9, -1.01))
ax.set_ylim((-0.2,2))
ax.set_yticks((0, 0.5, 1, 1.5, 2))
#
ax = ax.twinx()
ax.set_ylabel(r'$\omega\tau$', labelpad=-2)
ax.spines['right'].set_color(colors_['imag'])
ax.yaxis.label.set_color(colors_['imag'])
ax.tick_params(axis='y', colors=colors_['imag'])
ax.axhline(0, c='gray', alpha=0.5)
ax.plot(wris_data, (tau/tau_lag)*ot_data, c=colors_['imag'], label=r'$\omega \tau$')
ax.plot(wris_data, -(tau/tau_lag)*ot_data, c=colors_['imag'])
ax.spines['top'].set_visible(False)
# ax.set_ylim((-0.45/4, 0.45))
#
for i, wris in enumerate(w_I_sum_vals):
    ax.axvline( (tau_lag/tau_around_critical) * wris, c=run_colors[i], ls='dashed')

if save_fig:
    file_store_to = plots_folder_name + "theoLossEvolution.pdf"
    print("\nSave fig to {}".format(file_store_to))
    fig.savefig(file_store_to)

plt.show()

## Fig. 6) Loss curve (1D lagged) & $\tau_\mathrm{resp}$ vs $n_\mathrm{RF}$ (1D and 2D lin.MS)

In [None]:
x_ticks = [4, 10, 40, 100]
y_ticks_0 = [1, np.exp(-1), np.exp(-2), np.exp(-3)]
y_ticks_labels_0 = ['1', r'$e^{-1}$', r'$e^{-2}$', r'$e^{-3}$']
y_ticks_1 = [1, 10, 100, 1_000]
y_ticks_2 = [1, 10]
x_labelpad = -5
y_labelpad = [-2, 0]
hide_spines = True

show_schematics = True

use_relative_activity = True
neuron_range = range(50, 155)

if show_schematics:
    fig_5, axes = plt.subplots(1, 3, figsize=(15, 4))
    fig_5.subplots_adjust(left=0.08, bottom=0.15, right=0.99, top=0.99, wspace=0.4, hspace=0.4)
else:
    fig_5, axes = plt.subplots(1, 2, figsize=(10, 4))
    fig_5.subplots_adjust(left=0.08, bottom=0.15, right=0.99, top=0.99, wspace=0.4, hspace=0.4)

# Fit data
def monomial_fit(x, a, b, p):
    return a*x**p + b
def monomial_fit_no_offset(x, a, p):
    return a*x**p
fit_func = monomial_fit_no_offset

# Fit SFA data
# n_RF_vals_SFAtf, t_respBest_data_SFAtf
popt_SFAtf, pcov = curve_fit(monomial_fit_no_offset, n_RF_vals_SFAtf, t_respBest_data_SFAtf, p0=[1, 1])
print(f'SFA fit: tresp = {popt_SFAtf[0]:.5f}*n_ff^{popt_SFAtf[1]:.5f}')

""" axes[0] """
pows_show = [0]
ax = axes[0]
n_RF_name = r'$n_{\mathrm{RF}}$' if n_RF_reference=='n_RF_tar_scan' else r'$\hat{n}_{\mathrm{RF}}$'
# 1D vanilla
ax.plot(nff_data_1D_vanilla, tresp_data_1D_vanilla, ls='none', marker='o', c=colors['net'], markerfacecolor='none', label=r'$\tau_{\mathrm{resp}}^{\mathrm{net},\mathrm{1D}}$')
ax.plot(nff_data_1D_vanilla_theory, tresp_appr_data_1D_vanilla_theory, ls=':', c=colors['net'])
ax.set_xlabel(n_RF_name)
ax.set_ylabel(r'$\tau_{\mathrm{resp}}$')
ax.xaxis.labelpad = x_labelpad
ax.yaxis.labelpad = y_labelpad[0]
ax.legend()
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xticks(x_ticks, labels=x_ticks)
ax.set_yticks(y_ticks_1, labels=y_ticks_1)
if hide_spines:
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
# 1D lagged
ax.plot(nff_data_1D_lagged, tresp_data_1D_lagged_values[:,0], ls='none', marker='o', c=colors['full'], markerfacecolor='none', label=r'$\tau_{\mathrm{resp}}^{\mathrm{1D}}$')
ax.plot(nff_data_1D_lagged_theory, tresp_appr_data_1D_lagged_theory, ls=':', c=colors['full'])

# 1D SFA
ax.plot(n_RF_vals_SFAtf, t_respBest_data_SFAtf, ls='none', marker='o', c=colors['SFA'], markerfacecolor='none', label=r'$\tau_{\mathrm{resp}}^{\mathrm{1D},\mathrm{SFA}}$')
ax.plot(n_RF_vals_SFAtf, fit_func(n_RF_vals_SFAtf, *popt_SFAtf), c=colors['SFA'], ls=':')

# 2D lin.MS vanilla
ax.plot(nff_data_2D_linear_vanilla, tresp_data_2D_linear_vanilla, ls='none', marker='.', c=colors['net'], label=r'$\tau_{\mathrm{resp}}^{\mathrm{net},\mathrm{2D\ lin.MS}}$')
ax.plot(2*nff_data_1D_vanilla_theory, 2*tresp_appr_data_1D_vanilla_theory, ls=':', c=colors['net'])
ax.plot(nff_data_2D_linear_lagged, tresp_data_2D_linear_lagged_values[:,0], ls='none', marker='.', c=colors['full'], label=r'$\tau_{\mathrm{resp}}^{\mathrm{2D\ lin.MS}}$')
ax.plot(2*nff_data_1D_lagged_theory, np.sqrt(2)*tresp_appr_data_1D_lagged_theory, ls=':', c=colors['full'])





""" 2D lin.MS  Schematics """
ax = axes[1]

d_show_0 = 4
d_show_1 = 4
xy_offset = 2.

marker_width = 0.8 * (2/3)*340 / (2*d_show_0+1 + xy_offset)  # 340 = 100%
marker_size = marker_width**2
lw_syn = 2


N_2D = 200

x_target_plot = x_target_scan[0][N_2D//2-d_show_0:N_2D//2+d_show_0+1, N_2D//2-d_show_1:N_2D//2+d_show_1+1]
x_target_values_ = x_target_plot.flatten()
# Get x and y coordinates of x_target_plot_values
x_target_plot_x_ = np.repeat(np.arange(x_target_plot.shape[0]), x_target_plot.shape[1])
x_target_plot_y_ = np.tile(np.arange(x_target_plot.shape[1]), x_target_plot.shape[0])

mask = np.ones_like(x_target_plot)
x_target_values = x_target_values_[mask.flatten()==1]
x_target_plot_x = x_target_plot_x_[mask.flatten()==1]
x_target_plot_y = x_target_plot_y_[mask.flatten()==1]


# Plot synaptic connections
i_end = 2*d_show_0+0.5
j_end = 2*d_show_1+0.5
for i in range(2*d_show_0+1):
    j_start = np.argmax(mask[i, :]==1)-0.5
    ax.plot([j_start, j_end], [i, i], color='k', lw=lw_syn, zorder=-10)
for j in range(2*d_show_1+1):
    i_start = np.argmax(mask[:, j]==1)-0.5
    ax.plot([j, j], [i_start, i_end], color='k', lw=lw_syn, zorder=-10)


ax.scatter(x_target_plot_y, x_target_plot_x, c=x_target_values, cmap=cmap_activity, marker='s', s=marker_size, vmin=0, vmax=1)
ax.set_aspect('equal', 'box')

# Plot inputs
x_grid_input = np.arange(2*d_show_1+1)
y_grid_input = np.arange(2*d_show_0+1)

ax.scatter(x_grid_input, -xy_offset*np.ones_like(x_grid_input), color=colors['x_in_inactive'], marker='s', s=marker_size) #, vmin=0, vmax=1)
ax.scatter(d_show_1, -xy_offset, color=cmap_activity(1e10), marker='s', s=marker_size)
ax.scatter(-xy_offset*np.ones_like(y_grid_input), y_grid_input, color=colors['x_in_inactive'], marker='s', s=marker_size) #, vmin=0, vmax=1)
ax.scatter(-xy_offset, d_show_0, color=cmap_activity(1e10), marker='s', s=marker_size)

# Make ticks etc. invisible
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.set_frame_on(False)






""" axes[2] """
ax = axes[2] # loss curve
# 1D vanilla (net)
plot_limit_vanilla = int(3*t_resp_net_actProp/dt_actProp)
plot_limit_lagged = int(3*t_resp_full_actProp/dt_actProp)
n_steps = plot_limit_lagged
t_grid = np.arange(n_steps)*dt_actProp/tau_actProp
ax.plot(t_grid, np.exp(-t_grid/t_resp_full_actProp), c=colors['full'], ls='dotted')
ax.plot(t_grid, loss_curve_lagged_actProp[:n_steps]/loss_curve_lagged_actProp[0], c=colors['full'])
ax.axvline(t_resp_full_experimental_actProp, c='gray', ls='solid', alpha=0.5)
ax.axvline(t_resp_full_actProp, c='black', ls='dotted', alpha=0.5)
ax.axhline(np.exp(-1), c='gray', ls='solid', alpha=0.5)
ax.scatter(t_resp_full_experimental_actProp, np.exp(-1), facecolors='none', edgecolors=colors['full'])
ax.set_xlabel(r'Time ($\tau$)')
ax.set_ylabel('Loss')
ax.set_yscale('log')
ax.set_yticks(y_ticks_0, labels=y_ticks_labels_0)
ax.minorticks_off()
ax.set_xlim([0, t_grid[-1]])
ax.set_ylim([0.85*np.exp(-2), 1.1])
print(f"t_resp_net_actProp: {t_resp_net_actProp}")



if hide_spines:
    for ax in [axes[0], axes[1]]: #, ax_inset]:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)

if save_fig:
    file_store_to = plots_folder_name + "tresp_1D2Dlin.pdf"
    print("\nSave fig to {}".format(file_store_to))
    fig_5.savefig(file_store_to)

plt.show()

## Fig. 7) Loss curve (2D nonlin. MS) & t_resp

In [None]:
# Fit data
def monomial_fit(x, a, b, p):
    return a*x**p + b
def monomial_fit_no_offset(x, a, p):
    return a*x**p
fit_func = monomial_fit_no_offset

# Fit vanilla data
popt_vanilla, pcov = curve_fit(monomial_fit_no_offset, nff_data_2D_nonlinear_vanilla, tresp_data_2D_nonlinear_vanilla, p0=[1, 1])
print(f'Vanilla fit: tresp = {popt_vanilla[0]:.5f}*n_ff^{popt_vanilla[1]:.5f}')
# Fit lagged data
popt_lagged, pcov = curve_fit(monomial_fit_no_offset, nff_data_2D_nonlinear_lagged, tresp_data_2D_nonlinear_lagged_values[:,0], p0=[1, 1])
print(f'Lagged fit:  tresp = {popt_lagged[0]:.5f}*n_ff^{popt_lagged[1]:.5f}')


fig_6, axes = plt.subplots(1, 2, figsize=(10, 4))
fig_6.subplots_adjust(left=0.08, bottom=0.15, right=0.99, top=0.99, wspace=0.4, hspace=0.4)


""" axes[0]"""
ax = axes[0]

d_show_0 = 4
d_show_1 = 4
xy_offset = 2.75 # 2.625

marker_width = 0.825 *251 / (2*d_show_0+1+xy_offset)  # 251 = 100%
marker_size = marker_width**2
lw_syn = 2

N_2D = 200

print(f"np.shape(x_target_scan_2D) = {np.shape(x_target_scan_2D)}")

tar_show = min(5, len(x_target_scan_2D)-1)
print(f"{tar_show=}")
x_target_plot = x_target_scan_2D[tar_show][N_2D//2-d_show_0:N_2D//2+d_show_0+1, N_2D//2-d_show_1:N_2D//2+d_show_1+1]
x_target_values_ = x_target_plot.flatten()
# Get x and y coordinates of x_target_plot_values
x_target_plot_x_ = np.repeat(np.arange(x_target_plot.shape[0]), x_target_plot.shape[1])
x_target_plot_y_ = np.tile(np.arange(x_target_plot.shape[1]), x_target_plot.shape[0])

mask = np.ones_like(x_target_plot)
mask[len(mask)-1, :5] = 0
mask[-1, :4] = 0
mask[-2, :4] = 0
mask[-3, :4] = 0
mask[-4, :3] = 0
mask[-5, :1] = 0
x_target_values = x_target_values_[mask.flatten()==1]
x_target_plot_x = x_target_plot_x_[mask.flatten()==1]
x_target_plot_y = x_target_plot_y_[mask.flatten()==1]

# Plot input neurons
x_offset = -xy_offset
y_offset = 1.25 # xy_offset
ax.scatter(x_target_plot_x_+x_offset, x_target_plot_y_+y_offset, c=colors['x_in_inactive'], marker='s', s=marker_size, zorder=-20)
ax.scatter([d_show_0+x_offset], [d_show_1+y_offset], color=cmap_activity(1e10), marker='s', s=marker_size, zorder=-15)

# Plot synaptic connections
i_end = 2*d_show_1+0.5
i_start = -0.5
j_end = 2*d_show_0+0.5
for i in range(2*d_show_0+1): # From left to right
    j_start = np.argmax(mask[i, :]==1)-0.5
    ax.plot([j_start, j_end], [i, i], color='k', lw=lw_syn, zorder=-10)
for j in range(2*d_show_1+1): # From top to bottom
    i_end = len(mask)-np.argmax(mask[:, j][::-1]==1)-0.5
    ax.plot([j, j], [i_start, i_end], color='k', lw=lw_syn, zorder=-10)


# Plot feature neurons
ax.scatter(x_target_plot_y, x_target_plot_x, c=x_target_values, cmap=cmap_activity, vmin=0., vmax=1., marker='s', s=marker_size)
ax.set_aspect('equal', 'box')

# Make ticks etc. invisible
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
ax.set_frame_on(False)






""" axes[1]"""
ax = axes[1] # t_resp vs. n_ff

x_ticks = [40, 100, 400, 1_000, 4_000]
y_ticks_2 = [0.01, 0.1, 1]
x_labelpad = 0
pows_show = [0]

# 2D nonlinear mixed vanilla
ax.plot(nff_data_2D_nonlinear_vanilla, tresp_data_2D_nonlinear_vanilla, c=colors['net'], ls='', marker='.', label=r'$\tau_{\mathrm{resp}}$') 
ax.plot(nff_data_2D_nonlinear_vanilla, fit_func(nff_data_2D_nonlinear_vanilla, *popt_vanilla), c=colors['net'], ls=':')
# 2D nonlinear mixed lagged
for i in pows_show: 
    if i==0:
        ax.plot(nff_data_2D_nonlinear_lagged, tresp_data_2D_nonlinear_lagged_values[:,0], 'r.', label=r'$\tau_\mathrm{resp}^{\mathrm{bal}}$')
    else:
        ax.plot(nff_data_2D_nonlinear_lagged, tresp_data_2D_nonlinear_lagged_values[:,i]-tresp_data_2D_nonlinear_lagged_values[:,i-1], 'r.', alpha=1-(i+1)/(2*4))
ax.plot(nff_data_2D_nonlinear_lagged, fit_func(nff_data_2D_nonlinear_lagged, *popt_lagged), 'r:')

ax.set_xlabel(r'$\hat{n}_{\mathrm{RF}}$')
ax.set_ylabel(r'$\tau_{\mathrm{resp}}^\mathrm{2D}$')
ax.xaxis.labelpad = x_labelpad
ax.yaxis.labelpad = y_labelpad[0]
ax.legend()
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_xticks(x_ticks, labels=x_ticks)
ax.set_yticks(y_ticks_1, labels=y_ticks_1)
if hide_spines:
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)




if save_fig:
    file_store_to = plots_folder_name + "tresp_2D_nonlinear.pdf"
    print("\nSave fig to {}".format(file_store_to))
    fig_6.savefig(file_store_to)

plt.show()