-
Notifications
You must be signed in to change notification settings - Fork 16
Dbe/setting doc string format #64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…nes to readme.md, adding in some preliminary comments in the new format. Adding code that assumes that the first set_parameters is used to initialize the whole model, regardless of the type of parameter exchanger used. Simplifying the warm start implementation in SCAFFOLD a bit, also allowing for default initial control variates to zero if provided a model. You can also pass custom initial control variates as well and we enforce that these are the initial control variates for the server and the clients. Finally, fixing a bug in basic client where local steps was not calculated correctly
…eters are set the first time but are expected to be subsets thereafter where appropriate. Also renaming some of the APFL objects to follow the correct capitalization scheme. Finally, adding a sanity check to fedprox exchange of the proximal weight.
| **/datasets/news_classification/partitioned_datasets/** | ||
| **/datasets/mnist_data/** | ||
| **/datasets/MNIST/** | ||
| **/datasets/agnews_data/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just formally ignoring the agnews dataset used for the dynamic layer exchanger (if you download it as agnews_data)
| from examples.models.cnn_model import MnistNetWithBnAndFrozen | ||
| from fl4health.clients.apfl_client import ApflClient | ||
| from fl4health.model_bases.apfl_base import APFLModule | ||
| from fl4health.model_bases.apfl_base import ApflModule |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The standard casing (at least as I'm told, even for acronyms) is not to have all caps. So just changed here and in a bunch of other places to adhere to that run.
| type=str, | ||
| help="Path to configuration file.", | ||
| default="examples/basic_example/config.yaml", | ||
| default="examples/federated_eval_example/config.yaml", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a typo fix in the default arg value
| config["local_epochs"], | ||
| config["batch_size"], | ||
| config["n_server_rounds"], | ||
| config["adaptive_proximal_weight"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These weren't used anywhere in the config because they are part of the strategy. So I removed them.
| # Run further local training after the federated learning has finished | ||
| client.train_by_epochs(2) | ||
| local_epochs_to_perform = 2 | ||
| log(INFO, f"Beginning {local_epochs_to_perform} local epochs of training on each client") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adding a logging statement for clarity
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| outputs = torch.sigmoid(self.linear(x)).squeeze() | ||
| outputs = torch.sigmoid(self.linear(x)).reshape(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the batch size is 1 (which can happen at the end of a dataloader if the len(dataloader) mod batch_size = 1) then squeeze will cause problems in the BCE loss function calculation. Using reshape rather than squeeze avoids this.
| initial_parameters=initial_parameters, | ||
| initial_control_variates=initial_control_variates, | ||
| initial_parameters=get_initial_model_parameters(model), | ||
| model=model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model is used by the SCAFFOLD strategy to automatically initialize the control variates to zero.
| if local_epochs is not None: | ||
| _, metrics = self.train_by_epochs(local_epochs, current_server_round) | ||
| local_steps = self.num_train_samples * local_epochs # total steps over training round | ||
| local_steps = len(self.train_loader) * local_epochs # total steps over training round |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jewelltaylor: Correct me if I'm wrong, but I believe the LHS was a bug that popped up somewhere? The number of steps is the number of batches in the loader times the number of epochs rather than the number of sample times the number of epochs I think? So I fixed it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah looks like I accidentally computed the total samples saw through all the epochs in a given training round instead of the number of steps. Good catch!
| self.client_control_variates_updates: Optional[NDArrays] = None # delta_c_i in paper | ||
| self.server_control_variates: Optional[NDArrays] = None # c in paper | ||
| self.optimizer: torch.optim.SGD # Scaffold require vanilla SGD as optimizer | ||
| self.server_model_state: Optional[NDArrays] = None # model state from server |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This wasn't used anywhere, so I killed it off. If I missed something please let me know.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, dug around and also came to the conclusion this isn't being used anywhere.
| # Note that these are weights the require a gradient, because they are used to compute control variates | ||
| self.server_model_weights = [ | ||
| model_params.cpu().detach().numpy() | ||
| model_params.cpu().detach().clone().numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't call clone here and the model_params are already on cpu, there is a memory optimization where this numpy array will share the underlying memory of the parameters with the original tensor. This is bad, because it means that when the model parameters are updated in training, the values in self.server_model_weights are also updated, which we really don't want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This does not cause the same problems if the model is on GPU, as model_params.cpu() makes a copy of the tensor on cpu, distinct from the tensor being trained on the GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good catch! I was not aware of this at all but will be sure to keep this in mind in the future.
| # still holds. | ||
| if self.client_control_variates is None: | ||
| self.client_control_variates = [np.zeros_like(weight) for weight in self.server_control_variates] | ||
| self.client_control_variates = copy.deepcopy(self.server_control_variates) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This used to be setting it to zeros, which if the user explicitly sets the initial server_control_variates to be something else, say all ones, would violate the assumption of the paper that the initial values for self.server_control_variates = average(client_control_variates). So here we're initializing the control variates to match the server ones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice that we can correctly handle init of control variates in all cases. Thanks for the update :)
| # y_i | ||
| client_model_weights = [val.cpu().detach().numpy() for val in self.model.parameters() if val.requires_grad] | ||
| client_model_weights = [ | ||
| val.cpu().detach().clone().numpy() for val in self.model.parameters() if val.requires_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a similar issue to what was happening above with self.server_model_weights if we're working exclusively on CPU. I don't think this causes any issues, as it happens after training, but still doing the clone to be safe.
| extra_fedprox_variable = float(packed_parameters[split_size:][0]) | ||
| # The packed contents should have length 1 | ||
| packed_contents = packed_parameters[split_size:] | ||
| assert len(packed_contents) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adding an extra sanity check in case this catches something for us in the future
| self.warm_start = warm_start | ||
|
|
||
| def initialize_paramameters(self, timeout: Optional[float]) -> None: | ||
| def _get_initial_parameters(self, timeout: Optional[float]) -> Parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than having to worry about _get_initial_parameters being called after the custom initialize_paramameters function for warm-start, we just override the parent _get_initial_parameters so that we have control over when the warm start work happens.
| numpy | ||
| opacus | ||
| pandas | ||
| portalocker |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I ran the dynamic layer exchanger example, my install was missing this library.
…me necessary packages for tests and code checks
|
|
||
|
|
||
| class TestFLServer(FlServer): | ||
| class DummyFLServer(FlServer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was generating a warning in our test collection because it was fronted by the word test. Just changing it to suppress the error.
…clean up step after the tests are done.
…github builds even if I ignore it in the pip install which makes no sense. It seems like it gets installed somewhere else, so whatever.
| run: | | ||
| pip install --upgrade pip | ||
| pip install -r requirements.txt | ||
| pip install $(grep -v '^torchdata\|^torchtext\|^torcheval' requirements.txt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skipping installing some packages that are only needed for the examples
| run: | | ||
| pip install --upgrade pip | ||
| pip install -r requirements.txt | ||
| pip install $(grep -v '^torchdata\|^torchtext\|^torcheval' requirements.txt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skipping installing some packages that are only needed for the examples
| pip install $(grep -v '^torchdata\|^torchtext\|^torcheval' requirements.txt) | ||
| pre-commit run --all-files | ||
| - name: Clearing pip for space | ||
| run: pip uninstall -y -r requirements.txt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For whatever reason, we started running out of disk space to save the pip cache on the github runner. This removes the site-packages to save a bit of space and allow the caching to complete
…this, for example, to the batches in training as well.
fatemetkl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything looks good to me!
fl4health/clients/scaffold_client.py
Outdated
|
|
||
| super().set_parameters(server_model_state, config) | ||
|
|
||
| # Note that these are weights the require a gradient, because they are used to compute control variates |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that these are weights that require a gradient .... instead perhaps. The doesn't seem to make sense in this context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Fixed this up.
jewelltaylor
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looked through the changes that pertained to Scaffold and everything looks great! One small comment about an unclear comment you had in the ScaffoldClient set_parameters method
PR Type
[Feature | Fix | Documentation | Other() ]
Short Description
This PR addresses two tickets: Ticket 1, Ticket 2
The first ticket is fairly mundane but important. It implements a particular code documentation style for our repository. The style that we'll be using is the Google docstring format. The structure can be automatically generated if using VSCode, which is quite convenient. I've added the guidelines to the ReadMe and updated the doc strings in a bunch of areas of the code as examples and a step towards more formal documentation of the library.
The second ticket is a quality of life ticket mainly addressing the approaches that exchange only a portion of the model weights (FENDA, APFL, Dynamic Weight Exchange, FedBN). On the first pass when the client models are being initialized by the server (i.e. the first time set_parameters is called) the clients use a FullParameterExchanger to set all weights in the network. This means that the server sends all weights in the model to be trained, regardless of the strategy, to sync all of the different client's weights. Thereafter, the set parameter exchangers are used to potentially exchange only sub-weights. This means that the server need not really be aware of how the parameter exchange is done under the hood when setting initial weights.
During the work for these two tickets I cleaned up the way SCAFFOLD warm-start happens and also fixed a small bug in the SCAFFOLD control variates calculations if they were being done on CPU (Note: This bug doesn't affect our SCAFFOLD experiments for the paper because they were done on GPU). As another QoL change for SCAFFOLD, and also to avoid inconsistencies with the initialization, I made it so that the user code provide initial values for the control variates and the client control variates were automatically set to these or the user could simply provide the model to the SCAFFOLD strategy and the model would be used to set the control variate to zero everywhere.
Finally, we had a few place where the class name casing didn't follow standard, so I just replaced them (Mostly just APFL... to Apfl...)
Tests Added
Added a set of tests associated with the flow of first pass, set all model parameters, second pass, use the appropriate parameter exchanger.
Also re-ran all of the example to confirm that they all still work as expected.