Skip to content

Official implementation of FedGAT: Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction (https://arxiv.org/abs/2502.04521)

License

Notifications You must be signed in to change notification settings

icon-lab/FedGAT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation


FedGAT
Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction

Valiyeh A. Nezhad1,2·Gokberk Elmas1,2·Bilal Kabas1,2·Fuat Arslan1,2·Tolga Çukur1,2



1UMRAM, 2Bilkent University


Official PyTorch implementation of FedGAT, a novel model-agnostic federated learning technique based on generative autoregressive transformers for MRI reconstruction. Unlike conventional federated learning that requires homogeneous model architectures across sites, FedGAT enables flexible collaborations among sites with distinct reconstruction models by decentralizing the training of a global generative prior. This prior captures the distribution of multi-site MRI data via autoregressive prediction across spatial scales, guided by a site-specific prompt. Site-specific reconstruction models are trained using hybrid datasets combining local and synthetic samples. Comprehensive experiments demonstrate that FedGAT achieves superior within-site and across-site reconstruction performance compared to state-of-the-art FL baselines while preserving privacy.


⚙️ Installation

# Clone repo
git clone https://github.com/icon-lab/FedGAT.git
cd FedGAT

# Create and activate conda environment
conda env create -f environment.yml
conda activate fedgat

📚 Data Preparation

Expected dataset structure:

data/
├── Site_0/
│   ├── train/
│   │   └── data/
│   └── val/
│       └── data/
├── Site_1/
│   ├── train/
│   │   └── data/
│   └── val/
│       └── data/
├── Site_2/
│   ├── train/
│   │   └── data/
│   └── val/
│       └── data/

Each train/ and val/ folder contains MRI images (e.g., .png files) for each site.


🏋️ Training

Basic Training Command

torchrun --nproc_per_node=1 train.py \
    --case='multicoil' \
    --depth=16 \
    --bs=16 \
    --ep=500 \
    --fp16=1 \
    --alng=1e-3 \
    --wpe=0.1 \
    --client_num=3 \
    --comm_round=1

Training Parameters

Parameter Description Default
--case Dataset type ('singlecoil' or 'multicoil') -
--client_num Number of federated Sites 3
--comm_round Number of communication rounds 1
--depth Model depth 16
--bs Batch size 16
--ep Number of epochs 500
--fp16 Mixed precision training 1
--alng AdaLN gamma 1e-3
--wpe Final learning rate ratio at the end of training 0.1

FedGAT will create a fedGAT_output/ directory to store all checkpoints and logs. You can monitor training by:

  • Inspecting fedGAT_output/log.txt and fedGAT_output/stdout.txt

If your run is interrupted, simply re-execute the same training command—FedGAT will automatically pick up from the latest fedGAT_output/ckpt*.pth checkpoint (see utils/misc.py, lines 344–357).

📖 Citation

You are welcome to use, modify, and distribute this code. We kindly request that you acknowledge this repository and cite our paper appropriately.

@article{nezhad2025generative,
  title={Generative Autoregressive Transformers for Model-Agnostic Federated MRI Reconstruction},
  author={Nezhad, Valiyeh A and Elmas, Gokberk and Kabas, Bilal and Arslan, Fuat and {\c{C}}ukur, Tolga},
  journal={arXiv preprint arXiv:2502.04521},
  year={2025}
}

🙏 Acknowledgments

This repository uses code from the following projects:


Copyright © 2025, ICON Lab.