Skip to content

Commit

Permalink
Black formatting with default line length.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Saunders committed Sep 1, 2020
1 parent dbad95c commit 3ca676b
Show file tree
Hide file tree
Showing 38 changed files with 304 additions and 869 deletions.
16 changes: 4 additions & 12 deletions bindsnet/analysis/pipeline_analysis.py
Expand Up @@ -28,9 +28,7 @@ def finalize_step(self) -> None:
pass

@abstractmethod
def plot_obs(
self, obs: torch.Tensor, tag: str = "obs", step: int = None
) -> None:
def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None:
# language=rst
"""
Pulls the observation from PyTorch and sets up for Matplotlib
Expand Down Expand Up @@ -139,9 +137,7 @@ def __init__(self, **kwargs) -> None:
plt.ion()
self.plots = {}

def plot_obs(
self, obs: torch.Tensor, tag: str = "obs", step: int = None
) -> None:
def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None:
# language=rst
"""
Pulls the observation off of torch and sets up for Matplotlib
Expand Down Expand Up @@ -260,9 +256,7 @@ def plot_voltages(
"""
if tag not in self.plots:
self.plots[tag] = plot_voltages(
voltage_record,
plot_type=self.volts_type,
thresholds=thresholds,
voltage_record, plot_type=self.volts_type, thresholds=thresholds,
)
else:
v_im, v_ax = self.plots[tag]
Expand Down Expand Up @@ -321,9 +315,7 @@ def finalize_step(self) -> None:
"""
pass

def plot_obs(
self, obs: torch.Tensor, tag: str = "obs", step: int = None
) -> None:
def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None:
# language=rst
"""
Pulls the observation off of torch and sets up for Matplotlib
Expand Down
83 changes: 28 additions & 55 deletions bindsnet/analysis/plotting.py
Expand Up @@ -115,8 +115,7 @@ def plot_spikes(
for i, datum in enumerate(spikes.items()):
spikes = (
datum[1][
time[0] : time[1],
n_neurons[datum[0]][0] : n_neurons[datum[0]][1],
time[0] : time[1], n_neurons[datum[0]][0] : n_neurons[datum[0]][1],
]
.detach()
.clone()
Expand Down Expand Up @@ -144,19 +143,14 @@ def plot_spikes(
ax.set_aspect("auto")

plt.setp(
axes,
xticks=[],
yticks=[],
xlabel="Simulation time",
ylabel="Neuron index",
axes, xticks=[], yticks=[], xlabel="Simulation time", ylabel="Neuron index",
)
plt.tight_layout()
else:
for i, datum in enumerate(spikes.items()):
spikes = (
datum[1][
time[0] : time[1],
n_neurons[datum[0]][0] : n_neurons[datum[0]][1],
time[0] : time[1], n_neurons[datum[0]][0] : n_neurons[datum[0]][1],
]
.detach()
.clone()
Expand Down Expand Up @@ -375,16 +369,11 @@ def plot_assignments(

if classes is None:
color = plt.get_cmap("RdBu", 11)
im = ax.matshow(
locals_assignments, cmap=color, vmin=-1.5, vmax=9.5
)
im = ax.matshow(locals_assignments, cmap=color, vmin=-1.5, vmax=9.5)
else:
color = plt.get_cmap("RdBu", len(classes) + 1)
im = ax.matshow(
locals_assignments,
cmap=color,
vmin=-1.5,
vmax=len(classes) - 0.5,
locals_assignments, cmap=color, vmin=-1.5, vmax=len(classes) - 0.5,
)

div = make_axes_locatable(ax)
Expand Down Expand Up @@ -509,14 +498,12 @@ def plot_voltages(
)
)

if thresholds is not None and thresholds[
v[0]
].size() == torch.Size([]):
if thresholds is not None and thresholds[v[0]].size() == torch.Size(
[]
):
ims.append(
axes.axhline(
y=thresholds[v[0]].item(),
c="r",
linestyle="--",
y=thresholds[v[0]].item(), c="r", linestyle="--",
)
)
else:
Expand All @@ -540,10 +527,7 @@ def plot_voltages(
time[0],
time[1],
)
plt.title(
"%s voltages for neurons (%d - %d) from t = %d to %d "
% args
)
plt.title("%s voltages for neurons (%d - %d) from t = %d to %d " % args)
plt.xlabel("Time (ms)")

if plot_type == "line":
Expand All @@ -566,14 +550,12 @@ def plot_voltages(
]
)
)
if thresholds is not None and thresholds[
v[0]
].size() == torch.Size([]):
if thresholds is not None and thresholds[v[0]].size() == torch.Size(
[]
):
ims.append(
axes[i].axhline(
y=thresholds[v[0]].item(),
c="r",
linestyle="--",
y=thresholds[v[0]].item(), c="r", linestyle="--",
)
)
else:
Expand All @@ -597,8 +579,7 @@ def plot_voltages(
time[1],
)
axes[i].set_title(
"%s voltages for neurons (%d - %d) from t = %d to %d "
% args
"%s voltages for neurons (%d - %d) from t = %d to %d " % args
)

for ax in axes:
Expand All @@ -621,23 +602,19 @@ def plot_voltages(
v[1]
.cpu()
.numpy()[
n_neurons[v[0]][0] : n_neurons[v[0]][1],
time[0] : time[1],
n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1],
]
)
if thresholds is not None and thresholds[
v[0]
].size() == torch.Size([]):
axes.axhline(
y=thresholds[v[0]].item(), c="r", linestyle="--"
)
if thresholds is not None and thresholds[v[0]].size() == torch.Size(
[]
):
axes.axhline(y=thresholds[v[0]].item(), c="r", linestyle="--")
else:
axes.matshow(
v[1]
.cpu()
.numpy()[
n_neurons[v[0]][0] : n_neurons[v[0]][1],
time[0] : time[1],
n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1],
]
.T,
cmap=cmap,
Expand All @@ -650,8 +627,7 @@ def plot_voltages(
time[1],
)
axes.set_title(
"%s voltages for neurons (%d - %d) from t = %d to %d "
% args
"%s voltages for neurons (%d - %d) from t = %d to %d " % args
)
axes.set_aspect("auto")

Expand All @@ -664,13 +640,12 @@ def plot_voltages(
v[1]
.cpu()
.numpy()[
n_neurons[v[0]][0] : n_neurons[v[0]][1],
time[0] : time[1],
n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1],
]
)
if thresholds is not None and thresholds[
v[0]
].size() == torch.Size([]):
if thresholds is not None and thresholds[v[0]].size() == torch.Size(
[]
):
axes[i].axhline(
y=thresholds[v[0]].item(), c="r", linestyle="--"
)
Expand All @@ -679,8 +654,7 @@ def plot_voltages(
v[1]
.cpu()
.numpy()[
n_neurons[v[0]][0] : n_neurons[v[0]][1],
time[0] : time[1],
n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1],
]
.T,
cmap=cmap,
Expand All @@ -693,8 +667,7 @@ def plot_voltages(
time[1],
)
axes[i].set_title(
"%s voltages for neurons (%d - %d) from t = %d to %d "
% args
"%s voltages for neurons (%d - %d) from t = %d to %d " % args
)

for ax in axes:
Expand Down
23 changes: 6 additions & 17 deletions bindsnet/analysis/visualization.py
Expand Up @@ -18,18 +18,14 @@ def plot_weights_movie(ws: np.ndarray, sample_every: int = 1) -> None:

# Obtain samples from the weights for every example.
for i in range(ws.shape[0]):
sub_sampled_weight = ws[
i, :, :, range(0, ws[i].shape[2], sample_every)
]
sub_sampled_weight = ws[i, :, :, range(0, ws[i].shape[2], sample_every)]
weights.append(sub_sampled_weight)
else:
weights = np.concatenate(weights, axis=0)

# Initialize plot.
fig = plt.figure()
im = plt.imshow(
weights[0, :, :], cmap="hot_r", animated=True, vmin=0, vmax=1
)
im = plt.imshow(weights[0, :, :], cmap="hot_r", animated=True, vmin=0, vmax=1)
plt.axis("off")
plt.colorbar(im)

Expand Down Expand Up @@ -68,9 +64,7 @@ def plot_spike_trains_for_example(
plt.figure()

if top_k is None and indices is None: # Plot all neurons' spiking activity
spike_per_neuron = [
np.argwhere(i == 1).flatten() for i in spikes[n_ex, :, :]
]
spike_per_neuron = [np.argwhere(i == 1).flatten() for i in spikes[n_ex, :, :]]
plt.title("Spiking activity for all %d neurons" % spikes.shape[1])

elif top_k is None: # Plot based on indices parameter
Expand All @@ -82,12 +76,9 @@ def plot_spike_trains_for_example(
elif indices is None: # Plot based on top_k parameter
assert top_k is not None
# Obtain the top k neurons that fired the most
top_k_loc = np.argsort(np.sum(spikes[n_ex, :, :], axis=1), axis=0)[
::-1
]
top_k_loc = np.argsort(np.sum(spikes[n_ex, :, :], axis=1), axis=0)[::-1]
spike_per_neuron = [
np.argwhere(i == 1).flatten()
for i in spikes[n_ex, top_k_loc[0:top_k], :]
np.argwhere(i == 1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :]
]
plt.title("Spiking activity for top %d neurons" % top_k)

Expand Down Expand Up @@ -133,9 +124,7 @@ def plot_voltage(
plt.plot(voltage[n_ex, n_neuron, timer])
plt.xlabel("Simulation Time")
plt.ylabel("Voltage")
plt.title(
"Membrane voltage of neuron %d for example %d" % (n_neuron, n_ex + 1)
)
plt.title("Membrane voltage of neuron %d for example %d" % (n_neuron, n_ex + 1))
locs, labels = plt.xticks()
locs = range(int(locs[1]), int(locs[-1]), 10)
plt.xticks(locs, time_ticks)
Expand Down
44 changes: 13 additions & 31 deletions bindsnet/conversion/conversion.py
Expand Up @@ -126,26 +126,22 @@ def set_requires_grad(module, value):

if isinstance(module2, nn.ReLU):
if prev_module is not None:
scale_factor = np.percentile(
activations.cpu(), percentile
)
scale_factor = np.percentile(activations.cpu(), percentile)

prev_module.weight *= prev_factor / scale_factor
prev_module.bias /= scale_factor

prev_factor = scale_factor

elif isinstance(module2, nn.Linear) or isinstance(
module2, nn.Conv2d
):
elif isinstance(module2, nn.Linear) or isinstance(module2, nn.Conv2d):
prev_module = module2

if isinstance(module2, nn.Linear):
if prev_module is not None:
scale_factor = np.percentile(activations.cpu(), percentile)
prev_module.weight *= prev_factor / scale_factor
prev_module.bias /= scale_factor
prev_factor = scale_factor
if prev_module is not None:
scale_factor = np.percentile(activations.cpu(), percentile)
prev_module.weight *= prev_factor / scale_factor
prev_module.bias /= scale_factor
prev_factor = scale_factor

else:
activations = all_activations[name]
Expand All @@ -158,9 +154,7 @@ def set_requires_grad(module, value):

prev_factor = scale_factor

elif isinstance(module, nn.Linear) or isinstance(
module, nn.Conv2d
):
elif isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
prev_module = module

return ann
Expand All @@ -187,9 +181,7 @@ def _ann_to_snn_helper(prev, current, node_type, last=False, **kwargs):
sum_input=last,
**kwargs,
)
bias = (
current.bias if current.bias is not None else torch.zeros(layer.n)
)
bias = current.bias if current.bias is not None else torch.zeros(layer.n)
connection = topology.Connection(
source=prev, target=layer, w=current.weight.t(), b=bias
)
Expand All @@ -209,11 +201,7 @@ def _ann_to_snn_helper(prev, current, node_type, last=False, **kwargs):
layer = node_type(
shape=shape, reset=0, thresh=1, refrac=0, sum_input=last, **kwargs
)
bias = (
current.bias
if current.bias is not None
else torch.zeros(layer.shape[1])
)
bias = current.bias if current.bias is not None else torch.zeros(layer.shape[1])
connection = topology.Conv2dConnection(
source=prev,
target=layer,
Expand Down Expand Up @@ -259,9 +247,7 @@ def _ann_to_snn_helper(prev, current, node_type, last=False, **kwargs):
]
)

connection = PermuteConnection(
source=prev, target=layer, dims=current.dims
)
connection = PermuteConnection(source=prev, target=layer, dims=current.dims)

elif isinstance(current, nn.ConstantPad2d):
layer = PassThroughNodes(
Expand Down Expand Up @@ -317,9 +303,7 @@ def ann_to_snn(
if data is None:
import warnings

warnings.warn(
"Data is None. Weights will not be scaled.", RuntimeWarning
)
warnings.warn("Data is None. Weights will not be scaled.", RuntimeWarning)
else:
ann = data_based_normalization(
ann=ann, data=data.detach(), percentile=percentile
Expand All @@ -342,9 +326,7 @@ def ann_to_snn(
prev = input_layer
while i < len(children) - 1:
current, nxt = children[i : i + 2]
layer, connection = _ann_to_snn_helper(
prev, current, node_type, **kwargs
)
layer, connection = _ann_to_snn_helper(prev, current, node_type, **kwargs)

i += 1

Expand Down

0 comments on commit 3ca676b

Please sign in to comment.