-
Notifications
You must be signed in to change notification settings - Fork 16
Moon #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Moon #66
Conversation
fl4health/clients/moon_client.py
Outdated
| checkpointer=checkpointer, | ||
| ) | ||
| self.initial_tensors: List[torch.Tensor] | ||
| self.contrastive_weight: float = 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will probably want the ability to configure hyperparameters (temperature, contrastive weight and len_old_models_buffer) to take advantage of your otherwise general code. We can have the defaults that you specify but check if the config sent by the server contains these hyperparams, and if they do set them accordingly before the FL round begins.
fl4health/clients/moon_client.py
Outdated
| ) | ||
| self.initial_tensors: List[torch.Tensor] | ||
| self.contrastive_weight: float = 10 | ||
| self.temprature: float = 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: temperature
| assert isinstance(self.model, MoonModel) | ||
|
|
||
| # Save the parameters of the old local model | ||
| old_model = copy.deepcopy(self.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably also want to set params.requires_grad = False for all old_models params since we are not optimizing.
| output = super().set_parameters(parameters, config) | ||
|
|
||
| # Save the parameters of the global model | ||
| self.global_model = copy.deepcopy(self.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and set params.requires_grad = False for all global_model params since we are not optimizing.
fl4health/clients/moon_client.py
Outdated
| def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]: | ||
| preds, self.features, _ = self.model(input) | ||
| self.features = self.features.view(len(self.features), -1) | ||
| self.old_features_list = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments indicating that we are storing the features of old models as well global model for subsequent contrastive loss computation would be helpful to indicate that is a side effect of this predict method
fl4health/clients/moon_client.py
Outdated
| preds, self.features, _ = self.model(input) | ||
| self.features = self.features.view(len(self.features), -1) | ||
| self.old_features_list = [] | ||
| for old_model in self.old_models_list: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In order to avoid having this side effect (setting global_features and old_features_list attribute in predict method), we can stack the list of old_features into a tensor and include it in the preds dict. Would also include global_features in the preds dict. Then get_contrastive loss would accept preds. @emersodb do you think is a worthwhile refactor, or is it better this way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree this can be cleaner. I will apply it in refactor now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh but no, I get into errors with metrics file as it has assertion in line 242. I tentatively changed it, let me know if it won't break anything else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I think the way you did it is perfect. We should be able to have predictions that we don't have to run metrics on. I am going to update the MetricMeterManager in a subsequent PR so its easier for users to specify which predictions they want evaluated. This is a good first step, thanks Sana :)
examples/moon_example/README.md
Outdated
| @@ -0,0 +1,48 @@ | |||
| # Moon Federated Learning Example | |||
| This example provides an example of training a Moon type model on a non-IID subset of the MNIST data. The FL server | |||
| expects two clients to be spun up (i.e. it will wait until two clients report in before starting training). Each client | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: According to the config, we expect three clients top be spun up.
fl4health/clients/moon_client.py
Outdated
|
|
||
| def get_contrastive_loss(self) -> torch.Tensor: | ||
| assert len(self.features) == len(self.global_features) | ||
| posi = self.cos_sim(self.features, self.global_features) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be good to add comments to clarify how the contrastive loss is computed. I see you are stacking the predictions of global model and subsequently old models column wise. Then setting labels to 0, indicating the similarity for positive pairs should be the highest when compared to the negative similarities in other columns. May not be clear to people not familiar with how MOON computes the contrastive loss.
jewelltaylor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks Sana :)
PR Type
[Feature]
Short Description
Added Moon method to the existing methods.
Added a
clients/moon_client.pyandmodel_bases/moon_base.pyto library. Also added experiment sample to examples and flamby folder.Tests Added
Added
clients/test_moon_client.pyto tests.