It is based on flax's ImageNet classification sample.
- GPU backend
- TPU backend
- Google Colab notebook
- Install dependency
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html pip install flax pip install ml_collections clu pip install tensorflow tensorflow_datasets tensorboard pip install tf-models-official
git clone --depth 1 https://github.com/NobuoTsukamoto/jax_examples.git
cd jax_example
export PYTHONPATH=`pwd`/common:$PYTHONPATH
cd classification/
python main.py \
--task train \
--config configs/`config_file` \
--workdir `full path for workdir`
imagenet2012
Model | Backend | Config | Top-1 accuracy | Epochs |
---|---|---|---|---|
ResNet50 | TPU v2-8 | config | 76.3 % | 100 |
python main.py --task summarize --config configs/`_MODEL_`.py