Skip to content

Commit

Permalink
Bugfix and improvements for connect_* functions (#145)
Browse files Browse the repository at this point in the history
* fix connector calls (**kwargs)
* fix mutlithreading/mpi reload issue
* improve attribute generation support
* correct all_to_all edge type to int
  • Loading branch information
Silmathoron committed Dec 4, 2020
1 parent 5da1d8a commit 9b8f523
Show file tree
Hide file tree
Showing 10 changed files with 141 additions and 57 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ doc/.build/
doc/_build/
doc/examples/sp_graph
doc/examples/.ipynb_checkpoints/
doc/gallery/graph_properties/
doc/gallery/graph_structure/

# generated rst
doc/modules/examples
Expand Down
2 changes: 1 addition & 1 deletion nngt/generation/cconnect.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _all_to_all(cnp.ndarray[size_t, ndim=1] source_ids,
def _total_degree_list(cnp.ndarray[int64, ndim=1] source_ids,
cnp.ndarray[int64, ndim=1] target_ids,
cnp.ndarray[int64, ndim=1] degree_list,
bool directed=True, bool multigraph=False):
bool directed=True, bool multigraph=False, **kwargs):
''' Called from _from_degree_list '''
cdef:
size_t num_source = len(source_ids)
Expand Down
15 changes: 8 additions & 7 deletions nngt/generation/connect_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _all_to_all(source_ids, target_ids, directed=True, multigraph=False,
edges[current_edges:next_enum, 1] = target_ids
current_edges = next_enum
else:
edges = np.empty((num_edges, 2))
edges = np.empty((num_edges, 2), dtype=int)
edges[:, 0] = np.repeat(source_ids, num_targets)
edges[:, 1] = np.tile(target_ids, num_sources)

Expand All @@ -111,7 +111,7 @@ def _all_to_all(source_ids, target_ids, directed=True, multigraph=False,


def _total_degree_list(source_ids, target_ids, degree_list, directed=True,
multigraph=False):
multigraph=False, **kwargs):
''' Called from _from_degree_list '''
degree_list = np.array(degree_list, dtype=int)

Expand Down Expand Up @@ -532,7 +532,8 @@ def _erdos_renyi(source_ids, target_ids, density=None, edges=None,
return ia_edges


def _price_scale_free(ids, m, c, gamma, reciprocity, directed, multigraph):
def _price_scale_free(ids, m, c, gamma, reciprocity, directed, multigraph,
**kwargs):
'''
Generate a Price network.
'''
Expand Down Expand Up @@ -595,8 +596,8 @@ def _price_scale_free(ids, m, c, gamma, reciprocity, directed, multigraph):
return edges


def _circular(source_ids, target_ids, coord_nb, reciprocity, directed,
reciprocity_choice="random"):
def _circular(source_ids, target_ids, coord_nb, reciprocity=1, directed=True,
reciprocity_choice="random", **kwargs):
'''
Circular graph.
Expand Down Expand Up @@ -625,7 +626,7 @@ def _circular(source_ids, target_ids, coord_nb, reciprocity, directed,


def _circular_directed_recip(node_ids, coord_nb, reciprocity,
reciprocity_choice="random"):
reciprocity_choice="random", **kwargs):
''' Circular graph with given reciprocity '''
nodes = len(node_ids)
edges = int(0.5*nodes*coord_nb*(1 + reciprocity))
Expand Down Expand Up @@ -698,7 +699,7 @@ def _circular_directed_recip(node_ids, coord_nb, reciprocity,
return np.array([sources, targets], dtype=int).T


def _circular_full(node_ids, coord_nb, directed):
def _circular_full(node_ids, coord_nb, directed, **kwargs):
''' Create a circular graph with all possible edges '''
nodes = len(node_ids)

Expand Down
15 changes: 13 additions & 2 deletions nngt/generation/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nngt.generation import graph_connectivity as gc
from nngt.lib import is_iterable, nonstring_container
from nngt.lib.test_functions import deprecated
from nngt.lib.rng_tools import _generate_random


__all__ = [
Expand Down Expand Up @@ -118,9 +119,19 @@ def connect_nodes(network, sources, targets, graph_model, density=None,
attr = {}

if 'weights' in kwargs:
attr['weight'] = kwargs['weights']
ww = kwargs['weights']

if isinstance(ww, dict):
attr['weight'] = _generate_random(len(elist), ww)
else:
attr['weight'] = ww
if 'delays' in kwargs:
attr['delay'] = kwargs['delays']
dd = kwargs['delays']

if isinstance(ww, dict):
attr['delay'] = _generate_random(len(elist), dd)
else:
attr['delay'] = dd
if network.is_spatial() and distance:
attr['distance'] = distance

Expand Down
68 changes: 34 additions & 34 deletions nngt/generation/graph_connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@
on_master_process)


__all__ = [
'all_to_all',
'circular',
'distance_rule',
'erdos_renyi',
'fixed_degree',
'from_degree_list',
'gaussian_degree',
'newman_watts',
'random_scale_free',
'price_scale_free',
'watts_strogatz',
]


# do default import

from .connect_algorithms import *
Expand Down Expand Up @@ -96,21 +111,6 @@
raise e


__all__ = [
'all_to_all',
'circular',
'distance_rule',
'erdos_renyi',
'fixed_degree',
'from_degree_list',
'gaussian_degree',
'newman_watts',
'random_scale_free',
'price_scale_free',
'watts_strogatz',
]


# ----------------------------- #
# Specific degree distributions #
# ----------------------------- #
Expand Down Expand Up @@ -149,7 +149,7 @@ def all_to_all(nodes=0, weighted=True, directed=True, multigraph=False,
Note
----
`nodes` is required unless `population` is provided.
`nodes` is required unless `population` is provided.
Returns
-------
Expand Down Expand Up @@ -269,7 +269,7 @@ def fixed_degree(degree, degree_type='in', nodes=0, reciprocity=-1.,
The type of the fixed degree, among ``'in'``, ``'out'`` or ``'total'``.
@todo
`'total'` not implemented yet.
`'total'` not implemented yet.
nodes : int, optional (default: None)
The number of nodes in the graph.
Expand Down Expand Up @@ -299,9 +299,9 @@ def fixed_degree(degree, degree_type='in', nodes=0, reciprocity=-1.,
Note
----
`nodes` is required unless `from_graph` or `population` is provided.
If an `from_graph` is provided, all preexistant edges in the
object will be deleted before the new connectivity is implemented.
`nodes` is required unless `from_graph` or `population` is provided.
If an `from_graph` is provided, all preexistant edges in the
object will be deleted before the new connectivity is implemented.
Returns
-------
Expand Down Expand Up @@ -352,7 +352,7 @@ def gaussian_degree(avg, std, degree_type='in', nodes=0, reciprocity=-1.,
avg : float
The value of the average degree.
std : float
The standard deviation of the Gaussian distribution.
The standard deviation of the Gaussian distribution.
degree_type : str, optional (default: 'in')
The type of the fixed degree, among 'in', 'out' or 'total' (or the
full version: 'in-degree'...)
Expand Down Expand Up @@ -390,9 +390,9 @@ def gaussian_degree(avg, std, degree_type='in', nodes=0, reciprocity=-1.,
Note
----
`nodes` is required unless `from_graph` or `population` is provided.
If an `from_graph` is provided, all preexistant edges in the object
will be deleted before the new connectivity is implemented.
`nodes` is required unless `from_graph` or `population` is provided.
If an `from_graph` is provided, all preexistant edges in the object
will be deleted before the new connectivity is implemented.
"""
# set node number and library graph
graph_gd = from_graph
Expand Down Expand Up @@ -477,9 +477,9 @@ def erdos_renyi(density=None, nodes=0, edges=None, avg_deg=None,
Note
----
`nodes` is required unless `from_graph` or `population` is provided.
If an `from_graph` is provided, all preexistant edges in the
object will be deleted before the new connectivity is implemented.
`nodes` is required unless `from_graph` or `population` is provided.
If an `from_graph` is provided, all preexistant edges in the
object will be deleted before the new connectivity is implemented.
"""
# set node number and library graph
graph_er = from_graph
Expand Down Expand Up @@ -563,11 +563,11 @@ def random_scale_free(in_exp, out_exp, nodes=0, density=None, edges=None,
Note
----
As reciprocity increases, requested values of `in_exp` and `out_exp`
will be less and less respected as the distribution will converge to a
common exponent :math:`\gamma = (\gamma_i + \gamma_o) / 2`.
Parameter `nodes` is required unless `from_graph` or `population` is
provided.
As reciprocity increases, requested values of `in_exp` and `out_exp`
will be less and less respected as the distribution will converge to a
common exponent :math:`\gamma = (\gamma_i + \gamma_o) / 2`.
Parameter `nodes` is required unless `from_graph` or `population` is
provided.
"""
# set node number and library graph
graph_rsf = from_graph
Expand Down Expand Up @@ -858,7 +858,7 @@ def newman_watts(coord_nb, proba_shortcut=None, reciprocity_circular=1.,
Note
----
`nodes` is required unless `from_graph` or `population` is provided.
`nodes` is required unless `from_graph` or `population` is provided.
"""
if multigraph:
raise ValueError("`multigraph` is not supported for Watts-Strogatz.")
Expand Down Expand Up @@ -954,7 +954,7 @@ def watts_strogatz(coord_nb, proba_shortcut=None, reciprocity_circular=1.,
Note
----
`nodes` is required unless `from_graph` or `population` is provided.
`nodes` is required unless `from_graph` or `population` is provided.
"""
if multigraph:
raise ValueError("`multigraph` is not supported for Newman-Watts.")
Expand Down
10 changes: 4 additions & 6 deletions nngt/lib/nngt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,10 +355,9 @@ def _post_update_parallelism(new_config, old_gl, old_msd, old_mt, old_mpi):
new_multithreading = new_config.get("multithreading", old_mt)

if new_multithreading != old_mt:
# connectors.py and rewiring.py use directly
# nngt.generation.graph_connectivity so they always access the reloaded
# version and we don't need to reload them
reload(sys.modules["nngt"].generation.graph_connectivity)
reload(sys.modules["nngt"].generation.connectors)
reload(sys.modules["nngt"].generation.rewiring)

# if multithreading loading failed, set omp back to 1
if not nngt._config['multithreading']:
Expand Down Expand Up @@ -389,10 +388,9 @@ def _post_update_parallelism(new_config, old_gl, old_msd, old_mt, old_mpi):

# reload for mpi
if new_config.get('mpi', old_mpi) != old_mpi:
# connectors.py and rewiring.py use directly
# nngt.generation.graph_connectivity so they always access the reloaded
# version and we don't need to reload them
reload(sys.modules["nngt"].generation.graph_connectivity)
reload(sys.modules["nngt"].generation.connectors)
reload(sys.modules["nngt"].generation.rewiring)

# set graph-tool config
_set_gt_config(old_gl, new_config)
Expand Down
28 changes: 24 additions & 4 deletions nngt/lib/rng_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,33 @@ def seed(msd=None, seeds=None):
# ----------------------------- #

def _generate_random(number, instructions):
name = instructions[0]
if name in di_dfunc:
return di_dfunc[name](None, None, number, *instructions[1:])
else:
name = "not defined"

if isinstance(instructions, dict):
name = instructions["distribution"]

instructions = {
k: v for k, v in instructions.items() if k != "distribution"
}

if name in di_dfunc:
return di_dfunc[name](None, None, number, **instructions)

raise NotImplementedError(
"Unknown distribution: '{}'. Supported distributions " \
"are {}".format(name, ", ".join(di_dfunc.keys())))
elif nonstring_container(instructions):
name = instructions[0]

if name in di_dfunc:
return di_dfunc[name](None, None, number, *instructions[1:])

raise NotImplementedError(
"Unknown distribution: '{}'. Supported distributions " \
"are {}".format(name, ", ".join(di_dfunc.keys())))

raise NotImplementedError(
"Unknown instructions: '{}'".format(instructions))


def _eprop_distribution(graph, distrib_type, matrix=False, elist=None,
Expand Down
2 changes: 1 addition & 1 deletion nngt/plot/chord_diag
1 change: 0 additions & 1 deletion testing/test_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def test_attributes_are_copied():

assert np.all(np.isclose(vv, g.node_attributes["ntest"]))
assert not np.all(np.isclose(vv, ntest))



# ---------- #
Expand Down
Loading

0 comments on commit 9b8f523

Please sign in to comment.