federated server of mobile image classification model improvement
This is a POC federated server project, could train an CNN image classification model and involved to federated learning cycle interacting with Android mobile clientside.
I reconstruct it to springboot maven Multi Module Project and adds some features to make it more like production usage.
Including:
- an improved CNN model trainer construct with dl4j alongside with multiple test datasets with prebuild hyper-params,
- flexible image data preprocessing module, allow you to train with your own dataset
- a high concurrency rest api build with rate limiter and cache mechanism
Notice: for some reasons, this is a java-kotlin mixed project
- spring boot with java and kotlin
- Deeplearning4j (model training)
- GeoLite (upload region restrict)
- Guava with rate limiter
- redis (model upload and download cache)
- mongo DB (record model param update records, and federated cycle control[TODO])
- JDK up to 8
- redis at port 6379
- mongo at port 27017 with connection string
mongodb://admin:123456@localhost:27017/spring_boot_mongo_app?authSource=admin&readPreference=primary&appname=MongoDB%20Compass&ssl=false
- IDEA or eclipse
- android studio with physical android phones
- server and phone devices should under the same WI-FI environment
- train an initial model with image dataset
- deploy initial model to android platform client side
- setup and start federated sever
- do clientside image capture and label, and run client side on device training
- client side automatically upload trained model param to server
- while upload model params reach the default server setting, server will do federated average and get a new federated model
- manually or automatically(TODO) deploy the improved model to clientside
TODO
git clone this project and android client side from this branch
- adjust the dataset directory path
model/src/main/kotlin/com/steven/model/bo/TrainParams.kt
- use
model/src/main/kotlin/com/steven/model/Main.kt
to train the initial model with IDE run param argument, arg[0] is train, arg[1] is the destination path dir to save model,
ex:
train E:/workspace/phModelNew
- after get the trained model, rename the file to model.zip, and put to the initial model path define under
server/src/main/resources/application.properties
model_dir = E:/workspace/phModelNew
- make sure the model_dir exist, and start the local server at
server/src/main/kotlin/com/steven/server/ServerApplication.kt
- visits the following APIs to make sure all necessary documents initialized
http://localhost:8080/service/federatedservice/available
http://localhost:8080/service/federatedservice/currentRound
dev\PhotoLabeller\app\gradle.properties parameter_server_url
IPV6 to your own machine's
- deploy the same initial model file to Photolabeller android project under
PhotoLabeller\app\src\main\assets
- four different places should be modified to math model's name also it's labels
photolabeller.di.MainAppModule private final val modelFileName
photolabeller.trainer.ClientImageLoader public final fun createDataSet() val label
photolabeller.labeller.MainFragment.Companion public final val label
photolabeller.config.SharedConfig public final val labels
- run app branch
- take some pictures and manually label the prediction to actual classification
- run on device training, and will automatically update model params to serverside
- if default min_updates threshold is reached, serverside will run federated average algo. and produce new federated model
This project inspired from PhotoLabellerServer, the core package usage with MIT license.
You could also check out my contribution of the workable PhotoLabeller project use dl4j official cifar10 dataset from the following branches.