Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development #30

Merged
merged 11 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,4 @@ bench_test_old.py
bench_test.py
test2.py
examples_new/val_tau.py
cadex.py
186 changes: 113 additions & 73 deletions dendrify/compartment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
"""
This module defines the classes for different types of compartments in a neuron
model.

The `Compartment` class is a base class that provides the basic functionality for
a single compartment. It handles all differential equations and parameters needed
to describe a single compartment and any currents passing through it.

The `Soma` and `Dendrite` classes inherit from the `Compartment` class and represent
specific types of compartments.

Classes:
Compartment: Represents a single compartment in a neuron model.
Soma: Represents the somatic compartment in a neuron model.
Dendrite: Represents a dendritic compartment in a neuron model.
"""

from __future__ import annotations

import pprint as pp
Expand All @@ -6,7 +23,7 @@
import numpy as np
from brian2 import defaultclock
from brian2.core.functions import timestep
from brian2.units import Quantity, ms, mV, pA
from brian2.units import Quantity, ms, pA

from .ephysproperties import EphysProperties
from .equations import library
Expand Down Expand Up @@ -125,7 +142,7 @@ def _add_equations(self, model: str):
if model in library:
self._equations = library[model].format('_'+self.name)
else:
logger.warning(("The model you provided is not found. The default "
logger.warning(("The model you provided is not found. The default "
"'passive' membrane model will be used instead."))
self._equations = library['passive'].format('_'+self.name)

Expand Down Expand Up @@ -165,12 +182,13 @@ def connect(self,
if self.name == other.name:
raise ValueError(
"Cannot connect compartments with the same name.\n")
if (self.dimensionless or other.dimensionless) and type(g) == str:
if (self.dimensionless or other.dimensionless) and isinstance(g, str):
raise DimensionlessCompartmentError(
("Cannot automatically calculate the coupling \nconductance of "
"dimensionless compartments. To resolve this error, perform\n"
"one of the following:\n\n"
f"1. Provide [length, diameter, r_axial] for both '{self.name}'"
f"1. Provide [length, diameter, r_axial] for both '{
self.name}'"
f" and '{other.name}'.\n\n"
f"2. Turn both compartment into dimensionless by providing only"
" values for \n [cm_abs, gl_abs] and then connect them using "
Expand All @@ -179,23 +197,23 @@ def connect(self,
)

# Current from Comp2 -> Comp1
I_forward = 'I_{1}_{0} = (V_{1}-V_{0}) * g_{1}_{0} :amp'.format(
forward_current = 'I_{1}_{0} = (V_{1}-V_{0}) * g_{1}_{0} :amp'.format(
self.name, other.name)
# Current from Comp1 -> Comp2
I_backward = 'I_{0}_{1} = (V_{0}-V_{1}) * g_{0}_{1} :amp'.format(
self.name, other.name)
backward_current = 'I_{0}_{1} = (V_{0}-V_{1}) * g_{0}_{1} :amp'.format(
self.name, other.name)

# Add them to their respective compartments:
self._equations += '\n'+I_forward
other._equations += '\n'+I_backward
self._equations += '\n'+forward_current
other._equations += '\n'+backward_current

# Include them to the I variable (I_ext -> Inj + new_current):
self_change = f'= I_ext_{self.name}'
other_change = f'= I_ext_{other.name}'
self._equations = self._equations.replace(
self_change, self_change + ' + ' + I_forward.split('=')[0])
self_change, self_change + ' + ' + forward_current.split('=')[0])
other._equations = other._equations.replace(
other_change, other_change + ' + ' + I_backward.split('=')[0])
other_change, other_change + ' + ' + backward_current.split('=')[0])

# add them to connected comps
if not self._connections:
Expand Down Expand Up @@ -342,11 +360,11 @@ def noise(self, tau: Quantity = 20*ms, sigma: Quantity = 1*pA,
mean : :class:`~brian2.units.fundamentalunits.Quantity`, optional
Mean of the Gaussian noise, by default ``0*pA``
"""
I_noise_name = f'I_noise_{self.name}'
noise_current = f'I_noise_{self.name}'

if I_noise_name in self.equations:
if noise_current in self.equations:
raise DuplicateEquationsError(
f"The equations of '{I_noise_name}' have already been "
f"The equations of '{noise_current}' have already been "
f"added to '{self.name}'. \nYou might be seeing this error if "
"you are using Jupyter/iPython "
"which store variable values \nin memory. Try cleaning all "
Expand All @@ -359,7 +377,7 @@ def noise(self, tau: Quantity = 20*ms, sigma: Quantity = 1*pA,
to_change = f'= I_ext_{self.name}'
self._equations = self._equations.replace(
to_change,
f'{to_change} + {I_noise_name}'
f'{to_change} + {noise_current}'
)
self._equations += '\n'+noise_eqs

Expand Down Expand Up @@ -462,10 +480,21 @@ def _g_couples(self) -> Union[dict, None]:
return d_out

@staticmethod
def g_norm_factor(trise: Quantity, tdecay: Quantity):
tpeak = (tdecay*trise / (tdecay-trise)) * np.log(tdecay/trise)
factor = (((tdecay*trise) / (tdecay-trise))
* (-np.exp(-tpeak/trise) + np.exp(-tpeak/tdecay))
def g_norm_factor(t_rise: Quantity, t_decay: Quantity):
"""
Calculates the normalization factor for synaptic conductance with
t_rise and t_decay kinetics.

Parameters:
t_rise (Quantity): The rise time of the function.
t_decay (Quantity): The decay time of the function.

Returns:
float: The normalization factor for the g function.
"""
t_peak = (t_decay*t_rise / (t_decay-t_rise)) * np.log(t_decay/t_rise)
factor = (((t_decay*t_rise) / (t_decay-t_rise))
* (-np.exp(-t_peak/t_rise) + np.exp(-t_peak/t_decay))
/ ms)
return 1/factor

Expand All @@ -478,12 +507,14 @@ def dimensionless(self) -> bool:
-------
bool
"""
return True if self._ephys_object._dimensionless else False
return bool(self._ephys_object._dimensionless)


class Soma(Compartment):
"""
A class that automatically generates and handles all differential equations
A class representing a somatic compartment in a neuron model.

This class automatically generates and handles all differential equations
and parameters needed to describe a somatic compartment and any currents
(synaptic, dendritic, noise) passing through it.

Expand Down Expand Up @@ -739,8 +770,8 @@ def dspikes(self, name: str,

# The following code creates all necessary equations for dspikes:
comp = self.name
ID = f"{name}_{comp}"
event_name = f"spike_{ID}"
event_id = f"{name}_{comp}"
event_name = f"spike_{event_id}"

if self._events:
# Check if this event already exists
Expand All @@ -760,52 +791,59 @@ def dspikes(self, name: str,
else:
self._events = {}

dspike_currents = f"I_rise_{ID} + I_fall_{ID}"
dspike_currents = f"I_rise_{event_id} + I_fall_{event_id}"

# Both currents take into account the reversal potential of Na/K
I_rise_eqs = f"I_rise_{ID} = g_rise_{ID} * (E_rise_{name}-V_{comp}) :amp"
I_fall_eqs = f"I_fall_{ID} = g_fall_{ID} * (E_fall_{name}-V_{comp}) :amp"
current_rise_eqs = f"I_rise_{event_id} = g_rise_{
event_id} * (E_rise_{name}-V_{comp}) :amp"
current_fall_eqs = f"I_fall_{event_id} = g_fall_{
event_id} * (E_fall_{name}-V_{comp}) :amp"

# Ion conductances
g_rise_eqs = (
f"g_rise_{ID} = "
f"g_rise_max_{ID} * "
f"int(t_in_timesteps <= spiketime_{ID} + duration_rise_{ID}) * "
f"gate_{ID} "
f"g_rise_{event_id} = "
f"g_rise_max_{event_id} * "
f"int(t_in_timesteps <= spiketime_{
event_id} + duration_rise_{event_id}) * "
f"gate_{event_id} "
":siemens"
)
g_fall_eqs = (
f"g_fall_{ID} = "
f"g_fall_max_{ID} * "
f"int(t_in_timesteps <= spiketime_{ID} + offset_fall_{ID} + duration_fall_{ID}) * "
f"int(t_in_timesteps >= spiketime_{ID} + offset_fall_{ID}) * "
f"gate_{ID} "
f"g_fall_{event_id} = "
f"g_fall_max_{event_id} * "
f"int(t_in_timesteps <= spiketime_{
event_id} + offset_fall_{event_id} + duration_fall_{event_id}) * "
f"int(t_in_timesteps >= spiketime_{
event_id} + offset_fall_{event_id}) * "
f"gate_{event_id} "
":siemens"
)
spiketime = f'spiketime_{ID} :1' # in units of timestep
gate = f'gate_{ID} :1' # zero or one
spiketime = f'spiketime_{event_id} :1' # in units of timestep
gate = f'gate_{event_id} :1' # zero or one

# Add equations to a compartment
to_replace = f'= I_ext_{comp}'
self._equations = self._equations.replace(
to_replace,
f'{to_replace} + {dspike_currents}'
)
self._equations += '\n'.join(['', I_rise_eqs, I_fall_eqs,
self._equations += '\n'.join(['', current_rise_eqs, current_fall_eqs,
g_rise_eqs, g_fall_eqs,
spiketime, gate]
)

# Create and add custom dspike event
event_name = f"spike_{ID}"
condition = (f"V_{comp} >= Vth_{ID} and "
f"t_in_timesteps >= spiketime_{ID} + refractory_{ID} * gate_{ID}"
event_name = f"spike_{event_id}"
condition = (f"V_{comp} >= Vth_{event_id} and "
f"t_in_timesteps >= spiketime_{
event_id} + refractory_{event_id} * gate_{event_id}"
)

self._events[event_name] = condition

# Specify what is going to happen inside run_on_event()
action = {f"spike_{ID}": f"spiketime_{ID} = t_in_timesteps; gate_{ID} = 1"}
action = {f"spike_{event_id}": f"spiketime_{
event_id} = t_in_timesteps; gate_{event_id} = 1"}
if not self._event_actions:
self._event_actions = action
else:
Expand All @@ -817,28 +855,30 @@ def dspikes(self, name: str,

dt = defaultclock.dt

params = [threshold,
g_rise,
g_fall,
self._ionic_param(reversal_rise),
self._ionic_param(reversal_fall),
self._timestep(duration_rise, dt),
self._timestep(duration_fall, dt),
self._timestep(offset_fall, dt),
self._timestep(refractory, dt)]

vars = [f"Vth_{ID}",
f"g_rise_max_{ID}",
f"g_fall_max_{ID}",
f"E_rise_{name}",
f"E_fall_{name}",
f"duration_rise_{ID}",
f"duration_fall_{ID}",
f"offset_fall_{ID}",
f"refractory_{ID}"]

d = dict(zip(vars, params))
self._dspike_params[ID] = d
params = [
threshold,
g_rise,
g_fall,
self._ionic_param(reversal_rise),
self._ionic_param(reversal_fall),
self._timestep(duration_rise, dt),
self._timestep(duration_fall, dt),
self._timestep(offset_fall, dt),
self._timestep(refractory, dt)]

variables = [
f"Vth_{event_id}",
f"g_rise_max_{event_id}",
f"g_fall_max_{event_id}",
f"E_rise_{name}",
f"E_fall_{name}",
f"duration_rise_{event_id}",
f"duration_fall_{event_id}",
f"offset_fall_{event_id}",
f"refractory_{event_id}"]

d = dict(zip(variables, params))
self._dspike_params[event_id] = d

def _timestep(self,
param: Union[Quantity, None], dt
Expand All @@ -847,26 +887,26 @@ def _timestep(self,
return None
if isinstance(param, Quantity):
return timestep(param, dt)
else:
raise ValueError(
f"Please provide a valid time parameter for '{self.name}'."
)
raise ValueError(
f"Please provide a valid time parameter for '{self.name}'."
)

def _ionic_param(self,
param: Union[str, Quantity, None],
) -> Union[Quantity, None]:
DEFAULT_PARAMS = EphysProperties.DEFAULT_PARAMS
valid_params = {k: v for k, v in DEFAULT_PARAMS.items() if k[0] == 'E'}
default_params = EphysProperties.DEFAULT_PARAMS
valid_params = {k: v for k, v in default_params.items() if k[0] == 'E'}
if not param:
return None
if isinstance(param, Quantity):
return param
elif isinstance(param, str):
if isinstance(param, str):
try:
return DEFAULT_PARAMS[param]
return default_params[param]
except KeyError:
raise ValueError(
f"Please provide a valid ionic parameter for '{self.name}'."
f"Please provide a valid ionic parameter for '{
self.name}'."
" Available options:\n"
f"{pp.pformat(valid_params)}"
)
Expand Down