Skip to content

Commit

Permalink
First working two-compartment example (with pyNN.neuron)
Browse files Browse the repository at this point in the history
  • Loading branch information
apdavison committed Apr 27, 2017
1 parent dd87886 commit 675a577
Show file tree
Hide file tree
Showing 15 changed files with 207 additions and 183 deletions.
49 changes: 24 additions & 25 deletions examples/mc/current_injection_mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,23 @@
distal=P(x=18.8, y=0, z=0, diameter=18.8),
name="soma")
dend = Segment(proximal=P(x=0, y=0, z=0, diameter=2),
distal=P(x=-200, y=0, z=0, diameter=2),
distal=P(x=-500, y=0, z=0, diameter=2),
name="dendrite",
parent=soma)

# need to specify nseg for dendrite

cell_class = sim.MultiCompartmentNeuron
cell_class.label = "ExampleMultiCompartmentNeuron"
cell_class.insert(pas=sim.PassiveLeak, sections=('soma', )) ###'dendrite')) # or cell_class.whole_cell.pas = sim.PassiveLeak
cell_class.insert(pas=sim.PassiveLeak, sections=('soma', 'dendrite')) # or cell_class.whole_cell.pas = sim.PassiveLeak
cell_class.soma.insert(na=sim.NaChannel) # or cell_class.soma.na = sim.NaChannel
cell_class.soma.insert(kdr=sim.KdrChannel) # or cell_class.soma.kdr = sim.KdrChannel

cell_type = cell_class(morphology=Morphology(segments=(soma, )), ###dend)),
cell_type = cell_class(morphology=Morphology(segments=(soma, dend)),
cm=1.0, #*uF_per_cm2, # allow to set per segment
Ra=123.0, #*ohm_cm) # allow to set per segment?
#pas={"conductance_density": 0.0005, #*S_per_cm2)
# "e_rev":-54.3},
Ra=500.0, #*ohm_cm) # allow to set per segment?
pas={"conductance_density": 0.0003, #*S_per_cm2)
"e_rev":-54.3},
na={"conductance_density": 0.120,
"e_rev": 50.0},
kdr={"conductance_density": 0.036,
Expand All @@ -52,24 +52,23 @@

# === Create a population with two cells ====================================

cells = sim.Population(2, cell_type, initial_values={'v': [-64.0, -65.0]}) #*mV})
cells = sim.Population(2, cell_type, initial_values={'v': [-60.0, -70.0]}) #*mV})

# === Inject current into the soma of cell #0 and the dendrite of cell #1 ===

#step_current = sim.DCSource(amplitude=0.5*nA, start=50.0*ms, stop=400.0*ms)
step_current = sim.DCSource(amplitude=0.1, start=50.0, stop=150.0)
step_current.inject_into(cells[0:1], section="soma")
###step_current.inject_into(cells[1:2], section="dendrite")
step_current.inject_into(cells[1:2], section="dendrite")


# cells[0] --> ID - 1 cell
# cells[0:1] --> PopulationView containing 1 cell

# === Record from both compartments of both cells ===========================

#cells.record('spikes')
cells.record(['na.m', 'kdr.n'], sections=['soma'])
cells.record('v', sections=['soma']) ###, 'dendrite'])
cells.record('spikes')
cells.record(['na.m', 'na.h', 'kdr.n'], locations=['soma'])
cells.record('v', locations=['soma', 'dendrite'])

# === Run the simulation =====================================================

Expand All @@ -86,28 +85,28 @@
# The segment contains one AnalogSignal per compartment and per recorded variable
# and one SpikeTrain per neuron

print("Spike times: {}".format(data.spiketrains))

Figure(
Panel(data.filter(name='soma.v')[0],
ylabel="Membrane potential, soma (mV)",
#yticks=True, ylim=(-66*mV, -48*mV)),
yticks=True), #, ylim=(-66, -48)),
# Panel(data.filter(name='dendrite.v')[0],
# ylabel="Membrane potential, dendrite (mV)",
# #yticks=True, ylim=(-66*mV, -48*mV)),
# yticks=True), #, ylim=(-66, -48)),
yticks=True, ylim=(-80, 40)),
Panel(data.filter(name='dendrite.v')[0],
ylabel="Membrane potential, dendrite (mV)",
yticks=True, ylim=(-70, -45)),
Panel(data.filter(name='soma.na.m')[0],
ylabel="m, soma",
yticks=True), #, ylim=(0, 1)),
# Panel(data.filter(name='soma.kdr.n')[0],
# xticks=True, xlabel="Time (ms)",
# ylabel="n, soma",
# yticks=True), #, ylim=(0, 1)),
yticks=True, ylim=(0, 1)),
Panel(data.filter(name='soma.na.h')[0],
xticks=True, xlabel="Time (ms)",
ylabel="h, soma",
yticks=True, ylim=(0, 1)),
Panel(data.filter(name='soma.kdr.n')[0],
ylabel="n, soma",
xticks=True, xlabel="Time (ms)",
yticks=True), # , ylim=(0, 1)),
yticks=True, ylim=(0, 1)),
title="Responses of two-compartment neurons to current injection",
annotations="Simulated with %s" % options.simulator.upper()
).save("current_injection_mc.png")

#sim.end()
sim.end()
38 changes: 10 additions & 28 deletions pyNN/common/populations.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from pyNN import random, recording, errors, standardmodels, core, space, descriptions
from pyNN.models import BaseCellType
from pyNN.parameters import ParameterSpace, LazyArray, simplify as simplify_parameter_array
from pyNN.recording import files, Variable
from pyNN.recording import files

deprecated = core.deprecated
logger = logging.getLogger("PyNN")
Expand Down Expand Up @@ -407,28 +407,11 @@ def initialize(self, **initial_values):
def find_units(self, variable):
return self.celltype.units[variable.name]

def can_record(self, variable):
def can_record(self, variable, location=None):
"""Determine whether `variable` can be recorded from this population."""
return self.celltype.can_record(variable)

def _resolve_variables(self, variables, sections):
resolved_variables = []
if sections is None:
for var_path in variables:
if "." in var_path:
parts = var_path.split(".")
section = parts[0]
var_name = ".".join(parts[1:])
resolved_variables.append(Variable(section=section, name=var_name))
else:
resolved_variables.append(Variable(None, var_path))
else:
for section in sections:
for var_name in variables:
resolved_variables.append(Variable(section=section, name=var_name))
return resolved_variables
return self.celltype.can_record(variable, location)

def record(self, variables, to_file=None, sampling_interval=None, sections=None):
def record(self, variables, to_file=None, sampling_interval=None, locations=None):
"""
Record the specified variable or variables for all cells in the
Population or view.
Expand All @@ -448,12 +431,11 @@ def record(self, variables, to_file=None, sampling_interval=None, sections=None)
# recording will be reset for the entire population, not just the view
self.recorder.reset()
else:
variables = self._resolve_variables(variables, sections)
logger.debug("%s.record('%s')", self.label, variables)
if self._record_filter is None:
self.recorder.record(variables, self.all_cells, sampling_interval)
self.recorder.record(variables, self.all_cells, sampling_interval, locations)
else:
self.recorder.record(variables, self._record_filter, sampling_interval)
self.recorder.record(variables, self._record_filter, sampling_interval, locations)
if isinstance(to_file, basestring):
self.recorder.file = to_file

Expand All @@ -471,7 +453,7 @@ def record_gsyn(self, to_file=True):
"""
self.record(['gsyn_exc', 'gsyn_inh'], to_file)

def write_data(self, io, variables='all', gather=True, clear=False, annotations=None):
def write_data(self, io, variables='all', gather=True, clear=False, annotations=None, locations=None):
"""
Write recorded data to file, using one of the file formats supported by
Neo.
Expand All @@ -496,9 +478,9 @@ def write_data(self, io, variables='all', gather=True, clear=False, annotations=
"""
logger.debug("Population %s is writing %s to %s [gather=%s, clear=%s]" % (self.label, variables, io, gather, clear))
self.recorder.write(variables, io, gather, self._record_filter, clear=clear,
annotations=annotations)
annotations=annotations, locations=locations)

def get_data(self, variables='all', gather=True, clear=False):
def get_data(self, variables='all', gather=True, clear=False, locations=None):
"""
Return a Neo `Block` containing the data (spikes, state variables)
recorded from the Population.
Expand All @@ -514,7 +496,7 @@ def get_data(self, variables='all', gather=True, clear=False):
If `clear` is True, recorded data will be deleted from the `Population`.
"""
return self.recorder.get(variables, gather, self._record_filter, clear)
return self.recorder.get(variables, gather, self._record_filter, clear, locations=locations)

@deprecated("write_data(file, 'spikes')")
def printSpikes(self, file, gather=True, compatible_output=True):
Expand Down
2 changes: 1 addition & 1 deletion pyNN/mock/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _get_all_signals(self, variable, ids, clear=False):
def _local_count(self, variable, filter_ids=None):
N = {}
if variable == 'spikes':
for id in self.filter_recorded(variable, filter_ids):
for id in self.filter_recorded(recording.Variable('spikes', None), filter_ids):
N[int(id)] = 2
else:
raise Exception("Only implemented for spikes")
Expand Down
7 changes: 5 additions & 2 deletions pyNN/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ class BaseCellType(BaseModelType):
conductance_based = True # override for cells with current-based synapses
injectable = True # override for spike sources

def can_record(self, variable):
return variable in self.recordable
def can_record(self, variable, location=None):
if location is None:
return variable in self.recordable
else:
return False


class BaseIonChannelModel(BaseModelType):
Expand Down
14 changes: 8 additions & 6 deletions pyNN/nest/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def _set_status(obj, parameters):
"""Wrapper around nest.SetStatus() to add a more informative error message."""
try:
nest.SetStatus(obj, parameters)
except nest.hl_api.NESTError as e:
raise nest.hl_api.NESTError("%s. Parameter dictionary was: %s" % (e, parameters))
except nest.NESTError as e:
raise nest.NESTError("%s. Parameter dictionary was: %s" % (e, parameters))


class RecordingDevice(object):
Expand Down Expand Up @@ -389,11 +389,13 @@ def _record(self, variable, new_ids, sampling_interval=None):
(http://www.nest-initiative.org/index.php/Analog_recording_with_multimeter, 14/11/11)
we record all analog variables for all requested cells.
"""
if variable == 'spikes':
if variable.location is not None:
raise ValueError("Recording from specific cell locations is not supported for NEST")
if variable.name == 'spikes':
self._spike_detector.add_ids(new_ids)
else:
self.sampling_interval = sampling_interval
self._multimeter.add_variable(variable)
self._multimeter.add_variable(variable.name)
self._multimeter.add_ids(new_ids)

def _get_sampling_interval(self):
Expand All @@ -419,7 +421,7 @@ def _get_spiketimes(self, id):
return self._spike_detector.get_spiketimes([id])[id] # hugely inefficient - to be optimized later

def _get_all_signals(self, variable, ids, clear=False):
data = self._multimeter.get_data(variable, ids, clear=clear)
data = self._multimeter.get_data(variable.name, ids, clear=clear)
if len(ids) > 0:
return numpy.vstack([data[i] for i in ids]).T
else:
Expand All @@ -445,7 +447,7 @@ def _local_count(self, variable, filter_ids):
# for id, l, r in zip(idx, left, right):
# N[id] = r-l
#return N
return self._spike_detector.get_spike_counts(self.filter_recorded('spikes', filter_ids))
return self._spike_detector.get_spike_counts(self.filter_recorded(recording.Variable('spikes', None), filter_ids))

def _clear_simulator(self):
"""
Expand Down
76 changes: 75 additions & 1 deletion pyNN/neuron/cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def guess_units(variable):

class NativeCellType(BaseCellType):

def can_record(self, variable):
def can_record(self, variable, location=None):
# crude check, could be improved
return bool(recordable_pattern.match(variable))

Expand Down Expand Up @@ -585,3 +585,77 @@ def clear_past_spikes(self):
end = self._spike_times.indwhere(">", h.t)
if end > 0:
self._spike_times.remove(0, end - 1) # range is inclusive


PROXIMAL = 0
DISTAL = 1


class NeuronTemplate(object): # move to ../cells.py

def __init__(self, morphology, cm, Ra, **ion_channel_parameters):
self.traces = {}
self.recording_time = False
self.spike_source = None
self.spike_times = h.Vector(0)

# create morphology
self.sections = {}
unresolved_connections = []
for segment in morphology.segments:
section = nrn.Section()
section.L = segment.length
section(PROXIMAL).diam = segment.proximal.diameter
section(DISTAL).diam = segment.distal.diameter
section.nseg = 1
section.cm = cm
section.Ra = Ra
if segment.parent:
connection_point = DISTAL # should generalize
if segment.parent.name in self.sections:
section.connect(self.sections[segment.parent.name], DISTAL, PROXIMAL)
else:
unresolved_connections.append((segment.name, segment.parent.name))
self.sections[segment.name] = section
for section_name, parent_name in unresolved_connections:
self.sections[section_name].connect(self.sections[parent_name], DISTAL, PROXIMAL)

# insert ion channels
for name, ion_channel in self.ion_channels.items():
parameters = ion_channel_parameters[name]
sections = ion_channel["sections"]
mechanism_name = ion_channel["mechanism"].model
for section_name in sections:
section = self.sections[section_name]
section.insert(mechanism_name)
#mechanism = getattr(section(0.5), mechanism_name)
for param_name, value in parameters.items():
setattr(section, param_name, value)
print(name, section_name, mechanism_name, param_name, value)
# temporary hack - we're not using the leak conductance from the hh mechanism,
# so set the conductance to zero
if mechanism_name == "hh":
setattr(section, "gl_hh", 0.0)

# set source section
if self.spike_source:
self.source_section = self.sections[self.spike_source]
elif "axon_initial_segment" in self.sections:
self.source_section = self.sections["axon_initial_segment"]
elif "soma" in self.sections:
self.source_section = self.sections["soma"]
else:
raise Exception("Source section for action potential not defined")
self.source = self.source_section(0.5)._ref_v
self.rec = h.NetCon(self.source, None, sec=self.source_section)

def memb_init(self):
for state_var in ('v',):
initial_value = getattr(self, '{0}_init'.format(state_var))
assert initial_value is not None
if state_var == 'v':
for section in self.sections.values():
for seg in section:
seg.v = initial_value
else:
raise NotImplementedError()
9 changes: 5 additions & 4 deletions pyNN/neuron/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _record(self, variable, new_ids, sampling_interval=None):
self._record_state_variable(id._cell, variable)

def _record_state_variable(self, cell, variable):
if variable.section is None:
if variable.location is None:
if hasattr(cell, 'recordable') and variable in cell.recordable:
hoc_var = cell.recordable[variable]
elif variable.name == 'v':
Expand All @@ -44,9 +44,10 @@ def _record_state_variable(self, cell, variable):
hoc_var = getattr(source, "_ref_%s" % var_name)
else:
if variable.name == 'v':
hoc_var = cell.source_section(0.5)._ref_v # or use "seg.v"?
source = cell.sections[variable.location](0.5)
hoc_var = source._ref_v
else:
source = cell.sections[variable.section](0.5)
source = cell.sections[variable.location](0.5)
ion_channel, var_name = variable.name.split(".")
mechanism_name, hoc_var_name = self.population.celltype.ion_channels[ion_channel]["mechanism"].variable_translations[var_name]
mechanism = getattr(source, mechanism_name)
Expand Down Expand Up @@ -121,7 +122,7 @@ def _get_all_signals(self, variable, ids, clear=False):
def _local_count(self, variable, filter_ids=None):
N = {}
if variable == 'spikes':
for id in self.filter_recorded(variable, filter_ids):
for id in self.filter_recorded(recording.Variable('spikes', None), filter_ids):
N[int(id)] = id._cell.spike_times.size()
else:
raise Exception("Only implemented for spikes")
Expand Down

0 comments on commit 675a577

Please sign in to comment.