In [1]:
from typing import Mapping

import pandas as pd
import matplotlib.pyplot as plt
from brian2 import *
import sbi.utils
import sbi.analysis
import sbi.inference
import torch

from brian2modelfitting import *

In [2]:
def calc_prior(parameter_names, **params):
    for param in parameter_names:
        if param not in params:
            raise TypeError(f'"Bounds must be set for parameter {param}')
    prior_min = []
    prior_max = []
    for name in parameter_names:
        prior_min.append(min(params[name]).item())
        prior_max.append(max(params[name]).item())
    prior = sbi.utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min),
                                            high=torch.as_tensor(prior_max))
    return prior

In [3]:
class Inferencer(object):
    def __init__(self, dt, model, input, output, method=None, threshold=None,
                 reset=None, refractory=None, param_init=None):
        self.dt = dt
        
        if isinstance(model, str):
            self.model = Equations(model)
        else:
            raise TypeError('Equations must be appropriately formatted.')
        
        if not isinstance(input, Mapping):
            raise TypeError('`input` argument must be a dictionary mapping'
                            ' the name of the input variable and `input`.')
        if len(input) > 1:
            raise NotImplementedError('Only a single input is supported.')
        self.input_var = list(input.keys())[0]
        self.input = input[self.input_var]
        
        if not isinstance(output, Mapping):
            raise TypeError('`output` argument must be a dictionary mapping'
                            ' the name of the output variable and `output`')
        if len(output) > 1:
            raise NotImplementedError('Only a single output is supported')
        self.output_var = list(output.keys())[0]
        self.output = output[self.output_var]
        
        self.param_names = self.model.parameter_names
        self.method = method
        self.threshold = threshold
        self.reset = reset
        self.refractory = refractory
        self.param_init = param_init
        
        n_steps = self.input.size
        self.sim_time = self.dt * n_steps

    def generate(self, n_samples, **params):
        try:
            n_samples = int(n_samples)
        except ValueError as e:
            print(e)
        for param in params:
            if param not in self.param_names:
                raise ValueError(f'Parameter {param} must be defined as a'
                                 ' model\'s parameter')
        prior = calc_prior(self.param_names, **params)
        theta = prior.sample((n_samples, ))
        theta = np.atleast_2d(theta.numpy())
        G = NeuronGroup(N=n_samples,
                        model=self.model,
                        method=self.method,
                        threshold=self.threshold,
                        reset=self.reset,
                        refractory=self.refractory,
                        dt=self.dt)
        if self.param_init:
            for pinit_key, pinit_val in self.param_init.items():
                G.__setattr__(pinit_key, pinit_val)
        K = TimedArray(self.input.transpose(), dt=self.dt)
        G.add_attribute(self.input_var)
        G.__setattr__(self.input_var, 'K(t)')
        for param_idx, param_name in enumerate(self.param_names):
            param_dimension = get_dimensions(params[param_name])
            param_unit = get_unit(param_dimension)
            G.__setattr__(param_name, theta[:, param_idx] * param_unit)
        M = StateMonitor(G, self.output_var, record=True, dt=self.dt)
        run(self.sim_time)
        
    def train(self):
        pass
    
    def sample(self):
        pass

In [4]:
df_inp_traces = pd.read_csv('data/input_traces_hh.csv')
df_out_traces = pd.read_csv('data/output_traces_hh.csv')

inp_traces = df_inp_traces.to_numpy()
inp_traces = inp_traces[0, 1:]

out_traces = df_out_traces.to_numpy()
out_traces = out_traces[0, 1:]

In [5]:
start_scope()
gleak = 10*nS
Eleak = -70*mV
VT = -60.0*mV
C = 200*pF
ENa = 53*mV
EK = -107*mV
eqs = '''
    dVm/dt = -(gNa*m**3*h*(Vm - ENa) + gK*n**4*(Vm - EK) + gleak*(Vm - Eleak) - I) / C : volt
    dm/dt = alpham*(1-m) - betam*m : 1
    dn/dt = alphan*(1-n) - betan*n : 1
    dh/dt = alphah*(1-h) - betah*h : 1
    alpham = (-0.32/mV) * (Vm - VT - 13.*mV) / (exp((-(Vm - VT - 13.*mV))/(4.*mV)) - 1)/ms : Hz
    betam = (0.28/mV) * (Vm - VT - 40.*mV) / (exp((Vm - VT - 40.*mV)/(5.*mV)) - 1)/ms : Hz
    alphah = 0.128 * exp(-(Vm - VT - 17.*mV) / (18.*mV))/ms : Hz
    betah = 4/(1 + exp((-(Vm - VT - 40.*mV)) / (5.*mV)))/ms : Hz
    alphan = (-0.032/mV) * (Vm - VT - 15.*mV) / (exp((-(Vm - VT - 15.*mV)) / (5.*mV)) - 1)/ms : Hz
    betan = 0.5*exp(-(Vm - VT - 10.*mV) / (40.*mV))/ms : Hz
    # The parameters to fit
    gNa : siemens (constant)
    gK : siemens (constant)
    '''

inferencer = Inferencer(dt=0.5 * ms, model=eqs,
                        input={'I': inp_traces*amp},
                        output={'Vm': out_traces*volt},
                        method='exponential_euler',
                        threshold='m>0.5',
                        refractory='m>0.5',
                        param_init={'Vm': 'Eleak',
                                    'm': '1/(1 + betam/alpham)',
                                    'h': '1/(1 + betah/alphah)',
                                    'n': '1/(1 + betan/alphan)'})

inferencer.generate(10, gNa=[0.5*uS, 80.*uS], gK=[1e-4*uS, 15.*uS])

BrianObjectException: Error encountered with object named "neurongroup".
Object was created here (most recent call only, full details in debug log):
  File "<ipython-input-3-ab60428ddd13>", line 49, in generate
    G = NeuronGroup(N=n_samples,

An error occurred when preparing an object. (See above for original error message and traceback.)