Implementation of the UNET convolutional network using JAX, FLAX and OPTAX
Each branch is a different implementation of the training loop, so that they can be easily compared
Dataset: isbi2012
Keras code used for comparison: keras-unet
To run, execute python unet_jax/train_unet.py
- unet_jax
- unet_jax_jit
- unet_vamp
- unet_vmap_jit
- unet_pmap
- unet_jax_batchsize1
- unet_jit_batchsize1
Thesse experiments were made with the Santos Dummond Supercomputer in mid 2022.