This repository is the official implementation for the NeurIPS 2022 paper FedSR: A Simple and Effective Domain Generalization Method for Federated Learning.
Please consider citing our paper as
@inproceedings{
nguyen2022fedsr,
title={FedSR: A Simple and Effective Domain Generalization Method for Federated Learning},
author={A. Tuan Nguyen and Philip Torr and Ser-Nam Lim},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=mrt90D00aQX}
}
Code for DG datasets is modified from https://github.com/facebookresearch/DomainBed
python3, pytorch 1.7.0 or higher, torchvision 0.8.0 or higher
Currently, the implementation uses a distributed system with N gpus (with N equals the number of domains). This is to mimic a real-world system. Therefore, the code can't run if you has < N gpus. I will consider adding support for this in the future.
For example, to run the experiment for PACS with target domain 0:
cd src
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --master_port $RANDOM -m \
main --dataset PACS --test_env 0 --method [method] --total_iters 5000 --optim SGD \
--back_bone resnet18 --train_split 0.9 --z_dim 512 --L2R_coeff 0.01 --CMI_coeff 0.001 \
--num_samples 20 --seed [seed] --rounds_per_eval 10 --E 5 --batchsize 64 --lr 0.01 \
--weight_decay 0.0005 --dataset_folder [data_dir] --experiment_path [experiment_path] \
--save_checkpoint True --distributed True --world_size 4
Where:
- [method] is either FedSR or FedL2R (FedL2R is the variant with a deterministic representation and only uses the L2R regularizer)
- [seed] is the random seed (0,1,2).
- [data_dir] is the /path/to/your/data/directory
- [experiment_path] is /path/to/experiment/folder where you save the checkpoints and such
For OfficeHome and DomainNet: change [--back_bone] to resnet50, [--z_dim] to 2048, set [--nproc_per_node] and [--world_size] to the corresponding number of domains (4 and 6). Also, change [--L2R_coeff] and [--CMI_coeff] to the hyper-parameters stated in the paper.