Skip to content

Commit

Permalink
Merge pull request #484 from BindsNET/hananel
Browse files Browse the repository at this point in the history
Encoding -> single
  • Loading branch information
Hananel-Hazan committed May 4, 2021
2 parents 555cab0 + 26596e3 commit ea593ad
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 103 deletions.
58 changes: 10 additions & 48 deletions bindsnet/analysis/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def plot_spikes(
for i, datum in enumerate(spikes.items()):
spikes = (
datum[1][
time[0] : time[1],
n_neurons[datum[0]][0] : n_neurons[datum[0]][1],
time[0] : time[1], n_neurons[datum[0]][0] : n_neurons[datum[0]][1]
]
.detach()
.clone()
Expand Down Expand Up @@ -144,19 +143,13 @@ def plot_spikes(
for ax in axes:
ax.set_aspect("auto")

plt.setp(
axes,
xticks=[],
xlabel="Simulation time",
ylabel="Neuron index",
)
plt.setp(axes, xticks=[], xlabel="Simulation time", ylabel="Neuron index")
plt.tight_layout()
else:
for i, datum in enumerate(spikes.items()):
spikes = (
datum[1][
time[0] : time[1],
n_neurons[datum[0]][0] : n_neurons[datum[0]][1],
time[0] : time[1], n_neurons[datum[0]][0] : n_neurons[datum[0]][1]
]
.detach()
.clone()
Expand Down Expand Up @@ -424,10 +417,7 @@ def plot_assignments(
else:
color = plt.get_cmap("RdBu", len(classes) + 1)
im = ax.matshow(
locals_assignments,
cmap=color,
vmin=-1.5,
vmax=len(classes) - 0.5,
locals_assignments, cmap=color, vmin=-1.5, vmax=len(classes) - 0.5
)

div = make_axes_locatable(ax)
Expand Down Expand Up @@ -616,9 +606,7 @@ def plot_voltages(
):
ims.append(
axes.axhline(
y=thresholds[v[0]].item(),
c="r",
linestyle="--",
y=thresholds[v[0]].item(), c="r", linestyle="--"
)
)
else:
Expand All @@ -635,13 +623,7 @@ def plot_voltages(
)
)

args = (
v[0],
n_neurons[v[0]][0],
n_neurons[v[0]][1],
time[0],
time[1],
)
args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1])
plt.title("%s voltages for neurons (%d - %d) from t = %d to %d " % args)
plt.xlabel("Time (ms)")

Expand Down Expand Up @@ -670,9 +652,7 @@ def plot_voltages(
):
ims.append(
axes[i].axhline(
y=thresholds[v[0]].item(),
c="r",
linestyle="--",
y=thresholds[v[0]].item(), c="r", linestyle="--"
)
)
else:
Expand All @@ -688,13 +668,7 @@ def plot_voltages(
cmap=cmap,
)
)
args = (
v[0],
n_neurons[v[0]][0],
n_neurons[v[0]][1],
time[0],
time[1],
)
args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1])
axes[i].set_title(
"%s voltages for neurons (%d - %d) from t = %d to %d " % args
)
Expand Down Expand Up @@ -736,13 +710,7 @@ def plot_voltages(
.T,
cmap=cmap,
)
args = (
v[0],
n_neurons[v[0]][0],
n_neurons[v[0]][1],
time[0],
time[1],
)
args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1])
axes.set_title(
"%s voltages for neurons (%d - %d) from t = %d to %d " % args
)
Expand Down Expand Up @@ -776,13 +744,7 @@ def plot_voltages(
.T,
cmap=cmap,
)
args = (
v[0],
n_neurons[v[0]][0],
n_neurons[v[0]][1],
time[0],
time[1],
)
args = (v[0], n_neurons[v[0]][0], n_neurons[v[0]][1], time[0], time[1])
axes[i].set_title(
"%s voltages for neurons (%d - %d) from t = %d to %d " % args
)
Expand Down
5 changes: 1 addition & 4 deletions bindsnet/datasets/alov300.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,7 @@ def get_sample(self, idx):
bbox_curr_gt = BoundingBox(currbb[0], currbb[1], currbb[2], currbb[3])
bbox_gt_recentered = BoundingBox(0, 0, 0, 0)
bbox_gt_recentered = bbox_curr_gt.recenter(
rand_search_location,
edge_spacing_x,
edge_spacing_y,
bbox_gt_recentered,
rand_search_location, edge_spacing_x, edge_spacing_y, bbox_gt_recentered
)
curr_sample["image"] = rand_search_region
curr_sample["bb"] = bbox_gt_recentered.get_bb_list()
Expand Down
26 changes: 8 additions & 18 deletions bindsnet/datasets/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,7 @@ def shift_crop_training_sample(sample, bb_params):
bbox_curr_gt = BoundingBox(currbb[0], currbb[1], currbb[2], currbb[3])
bbox_gt_recentered = BoundingBox(0, 0, 0, 0)
bbox_gt_recentered = bbox_curr_gt.recenter(
rand_search_location,
edge_spacing_x,
edge_spacing_y,
bbox_gt_recentered,
rand_search_location, edge_spacing_x, edge_spacing_y, bbox_gt_recentered
)
output_sample["image"] = rand_search_region
output_sample["bb"] = bbox_gt_recentered.get_bb_list()
Expand All @@ -155,12 +152,9 @@ def crop_sample(sample):
opts = {}
image, bb = sample["image"], sample["bb"]
orig_bbox = BoundingBox(bb[0], bb[1], bb[2], bb[3])
(
output_image,
pad_image_location,
edge_spacing_x,
edge_spacing_y,
) = cropPadImage(orig_bbox, image)
(output_image, pad_image_location, edge_spacing_x, edge_spacing_y) = cropPadImage(
orig_bbox, image
)
new_bbox = BoundingBox(0, 0, 0, 0)
new_bbox = new_bbox.recenter(
pad_image_location, edge_spacing_x, edge_spacing_y, new_bbox
Expand Down Expand Up @@ -198,8 +192,7 @@ def cropPadImage(bbox_tight, image):
output_height = max(math.ceil(bbox_tight.compute_output_height()), roi_height)
if image.ndim > 2:
output_image = np.zeros(
(int(output_height), int(output_width), image.shape[2]),
dtype=image.dtype,
(int(output_height), int(output_width), image.shape[2]), dtype=image.dtype
)
else:
output_image = np.zeros(
Expand Down Expand Up @@ -392,8 +385,7 @@ def shift(
):
if shift_motion_model:
width_scale_factor = max(
min_scale,
min(max_scale, sample_exp_two_sides(lambda_scale_frac)),
min_scale, min(max_scale, sample_exp_two_sides(lambda_scale_frac))
)
else:
rand_num = sample_rand_uniform()
Expand All @@ -410,8 +402,7 @@ def shift(
):
if shift_motion_model:
height_scale_factor = max(
min_scale,
min(max_scale, sample_exp_two_sides(lambda_scale_frac)),
min_scale, min(max_scale, sample_exp_two_sides(lambda_scale_frac))
)
else:
rand_num = sample_rand_uniform()
Expand Down Expand Up @@ -464,8 +455,7 @@ def shift(
new_y_temp = center_y + rand_num * (2 * new_height) - new_height

new_center_y = min(
image.shape[0] - new_height / 2,
max(new_height / 2, new_y_temp),
image.shape[0] - new_height / 2, max(new_height / 2, new_y_temp)
)
first_time_y = False
num_tries_y = num_tries_y + 1
Expand Down
9 changes: 4 additions & 5 deletions bindsnet/encoding/encodings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Optional

import torch
import numpy as np


def single(
Expand All @@ -27,10 +26,10 @@ def single(
"""
time = int(time / dt)
shape = list(datum.shape)
datum = np.copy(datum)
quantile = np.quantile(datum, 1 - sparsity)
s = np.zeros([time, *shape], device=device)
s[0] = np.where(datum > quantile, np.ones(shape), np.zeros(shape))
datum = torch.tensor(datum)
quantile = torch.quantile(datum, 1 - sparsity)
s = torch.zeros([time, *shape], device=device)
s[0] = torch.where(datum > quantile, torch.ones(shape), torch.zeros(shape))
return torch.Tensor(s).byte()


Expand Down
4 changes: 1 addition & 3 deletions bindsnet/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,9 +405,7 @@ def __init__(
w = w / w.max()
w = (w * self.max_inhib) + self.start_inhib
recurrent_output_conn = Connection(
source=self.layers["Y"],
target=self.layers["Y"],
w=w,
source=self.layers["Y"], target=self.layers["Y"], w=w
)
self.add_connection(recurrent_output_conn, source="Y", target="Y")

Expand Down
3 changes: 1 addition & 2 deletions bindsnet/network/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,7 @@ def __init__(

self.w = Parameter(w, requires_grad=False)
self.b = Parameter(
kwargs.get("b", torch.zeros(self.out_channels)),
requires_grad=False,
kwargs.get("b", torch.zeros(self.out_channels)), requires_grad=False
)

def compute(self, s: torch.Tensor) -> torch.Tensor:
Expand Down
10 changes: 2 additions & 8 deletions examples/mnist/batch_eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@
from bindsnet import ROOT_DIR
from bindsnet.datasets import MNIST, DataLoader
from bindsnet.encoding import PoissonEncoder
from bindsnet.evaluation import (
all_activity,
proportion_weighting,
assign_labels,
)
from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
from bindsnet.models import DiehlAndCook2015
from bindsnet.network.monitors import Monitor
from bindsnet.utils import get_square_weights, get_square_assignments
Expand Down Expand Up @@ -201,9 +197,7 @@

# Get network predictions.
all_activity_pred = all_activity(
spikes=spike_record,
assignments=assignments,
n_labels=n_classes,
spikes=spike_record, assignments=assignments, n_labels=n_classes
)
proportion_pred = proportion_weighting(
spikes=spike_record,
Expand Down
18 changes: 3 additions & 15 deletions examples/mnist/supervised_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@
from bindsnet.models import DiehlAndCook2015
from bindsnet.network.monitors import Monitor
from bindsnet.utils import get_square_assignments, get_square_weights
from bindsnet.evaluation import (
all_activity,
proportion_weighting,
assign_labels,
)
from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels
from bindsnet.analysis.plotting import (
plot_input,
plot_assignments,
Expand Down Expand Up @@ -183,11 +179,7 @@

print(
"\nAll activity accuracy: %.2f (last), %.2f (average), %.2f (best)"
% (
accuracy["all"][-1],
np.mean(accuracy["all"]),
np.max(accuracy["all"]),
)
% (accuracy["all"][-1], np.mean(accuracy["all"]), np.max(accuracy["all"]))
)
print(
"Proportion weighting accuracy: %.2f (last), %.2f (average), %.2f (best)\n"
Expand Down Expand Up @@ -233,11 +225,7 @@
voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

inpt_axes, inpt_ims = plot_input(
image.sum(1).view(28, 28),
inpt,
label=label,
axes=inpt_axes,
ims=inpt_ims,
image.sum(1).view(28, 28), inpt, label=label, axes=inpt_axes, ims=inpt_ims
)
spike_ims, spike_axes = plot_spikes(
{layer: spikes[layer].get("s").view(time, 1, -1) for layer in spikes},
Expand Down

0 comments on commit ea593ad

Please sign in to comment.