In [61]:
from typing import Mapping
import numbers

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sbi.utils
import sbi.analysis
import sbi.inference
import torch

from brian2 import (get_device, get_local_namespace, Quantity, NeuronGroup, Network, StateMonitor, second, ms, volt, mV, amp, nA, siemens, uS, nS, farad, pF)
from brian2.input import TimedArray
from brian2.core.functions import Function
from brian2.equations.equations import Equations
from brian2.devices import set_device, reset_device, device
from brian2.devices.cpp_standalone.device import CPPStandaloneDevice
from brian2.units.fundamentalunits import DIMENSIONLESS, get_dimensions


from brian2modelfitting.simulator import RuntimeSimulator, CPPStandaloneSimulator

In [2]:
def setup_fit():
    simulators = {'CPPStandaloneDevice': CPPStandaloneSimulator(),
                  'RuntimeDevice': RuntimeSimulator()}
    if isinstance(get_device(), CPPStandaloneDevice):
        if device.has_been_run is True:
            build_options = dict(device.build_options)
            get_device().reinit()
            get_device().activate(**build_options)
    return simulators[get_device().__class__.__name__]

In [3]:
def get_full_namespace(additional_namespace, level=0):
    namespace = {key: value
                 for key, value in get_local_namespace(level=level + 1).items()
                 if isinstance(value, (Quantity, numbers.Number, Function))}
    namespace.update(additional_namespace)
    return namespace

In [32]:
def get_param_dic(params, param_names, n_samples):
    params = np.array(params)
    d = dict()
    for name, value in zip(param_names, params.T):
        d[name] = (np.ones((n_samples, )) * value)
    return d

In [4]:
def calc_prior(parameter_names, **params):
    for param in parameter_names:
        if param not in params:
            raise TypeError(f'"Bounds must be set for parameter {param}')
    prior_min = []
    prior_max = []
    for name in parameter_names:
        prior_min.append(min(params[name]).item())
        prior_max.append(max(params[name]).item())
    prior = sbi.utils.torchutils.BoxUniform(low=torch.as_tensor(prior_min),
                                            high=torch.as_tensor(prior_max))
    return prior

In [45]:
class Inferencer(object):
    def __init__(self, dt, model, input, output, method=None, threshold=None,
                 reset=None, refractory=None, param_init=None):
        self.dt = dt
        
        if isinstance(model, str):
            self.model = Equations(model)
        else:
            raise TypeError('Equations must be appropriately formatted.')
        
        if not isinstance(input, Mapping):
            raise TypeError('`input` argument must be a dictionary mapping'
                            ' the name of the input variable and `input`.')
        if len(input) > 1:
            raise NotImplementedError('Only a single input is supported.')
        self.input_var = list(input.keys())[0]
        self.input = input[self.input_var]
        
        if not isinstance(output, Mapping):
            raise TypeError('`output` argument must be a dictionary mapping'
                            ' the name of the output variable and `output`')
        if len(output) > 1:
            raise NotImplementedError('Only a single output is supported')
        self.output_var = list(output.keys())[0]
        self.output = output[self.output_var]
        
        input_dim = get_dimensions(self.input)
        input_dim = '1' if input_dim is DIMENSIONLESS else repr(input_dim)
        input_eqs = f'{self.input_var} = input_var(t) : {input_dim}'
        self.model += input_eqs
        
        self.input_traces = TimedArray(self.input.transpose(), dt=self.dt)
        
        n_steps = self.input.size
        self.sim_time = self.dt * n_steps
        
        if not param_init:
            param_init = {}
        for param, val in param_init.items():
            if not (param in self.model.diff_eq_names or
                    param in self.model.parameter_names):
                raise ValueError(f'{param} is not a model variable or a'
                                 ' parameter in the model')
        self.param_init = param_init
        
        self.param_names = self.model.parameter_names
        self.method = method
        self.threshold = threshold
        self.reset = reset
        self.refractory = refractory
    
    def setup_simulator(self, n_samples, output_var, param_init, level=1):
        simulator = setup_fit()
        
        namespace = get_full_namespace({'input_var': self.input_traces},
                                       level=level+1)
        namespace[f'output_var'] = TimedArray(self.output.transpose(), dt=self.dt)
        
        kwargs = {}
        if self.method is not None:
            kwargs['method'] = self.method
        model = self.model + Equations('iteration : integer (constant, shared)')
        neurons = NeuronGroup(N=n_samples,
                              model=model,
                              threshold=self.threshold,
                              reset=self.reset,
                              refractory=self.refractory,
                              dt=self.dt,
                              namespace=namespace,
                              name='neurons',
                              **kwargs)
        network = Network(neurons)
        network.add(StateMonitor(source=neurons, variables=output_var,
                                 record=True, dt=self.dt, name='statemonitor'))
        simulator.initialize(network, param_init)
        return simulator

    def generate(self, n_samples, level=0, **params):
        try:
            n_samples = int(n_samples)
        except ValueError as e:
            print(e)
        for param in params:
            if param not in self.param_names:
                raise ValueError(f'Parameter {param} must be defined as a'
                                 ' model\'s parameter')
        prior = calc_prior(self.param_names, **params)
        theta = prior.sample((n_samples, ))
        theta = np.atleast_2d(theta.numpy())
        
        self.simulator = self.setup_simulator(n_samples=n_samples,
                                              output_var=self.output_var,
                                              param_init=self.param_init,
                                              level=1)
        d_param = get_param_dic(theta, self.param_names, n_samples)
        self.simulator.run(self.sim_time, d_param, self.param_names, 0)
        
    def train(self):
        pass
    
    def sample(self):
        pass

In [62]:
df_inp_traces = pd.read_csv('data/input_traces_hh.csv')
df_out_traces = pd.read_csv('data/output_traces_hh.csv')

inp_traces = df_inp_traces.to_numpy()
inp_traces = inp_traces[0, 1:]

out_traces = df_out_traces.to_numpy()
out_traces = out_traces[0, 1:]

gleak = 10 * nS
Eleak = -70 * mV
VT = - 60.0 * mV
C = 200 * pF
ENa = 53 * mV
EK = - 107 * mV
eqs = '''
    dVm/dt = -(gNa*m**3*h*(Vm - ENa) + gK*n**4*(Vm - EK) + gleak*(Vm - Eleak) - I) / C : volt
    dm/dt = alpham*(1-m) - betam*m : 1
    dn/dt = alphan*(1-n) - betan*n : 1
    dh/dt = alphah*(1-h) - betah*h : 1
    alpham = (-0.32/mV) * (Vm - VT - 13.*mV) / (exp((-(Vm - VT - 13.*mV))/(4.*mV)) - 1)/ms : Hz
    betam = (0.28/mV) * (Vm - VT - 40.*mV) / (exp((Vm - VT - 40.*mV)/(5.*mV)) - 1)/ms : Hz
    alphah = 0.128 * exp(-(Vm - VT - 17.*mV) / (18.*mV))/ms : Hz
    betah = 4/(1 + exp((-(Vm - VT - 40.*mV)) / (5.*mV)))/ms : Hz
    alphan = (-0.032/mV) * (Vm - VT - 15.*mV) / (exp((-(Vm - VT - 15.*mV)) / (5.*mV)) - 1)/ms : Hz
    betan = 0.5*exp(-(Vm - VT - 10.*mV) / (40.*mV))/ms : Hz
    gNa : siemens (constant)
    gK : siemens (constant)
'''

inferencer = Inferencer(dt=0.5*ms, model=eqs,
                        input={'I': inp_traces*amp},
                        output={'Vm': out_traces*volt},
                        method='exponential_euler',
                        threshold='m>0.5',
                        refractory='m>0.5',
                        param_init={'Vm': 'Eleak',
                                    'm': '1/(1 + betam/alpham)',
                                    'h': '1/(1 + betah/alphah)',
                                    'n': '1/(1 + betan/alphan)'})

inferencer.generate(10, gNa=[0.5*uS, 80.*uS], gK=[1e-4*uS, 15.*uS])