In [15]:
import math
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import itertools
import random
import scipy.integrate as integrate
import scipy.stats as stats
import torch
import time
import pandas as pd

from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
from sbi.utils.get_nn_models import posterior_nn
from sbi import utils as utils
from sbi import analysis as analysis
from sbi.inference.base import infer

In [None]:
# SBI: target firing rate for dIN/cIN/MN = 20Hz; for aINs = 10Hz
# parameters to find: delta_g_ex, tau_ex for exc connections: 
# MN -> MN/dIN/aIN/cIN; dIN -> MN/dIN/aIN/cIN
# delta_g_inh, tau_inh for inh connections:
# aIN -> MN/dIN/aIN/cIN; cIN -> MN/dIN/aIN/cIN
# external current I ?? may be different for dINs and rest
# try with time constants being same for all exc/inh connections i.e. just find 1 tau_ex and 1 tau_inh
# gap juntion conductance

In [18]:
columns = ['Time (s)', 'Voltage (mV)']

mn_filename = 'data/motoneuron_1429_f5.txt'
mn_df = pd.read_csv(mn_filename,header=None,names=columns,skiprows=0,delim_whitespace=True)
mn_trace = mn_df['Voltage (mV)'].values.tolist()[:5000]
ain_filename = 'data/aIN_1311_f109.txt'
ain_df = pd.read_csv(ain_filename,header=None,names=columns,skiprows=0,delim_whitespace=True)
ain_trace = ain_df['Voltage (mV)'].values.tolist()[:5000]
cin_filename = 'data/cIN_v120_f4.txt'
cin_df = pd.read_csv(cin_filename,header=None,names=columns,skiprows=0,delim_whitespace=True)
cin_trace_int = cin_df['Voltage (mV)'].values.tolist()[:5000]
cin_trace = [i*100 for i in cin_trace_int]
din_filename = 'data/dIN_1399_f12.txt'
din_df = pd.read_csv(din_filename,header=None,names=columns,skiprows=0,delim_whitespace=True)
din_trace = din_df['Voltage (mV)'].values.tolist()[:5000]

In [None]:
time = 5000
delta_t = 0.1
V_thr = 0

din_b = 0
din_a = 12
din_tau_w = 30
din_cm = 200

cm = 200
g_L = 10
E_L = -70
E_ex = 0
E_inh = -70
g_ex = 0
g_inh = 0
V_T = -50
V_thr = 0
V_reset = -58
mn_b = 60
mn_a = 0
mn_tau_w = 30

slope_f = 2
n_mns = 2
n_ains = 2
n_cins = 2
n_dins = 2
g_gj = 0.2

In [19]:
class Innerloop_AdExSim():
    
    def __init__(self, inner_loop_params):
        self.time = inner_loop_params['time']
        self.delta_t = inner_loop_params['delta_t']
        self.V_thr = inner_loop_params['V_thr']

        self.din_b = inner_loop_params['din_b']
        self.din_a = inner_loop_params['din_a']
        self.din_tau_w = inner_loop_params['din_tau_w']
        self.din_cm = inner_loop_params['din_cm']

        self.cm = inner_loop_params['cm']
        self.g_L = inner_loop_params['g_L']
        self.E_L = inner_loop_params['E_L']
        self.E_ex = inner_loop_params['E_ex']
        self.E_inh = inner_loop_params['E_inh']
        self.g_ex = inner_loop_params['g_ex']
        self.g_inh = inner_loop_params['g_inh']
        self.V_T = inner_loop_params['V_T']
        self.V_thr = inner_loop_params['V_thr']
        self.V_reset = inner_loop_params['V_reset']
        self.mn_b = inner_loop_params['mn_b']
        self.mn_a = inner_loop_params['mn_a']
        self.mn_tau_w = inner_loop_params['mn_tau_w']

        self.slope_f = inner_loop_params['slope_f']
        self.n_mns = inner_loop_params['n_mns']
        self.n_ains = inner_loop_params['n_ains']
        self.n_cins = inner_loop_params['n_cins']
        self.n_dins = inner_loop_params['n_dins']
        self.g_gj = inner_loop_params['g_gj']
        
    def reset(self):
        # Voltage vector: size = number of MNs x 1; stored voltage matrix: size = number of recorded neurons x number of time steps. 
        n_record = 2
        n_timesteps = len(np.arange(0,self.time,self.delta_t))
        self.mn_voltage_vec = np.zeros((self.n_mns,1))
        self.mn_voltage_vec[:,:] += self.E_L
        self.mn_voltage_record = np.zeros((n_record,n_timesteps))
        # aINs
        self.ain_voltage_vec = np.zeros((self.n_ains,1))
        self.ain_voltage_vec[:,:] += self.E_L
        self.ain_voltage_record = np.zeros((n_record,n_timesteps))
        # cINs
        self.cin_voltage_vec = np.zeros((self.n_cins,1))
        self.cin_voltage_vec[:,:] += self.E_L
        self.cin_voltage_record = np.zeros((n_record,n_timesteps))
        # dINs
        self.din_voltage_vec = np.zeros((self.n_dins,1))
        self.din_voltage_vec[:,:] += self.E_L
        self.din_voltage_record = np.zeros((n_record,n_timesteps))
        self.din_current_record = np.zeros((n_record,n_timesteps))

        # Spike times list of lists: number of lists = number of MNs. 
        self.mn_spike_times = [[] for m in range(self.n_mns)]
        self.ain_spike_times = [[] for m in range(self.n_ains)]
        self.cin_spike_times = [[] for m in range(self.n_cins)]
        self.din_spike_times = [[] for m in range(self.n_dins)]

        # dIN gap junctions
        self.gap_junctions = np.zeros((self.n_dins,self.n_dins))
        for r in range(self.n_dins):
            for c in range(self.n_dins):
                if r < self.n_dins/2 and c < self.n_dins/2 or r >= self.n_dins/2 and c >= self.n_dins/2:
                    if abs(r-c) < 4 and r != c:
                        self.gap_junctions[r,c] += 1

        # MN -> MN synapses.
        self.mn_synapses = np.zeros((self.n_mns,self.n_mns))
        for r in range(self.n_mns):
            for c in range(self.n_mns):
                if r < self.n_mns/2 and c < self.n_mns/2 or r >= self.n_mns/2 and c >= self.n_mns/2:
                    if random.random() < 0.05 and r != c:
                        self.mn_synapses[r,c] += 1

        # MN -> dIN synapses.
        self.mn_din_synapses = np.zeros((self.n_mns,self.n_dins))
        for r in range(self.n_mns):
            for c in range(self.n_dins):
                if r < self.n_mns/2 and c < self.n_dins/2 or r >= self.n_mns/2 and c >= self.n_dins/2:
                    if random.random() < 0.05 and r != c:
                        self.mn_din_synapses[r,c] += 1

        # MN -> aIN synapses.
        self.mn_ain_synapses = np.zeros((self.n_mns,self.n_ains))
        for r in range(self.n_mns):
            for c in range(self.n_ains):
                if r < self.n_mns/2 and c < self.n_ains/2 or r >= self.n_mns/2 and c >= self.n_ains/2:
                    if random.random() < 0.05 and r != c:
                        self.mn_ain_synapses[r,c] += 1

        # MN -> cIN synapses.
        self.mn_cin_synapses = np.zeros((self.n_mns,self.n_cins))
        for r in range(self.n_mns):
            for c in range(self.n_cins):
                if r < self.n_mns/2 and c < self.n_cins/2 or r >= self.n_mns/2 and c >= self.n_cins/2:
                    if random.random() < 0.05 and r != c:
                        self.mn_cin_synapses[r,c] += 1   

        # aIN -> MN synapses.
        self.ain_mn_synapses = np.zeros((self.n_ains,self.n_mns))
        for r in range(self.n_ains):
            for c in range(self.n_mns):
                if r < self.n_ains/2 and c < self.n_mns/2 or r >= self.n_ains/2 and c >= self.n_mns/2:
                    if random.random() < 0.1:
                        self.ain_mn_synapses[r,c] += 1

        # aIN -> aIN synapses.
        self.ain_synapses = np.zeros((self.n_ains,self.n_ains))
        for r in range(self.n_ains):
            for c in range(self.n_ains):
                if r < self.n_ains/2 and c < self.n_ains/2 or r >= self.n_ains/2 and c >= self.n_ains/2:
                    if random.random() < 0.1 and r != c:
                        self.ain_synapses[r,c] += 1

        # aIN -> cIN synapses.
        self.ain_cin_synapses = np.zeros((self.n_ains,self.n_cins))
        for r in range(self.n_ains):
            for c in range(self.n_cins):
                if r < self.n_ains/2 and c < self.n_cins/2 or r >= self.n_ains/2 and c >= self.n_cins/2:
                    if random.random() < 0.1:
                        self.ain_cin_synapses[r,c] += 1

        # aIN -> dIN synapses.
        self.ain_din_synapses = np.zeros((self.n_ains,self.n_dins))
        for r in range(self.n_ains):
            for c in range(self.n_dins):
                if r < self.n_ains/2 and c < self.n_dins/2 or r >= self.n_ains/2 and c >= self.n_dins/2:
                    if random.random() < 0.1:
                        self.ain_din_synapses[r,c] += 1

        # cIN -> aIN synapses.
        self.cin_ain_synapses = np.zeros((self.n_cins,self.n_ains))
        for r in range(self.n_cins):
            for c in range(self.n_ains):
                if r < self.n_cins/2 and c >= self.n_ains/2 or r >= self.n_cins/2 and c < self.n_ains/2:
                    if random.random() < 0.1:
                        self.cin_ain_synapses[r,c] += 1

        # cIN -> MN synapses.
        self.cin_mn_synapses = np.zeros((self.n_cins,self.n_mns))
        for r in range(self.n_cins):
            for c in range(self.n_mns):
                if r < self.n_cins/2 and c >= self.n_mns/2 or r >= self.n_cins/2 and c < self.n_mns/2:
                    if random.random() < 0.1:
                        self.cin_mn_synapses[r,c] += 1

        # cIN -> cIN synapses.
        self.cin_synapses = np.zeros((self.n_cins,self.n_cins))
        for r in range(self.n_cins):
            for c in range(self.n_cins):
                if r < self.n_cins/2 and c >= self.n_cins/2 or r >= self.n_cins/2 and c < self.n_cins/2:
                    if random.random() < 0.1:
                        self.cin_synapses[r,c] += 1

        # cIN -> dIN synapses.
        self.cin_din_synapses = np.zeros((self.n_cins,self.n_dins))
        for r in range(self.n_cins):
            for c in range(self.n_dins):
                if r < self.n_cins/2 and c >= self.n_dins/2 or r >= self.n_cins/2 and c < self.n_dins/2:
                    if random.random() < 0.1:
                        self.cin_din_synapses[r,c] += 1

        # dIN -> dIN synapses.
        self.din_synapses = np.zeros((self.n_dins,self.n_dins))
        for r in range(self.n_dins):
            for c in range(self.n_dins):
                if r < self.n_dins/2 and c < self.n_dins/2 or r >= self.n_dins/2 and c >= self.n_dins/2:
                    if random.random() < 0.1 and r != c:
                        self.din_synapses[r,c] += 1   

        # dIN -> MN synapses.
        self.din_mn_synapses = np.zeros((self.n_dins,self.n_mns))
        for r in range(self.n_dins):
            for c in range(self.n_mns):
                if r < self.n_dins/2 and c < self.n_mns/2 or r >= self.n_dins/2 and c >= self.n_mns/2:
                    if random.random() < 0.1:
                        self.din_mn_synapses[r,c] += 1

        # dIN -> aIN synapses.
        self.din_ain_synapses = np.zeros((self.n_dins,self.n_ains))
        for r in range(self.n_dins):
            for c in range(self.n_ains):
                if r < self.n_dins/2 and c < self.n_ains/2 or r >= self.n_dins/2 and c >= self.n_ains/2:
                    if random.random() < 0.1:
                        self.din_ain_synapses[r,c] += 1

        # dIN -> cIN synapses.
        self.din_cin_synapses = np.zeros((self.n_dins,self.n_cins))
        for r in range(self.n_dins):
            for c in range(self.n_cins):
                if r < self.n_dins/2 and c < self.n_cins/2 or r >= self.n_dins/2 and c >= self.n_cins/2:
                    if random.random() < 0.1:
                        self.din_cin_synapses[r,c] += 1

        self.mn_g_ex = np.zeros((self.n_mns,1))
        self.mn_g_inh = np.zeros((self.n_mns,1))
        self.mn_w = np.zeros((self.n_mns,1))
        self.ain_g_ex = np.zeros((self.n_ains,1))
        self.ain_g_inh = np.zeros((self.n_ains,1))
        self.ain_w = np.zeros((self.n_ains,1))
        self.cin_g_ex = np.zeros((self.n_cins,1))
        self.cin_g_inh = np.zeros((self.n_cins,1))
        self.cin_w = np.zeros((self.n_cins,1))
        self.din_g_ex = np.zeros((self.n_dins,1))
        self.din_g_inh = np.zeros((self.n_dins,1))
        self.din_w = np.zeros((self.n_dins,1))
        self.din_ws = []
        
    # function to take in vector of params (size 17) and run the AdEx model

    def run_AdEx_model(self, params):
        # run simulation: right side starts with slightly greater input than left side (first 25ms) so activity 
        # starts on that side
        delta_g_ex_mn_mn = params[0]
        delta_g_ex_mn_ain = params[1]
        delta_g_ex_mn_cin = params[2]
        delta_g_ex_mn_din = params[3]
        delta_g_ex_din_mn = params[4]
        delta_g_ex_din_ain = params[5]
        delta_g_ex_din_cin = params[6]
        delta_g_ex_din_din = params[7]
        delta_g_inh_ain_mn = params[8]
        delta_g_inh_ain_ain = params[9]
        delta_g_inh_ain_cin = params[10]
        delta_g_inh_ain_din = params[11]
        delta_g_inh_cin_mn = params[12]
        delta_g_inh_cin_ain = params[13]
        delta_g_inh_cin_cin = params[14]
        delta_g_inh_cin_din = params[15]
        tau_ex = params[16]
        tau_inh = params[17]
        g_gj = params[18]
        for idx,t in enumerate(np.arange(0,self.time,self.delta_t)):
            if t > 25:
                I_L = 200
                I_R = 200
                I_L_din = 400
                I_R_din = 400
            else:
                I_L = 150
                I_R = 200
                I_L_din = 350
                I_R_din = 400
            i_gj = np.zeros((1,self.n_dins))
            for mn in range(self.n_mns):
                if self.mn_voltage_vec[mn,-1] >= self.V_thr:
                    self.mn_voltage_vec[mn,-1] = self.V_reset
            for ain in range(self.n_ains):
                if self.ain_voltage_vec[ain,-1] >= self.V_thr:
                    self.ain_voltage_vec[ain,-1] = self.V_reset
            for cin in range(self.n_cins):
                if self.cin_voltage_vec[cin,-1] >= self.V_thr:
                    self.cin_voltage_vec[cin,-1] = self.V_reset
            for din in range(self.n_dins):
                if self.din_voltage_vec[din,-1] >= self.V_thr:
                    self.din_voltage_vec[din,-1] = self.V_reset
            # dIN gap junctions
            for row in range(self.n_dins):
                for c in range(self.n_dins):
                    i_gj[:,row] += self.gap_junctions[row,c]*(self.din_voltage_vec[row,-1] - self.din_voltage_vec[c,-1])*g_gj
            i_gj = i_gj.reshape(self.n_dins,1)
            i_gj = np.zeros((self.n_dins,1))
            # AMPA synapses from MNs to CPG
            for mn in range(self.n_mns):
                spikes = self.mn_spike_times[self.mn]
                for spike in spikes:
                    if spike + 1 == t:
                        self.mn_g_ex += self.mn_synapses[mn,-1]*delta_g_ex_mn_mn
                        self.din_g_ex += self.mn_din_synapses[mn, -1]*delta_g_ex_mn_din
                        self.ain_g_ex += self.mn_ain_synapses[mn, -1]*delta_g_ex_mn_ain
                        self.cin_g_ex += self.mn_cin_synapses[mn, -1]*delta_g_ex_mn_cin
            # dIN -> MN/aIN/cIN/dIN excitatory synapses
            for din in range(self.n_dins):
                spikes = self.din_spike_times[din]
                for spike in spikes:
                    if spike + 1 == t:
                        self.din_g_ex += self.din_synapses[din,-1]*delta_g_ex_din_din
                        self.mn_g_ex += self.din_mn_synapses[din,-1]*delta_g_ex_din_mn
                        self.ain_g_ex += self.din_ain_synapses[din,-1]*delta_g_ex_din_ain
                        self.cin_g_ex += self.din_cin_synapses[din,-1]*delta_g_ex_din_cin
            # aIN -> MN/aIN/cIN/dIN inhibitory synapses
            for ain in range(self.n_ains):
                spikes = self.ain_spike_times[ain]
                for spike in spikes:
                    if spike + 1 == t:
                        self.mn_g_inh += self.ain_mn_synapses[ain,-1]*delta_g_inh_ain_mn
                        self.ain_g_inh += self.ain_synapses[ain,-1]*delta_g_inh_ain_ain
                        self.cin_g_inh += self.ain_cin_synapses[ain,-1]*delta_g_inh_ain_cin
                        self.din_g_inh += self.ain_din_synapses[ain,-1]*delta_g_inh_ain_din
            # cIN -> aIN/MN/dIN/cIN synapses
            for cin in range(self.n_cins):
                spikes = self.cin_spike_times[cin]
                for spike in spikes:
                    if spike + 1 == t:
                        self.mn_g_inh += self.cin_mn_synapses[cin,-1]*delta_g_inh_cin_mn
                        self.ain_g_inh += self.cin_ain_synapses[cin,-1]*delta_g_inh_cin_ain
                        self.cin_g_inh += self.cin_synapses[cin,-1]*delta_g_inh_cin_cin
                        self.din_g_inh += self.cin_din_synapses[cin,-1]*delta_g_inh_cin_din
                        
            self.mn_voltage_vec[:int(self.n_mns/2),-1] += self.delta_t*((-self.g_L*(self.mn_voltage_vec[:int(self.n_mns/2),-1] \
                                             - self.E_L) + self.g_L * self.slope_f * np.exp((self.mn_voltage_vec[:int(self.n_mns/2),-1] \
                                             - self.V_T)/self.slope_f) - self.mn_g_ex[:int(self.n_mns/2),-1] * \
                                             (self.mn_voltage_vec[:int(self.n_mns/2),-1] - self.E_ex) - self.mn_g_inh[:int(self.n_mns/2),-1] \
                                             * (self.mn_voltage_vec[:int(self.n_mns/2),-1] - self.E_inh) + I_L - self.mn_w[:int(self.n_mns/2),-1]) / self.cm)
            self.mn_voltage_vec[int(self.n_mns/2):,-1] += self.delta_t*((-self.g_L*(self.mn_voltage_vec[int(self.n_mns/2),-1] \
                                             - self.E_L) + self.g_L * self.slope_f * np.exp((self.mn_voltage_vec[int(self.n_mns/2):,-1] \
                                             - self.V_T)/self.slope_f) - self.mn_g_ex[int(self.n_mns/2):,-1] * 
                                             (self.mn_voltage_vec[int(self.n_mns/2):,-1] - self.E_ex) - self.mn_g_inh[int(self.n_mns/2):,-1] \
                                             * (self.mn_voltage_vec[int(self.n_mns/2):,-1] - self.E_inh) + I_R - self.mn_w[int(self.n_mns/2):,-1]) / self.cm)
            
            self.ain_voltage_vec[:int(self.n_ains/2),-1] += self.delta_t*((-self.g_L*(self.ain_voltage_vec[:int(self.n_ains/2),-1] \
                                                             - self.E_L) + self.g_L * self.slope_f * np.exp((self.ain_voltage_vec[:int(self.n_ains/2),-1] \
                                                             - self.V_T)/self.slope_f) - self.ain_g_ex[:int(self.n_ains/2),-1] * \
                                                             (self.ain_voltage_vec[:int(self.n_ains/2),-1] - self.E_ex) \
                                                             - self.ain_g_inh[:int(self.n_ains/2),-1] * (self.ain_voltage_vec[:int(self.n_ains/2),-1] \
                                                             - self.E_inh) + I_L - self.ain_w[:int(self.n_ains/2),-1]) / self.cm)
            self.ain_voltage_vec[int(self.n_ains/2):,-1] += self.delta_t*((-self.g_L*(self.ain_voltage_vec[int(self.n_ains/2):,-1] \
                                                            - self.E_L) + self.g_L * self.slope_f * np.exp((self.ain_voltage_vec[int(self.n_ains/2):,-1] \
                                                            - self.V_T)/self.slope_f) - self.ain_g_ex[int(self.n_ains/2):,-1] * \
                                                            (self.ain_voltage_vec[int(self.n_ains/2):,-1] - self.E_ex) - self.ain_g_inh[int(self.n_ains/2):,-1] \
                                                            * (self.ain_voltage_vec[int(self.n_ains/2):,-1] - self.E_inh) + I_R - self.ain_w[int(self.n_ains/2):,-1]) / self.cm)
            
            self.cin_voltage_vec[:int(self.n_cins/2),-1] += self.delta_t*((-self.g_L*(self.cin_voltage_vec[:int(n_cins/2),-1] \
                                                            - self.E_L) + self.g_L * self.slope_f * np.exp((self.cin_voltage_vec[:int(self.n_cins/2),-1] \
                                                            - self.V_T)/self.slope_f) - self.cin_g_ex[:int(self.n_cins/2),-1] * \
                                                            (self.cin_voltage_vec[:int(self.n_cins/2),-1] - self.E_ex) - self.cin_g_inh[:int(self.n_cins/2),-1] \
                                                            * (self.cin_voltage_vec[:int(self.n_cins/2),-1] - self.E_inh) + I_L \
                                                            - self.cin_w[:int(self.n_cins/2),-1]) / self.cm)
            self.cin_voltage_vec[int(self.n_cins/2):,-1] += self.delta_t*((-self.g_L*(self.cin_voltage_vec[int(self.n_cins/2):,-1] \
                                                            - self.E_L) + self.g_L * self.slope_f * np.exp((self.cin_voltage_vec[int(self.n_cins/2):,-1] \
                                                            - self.V_T)/self.slope_f) - self.cin_g_ex[int(self.n_cins/2):,-1] \
                                                            * (self.cin_voltage_vec[int(self.n_cins/2):,-1] - self.E_ex) - self.cin_g_inh[int(self.n_cins/2):,-1] \
                                                            * (self.cin_voltage_vec[int(self.n_cins/2):,-1] - self.E_inh) + I_R - self.cin_w[int(self.n_cins/2):,-1]) / self.cm)

            self.din_voltage_vec[:int(self.n_dins/2),-1] += self.delta_t*((-self.g_L*(self.din_voltage_vec[:int(self.n_dins/2),-1] \
                                                            - self.E_L) + self.g_L * self.slope_f * \
                                                            np.exp((self.din_voltage_vec[:int(self.n_dins/2),-1] - self.V_T)/self.slope_f) \
                                                            - self.din_g_ex[:int(self.n_dins/2),-1] * (self.din_voltage_vec[:int(self.n_dins/2),-1] \
                                                            - self.E_ex) - self.din_g_inh[:int(self.n_dins/2),-1] * (self.din_voltage_vec[:int(self.n_dins/2),-1] \
                                                            - self.E_inh) - i_gj[:int(self.n_dins/2),-1] + I_L_din - self.din_w[:int(self.n_dins/2),-1]) / self.din_cm)
            self.din_voltage_vec[int(self.n_dins/2):,-1] += self.delta_t*((-self.g_L*(self.din_voltage_vec[int(self.n_dins/2):,-1] \
                                                            - self.E_L) + self.g_L * self.slope_f * \
                                                            np.exp((self.din_voltage_vec[int(self.n_dins/2):,-1] - self.V_T)/self.slope_f) \
                                                            - self.din_g_ex[int(self.n_dins/2):,-1] * (self.din_voltage_vec[int(self.n_dins/2):,-1] \
                                                            - self.E_ex) - self.din_g_inh[int(self.n_dins/2):,-1] * (self.din_voltage_vec[int(self.n_dins/2):,-1] \
                                                            - self.E_inh) - i_gj[int(self.n_dins/2):,-1] + I_R_din - self.din_w[int(self.n_dins/2):,-1]) / self.din_cm)
           
            self.mn_w += self.delta_t*((self.mn_a * (self.mn_voltage_vec - self.E_L) - self.mn_w) / self.mn_tau_w)
            self.ain_w += self.delta_t*((self.mn_a * (self.ain_voltage_vec - self.E_L) - self.ain_w) / self.mn_tau_w)
            self.cin_w += self.delta_t*((self.mn_a * (self.cin_voltage_vec - self.E_L) - self.cin_w) / self.mn_tau_w)
            self.din_w += self.delta_t*((self.din_a * (self.din_voltage_vec - self.E_L) - self.din_w) / self.din_tau_w)
            self.mn_g_ex += self.delta_t*(-self.mn_g_ex/tau_ex)
            self.mn_g_inh += self.delta_t*(-self.mn_g_inh/tau_inh)
            self.ain_g_ex += self.delta_t*(-self.ain_g_ex/tau_ex)
            self.ain_g_inh += self.delta_t*(-self.ain_g_inh/tau_inh)
            self.cin_g_ex += self.delta_t*(-self.cin_g_ex/tau_ex)
            self.cin_g_inh += self.delta_t*(-self.cin_g_inh/tau_inh)
            self.din_g_ex += self.delta_t*(-self.din_g_ex/tau_ex)
            self.din_g_inh += self.delta_t*(-self.din_g_inh/tau_inh)
            for mn in range(self.n_mns):
                if self.mn_voltage_vec[mn,-1] >= self.V_thr:
                    self.mn_voltage_vec[mn,-1] = self.V_thr
                    self.mn_spike_times[mn].append(t)
                    self.mn_w[mn,-1] += self.mn_b
            self.mn_voltage_record[:,idx] += self.mn_voltage_vec[[0,int(self.n_mns/2)],:].squeeze()
            for self.ain in range(self.n_ains):
                if self.ain_voltage_vec[ain,-1] >= self.V_thr:
                    self.ain_voltage_vec[ain,-1] = self.V_thr
                    self.ain_spike_times[ain].append(t)
                    self.ain_w[ain,-1] += self.mn_b
            self.ain_voltage_record[:,idx] += self.ain_voltage_vec[[0,int(self.n_ains/2)],:].squeeze()
            for cin in range(self.n_cins):
                if self.cin_voltage_vec[cin,-1] >= self.V_thr:
                    self.cin_voltage_vec[cin,-1] = self.V_thr
                    self.cin_spike_times[cin].append(t)
                    self.cin_w[cin,-1] += self.mn_b
            self.cin_voltage_record[:,idx] += self.cin_voltage_vec[[0,int(self.n_cins/2)],:].squeeze()
            for din in range(self.n_dins):
                if self.din_voltage_vec[din,-1] >= self.V_thr:
                    self.din_voltage_vec[din,-1] = self.V_thr
                    self.din_spike_times[din].append(t)
                    self.din_w[din,-1] += din_b
            self.din_voltage_record[:,idx] += self.din_voltage_vec[[0,int(self.n_dins/2)],:].squeeze()

        return self.mn_voltage_record, self.ain_voltage_record, self.cin_voltage_record, self.din_voltage_record
    
    # calculate the summary statistics: average firing rate and l-r phase difference for MNs, aINs, cINs and dINs

    def calculate_summary_statistics(self, mn_x, ain_x, cin_x, din_x):
        fr_mns = np.zeros(np.shape(mn_x)[0])
        time_spike1_mns = np.zeros(np.shape(mn_x)[0])
        fr_ains = np.zeros(np.shape(ain_x)[0])
        time_spike1_ains = np.zeros(np.shape(ain_x)[0])
        fr_cins = np.zeros(np.shape(cin_x)[0])
        time_spike1_cins = np.zeros(np.shape(cin_x)[0])
        fr_dins = np.zeros(np.shape(din_x)[0])
        time_spike1_dins = np.zeros(np.shape(din_x)[0])
        for mn in range(np.shape(mn_x)[0]):
            fr_mns[mn] += len(np.where(mn_x[mn,:] >= self.V_thr))/(self.time/1000)
            time_spike1_mn[mn] += np.where(mn_x[mn,:] >= self.V_thr)[0][0]*self.delta_t
        for ain in range(np.shape(ain_x)[0]):
            fr_ains[ain] += len(np.where(ain_x[ain,:] >= self.V_thr))/(self.time/1000)
            time_spike1_ains[ain] += np.where(ain_x[ain,:] >= self.V_thr)[0][0]*self.delta_t
        for cin in range(np.shape(cin_x)[0]):
            fr_cins[cin] += len(np.where(cin_x[cin,:] >= self.V_thr))/(self.time/1000)
            time_spike1_cins[cin] += np.where(cin_x[cin,:] >= self.V_thr)[0][0]*self.delta_t
        for din in range(np.shape(din_x)[0]):
            fr_dins[din] += len(np.where(din_x[din,:] >= self.V_thr))/(self.time/1000)
            time_spike1_dins[din] += np.where(din_x[din,:] >= self.V_thr)[0][0]*self.delta_t
        spike1_diff_mn = np.abs(time_spike1_mns[0] - time_spike1_mns[1])
        av_fr_mn = (fr_mns[0] + fr_mns[1]) / 2
        av_isi_mn = 1000 * (1 / av_fr_mn)
        phase_diff_mn = spike1_diff_mn / av_isi_mn
        spike1_diff_ain = np.abs(time_spike1_ains[0] - time_spike1_ains[1])
        av_fr_ain = (fr_ains[0] + fr_ains[1]) / 2
        av_isi_ain = 1000 * (1 / av_fr_ain)
        phase_diff_ain = spike1_diff_ain / av_isi_ain
        spike1_diff_cin = np.abs(time_spike1_cins[0] - time_spike1_cins[1])
        av_fr_cin = (fr_cins[0] + fr_cins[1]) / 2
        av_isi_cin = 1000 * (1 / av_fr_cin)
        phase_diff_cin = spike1_diff_cin / av_isi_cin
        spike1_diff_din = np.abs(time_spike1_dins[0] - time_spike1_dins[1])
        av_fr_din = (fr_dins[0] + fr_dins[1]) / 2
        av_isi_din = 1000 * (1 / av_fr_din)
        phase_diff_din = spike1_diff_din / av_isi_din
        sum_stats = np.array((av_fr_mn, phase_diff_mn, av_fr_ain, phase_diff_ain, av_fr_cin, phase_diff_cin, av_fr_din, phase_diff_din))
        return sum_stats

In [24]:
# simulation wrapper to take in the parameters, run an AdEx simulation and calculate the summary statistics

def simulation_wrapper(params):
    obs = run_AdEx_model(params, time, delta_t, V_thr)
    summstats = torch.as_tensor(calculate_summary_statistics(obs, delta_t,V_thr,t_on))
    return summstats

In [25]:
# experimental parameters: firing rate for MNs, dINs, cINs = 20Hz while for aINs = 10Hz
# phase difference between left and right side for each population should be 0.5

exp_stats = np.array((20, 0.5, 10, 0.5, 20, 0.5, 20, 0.5))

In [26]:
# generate priors

prior_min = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
prior_max = [50, 50, 50, 50, 50, 50, 50, 50, 100, 100, 100, 100, 100, 100, 100, 100, 20, 10, 1]
prior =utils.torchutils.BoxUniform(
        low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max))

In [27]:
simulator, prior = prepare_for_sbi(simulation_wrapper,prior)
inference = SNPE(prior=prior)

UnboundLocalError: local variable 'mn_g_ex' referenced before assignment

In [None]:
# single round SBI
theta, x = simulate_for_sbi(simulator, prior, num_simulations=500, num_workers=4)
density_estimator = inference.append_simulations(theta,x,proposal=prior).train()
posterior = inference.build_posterior(density_estimator)
posterior_sample = posterior.sample((1,), x=exp_stats).numpy()

fig = plt.figure(figsize=(7, 5))

# plot observations
mn_obs = mn_trace
ain_obs = ain_trace
cin_obs = cin_trace
din_obs = din_trace
plt.plot(np.arange(0,time,delta_t), mn_obs, lw=2, label="MN observation")
plt.plot(np.arange(0,time,delta_t), ain_obs, lw=2, label="aIN observation")
plt.plot(np.arange(0,time,delta_t), cin_obs, lw=2, label="cIN observation")
plt.plot(np.arange(0,time,delta_t), din_obs, lw=2, label="dIN observation")

# simulate and plot samples
x_mn, x_ain, x_cin, x_din = run_AdEx_model(posterior_sample[0], time, delta_t, V_thr)
# sum_stats = calculate_summary_statistics(x_mn, x_ain, x_cin, x_din, delta_t, time, V_thr)
plt.plot(np.arange(0,time,delta_t), x_mn, "--", lw=2, label="MN posterior sample")
plt.plot(np.arange(0,time,delta_t), x_ain, "--", lw=2, label="aIN posterior sample")
plt.plot(np.arange(0,time,delta_t), x_cin, "--", lw=2, label="cIN posterior sample")
plt.plot(np.arange(0,time,delta_t), x_din, "--", lw=2, label="dIN posterior sample")

plt.xlabel("time (ms)")
plt.ylabel("voltage (mV)")
plt.savefig('Figures/SBI_CPG_500sims.png')

ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1.3, 1), loc="upper right")