diff --git a/federated_learning/breast_density_challenge/.dockerignore b/federated_learning/breast_density_challenge/.dockerignore new file mode 100644 index 0000000000..6029ce1d95 --- /dev/null +++ b/federated_learning/breast_density_challenge/.dockerignore @@ -0,0 +1,3 @@ +# Ignore the following files/folders during docker build + +__pycache__/ diff --git a/federated_learning/breast_density_challenge/.gitignore b/federated_learning/breast_density_challenge/.gitignore new file mode 100644 index 0000000000..6f15968018 --- /dev/null +++ b/federated_learning/breast_density_challenge/.gitignore @@ -0,0 +1,12 @@ +# IDE +.idea/ + +# artifacts +poc/ +*.pyc +result_* +*.pth +logs + +# example data +*preprocessed* diff --git a/federated_learning/breast_density_challenge/Dockerfile b/federated_learning/breast_density_challenge/Dockerfile new file mode 100644 index 0000000000..e9d1ce585d --- /dev/null +++ b/federated_learning/breast_density_challenge/Dockerfile @@ -0,0 +1,36 @@ +# use python base image +FROM python:3.8.10 +ENV DEBIAN_FRONTEND noninteractive + +# specify the server FQDN as commandline argument +ARG server_fqdn +RUN echo "Setting up FL workspace wit FQDN: ${server_fqdn}" + +# add your code to container +COPY code /code + +# add code to path +ENV PYTHONPATH=${PYTHONPATH}:"/code" + +# install dependencies +# RUN python -m pip install --upgrade pip +RUN pip3 install tensorboard sklearn torchvision +RUN pip3 install monai==0.8.1 +RUN pip3 install nvflare==2.0.16 + +# mount nvflare from source +#RUN pip install tenseal +#WORKDIR /code +#RUN git clone https://github.com/NVIDIA/NVFlare.git +#ENV PYTHONPATH=${PYTHONPATH}:"/code/NVFlare" + +# download pretrained weights +ENV TORCH_HOME=/opt/torch +RUN python3 /code/pt/utils/download_model.py --model_url=https://download.pytorch.org/models/resnet18-f37072fd.pth + +# prepare FL workspace +WORKDIR /code +RUN sed -i "s|{SERVER_FQDN}|${server_fqdn}|g" fl_project.yml +RUN python3 -m nvflare.lighter.provision -p fl_project.yml +RUN cp -r workspace/fl_project/prod_00 fl_workspace +RUN mv fl_workspace/${server_fqdn} fl_workspace/server diff --git a/federated_learning/breast_density_challenge/README.md b/federated_learning/breast_density_challenge/README.md new file mode 100644 index 0000000000..f7506bd439 --- /dev/null +++ b/federated_learning/breast_density_challenge/README.md @@ -0,0 +1,176 @@ +## MammoFL_MICCAI2022 + +Reference implementation for +[ACR-NVIDIA-NCI Breast Density FL challenge](http://BreastDensityFL.acr.org). + +Held in conjunction with [MICCAI 2022](https://conferences.miccai.org/2022/en/). + + +------------------------------------------------ +## 1. Run Training using [NVFlare](https://github.com/NVIDIA/NVFlare) reference implementation + +We provide a minimal example of how to implement Federated Averaging using [NVFlare 2.0](https://github.com/NVIDIA/NVFlare) and [MONAI](https://monai.io/) to train + a breast density prediction model with ResNet18. + +### 1.1 Download example data +Follow the steps described in [./data/README.md](./data/README.md) to download an example breast density mammography dataset. +Note, the data used in the actual challenge will be different. We do however follow the same preprocessing steps and +use the same four BI-RADS breast density classes for prediction, See [./code/pt/utils/preprocess_dicomdir.py](./code/pt/utils/preprocess_dicomdir.py) for details. + +We provide a set of random data splits. Please download them using +``` +python3 ./code/pt/utils/download_datalists_and_predictions.py +``` +After download, they will be available as `./data/dataset_blinded_site-*.json` which follows the same format as what +will be used in the challenge. +Please do not modify the data list filenames in the configs as they will be the same during the challenge. + +Note, the location of the dataset and data lists will be given by the system. +Do not change the locations given in [config_fed_client.json](./code/configs/mammo_fedavg/config/config_fed_client.json): +``` + "DATASET_ROOT": "/data/preprocessed", + "DATALIST_PREFIX": "/data/dataset_blinded_", +``` + +### 1.2 Build container +The argument specifies the FQDN (Fully Qualified Domain Name) of the FL server. Use `localhost` when simulating FL on your machine. +``` +./build_docker.sh localhost +``` +Note, all code and pretrained models need to be included in the docker image. +The virtual machines running the containers will not have public internet access during training. +For an example, please see the `download_model.py` used to download ImageNet pretrained weights in this example. + +The Dockerfile will be submitted using the [MedICI platform](https://www.medici-challenges.org). +For detailed instructions, see the [challenge website](http://BreastDensityFL.acr.org). + +### 1.3 Run server and clients containers, and start training +Run all commands at once using. Note this will also create separate logs under `./logs` +``` +./run_all_fl.sh +``` +Note, the GPU index to use for each client is specified inside `run_all_fl.sh`. +See the individual `run_docker_site-*.sh` commands described below. +Note, the server script will automatically kill all running container used in this example +and final results will be placed under `./result_server`. + +(optional) Run each command in a separate terminals to get site-specific printouts in separate windows. + +The argument for each shell script specifies the GPU index to be used. +``` +./run_docker_server.sh +./run_docker_site-1.sh 0 +./run_docker_site-2.sh 1 +./run_docker_site-3.sh 0 +``` + +### 1.4 (Optional) Visualize training using TensorBoard +After training completed, the training curves can be visualized using +``` +tensorboard --logdir=./result_server +``` +A visualization of the global accuracy and [Kappa](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.cohen_kappa_score.html) validation scores for each site with the provided example data is shown below. +The current setup runs on a machine with two NVIDIA GPUs with 12GB memory each. +The runtime for this experiment is about 45 minutes. +You can adjust the argument to the `run_docker_site-*.sh` scripts to specify different +GPU indices if needed in your environment. + +![](./figs/example_data_val_global_acc_kappa.png) + +### 1.5 (Optional) Kill all containers +If you didn't use `run_all_fl.sh`, all containers can be killed by running +``` +docker kill server site-1 site-2 site-3 +``` + + +------------------------------------------------ +## 2. Modify the FL algorithm + +You can modify and extend the provided example code under [./code/pt](./code/pt). + +You could use other components available at [NVFlare](https://github.com/NVIDIA/NVFlare) +or enhance the training pipeline using your custom code or features of other libraries. + +See the [NVFlare examples](https://github.com/NVIDIA/NVFlare/tree/main/examples) for features that could be utilized in this challenge. + +### 2.1 Debugging the learning algorithm + +The example NVFlare `Learner` class is implemented at [./code/pt/learners/mammo_learner.py](./code/pt/learners/mammo_learner.py). +You can debug the file using the `MockClientEngine` as shown in the script by running +``` +python3 code/pt/learners/mammo_learner.py +``` +Furthermore, you can test it inside the container, by first running +``` +./run_docker_debug.sh +``` +Note, set `inside_container = True` to reflect the changed filepaths inside the container. + + +------------------------------------------------ +## 3. Bring your own FL framework +If you would like to use your own FL framework to participate in the challenge, +please modify the Dockerfile accordingly to include all the dependencies. + +Your container needs to provide the following scripts that implement the starting of server, clients, and finalizing of the server. +They will be executed by the system in the following order. + +### 3.1 start server +``` +/code/start_server.sh +``` + +### 3.2 start each client (in parallel) +``` +/code/start_site-1.sh +/code/start_site-2.sh +/code/start_site-3.sh +``` + +### 3.3 finalize the server +``` +/code/finalize_server.sh +``` +For an example on how the challenge system will execute these commands, see the provided `run_docker*.sh` scripts. + +### 3.4 Communication +The communication channels for FL will be restricted to the ports specified in [fl_project.yml](./code/fl_project.yml). +Your FL framework will also need those ports for implementing the communication. + +### 3.5 Results +Results will need to be written to `/result/predictions.json`. +Please follow the format produced by the reference implementation at [./result_server_example/predictions.json](./result_server_example/predictions.json) +(available after running `python3 ./code/pt/utils/download_datalists_and_predictions.py`) +The code is expected to return a json file containing at least list of image names and prediction probabilities for each breast density class +for the global model (should be named `SRV_best_FL_global_model.pt`). +``` +{ + "site-1": { + "SRV_best_FL_global_model.pt": { + ... + "test_probs": [{ + "image": "Calc-Test_P_00643_LEFT_MLO.npy", + "probs": [0.005602597258985043, 0.7612965703010559, 0.23040543496608734, 0.0026953918859362602] + }, { + ... + }, + "site-2": { + "SRV_best_FL_global_model.pt": { + ... + "test_probs": [{ + "image": "Calc-Test_P_00643_LEFT_MLO.npy", + "probs": [0.005602597258985043, 0.7612965703010559, 0.23040543496608734, 0.0026953918859362602] + }, { + ... + }, + "site-3": { + "SRV_best_FL_global_model.pt": { + ... + "test_probs": [{ + "image": "Calc-Test_P_00643_LEFT_MLO.npy", + "probs": [0.005602597258985043, 0.7612965703010559, 0.23040543496608734, 0.0026953918859362602] + }, { + ... + } +``` diff --git a/federated_learning/breast_density_challenge/build_docker.sh b/federated_learning/breast_density_challenge/build_docker.sh new file mode 100755 index 0000000000..467d587a6e --- /dev/null +++ b/federated_learning/breast_density_challenge/build_docker.sh @@ -0,0 +1,15 @@ +#!/usr/bin/env bash + +#SERVER_FQDN="localhost" +SERVER_FQDN=$1 + +if test -z "${SERVER_FQDN}" +then + echo "Usage: ./build_docker.sh [SERVER_FQDN], e.g. ./build_docker.sh localhost" + exit 1 +fi + +NEW_IMAGE=monai-nvflare:latest + +DOCKER_BUILDKIT=0 # show command outputs +docker build --network=host -t ${NEW_IMAGE} --build-arg server_fqdn=${SERVER_FQDN} -f Dockerfile . diff --git a/federated_learning/breast_density_challenge/code/configs/mammo_fedavg/config/config_fed_client.json b/federated_learning/breast_density_challenge/code/configs/mammo_fedavg/config/config_fed_client.json new file mode 100644 index 0000000000..3a2729d717 --- /dev/null +++ b/federated_learning/breast_density_challenge/code/configs/mammo_fedavg/config/config_fed_client.json @@ -0,0 +1,51 @@ +{ + "format_version": 2, + + "DATASET_ROOT": "/data/preprocessed", + "DATALIST_PREFIX": "/data/dataset_blinded_", + + "executors": [ + { + "tasks": [ + "train", "submit_model", "validate" + ], + "executor": { + "id": "Executor", + "path": "nvflare.app_common.executors.learner_executor.LearnerExecutor", + "args": { + "learner_id": "learner" + } + } + } + ], + + "task_result_filters": [ + ], + "task_data_filters": [ + ], + + "components": [ + { + "id": "learner", + "path": "pt.learners.mammo_learner.MammoLearner", + "args": { + "dataset_root": "{DATASET_ROOT}", + "datalist_prefix": "{DATALIST_PREFIX}", + "aggregation_epochs": 1, + "lr": 2e-3, + "batch_size": 64, + "val_frac": 0.1 + } + }, + { + "id": "analytic_sender", + "name": "AnalyticsSender", + "args": {} + }, + { + "id": "event_to_fed", + "name": "ConvertToFedEvent", + "args": {"events_to_convert": ["analytix_log_stats"], "fed_event_prefix": "fed."} + } + ] +} diff --git a/federated_learning/breast_density_challenge/code/configs/mammo_fedavg/config/config_fed_server.json b/federated_learning/breast_density_challenge/code/configs/mammo_fedavg/config/config_fed_server.json new file mode 100644 index 0000000000..37f0e84abc --- /dev/null +++ b/federated_learning/breast_density_challenge/code/configs/mammo_fedavg/config/config_fed_server.json @@ -0,0 +1,88 @@ +{ + "format_version": 2, + + "min_clients": 3, + "num_rounds": 100, + + "server": { + "heart_beat_timeout": 600 + }, + "task_data_filters": [], + "task_result_filters": [], + "components": [ + { + "id": "persistor", + "name": "PTFileModelPersistor", + "args": { + "model": { + "path": "monai.networks.nets.TorchVisionFCModel", + "args": { + "model_name": "resnet18", + "n_classes": 4, + "use_conv": false, + "pretrained": true, + "pool": null + } + } + } + }, + { + "id": "shareable_generator", + "name": "FullModelShareableGenerator", + "args": {} + }, + { + "id": "aggregator", + "name": "InTimeAccumulateWeightedAggregator", + "args": {} + }, + { + "id": "model_selector", + "name": "IntimeModelSelectionHandler", + "args": {} + }, + { + "id": "model_locator", + "name": "PTFileModelLocator", + "args": { + "pt_persistor_id": "persistor" + } + }, + { + "id": "json_generator", + "name": "ValidationJsonGenerator", + "args": {} + }, + { + "id": "tb_analytics_receive", + "name": "TBAnalyticsReceiver", + "args": {"events": ["fed.analytix_log_stats"]} + } + ], + "workflows": [ + { + "id": "scatter_gather_ctl", + "name": "ScatterAndGather", + "args": { + "min_clients" : "{min_clients}", + "num_rounds" : "{num_rounds}", + "start_round": 0, + "wait_time_after_min_received": 10, + "aggregator_id": "aggregator", + "persistor_id": "persistor", + "shareable_generator_id": "shareable_generator", + "train_task_name": "train", + "train_timeout": 0 + } + }, + { + "id": "global_model_eval", + "name": "GlobalModelEval", + "args": { + "model_locator_id": "model_locator", + "validation_timeout": 6000, + "cleanup_models": true + } + } + ] +} diff --git a/federated_learning/breast_density_challenge/code/finalize_server.sh b/federated_learning/breast_density_challenge/code/finalize_server.sh new file mode 100755 index 0000000000..2570a65f30 --- /dev/null +++ b/federated_learning/breast_density_challenge/code/finalize_server.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +SERVER="server" +echo "FINALIZING ${CLIENT_NAME}" +cp -r ./fl_workspace/${SERVER}/run_1 /result/. +cp ./fl_workspace/${SERVER}/*.txt /result/. +cp ./fl_workspace/*_log.txt /result/. +cp ./fl_workspace/${SERVER}/run_1/cross_site_val/cross_val_results.json /result/predictions.json # only file required for leaderboard computation +# TODO: might need some more standardization of the result folder diff --git a/federated_learning/breast_density_challenge/code/fl_project.yml b/federated_learning/breast_density_challenge/code/fl_project.yml new file mode 100644 index 0000000000..466bd3f9bf --- /dev/null +++ b/federated_learning/breast_density_challenge/code/fl_project.yml @@ -0,0 +1,60 @@ +api_version: 2 +name: fl_project +description: NVFlare sample project yaml file + +participants: + # change example.com to the FQDN of the server + - name: {SERVER_FQDN} + type: server + org: nvflare + fed_learn_port: 8002 + admin_port: 8003 + - name: site-1 + type: client + org: nvflare + - name: site-2 + type: client + org: nvflare + - name: site-3 + type: client + org: nvflare + - name: admin@nvflare.com + type: admin + org: nvflare + roles: + - super + +# The same methods in all builders are called in their order defined in builders section +builders: + - path: nvflare.lighter.impl.workspace.WorkspaceBuilder + args: + template_file: master_template.yml + - path: nvflare.lighter.impl.template.TemplateBuilder + - path: nvflare.lighter.impl.static_file.StaticFileBuilder + args: + # config_folder can be set to inform NVFlare where to get configuration + config_folder: config + # when docker_image is set to a docker image name, docker.sh will be generated on server/client/admin + # docker_image: + - path: nvflare.lighter.impl.auth_policy.AuthPolicyBuilder + args: + orgs: + nvflare: + - relaxed + roles: + super: super user of system + groups: + relaxed: + desc: org group with relaxed policies + rules: + allow_byoc: true + allow_custom_datalist: true + disabled: false + - path: nvflare.lighter.impl.cert.CertBuilder + - path: nvflare.lighter.impl.he.HEBuilder + args: + poly_modulus_degree: 8192 + coeff_mod_bit_sizes: [60, 40, 40] + scale_bits: 40 + scheme: CKKS + - path: nvflare.lighter.impl.signature.SignatureBuilder diff --git a/federated_learning/breast_density_challenge/code/pt/learners/mammo_learner.py b/federated_learning/breast_density_challenge/code/pt/learners/mammo_learner.py new file mode 100644 index 0000000000..eca65f5553 --- /dev/null +++ b/federated_learning/breast_density_challenge/code/pt/learners/mammo_learner.py @@ -0,0 +1,696 @@ +# Copyright 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import numpy as np +import torch +import torch.optim as optim +from monai.data import CacheDataset, DataLoader +from monai.networks.nets import TorchVisionFCModel +from monai.transforms import ( + Compose, + EnsureTyped, + LoadImaged, + RandFlipd, + RandGaussianNoised, + RandGaussianSmoothd, + RandRotated, + RandScaleIntensityd, + RandShiftIntensityd, + RandZoomd, + Transposed, +) +from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable +from nvflare.apis.fl_constant import FLContextKey, ReturnCode +from nvflare.apis.fl_context import FLContext, FLContextManager +from nvflare.apis.shareable import ReservedHeaderKey, Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.abstract.learner_spec import Learner +from nvflare.app_common.app_constant import AppConstants, ModelName, ValidateType +from sklearn.metrics import cohen_kappa_score +from torch.utils.tensorboard import SummaryWriter + + +def load_datalist(filename, data_list_key="train", base_dir=""): + with open(filename, "r") as f: + data = json.load(f) + + data_list = data[data_list_key] + for d in data_list: + d["image"] = os.path.join(base_dir, d["image"]) + + return data_list + + +class MammoLearner(Learner): + def __init__( + self, + dataset_root: str = None, + datalist_prefix: str = None, + aggregation_epochs: int = 1, + train_task_name: str = AppConstants.TASK_TRAIN, + submit_model_task_name: str = AppConstants.TASK_SUBMIT_MODEL, + lr: float = 1e-2, + batch_size: int = 32, + val_freq: int = 1, + val_frac: float = 0.1, + analytic_sender_id: str = "analytic_sender", + ): + """Simple CIFAR-10 Trainer. + + Args: + dataset_root: directory with breast density mammography data. + datalist_prefix: json file with data list + aggregation_epochs: the number of training epochs for a round. Defaults to 1. + train_task_name: name of the task to train the model. + submit_model_task_name: name of the task to submit the best local model. + lr: local learning rate. Float number. Defaults to 1e-2. + val_freq: int. How often to validate during local training + val_frac: float. Fraction of training set to reserve for validation/model selection + analytic_sender_id: id of `AnalyticsSender` if configured as a client component. If configured, TensorBoard events will be fired. Defaults to "analytic_sender". + Returns: + a Shareable with the updated local model after running `execute()` + or the best local model depending on the specified task. + """ + super().__init__() + # trainer init happens at the very beginning, only the basic info regarding the trainer is set here + # the actual run has not started at this point + self.dataset_root = dataset_root + self.datalist_prefix = datalist_prefix + self.aggregation_epochs = aggregation_epochs + self.train_task_name = train_task_name + self.lr = lr + self.batch_size = batch_size + self.val_freq = val_freq + self.submit_model_task_name = submit_model_task_name + self.best_metric = 0.0 + self.val_frac = val_frac + self.analytic_sender_id = analytic_sender_id + + # Epoch counter + self.epoch_of_start_time = 0 + self.epoch_global = 0 + + if not isinstance(self.val_freq, int): + raise ValueError(f"Expected `val_freq` but got type {type(self.val_freq)}") + + # The following objects will be build in `initialize()` + self.app_root = None + self.client_id = None + self.local_model_file = None + self.best_local_model_file = None + self.writer = None + self.device = None + self.model = None + self.optimizer = None + self.criterion = None + self.transform_train = None + self.transform_valid = None + self.transform_test = None + self.train_dataset = None + self.train_loader = None + self.valid_dataset = None + self.valid_loader = None + self.test_dataset = None + self.test_loader = None + + def initialize(self, parts: dict, fl_ctx: FLContext): + # when the run starts, this is where the actual settings get initialized for trainer + + # Set the paths according to fl_ctx + self.app_root = fl_ctx.get_prop(FLContextKey.APP_ROOT) + fl_args = fl_ctx.get_prop(FLContextKey.ARGS) + self.client_id = fl_ctx.get_identity_name() + self.log_info( + fl_ctx, + f"Client {self.client_id} initialized at \n {self.app_root} \n with args: {fl_args}", + ) + + self.local_model_file = os.path.join(self.app_root, "local_model.pt") + self.best_local_model_file = os.path.join(self.app_root, "best_local_model.pt") + + # Select local TensorBoard writer or event-based writer for streaming + self.writer = parts.get( + self.analytic_sender_id + ) # user configured config_fed_client.json for streaming + if not self.writer: # use local TensorBoard writer only + self.writer = SummaryWriter(self.app_root) + + # set the training-related parameters + # can be replaced by a config-style block + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.model = TorchVisionFCModel( + "resnet18", n_classes=4, use_conv=False, pretrained=False, pool=None + ) # pretrained is used only on server + self.model = self.model.to(self.device) + self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=0.9) + self.criterion = torch.nn.CrossEntropyLoss() + self.criterion = self.criterion.to(self.device) + + self.transform_train = Compose( + [ + LoadImaged(keys=["image"]), + Transposed(keys=["image"], indices=[2, 0, 1]), # make channels-first + RandRotated( + keys=["image"], range_x=np.pi / 12, prob=0.5, keep_size=True + ), + RandFlipd(keys=["image"], spatial_axis=0, prob=0.5), + RandFlipd(keys=["image"], spatial_axis=1, prob=0.5), + RandZoomd( + keys=["image"], min_zoom=0.9, max_zoom=1.1, prob=0.5, keep_size=True + ), + RandGaussianSmoothd( + keys=["image"], + sigma_x=(0.5, 1.15), + sigma_y=(0.5, 1.15), + sigma_z=(0.5, 1.15), + prob=0.15, + ), + RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.5), + RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5), + RandGaussianNoised(keys=["image"], std=0.01, prob=0.15), + EnsureTyped(keys=["image", "label"]), + ] + ) + self.transform_valid = Compose( + [ + LoadImaged(keys=["image"]), + Transposed(keys=["image"], indices=[2, 0, 1]), # make channels-first + EnsureTyped(keys=["image", "label"]), + ] + ) + self.transform_test = Compose( + [ + LoadImaged(keys=["image"]), + Transposed(keys=["image"], indices=[2, 0, 1]), # make channels-first + EnsureTyped(keys=["image"]), # Testing set won't have labels + ] + ) + + # Note, do not change this syntax. The data list filename is given by the system. + datalist_file = self.datalist_prefix + self.client_id + ".json" + if not os.path.isfile(datalist_file): + self.log_critical(fl_ctx, f"{datalist_file} does not exist!") + + # Set dataset + train_datalist = load_datalist( + datalist_file, + data_list_key="train", # do not change this key name + base_dir=self.dataset_root, + ) + + # Validation set can be created from training set. + if self.val_frac > 0: + np.random.seed(0) + val_indices = np.random.randint( + 0, len(train_datalist), size=int(self.val_frac * len(train_datalist)) + ) + val_datalist = [train_datalist[i] for i in val_indices] + train_indices = list(set(np.arange(len(train_datalist))) - set(val_indices)) + train_datalist = [ + train_datalist[i] for i in train_indices + ] # remove validation entries from training + assert (len(np.intersect1d(val_indices, train_indices))) == 0 + self.log_info( + fl_ctx, + f"Reserved {len(val_indices)} entries for validation during training.", + ) + elif self.val_frac >= 1.0: + raise ValueError( + f"`val_frac` was {self.val_frac}. Cannot use whole training set for validation, use 0 > `val_frac` < 1." + ) + else: + val_datalist = [] + + test_datalist = load_datalist( + datalist_file, + data_list_key="test", # do not change this key name + base_dir=self.dataset_root, + ) + + num_workers = 1 + cache_rate = 1.0 + self.train_dataset = CacheDataset( + data=train_datalist, + transform=self.transform_train, + cache_rate=cache_rate, + num_workers=num_workers, + ) + self.train_loader = DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=num_workers, + ) + self.log_info(fl_ctx, f"Training set: {len(train_datalist)} entries") + + if len(val_datalist) > 0: + self.valid_dataset = CacheDataset( + data=val_datalist, + transform=self.transform_valid, + cache_rate=cache_rate, + num_workers=num_workers, + ) + self.valid_loader = DataLoader( + self.valid_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=num_workers, + ) + self.log_info(fl_ctx, f"Validation set: {len(train_datalist)} entries") + else: + self.valid_dataset = None + self.valid_loader = None + self.log_info(fl_ctx, "Use no validation set") + + # evaluation on testing is required + self.test_dataset = CacheDataset( + data=test_datalist, + transform=self.transform_test, + cache_rate=cache_rate, + num_workers=num_workers, + ) + self.test_loader = DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=num_workers, + ) + self.log_info(fl_ctx, f"Testing set: {len(train_datalist)} entries") + + self.log_info(fl_ctx, f"Finished initializing {self.client_id}") + + def finalize(self, fl_ctx: FLContext): + # collect threads, close files here + pass + + def local_train( + self, fl_ctx, train_loader, abort_signal: Signal, val_freq: int = 0 + ): + for epoch in range(self.aggregation_epochs): + if abort_signal.triggered: + return + self.model.train() + epoch_len = len(train_loader) + self.epoch_global = self.epoch_of_start_time + epoch + self.log_info( + fl_ctx, + f"Local epoch {self.client_id}: {epoch + 1}/{self.aggregation_epochs} (lr={self.lr})", + ) + avg_loss = 0.0 + for i, batch_data in enumerate(train_loader): + if abort_signal.triggered: + return + inputs, labels = ( + batch_data["image"].to(self.device), + batch_data["label"].to(self.device), + ) + # zero the parameter gradients + self.optimizer.zero_grad() + # forward + backward + optimize + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + + loss.backward() + self.optimizer.step() + current_step = epoch_len * self.epoch_global + i + avg_loss += loss.item() + self.writer.add_scalar( + "train_loss", avg_loss / len(train_loader), current_step + ) + if val_freq > 0 and epoch % val_freq == 0: + acc, kappa = self.local_valid( + self.valid_loader, + abort_signal, + tb_id="val_acc_local_model", + fl_ctx=fl_ctx, + ) + if kappa > self.best_metric: + self.best_metric = kappa + self.save_model(is_best=True) + + def save_model(self, is_best=False): + # save model + model_weights = self.model.state_dict() + save_dict = {"model_weights": model_weights, "epoch": self.epoch_global} + if is_best: + save_dict.update({"best_acc": self.best_metric}) + torch.save(save_dict, self.best_local_model_file) + else: + torch.save(save_dict, self.local_model_file) + + def train( + self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal + ) -> Shareable: + # Check abort signal + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + # get round information + current_round = shareable.get_header(AppConstants.CURRENT_ROUND) + total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS) + self.log_info( + fl_ctx, f"Current/Total Round: {current_round + 1}/{total_rounds}" + ) + self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}") + + # update local model weights with received weights + dxo = from_shareable(shareable) + global_weights = dxo.data + + # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation. + local_var_dict = self.model.state_dict() + model_keys = global_weights.keys() + for var_name in local_var_dict: + if var_name in model_keys: + weights = global_weights[var_name] + try: + # reshape global weights to compute difference later on + global_weights[var_name] = np.reshape( + weights, local_var_dict[var_name].shape + ) + # update the local dict + local_var_dict[var_name] = torch.as_tensor(global_weights[var_name]) + except Exception as e: + raise ValueError( + "Convert weight from {} failed with error: {}".format( + var_name, str(e) + ) + ) + self.model.load_state_dict(local_var_dict) + + # local steps + epoch_len = len(self.train_loader) + self.log_info(fl_ctx, f"Local steps per epoch: {epoch_len}") + + # local train + self.local_train( + fl_ctx=fl_ctx, + train_loader=self.train_loader, + abort_signal=abort_signal, + val_freq=self.val_freq, + ) + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + self.epoch_of_start_time += self.aggregation_epochs + + # perform valid after local train + acc, kappa = self.local_valid( + self.valid_loader, abort_signal, tb_id="val_local_model", fl_ctx=fl_ctx + ) + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + self.log_info(fl_ctx, f"val_acc_local_model: {acc:.4f}") + + # save model + self.save_model(is_best=False) + if kappa > self.best_metric: + self.best_metric = kappa + self.save_model(is_best=True) + + # compute delta model, global model has the primary key set + local_weights = self.model.state_dict() + model_diff = {} + for name in global_weights: + if name not in local_weights: + continue + model_diff[name] = local_weights[name].cpu().numpy() - global_weights[name] + if np.any(np.isnan(model_diff[name])): + self.system_panic(f"{name} weights became NaN...", fl_ctx) + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + # build the shareable + dxo = DXO(data_kind=DataKind.WEIGHT_DIFF, data=model_diff) + dxo.set_meta_prop(MetaKey.NUM_STEPS_CURRENT_ROUND, epoch_len) + + self.log_info(fl_ctx, "Local epochs finished. Returning shareable") + return dxo.to_shareable() + + def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable: + # Retrieve the best local model saved during training. + if model_name == ModelName.BEST_MODEL: + model_data = None + try: + # load model to cpu as server might or might not have a GPU + model_data = torch.load(self.best_local_model_file, map_location="cpu") + except Exception as e: + self.log_error(fl_ctx, f"Unable to load best model: {e}") + + # Create DXO and shareable from model data. + if model_data: + dxo = DXO(data_kind=DataKind.WEIGHTS, data=model_data["model_weights"]) + return dxo.to_shareable() + else: + # Set return code. + self.log_error( + fl_ctx, + f"best local model not found at {self.best_local_model_file}.", + ) + return make_reply(ReturnCode.EXECUTION_RESULT_ERROR) + else: + raise ValueError( + f"Unknown model_type: {model_name}" + ) # Raised errors are caught in LearnerExecutor class. + + def local_valid( + self, + valid_loader, + abort_signal: Signal, + tb_id=None, + return_probs_only=False, + fl_ctx=None, + ): + if not valid_loader: + return None + self.model.eval() + return_probs = [] + labels = [] + pred_labels = [] + with torch.no_grad(): + correct, total = 0, 0 + for i, batch_data in enumerate(valid_loader): + if abort_signal.triggered: + return None + inputs = batch_data["image"].to(self.device) + outputs = torch.softmax(self.model(inputs), dim=1) + probs = outputs.detach().cpu().numpy() + # make json serializable + for _img_file, _probs in zip( + batch_data["image_meta_dict"]["filename_or_obj"], probs + ): + return_probs.append( + { + "image": os.path.basename(_img_file), + "probs": [float(p) for p in _probs], + } + ) + if not return_probs_only: + _, _pred_label = torch.max(outputs.data, 1) + _labels = batch_data["label"].to(self.device) + total += inputs.data.size()[0] + correct += (_pred_label == _labels.data).sum().item() + labels.extend(_labels.detach().cpu().numpy()) + pred_labels.extend(_pred_label.detach().cpu().numpy()) + if return_probs_only: + return return_probs # create a list of image names and probs + else: + acc = correct / float(total) + assert len(labels) == total + assert len(pred_labels) == total + kappa = cohen_kappa_score(labels, pred_labels, weights="linear") + if tb_id: + self.writer.add_scalar(tb_id + "_acc", acc, self.epoch_global) + self.writer.add_scalar(tb_id + "_kappa", kappa, self.epoch_global) + return acc, kappa + + def validate( + self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal + ) -> Shareable: + # Check abort signal + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + + # get validation information + self.log_info(fl_ctx, f"Client identity: {fl_ctx.get_identity_name()}") + model_owner = shareable.get(ReservedHeaderKey.HEADERS).get( + AppConstants.MODEL_OWNER + ) + if model_owner: + self.log_info( + fl_ctx, + f"Evaluating model from {model_owner} on {fl_ctx.get_identity_name()}", + ) + else: + model_owner = "global_model" # evaluating global model during training + + # update local model weights with received weights + dxo = from_shareable(shareable) + global_weights = dxo.data + + # Before loading weights, tensors might need to be reshaped to support HE for secure aggregation. + local_var_dict = self.model.state_dict() + model_keys = global_weights.keys() + n_loaded = 0 + for var_name in local_var_dict: + if var_name in model_keys: + weights = torch.as_tensor(global_weights[var_name], device=self.device) + try: + # update the local dict + local_var_dict[var_name] = torch.as_tensor( + torch.reshape(weights, local_var_dict[var_name].shape) + ) + n_loaded += 1 + except Exception as e: + raise ValueError( + "Convert weight from {} failed with error: {}".format( + var_name, str(e) + ) + ) + self.model.load_state_dict(local_var_dict) + if n_loaded == 0: + raise ValueError( + f"No weights loaded for validation! Received weight dict is {global_weights}" + ) + + validate_type = shareable.get_header(AppConstants.VALIDATE_TYPE) + if validate_type == ValidateType.BEFORE_TRAIN_VALIDATE: + try: + # perform valid before local train + global_acc, global_kappa = self.local_valid( + self.valid_loader, + abort_signal, + tb_id="val_global_model", + fl_ctx=fl_ctx, + ) + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + self.log_info( + fl_ctx, f"val_acc_global_model ({model_owner}): {global_acc}" + ) + + return DXO( + data_kind=DataKind.METRICS, + data={MetaKey.INITIAL_METRICS: global_acc}, + meta={}, + ).to_shareable() + except Exception as e: + raise ValueError(f"BEFORE_TRAIN_VALIDATE failed: {e}") + elif validate_type == ValidateType.MODEL_VALIDATE: + try: + # perform valid + train_acc, train_kappa = self.local_valid( + self.train_loader, abort_signal + ) + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + self.log_info(fl_ctx, f"training acc ({model_owner}): {train_acc}") + + val_acc, val_kappa = self.local_valid(self.valid_loader, abort_signal) + + # testing performance + test_probs = self.local_valid( + self.test_loader, abort_signal, return_probs_only=True + ) + if abort_signal.triggered: + return make_reply(ReturnCode.TASK_ABORTED) + self.log_info(fl_ctx, f"validation acc ({model_owner}): {val_acc}") + + self.log_info(fl_ctx, "Evaluation finished. Returning shareable") + + val_results = { + "train_accuracy": train_acc, + "train_kappa": train_kappa, + "val_accuracy": val_acc, + "val_kappa": val_kappa, + "test_probs": test_probs, + } + + metric_dxo = DXO(data_kind=DataKind.METRICS, data=val_results) + return metric_dxo.to_shareable() + except Exception as e: + raise ValueError(f"MODEL_VALIDATE failed: {e}") + else: + return make_reply(ReturnCode.VALIDATE_TYPE_UNKNOWN) + + +# To test your Learner + +class MockClientEngine: + def __init__(self, run_num=0): + self.fl_ctx_mgr = FLContextManager( + engine=self, + identity_name="site-1", + run_num=run_num, + public_stickers={}, + private_stickers={}, + ) + + def new_context(self): + return self.fl_ctx_mgr.new_context() + + def fire_event(self, event_type: str, fl_ctx: FLContext): + pass + + +if __name__ == "__main__": + inside_container = True + if inside_container: + debug_dataset_root = "/data/preprocessed" + debug_datalist_prefix = "/data/dataset_blinded_phase1_" + else: + # assumes script is run in from repo root, e.g. using `python3 code/pt/learners/mammo_learner.py` + debug_dataset_root = "./data/preprocessed" + debug_datalist_prefix = "./data/dataset_blinded_phase1_" + + print("Testing MammoLearner...") + learner = MammoLearner( + dataset_root=debug_dataset_root, + datalist_prefix=debug_datalist_prefix, + aggregation_epochs=1, + lr=1e-2, + ) + engine = MockClientEngine() + fl_ctx = engine.fl_ctx_mgr.new_context() + fl_ctx.set_prop(FLContextKey.APP_ROOT, "/tmp/debug") + + print("test initialize...") + learner.initialize(parts={}, fl_ctx=fl_ctx) + + print("test train...") + learner.local_train( + fl_ctx=fl_ctx, + train_loader=learner.train_loader, + abort_signal=Signal(), + val_freq=1, + ) + + print("test valid...") + acc, kappa = learner.local_valid( + valid_loader=learner.valid_loader, + abort_signal=Signal(), + tb_id="val_debug", + fl_ctx=fl_ctx, + ) + print("debug acc", acc) + print("debug kappa", kappa) + + print("test valid...") + test_probs = learner.local_valid( + valid_loader=learner.test_loader, abort_signal=Signal(), return_probs_only=True + ) + print("test_probs", test_probs) + + print("finished testing.") + + # you can check the result for one epoch and validation on TensorBoard using + # `tensorboard --logdir=./debug` diff --git a/federated_learning/breast_density_challenge/code/pt/utils/download_datalists_and_predictions.py b/federated_learning/breast_density_challenge/code/pt/utils/download_datalists_and_predictions.py new file mode 100644 index 0000000000..41721b7a15 --- /dev/null +++ b/federated_learning/breast_density_challenge/code/pt/utils/download_datalists_and_predictions.py @@ -0,0 +1,16 @@ +# Copyright 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from monai.apps.utils import download_url, download_and_extract + + +download_and_extract(url="https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/dataset_lists.zip", output_dir="./data") +download_url(url="https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/predictions.json", filepath="./result_server_example/predictions.json") diff --git a/federated_learning/breast_density_challenge/code/pt/utils/download_model.py b/federated_learning/breast_density_challenge/code/pt/utils/download_model.py new file mode 100644 index 0000000000..4ea50e70cf --- /dev/null +++ b/federated_learning/breast_density_challenge/code/pt/utils/download_model.py @@ -0,0 +1,25 @@ +# Copyright 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from torch.utils.model_zoo import load_url as load_state_dict_from_url + +parser = argparse.ArgumentParser() +parser.add_argument( + "--model_url", + type=str, + default="https://download.pytorch.org/models/resnet18-f37072fd.pth", +) +args = parser.parse_args() + +# will download +model = load_state_dict_from_url(args.model_url) diff --git a/federated_learning/breast_density_challenge/code/pt/utils/preprocess_dicom.py b/federated_learning/breast_density_challenge/code/pt/utils/preprocess_dicom.py new file mode 100644 index 0000000000..752835f58a --- /dev/null +++ b/federated_learning/breast_density_challenge/code/pt/utils/preprocess_dicom.py @@ -0,0 +1,67 @@ +# Copyright 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import cv2 +import numpy as np +import pydicom +import skimage.io + + +def dicom_preprocess(dicom_file, save_prefix): + try: + # Read needed dicom tags + ds = pydicom.dcmread(dicom_file) # , stop_before_pixels=True) + try: + code = ds.ViewCodeSequence[0].ViewModifierCodeSequence[0].CodeMeaning + except BaseException: + code = None + + # Filter image + dc_tags = f"BS={ds.BitsStored};PI={ds.PhotometricInterpretation};Modality={ds.Modality};PatientOrientation={ds.PatientOrientation};Code={code}" + if ds.PatientOrientation == "MLO" or ds.PatientOrientation == "CC": + curr_img = ds.pixel_array + curr_img = np.squeeze(curr_img).T.astype(np.float) + + # Can be modified as well to handle other bit and monochrome combinations + if (ds.BitsStored == 16) and "2" in ds.PhotometricInterpretation: + curr_img = curr_img / 65535.0 + else: + raise ValueError(dicom_file + " - unsupported dicom tags: " + dc_tags) + + # Resize and replicate into 3 channels + curr_img = cv2.resize(curr_img, (224, 224)) + curr_img = np.concatenate( + ( + curr_img[:, :, np.newaxis], + curr_img[:, :, np.newaxis], + curr_img[:, :, np.newaxis], + ), + axis=-1, + ) + # Save output file + assert curr_img.min() >= 0 and curr_img.max() <= 1.0 + + os.makedirs(os.path.dirname(save_prefix), exist_ok=True) + np.save(save_prefix + ".npy", curr_img.astype(np.float32)) + skimage.io.imsave( + save_prefix + ".png", (255 * curr_img / curr_img.max()).astype(np.uint8) + ) + else: + raise ValueError( + "Error: " + dicom_file + " - not a valid image file: " + dc_tags + ) + except BaseException as e: + print(f"[WARNING] Reading {dicom_file} failed with Exception: {e}") + return False, f"{dicom_file} failed" + + return True, dc_tags diff --git a/federated_learning/breast_density_challenge/code/pt/utils/preprocess_dicomdir.py b/federated_learning/breast_density_challenge/code/pt/utils/preprocess_dicomdir.py new file mode 100644 index 0000000000..66c76da592 --- /dev/null +++ b/federated_learning/breast_density_challenge/code/pt/utils/preprocess_dicomdir.py @@ -0,0 +1,304 @@ +# Copyright 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import json +import os +import random + +import numpy as np +import pandas as pd +from preprocess_dicom import dicom_preprocess +from sklearn.model_selection import GroupKFold + +# density labels +# 1 - fatty +# 2 - scattered fibroglandular density +# 3 - heterogeneously dense +# 4 - extremely dense + + +def preprocess(dicom_root, out_path, ids, images, densities, process_image=True): + data_list = [] + dc_tags = [] + saved_filenames = [] + assert len(ids) == len(images) == len(densities) + for i, (id, image, density) in enumerate(zip(ids, images, densities)): + if (i + 1) % 200 == 0: + print(f"processing {i+1} of {len(ids)}...") + dir_name = image.split(os.path.sep)[0] + img_file = glob.glob( + os.path.join(dicom_root, dir_name, "**", "*.dcm"), recursive=True + ) + assert len(img_file) == 1, f"No unique dicom image found for {dir_name}!" + save_prefix = os.path.join(out_path, dir_name) + if process_image: + _success, _dc_tags = dicom_preprocess(img_file[0], save_prefix) + else: + if os.path.isfile(save_prefix + ".npy"): + _success = True + else: + _success = False + _dc_tags = [] + if _success and density >= 1: # label can be 0 sometimes, excluding those cases + dc_tags.append(_dc_tags) + data_list.append( + { + "patient_id": id, + "image": dir_name + ".npy", + "label": int(density - 1), + } + ) + saved_filenames.append(dir_name + ".npy") + return data_list, dc_tags, saved_filenames + + +def write_datalist(save_datalist_file, data_set): + os.makedirs(os.path.dirname(save_datalist_file), exist_ok=True) + with open(save_datalist_file, "w") as f: + json.dump(data_set, f, indent=4) + print(f"Data list saved at {save_datalist_file}") + + +def get_indices(all_ids, search_ids): + indices = [] + for _id in search_ids: + _indices = np.where(all_ids == _id) + indices.extend(_indices[0].tolist()) + return indices + + +def main(): + process_image = True # set False if dicoms have already been preprocessed + + out_path = "./data/preprocessed" # YOUR DEST FOLDER SHOULD BE WRITTEN HERE + out_dataset_prefix = "./data/dataset" + + # Input folders + label_root = "/media/hroth/Elements/NVIDIA/Data/CBIS-DDSM/" + dicom_root = "/media/hroth/Elements/NVIDIA/Data/CBIS-DDSM/DICOM/manifest-ZkhPvrLo5216730872708713142/CBIS-DDSM" + n_clients = 3 + + """ Run preprocessing """ + + """ 1. Load the label data """ + random.seed(0) + + label_files = [ + os.path.join(label_root, "mass_case_description_train_set.csv"), + os.path.join(label_root, "calc_case_description_train_set.csv"), + os.path.join(label_root, "mass_case_description_test_set.csv"), + os.path.join(label_root, "calc_case_description_test_set.csv"), + ] + + breast_densities = [] + patients_ids = [] + image_file_path = [] + + # read annotations + for label_file in label_files: + print(f"add {label_file}") + label_data = pd.read_csv(label_file) + unique_images, unique_indices = np.unique( + label_data["image file path"], return_index=True + ) + print( + f"including {len(unique_images)} unique images of {len(label_data['image file path'])} image entries" + ) + + try: + breast_densities.extend(label_data["breast_density"][unique_indices]) + except BaseException: + breast_densities.extend(label_data["breast density"][unique_indices]) + patients_ids.extend(label_data["patient_id"][unique_indices]) + image_file_path.extend(label_data["image file path"][unique_indices]) + + assert len(breast_densities) == len(patients_ids) == len(image_file_path), ( + f"Mismatch between label data, breast_densities: " + f"{len(breast_densities)}, patients_ids: {len(patients_ids)}, image_file_path: {len(image_file_path)}" + ) + print(f"Read {len(image_file_path)} data entries.") + + """ 2. Split the data """ + + # shuffle data + label_data = list(zip(breast_densities, patients_ids, image_file_path)) + random.shuffle(label_data) + breast_densities, patients_ids, image_file_path = zip(*label_data) + + # Split data + breast_densities = np.array(breast_densities) + patients_ids = np.array(patients_ids) + image_file_path = np.array(image_file_path) + + unique_patient_ids = np.unique(patients_ids) + n_patients = len(unique_patient_ids) + print(f"Found {n_patients} patients.") + + # generate splits using roughly the same ratios as for challenge data: + n_train_challenge = 60_000 + n_val_challenge = 6_500 + n_test_challenge = 40_000 + test_ratio = n_test_challenge / ( + n_train_challenge + n_val_challenge + n_test_challenge + ) + val_ratio = n_val_challenge / ( + n_val_challenge + n_test_challenge + ) # test cases will be removed at this point + + # use groups to avoid patient overlaps + # test split + n_splits = int(np.ceil(len(image_file_path) / (len(image_file_path) * test_ratio))) + print( + f"Splitting into {n_splits} folds for test split. (Only the first fold is used.)" + ) + group_kfold = GroupKFold(n_splits=n_splits) + for train_val_index, test_index in group_kfold.split( + image_file_path, breast_densities, groups=patients_ids + ): + break # just use first fold + test_images = image_file_path[test_index] + test_patients_ids = patients_ids[test_index] + test_densities = breast_densities[test_index] + + # train/val splits + train_val_images = image_file_path[train_val_index] + train_val_patients_ids = patients_ids[train_val_index] + train_val_densities = breast_densities[train_val_index] + + n_splits = int(np.ceil(len(image_file_path) / (len(image_file_path) * val_ratio))) + print( + f"Splitting into {n_splits} folds for train/val splits. (Only the first fold is used.)" + ) + group_kfold = GroupKFold(n_splits=n_splits) + for train_index, val_index in group_kfold.split( + train_val_images, train_val_densities, groups=train_val_patients_ids + ): + break # just use first fold + + train_images = train_val_images[train_index] + train_patients_ids = train_val_patients_ids[train_index] + train_densities = train_val_densities[train_index] + + val_images = train_val_images[val_index] + val_patients_ids = train_val_patients_ids[val_index] + val_densities = train_val_densities[val_index] + + # check that there is no patient overlap + assert ( + len(np.intersect1d(train_patients_ids, val_patients_ids)) == 0 + ), "Overlapping patients in train and validation!" + assert ( + len(np.intersect1d(train_patients_ids, test_patients_ids)) == 0 + ), "Overlapping patients in train and test!" + assert ( + len(np.intersect1d(val_patients_ids, test_patients_ids)) == 0 + ), "Overlapping patients in validation and test!" + + n_total = len(train_images) + len(val_images) + len(test_images) + print(20 * "-") + print(f"Train : {len(train_images)} ({100*len(train_images)/n_total:.2f}%)") + print(f"Val : {len(val_images)} ({100*len(val_images)/n_total:.2f}%)") + print(f"Test : {len(test_images)} ({100*len(test_images)/n_total:.2f}%)") + print(20 * "-") + print(f"Total : {n_total}") + assert n_total == len(image_file_path), ( + f"mismatch between total split images ({n_total})" + f" and length of all images {len(image_file_path)}!" + ) + + """ split train/validation dataset for n_clients """ + # Split and avoid patient overlap + unique_train_patients_ids = np.unique(train_patients_ids) + split_train_patients_ids = np.array_split(unique_train_patients_ids, n_clients) + + unique_val_patients_ids = np.unique(val_patients_ids) + split_val_patients_ids = np.array_split(unique_val_patients_ids, n_clients) + + unique_test_patients_ids = np.unique(test_patients_ids) + split_test_patients_ids = np.array_split(unique_test_patients_ids, n_clients) + + """ 3. Preprocess the images """ + dc_tags = [] + saved_filenames = [] + for c in range(n_clients): + site_name = f"site-{c+1}" + print(f"Preprocessing training set of client {site_name}") + _curr_patient_ids = split_train_patients_ids[c] + _curr_indices = get_indices(train_patients_ids, _curr_patient_ids) + train_list, _dc_tags, _saved_filenames = preprocess( + dicom_root, + out_path, + train_patients_ids[_curr_indices], + train_images[_curr_indices], + train_densities[_curr_indices], + process_image=process_image, + ) + print( + f"Converted {len(train_list)} of {len(train_patients_ids)} training images" + ) + dc_tags.extend(_dc_tags) + saved_filenames.extend(_saved_filenames) + + print("Preprocessing validation") + _curr_patient_ids = split_val_patients_ids[c] + _curr_indices = get_indices(val_patients_ids, _curr_patient_ids) + val_list, _dc_tags, _saved_filenames = preprocess( + dicom_root, + out_path, + val_patients_ids[_curr_indices], + val_images[_curr_indices], + val_densities[_curr_indices], + process_image=process_image, + ) + print(f"Converted {len(val_list)} of {len(val_patients_ids)} validation images") + dc_tags.extend(_dc_tags) + saved_filenames.extend(_saved_filenames) + + print("Preprocessing testing") + _curr_patient_ids = split_test_patients_ids[c] + _curr_indices = get_indices(test_patients_ids, _curr_patient_ids) + test_list, _dc_tags, _saved_filenames = preprocess( + dicom_root, + out_path, + test_patients_ids[_curr_indices], + test_images[_curr_indices], + test_densities[_curr_indices], + process_image=process_image, + ) + print(f"Converted {len(test_list)} of {len(test_patients_ids)} testing images") + dc_tags.extend(_dc_tags) + saved_filenames.extend(_saved_filenames) + + data_set = { + "train": train_list, # will stay the same for both phases + "test1": val_list, # like phase 1 leaderboard + "test2": test_list, # like phase 2 - final leaderboard + } + write_datalist(f"{out_dataset_prefix}_{site_name}.json", data_set) + + print(50 * "=") + print( + f"Successfully converted a total {len(saved_filenames)} of {len(image_file_path)} images." + ) + + # check that there were no duplicated files + assert len(saved_filenames) == len( + np.unique(saved_filenames) + ), f"Not all generated files ({len(saved_filenames)}) are unique ({len(np.unique(saved_filenames))})!" + + print(f"Data lists saved wit prefix {out_dataset_prefix}") + print(50 * "=") + print("Processed unique DICOM tags", np.unique(dc_tags)) + + +if __name__ == "__main__": + main() diff --git a/federated_learning/breast_density_challenge/code/pt/utils/strip_testing_labels.py b/federated_learning/breast_density_challenge/code/pt/utils/strip_testing_labels.py new file mode 100644 index 0000000000..e891ed8d9b --- /dev/null +++ b/federated_learning/breast_density_challenge/code/pt/utils/strip_testing_labels.py @@ -0,0 +1,59 @@ +# Copyright 2022 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + + +def strip_and_split(dataset_filename, strip_set): + with open(dataset_filename, "r") as f: + data = json.load(f) + + # remove labels + [x.pop("label") for x in data[strip_set]] + new_data = { + "train": data["train"], # keep the same train set in both cases + "test": data[strip_set], + } + print(f"removed {len(data[strip_set])} labels from `{strip_set}`") + return new_data + + +def main(): + datalist_rootdir = "../../../data" + for client_id in ["site-1", "site-2", "site-3"]: + print(f"processing {client_id}") + new_datalist1 = strip_and_split( + os.path.join(datalist_rootdir, f"./dataset_{client_id}.json"), + strip_set="test1", + ) + new_datalist2 = strip_and_split( + os.path.join(datalist_rootdir, f"./dataset_{client_id}.json"), + strip_set="test2", + ) + with open( + os.path.join( + datalist_rootdir, f"./dataset_blinded_{client_id}.json" + ), + "w", + ) as f: + json.dump(new_datalist1, f, indent=4) + with open( + os.path.join( + datalist_rootdir, f"./dataset_blinded_phase2_{client_id}.json" + ), + "w", + ) as f: + json.dump(new_datalist2, f, indent=4) + + +if __name__ == "__main__": + main() diff --git a/federated_learning/breast_density_challenge/code/run_fl.py b/federated_learning/breast_density_challenge/code/run_fl.py new file mode 100644 index 0000000000..36faa2fd39 --- /dev/null +++ b/federated_learning/breast_density_challenge/code/run_fl.py @@ -0,0 +1,94 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time + +from nvflare.fuel.hci.client.fl_admin_api_runner import FLAdminAPIRunner, api_command_wrapper, wait_until_clients_greater_than_cb +from nvflare.fuel.hci.client.fl_admin_api_spec import TargetType + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--run_number", type=int, default=100, help="FL run number to start at.") + parser.add_argument("--admin_dir", type=str, default="./admin/", help="Path to admin directory.") + parser.add_argument("--username", type=str, default="admin@nvflare.com", help="Admin username") + parser.add_argument("--app", type=str, default="cifar10_fedavg", help="App to be deployed") + parser.add_argument("--port", type=int, default=8003, help="The admin server port") + parser.add_argument("--poc", action='store_true', help="Whether admin uses POC mode.") + parser.add_argument("--min_clients", type=int, default=8, help="Minimum number of clients.") + args = parser.parse_args() + + host = "" + port = args.port + + assert os.path.isdir(args.admin_dir), f"admin directory does not exist at {args.admin_dir}" + + # Set up certificate names and admin folders + upload_dir = os.path.join(args.admin_dir, "transfer") + if not os.path.isdir(upload_dir): + os.makedirs(upload_dir) + download_dir = os.path.join(args.admin_dir, "download") + if not os.path.isdir(download_dir): + os.makedirs(download_dir) + + run_number = args.run_number + + # Initialize the runner + runner = FLAdminAPIRunner( + host=host, + port=port, + username=args.username, + admin_dir=args.admin_dir, + poc=args.poc, + debug=False, + ) + + # Run + start = time.time() + # Wait for clients to be connected + print(f"WAITING FOR {args.min_clients} CLIENTS TO CONNECT...") + api_command_wrapper( + runner.api.wait_until_server_status( + callback=wait_until_clients_greater_than_cb, min_clients=args.min_clients + ) + ) + print("MAKING SURE CLIENTS ARE READY...") + time.sleep(30) # make sure clients are ready + + # Run Training + print("RUN TRAINING...") + runner.run(run_number, args.app, restart_all_first=False, shutdown_on_error=False, shutdown_at_end=False, + timeout=None, min_clients=args.min_clients) + print("Total training time", time.time() - start) + + # Move client logs to server + print("GET CLIENT LOGS") + for client_id in ["site-1", "site-2", "site-3"]: + result = runner.api.cat_target(target="site-1", file="log.txt") + if result["status"] == "SUCCESS": + if "message" in result["details"]: + log = result["details"]["message"] + client_log_file = os.path.join(args.admin_dir, "..", f"{client_id}_log.txt") + with open(client_log_file, "w") as f: + f.write(log) + print(f"Wrote {client_id}'s log to {client_log_file}") + + print("SHUTDOWN ALL...") + api_command_wrapper(runner.api.shutdown(TargetType.ALL)) + + +if __name__ == "__main__": + main() diff --git a/federated_learning/breast_density_challenge/code/run_fl.sh b/federated_learning/breast_density_challenge/code/run_fl.sh new file mode 100755 index 0000000000..31333d3b2e --- /dev/null +++ b/federated_learning/breast_density_challenge/code/run_fl.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash +# add current folder to PYTHONPATH +export PYTHONPATH="${PYTHONPATH}:${PWD}" +echo "PYTHONPATH is ${PYTHONPATH}" +export PYTHONUNBUFFERED=1 + +algorithms_dir="${PWD}/configs" +workspace="fl_workspace" +admin_username="admin@nvflare.com" +site_pre="site-" + +n_clients=$1 +config=$2 +run=$3 + +if test -z "${n_clients}" || test -z "${config}" || test -z "${run}" +then + echo "Usage: ./run_fl.sh [n_clients] [config] [run], e.g. ./run_fl.sh 3 mammo_fedavg 1 0.1" + exit 1 +fi + +# start training +echo "STARTING TRAINING" +python3 ./run_fl.py --port=8003 --admin_dir="./${workspace}/${admin_username}" \ + --username="${admin_username}" --run_number="${run}" --app="${algorithms_dir}/${config}" --min_clients="${n_clients}" + +# sleep for FL system to shut down, so a new run can be started automatically +sleep 30 +echo "TRAINING ENDED" diff --git a/federated_learning/breast_density_challenge/code/start_server.sh b/federated_learning/breast_density_challenge/code/start_server.sh new file mode 100755 index 0000000000..89b9792936 --- /dev/null +++ b/federated_learning/breast_density_challenge/code/start_server.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +SERVER="server" +echo "STARTING ${CLIENT_NAME}" +./fl_workspace/${SERVER}/startup/start.sh; sleep 30s # TODO: Is there a better way than sleep? +./run_fl.sh 3 mammo_fedavg 1 diff --git a/federated_learning/breast_density_challenge/code/start_site-1.sh b/federated_learning/breast_density_challenge/code/start_site-1.sh new file mode 100755 index 0000000000..2f2e06a75b --- /dev/null +++ b/federated_learning/breast_density_challenge/code/start_site-1.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +CLIENT_NAME="site-1" +echo "STARTING ${CLIENT_NAME}" +./fl_workspace/${CLIENT_NAME}/startup/start.sh diff --git a/federated_learning/breast_density_challenge/code/start_site-2.sh b/federated_learning/breast_density_challenge/code/start_site-2.sh new file mode 100755 index 0000000000..86ed4fdeb8 --- /dev/null +++ b/federated_learning/breast_density_challenge/code/start_site-2.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +CLIENT_NAME="site-2" +echo "STARTING ${CLIENT_NAME}" +./fl_workspace/${CLIENT_NAME}/startup/start.sh diff --git a/federated_learning/breast_density_challenge/code/start_site-3.sh b/federated_learning/breast_density_challenge/code/start_site-3.sh new file mode 100755 index 0000000000..9a5a3992c5 --- /dev/null +++ b/federated_learning/breast_density_challenge/code/start_site-3.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +CLIENT_NAME="site-3" +echo "STARTING ${CLIENT_NAME}" +./fl_workspace/${CLIENT_NAME}/startup/start.sh diff --git a/federated_learning/breast_density_challenge/data/README.md b/federated_learning/breast_density_challenge/data/README.md new file mode 100644 index 0000000000..f06a23217e --- /dev/null +++ b/federated_learning/breast_density_challenge/data/README.md @@ -0,0 +1,19 @@ +## Example breast density data + +Download example data from https://drive.google.com/file/d/1Fd9GLUIzbZrl4FrzI3Huzul__C8wwzyx/view?usp=sharing. +Extract here. + +## Data source +This example data is based on [CBIS-DDSM](https://wiki.cancerimagingarchive.net/display/Public/CBIS-DDSM) from [TCIA](https://wiki.cancerimagingarchive.net/) [1]. + +We preprocessed all files using `code/pt/utils/preprocess_dicomdir.py` and generated train/val splits for each client +and separate testing split. + +For more details on this example data, see [2,3]. + +## References +[1] Clark K, Vendt B, Smith K, Freymann J, Kirby J, Koppel P, Moore S, Phillips S, Maffitt D, Pringle M, Tarbox L, Prior F. The Cancer Imaging Archive (TCIA): Maintaining and Operating a Public Information Repository, Journal of Digital Imaging, Volume 26, Number 6, December, 2013, pp 1045-1057. DOI: https://doi.org/10.1007/s10278-013-9622-7 + +[2] Rebecca Sawyer Lee, Francisco Gimenez, Assaf Hoogi , Daniel Rubin (2016). Curated Breast Imaging Subset of DDSM [Dataset]. The Cancer Imaging Archive. DOI: https://doi.org/10.7937/K9/TCIA.2016.7O02S9CY + +[3] Rebecca Sawyer Lee, Francisco Gimenez, Assaf Hoogi, Kanae Kawai Miyake, Mia Gorovoy & Daniel L. Rubin. (2017) A curated mammography data set for use in computer-aided detection and diagnosis research. Scientific Data volume 4, Article number: 170177 DOI: https://doi.org/10.1038/sdata.2017.177 diff --git a/federated_learning/breast_density_challenge/figs/example_data_val_global_acc_kappa.png b/federated_learning/breast_density_challenge/figs/example_data_val_global_acc_kappa.png new file mode 100644 index 0000000000..1a28809cef Binary files /dev/null and b/federated_learning/breast_density_challenge/figs/example_data_val_global_acc_kappa.png differ diff --git a/federated_learning/breast_density_challenge/result_server/.gitkeep b/federated_learning/breast_density_challenge/result_server/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/federated_learning/breast_density_challenge/run_all_fl.sh b/federated_learning/breast_density_challenge/run_all_fl.sh new file mode 100755 index 0000000000..604d992b53 --- /dev/null +++ b/federated_learning/breast_density_challenge/run_all_fl.sh @@ -0,0 +1,8 @@ +#!/usr/bin/env bash + +mkdir logs +./run_docker_server.sh 2>&1 | tee logs/server_log.txt & +sleep 30s +./run_docker_site-1.sh 0 2>&1 | tee logs/site-1_log.txt & +./run_docker_site-2.sh 1 2>&1 | tee logs/site-2_log.txt & +./run_docker_site-3.sh 0 2>&1 | tee logs/site-3_log.txt diff --git a/federated_learning/breast_density_challenge/run_docker_debug.sh b/federated_learning/breast_density_challenge/run_docker_debug.sh new file mode 100755 index 0000000000..bf4d542f8c --- /dev/null +++ b/federated_learning/breast_density_challenge/run_docker_debug.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +DOCKER_IMAGE=monai-nvflare:latest + +GPU=$1 +CLIENT_NAME="site-1" + +DATA_DIR="${PWD}/data" + +# interactive session +#COMMAND="/bin/bash" +# test learner +COMMAND="python3 pt/learners/mammo_learner.py" + +echo "Starting $DOCKER_IMAGE with GPU=${GPU}" +docker run -it \ +--gpus="device=${GPU}" --network=host --ipc=host --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \ +--name="${CLIENT_NAME}_debug" \ +-e NVIDIA_VISIBLE_DEVICES=${GPU} \ +-v ${DATA_DIR}:/data:ro \ +-w /code \ +${DOCKER_IMAGE} /bin/bash -c "${COMMAND}" diff --git a/federated_learning/breast_density_challenge/run_docker_server.sh b/federated_learning/breast_density_challenge/run_docker_server.sh new file mode 100755 index 0000000000..9cedd95e17 --- /dev/null +++ b/federated_learning/breast_density_challenge/run_docker_server.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash +DOCKER_IMAGE=monai-nvflare:latest + +OUT_DIR="${PWD}/result_server" +SERVER="server" + +GPU=$1 + +COMMAND="/code/start_server.sh; /code/finalize_server.sh" + +echo "Starting $DOCKER_IMAGE with GPU=${GPU}" +docker run \ +--gpus="device=${GPU}" --network=host --ipc=host --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \ +--name=${SERVER} \ +-e NVIDIA_VISIBLE_DEVICES="${GPU}" \ +-v "${OUT_DIR}":/result \ +-w /code \ +${DOCKER_IMAGE} /bin/bash -c "${COMMAND}" + +# kill client containers +docker kill site-1 site-2 site-3 diff --git a/federated_learning/breast_density_challenge/run_docker_site-1.sh b/federated_learning/breast_density_challenge/run_docker_site-1.sh new file mode 100755 index 0000000000..b5afcad11c --- /dev/null +++ b/federated_learning/breast_density_challenge/run_docker_site-1.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +DOCKER_IMAGE=monai-nvflare:latest + +GPU=$1 +CLIENT_NAME="site-1" + +DATA_DIR="${PWD}/data" + +COMMAND="/code/start_${CLIENT_NAME}.sh; tail -f /dev/null" + +echo "Starting $DOCKER_IMAGE with GPU=${GPU}" +docker run \ +--gpus="device=${GPU}" --network=host --ipc=host --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \ +--name="${CLIENT_NAME}" \ +-e NVIDIA_VISIBLE_DEVICES="${GPU}" \ +-v "${DATA_DIR}":/data:ro \ +-w /code \ +${DOCKER_IMAGE} /bin/bash -c "${COMMAND}" diff --git a/federated_learning/breast_density_challenge/run_docker_site-2.sh b/federated_learning/breast_density_challenge/run_docker_site-2.sh new file mode 100755 index 0000000000..7268baf088 --- /dev/null +++ b/federated_learning/breast_density_challenge/run_docker_site-2.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +DOCKER_IMAGE=monai-nvflare:latest + +GPU=$1 +CLIENT_NAME="site-2" + +DATA_DIR="${PWD}/data" + +COMMAND="/code/start_${CLIENT_NAME}.sh; tail -f /dev/null" + +echo "Starting $DOCKER_IMAGE with GPU=${GPU}" +docker run \ +--gpus="device=${GPU}" --network=host --ipc=host --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \ +--name="${CLIENT_NAME}" \ +-e NVIDIA_VISIBLE_DEVICES="${GPU}" \ +-v "${DATA_DIR}":/data:ro \ +-w /code \ +${DOCKER_IMAGE} /bin/bash -c "${COMMAND}" diff --git a/federated_learning/breast_density_challenge/run_docker_site-3.sh b/federated_learning/breast_density_challenge/run_docker_site-3.sh new file mode 100755 index 0000000000..976724b28f --- /dev/null +++ b/federated_learning/breast_density_challenge/run_docker_site-3.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +DOCKER_IMAGE=monai-nvflare:latest + +GPU=$1 +CLIENT_NAME="site-3" + +DATA_DIR="${PWD}/data" + +COMMAND="/code/start_${CLIENT_NAME}.sh; tail -f /dev/null" + +echo "Starting $DOCKER_IMAGE with GPU=${GPU}" +docker run \ +--gpus="device=${GPU}" --network=host --ipc=host --rm --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \ +--name="${CLIENT_NAME}" \ +-e NVIDIA_VISIBLE_DEVICES="${GPU}" \ +-v "${DATA_DIR}":/data:ro \ +-w /code \ +${DOCKER_IMAGE} /bin/bash -c "${COMMAND}"