Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
checkpointer: Optional[ClientCheckpointModule] = None,
metrics_reporter: Optional[MetricsReporter] = None,
progress_bar: bool = False,
intermediate_checkpoint_dir: Optional[Path] = None,
intermediate_client_state_dir: Optional[Path] = None,
client_name: Optional[str] = None,
) -> None:
"""
Expand All @@ -68,7 +68,7 @@ def __init__(
progress_bar (bool): Whether or not to display a progress bar
during client training and validation. Uses tqdm. Defaults to
False
intermediate_checkpoint_dir (Optional[Path]): An optional path to store per round
intermediate_client_state_dir (Optional[Path]): An optional path to store per round
checkpoints.
client_name (str): An optional client name that uniquely identifies a client.
If not passed, a hash is randomly generated.
Expand All @@ -83,9 +83,9 @@ def __init__(

self.per_round_checkpointer: Union[None, PerRoundCheckpointer]

if intermediate_checkpoint_dir is not None:
if intermediate_client_state_dir is not None:
self.per_round_checkpointer = PerRoundCheckpointer(
intermediate_checkpoint_dir, Path(f"client_{self.client_name}.pt")
intermediate_client_state_dir, Path(f"client_{self.client_name}.pt")
)
else:
self.per_round_checkpointer = None
Expand Down Expand Up @@ -269,7 +269,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict
# If per_round_checkpointer not None and checkpoint exists load it and set attributes.
# Model not updated because FL restarted from most recent FL round (redo pre-empted round)
if self.per_round_checkpointer is not None and self.per_round_checkpointer.checkpoint_exists():
self.load_checkpoint()
self.load_client_state()

self.metrics_reporter.add_to_metrics_at_round(
current_server_round,
Expand Down Expand Up @@ -310,7 +310,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict
# After local client training has finished, checkpoint client state
# if per_round_checkpointer is not None
if self.per_round_checkpointer is not None:
self.save_checkpoint()
self.save_client_state()

# FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics
# calculation results.
Expand Down Expand Up @@ -1277,7 +1277,7 @@ def transform_gradients(self, losses: TrainingLosses) -> None:
"""
pass

def save_checkpoint(self) -> None:
def save_client_state(self) -> None:
"""
Saves checkpoint dict consisting of client name, total steps, lr schedulers,
metrics reporter and optimizers state. Method can be overridden to augment saved checkpointed state.
Expand All @@ -1290,14 +1290,14 @@ def save_checkpoint(self) -> None:
"total_steps": self.total_steps,
"client_name": self.client_name,
"metrics_reporter": self.metrics_reporter,
"optimizers_state": {key: opt.state_dict()["state"] for key, opt in self.optimizers.items()},
"optimizers_state": {key: optimizer.state_dict()["state"] for key, optimizer in self.optimizers.items()},
}

self.per_round_checkpointer.save_checkpoint(ckpt)

log(INFO, f"Saving client state to checkpoint at {self.per_round_checkpointer.checkpoint_path}")

def load_checkpoint(self) -> None:
def load_client_state(self) -> None:
"""
Load checkpoint dict consisting of client name, total steps, lr schedulers, metrics
reporter and optimizers state. Method can be overriden to augment loaded checkpointed state.
Expand All @@ -1319,10 +1319,11 @@ def load_checkpoint(self) -> None:
# Optimizer is updated in setup_client to reference model weights from server
# Thus, only optimizer state (per parameter values such as momentum)
# should be loaded
for opt, state in zip(self.optimizers.values(), ckpt["optimizers_state"].values()):
opt_state_dict = opt.state_dict()
opt_state_dict["state"] = state
opt.load_state_dict(opt_state_dict)
for key, optimizer in self.optimizers.items():
optimizer_state = ckpt["optimizers_state"][key]
optimizer_state_dict = optimizer.state_dict()
optimizer_state_dict["state"] = optimizer_state
optimizer.load_state_dict(optimizer_state_dict)

# Schedulers initialized in setup_client to reference correct optimizers
# Here we load in all other aspects of the scheduler state
Expand Down
7 changes: 4 additions & 3 deletions fl4health/clients/nnunet_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
verbose: bool = True,
metrics: Optional[Sequence[Metric]] = None,
progress_bar: bool = False,
intermediate_checkpoint_dir: Optional[Path] = None,
intermediate_client_state_dir: Optional[Path] = None,
loss_meter_type: LossMeterType = LossMeterType.AVERAGE,
checkpointer: Optional[ClientCheckpointModule] = None,
metrics_reporter: Optional[MetricsReporter] = None,
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
on the labels and predictions of the client model. Defaults to [].
progress_bar (bool, optional): Whether or not to print a progress bar to
stdout for training. Defaults to False
intermediate_checkpoint_dir (Optional[Path]): An optional path to store per round
intermediate_client_state_dir (Optional[Path]): An optional path to store per round
checkpoints.
loss_meter_type (LossMeterType, optional): Type of meter used to
track and compute the losses over each batch. Defaults to
Expand All @@ -150,7 +150,7 @@ def __init__(
checkpointer=checkpointer, # self.checkpointer
metrics_reporter=metrics_reporter, # self.metrics_reporter
progress_bar=progress_bar,
intermediate_checkpoint_dir=intermediate_checkpoint_dir,
intermediate_client_state_dir=intermediate_client_state_dir,
)

# Some nnunet client specific attributes
Expand Down Expand Up @@ -276,6 +276,7 @@ def get_lr_scheduler(self, optimizer_key: str, config: Config) -> _LRScheduler:

# Create and return LR Scheduler Wrapper for the PolyLRScheduler so that it is
# compatible with Torch LRScheduler
# Create and return LR Scheduler. This is nnunet default for version 2.5.1
return PolyLRSchedulerWrapper(
self.optimizers[optimizer_key], initial_lr=self.nnunet_trainer.initial_lr, max_steps=total_steps
)
Expand Down
32 changes: 14 additions & 18 deletions fl4health/server/base_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from flwr.common import EvaluateRes, Parameters
from flwr.common.logger import log
from flwr.common.parameter import ndarrays_to_parameters, parameters_to_ndarrays
from flwr.common.parameter import parameters_to_ndarrays
from flwr.common.typing import Code, GetParametersIns, Scalar
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy
Expand All @@ -22,14 +22,10 @@
from fl4health.server.polling import poll_clients
from fl4health.strategies.strategy_with_poll import StrategyWithPolling
from fl4health.utils.metrics import TEST_LOSS_KEY, TEST_NUM_EXAMPLES_KEY, TestMetricPrefix
from fl4health.utils.parameter_extraction import get_all_model_parameters
from fl4health.utils.random import generate_hash


def get_initial_model_parameters(client_model: nn.Module) -> Parameters:
# Initializing the model parameters on the server side.
return ndarrays_to_parameters([val.cpu().numpy() for _, val in client_model.state_dict().items()])


class FlServer(Server):
def __init__(
self,
Expand Down Expand Up @@ -337,13 +333,13 @@ def __init__(
strategy: Optional[Strategy] = None,
checkpointer: Optional[Union[TorchCheckpointer, Sequence[TorchCheckpointer]]] = None,
metrics_reporter: Optional[MetricsReporter] = None,
intermediate_checkpoint_dir: Optional[Path] = None,
intermediate_server_state_dir: Optional[Path] = None,
server_name: Optional[str] = None,
) -> None:
"""
This is a standard FL server but equipped with the assumption that the parameter exchanger is capable of
hydrating the provided server model fully such that it can be checkpointed. For custom checkpointing
functionality, one need only override _hydrate_model_for_checkpointing. If intermediate_checkpoint_dir
functionality, one need only override _hydrate_model_for_checkpointing. If intermediate_server_state_dir
is not None, performs per round checkpointing.


Expand All @@ -366,7 +362,7 @@ def __init__(
sequence to checkpoint based on multiple criteria. Defaults to
None.
metrics_reporter (Optional[MetricsReporter], optional): A metrics reporter instance to record the metrics
intermediate_checkpoint_dir (Path): A directory to store and load checkpoints from for the server
intermediate_server_state_dir (Path): A directory to store and load checkpoints from for the server
during an FL experiment.
server_name (Optional[str]): An optional string name to uniquely identify server.
"""
Expand All @@ -379,9 +375,9 @@ def __init__(

self.per_round_checkpointer: Union[None, PerRoundCheckpointer]

if intermediate_checkpoint_dir is not None:
if intermediate_server_state_dir is not None:
self.per_round_checkpointer = PerRoundCheckpointer(
intermediate_checkpoint_dir, Path(f"{self.server_name}.ckpt")
intermediate_server_state_dir, Path(f"{self.server_name}.ckpt")
)
else:
self.per_round_checkpointer = None
Expand All @@ -397,7 +393,7 @@ def _hydrate_model_for_checkpointing(self) -> nn.Module:
def fit(self, num_rounds: int, timeout: Optional[float]) -> Tuple[History, float]:
"""
Overrides method in parent class to call custom fit_with_per_round_checkpointing if an
intermediate_checkpoint_dir is provided. Otherwise calls standard fit method.
intermediate_server_state_dir is provided. Otherwise calls standard fit method.

Args:
num_rounds (int): The number of rounds to perform federated learning.
Expand Down Expand Up @@ -426,7 +422,7 @@ def fit_with_per_epoch_checkpointing(self, num_rounds: int, timeout: Optional[fl
"""
Runs federated learning for a number of rounds. Heavily based on the fit method from the base
server provided by flower (flwr.server.server.Server) except that it is resilient to pre-emptions.
It accomplishes this by checkpointing the sever state each round. In the case of pre-emption,
It accomplishes this by checkpointing the server state each round. In the case of pre-emption,
when the server is restarted it will load from the most recent checkpoint.

Args:
Expand All @@ -445,7 +441,7 @@ def fit_with_per_epoch_checkpointing(self, num_rounds: int, timeout: Optional[fl

# if checkpoint exists, update history, server round and model accordingly
if self.per_round_checkpointer.checkpoint_exists():
self.load_checkpoint()
self.load_server_state()
else:
log(INFO, "Initializing server state")
self.parameters = self._get_initial_parameters(server_round=1, timeout=timeout)
Expand Down Expand Up @@ -506,15 +502,15 @@ def fit_with_per_epoch_checkpointing(self, num_rounds: int, timeout: Optional[fl

# Save checkpoint after training and testing
self._hydrate_model_for_checkpointing()
self.save_checkpoint()
self.save_server_state()

# Bookkeeping
end_time = timeit.default_timer()
elapsed_time = end_time - start_time
log(INFO, "FL finished in %s", elapsed_time)
return self.history, elapsed_time

def save_checkpoint(self) -> None:
def save_server_state(self) -> None:
"""
Save server checkpoint consisting of model, history, server round, metrics reporter and server name.
This method can be overridden to add any necessary state to the checkpoint.
Expand All @@ -534,7 +530,7 @@ def save_checkpoint(self) -> None:

log(INFO, f"Saving server state to checkpoint at {self.per_round_checkpointer.checkpoint_path}")

def load_checkpoint(self) -> None:
def load_server_state(self) -> None:
"""
Load server checkpoint consisting of model, history, server name, current round and metrics reporter.
The method can be overridden to add any necessary state when loading the checkpoint.
Expand All @@ -555,7 +551,7 @@ def load_checkpoint(self) -> None:
self.server_name = ckpt["server_name"]
self.metrics_reporter = ckpt["metrics_reporter"]
self.history = ckpt["history"]
self.parameters = get_initial_model_parameters(ckpt["model"])
self.parameters = get_all_model_parameters(ckpt["model"])


class FlServerWithInitializer(FlServer):
Expand Down
2 changes: 2 additions & 0 deletions fl4health/utils/nnunet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ def __init__(
self._step_count: int
super().__init__(optimizer, -1, False)

# mypy incorrectly infers get_lr returns a float
# Documented issue https://github.com/pytorch/pytorch/issues/100804
@no_type_check
def get_lr(self) -> Sequence[float]:
curr_step = min(self._step_count, self.max_steps)
Expand Down
6 changes: 3 additions & 3 deletions research/picai/fedavg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
checkpointer: Optional[ClientCheckpointModule] = None,
metrics_reporter: Optional[MetricsReporter] = None,
progress_bar: bool = False,
intermediate_checkpoint_dir: Optional[Path] = None,
intermediate_client_state_dir: Optional[Path] = None,
overviews_dir: Path = Path("./"),
data_partition: Optional[int] = None,
) -> None:
Expand All @@ -53,7 +53,7 @@ def __init__(
checkpointer=checkpointer,
metrics_reporter=metrics_reporter,
progress_bar=progress_bar,
intermediate_checkpoint_dir=intermediate_checkpoint_dir,
intermediate_client_state_dir=intermediate_client_state_dir,
)

self.data_partition = data_partition
Expand Down Expand Up @@ -155,7 +155,7 @@ def get_optimizer(self, config: Config) -> Optimizer:
data_path=Path(args.base_dir),
metrics=metrics,
device=DEVICE,
intermediate_checkpoint_dir=args.artifact_dir,
intermediate_client_state_dir=args.artifact_dir,
overviews_dir=args.overviews_dir,
data_partition=args.data_partition,
)
Expand Down
2 changes: 1 addition & 1 deletion research/picai/fedavg/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main(config: Dict[str, Any], server_address: str, n_clients: int) -> None:
model=model,
parameter_exchanger=FullParameterExchanger(),
strategy=strategy,
intermediate_checkpoint_dir=args.artifact_dir,
intermediate_server_state_dir=args.artifact_dir,
)

fl.server.start_server(
Expand Down
8 changes: 4 additions & 4 deletions tests/smoke_tests/load_from_checkpoint_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
checkpointer: Optional[ClientCheckpointModule] = None,
metrics_reporter: Optional[MetricsReporter] = None,
progress_bar: bool = False,
intermediate_checkpoint_dir: Optional[Path] = None,
intermediate_client_state_dir: Optional[Path] = None,
client_name: Optional[str] = None,
seed: int = 42,
) -> None:
Expand All @@ -43,7 +43,7 @@ def __init__(
checkpointer,
metrics_reporter,
progress_bar,
intermediate_checkpoint_dir,
intermediate_client_state_dir,
client_name,
)
self.seed = seed
Expand Down Expand Up @@ -76,7 +76,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict
parser = argparse.ArgumentParser(description="FL Client Main")
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
parser.add_argument(
"--intermediate_checkpoint_dir",
"--intermediate_client_state_dir",
action="store",
type=str,
help="Path to intermediate checkpoint directory.",
Expand Down Expand Up @@ -107,7 +107,7 @@ def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict
data_path,
[Accuracy("accuracy")],
DEVICE,
intermediate_checkpoint_dir=args.intermediate_checkpoint_dir,
intermediate_client_state_dir=args.intermediate_client_state_dir,
client_name=args.client_name,
seed=args.seed,
)
Expand Down
8 changes: 4 additions & 4 deletions tests/smoke_tests/load_from_checkpoint_example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def fit_config(
}


def main(config: Dict[str, Any], intermediate_checkpoint_dir: str, server_name: str) -> None:
def main(config: Dict[str, Any], intermediate_server_state_dir: str, server_name: str) -> None:
# This function will be used to produce a config that is sent to each client to initialize their own environment
fit_config_fn = partial(
fit_config,
Expand Down Expand Up @@ -71,7 +71,7 @@ def main(config: Dict[str, Any], intermediate_checkpoint_dir: str, server_name:
None,
strategy,
checkpointers,
intermediate_checkpoint_dir=Path(intermediate_checkpoint_dir),
intermediate_server_state_dir=Path(intermediate_server_state_dir),
server_name=server_name,
)

Expand All @@ -95,7 +95,7 @@ def main(config: Dict[str, Any], intermediate_checkpoint_dir: str, server_name:
default="tests/smoke_tests/load_from_checkpoint_example/config.yaml",
)
parser.add_argument(
"--intermediate_checkpoint_dir",
"--intermediate_server_state_dir",
action="store",
type=str,
help="Path to intermediate checkpoint directory.",
Expand All @@ -120,4 +120,4 @@ def main(config: Dict[str, Any], intermediate_checkpoint_dir: str, server_name:
# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

main(config, args.intermediate_checkpoint_dir, args.server_name)
main(config, args.intermediate_server_state_dir, args.server_name)
6 changes: 3 additions & 3 deletions tests/smoke_tests/run_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ async def run_fault_tolerance_smoke_test(
server_python_path,
"--config_path",
partial_config_path,
"--intermediate_checkpoint_dir",
"--intermediate_server_state_dir",
intermediate_checkpoint_dir,
"--server_name",
server_name,
Expand All @@ -356,7 +356,7 @@ async def run_fault_tolerance_smoke_test(
server_python_path,
"--config_path",
config_path,
"--intermediate_checkpoint_dir",
"--intermediate_server_state_dir",
intermediate_checkpoint_dir,
"--server_name",
server_name,
Expand All @@ -366,7 +366,7 @@ async def run_fault_tolerance_smoke_test(
client_python_path,
"--dataset_path",
dataset_path,
"--intermediate_checkpoint_dir",
"--intermediate_client_state_dir",
intermediate_checkpoint_dir,
]
if seed is not None:
Expand Down