[Research] FedUMM: Federated Learning for Unified Multimodal Models#4158
[Research] FedUMM: Federated Learning for Unified Multimodal Models#4158ZiyueXu77 merged 11 commits intoNVIDIA:mainfrom
Conversation
|
official code for fedumm, first fed learning pipeline using nvflare to train unified multimodal models |
Greptile SummaryThis PR adds the FedUMM research example — a federated learning framework for fine-tuning vision-language models (BLIP-VQA and JanusPro) using LoRA adapters with FedAvg, implemented on top of NVFlare. The PR addresses two prior review comments: the empty-dataloader guard now correctly raises Key findings:
Confidence Score: 3/5Not safe to merge — the JanusPro training path will crash at runtime due to a tensor shape mismatch, and the weight_decay discrepancy prevents paper result reproduction. Two P1 findings remain: (1) a definite runtime crash in januspro_backend.py due to labels vs inputs_embeds sequence-length mismatch, and (2) weight_decay=0.01 (PyTorch default) used instead of documented 0.05 across both training entry points. The BLIP path works correctly, and previously flagged issues (empty dataloader, wrong metric key) have been addressed. research/fedumm/src/januspro_backend.py (label shape bug — will crash), research/fedumm/src/fl_client.py and research/fedumm/src/local_train.py (weight_decay mismatch) Important Files Changed
Sequence DiagramsequenceDiagram
participant job as job.py
participant server as Server (FedAvg + ModelSelector)
participant client as FL Client (fl_client.py)
participant backend as VLM Backend
job->>server: configure FedAvg and IntimeModelSelector
job->>client: ScriptRunner with script_args
client->>client: load and shard dataset (Dirichlet or IID)
client->>backend: build_model_and_processor with LoRA config
backend-->>client: model and processor
loop FL Rounds
server->>client: FLModel with global LoRA weights
client->>client: load_trainable_params into model
alt Evaluate round
client->>backend: evaluate on eval_loader
backend-->>client: val_accuracy
client->>server: FLModel metrics val_accuracy
server->>server: ModelSelector tracks best checkpoint
else Train round
loop local_epochs
client->>backend: train_one_epoch
backend-->>client: avg loss
end
client->>backend: evaluate on eval_loader
backend-->>client: local_acc
client->>server: FLModel with LoRA delta and metrics
server->>server: FedAvg aggregates LoRA weights
end
end
|
There was a problem hiding this comment.
Thanks @rollingsu! Overall looks good and aligns well with the paper. Do we have experiments that have explicitly modeled the "missing modality" (as shown in Fig.3)? One further question: could you point me to the part of the code that handles the "shared alignment token is introduced to stabilize cross-client updates and maintain semantic consistency across modalities."?
Please move it from /examples/advanced/ to /research
|
Thanks for the great contribution! I agree, let's put this under /research to increase its visibility. |
|
@rollingsu please provide a PR description as well. |
Added a check to raise an error if the dataloader is empty.
Removed the line that sets padding token IDs to -100 in labels.
Added aiohttp for timeout configuration in dataset loading.
Removed the Supported Models section and updated expected output values.
ZiyueXu77
left a comment
There was a problem hiding this comment.
Good enough for now, I will further polish it
|
/build |
Fixes # .
Description
A few sentences describing the changes proposed in this pull request.
Types of changes
./runtest.sh.