Skip to content

Conversation

@sanaAyrml
Copy link
Collaborator

PR Type

[Feature]

Short Description

Added Moon method to the existing methods.
Added a clients/moon_client.py and model_bases/moon_base.py to library. Also added experiment sample to examples and flamby folder.

Tests Added

Added clients/test_moon_client.py to tests.

checkpointer=checkpointer,
)
self.initial_tensors: List[torch.Tensor]
self.contrastive_weight: float = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will probably want the ability to configure hyperparameters (temperature, contrastive weight and len_old_models_buffer) to take advantage of your otherwise general code. We can have the defaults that you specify but check if the config sent by the server contains these hyperparams, and if they do set them accordingly before the FL round begins.

)
self.initial_tensors: List[torch.Tensor]
self.contrastive_weight: float = 10
self.temprature: float = 0.5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: temperature

assert isinstance(self.model, MoonModel)

# Save the parameters of the old local model
old_model = copy.deepcopy(self.model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably also want to set params.requires_grad = False for all old_models params since we are not optimizing.

output = super().set_parameters(parameters, config)

# Save the parameters of the global model
self.global_model = copy.deepcopy(self.model)
Copy link
Contributor

@jewelltaylor jewelltaylor Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and set params.requires_grad = False for all global_model params since we are not optimizing.

def predict(self, input: torch.Tensor) -> Dict[str, torch.Tensor]:
preds, self.features, _ = self.model(input)
self.features = self.features.view(len(self.features), -1)
self.old_features_list = []
Copy link
Contributor

@jewelltaylor jewelltaylor Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments indicating that we are storing the features of old models as well global model for subsequent contrastive loss computation would be helpful to indicate that is a side effect of this predict method

preds, self.features, _ = self.model(input)
self.features = self.features.view(len(self.features), -1)
self.old_features_list = []
for old_model in self.old_models_list:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to avoid having this side effect (setting global_features and old_features_list attribute in predict method), we can stack the list of old_features into a tensor and include it in the preds dict. Would also include global_features in the preds dict. Then get_contrastive loss would accept preds. @emersodb do you think is a worthwhile refactor, or is it better this way?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this can be cleaner. I will apply it in refactor now

Copy link
Collaborator Author

@sanaAyrml sanaAyrml Oct 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh but no, I get into errors with metrics file as it has assertion in line 242. I tentatively changed it, let me know if it won't break anything else.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I think the way you did it is perfect. We should be able to have predictions that we don't have to run metrics on. I am going to update the MetricMeterManager in a subsequent PR so its easier for users to specify which predictions they want evaluated. This is a good first step, thanks Sana :)

@@ -0,0 +1,48 @@
# Moon Federated Learning Example
This example provides an example of training a Moon type model on a non-IID subset of the MNIST data. The FL server
expects two clients to be spun up (i.e. it will wait until two clients report in before starting training). Each client
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: According to the config, we expect three clients top be spun up.


def get_contrastive_loss(self) -> torch.Tensor:
assert len(self.features) == len(self.global_features)
posi = self.cos_sim(self.features, self.global_features)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to add comments to clarify how the contrastive loss is computed. I see you are stacking the predictions of global model and subsequently old models column wise. Then setting labels to 0, indicating the similarity for positive pairs should be the highest when compared to the negative similarities in other columns. May not be clear to people not familiar with how MOON computes the contrastive loss.

Copy link
Contributor

@jewelltaylor jewelltaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks Sana :)

@jewelltaylor jewelltaylor merged commit 0479620 into main Oct 30, 2023
@jewelltaylor jewelltaylor deleted the Moon branch October 30, 2023 21:35
@jewelltaylor jewelltaylor mentioned this pull request Nov 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants