Skip to content

Conversation

@emersodb
Copy link
Collaborator

@emersodb emersodb commented Oct 6, 2023

PR Type

[Feature | Fix | Documentation | Other() ]

Short Description

This PR addresses two tickets: Ticket 1, Ticket 2

The first ticket is fairly mundane but important. It implements a particular code documentation style for our repository. The style that we'll be using is the Google docstring format. The structure can be automatically generated if using VSCode, which is quite convenient. I've added the guidelines to the ReadMe and updated the doc strings in a bunch of areas of the code as examples and a step towards more formal documentation of the library.

The second ticket is a quality of life ticket mainly addressing the approaches that exchange only a portion of the model weights (FENDA, APFL, Dynamic Weight Exchange, FedBN). On the first pass when the client models are being initialized by the server (i.e. the first time set_parameters is called) the clients use a FullParameterExchanger to set all weights in the network. This means that the server sends all weights in the model to be trained, regardless of the strategy, to sync all of the different client's weights. Thereafter, the set parameter exchangers are used to potentially exchange only sub-weights. This means that the server need not really be aware of how the parameter exchange is done under the hood when setting initial weights.

During the work for these two tickets I cleaned up the way SCAFFOLD warm-start happens and also fixed a small bug in the SCAFFOLD control variates calculations if they were being done on CPU (Note: This bug doesn't affect our SCAFFOLD experiments for the paper because they were done on GPU). As another QoL change for SCAFFOLD, and also to avoid inconsistencies with the initialization, I made it so that the user code provide initial values for the control variates and the client control variates were automatically set to these or the user could simply provide the model to the SCAFFOLD strategy and the model would be used to set the control variate to zero everywhere.

Finally, we had a few place where the class name casing didn't follow standard, so I just replaced them (Mostly just APFL... to Apfl...)

Tests Added

Added a set of tests associated with the flow of first pass, set all model parameters, second pass, use the appropriate parameter exchanger.

Also re-ran all of the example to confirm that they all still work as expected.

…nes to readme.md, adding in some preliminary comments in the new format. Adding code that assumes that the first set_parameters is used to initialize the whole model, regardless of the type of parameter exchanger used. Simplifying the warm start implementation in SCAFFOLD a bit, also allowing for default initial control variates to zero if provided a model. You can also pass custom initial control variates as well and we enforce that these are the initial control variates for the server and the clients. Finally, fixing a bug in basic client where local steps was not calculated correctly
…eters are set the first time but are expected to be subsets thereafter where appropriate. Also renaming some of the APFL objects to follow the correct capitalization scheme. Finally, adding a sanity check to fedprox exchange of the proximal weight.
**/datasets/news_classification/partitioned_datasets/**
**/datasets/mnist_data/**
**/datasets/MNIST/**
**/datasets/agnews_data/**
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just formally ignoring the agnews dataset used for the dynamic layer exchanger (if you download it as agnews_data)

from examples.models.cnn_model import MnistNetWithBnAndFrozen
from fl4health.clients.apfl_client import ApflClient
from fl4health.model_bases.apfl_base import APFLModule
from fl4health.model_bases.apfl_base import ApflModule
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The standard casing (at least as I'm told, even for acronyms) is not to have all caps. So just changed here and in a bunch of other places to adhere to that run.

type=str,
help="Path to configuration file.",
default="examples/basic_example/config.yaml",
default="examples/federated_eval_example/config.yaml",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a typo fix in the default arg value

config["local_epochs"],
config["batch_size"],
config["n_server_rounds"],
config["adaptive_proximal_weight"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These weren't used anywhere in the config because they are part of the strategy. So I removed them.

# Run further local training after the federated learning has finished
client.train_by_epochs(2)
local_epochs_to_perform = 2
log(INFO, f"Beginning {local_epochs_to_perform} local epochs of training on each client")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adding a logging statement for clarity


def forward(self, x: torch.Tensor) -> torch.Tensor:
outputs = torch.sigmoid(self.linear(x)).squeeze()
outputs = torch.sigmoid(self.linear(x)).reshape(-1)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the batch size is 1 (which can happen at the end of a dataloader if the len(dataloader) mod batch_size = 1) then squeeze will cause problems in the BCE loss function calculation. Using reshape rather than squeeze avoids this.

initial_parameters=initial_parameters,
initial_control_variates=initial_control_variates,
initial_parameters=get_initial_model_parameters(model),
model=model,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model is used by the SCAFFOLD strategy to automatically initialize the control variates to zero.

if local_epochs is not None:
_, metrics = self.train_by_epochs(local_epochs, current_server_round)
local_steps = self.num_train_samples * local_epochs # total steps over training round
local_steps = len(self.train_loader) * local_epochs # total steps over training round
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jewelltaylor: Correct me if I'm wrong, but I believe the LHS was a bug that popped up somewhere? The number of steps is the number of batches in the loader times the number of epochs rather than the number of sample times the number of epochs I think? So I fixed it here.

Copy link
Contributor

@jewelltaylor jewelltaylor Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah looks like I accidentally computed the total samples saw through all the epochs in a given training round instead of the number of steps. Good catch!

self.client_control_variates_updates: Optional[NDArrays] = None # delta_c_i in paper
self.server_control_variates: Optional[NDArrays] = None # c in paper
self.optimizer: torch.optim.SGD # Scaffold require vanilla SGD as optimizer
self.server_model_state: Optional[NDArrays] = None # model state from server
Copy link
Collaborator Author

@emersodb emersodb Oct 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't used anywhere, so I killed it off. If I missed something please let me know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, dug around and also came to the conclusion this isn't being used anywhere.

# Note that these are weights the require a gradient, because they are used to compute control variates
self.server_model_weights = [
model_params.cpu().detach().numpy()
model_params.cpu().detach().clone().numpy()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't call clone here and the model_params are already on cpu, there is a memory optimization where this numpy array will share the underlying memory of the parameters with the original tensor. This is bad, because it means that when the model parameters are updated in training, the values in self.server_model_weights are also updated, which we really don't want.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This does not cause the same problems if the model is on GPU, as model_params.cpu() makes a copy of the tensor on cpu, distinct from the tensor being trained on the GPU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good catch! I was not aware of this at all but will be sure to keep this in mind in the future.

# still holds.
if self.client_control_variates is None:
self.client_control_variates = [np.zeros_like(weight) for weight in self.server_control_variates]
self.client_control_variates = copy.deepcopy(self.server_control_variates)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This used to be setting it to zeros, which if the user explicitly sets the initial server_control_variates to be something else, say all ones, would violate the assumption of the paper that the initial values for self.server_control_variates = average(client_control_variates). So here we're initializing the control variates to match the server ones

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice that we can correctly handle init of control variates in all cases. Thanks for the update :)

# y_i
client_model_weights = [val.cpu().detach().numpy() for val in self.model.parameters() if val.requires_grad]
client_model_weights = [
val.cpu().detach().clone().numpy() for val in self.model.parameters() if val.requires_grad
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a similar issue to what was happening above with self.server_model_weights if we're working exclusively on CPU. I don't think this causes any issues, as it happens after training, but still doing the clone to be safe.

extra_fedprox_variable = float(packed_parameters[split_size:][0])
# The packed contents should have length 1
packed_contents = packed_parameters[split_size:]
assert len(packed_contents) == 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adding an extra sanity check in case this catches something for us in the future

self.warm_start = warm_start

def initialize_paramameters(self, timeout: Optional[float]) -> None:
def _get_initial_parameters(self, timeout: Optional[float]) -> Parameters:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than having to worry about _get_initial_parameters being called after the custom initialize_paramameters function for warm-start, we just override the parent _get_initial_parameters so that we have control over when the warm start work happens.

numpy
opacus
pandas
portalocker
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I ran the dynamic layer exchanger example, my install was missing this library.



class TestFLServer(FlServer):
class DummyFLServer(FlServer):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was generating a warning in our test collection because it was fronted by the word test. Just changing it to suppress the error.

run: |
pip install --upgrade pip
pip install -r requirements.txt
pip install $(grep -v '^torchdata\|^torchtext\|^torcheval' requirements.txt)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping installing some packages that are only needed for the examples

run: |
pip install --upgrade pip
pip install -r requirements.txt
pip install $(grep -v '^torchdata\|^torchtext\|^torcheval' requirements.txt)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping installing some packages that are only needed for the examples

pip install $(grep -v '^torchdata\|^torchtext\|^torcheval' requirements.txt)
pre-commit run --all-files
- name: Clearing pip for space
run: pip uninstall -y -r requirements.txt
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For whatever reason, we started running out of disk space to save the pip cache on the github runner. This removes the site-packages to save a bit of space and allow the caching to complete

@emersodb emersodb requested a review from zxj-c October 10, 2023 19:12
…this, for example, to the batches in training as well.
Copy link
Collaborator

@fatemetkl fatemetkl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything looks good to me!

@emersodb emersodb mentioned this pull request Oct 12, 2023

super().set_parameters(server_model_state, config)

# Note that these are weights the require a gradient, because they are used to compute control variates
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that these are weights that require a gradient .... instead perhaps. The doesn't seem to make sense in this context

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Fixed this up.

Copy link
Contributor

@jewelltaylor jewelltaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looked through the changes that pertained to Scaffold and everything looks great! One small comment about an unclear comment you had in the ScaffoldClient set_parameters method

@emersodb emersodb merged commit 79e0f5d into main Oct 16, 2023
@emersodb emersodb deleted the dbe/setting_doc_string_format branch October 16, 2023 18:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants