Skip to content

Commit

Permalink
Merge pull request #457 from BindsNET/hananel
Browse files Browse the repository at this point in the history
Fix plotting issue with batch_eth_mnist and code black formater
  • Loading branch information
Hananel-Hazan committed Feb 15, 2021
2 parents b231f50 + df96c4c commit 6f521a6
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
4 changes: 3 additions & 1 deletion examples/mnist/SOM_LM-SNNs.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@
accuracy = {"all": [], "proportion": []}

# Voltage recording for excitatory and inhibitory layers.
som_voltage_monitor = Monitor(network.layers["Y"], ["v"], time=int(time / dt), device=device)
som_voltage_monitor = Monitor(
network.layers["Y"], ["v"], time=int(time / dt), device=device
)
network.add_monitor(som_voltage_monitor, name="som_voltage")

# Set up monitors for spikes and voltages
Expand Down
17 changes: 10 additions & 7 deletions examples/mnist/batch_eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,25 +133,27 @@
accuracy = {"all": [], "proportion": []}

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=int(time / dt), device=device)
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=int(time / dt), device=device)
exc_voltage_monitor = Monitor(
network.layers["Ae"], ["v"], time=int(time / dt), device=device
)
inh_voltage_monitor = Monitor(
network.layers["Ai"], ["v"], time=int(time / dt), device=device
)
network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")

# Set up monitors for spikes and voltages
spikes = {}
for layer in set(network.layers):
spikes[layer] = Monitor(
network.layers[layer], state_vars=["s"], time=int(time / dt),
device=device
network.layers[layer], state_vars=["s"], time=int(time / dt), device=device
)
network.add_monitor(spikes[layer], name="%s_spikes" % layer)

voltages = {}
for layer in set(network.layers) - {"X"}:
voltages[layer] = Monitor(
network.layers[layer], state_vars=["v"], time=int(time / dt),
device=device
network.layers[layer], state_vars=["v"], time=int(time / dt), device=device
)
network.add_monitor(voltages[layer], name="%s_voltages" % layer)

Expand Down Expand Up @@ -271,6 +273,7 @@
if plot:
image = batch["image"][:, 0].view(28, 28)
inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28)
lable = batch["label"][0]
input_exc_weights = network.connections[("X", "Ae")].w
square_weights = get_square_weights(
input_exc_weights.view(784, n_neurons), n_sqrt, 28
Expand All @@ -281,7 +284,7 @@
}
voltages = {"Ae": exc_voltages, "Ai": inh_voltages}
inpt_axes, inpt_ims = plot_input(
image, inpt, label=labels[step], axes=inpt_axes, ims=inpt_ims
image, inpt, label=lable, axes=inpt_axes, ims=inpt_ims
)
spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes)
weights_im = plot_weights(square_weights, im=weights_im)
Expand Down
8 changes: 6 additions & 2 deletions examples/mnist/eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,12 @@
accuracy = {"all": [], "proportion": []}

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=int(time / dt), device=device)
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=int(time / dt), device=device)
exc_voltage_monitor = Monitor(
network.layers["Ae"], ["v"], time=int(time / dt), device=device
)
inh_voltage_monitor = Monitor(
network.layers["Ai"], ["v"], time=int(time / dt), device=device
)
network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")

Expand Down

0 comments on commit 6f521a6

Please sign in to comment.