diff --git a/examples/mnist/batch_eth_mnist.py b/examples/mnist/batch_eth_mnist.py index 103692e6..c11bb119 100644 --- a/examples/mnist/batch_eth_mnist.py +++ b/examples/mnist/batch_eth_mnist.py @@ -166,7 +166,7 @@ labels = [] if epoch % progress_interval == 0: - print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) + print("\n Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) start = t() # Create a dataloader to iterate and batch data @@ -178,7 +178,8 @@ pin_memory=gpu, ) - for step, batch in enumerate(tqdm(train_dataloader)): + pbar_training = tqdm(total=n_train) + for step, batch in enumerate(train_dataloader): if step > n_train: break # Get next input sample. @@ -286,9 +287,7 @@ plt.pause(1e-8) network.reset_state_variables() # Reset state variables. - - if step % update_steps == 0 and step > 0: - break + pbar_training.update() print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start)) print("Training complete.\n")