Skip to content

Add heterogeneous-rank HLoRA enhancement to MedGemma example#4424

Merged
holgerroth merged 19 commits intoNVIDIA:mainfrom
holgerroth:codex/medgemma-hlora
Apr 10, 2026
Merged

Add heterogeneous-rank HLoRA enhancement to MedGemma example#4424
holgerroth merged 19 commits intoNVIDIA:mainfrom
holgerroth:codex/medgemma-hlora

Conversation

@holgerroth
Copy link
Copy Markdown
Collaborator

@holgerroth holgerroth commented Apr 10, 2026

Summary

This draft PR extends the advanced MedGemma example with a heterogeneous-rank HLoRA workflow on top of the merged federated fine-tuning example.

Changes introduced:

  • add fixed-global-rank custom LoRA aggregators for both the paper baseline (naive) and HLoRA (hlora)
  • keep the stock FedAvg controller path while making the client rank-aware: each site truncates the server LoRA bank to its local rank before local training
  • support heterogeneous local ranks from job.py, with distinct ranks used by default for the 3-client example
  • update data preparation to default to a more heterogeneous label-skewed client split and emit a client label-distribution SVG for visualization
  • add per-site runtime and peak CUDA memory logging in the client
  • add --finetune_only to run_evaluation.py for faster repeated checkpoint comparisons
  • update the MedGemma README with the heterogeneous-rank workflow and observed comparison results

Observed comparison

Using the heterogeneous-rank 3-client layout (4,8,16) on CRC-VAL-HE-7K:

Run Naive HLoRA Delta
default seed 0.8955 (6430/7180) 0.9414 (6759/7180) +0.0458
alternate seed 0.8961 (6434/7180) 0.9366 (6725/7180) +0.0405

Notes

  • This PR intentionally keeps the standard controller path and introduces the HLoRA logic through custom aggregators plus a rank-aware client.
  • The custom naive aggregator remains necessary in the heterogeneous-rank setting because plain built-in FedAvg tensor averaging is not shape-safe once sites use different LoRA ranks.

@holgerroth
Copy link
Copy Markdown
Collaborator Author

@greptileai review this PR

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 10, 2026

Greptile Summary

This PR extends the MedGemma federated fine-tuning example with heterogeneous-rank HLoRA: it adds NaiveMaxRankAggregator and HLoRAMaxRankAggregator custom aggregators, a per-client LoRA rank truncation step on the client, heterogeneous data-splitting in prepare_data.py, and batched inference in the evaluator. The math in _factorize_hlora_update (compact QR+SVD instead of materialising the full dense update) is correct, and the truncate_global_bank_for_site slicing of A/B factors is shape-safe for heterogeneous ranks.

Confidence Score: 5/5

Safe to merge — all remaining findings are P2 style/cleanup suggestions with no correctness or data-integrity impact.

The HLoRA math (compact QR+SVD factorization), the rank-truncation slicing, and the aggregator's zero-weight guard are all correct. Prior P0/P1 concerns (missing os import, silent empty-model broadcast, unreachable RuntimeError) have been addressed in earlier commits. The three remaining comments are minor inconsistencies that do not affect runtime behavior.

No files require special attention; data_utils.py and run_evaluation.py have minor P2 style notes.

Important Files Changed

Filename Overview
examples/advanced/medgemma/custom_aggregators.py New file — NaiveMaxRankAggregator and HLoRAMaxRankAggregator implement fixed-global-rank LoRA aggregation; math is correct, zero-weight guard raises RuntimeError as required.
examples/advanced/medgemma/lora_utils.py New file — LoRA key detection, base-key extraction, uniform rank-map builder, and global-bank truncation; slicing is correct for both lora_A (row-truncation) and lora_B (column-truncation).
examples/advanced/medgemma/client.py Adds --lora_rank arg, rank-aware global-bank truncation before training, detailed CUDA timing/memory logging, and explicit ParamsType.FULL on send; logic is sound.
examples/advanced/medgemma/job.py Adds --lora_aggregation, --global_lora_rank, --site_lora_ranks args; _build_default_site_lora_ranks produces geometrically spaced ranks with monotonicity enforcement; aggregator injection is correct.
examples/advanced/medgemma/data_utils.py Adds heterogeneous non-IID split strategy with label-preference weighted assignment; shard-size vs validation-size check is placed after the assignment loop — a misplaced guard discussed in a comment below.
examples/advanced/medgemma/inference_utils.py Refactored to support batched generation; left-padding is correctly set, prompt_length stripping is safe, use_cache enabled for inference.
examples/advanced/medgemma/model.py Adds configurable lora_rank to create_lora_config/create_peft_medgemma_model/MedGemmaLoRAModel; infer_uniform_lora_rank_from_state_dict validates that the state dict has a single consistent rank.
examples/advanced/medgemma/utils.py New file — consolidates shared helpers (abs_path, free_memory, CUDA stats) previously duplicated across client.py and run_evaluation.py.
examples/advanced/medgemma/run_evaluation.py Adds --batch_size and --finetune_only; still defines a local _abs_path function that is now redundant with utils.abs_path.
examples/advanced/medgemma/prepare_data.py Adds --split_strategy, --dominant_fraction, --plot_path args; SVG generation and summary printing work correctly for both split strategies.

Sequence Diagram

sequenceDiagram
    participant S as Server (FedAvgRecipe)
    participant A as NaiveMaxRankAggregator / HLoRAMaxRankAggregator
    participant C1 as site-1 (rank=4)
    participant C2 as site-2 (rank=8)
    participant C3 as site-3 (rank=16)

    S->>C1: Global LoRA bank (rank=16)
    S->>C2: Global LoRA bank (rank=16)
    S->>C3: Global LoRA bank (rank=16)

    Note over C1: truncate_global_bank_for_site → rank=4
    Note over C2: truncate_global_bank_for_site → rank=8
    Note over C3: no truncation needed

    C1->>C1: SFTTrainer (local rank=4)
    C2->>C2: SFTTrainer (local rank=8)
    C3->>C3: SFTTrainer (local rank=16)

    C1->>A: FLModel FULL (rank-4 A/B tensors, num_examples weight)
    C2->>A: FLModel FULL (rank-8 A/B tensors, num_examples weight)
    C3->>A: FLModel FULL (rank-16 A/B tensors, num_examples weight)

    Note over A: Naive: weighted factor avg into rank-16 bank
    Note over A: HLoRA: cat → QR → SVD → project to rank-16

    A->>S: Aggregated FLModel (rank=16)
    S->>C1: Next round global bank (rank=16)
    S->>C2: Next round global bank (rank=16)
    S->>C3: Next round global bank (rank=16)
Loading

Reviews (4): Last reviewed commit: "Merge branch 'main' into codex/medgemma-..." | Re-trigger Greptile

Comment thread examples/advanced/medgemma/client.py
Comment thread examples/advanced/medgemma/custom_aggregators.py
Comment thread examples/advanced/medgemma/data_utils.py
@holgerroth
Copy link
Copy Markdown
Collaborator Author

@greptileai, review again to see if the issues were addressed.

@holgerroth
Copy link
Copy Markdown
Collaborator Author

@greptileai, review the latest version.

@holgerroth holgerroth marked this pull request as ready for review April 10, 2026 14:01
@holgerroth
Copy link
Copy Markdown
Collaborator Author

/build

Copy link
Copy Markdown
Collaborator

@pcnudde pcnudde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice results

@holgerroth holgerroth merged commit cc69b06 into NVIDIA:main Apr 10, 2026
53 of 54 checks passed
@holgerroth holgerroth deleted the codex/medgemma-hlora branch April 10, 2026 18:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants