Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations, doc-string re-formatting, code cleanup. #100

Merged
merged 7 commits into from
Jul 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,7 @@ results/

# Swap files.
*.swp

# PyCharm project folder.
.idea/

14 changes: 4 additions & 10 deletions bindsnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
from .utils import *
from .models import *
from .network import *
from .analysis import *
from .datasets import *
from .encoding import *
from .pipeline import *
from .learning import *
from .evaluation import *
from .environment import *
__all__ = [
'utils', 'network', 'models', 'analysis', 'datasets', 'encoding', 'pipeline', 'learning', 'evaluation',
'environment'
]
5 changes: 3 additions & 2 deletions bindsnet/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .plotting import *
from .visualization import *
__all__ = [
'plotting', 'visualization'
]
64 changes: 32 additions & 32 deletions bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt

from matplotlib.image import AxesImage
from typing import Tuple, List, Optional
from mpl_toolkits.axes_grid1 import make_axes_locatable

from ..utils import reshape_locally_connected_weights


plt.ion()

def plot_input(image, inpt, label=None, axes=None, ims=None, figsize=(8, 4)):
'''
Plots a two-dimensional image and its corresponding spike-train representation.

Inputs:

| :code:`image` (:code:`torch.Tensor`): A 2D array of floats depicting an input image.
| :code:`inpt` (:code:`torch.Tensor`): A 2D array of floats depicting an image's spike-train encoding.
| :code:`ims` (:code:`list(matplotlib.image.AxesImage)`): Used for re-drawing the input plots.
| :code:`figsize` (:code:`tuple(int)`): Horizontal, vertical figure size in inches.

Returns:
def plot_input(image: torch.Tensor, inpt: torch.Tensor, label: Optional[int] = None, axes: List['AxesSubplot'] = None,
ims: List[AxesImage] = None,
figsize: Tuple[int, int]=(8, 4)) -> Tuple[List['AxesSubplot'], List[AxesImage]]:
# language=rst
"""
Plots a two-dimensional image and its corresponding spike-train representation.

| (:code:`axes` (:code:`list(matplotlib.image.AxesImage)`): Used for re-drawing the input plots.
| (:code:`ims` (:code:`list(matplotlib.axes.Axes)): Used for re-drawing the input plots.
'''
:param image: A 2D array of floats depicting an input image.
:param inpt: A 2D array of floats depicting an image's spike-train encoding.
:param label: Class label of the input data.
:param axes: Used for re-drawing the input plots.
:param ims: Used for re-drawing the input plots.
:param figsize: Horizontal, vertical figure size in inches.
:return: Tuple of ``(axes, ims)`` used for re-drawing the input plots.
"""
if axes is None:
fig, axes = plt.subplots(1, 2, figsize=figsize)
ims = axes[0].imshow(image, cmap='binary'), axes[1].imshow(inpt, cmap='binary')
Expand All @@ -52,7 +52,7 @@ def plot_input(image, inpt, label=None, axes=None, ims=None, figsize=(8, 4)):


def plot_spikes(network=None, spikes=None, layer_to_monitor={}, layers=[], time={}, n_neurons={}, ims=None, axes=None, figsize=(8, 4.5)):
'''
"""
Plot spikes for any group(s) of neurons. Default behavior will plot everything.

Inputs:
Expand All @@ -76,7 +76,7 @@ def plot_spikes(network=None, spikes=None, layer_to_monitor={}, layers=[], time=
| (:code:`ims` (:code:`list(matplotlib.axes.Axes)): Used for re-drawing the spike plots.
| (:code:`axes` (:code:`list(matplotlib.image.AxesImage)`): Used for re-drawing the spike plots.

'''
"""

assert network is not None or spikes is not None, 'No plotting information'

Expand Down Expand Up @@ -239,7 +239,7 @@ def plot_spikes(network=None, spikes=None, layer_to_monitor={}, layers=[], time=


def plot_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)):
'''
"""
Plot a connection weight matrix.

Inputs:
Expand All @@ -253,7 +253,7 @@ def plot_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)):
Returns:

| (:code:`im` (:code:`matplotlib.image.AxesImage`): Used for re-drawing the weights plot.
'''
"""
if not im:
fig, ax = plt.subplots(figsize=figsize)
ax.set_title('Connection weights')
Expand All @@ -274,7 +274,7 @@ def plot_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)):


def plot_conv2d_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)):
'''
"""
Plot a connection weight matrix of a Conv2dConnection.

Inputs:
Expand All @@ -288,7 +288,7 @@ def plot_conv2d_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)):
Returns:

| (:code:`im` (:code:`matplotlib.image.AxesImage`): Used for re-drawing the weights plot.
'''
"""
n_sqrt = int(np.ceil(np.sqrt(weights.size(0))))
height = weights.size(2)
width = weights.size(3)
Expand Down Expand Up @@ -328,7 +328,7 @@ def plot_conv2d_weights(weights, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)):

def plot_locally_connected_weights(weights, n_filters, kernel_size, conv_size, locations,
input_sqrt, wmin=0.0, wmax=1.0, im=None, figsize=(5, 5)):
'''
"""
Plot a connection weight matrix of a :code:`Connection` with
`locally connected structure <http://yann.lecun.com/exdb/publis/pdf/gregor-nips-11.pdf>_.

Expand All @@ -349,7 +349,7 @@ def plot_locally_connected_weights(weights, n_filters, kernel_size, conv_size, l
Returns:

| (:code:`im` (:code:`matplotlib.image.AxesImage`): Used for re-drawing the weights plot.
'''
"""
reshaped = reshape_locally_connected_weights(weights, n_filters, kernel_size,
conv_size, locations, input_sqrt)

Expand Down Expand Up @@ -381,7 +381,7 @@ def plot_locally_connected_weights(weights, n_filters, kernel_size, conv_size, l


def plot_assignments(assignments, im=None, figsize=(5, 5), classes=None):
'''
"""
Plot the two-dimensional neuron assignments.

Inputs:
Expand All @@ -398,7 +398,7 @@ def plot_assignments(assignments, im=None, figsize=(5, 5), classes=None):

| (:code:`im` (:code:`matplotlib.image.AxesImage`):
Used for re-drawing the assigments plot.
'''
"""
if not im:
fig, ax = plt.subplots(figsize=figsize)
ax.set_title('Categorical assignments')
Expand All @@ -423,7 +423,7 @@ def plot_assignments(assignments, im=None, figsize=(5, 5), classes=None):


def plot_performance(performances, ax=None, figsize=(7, 4)):
'''
"""
Plot training accuracy curves.

Inputs:
Expand All @@ -439,7 +439,7 @@ def plot_performance(performances, ax=None, figsize=(7, 4)):

| (:code:`ax` (:code:`matplotlib.axes.Axes`):
Used for re-drawing the performance plot.
'''
"""
if not ax:
_, ax = plt.subplots(figsize=figsize)
else:
Expand All @@ -458,7 +458,7 @@ def plot_performance(performances, ax=None, figsize=(7, 4)):


def plot_general(monitor=None, ims=None, axes=None, labels=None, parameters=None, figsize=(8,4.5)):
'''
"""
General plotting function for variables being monitored.

Inputs:
Expand All @@ -482,7 +482,7 @@ def plot_general(monitor=None, ims=None, axes=None, labels=None, parameters=None
Used for re-drawing plots.
| (:code:`axes` (:code:`list(matplotlib.image.AxesImage)`):
Used for re-drawing plots.
'''
"""
default = {'xlabel' : 'Simulation time', 'ylabel' : 'Index'}

if monitor is None:
Expand Down Expand Up @@ -567,7 +567,7 @@ def plot_general(monitor=None, ims=None, axes=None, labels=None, parameters=None


def plot_voltages(voltages, ims=None, axes=None, time=None, n_neurons={}, figsize=(8, 4.5)):
'''
"""
Plot voltages for any group(s) of neurons.

Inputs:
Expand All @@ -584,7 +584,7 @@ def plot_voltages(voltages, ims=None, axes=None, time=None, n_neurons={}, figsiz
| (:code:`ims` (:code:`list(matplotlib.axes.Axes)): Used for re-drawing the voltage plots.
| (:code:`axes` (:code:`list(matplotlib.image.AxesImage)`): Used for re-drawing the voltage plots.

'''
"""
n_subplots = len(voltages.keys())

# Confirm only 2 values for time were given
Expand Down
96 changes: 43 additions & 53 deletions bindsnet/analysis/visualization.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,93 @@
import sys
import torch
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.animation as animation

from mpl_toolkits.axes_grid1 import make_axes_locatable
from typing import List, Tuple, Optional


def plot_weights_movie(ws, sample_every=1):
def plot_weights_movie(ws: np.ndarray, sample_every: int=1) -> None:
# language=rst
"""
Create and plot movie of weights (:code:`ws`).

Inputs:
Create and plot movie of weights.

| :code:`ws` (:code:`numpy.array`): Numpy array
of shape :code:`[n_examples, source, target, time]`
| :code:`sample_every` (:code:`int`): Sub-sample using this parameter.
:param ws: Array of shape ``[n_examples, source, target, time]``
:param sample_every: Sub-sample using this parameter.
"""
weights = []

# Obtain samples from the weights for every example
# Obtain samples from the weights for every example.
for i in range(ws.shape[0]):
sub_sampled_weight = ws[i, :, :, range(0, ws[i].shape[2], sample_every)]
weights.append(sub_sampled_weight)
else:
weights = np.concatenate(weights, axis=0)

# Initialize plot
# Initialize plot.
fig = plt.figure()
im = plt.imshow(weights[0, :, :], cmap='hot_r', animated=True, vmin=0, vmax=1)
plt.axis('off'); plt.colorbar(im)

# Update function for the animation
# Update function for the animation.
def update(j):
im.set_data(weights[j, :, :])
return [im]

# Initialize animatino
# Initialize animation.
global ani; ani=0
ani = animation.FuncAnimation(fig, update, frames=weights.shape[-1], interval=1000, blit=True)
plt.show()

def plot_spike_trains_for_example(spikes, n_ex=None, top_k=None, indices=None):
'''


def plot_spike_trains_for_example(spikes: torch.Tensor, n_ex: Optional[int]=None, top_k: Optional[int]=None,
indices: Optional[List[int]]=None) -> None:
# language=rst
"""
Plot spike trains for top-k neurons or for specific indices.

Inputs:

| :code:`spikes` (:code:`torch.Tensor (n_examples, n_neurons, time)`):
Spiking train data for a population of neurons for one example.
| :code:`n_ex` (:code:`int`): Allows user to pick
which example to plot spikes for. Must be >= 0.
| :code:`top_k` (:code:`int`): Plot k neurons that spiked the most for n_ex example.
| :code:`indices` (:code:`list(int)`): Plot specific neurons'
spiking activity instead of top_k. Meant to replace top_k.
'''

assert (n_ex is not None and n_ex >= 0 and n_ex < spikes.shape[0])
:param spikes: Spikes for one simulation run of shape ``(n_examples, n_neurons, time)``.
:param n_ex: Allows user to pick which example to plot spikes for.
:param top_k: Plot k neurons that spiked the most for n_ex example.
:param indices: Plot specific neurons' spiking activity instead of top_k.
"""
assert n_ex is not None and 0 <= n_ex < spikes.shape[0]

plt.figure()

if top_k is None and indices is None: # Plot all neurons' spiking activity
spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, :, :]]
if top_k is None and indices is None: # Plot all neurons' spiking activity
spike_per_neuron = [np.argwhere(i == 1).flatten() for i in spikes[n_ex, :, :]]
plt.title('Spiking activity for all %d neurons'%spikes.shape[1])

elif top_k is None: # Plot based on indices parameter
assert (indices is not None)
spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, indices, :]]
spike_per_neuron = [np.argwhere(i == 1).flatten() for i in spikes[n_ex, indices, :]]

elif indices is None: # Plot based on top_k parameter
assert (top_k is not None)
# Obtain the top k neurons that fired the most
top_k_loc = np.argsort(np.sum(spikes[n_ex,:,:], axis=1), axis=0)[::-1]
spike_per_neuron = [np.argwhere(i==1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :]]
plt.title('Spiking activity for top %d neurons'%top_k)
top_k_loc = np.argsort(np.sum(spikes[n_ex, :, :], axis=1), axis=0)[::-1]
spike_per_neuron = [np.argwhere(i == 1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :]]
plt.title('Spiking activity for top %d neurons' % top_k)

else:
raise ValueError('One of "top_k" or "indices" or both must be None')

plt.eventplot(spike_per_neuron, linelengths= [0.5]*len(spike_per_neuron))
plt.xlabel('Simulation Time'); plt.ylabel('Neuron index')
plt.show()

def plot_voltage(voltage, n_ex=0, n_neuron=0, time=None, threshold=None):
'''
def plot_voltage(voltage: torch.Tensor, n_ex: int=0, n_neuron: int=0, time: Optional[Tuple[int, int]]=None,
threshold: float=None) -> None:
# language=rst
"""
Plot voltage for a single neuron on a specific example.

Inputs:

| :code:`voltage` (:code:`torch.Tensor` or :code:`numpy.array`):
Tensor or array of shape :code:`[n_examples, n_neurons, time]`.
| :code:`n_ex` (:code:`int`): Allows user
to pick which example to plot voltage for.
| :code:`n_neuron` (:code:`int`): Neuron
index for which to plot voltages for.
| :code:`time` (:code:`tuple(int)`): Plot spiking
activity of neurons between the given range of time.
| :code:`threshold` (:code:`float`): Neuron
spiking threshold. Will be shown on the plot.
'''


:param voltage: Tensor or array of shape ``[n_examples, n_neurons, time]``.
:param n_ex: Allows user to pick which example to plot voltage for.
:param n_neuron: Neuron index for which to plot voltages for.
:param time: Plot spiking activity of neurons between the given range of time.
:param threshold: Neuron spiking threshold.
"""
assert (n_ex >= 0 and n_neuron >= 0)
assert (n_ex < voltage.shape[0] and n_neuron < voltage.shape[1])

Expand Down