Skip to content

Commit

Permalink
fixed gathering of somavs if number of cells < MPI::Size (#323)
Browse files Browse the repository at this point in the history
  • Loading branch information
espenhgn committed Mar 22, 2021
1 parent aa0a3ba commit 9eb9fba
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 20 deletions.
40 changes: 30 additions & 10 deletions examples/example_network/example_network.py
Expand Up @@ -315,15 +315,25 @@ def draw_lineplot(
if RANK == 0:
somavs = []
for i, name in enumerate(population_names):
somavs_pop = None # avoid undeclared variable
for j, cell in enumerate(network.populations[name].cells):
if j == 0:
somavs_pop = cell.somav
else:
somavs_pop = np.vstack((somavs_pop, cell.somav))
if RANK == 0:
for j in range(1, SIZE):
somavs_pop = np.vstack((somavs_pop,
COMM.recv(source=j, tag=15)))
recv = COMM.recv(source=j, tag=15)
if somavs_pop is None:
if recv is not None:
somavs_pop = recv
else:
continue
else:
if recv is not None:
somavs_pop = np.vstack((somavs_pop, recv))
if somavs_pop.ndim == 1:
somavs_pop = somavs_pop.reshape((1, -1))
somavs.append(somavs_pop)
else:
COMM.send(somavs_pop, dest=0, tag=15)
Expand All @@ -342,7 +352,7 @@ def draw_lineplot(
for spt, gid in zip(spts, gids):
t = np.r_[t, spt]
g = np.r_[g, np.zeros(spt.size) + gid]
ax.plot(t[t >= 200], g[t >= 200], '.', label=name)
ax.plot(t[t >= 200], g[t >= 200], '.', ms=3, label=name)
ax.legend(loc=1)
remove_axis_junk(ax, lines=['right', 'top'])
ax.set_xlabel('t (ms)')
Expand All @@ -354,11 +364,16 @@ def draw_lineplot(

# somatic potentials
fig = plt.figure()
gs = GridSpec(5, 1)
ax = fig.add_subplot(gs[:4])
gs = GridSpec(4, 1)
ax = fig.add_subplot(gs[:2])
if somavs[0].shape[0] > 10:
somavs_pop = ss.decimate(somavs[0][:10], q=16, axis=-1,
zero_phase=True)
else:
somavs_pop = ss.decimate(somavs[0], q=16, axis=-1,
zero_phase=True)
draw_lineplot(ax,
ss.decimate(somavs[0][::4], q=16, axis=-1,
zero_phase=True),
somavs_pop,
dt=network.dt * 16,
T=(200, 1200),
scaling_factor=1.,
Expand All @@ -376,10 +391,15 @@ def draw_lineplot(
ax.set_title('somatic potentials')
ax.set_xlabel('')

ax = fig.add_subplot(gs[4])
ax = fig.add_subplot(gs[2:])
if somavs[1].shape[0] > 10:
somavs_pop = ss.decimate(somavs[1][:10], q=16, axis=-1,
zero_phase=True)
else:
somavs_pop = ss.decimate(somavs[1], q=16, axis=-1,
zero_phase=True)
draw_lineplot(ax,
ss.decimate(somavs[1][::4], q=16, axis=-1,
zero_phase=True),
somavs_pop,
dt=network.dt * 16,
T=(200, 1200),
scaling_factor=1.,
Expand Down
40 changes: 30 additions & 10 deletions examples/example_network/example_network_to_file.py
Expand Up @@ -319,15 +319,25 @@ def draw_lineplot(
if RANK == 0:
somavs = []
for i, name in enumerate(population_names):
somavs_pop = None # avoid undeclared variable
for j, cell in enumerate(network.populations[name].cells):
if j == 0:
somavs_pop = cell.somav
else:
somavs_pop = np.vstack((somavs_pop, cell.somav))
if RANK == 0:
for j in range(1, SIZE):
somavs_pop = np.vstack((somavs_pop,
COMM.recv(source=j, tag=15)))
recv = COMM.recv(source=j, tag=15)
if somavs_pop is None:
if recv is not None:
somavs_pop = recv
else:
continue
else:
if recv is not None:
somavs_pop = np.vstack((somavs_pop, recv))
if somavs_pop.ndim == 1:
somavs_pop = somavs_pop.reshape((1, -1))
somavs.append(somavs_pop)
else:
COMM.send(somavs_pop, dest=0, tag=15)
Expand All @@ -346,7 +356,7 @@ def draw_lineplot(
for spt, gid in zip(spts, gids):
t = np.r_[t, spt]
g = np.r_[g, np.zeros(spt.size) + gid]
ax.plot(t[t >= 200], g[t >= 200], '.', label=name)
ax.plot(t[t >= 200], g[t >= 200], '.', ms=3, label=name)
ax.legend(loc=1)
remove_axis_junk(ax, lines=['right', 'top'])
ax.set_xlabel('t (ms)')
Expand All @@ -358,11 +368,16 @@ def draw_lineplot(

# somatic potentials
fig = plt.figure()
gs = GridSpec(5, 1)
ax = fig.add_subplot(gs[:4])
gs = GridSpec(4, 1)
ax = fig.add_subplot(gs[:2])
if somavs[0].shape[0] > 10:
somavs_pop = ss.decimate(somavs[0][:10], q=16, axis=-1,
zero_phase=True)
else:
somavs_pop = ss.decimate(somavs[0], q=16, axis=-1,
zero_phase=True)
draw_lineplot(ax,
ss.decimate(somavs[0][::4], q=16, axis=-1,
zero_phase=True),
somavs_pop,
dt=network.dt * 16,
T=(200, 1200),
scaling_factor=1.,
Expand All @@ -380,10 +395,15 @@ def draw_lineplot(
ax.set_title('somatic potentials')
ax.set_xlabel('')

ax = fig.add_subplot(gs[4])
ax = fig.add_subplot(gs[2:])
if somavs[1].shape[0] > 10:
somavs_pop = ss.decimate(somavs[1][:10], q=16, axis=-1,
zero_phase=True)
else:
somavs_pop = ss.decimate(somavs[1], q=16, axis=-1,
zero_phase=True)
draw_lineplot(ax,
ss.decimate(somavs[1][::4], q=16, axis=-1,
zero_phase=True),
somavs_pop,
dt=network.dt * 16,
T=(200, 1200),
scaling_factor=1.,
Expand Down

0 comments on commit 9eb9fba

Please sign in to comment.