Add MedGemma fine-tuning example with QLoRA#4359
Conversation
Greptile SummaryThis PR adds a new federated MedGemma QLoRA fine-tuning example, closely modelled on the Qwen3-VL example and the official Google Health notebook. All six issues raised in the previous review rounds have been addressed: the empty-dataset guard, zip-slip sanitisation, the broken Confidence Score: 5/5Safe to merge — all previously raised P0/P1 issues are resolved and the one remaining note is a P2 best-practice suggestion for GPU memory management in the evaluation script. Six significant issues from prior review rounds are fully addressed (empty-dataset guard, zip-slip sanitisation, broken alt-label f-string, hardcoded num_train_epochs, device_map derivation, peft version pin). The only remaining finding is a P2 suggestion to explicitly flush CUDA cache between the two sequential model loads in run_evaluation.py, which affects convenience on constrained hardware but not correctness. Per the confidence guidance, all P2 findings yield a score of 5. run_evaluation.py — minor GPU memory hygiene between two sequential model loads. Important Files Changed
Sequence DiagramsequenceDiagram
participant J as job.py (SimEnv)
participant S as NVFlare Server (FedAvgRecipe)
participant C as client.py (SFTTrainer)
participant M as model.py (MedGemmaLoRAModel)
J->>S: recipe.execute(env)
S->>M: MedGemmaLoRAModel.state_dict() → initial LoRA adapter weights
loop Each FL Round
S->>C: flare.send(FLModel with adapter params)
C->>C: apply_adapter_state(model, params)
C->>C: SFTTrainer.train() — QLoRA fine-tuning on local data
C->>C: get_adapter_state_dict(model) → updated LoRA weights
C->>S: flare.send(FLModel with updated params + metrics)
S->>S: FedAvg aggregation of adapter weights
end
S->>J: run.get_result() → FL_global_model.pt
J-->>J: inference / evaluation via run_inference.py / run_evaluation.py
Reviews (16): Last reviewed commit: "Merge branch 'main' into codex/medgemma-..." | Re-trigger Greptile |
|
/build |
|
/build |
|
The adapter-only exchange pattern here is the correct architecture for clinical FL base model weights stay local, The Zip Slip mitigation in download_data.py is necessary given external dataset sources. Good. One substantive concern: FedAvg over LoRA adapters is not neutral with respect to rank. When you average low-rank matrices across heterogeneous clients, the effective rank of the aggregated adapter can collapse below the configured lora_r, particularly under non-IID data distributions common in multi-site clinical imaging. This is a known failure mode (see Cho et al., "Heterogeneous LoRA Aggregation" literature). At minimum, the documentation should note this risk and suggest monitoring validation loss divergence as a diagnostic. Ideally, a future iteration would support Otherwise, well-structured and properly tested. |
YuanTingHsieh
left a comment
There was a problem hiding this comment.
added one comment
|
/build |
|
/build |
|
/build |
Summary
Testing