Skip to content

Commit

Permalink
Adapt validation to the sonata original repo examples (#81)
Browse files Browse the repository at this point in the history
* guarantee string in attrs

* more reliable condition to start multiple group check

* consider `node_types_file` when validating nodes
  • Loading branch information
asanin-epfl committed Jul 21, 2020
1 parent e029b6f commit 639ac58
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 41 deletions.
95 changes: 59 additions & 36 deletions bluepysnap/circuit_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import itertools as it

import numpy as np
import pandas as pd
import click
from pathlib2 import Path
import h5py
Expand Down Expand Up @@ -206,6 +207,38 @@ def _get_population_groups(population_h5):
if isinstance(population_h5[name], h5py.Group) and name.isdigit()]


def _nodes_group_to_dataframe(group, types_file, population):
"""Transforms hdf5 population group to pandas DataFrame.
Args:
group: HDF5 nodes group
types_file: path to 'node_types_file' of Sonata config
population: HDF5 nodes population
Returns:
pd.DataFrame: dataframe with all group attributes
"""
df = pd.DataFrame(population['node_type_id'], columns=['type_id'])
size = df.size
df['id'] = population['node_id'] if 'node_id' in population else np.arange(size)
df['group_id'] = population['node_group_id'] if 'node_group_id' in population else 0
df['group_index'] = population['node_group_index'] \
if 'node_group_index' in population else np.arange(size)
df = df[df['group_id'] == int(str(_get_group_name(group)))]
for k, v in group.items():
if isinstance(v, h5py.Dataset):
df[k] = v[:]
if '@library' in group:
for k, v in group['@library'].items():
if isinstance(v, h5py.Dataset):
df[k] = v[:][df[k].to_numpy(dtype=int)]
if types_file is None:
return df
types = pd.read_csv(types_file, sep=r'\s+')
types.rename(columns={types.columns[0]: 'type_id'}, inplace=True)
return pd.merge(df, types, on='type_id', how='left')


def _get_group_size(group_h5):
"""Gets size of an edges or nodes group."""
for name in group_h5:
Expand Down Expand Up @@ -236,10 +269,11 @@ def _check_multi_groups(group_id_h5, group_index_h5, population):
return []


def _check_bio_nodes_group(group, config):
def _check_bio_nodes_group(group_df, group, config):
"""Checks biophysical nodes group for errors.
Args:
group_df (pd.DataFrame): nodes group as a dataframe
group (h5py.Group): nodes group in nodes .h5 file
config (dict): resolved bluepysnap config
Expand All @@ -250,20 +284,21 @@ def _check_bio_nodes_group(group, config):
def _check_rotations():
"""Checks for proper rotation fields."""
angle_fields = {'rotation_angle_xaxis', 'rotation_angle_yaxis', 'rotation_angle_zaxis'}
has_angle_fields = len(angle_fields - set(group)) < len(angle_fields)
has_rotation_fields = 'orientation' in group or has_angle_fields
has_angle_fields = len(angle_fields - group_attrs) < len(angle_fields)
has_rotation_fields = 'orientation' in group_attrs or has_angle_fields
if not has_rotation_fields:
errors.append(Error(Error.WARNING, 'Group {} of {} has no rotation fields'.
format(group_name, group.file.filename)))
if not has_angle_fields:
bbp_orient_fields = {'orientation_w', 'orientation_x', 'orientation_y', 'orientation_z'}
if 0 < len(bbp_orient_fields - set(group)) < len(bbp_orient_fields):
if 0 < len(bbp_orient_fields - group_attrs) < len(bbp_orient_fields):
errors.append(BbpError(Error.WARNING, 'Group {} of {} has no rotation fields'.
format(group_name, group.file.filename)))

errors = []
group_attrs = set(group_df.columns)
group_name = _get_group_name(group, parents=1)
missing_fields = sorted({'morphology', 'x', 'y', 'z'} - set(group))
missing_fields = sorted({'morphology', 'x', 'y', 'z'} - group_attrs)
if missing_fields:
errors.append(fatal('Group {} of {} misses biophysical fields: {}'.
format(group_name, group.file.filename, missing_fields)))
Expand All @@ -273,52 +308,37 @@ def _check_rotations():
errors += _check_components_dir('biophysical_neuron_models_dir', components)
if errors:
return errors
morph_files = group['morphology'] if _get_h5_data(group, '@library/morphology') is None \
else group['@library/morphology']
errors += _check_files(
'morphology: {}[{}]'.format(group_name, group.file.filename),
(Path(components['morphologies_dir'], m + '.swc') for m in morph_files),
(Path(components['morphologies_dir'], m + '.swc') for m in group_df['morphology']),
Error.WARNING)
bio_files = group['model_template'] if _get_h5_data(group, '@library/model_template') is None \
else group['@library/model_template']
bio_path = Path(components['biophysical_neuron_models_dir'])
errors += _check_files(
'model_template: {}[{}]'.format(group_name, group.file.filename),
(bio_path / _get_model_template_file(m) for m in bio_files),
(bio_path / _get_model_template_file(m) for m in group_df['model_template']),
Error.WARNING)
return errors


def _is_biophysical(group):
"""Check if a group contains biophysical nodes."""
if group['model_type'][0] == 'biophysical':
return True
if "@library/model_type" in group:
model_type_int = group['model_type'][0]
model_type = group["@library/model_type"][model_type_int]
if six.ensure_str(model_type) == 'biophysical':
return True
return False


def _check_nodes_group(group, config):
def _check_nodes_group(group_df, group, config):
"""Validates nodes group in nodes population.
Args:
group_df (pd.DataFrame): nodes group in nodes .h5 file
group (h5py.Group): nodes group in nodes .h5 file
config (dict): resolved bluepysnap config
Returns:
list: List of errors, empty if no errors
"""
REQUIRED_GROUP_NAMES = ['model_type', 'model_template']
missing_fields = sorted(set(REQUIRED_GROUP_NAMES) - set(group))
missing_fields = sorted(set(REQUIRED_GROUP_NAMES) - set(group_df.columns.tolist()))
if missing_fields:
return [fatal('Group {} of {} misses required fields: {}'
.format(_get_group_name(group, parents=1), group.file.filename,
missing_fields))]
elif _is_biophysical(group):
return _check_bio_nodes_group(group, config)
elif group_df['model_type'][0] == 'biophysical':
return _check_bio_nodes_group(group_df, group, config)
return []


Expand All @@ -335,25 +355,28 @@ def _check_nodes_population(nodes_dict, config):
required_datasets = ['node_type_id']
errors = []
nodes_file = nodes_dict.get('nodes_file')
node_types_file = nodes_dict.get('node_types_file', None)
with h5py.File(nodes_file, 'r') as h5f:
nodes = _get_h5_data(h5f, 'nodes')
if not nodes or len(nodes) == 0:
errors.append(fatal('No "nodes" in {}.'.format(nodes_file)))
return errors
return [fatal('No "nodes" in {}.'.format(nodes_file))]
for population_name in nodes:
population = nodes[population_name]
groups = _get_population_groups(population)
if len(groups) > 1:
required_datasets += ['node_group_id', 'node_group_index']
missing_datasets = sorted(set(required_datasets) - set(population))
if missing_datasets:
errors.append(fatal('Population {} of {} misses datasets {}'.
format(population_name, nodes_file, missing_datasets)))
elif 'node_group_id' in population:
errors += _check_multi_groups(
population['node_group_id'], population['node_group_index'], population)
return [fatal('Population {} of {} misses datasets {}'.
format(population_name, nodes_file, missing_datasets))]
if len(groups) > 1:
m_errors = _check_multi_groups(population['node_group_id'],
population['node_group_index'], population)
if len(m_errors) > 0:
return m_errors
for group in groups:
errors += _check_nodes_group(group, config)
group_df = _nodes_group_to_dataframe(group, node_types_file, population)
errors += _check_nodes_group(group_df, group, config)
return errors


Expand Down Expand Up @@ -417,7 +440,7 @@ def _check_edges_node_ids(nodes_ds, nodes):
list: List of errors, empty if no errors
"""
errors = []
node_population_name = nodes_ds.attrs['node_population']
node_population_name = six.ensure_str(nodes_ds.attrs['node_population'])
nodes_dict = _find_nodes_population(node_population_name, nodes)
if not nodes_dict:
errors.append(fatal('No node population for "{}"'.format(nodes_ds.name)))
Expand Down
2 changes: 1 addition & 1 deletion tests/data/circuit_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"nodes": [
{
"nodes_file": "$NETWORK_DIR/nodes.h5",
"node_types_file": null
"node_types_file": "$NETWORK_DIR/node_types.csv"
}
],
"edges": [
Expand Down
2 changes: 2 additions & 0 deletions tests/data/node_types.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
node_type_id model_processing
1 perisomatic
Binary file modified tests/data/nodes.h5
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_all():
circuit.config['networks']['nodes'][0] ==
{
'nodes_file': str(TEST_DATA_DIR / 'nodes.h5'),
'node_types_file': None,
'node_types_file': str(TEST_DATA_DIR / 'node_types.csv'),
}
)
assert isinstance(circuit.nodes, dict)
Expand Down
22 changes: 19 additions & 3 deletions tests/test_circuit_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def test_ok_circuit():
errors = test_module.validate(str(TEST_DATA_DIR / 'circuit_config.json'))
assert errors == []

with copy_circuit() as (_, config_copy_path):
with edit_config(config_copy_path) as config:
config['networks']['nodes'][0]['node_types_file'] = None
errors = test_module.validate(str(config_copy_path))
assert errors == []


def test_no_config_components():
with copy_circuit() as (_, config_copy_path):
Expand Down Expand Up @@ -163,6 +169,17 @@ def test_no_required_node_multi_group_datasets():
format(nodes_file, [ds]))]


def test_nodes_multi_group_wrong_group_id():
with copy_circuit() as (circuit_copy_path, config_copy_path):
nodes_file = circuit_copy_path / 'nodes.h5'
with h5py.File(nodes_file, 'r+') as h5f:
h5f.copy('nodes/default/0', 'nodes/default/1')
h5f['nodes/default/node_group_id'][-1] = 2
errors = test_module.validate(str(config_copy_path))
assert errors == [Error(Error.FATAL, 'Population /nodes/default of {} misses group(s): {}'.
format(nodes_file, {2}))]


def test_no_required_node_group_datasets():
required_datasets = ['model_template', 'model_type']
with copy_circuit() as (circuit_copy_path, config_copy_path):
Expand Down Expand Up @@ -230,8 +247,9 @@ def test_no_rotation_bbp_node_group_datasets():
nodes_file = circuit_copy_path / 'nodes.h5'
with h5py.File(nodes_file, 'r+') as h5f:
for ds in angle_datasets:
shape = h5f['nodes/default/0/' + ds].shape
del h5f['nodes/default/0/' + ds]
h5f['nodes/default/0/orientation_w'] = 0
h5f['nodes/default/0/'].create_dataset('orientation_w', shape, fillvalue=0)
errors = test_module.validate(str(config_copy_path), bbp_check=True)
assert errors == [
Error(Error.WARNING, 'Group default/0 of {} has no rotation fields'.format(nodes_file)),
Expand Down Expand Up @@ -538,8 +556,6 @@ def test_no_edge_all_node_ids():
del h5f['nodes/default/0']
errors = test_module.validate(str(config_copy_path))
assert errors == [
Error(Error.FATAL, 'Population /nodes/default of {} misses group(s): {}'.
format(nodes_file, {0})),
Error(Error.FATAL,
'/edges/default/source_node_id does not have node ids in its node population'),
Error(Error.FATAL,
Expand Down

0 comments on commit 639ac58

Please sign in to comment.