Skip to content

Marcos-Tonari-Diaz/UNET_JAX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

37 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

UNET JAX

Implementation of the UNET convolutional network using JAX, FLAX and OPTAX

Each branch is a different implementation of the training loop, so that they can be easily compared

Dataset: isbi2012

Keras code used for comparison: keras-unet

To run, execute python unet_jax/train_unet.py

Branches with support for batch size = N (mini-batch)

  • unet_jax
  • unet_jax_jit
  • unet_vamp
  • unet_vmap_jit
  • unet_pmap

Branches with support for batch size=1 (SGD)

  • unet_jax_batchsize1
  • unet_jit_batchsize1

Thesse experiments were made with the Santos Dummond Supercomputer in mid 2022.

SGD performance comparison

batchsize1

Mini-batch gradient descent performance comparison

batchsize4

Segmentation Result

From left to right: original image; target; logits; final prediction

seg_result