Skip to content

Commit

Permalink
Merge pull request #492 from cearlUmass/master
Browse files Browse the repository at this point in the history
Update reservoir.py
  • Loading branch information
Hananel-Hazan committed Jun 13, 2021
2 parents e47ae91 + fecc12e commit eac5af1
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions examples/mnist/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
torch.set_num_threads(os.cpu_count() - 1)
print("Running on Device = ", device)

# Create simple Torch NN
network = Network(dt=dt)
inpt = Input(784, shape=(1, 28, 28))
network.add_layer(inpt, name="I")
Expand All @@ -84,6 +85,7 @@
network.add_connection(C1, source="I", target="O")
network.add_connection(C2, source="O", target="O")

# Monitors for visualizing activity
spikes = {}
for l in network.layers:
spikes[l] = Monitor(network.layers[l], ["s"], time=time, device=device)
Expand All @@ -101,7 +103,7 @@
dataset = MNIST(
PoissonEncoder(time=time, dt=dt),
None,
root=os.path.join("..", "..", "data", "MNIST"),
root=os.path.join("..", "data", "MNIST"),
download=True,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)]
Expand All @@ -123,19 +125,26 @@
)

# Run training data on reservoir computer and store (spikes per neuron, label) per example.
# Note: Because this is a reservoir network, no adjustments of neuron parameters occurs in this phase.
n_iters = examples
training_pairs = []
pbar = tqdm(enumerate(dataloader))
for (i, dataPoint) in pbar:
if i > n_iters:
break

# Extract & resize the MNIST samples image data for training
# int(time / dt) -> length of spike train
# 28 x 28 -> size of sample
datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
label = dataPoint["label"]
pbar.set_description_str("Train progress: (%d / %d)" % (i, n_iters))

# Run network on sample image
network.run(inputs={"I": datum}, time=time, input_time_dim=1)
training_pairs.append([spikes["O"].get("s").sum(0), label])

# Plot spiking activity using monitors
if plot:

inpt_axes, inpt_ims = plot_input(
Expand Down Expand Up @@ -165,6 +174,7 @@


# Define logistic regression model using PyTorch.
# These neurons will take the reservoirs output as its input, and be trained to classify the images.
class NN(nn.Module):
def __init__(self, input_size, num_classes):
super(NN, self).__init__()
Expand All @@ -189,14 +199,26 @@ def forward(self, x):
pbar = tqdm(enumerate(range(n_epochs)))
for epoch, _ in pbar:
avg_loss = 0

# Extract spike outputs from reservoir for a training sample
# i -> Loop index
# s -> Reservoir output spikes
# l -> Image label
for i, (s, l) in enumerate(training_pairs):
# Forward + Backward + Optimize

# Reset gradients to 0
optimizer.zero_grad()

# Run spikes through logistic regression model
outputs = model(s)

# Calculate MSE
label = torch.zeros(1, 1, 10).float().to(device)
label[0, 0, l] = 1.0
loss = criterion(outputs.view(1, 1, -1), label)
avg_loss += loss.data

# Optimize parameters
loss.backward()
optimizer.step()

Expand All @@ -205,17 +227,19 @@ def forward(self, x):
% (epoch + 1, n_epochs, avg_loss / len(training_pairs))
)

# Run same simulation on reservoir with testing data instead of training data
# (see training section for intuition)
n_iters = examples
test_pairs = []
pbar = tqdm(enumerate(dataloader))
for (i, dataPoint) in pbar:
if i > n_iters:
break
datum = dataPoint["encoded_image"].view(time, 1, 1, 28, 28).to(device)
datum = dataPoint["encoded_image"].view(int(time / dt), 1, 1, 28, 28).to(device)
label = dataPoint["label"]
pbar.set_description_str("Testing progress: (%d / %d)" % (i, n_iters))

network.run(inputs={"I": datum}, time=250, input_time_dim=1)
network.run(inputs={"I": datum}, time=time, input_time_dim=1)
test_pairs.append([spikes["O"].get("s").sum(0), label])

if plot:
Expand All @@ -227,12 +251,12 @@ def forward(self, x):
ims=inpt_ims,
)
spike_ims, spike_axes = plot_spikes(
{layer: spikes[layer].get("s").view(-1, 250) for layer in spikes},
{layer: spikes[layer].get("s").view(time, -1) for layer in spikes},
axes=spike_axes,
ims=spike_ims,
)
voltage_ims, voltage_axes = plot_voltages(
{layer: voltages[layer].get("v").view(-1, 250) for layer in voltages},
{layer: voltages[layer].get("v").view(time, -1) for layer in voltages},
ims=voltage_ims,
axes=voltage_axes,
)
Expand All @@ -244,7 +268,7 @@ def forward(self, x):
plt.pause(1e-8)
network.reset_state_variables()

# Test the Model
# Test model with previously trained logistic regression classifier
correct, total = 0, 0
for s, label in test_pairs:
outputs = model(s)
Expand Down

0 comments on commit eac5af1

Please sign in to comment.