In [None]:
import sys
from pathlib import Path
import time

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from scipy import stats as spstats
import statsmodels.api as sm

# pytorch related stuff
import torch 

# SBI related stuff
import sbi
from sbi.inference import SNPE
from sbi import utils as sutils

# modules
from metrics import HH_metrics as metrics

In [None]:
sys.path.append('./HH/')

In [None]:
import hh_cython
import utils

In [None]:
solver = hh_cython.forwardeuler

# Util

In [None]:
itr = 1
path = f"results/HH/{itr}"
Path(f"{path}/samples").mkdir(parents=True, exist_ok=True)

In [None]:
def save_sample(tensor, name):
    Path(f"{path}/samples").mkdir(parents=True, exist_ok=True)
    torch.save(tensor, f"{path}/samples/{name}.pt")

def load_sample(name):
    sample = torch.load(f"{path}/samples/{name}.pt")
    return sample

# Simulator (HH model 8D)

In [None]:
seed = 247
torch.manual_seed(seed)
np.random.seed(seed)

dim_theta = 8
dim_x = 13

In [None]:
class HH:
    def __init__(self, init, params, seed=None):
        self.state = init
        self.params = np.asarray(params)

        self.seed = seed
        if seed is not None:
            hh_cython.seed(seed)
            self.rng = np.random.RandomState(seed=seed)
        else:
            self.rng = np.random.RandomState()

    def sim_time(self, dt, t, I, fineness=1, max_n_steps=float('inf')):
        """Simulates the model for a specified time duration."""

        hh_cython.setparams(self.params)
        tstep = float(dt)

        # explictly cast everything to double precision
        t = t.astype(np.float64)
        I = I.astype(np.float64)
        V = np.zeros_like(t).astype(np.float64)  # baseline voltage
        V[0] = self.state
        n = np.zeros_like(t).astype(np.float64)
        m = np.zeros_like(t).astype(np.float64)
        h = np.zeros_like(t).astype(np.float64)
        p = np.zeros_like(t).astype(np.float64)
        q = np.zeros_like(t).astype(np.float64)
        r = np.zeros_like(t).astype(np.float64)
        u = np.zeros_like(t).astype(np.float64)
        r_mat = self.rng.randn(len(t)).astype(np.float64)

        solver(t, I, V, m, n, h, p, q, r, u, tstep, r_mat)

        return np.array(V).reshape(-1,1)

In [None]:
true_params, labels_params = utils.obs_params(reduced_model=False)

In [None]:
def syn_current(duration=120, dt=0.01, t_on=10, curr_level=5e-4, seed=None):
    duration = duration
    t_off = duration - t_on
    t = np.arange(0, duration + dt, dt)

    # external current
    A_soma = np.pi * ((70.0 * 1e-4) ** 2)  # cm2
    I = np.zeros_like(t)
    I[int(np.round(t_on / dt)) : int(np.round(t_off / dt))] = (
        curr_level / A_soma
    )  # muA/cm2

    return I, t_on, t_off, dt, t, A_soma

In [None]:
I, t_on, t_off, dt, t, A_soma = syn_current()

In [None]:
plt.plot(I)

In [None]:
def HHSimulator(init, params, dt, t, I):
    hh = HH(init, params.reshape(1,-1))
    V = hh.sim_time(dt, t, I)
    return V

In [None]:
init = -70
V = HHSimulator(init, true_params, dt, t, I)
V.shape

In [None]:
plt.plot(V)

In [None]:
def calculate_summary_statistics(x):
    n_xcorr=5
    n_mom=5
    n_summary=13

    N = x['data'].shape[0]
    t = x['time']
    dt = x['dt']

    # initialise array of spike counts
    v = np.array(x['data'])

    # put everything to -10 that is below -10 or has negative slope
    ind = np.where(v < -10)
    v[ind] = -10
    ind = np.where(np.diff(v) < 0)
    v[ind] = -10

    # remaining negative slopes are at spike peaks
    ind = np.where(np.diff(v) < 0)
    spike_times = np.array(t)[ind]
    spike_times_stim = spike_times[(spike_times > t_on) & (spike_times < t_off)]

    # number of spikes
    if spike_times_stim.shape[0] > 0:
        spike_times_stim = spike_times_stim[np.append(1, np.diff(spike_times_stim))>0.5]

    # resting potential and std
    rest_pot = np.mean(x['data'][t<t_on])
    rest_pot_std = np.std(x['data'][int(.9*t_on/dt):int(t_on/dt)])

    # auto-correlations
    x_on_off = x['data'][(t > t_on) & (t < t_off)]-np.mean(x['data'][(t > t_on) & (t < t_off)])
    x_corr_val = np.dot(x_on_off,x_on_off)

    xcorr_steps = np.linspace(1./dt,n_xcorr*1./dt,n_xcorr).astype(int)
    x_corr_full = np.zeros(n_xcorr)
    for ii in range(n_xcorr):
        x_on_off_part = np.concatenate((x_on_off[xcorr_steps[ii]:],np.zeros(xcorr_steps[ii])))
        x_corr_full[ii] = np.dot(x_on_off,x_on_off_part)

    x_corr1 = x_corr_full/x_corr_val

    std_pw = np.power(np.std(x['data'][(t > t_on) & (t < t_off)]), np.linspace(3,n_mom,n_mom-2))
    std_pw = np.concatenate((np.ones(1),std_pw))
    moments = spstats.moment(x['data'][(t > t_on) & (t < t_off)], np.linspace(2,n_mom,n_mom-1))/std_pw

    # concatenation of summary statistics
    try:
        sum_stats_vec = np.concatenate((
                np.array([spike_times_stim.shape[0]]),
                x_corr1,
                np.array([rest_pot,rest_pot_std,np.mean(x['data'][(t > t_on) & (t < t_off)])]),
                moments
            ))
        sum_stats_vec = sum_stats_vec[:n_summary]
    except:
        return None

    return sum_stats_vec

In [None]:
def run_HH_model(params):
    # input current, time step
    I, t_on, t_off, dt, t, A_soma = syn_current()
    t = np.arange(0, len(I), 1)*dt
    # initial voltage
    V0 = -70
    states = HHSimulator(V0, params, dt, t, I)
    return dict(data=states.reshape(-1), time=t, dt=dt, I=I.reshape(-1))

In [None]:
def simulator(params):
    n = params.shape[0]
    sumstats = []
    for i in range(n):
        obs = run_HH_model(params[i,:])
        summstats_i = torch.as_tensor(calculate_summary_statistics(obs))
        sumstats.append(summstats_i)
    return torch.stack(sumstats)

In [None]:
# define prior
prior_min = [.5,1e-4,0.05,0.035,3e2,30.0,0.05,35.0]
prior_max = [80.,15.,0.15,0.105,9e2,90.0,0.15,105.0]
prior = sutils.torchutils.BoxUniform(low=torch.as_tensor(prior_min), high=torch.as_tensor(prior_max))

In [None]:
results = run_HH_model(true_params)

In [None]:
plt.plot(results["data"])

# Experiments

In [None]:
true_params

In [None]:
# observation_trace = run_HH_model(true_params)
# x_obs = calculate_summary_statistics(observation_trace)
# torch.save(observation_trace, "observation_trace.pkl")
# torch.save(x_obs, "x_obs.pkl")

In [None]:
observation_trace = torch.load("./HH/observation_trace.pkl")
x_obs = torch.load("./HH/x_obs.pkl")

In [None]:
x_obs

In [None]:
x_min = torch.tensor([ 0.0000e+00, -1, -1, -1, -1, -1, -1.2e+02,  
                      1.0e-03, -1.0e+02,  1.0e-03, -1.0e+01,  2.5e+00, -8.0e+02])
x_max = torch.tensor([3.5000e+01, 1, 1, 1, 1, 1, 7.0e+00, 
                      7.0e+01, 1.2e+01, 1.75e+03, 1.5e+01, 1.75e+02, 2.25e+03])

In [None]:
# define sur-prior
sur_prior = sutils.torchutils.BoxUniform(low=x_min, high=x_max) # non informative

In [None]:
torch.manual_seed(seed)
num_runs = 2
num_simulations = 10
samples_len = [1000, 2500, 5000, 15000, 25000, 50000] # simulation runs budget
sl = len(samples_len)
results = torch.ones(num_simulations, sl, 4, 4)*float('nan')

In [None]:
samples_len[0]//num_runs

In [None]:
torch.manual_seed(seed)

for i in range(sl):
  for j in range(num_simulations):
    n = samples_len[i]//num_runs

    # regular - mdn
    proposal = prior
    inference = SNPE(prior, density_estimator='mdn')
    for _ in range(num_runs):    
      theta = proposal.sample((n,))
      x_sim = simulator(theta).to(torch.float32)
      density_estimator = inference.append_simulations(theta, x_sim, proposal).train()
      posterior = inference.build_posterior(density_estimator)
      proposal = posterior.set_default_x(x_obs)

    sample_post1 = posterior.sample((1000,), x=x_obs)
    save_sample(sample_post1, f"reg_mdn_{n}_{j}")
    results[j,i,0,:] = metrics(sample_post1, true_params)

    # surrogate - mdn
    proposal = prior
    inference = SNPE(prior, density_estimator='mdn')
    for k in range(num_runs):    
      theta = proposal.sample((n*num_runs,))
      # 1st iteration - use real simulator, and train surrogate
      if k==0:  
        x_sim = simulator(theta).to(torch.float32)

        # train surrogate
        inference2 = SNPE(sur_prior, density_estimator='mdn')
        density_estimator = inference2.append_simulations(theta=x_sim, x=theta).train() # x and theta switch roles
        surrogate = inference2.build_posterior(density_estimator)

        # plot sim vs sur
        if j==0:
          pred = torch.zeros_like(x_sim)
          for l in range(len(theta)):
            pred[l] = surrogate.sample((1,), x=theta[l,:], show_progress_bars=False)

      # Other iterations - use surrogate/emulator instead
      else:
        x_sim = torch.zeros_like(x_sim)
        for l in range(len(theta)):
          x_sim[l] = surrogate.sample((1,), x=theta[l,:], show_progress_bars=False)

      density_estimator = inference.append_simulations(theta, x_sim, proposal).train()
      posterior = inference.build_posterior(density_estimator)
      proposal = posterior.set_default_x(x_obs)

    sample_post2 = posterior.sample((1000,), x=x_obs)
    save_sample(sample_post2, f"sur_mdn_{n}_{j}")
    results[j,i,1,:] = metrics(sample_post2, true_params)

    # regular - nsf
    proposal = prior
    inference = SNPE(prior, density_estimator='nsf')
    for _ in range(num_runs):    
      theta = proposal.sample((n,))
      x_sim = simulator(theta).to(torch.float32)
      density_estimator = inference.append_simulations(theta, x_sim, proposal).train()
      posterior = inference.build_posterior(density_estimator)
      proposal = posterior.set_default_x(x_obs)

    sample_post3 = posterior.sample((1000,), x=x_obs)
    save_sample(sample_post3, f"reg_nsf_{n}_{j}")
    results[j,i,2,:] = metrics(sample_post3, true_params)

    # surrogate - nsf
    proposal = prior
    inference = SNPE(prior, density_estimator='nsf')
    for k in range(num_runs):    
      theta = proposal.sample((n*num_runs,))
      # 1st iteration - use real simulator, and train surrogate
      if k==0:  
        x_sim = simulator(theta).to(torch.float32)

        # train surrogate
        inference2 = SNPE(sur_prior, density_estimator='nsf')
        density_estimator = inference2.append_simulations(theta=x_sim, x=theta).train() # x and theta switch roles
        surrogate = inference2.build_posterior(density_estimator)

        # plot sim vs sur
        if j==0:
          pred = torch.zeros_like(x_sim)
          for l in range(len(theta)):
            pred[l] = surrogate.sample((1,), x=theta[l,:], show_progress_bars=False)

      # Other iterations - use surrogate/emulator instead
      else:
        x_sim = torch.zeros_like(x_sim)
        for l in range(len(theta)):
          x_sim[l] = surrogate.sample((1,), x=theta[l,:], show_progress_bars=False)

      density_estimator = inference.append_simulations(theta, x_sim, proposal).train()
      posterior = inference.build_posterior(density_estimator)
      proposal = posterior.set_default_x(x_obs)

    sample_post4 = posterior.sample((1000,), x=x_obs)
    save_sample(sample_post4, f"sur_nsf_{n}_{j}")
    results[j,i,3,:] = metrics(sample_post4, true_params)

    torch.save(results, f'{path}/results.pkl')
