-
Notifications
You must be signed in to change notification settings - Fork 16
Add MR-MTL Method #125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MR-MTL Method #125
Conversation
| @@ -1,4 +1,4 @@ | |||
| # FedProx Federated Learning Example | |||
| # Ditto Federated Learning Example | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this left from copy past so I fixed it.
| # Shutdown the client gracefully | ||
| client.shutdown() | ||
|
|
||
| client.metrics_reporter.dump() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lotif: Do you think we should just add this client.metrics_report.dump() to the shutdown so that it just always happens?
| device: torch.device, | ||
| loss_meter_type: LossMeterType = LossMeterType.AVERAGE, | ||
| checkpointer: Optional[TorchCheckpointer] = None, | ||
| lam: float = 1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't part of the original implementation (and shouldn't be in this PR), but it would be interesting to see what effect adapting this value based on the adaptive FedProx approach or the generalization gap of FedDG-GA would have on the approach. A similar idea could be said for the Ditto parameter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused by this comment, which part do you mean exactly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, this was just me recording a thought about something to investigate in the future. That is, would MR-MTL or Ditto benefit from an adaptive implementation of the FedProx-like parameter
TL;DR: No need to do anything. Was just thinking about potential future experimentation.
|
Overall, I think the implementation looks good. Just a few small comments before we can merge. |
emersodb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to go for me.
PR Type
[Feature]
Short Description
Clickup Ticket(s): Link
This is the implementation of MR-MTL: On Privacy and Personalization in Cross-Silo Federated Learning.
The method is very related to FedProx and Ditto. Essentially, at the start of each client training round, we do not update local model weights directly with the global model weights, but instead we constrain the model weights to be close to the initial global model weights at the start of each client training round. Initial global model weights are computed by averaging the model weights at the end of the previous round. Such a mean-regularized training is done by adding a penalty term to the loss function that constrains the local model weights to be close to the initial global model weights.
Tests Added
Three tests are added to
tests/clients/test_mr_mtl_client.pyregarding setting global weights (as we don't set them for local model but just save them for computing mr loss), forming mr loss and computing loss.