Skip to content

Commit

Permalink
Merge pull request #773 from apdavison/enhance-translation
Browse files Browse the repository at this point in the history
ok, I thought this was a standalone PR, but some Brian2 tests will fail until the matching changes are made. Merging anyway, essentially for documentation purposes.
  • Loading branch information
apdavison committed Feb 3, 2023
2 parents 97ba231 + cb5a48e commit ddfd58f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 15 deletions.
34 changes: 22 additions & 12 deletions pyNN/standardmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@
# ==============================================================================


def build_scaling_functions(pynn_name, sim_name, scale_factor):
def f(**p):
return p[pynn_name] * scale_factor
def g(**p):
return p[sim_name] / scale_factor
return f, g


def build_translations(*translation_list):
"""
Build a translation dictionary from a list of translations/transformations.
Expand All @@ -49,16 +57,19 @@ def build_translations(*translation_list):
if len(item) == 2: # no transformation
f = pynn_name
g = sim_name
type_ = "simple"
elif len(item) == 3: # simple multiplicative factor
scale_factor = item[2]
f = "float(%g)*%s" % (scale_factor, pynn_name)
g = "%s/float(%g)" % (sim_name, scale_factor)
f, g = build_scaling_functions(pynn_name, sim_name, scale_factor)
type_ = "scaled"
elif len(item) == 4: # more complex transformation
f = item[2]
g = item[3]
type_ = "computed"
translations[pynn_name] = {'translated_name': sim_name,
'forward_transform': f,
'reverse_transform': g}
'reverse_transform': g,
'type': type_}
return translations


Expand Down Expand Up @@ -133,21 +144,19 @@ def simple_parameters(self):
"""Return a list of parameters for which there is a one-to-one
correspondance between standard and native parameter values."""
return [name for name in self.translations
if self.translations[name]['forward_transform'] == name]
if self.translations[name]['type'] == "simple"]

def scaled_parameters(self):
"""Return a list of parameters for which there is a unit change between
standard and native parameter values."""
def scaling(trans):
return (not callable(trans)) and ("float" in trans)
return [name for name in self.translations
if scaling(self.translations[name]['forward_transform'])]
if self.translations[name]['type'] == "scaled"]

def computed_parameters(self):
"""Return a list of parameters whose values must be computed from
more than one other parameter."""
return [name for name in self.translations
if name not in self.simple_parameters() + self.scaled_parameters()]
if self.translations[name]['type'] == "computed"]

def computed_parameters_include(self, parameter_names):
return any(name in self.computed_parameters() for name in parameter_names)
Expand Down Expand Up @@ -203,12 +212,13 @@ def __getattr__(self, name):
"e.g. source.amplitude = 0.5, or use 'set_parameters()' " \
"e.g. source.set_parameters(amplitude=0.5)"
raise AttributeError(err_msg)

try:
val = self.__getattribute__(name)
except AttributeError:
val = self.get_parameters()[name]
except KeyError:
try:
val = self.get_parameters()[name]
except KeyError:
val = self.__getattribute__(name)
except AttributeError:
raise errors.NonExistentParameterError(name,
self.__class__.__name__,
self.get_parameter_names())
Expand Down
8 changes: 5 additions & 3 deletions test/unittests/test_standardmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ def test_build_translations():
('c', 'C', 'c + a', 'C - A')
)
assert set(t.keys()) == set(['a', 'b', 'c'])
assert set(t['a'].keys()) == set(['translated_name', 'forward_transform', 'reverse_transform'])
assert set(t['a'].keys()) == set(['translated_name', 'forward_transform', 'reverse_transform', 'type'])
assert t['a']['translated_name'] == 'A'
assert t['a']['forward_transform'] == 'a'
assert t['a']['reverse_transform'] == 'A'
assert t['b']['translated_name'] == 'B'
assert t['b']['forward_transform'] == 'float(1000)*b'
assert t['b']['reverse_transform'] == 'B/float(1000)'
assert callable(t['b']['forward_transform'])
assert t['b']['forward_transform'](b=7) == 7000
assert callable(t['b']['reverse_transform'])
assert t['b']['reverse_transform'](B=7000) == 7
assert t['c']['translated_name'] == 'C'
assert t['c']['forward_transform'] == 'c + a'
assert t['c']['reverse_transform'] == 'C - A'
Expand Down

0 comments on commit ddfd58f

Please sign in to comment.