Skip to content

Commit

Permalink
Improved wrapping (#38)
Browse files Browse the repository at this point in the history
* Added init_subclass that registers a wrapper in the PythonHocInterpreter

* Improved the PythonHocInterpreter's HocObject wrapping. closes #28
  • Loading branch information
Helveg committed Oct 11, 2020
1 parent 5280d3d commit e74ac06
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 20 deletions.
1 change: 1 addition & 0 deletions patch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def p(self):
global _p
if _p is None:
_p = PythonHocModule.PythonHocInterpreter()
PythonHocModule.PythonHocInterpreter._process_registration_queue()
return _p

def connection(self, source, target, strict=True):
Expand Down
65 changes: 45 additions & 20 deletions patch/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .objects import PythonHocObject, NetCon, PointProcess, VecStim, Section, IClamp, SectionRef
from .objects import PythonHocObject, NetCon, PointProcess, VecStim, Section, IClamp, SectionRef, _get_obj_registration_queue
from .core import (
transform,
transform_netcon,
Expand All @@ -8,44 +8,67 @@
)
from .exceptions import *
from .error_handler import catch_hoc_error, CatchNetCon, CatchSectionAccess, _suppress_nrn
from functools import wraps


class PythonHocInterpreter:
def __init__(self):
from neuron import h

self.__dict__["_PythonHocInterpreter__h"] = h
# Wrapping should occur around all calls to functions that share a name with
# child classes of the PythonHocObject like h.Section, h.NetStim, h.NetCon
self.__object_classes = PythonHocObject.__subclasses__().copy()
self.__requires_wrapping = [cls.__name__ for cls in self.__object_classes]
self.__loaded_extensions = []
self.load_file("stdrun.hoc")
self.runtime = 0

@classmethod
def _process_registration_queue(cls):
"""
Most PythonHocObject classes (all those provided by Patch for sure) are created
before the PythonHocInterpreter class is available. Yet they require the class to
combine the original pointer from ``h.<object>`` (e.g. ``h.Section``) with a
function that defers to their constructor so that you can call ``p.Section()``
and create a PythonHocObject wrapped around the underlying ``h`` pointer.
This function is called right after the PythonHocInterpreter class is created so
that PythonHocObjects can place themselves in a queue and have themselves
registered into the class right after it's ready.
"""
for hoc_object_class in _get_obj_registration_queue():
cls.register_hoc_object(hoc_object_class)

@classmethod
def register_hoc_object(interpreter_class, hoc_object_class):
# We shouldn't use multiple copies of h in case of monkey patches but since we
# need only native functions that return a hoc object this is fine.
from neuron import h

if hoc_object_class.__name__ in interpreter_class.__dict__:
# The function call was overridden in the interpreter and should not be destroyed.
return
hoc_object_name = hoc_object_class.__name__
# If the original interpreter doesn't have a function with the same name we can't
# simplify the constructor of the PythonHocObject and shouldn't wrap it.
if hasattr(h, hoc_object_name):
# Wrap it in the interpreter with a call to the underlying `h` to obtain a pointer
# and use that to make our PythonHocObject
factory = getattr(h, hoc_object_name)
@wraps(hoc_object_class.__init__)
def wrapper(interpreter_instance, *args, **kwargs):
hoc_ptr = factory(*args, **kwargs)
return hoc_object_class(interpreter_instance, hoc_ptr)

setattr(PythonHocInterpreter, hoc_object_class.__name__, wrapper)

def __getattr__(self, attr_name):
# Get the missing attribute from h, if it requires wrapping return a wrapped
# object instead.
attr = getattr(self.__h, attr_name)
if attr_name in self.__requires_wrapping:
return self.wrap(attr, attr_name)
else:
return attr
# Get the missing attribute from h
return getattr(self.__h, attr_name)

def __setattr__(self, attr, value):
if hasattr(self.__h, attr):
setattr(self.__h, attr, value)
else:
self.__dict__[attr] = value

def wrap(self, factory, name):
def wrapper(*args, **kwargs):
obj = factory(*args, **kwargs)
cls = next((c for c in self.__object_classes if c.__name__ == name), None)
return cls(self, obj)

return wrapper

def NetCon(self, source, target, *args, **kwargs):
nrn_source = transform_netcon(source)
nrn_target = transform_netcon(target)
Expand Down Expand Up @@ -337,3 +360,5 @@ def _broadcast(self, data, root=0):
raise BroadcastError(
"Root node did not transmit. Look for root node error."
) from None

PythonHocInterpreter._process_registration_queue()
17 changes: 17 additions & 0 deletions patch/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,20 @@
from .error_handler import catch_hoc_error, CatchRecord


_registration_queue = []


class PythonHocObject:
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
try:
from .interpreter import PythonHocInterpreter
except ImportError:
_registration_queue.append(cls)
return

PythonHocInterpreter.register_hoc_object(cls)

def __init__(self, interpreter, ptr):
# Initialize ourselves with a reference to our own "pointer"
# and prepare a list for other references.
Expand Down Expand Up @@ -309,3 +322,7 @@ def stimulate(self, pattern=None, weight=0.04, delay=0.0, **kwargs):
stimulus = self._interpreter.VecStim(pattern=pattern)
self._interpreter.NetCon(stimulus, self, weight=weight, delay=delay)
return stimulus


def _get_obj_registration_queue():
return _registration_queue
19 changes: 19 additions & 0 deletions tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@
from patch.exceptions import *


class TestPatchRegistration(_shared.NeuronTestCase):
"""
Check that the registration of PythonHocObjects works. (Will almost never be relevant
since most actual HocObjects will be covered by Patch and use the registration queue
rather than immediate registration; and any class names that don't correspond to an
actual ``h.<name>`` function don't create a wrapper)
"""

def test_registration(self):
from patch import p

# Create a new PythonHocObject, no wrapper will be added as it does not exist in h
class NewHocObject(patch.objects.PythonHocObject):
pass

# Nothing to test, but the import inside ``PythonHocObject.__init_subclass__``
# should complete and the call to ``PythonHocInterpreter.register_hoc_object``
# should be covered in test coverage results.

class TestPatch(_shared.NeuronTestCase):
"""
Check Patch basics like object wrapping and the standard interface.
Expand Down

0 comments on commit e74ac06

Please sign in to comment.