Skip to content

Commit

Permalink
Plots (#39)
Browse files Browse the repository at this point in the history
* Add plotting module 
* Add 'plot' extra in setup.py
* Add frame and spike report filter classes
* Add the property_dtypes function to nodes and edges
* Fix the `node_id` query (now works like a "or" for the ids.)
* Added new data for tests with different attributes for node populations
  • Loading branch information
tomdele committed May 15, 2020
1 parent a8f75e4 commit 0a04a19
Show file tree
Hide file tree
Showing 20 changed files with 717 additions and 35 deletions.
373 changes: 373 additions & 0 deletions bluepysnap/_plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,373 @@
# Copyright (c) 2020, EPFL/Blue Brain Project

# This file is part of BlueBrain SNAP library <https://github.com/BlueBrain/snap>

# This library is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License version 3.0 as published
# by the Free Software Foundation.

# This library is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.

# You should have received a copy of the GNU Lesser General Public License
# along with this library; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
"""Plotting module for the different snap objects."""
import logging
import numpy as np
import pandas as pd

from bluepysnap.exceptions import BluepySnapError
from bluepysnap.sonata_constants import Node
from bluepysnap.utils import roundrobin

L = logging.getLogger(__name__)


def _get_pyplot():
try:
import matplotlib.pyplot as plt
except ImportError as e:
msg = (
"Bluepysnap requirements are not installed.\n"
"Please pip install as follows:\n"
" pip install bluepysnap[plots] --upgrade"
)
raise ImportError(str(e) + "\n\n" + msg)
return plt


def spikes_firing_rate_histogram(filtered_report, time_binsize=None, ax=None): # pragma: no cover
"""Spike firing rate histogram.
This plot shows the number of nodes firing during a range of time.
Args:
time_binsize(None/int/float): bin size (milliseconds). If None, a binning heuristic is used
to create an histogram with ~100 spikes per bin in average.
ax(matplotlib.Axis): matplotlib Axis to draw on (if not specified, pyplot.gca() is used).
Returns:
matplotlib.Axis: Axis containing firing rate histogram.
Notes:
If no axis is provided through the ax=ax keyword argument,
then a default layout is set using pyplot.gca().
"""
# pylint: disable=too-many-locals
plt = _get_pyplot()
if time_binsize is not None and time_binsize <= 0:
raise BluepySnapError("Invalid time_binsize = {}. Should be > 0.".format(time_binsize))

spike_report = filtered_report.spike_report

times = filtered_report.report.index
node_count = filtered_report.report[['ids', 'population']].drop_duplicates().shape[0]

if len(times) == 0:
raise BluepySnapError("No data to display. You should check your "
"'group' query: {}.".format(spike_report.group))

time_start = np.min(times)
time_stop = np.max(times)

if time_binsize is None:
# heuristic for a nice bin size (~100 spikes per bin on average)
time_binsize = min(50.0, (time_stop - time_start) / ((len(times) / 100.) + 1.))

bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
hist, bin_edges = np.histogram(times, bins=bins)
freq = 1.0 * hist / node_count / (0.001 * time_binsize)

if ax is None:
ax = plt.gca()
ax.set_xlabel('Time [ms]')
ax.set_ylabel('PSTH [Hz]')

# use the middle of the bins instead of the start of the bin
ax.plot(0.5 * (bin_edges[1:] + bin_edges[:-1]), freq, label="PSTH", drawstyle='steps-mid')
return ax


def spike_raster(filtered_report, y_axis="node_id", ax=None): # pragma: no cover
"""Spike raster plot.
Shows a global overview of the circuit's firing nodes. The y axis can project either the
node_ids or any properties present in the different node populations.
Args:
y_axis (None/str): The property to display on the y axis. None is node_ids.
ax(matplotlib.Axis): matplotlib Axis to draw on (if not specified, pyplot.gca() is used).
Returns:
matplotlib.Axis: Axis containing Spikes raster plot.
Notes:
If no axis is provided through the ax=ax keyword argument,
then a default layout is set using pyplot.gca().
"""
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
plt = _get_pyplot()

spike_report = filtered_report.spike_report
population_names = filtered_report.spike_report.population_names

props = {"node_id_offset": 0,
"pop_separators": [],
"categorical_values": set(),
"ymin": np.inf,
"ymax": -np.inf
}

def _update_raster_properties():
if y_axis == "node_id":
props["node_id_offset"] += spikes.nodes.size
props["pop_separators"].append(props["node_id_offset"])
elif pd.api.types.is_categorical_dtype(spikes.nodes.property_dtypes[y_axis]):
props["categorical_values"].update(spikes.nodes.property_values(y_axis))
else:
props["ymin"] = min(props["ymin"], spikes.nodes.get(properties=y_axis).min())
props["ymax"] = max(props["ymax"], spikes.nodes.get(properties=y_axis).max())

report = filtered_report.report

dtype = spike_report[population_names[0]].nodes.property_dtypes[y_axis] if y_axis else None
if dtype and pd.api.types.is_categorical_dtype(dtype):
# this is to prevent the problems when concatenating categoricals with unknown categories
dtype = str
data = pd.Series(index=report.index, dtype=dtype)
for population in population_names:
spikes = spike_report[population]
mask = report["population"] == population
if y_axis == "node_id":
data.loc[mask] = report.loc[mask, "ids"] + props["node_id_offset"]
else:
ids = report.loc[mask, "ids"].to_numpy()
try:
ys = spikes.nodes.get(properties=y_axis)
except BluepySnapError:
continue
# astype is used to avoid problems with the categorical
data[mask] = ys[ids].astype(dtype).to_numpy()
_update_raster_properties()

data = data[data.notna()]
if ax is None:
ax = plt.gca()
ax.xaxis.grid()
ax.set_xlabel("Time [ms]")
ax.tick_params(axis='y', which='both', length=0)
ax.set_xlim(spike_report.time_start, spike_report.time_stop)
if y_axis == "node_id":
ax.set_ylim(0, props["node_id_offset"])
ax.set_ylabel("nodes")
else:
if np.issubdtype(type(data.iloc[0]), np.number):
# automatically expended by plt if ymin == ymax
ax.set_ylim(props["ymin"], props["ymax"])
else:
labels = sorted(list(props["categorical_values"]))
ax.set_yticks(np.arange(len(labels)))
ax.set_yticklabels(labels)
if len(labels) > 1:
ax.set_ylim(-0.5, len(labels) - 0.5)
ax.set_ylabel("{}".format(y_axis))

ax.scatter(data.index.to_numpy(), data.to_numpy(), s=10, marker='|')
if len(props["pop_separators"]) > 1:
for separator in props["pop_separators"]:
ax.axhline(y=separator, color='black', lw=2)
return ax


def spikes_isi(filtered_report, use_frequency=False, binsize=None, ax=None): # pragma: no cover
# pylint: disable=too-many-locals
"""Interspike interval histogram.
This plots show the binned time/frequency interval between to spikes for neurons.
Args:
use_frequency(bool): use inverse interspike interval times (Hz)
binsize(None/int/float): bin size in milliseconds or Hz. If None is used the binning is
delegated to matplolib and is done automatically.
ax(matplotlib.Axis): matplotlib Axis to draw on (if not specified, pyplot.gca() is used).
Returns:
matplotlib.Axis: axis containing the interspike interval histogram.
Notes:
If no axis is provided through the ax=ax keyword argument,
then a default layout is set using pyplot.gca().
"""
plt = _get_pyplot()
if binsize is not None and binsize <= 0:
raise BluepySnapError("Invalid binsize = {}. Should be > 0.".format(binsize))

gb = filtered_report.report.groupby(["ids", "population"])
values = np.concatenate([np.diff(node_spikes.index.to_numpy()) for _, node_spikes in gb])

if len(values) == 0:
raise BluepySnapError("No data to display. You should check your "
"'group' query: {}.".format(filtered_report.spike_report.group))
if use_frequency:
values = values[values > 0] # filter out zero intervals
values = 1000.0 / values

if binsize is None:
bins = 'auto'
else:
bins = np.arange(0, np.max(values), binsize)

if ax is None:
ax = plt.gca()
if use_frequency:
ax.set_xlabel('Frequency [Hz]')
else:
ax.set_xlabel('Interspike interval [ms]')
ax.set_ylabel('Bin weight')

ax.hist(values, bins=bins, edgecolor='black', density=True)
return ax


def spikes_firing_animation(filtered_report, x_axis=Node.X, y_axis=Node.Y,
dt=20, ax=None): # pragma: no cover
# pylint: disable=too-many-locals,too-many-arguments,anomalous-backslash-in-string
"""Simple animation of simulation spikes.
Each frame of the animation represents the spiking nodes during a period of dt ms seconds
in a coordinate system corresponding to the x, y or z axis of the circuit.
Args:
x_axis (str): Node enum that will determine the animation x_axis
y_axis (str): Node enum that will determine the animation y_axis
dt (int) : the time bin size of each frame in the video in ms
ax(matplotlib.Axis): matplotlib Axis to draw on (if not specified, pyplot.gca()
and plt.figure() are used).
Returns :
(matplotlib.animation.FuncAnimation, matplotlib.Axis): the matplotlib animation object and
the corresponding axis.
Notes:
From scripts:
>>> import matplotlib.pyplot as plt
>>> from bluepysnap import Simulation
>>> report = Simulation("config.json").spikes["my_population"]
>>> anim, ax = report.firing_animation()
>>> plt.show()
>>> # to save the animation : do not plt.show() and just anim.save('my_movie.mp4')
From notebooks:
>>> from IPython.display import HTML
>>> from bluepysnap import Simulation
>>> report = Simulation("config.json").spikes["my_population"]
>>> anim, ax = report.firing_animation()
>>> HTML(anim.to_html5_video())
"""
plt = _get_pyplot()
from matplotlib.animation import FuncAnimation

def _check_axis(axis):
"""Verifies axes values."""
axes = {Node.X, Node.Y, Node.Z}
if axis not in axes:
raise BluepySnapError('{} is not a valid axis'.format(axis))

_check_axis(x_axis)
_check_axis(y_axis)

spike_report = filtered_report.spike_report
population_names = filtered_report.spike_report.population_names
report = filtered_report.report

data = pd.DataFrame(index=report.index, columns=[x_axis, y_axis], dtype=np.float32)

for population in population_names:
spikes = spike_report[population]
pop_mask = report["population"] == population

ids = report.loc[pop_mask, "ids"].to_numpy()
try:
values = spikes.nodes.get(properties=[x_axis, y_axis]).loc[ids].to_numpy()
data.loc[pop_mask, [x_axis, y_axis]] = values
except BluepySnapError:
continue

data = data[data.notna()]

if ax is None:
fig = plt.figure()
ax = plt.gca()
ax.set_title('time = {}ms'.format(np.min(data.index)))
x_limits = [data[x_axis].min(), data[x_axis].max()]
y_limits = [data[y_axis].min(), data[y_axis].max()]
ax.set_xlim(*x_limits)
ax.set_ylim(*y_limits)
ax.set_xlabel('{} $\mu$m'.format(x_axis)) # noqa
ax.set_ylabel('{} $\mu$m'.format(y_axis)) # noqa

else:
fig = ax.figure

dots = ax.plot([], [], '.k')

def update_animation(frame):
"""Update the animation plots and axes."""
ax.set_title('time = ' + str(frame * dt) + ' ms')
mask = (data.index >= frame * dt) & (data.index <= (frame + 1) * dt)
positions = data.loc[mask, [x_axis, y_axis]].values
x = positions[:, 0]
y = positions[:, 1]
dots[0].set_data(x, y)
return dots

frames = list(range(int(data.index[0] / dt), int(data.index[-1] / dt)))
anim = FuncAnimation(fig, update_animation, frames=frames)
return anim, ax


def frame_trace(filtered_report, plot_type='mean', ax=None): # pragma: no cover
"""Returns a plot displaying the voltage of a node or a compartment as a function of time.
Args:
plot_type (str): string either `all` or `mean`. `all` will plot the first 15 traces from the
group. `mean` will plot the mean value of the node
ax: A plot axis object that will be updated
Returns:
matplotlib.Axis: axis containing the soma's traces.
"""
# pylint: disable=too-many-locals

plt = _get_pyplot()

if ax is None:
ax = plt.gca()
data_units = filtered_report.frame_report.data_units
if plot_type == "mean":
ax.set_ylabel('Avg volt. [{}]'.format(data_units))
elif plot_type == "all":
ax.set_ylabel('Voltage [{}]'.format(data_units))
ax.set_xlabel("Time [{}]".format(filtered_report.frame_report.time_units))
ax.set_xlim([filtered_report.t_start, filtered_report.t_stop])

if plot_type == "mean":
ax.plot(filtered_report.report.T.mean())
elif plot_type == "all":
max_per_pop = 15
levels = filtered_report.report.columns.levels
slicer = tuple(slice(None) if i != len(levels) - 1 else slice(None, max_per_pop)
for i in range(len(levels)))
data = filtered_report.report.loc[:, slicer].T
# create [[(pop1, id1), (pop1, id2),...], [(pop2, id1), (pop2, id2),...]]
indexes = [[(pop, idx) for idx in data.loc[pop].index] for pop in levels[0]]
# try to keep the maximum of ids from each population
kept_ids = list(roundrobin(*indexes))[:max_per_pop]
for _, row in data.loc[kept_ids].iterrows():
ax.plot(row)
else:
raise BluepySnapError("Unknown plot_type {}. Should be 'mean or 'all'.".format(plot_type))
return ax
9 changes: 9 additions & 0 deletions bluepysnap/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ def property_names(self):
"""Set of available edge properties."""
return self._property_names | self._dynamics_params_names

@cached_property
def property_dtypes(self):
"""Returns the dtypes of all the properties.
Returns:
pandas.Series: series indexed by field name with the corresponding dtype as value.
"""
return self.properties([0], list(self.property_names)).dtypes.sort_index()

def container_property_names(self, container):
"""Lists the ConstContainer properties shared with the EdgePopulation.
Expand Down

0 comments on commit 0a04a19

Please sign in to comment.