Skip to content

Commit

Permalink
Merge pull request #557 from JeremCab/master
Browse files Browse the repository at this point in the history
Update examples/mnist/batch_eth_mnist.py
  • Loading branch information
Hananel-Hazan committed Jun 29, 2022
2 parents bc26497 + 206e6c5 commit 2f70f4c
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions examples/mnist/batch_eth_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
parser.add_argument("--n_test", type=int, default=10000)
parser.add_argument("--n_train", type=int, default=60000)
parser.add_argument("--n_workers", type=int, default=-1)
parser.add_argument("--update_steps", type=int, default=256)
parser.add_argument("--n_updates", type=int, default=10)
parser.add_argument("--exc", type=float, default=22.5)
parser.add_argument("--inh", type=float, default=120)
parser.add_argument("--theta_plus", type=float, default=0.05)
Expand All @@ -44,7 +44,7 @@
parser.add_argument("--test", dest="train", action="store_false")
parser.add_argument("--plot", dest="plot", action="store_true")
parser.add_argument("--gpu", dest="gpu", action="store_true")
parser.set_defaults(plot=True, gpu=True)
parser.set_defaults(plot=False, gpu=True)

args = parser.parse_args()

Expand All @@ -55,7 +55,7 @@
n_test = args.n_test
n_train = args.n_train
n_workers = args.n_workers
update_steps = args.update_steps
n_updates = args.n_updates
exc = args.exc
inh = args.inh
theta_plus = args.theta_plus
Expand All @@ -67,6 +67,7 @@
plot = args.plot
gpu = args.gpu

update_steps = int(n_train / batch_size / n_updates)
update_interval = update_steps * batch_size

device = "cpu"
Expand Down Expand Up @@ -162,14 +163,14 @@
spike_record = torch.zeros((update_interval, int(time / dt), n_neurons), device=device)

# Train the network.
print("\nBegin training.\n")
print("\nBegin training...")
start = t()

for epoch in range(n_epochs):
labels = []

if epoch % progress_interval == 0:
print("\n Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start))
print("\nProgress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start))
start = t()

# Create a dataloader to iterate and batch data
Expand All @@ -183,13 +184,10 @@

pbar_training = tqdm(total=n_train)
for step, batch in enumerate(train_dataloader):
if step > n_train:
if step * batch_size > n_train:
break
# Get next input sample.
inputs = {"X": batch["encoded_image"]}
if gpu:
inputs = {k: v.cuda() for k, v in inputs.items()}

# Assign labels to excitatory neurons.
if step % update_steps == 0 and step > 0:
# Convert the array of labels into a tensor
label_tensor = torch.tensor(labels, device=device)
Expand Down Expand Up @@ -245,6 +243,12 @@

labels = []

# Get next input sample.
inputs = {"X": batch["encoded_image"]}
if gpu:
inputs = {k: v.cuda() for k, v in inputs.items()}

# Remember labels.
labels.extend(batch["label"].tolist())

# Run the network on the input.
Expand Down Expand Up @@ -293,9 +297,10 @@

network.reset_state_variables() # Reset state variables.
pbar_training.update(batch_size)
pbar_training.close()

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")
print("\nTraining complete.\n")

# Load MNIST data.
test_dataset = MNIST(
Expand All @@ -322,13 +327,15 @@
accuracy = {"all": 0, "proportion": 0}

# Train the network.
print("\nBegin testing\n")
print("\nBegin testing...\n")
network.train(mode=False)
start = t()

pbar = tqdm(total=n_test)
for step, batch in enumerate(test_dataset):
if step > n_test:
pbar.set_description_str("Test progress: ")

for step, batch in enumerate(test_dataloader):
if step * batch_size > n_test:
break
# Get next input sample.
inputs = {"X": batch["encoded_image"]}
Expand Down Expand Up @@ -362,11 +369,11 @@
)

network.reset_state_variables() # Reset state variables.
pbar.set_description_str("Test progress: ")
pbar.update()
pbar.update(batch_size)
pbar.close()

print("\nAll activity accuracy: %.2f" % (accuracy["all"] / n_test))
print("Proportion weighting accuracy: %.2f \n" % (accuracy["proportion"] / n_test))

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Testing complete.\n")
print("\nTesting complete.\n")

0 comments on commit 2f70f4c

Please sign in to comment.