Skip to content

Commit

Permalink
Improve extract_input.from_diameter.model
Browse files Browse the repository at this point in the history
Change-Id: I0408d1fa87bfd092467e690cafbe52d979117d21
  • Loading branch information
adrien-berchet committed Jun 7, 2021
1 parent f98671b commit e324779
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 19 deletions.
54 changes: 54 additions & 0 deletions tests/test_extract_input.py
Expand Up @@ -103,6 +103,60 @@ def test_diameter_extract():
assert_raises(TNSError, extract_input.from_diameter.model,
load_neurons(os.path.join(_PATH, 'simple.swc')))

# Test on Population
res = extract_input.from_diameter.model(POPUL)
assert_equal(set(res.keys()), {'axon', 'basal', 'apical'})
expected = {
'basal': {
'Rall_ratio': 1.5,
'siblings_ratio': 1.0,
'taper': [
0.003361, 0.009487, 0.009931, 0.016477,
0.023878, 0.024852, 0.027809, 0.027975
],
'term': [0.3] * 8,
'trunk': [0.6 , 0.6 , 0.72, 0.84, 1.2 , 1.5 , 1.8 , 2.4],
'trunk_taper': [
0, 3.036411e-02, 3.053287e-02, 5.059035e-02,
1.168936e-01, 1.172027e-01, 0.15, 2.121002e-01
]
},
'apical': {
'Rall_ratio': 1.5,
'siblings_ratio': 1.0,
'taper': [
0.010331, 0.02135 , 0.02264 , 0.033914,
0.035313, 0.041116, 0.055751, 0.056211
],
'term': [0.3] * 8,
'trunk': [1.57, 7.51],
'trunk_taper': [0.05324615, 0.65223652]
},
'axon': {
'Rall_ratio': 1.5,
'siblings_ratio': 1.0,
'taper': [
0.04079 , 0.055286, 0.092382, 0.099524,
0.11986 , 0.140346, 0.214172, 0.407058
],
'term': [0.12] * 8,
'trunk': [2.1, 3.0],
'trunk_taper': [0.0435508, 0.0717109]
}
}

for neurite_type in ['basal', 'apical', 'axon']:
for key in expected[neurite_type].keys():
try:
assert_equal(res[neurite_type].keys(), expected[neurite_type].keys())
if key in ['taper', 'term', 'trunk', 'trunk_taper']:
tested = sorted(res[neurite_type][key])[:8]
else:
tested = res[neurite_type][key]
assert_array_almost_equal(tested, expected[neurite_type][key])
except AssertionError:
raise AssertionError(f"Failed for res[{neurite_type}][{key}]")


def test_distributions():
filename = os.path.join(_PATH, 'bio/')
Expand Down
35 changes: 17 additions & 18 deletions tns/extract_input/from_diameter.py
@@ -1,12 +1,10 @@
''' Module to extract morphometrics about diameters of cells.'''

from collections import defaultdict
from itertools import chain

import numpy as np
from neurom import get
from neurom.core.neuron import Section, iter_neurites
from neurom.core.population import Population
from neurom.core.types import tree_type_checker as is_type
from neurom.morphmath import segment_length, segment_radius

from tns.morphio_utils import NEUROM_TYPE_TO_STR
Expand Down Expand Up @@ -57,29 +55,30 @@ def model(input_object):
"""
values = {}

if isinstance(input_object, Population):
# If the population only contains the filenames instead of the actual morphologies,
# we have to ensure they are all loaded before working on them.
input_object = Population(list(input_object.neurons))
tapers = defaultdict(list)
trunk_tapers = defaultdict(list)
term_diams = defaultdict(list)
trunk_diams = defaultdict(list)

for neurite in iter_neurites(input_object):
tapers[neurite.type].append(section_taper(neurite))
trunk_tapers[neurite.type].append(section_trunk_taper(neurite))
term_diams[neurite.type].append(terminal_diam(neurite))
trunk_diams[neurite.type].append(2. * np.max(get('segment_radii', neurite)))

for neurite_type in set(tree.type for tree in input_object.neurites):
for neurite_type in tapers:
key = NEUROM_TYPE_TO_STR[neurite_type]

neurites = list(iter_neurites(input_object, filt=is_type(neurite_type)))
taper_c = np.array(list(chain(*tapers[neurite_type])))
trunk_taper = np.array(trunk_tapers[neurite_type])

taper = [section_taper(tree) for tree in neurites]
trunk_taper = np.array([section_trunk_taper(tree) for tree in neurites])
taper_c = np.array(list(chain(*taper)))
# Keep only positive, non-zero taper rates
taper_c = taper_c[np.where(taper_c > 0.00001)[0]]
trunk_taper = trunk_taper[np.where(trunk_taper >= 0.0)[0]]
term_diam = [terminal_diam(tree) for tree in neurites]
trunk_diam = [2. * np.max(get('segment_radii', tree)) for tree in neurites]

key = NEUROM_TYPE_TO_STR[neurite_type]

values[key] = {"taper": taper_c,
"term": list(chain(*term_diam)),
"trunk": trunk_diam,
"term": list(chain(*term_diams[neurite_type])),
"trunk": trunk_diams[neurite_type],
"trunk_taper": trunk_taper}

_check(values[key])
Expand Down
2 changes: 1 addition & 1 deletion tns/version.py
@@ -1,2 +1,2 @@
""" tns version """
VERSION = "2.4.4"
VERSION = "2.4.5.dev0"

0 comments on commit e324779

Please sign in to comment.