Skip to content

Add MedGemma fine-tuning example with QLoRA#4359

Merged
holgerroth merged 31 commits intoNVIDIA:mainfrom
holgerroth:codex/medgemma-upstream-pr
Apr 7, 2026
Merged

Add MedGemma fine-tuning example with QLoRA#4359
holgerroth merged 31 commits intoNVIDIA:mainfrom
holgerroth:codex/medgemma-upstream-pr

Conversation

@holgerroth
Copy link
Copy Markdown
Collaborator

Summary

  • add a new advanced MedGemma example modeled on the advanced Qwen3-VL example and adapted from the official MedGemma notebook
  • include data prep, QLoRA fine-tuning, inference, and before/after evaluation utilities plus updated documentation
  • add focused unit coverage for MedGemma data utilities

Testing

  • python3 -m compileall examples/advanced/medgemma tests/unit_test/examples/advanced/medgemma
  • pytest tests/unit_test/examples/advanced/medgemma/data_utils_test.py -q

@holgerroth holgerroth changed the title Add advanced MedGemma fine-tuning example Add MedGemma fine-tuning example with LoRA Mar 25, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 25, 2026

Greptile Summary

This PR adds a new federated MedGemma QLoRA fine-tuning example, closely modelled on the Qwen3-VL example and the official Google Health notebook. All six issues raised in the previous review rounds have been addressed: the empty-dataset guard, zip-slip sanitisation, the broken ALT_TISSUE_LABELS f-string, the hardcoded num_train_epochs, device_map derivation in inference, and the peft>=0.18.0 version pin.

Confidence Score: 5/5

Safe to merge — all previously raised P0/P1 issues are resolved and the one remaining note is a P2 best-practice suggestion for GPU memory management in the evaluation script.

Six significant issues from prior review rounds are fully addressed (empty-dataset guard, zip-slip sanitisation, broken alt-label f-string, hardcoded num_train_epochs, device_map derivation, peft version pin). The only remaining finding is a P2 suggestion to explicitly flush CUDA cache between the two sequential model loads in run_evaluation.py, which affects convenience on constrained hardware but not correctness. Per the confidence guidance, all P2 findings yield a score of 5.

run_evaluation.py — minor GPU memory hygiene between two sequential model loads.

Important Files Changed

Filename Overview
examples/advanced/medgemma/client.py Federated training client with QLoRA; empty-dataset guard, per-round adapter exchange, and memory cleanup all look correct.
examples/advanced/medgemma/data_utils.py Data helpers including the corrected _to_alt_tissue_label (replacing the broken f-string), label parsing, and client-shard splitting look correct.
examples/advanced/medgemma/inference_utils.py Device map now correctly derived from the --device argument via get_inference_device_map(); model loading branches look clean.
examples/advanced/medgemma/job.py --num_train_epochs now wired through _build_train_args, fixing the previous hardcoded-1 issue; GPU and timeout configuration looks correct.
examples/advanced/medgemma/model.py LoRA config, adapter-state helpers, and MedGemmaLoRAModel look correct; ensure_weight_tying=True now guarded by peft>=0.18.0 in requirements.
examples/advanced/medgemma/download_data.py Zip-slip guard via resolve_extraction_path() correctly validates each member before extraction; partial-download cleanup is solid.
examples/advanced/medgemma/run_evaluation.py Sequential two-model evaluation using PIL context managers correctly; base model not explicitly freed before loading the fine-tuned model, risking OOM on constrained GPUs.
examples/advanced/medgemma/run_inference.py Now uses with Image.open(...) as image_file: context manager, addressing the previously noted file-handle leak.
examples/advanced/medgemma/prepare_data.py Data preparation script correctly delegates to collect_image_records and split_records_for_clients; JSON output looks clean.
examples/advanced/medgemma/requirements.txt All version pins now present including peft>=0.18.0 and transformers>=4.57.1; looks complete for the example's dependencies.

Sequence Diagram

sequenceDiagram
    participant J as job.py (SimEnv)
    participant S as NVFlare Server (FedAvgRecipe)
    participant C as client.py (SFTTrainer)
    participant M as model.py (MedGemmaLoRAModel)

    J->>S: recipe.execute(env)
    S->>M: MedGemmaLoRAModel.state_dict() → initial LoRA adapter weights
    loop Each FL Round
        S->>C: flare.send(FLModel with adapter params)
        C->>C: apply_adapter_state(model, params)
        C->>C: SFTTrainer.train() — QLoRA fine-tuning on local data
        C->>C: get_adapter_state_dict(model) → updated LoRA weights
        C->>S: flare.send(FLModel with updated params + metrics)
        S->>S: FedAvg aggregation of adapter weights
    end
    S->>J: run.get_result() → FL_global_model.pt
    J-->>J: inference / evaluation via run_inference.py / run_evaluation.py
Loading

Reviews (16): Last reviewed commit: "Merge branch 'main' into codex/medgemma-..." | Re-trigger Greptile

Comment thread examples/advanced/medgemma/client.py
Comment thread examples/advanced/medgemma/run_inference.py Outdated
Comment thread examples/advanced/medgemma/download_data.py
Comment thread examples/advanced/medgemma/job.py Outdated
@holgerroth holgerroth changed the title Add MedGemma fine-tuning example with LoRA Add MedGemma fine-tuning example with QLoRA Mar 25, 2026
Comment thread examples/advanced/medgemma/inference_utils.py
Comment thread examples/advanced/medgemma/data_utils.py
Comment thread examples/advanced/medgemma/model.py
@holgerroth
Copy link
Copy Markdown
Collaborator Author

/build

@holgerroth
Copy link
Copy Markdown
Collaborator Author

/build

@vijaygovindaraja
Copy link
Copy Markdown
Contributor

The adapter-only exchange pattern here is the correct architecture for clinical FL base model weights stay local,
which is critical for HIPAA-compliant deployments where institutional data governance prevents any model artifact that touched patient data from leaving the site. This should be stated explicitly in the README rather than left implicit.

The Zip Slip mitigation in download_data.py is necessary given external dataset sources. Good.

One substantive concern: FedAvg over LoRA adapters is not neutral with respect to rank. When you average low-rank matrices across heterogeneous clients, the effective rank of the aggregated adapter can collapse below the configured lora_r, particularly under non-IID data distributions common in multi-site clinical imaging. This is a known failure mode (see Cho et al., "Heterogeneous LoRA Aggregation" literature). At minimum, the documentation should note this risk and suggest monitoring validation loss divergence as a diagnostic. Ideally, a future iteration would support
rank-adaptive aggregation strategies.

Otherwise, well-structured and properly tested.

Copy link
Copy Markdown
Collaborator

@YuanTingHsieh YuanTingHsieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added one comment

Comment thread examples/advanced/medgemma/install_requirements.sh Outdated
@holgerroth
Copy link
Copy Markdown
Collaborator Author

/build

@holgerroth
Copy link
Copy Markdown
Collaborator Author

/build

@holgerroth
Copy link
Copy Markdown
Collaborator Author

/build

@holgerroth holgerroth enabled auto-merge (squash) April 7, 2026 21:35
@holgerroth holgerroth merged commit 80a2abc into NVIDIA:main Apr 7, 2026
29 checks passed
@holgerroth holgerroth deleted the codex/medgemma-upstream-pr branch April 7, 2026 22:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants