Skip to content

Commit

Permalink
Improve pixelcnn
Browse files Browse the repository at this point in the history
  • Loading branch information
juliuskunze committed Nov 29, 2019
1 parent ce6e2b6 commit d82a5c8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/pixelcnn.py
Expand Up @@ -12,7 +12,7 @@
from jaxnet import parametrized, Parameter, Dropout, parameter, save
from jaxnet.optimizers import Adam

image_dtype = np.uint32 # is supported on TPU unlike np.uint8
image_dtype = np.uint8


def _l2_normalize(arr, axis):
Expand Down Expand Up @@ -299,7 +299,7 @@ def main(batch_size=32, nr_filters=8, epochs=10, step_size=.001, decay_rate=.999
f"train loss {train_loss:.3f}, "
f"test loss {test_loss:.3f} ")

save(opt.get_parameters(state), model_path)
save(opt.get_parameters(state), model_path)


if __name__ == '__main__':
Expand Down

0 comments on commit d82a5c8

Please sign in to comment.