diff --git a/examples/hello-world/job_api/pt/client_api_kmeans.py b/examples/hello-world/job_api/pt/client_api_kmeans.py new file mode 100644 index 0000000000..a8515c94d9 --- /dev/null +++ b/examples/hello-world/job_api/pt/client_api_kmeans.py @@ -0,0 +1,156 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os + +from src.kmeans_assembler import KMeansAssembler +from src.split_csv import distribute_header_file, split_csv + +from nvflare import FedJob, ScriptExecutor +from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator +from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator +from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather +from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor +from nvflare.client.config import ExchangeFormat + + +def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample_rate, site_name_prefix="site-"): + input_file = input_data_path + output_directory = output_dir + num_parts = site_num + site_name_prefix = site_name_prefix + sample_rate = sample_rate + split_csv(input_file, output_directory, num_parts, site_name_prefix, sample_rate) + distribute_header_file(input_header_path, output_directory, num_parts, site_name_prefix) + + +if __name__ == "__main__": + n_clients = 3 + num_rounds = 2 + train_script = "src/kmeans_fl.py" + data_input_dir = "/tmp/nvflare/higgs/data" + data_output_dir = "/tmp/nvflare/higgs/split_data" + + # Download data + os.makedirs(data_input_dir, exist_ok=True) + higgs_zip_file = os.path.join(data_input_dir, "higgs.zip") + if not os.path.exists(higgs_zip_file): + os.system( + f"curl -o {higgs_zip_file} https://archive.ics.uci.edu/static/public/280/higgs.zip" + ) # This might take a while. The file is 2.8 GB. + os.system(f"unzip -d {data_input_dir} {higgs_zip_file}") + os.system( + f"gunzip -c {os.path.join(data_input_dir, 'HIGGS.csv.gz')} > {os.path.join(data_input_dir, 'higgs.csv')}" + ) + + # Generate the csv header file + # Your list of data + features = [ + "label", + "lepton_pt", + "lepton_eta", + "lepton_phi", + "missing_energy_magnitude", + "missing_energy_phi", + "jet_1_pt", + "jet_1_eta", + "jet_1_phi", + "jet_1_b_tag", + "jet_2_pt", + "jet_2_eta", + "jet_2_phi", + "jet_2_b_tag", + "jet_3_pt", + "jet_3_eta", + "jet_3_phi", + "jet_3_b_tag", + "jet_4_pt", + "jet_4_eta", + "jet_4_phi", + "jet_4_b_tag", + "m_jj", + "m_jjj", + "m_lv", + "m_jlv", + "m_bb", + "m_wbb", + "m_wwbb", + ] + + # Specify the file path + file_path = os.path.join(data_input_dir, "headers.csv") + + with open(file_path, "w", newline="") as file: + csv_writer = csv.writer(file) + csv_writer.writerow(features) + + print(f"features written to {file_path}") + + # Split the data + split_higgs( + input_data_path=os.path.join(data_input_dir, "higgs.csv"), + input_header_path=os.path.join(data_input_dir, "headers.csv"), + output_dir=data_output_dir, + site_num=n_clients, + sample_rate=0.3, + ) + + # Create the federated learning job + job = FedJob(name="kmeans") + + controller = ScatterAndGather( + min_clients=n_clients, + num_rounds=num_rounds, + wait_time_after_min_received=0, + aggregator_id="aggregator", + persistor_id="persistor", # TODO: Allow adding python objects rather than ids + shareable_generator_id="shareable_generator", + train_task_name="train", # Client will start training once received such task. + train_timeout=0, + ) + job.to(controller, "server") + + # For kmeans with sklean, we need a custom persistor + # JoblibModelParamPersistor is a persistor which save/read the model to/from file with JobLib format. + persistor = JoblibModelParamPersistor(initial_params={"n_clusters": 2}) + # When assigning the persistor to the server, we need to specify the id that's expected + # by the ScatterAndGather controller. + job.to(persistor, "server") + + # Similarly, ScatterAndGather expects a "shareable_generator" which we need to assign to the server. + job.to(FullModelShareableGenerator(), "server") + + # ScatterAndGather also expects an "aggregator" which we define here. + # The actual aggregation function is defined by an "assembler" to specify how to handle the collected updates. + aggregator = CollectAndAssembleAggregator( + assembler_id="kmeans_assembler" + ) # TODO: Allow adding KMeansAssembler() directly + job.to(aggregator, "server") + + # This is the assembler designed for k-Means algorithm. + # As CollectAndAssembleAggregator expects an assembler_id, we need to specify it here. + job.to(KMeansAssembler(), "server", id="kmeans_assembler") + + # Add clients + for i in range(n_clients): + executor = ScriptExecutor( + task_script_path=train_script, + task_script_args=f"--data_root_dir {data_output_dir}", + params_exchange_format=ExchangeFormat.RAW, # kmeans requires raw values only rather than PyTorch Tensors (the default) + ) + job.to(executor, f"site-{i+1}", gpu=0) # HIGGs data splitter assumes site names start from 1 + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir") diff --git a/examples/hello-world/job_api/pt/client_api_lightning.py b/examples/hello-world/job_api/pt/client_api_lightning.py new file mode 100644 index 0000000000..d6c75415ee --- /dev/null +++ b/examples/hello-world/job_api/pt/client_api_lightning.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.lit_net import LitNet + +from nvflare import FedAvg, FedJob, ScriptExecutor + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + train_script = "src/cifar10_lightning_fl.py" + + job = FedJob(name="cifar10_fedavg_lightning") + + # Define the controller workflow and send to server + controller = FedAvg( + min_clients=n_clients, + num_rounds=num_rounds, + ) + job.to(controller, "server") + + # Define the initial global model and send to server + job.to(LitNet(), "server") + + # Add clients + for i in range(n_clients): + executor = ScriptExecutor( + task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + ) + job.to(executor, f"site-{i}", gpu=0) + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir") diff --git a/examples/hello-world/job_api/pt/client_api_pt.py b/examples/hello-world/job_api/pt/client_api_pt.py new file mode 100644 index 0000000000..cddabd7178 --- /dev/null +++ b/examples/hello-world/job_api/pt/client_api_pt.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.net import Net + +from nvflare import FedAvg, FedJob, ScriptExecutor + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + train_script = "src/cifar10_fl.py" + + job = FedJob(name="cifar10_fedavg") + + # Define the controller workflow and send to server + controller = FedAvg( + min_clients=n_clients, + num_rounds=num_rounds, + ) + job.to(controller, "server") + + # Define the initial global model and send to server + job.to(Net(), "server") + + # Add clients + for i in range(n_clients): + executor = ScriptExecutor( + task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + ) + job.to(executor, f"site-{i}", gpu=0) + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir") diff --git a/examples/hello-world/job_api/pt/client_api_pt_cyclic_cc.py b/examples/hello-world/job_api/pt/client_api_pt_cyclic_cc.py new file mode 100644 index 0000000000..79fa92bbf3 --- /dev/null +++ b/examples/hello-world/job_api/pt/client_api_pt_cyclic_cc.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.net import Net + +from nvflare import FedJob, ScriptExecutor +from nvflare.app_common.ccwf import CyclicClientController, CyclicServerController +from nvflare.app_common.ccwf.comps.simple_model_shareable_generator import SimpleModelShareableGenerator +from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 3 + train_script = "src/cifar10_fl.py" + + job = FedJob(name="cifar10_cyclic") + + controller = CyclicServerController(num_rounds=num_rounds, max_status_report_interval=300) + job.to(controller, "server") + + for i in range(n_clients): + executor = ScriptExecutor( + task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + ) + job.to(executor, f"site-{i}", gpu=0) + + # Add client-side controller for cyclic workflow + executor = CyclicClientController() + job.to(executor, f"site-{i}", tasks=["cyclic_*"]) + + # In swarm learning, each client uses a model persistor and shareable_generator + job.to(PTFileModelPersistor(model=Net()), f"site-{i}", id="persistor") + job.to(SimpleModelShareableGenerator(), f"site-{i}", id="shareable_generator") + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir") diff --git a/examples/hello-world/job_api/pt/client_api_pt_dp_filter.py b/examples/hello-world/job_api/pt/client_api_pt_dp_filter.py new file mode 100644 index 0000000000..92d84931cc --- /dev/null +++ b/examples/hello-world/job_api/pt/client_api_pt_dp_filter.py @@ -0,0 +1,46 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.net import Net + +from nvflare import FedAvg, FedJob, FilterType, ScriptExecutor +from nvflare.app_common.filters.percentile_privacy import PercentilePrivacy + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + train_script = "src/cifar10_fl.py" + + job = FedJob(name="cifar10_fedavg_privacy") + + # Define the controller workflow and send to server + controller = FedAvg( + min_clients=n_clients, + num_rounds=num_rounds, + ) + job.to(controller, "server") + + # Define the initial global model and send to server + job.to(Net(), "server") + + for i in range(n_clients): + executor = ScriptExecutor(task_script_path=train_script, task_script_args="") + job.to(executor, f"site-{i}", tasks=["train"], gpu=0) + + # add privacy filter. + pp_filter = PercentilePrivacy(percentile=10, gamma=0.01) + job.to(pp_filter, f"site-{i}", tasks=["train"], filter_type=FilterType.TASK_RESULT) + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir") diff --git a/examples/hello-world/job_api/pt/client_api_pt_swarm.py b/examples/hello-world/job_api/pt/client_api_pt_swarm.py new file mode 100644 index 0000000000..6a00b39514 --- /dev/null +++ b/examples/hello-world/job_api/pt/client_api_pt_swarm.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.net import Net + +from nvflare import FedJob, ScriptExecutor +from nvflare.apis.dxo import DataKind +from nvflare.app_common.aggregators.intime_accumulate_model_aggregator import InTimeAccumulateWeightedAggregator +from nvflare.app_common.ccwf import ( + CrossSiteEvalClientController, + CrossSiteEvalServerController, + SwarmClientController, + SwarmServerController, +) +from nvflare.app_common.ccwf.comps.simple_model_shareable_generator import SimpleModelShareableGenerator +from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 3 + train_script = "src/train_eval_submit.py" + + job = FedJob(name="cifar10_swarm") + + controller = SwarmServerController( + num_rounds=num_rounds, + ) + job.to(controller, "server") + controller = CrossSiteEvalServerController(eval_task_timeout=300) + job.to(controller, "server") + + # Define the initial server model + job.to(Net(), "server") + + for i in range(n_clients): + executor = ScriptExecutor(task_script_path=train_script) + job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"]) + + client_controller = SwarmClientController() + job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) + + client_controller = CrossSiteEvalClientController() + job.to(client_controller, f"site-{i}", tasks=["cse_*"]) + + # In swarm learning, each client acts also as an aggregator + aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS) + job.to(aggregator, f"site-{i}") + + # In swarm learning, each client uses a model persistor and shareable_generator + job.to(PTFileModelPersistor(model=Net()), f"site-{i}") + job.to(SimpleModelShareableGenerator(), f"site-{i}") + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir") diff --git a/examples/hello-world/job_api/pt/model_learner_xsite_val.py b/examples/hello-world/job_api/pt/model_learner_xsite_val.py new file mode 100644 index 0000000000..b2b700679a --- /dev/null +++ b/examples/hello-world/job_api/pt/model_learner_xsite_val.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +sys.path.insert(0, os.path.join(os.getcwd(), "..", "..", "..", "advanced", "cifar10")) + +from pt.learners.cifar10_model_learner import CIFAR10ModelLearner +from pt.networks.cifar10_nets import ModerateCNN +from pt.utils.cifar10_data_splitter import Cifar10DataSplitter +from pt.utils.cifar10_data_utils import load_cifar10_data + +from nvflare import FedAvg, FedJob +from nvflare.app_common.executors.model_learner_executor import ModelLearnerExecutor +from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + aggregation_epochs = 4 + alpha = 0.1 + train_split_root = f"/tmp/cifar10_splits/clients{n_clients}_alpha{alpha}" # avoid overwriting results + + job = FedJob(name="cifar10_fedavg") + + ctrl1 = FedAvg( + min_clients=n_clients, + num_rounds=num_rounds, + ) + ctrl2 = CrossSiteModelEval() + + load_cifar10_data() # preload CIFAR10 data + data_splitter = Cifar10DataSplitter( + split_dir=train_split_root, + num_sites=n_clients, + alpha=alpha, + ) + + job.to(ctrl1, "server") + job.to(ctrl2, "server") + job.to(data_splitter, "server") + + # Define the initial global model and send to server + job.to(ModerateCNN(), "server") + + for i in range(n_clients): + learner = CIFAR10ModelLearner(train_idx_root=train_split_root, aggregation_epochs=aggregation_epochs, lr=0.01) + executor = ModelLearnerExecutor(learner_id=learner) + job.to(executor, f"site-{i+1}", gpu=0) # data splitter assumes client names start from 1 + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir") diff --git a/examples/hello-world/job_api/pt/src/__init__.py b/examples/hello-world/job_api/pt/src/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/examples/hello-world/job_api/pt/src/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/hello-world/job_api/pt/src/cifar10_fl.py b/examples/hello-world/job_api/pt/src/cifar10_fl.py new file mode 100644 index 0000000000..b32e1bdc32 --- /dev/null +++ b/examples/hello-world/job_api/pt/src/cifar10_fl.py @@ -0,0 +1,136 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from src.net import Net + +# (1) import nvflare client API +import nvflare.client as flare + +# (optional) metrics +from nvflare.client.tracking import SummaryWriter + +# (optional) set a fix place so we don't need to download everytime +DATASET_PATH = "/tmp/nvflare/data" +# If available, we use GPU to speed things up. +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def main(): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + batch_size = 4 + epochs = 2 + + trainset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + testset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + net = Net() + + # (2) initializes NVFlare client API + flare.init() + + summary_writer = SummaryWriter() + while flare.is_running(): + # (3) receives FLModel from NVFlare + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + # (4) loads model from NVFlare + net.load_state_dict(input_model.params) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + # (optional) use GPU to speed things up + net.to(DEVICE) + # (optional) calculate total steps + steps = epochs * len(trainloader) + for epoch in range(epochs): # loop over the dataset multiple times + + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + # (optional) use GPU to speed things up + inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + global_step = input_model.current_round * steps + epoch * len(trainloader) + i + + summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(net.state_dict(), PATH) + + # (5) wraps evaluation logic into a method to re-use for + # evaluation on both trained and received model + def evaluate(input_weights): + net = Net() + net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + net.to(DEVICE) + + correct = 0 + total = 0 + # since we're not training, we don't need to calculate the gradients for our outputs + with torch.no_grad(): + for data in testloader: + # (optional) use GPU to speed things up + images, labels = data[0].to(DEVICE), data[1].to(DEVICE) + # calculate outputs by running images through the network + outputs = net(images) + # the class with the highest energy is what we choose as prediction + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") + return 100 * correct // total + + # (6) evaluate on received model for model selection + accuracy = evaluate(input_model.params) + # (7) construct trained FL model + output_model = flare.FLModel( + params=net.cpu().state_dict(), + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + # (8) send model back to NVFlare + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/examples/hello-world/job_api/pt/src/cifar10_lightning_fl.py b/examples/hello-world/job_api/pt/src/cifar10_lightning_fl.py new file mode 100644 index 0000000000..9b26db523d --- /dev/null +++ b/examples/hello-world/job_api/pt/src/cifar10_lightning_fl.py @@ -0,0 +1,108 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torchvision +import torchvision.transforms as transforms +from pytorch_lightning import LightningDataModule, Trainer, seed_everything +from src.lit_net import LitNet +from torch.utils.data import DataLoader, random_split + +# (1) import nvflare lightning client API +import nvflare.client.lightning as flare + +seed_everything(7) + + +DATASET_PATH = "/tmp/nvflare/data" +BATCH_SIZE = 4 + +transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + +class CIFAR10DataModule(LightningDataModule): + def __init__(self, data_dir: str = DATASET_PATH, batch_size: int = BATCH_SIZE): + super().__init__() + self.data_dir = data_dir + self.batch_size = batch_size + + def prepare_data(self): + torchvision.datasets.CIFAR10(root=self.data_dir, train=True, download=True, transform=transform) + torchvision.datasets.CIFAR10(root=self.data_dir, train=False, download=True, transform=transform) + + def setup(self, stage: str): + # Assign train/val datasets for use in dataloaders + if stage == "fit" or stage == "validate": + cifar_full = torchvision.datasets.CIFAR10( + root=self.data_dir, train=True, download=False, transform=transform + ) + self.cifar_train, self.cifar_val = random_split(cifar_full, [0.8, 0.2]) + + # Assign test dataset for use in dataloader(s) + if stage == "test" or stage == "predict": + self.cifar_test = torchvision.datasets.CIFAR10( + root=self.data_dir, train=False, download=False, transform=transform + ) + + def train_dataloader(self): + return DataLoader(self.cifar_train, batch_size=self.batch_size) + + def val_dataloader(self): + return DataLoader(self.cifar_val, batch_size=self.batch_size) + + def test_dataloader(self): + return DataLoader(self.cifar_test, batch_size=self.batch_size) + + def predict_dataloader(self): + return DataLoader(self.cifar_test, batch_size=self.batch_size) + + +def main(): + model = LitNet() + cifar10_dm = CIFAR10DataModule() + if torch.cuda.is_available(): + trainer = Trainer(max_epochs=1, accelerator="gpu", devices=1 if torch.cuda.is_available() else None) + else: + trainer = Trainer(max_epochs=1, devices=None) + + # (2) patch the lightning trainer + flare.patch(trainer) + + while flare.is_running(): + # (3) receives FLModel from NVFlare + # Note that we don't need to pass this input_model to trainer + # because after flare.patch the trainer.fit/validate will get the + # global model internally + input_model = flare.receive() + print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n") + + # (4) evaluate the current global model to allow server-side model selection + print("--- validate global model ---") + trainer.validate(model, datamodule=cifar10_dm) + + # perform local training starting with the received global model + print("--- train new model ---") + trainer.fit(model, datamodule=cifar10_dm) + + # test local model + print("--- test new model ---") + trainer.test(ckpt_path="best", datamodule=cifar10_dm) + + # get predictions + print("--- prediction with new best model ---") + trainer.predict(ckpt_path="best", datamodule=cifar10_dm) + + +if __name__ == "__main__": + main() diff --git a/examples/hello-world/job_api/pt/src/kmeans_assembler.py b/examples/hello-world/job_api/pt/src/kmeans_assembler.py new file mode 100644 index 0000000000..23e6fdc62e --- /dev/null +++ b/examples/hello-world/job_api/pt/src/kmeans_assembler.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import numpy as np +from sklearn.cluster import KMeans + +from nvflare.apis.dxo import DXO, DataKind +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.aggregators.assembler import Assembler +from nvflare.app_common.app_constant import AppConstants + + +class KMeansAssembler(Assembler): + def __init__(self): + super().__init__(data_kind=DataKind.WEIGHTS) + # Aggregator needs to keep record of historical + # center and count information for mini-batch kmeans + self.center = None + self.count = None + self.n_cluster = 0 + + def get_model_params(self, dxo: DXO): + data = dxo.data + return {"center": data["center"], "count": data["count"]} + + def assemble(self, data: Dict[str, dict], fl_ctx: FLContext) -> DXO: + current_round = fl_ctx.get_prop(AppConstants.CURRENT_ROUND) + if current_round == 0: + # First round, collect the information regarding n_feature and n_cluster + # Initialize the aggregated center and count to all zero + client_0 = list(self.collection.keys())[0] + self.n_cluster = self.collection[client_0]["center"].shape[0] + n_feature = self.collection[client_0]["center"].shape[1] + self.center = np.zeros([self.n_cluster, n_feature]) + self.count = np.zeros([self.n_cluster]) + # perform one round of KMeans over the submitted centers + # to be used as the original center points + # no count for this round + center_collect = [] + for _, record in self.collection.items(): + center_collect.append(record["center"]) + centers = np.concatenate(center_collect) + kmeans_center_initial = KMeans(n_clusters=self.n_cluster) + kmeans_center_initial.fit(centers) + self.center = kmeans_center_initial.cluster_centers_ + else: + # Mini-batch k-Means step to assemble the received centers + for center_idx in range(self.n_cluster): + centers_global_rescale = self.center[center_idx] * self.count[center_idx] + # Aggregate center, add new center to previous estimate, weighted by counts + for _, record in self.collection.items(): + centers_global_rescale += record["center"][center_idx] * record["count"][center_idx] + self.count[center_idx] += record["count"][center_idx] + # Rescale to compute mean of all points (old and new combined) + alpha = 1 / self.count[center_idx] + centers_global_rescale *= alpha + # Update the global center + self.center[center_idx] = centers_global_rescale + params = {"center": self.center} + dxo = DXO(data_kind=self.expected_data_kind, data=params) + + return dxo diff --git a/examples/hello-world/job_api/pt/src/kmeans_fl.py b/examples/hello-world/job_api/pt/src/kmeans_fl.py new file mode 100644 index 0000000000..8df5acab2a --- /dev/null +++ b/examples/hello-world/job_api/pt/src/kmeans_fl.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import csv +from typing import Dict, List, Tuple + +import pandas as pd +from sklearn.cluster import KMeans, MiniBatchKMeans, kmeans_plusplus +from sklearn.metrics import homogeneity_score +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler + +# (1) import nvflare client API +from nvflare import client as flare + + +def to_dataset_tuple(data: dict): + dataset_tuples = {} + for dataset_name, dataset in data.items(): + dataset_tuples[dataset_name] = _to_data_tuple(dataset) + return dataset_tuples + + +def _to_data_tuple(data): + data_num = data.shape[0] + # split to feature and label + x = data.iloc[:, 1:] + y = data.iloc[:, 0] + return x.to_numpy(), y.to_numpy(), data_num + + +def load_features(feature_data_path: str) -> List: + try: + features = [] + with open(feature_data_path, "r") as file: + # Create a CSV reader object + csv_reader = csv.reader(file) + line_list = next(csv_reader) + features = line_list + return features + except Exception as e: + raise Exception(f"Load header for path'{feature_data_path} failed! {e}") + + +def load_data( + data_path: str, data_features: List, random_state: int, test_size: float, skip_rows=None +) -> Dict[str, pd.DataFrame]: + try: + df: pd.DataFrame = pd.read_csv( + data_path, names=data_features, sep=r"\s*,\s*", engine="python", na_values="?", skiprows=skip_rows + ) + + train, test = train_test_split(df, test_size=test_size, random_state=random_state) + + return {"train": train, "test": test} + + except Exception as e: + raise Exception(f"Load data for path '{data_path}' failed! {e}") + + +def transform_data(data: Dict[str, Tuple]) -> Dict[str, Tuple]: + # Standardize features by removing the mean and scaling to unit variance + scaler = StandardScaler() + scaled_datasets = {} + for dataset_name, (x_data, y_data, data_num) in data.items(): + x_scaled = scaler.fit_transform(x_data) + scaled_datasets[dataset_name] = (x_scaled, y_data, data_num) + return scaled_datasets + + +def main(): + parser = define_args_parser() + args = parser.parse_args() + data_root_dir = args.data_root_dir + random_state = args.random_state + test_size = args.test_size + skip_rows = args.skip_rows + + # (2) initializes NVFlare client API + flare.init() + + site_name = flare.get_site_name() + feature_data_path = f"{data_root_dir}/{site_name}_header.csv" + features = load_features(feature_data_path) + n_features = len(features) - 1 # remove label + + data_path = f"{data_root_dir}/{site_name}.csv" + data = load_data( + data_path=data_path, data_features=features, random_state=random_state, test_size=test_size, skip_rows=skip_rows + ) + + data = to_dataset_tuple(data) + dataset = transform_data(data) + x_train, y_train, train_size = dataset["train"] + x_test, y_test, test_size = dataset["test"] + + model = None + n_clusters = 0 + while flare.is_running(): + # (3) receives FLModel from NVFlare + input_model = flare.receive() + global_params = input_model.params + curr_round = input_model.current_round + + print(f"current_round={curr_round}") + if curr_round == 0: + # (4) first round, initialize centers with kmeans++ + n_clusters = global_params["n_clusters"] + center_local, _ = kmeans_plusplus(x_train, n_clusters=n_clusters, random_state=random_state) + params = {"center": center_local, "count": None} + homo = 0.0 + else: + # (5) following rounds, starting from global centers + center_global = global_params["center"] + model = MiniBatchKMeans( + n_clusters=n_clusters, + batch_size=train_size, + max_iter=1, + init=center_global, + n_init=1, + reassignment_ratio=0, + random_state=random_state, + ) + # train model + model.fit(x_train) + center_local = model.cluster_centers_ + count_local = model._counts + params = {"center": center_local, "count": count_local} + + # (6) evaluate global center + model_eval = KMeans(n_clusters=n_clusters, init=center_global, n_init=1) + model_eval.fit(center_global) + homo = evaluate_model(x_test, model_eval, y_test) + # Print the results + print(f"{site_name}: global model homogeneity_score: {homo:.4f}") + + # (7) construct trained FL model + metrics = {"accuracy": homo} + output_model = flare.FLModel(params=params, metrics=metrics) + + # (8) send model back to NVFlare + flare.send(output_model) + + +def evaluate_model(x_test, model, y_test): + # Make predictions on the testing set + y_pred = model.predict(x_test) + + # Evaluate the model + homo = homogeneity_score(y_test, y_pred) + return homo + + +def define_args_parser(): + parser = argparse.ArgumentParser(description="scikit learn linear model with SGD") + parser.add_argument("--data_root_dir", type=str, help="root directory path to csv data file") + parser.add_argument("--random_state", type=int, default=0, help="random state") + parser.add_argument("--test_size", type=float, default=0.2, help="test ratio, default to 20%") + parser.add_argument( + "--skip_rows", + type=str, + default=None, + help="""If skip_rows = N, the first N rows will be skipped, + if skiprows=[0, 1, 4], the rows will be skip by row indices such as row 0,1,4 will be skipped. """, + ) + return parser + + +if __name__ == "__main__": + main() diff --git a/examples/hello-world/job_api/pt/src/lit_net.py b/examples/hello-world/job_api/pt/src/lit_net.py new file mode 100644 index 0000000000..01df5eedde --- /dev/null +++ b/examples/hello-world/job_api/pt/src/lit_net.py @@ -0,0 +1,72 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch.nn as nn +import torch.optim as optim +from pytorch_lightning import LightningModule +from src.net import Net +from torchmetrics import Accuracy + +NUM_CLASSES = 10 +criterion = nn.CrossEntropyLoss() + + +class LitNet(LightningModule): + def __init__(self): + super().__init__() + self.save_hyperparameters() + self.model = Net() + self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) + self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) + # (optional) pass additional information via self.__fl_meta__ + self.__fl_meta__ = {} + + def forward(self, x): + out = self.model(x) + return out + + def training_step(self, batch, batch_idx): + x, labels = batch + outputs = self(x) + loss = criterion(outputs, labels) + self.train_acc(outputs, labels) + self.log("train_loss", loss) + self.log("train_acc", self.train_acc, on_step=True, on_epoch=False) + return loss + + def evaluate(self, batch, stage=None): + x, labels = batch + outputs = self(x) + loss = criterion(outputs, labels) + self.valid_acc(outputs, labels) + + if stage: + self.log(f"{stage}_loss", loss) + self.log(f"{stage}_acc", self.valid_acc, on_step=True, on_epoch=True) + return outputs + + def validation_step(self, batch, batch_idx): + self.evaluate(batch, "val") + + def test_step(self, batch, batch_idx): + self.evaluate(batch, "test") + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + return self.evaluate(batch) + + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9) + return {"optimizer": optimizer} diff --git a/examples/hello-world/job_api/pt/src/net.py b/examples/hello-world/job_api/pt/src/net.py new file mode 100644 index 0000000000..031f84f432 --- /dev/null +++ b/examples/hello-world/job_api/pt/src/net.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/examples/hello-world/job_api/pt/src/split_csv.py b/examples/hello-world/job_api/pt/src/split_csv.py new file mode 100644 index 0000000000..c6eab5992d --- /dev/null +++ b/examples/hello-world/job_api/pt/src/split_csv.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os +import shutil + +import pandas as pd + + +def load_data(input_file_path) -> pd.DataFrame: + # Read the CSV file into a pandas DataFrame + return pd.read_csv(input_file_path, header=None) + + +def split_csv(input_file_path, output_dir, num_parts, part_name, sample_rate): + df = load_data(input_file_path) + + # Calculate the number of rows per part + total_size = int(len(df) * sample_rate) + rows_per_part = total_size // num_parts + + # Create the output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Split the DataFrame into N parts + for i in range(num_parts): + start_index = i * rows_per_part + end_index = (i + 1) * rows_per_part if i < num_parts - 1 else total_size + print(f"{part_name}{i + 1}=", f"{start_index=}", f"{end_index=}") + part_df = df.iloc[start_index:end_index] + + # Save each part to a separate CSV file + output_file = os.path.join(output_dir, f"{part_name}{i + 1}.csv") + part_df.to_csv(output_file, header=False, index=False) + + +def distribute_header_file(input_header_file: str, output_dir: str, num_parts: int, part_name: str): + source_file = input_header_file + + # Split the DataFrame into N parts + for i in range(num_parts): + output_file = os.path.join(output_dir, f"{part_name}{i + 1}_header.csv") + shutil.copy(source_file, output_file) + print(f"File copied to {output_file}") + + +def define_args_parser(): + parser = argparse.ArgumentParser(description="csv data split") + parser.add_argument("--input_data_path", type=str, help="input path to csv data file") + parser.add_argument("--input_header_path", type=str, help="input path to csv header file") + parser.add_argument("--site_num", type=int, help="Total number of sites or clients") + parser.add_argument("--site_name_prefix", type=str, default="site-", help="Site name prefix") + parser.add_argument("--output_dir", type=str, default="/tmp/nvflare/dataset/output", help="Output directory") + parser.add_argument( + "--sample_rate", type=float, default="1.0", help="percent of the data will be used. default 1.0 for 100%" + ) + return parser + + +def main(): + parser = define_args_parser() + args = parser.parse_args() + input_file = args.input_data_path + output_directory = args.output_dir + num_parts = args.site_num + site_name_prefix = args.site_name_prefix + sample_rate = args.sample_rate + split_csv(input_file, output_directory, num_parts, site_name_prefix, sample_rate) + distribute_header_file(args.input_header_path, output_directory, num_parts, site_name_prefix) + + +if __name__ == "__main__": + main() diff --git a/examples/hello-world/job_api/pt/src/train_eval_submit.py b/examples/hello-world/job_api/pt/src/train_eval_submit.py new file mode 100644 index 0000000000..11d72024cb --- /dev/null +++ b/examples/hello-world/job_api/pt/src/train_eval_submit.py @@ -0,0 +1,188 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from src.net import Net + +# (1) import nvflare client API +import nvflare.client as flare +from nvflare.app_common.app_constant import ModelName + +# (optional) set a fix place so we don't need to download everytime +CIFAR10_ROOT = "/tmp/nvflare/data/cifar10" +# (optional) We change to use GPU to speed things up. +# if you want to use CPU, change DEVICE="cpu" +DEVICE = "cuda:0" + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--dataset_path", type=str, default=CIFAR10_ROOT, nargs="?") + parser.add_argument("--batch_size", type=int, default=4, nargs="?") + parser.add_argument("--num_workers", type=int, default=1, nargs="?") + parser.add_argument("--local_epochs", type=int, default=2, nargs="?") + parser.add_argument("--model_path", type=str, default=f"{CIFAR10_ROOT}/cifar_net.pth", nargs="?") + return parser.parse_args() + + +def main(): + # define local parameters + args = define_parser() + + dataset_path = args.dataset_path + batch_size = args.batch_size + num_workers = args.num_workers + local_epochs = args.local_epochs + model_path = args.model_path + + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers) + testset = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=True, transform=transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + net = Net() + best_accuracy = 0.0 + + # wraps evaluation logic into a method to re-use for + # evaluation on both trained and received model + def evaluate(input_weights): + net = Net() + net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + net.to(DEVICE) + + correct = 0 + total = 0 + # since we're not training, we don't need to calculate the gradients for our outputs + with torch.no_grad(): + for data in testloader: + # (optional) use GPU to speed things up + images, labels = data[0].to(DEVICE), data[1].to(DEVICE) + # calculate outputs by running images through the network + outputs = net(images) + # the class with the highest energy is what we choose as prediction + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + return 100 * correct // total + + # (2) initialize NVFlare client API + flare.init() + + # (3) run continously when launch_once=true + while flare.is_running(): + + # (4) receive FLModel from NVFlare + input_model = flare.receive() + client_id = flare.get_site_name() + + # Based on different "task" we will do different things + # for "train" task (flare.is_train()) we use the received model to do training and/or evaluation + # and send back updated model and/or evaluation metrics, if the "train_with_evaluation" is specified as True + # in the config_fed_client we will need to do evaluation and include the evaluation metrics + # for "evaluate" task (flare.is_evaluate()) we use the received model to do evaluation + # and send back the evaluation metrics + # for "submit_model" task (flare.is_submit_model()) we just need to send back the local model + # (5) performing train task on received model + if flare.is_train(): + print(f"({client_id}) current_round={input_model.current_round}, total_rounds={input_model.total_rounds}") + + # (5.1) loads model from NVFlare + net.load_state_dict(input_model.params) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + # (optional) use GPU to speed things up + net.to(DEVICE) + # (optional) calculate total steps + steps = local_epochs * len(trainloader) + for epoch in range(local_epochs): # loop over the dataset multiple times + + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + # (optional) use GPU to speed things up + inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + running_loss = 0.0 + + print(f"({client_id}) Finished Training") + + # (5.2) evaluation on local trained model to save best model + local_accuracy = evaluate(net.state_dict()) + print(f"({client_id}) Evaluating local trained model. Accuracy on the 10000 test images: {local_accuracy}") + if local_accuracy > best_accuracy: + best_accuracy = local_accuracy + torch.save(net.state_dict(), model_path) + + # (5.3) evaluate on received model for model selection + accuracy = evaluate(input_model.params) + print( + f"({client_id}) Evaluating received model for model selection. Accuracy on the 10000 test images: {accuracy}" + ) + + # (5.4) construct trained FL model + output_model = flare.FLModel( + params=net.cpu().state_dict(), + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + + # (5.5) send model back to NVFlare + flare.send(output_model) + + # (6) performing evaluate task on received model + elif flare.is_evaluate(): + accuracy = evaluate(input_model.params) + flare.send(flare.FLModel(metrics={"accuracy": accuracy})) + + # (7) performing submit_model task to obtain best local model + elif flare.is_submit_model(): + model_name = input_model.meta["submit_model_name"] + if model_name == ModelName.BEST_MODEL: + try: + weights = torch.load(model_path) + net = Net() + net.load_state_dict(weights) + flare.send(flare.FLModel(params=net.cpu().state_dict())) + except Exception as e: + raise ValueError("Unable to load best model") from e + else: + raise ValueError(f"Unknown model_type: {model_name}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/__init__.py b/nvflare/__init__.py index 1ecdb822bc..1ca104a5aa 100644 --- a/nvflare/__init__.py +++ b/nvflare/__init__.py @@ -19,3 +19,7 @@ # https://github.com/microsoft/pylance-release/issues/856 from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner as SimulatorRunner +from nvflare.app_common.executors.script_executor import ScriptExecutor +from nvflare.app_common.workflows.fedavg import FedAvg +from nvflare.fed_job import FedJob, FilterType + diff --git a/nvflare/app_common/executors/model_learner_executor.py b/nvflare/app_common/executors/model_learner_executor.py index f22bcd0a1c..13a088c1bf 100644 --- a/nvflare/app_common/executors/model_learner_executor.py +++ b/nvflare/app_common/executors/model_learner_executor.py @@ -84,7 +84,11 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): def _create_learner(self, fl_ctx: FLContext): engine = fl_ctx.get_engine() - self.learner = engine.get_component(self.learner_id) + if isinstance(self.learner_id, str): + self.learner = engine.get_component(self.learner_id) + else: + self.learner = self.learner_id + if self.learner: self.learner_name = self.learner.__class__.__name__ diff --git a/nvflare/app_common/executors/script_executor.py b/nvflare/app_common/executors/script_executor.py new file mode 100644 index 0000000000..02105953c9 --- /dev/null +++ b/nvflare/app_common/executors/script_executor.py @@ -0,0 +1,72 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.executors.in_process_client_api_executor import InProcessClientAPIExecutor +from nvflare.app_opt.pt.decomposers import TensorDecomposer +from nvflare.app_opt.pt.params_converter import NumpyToPTParamsConverter, PTToNumpyParamsConverter +from nvflare.client.config import ExchangeFormat, TransferType +from nvflare.fuel.utils import fobs + + +class ScriptExecutor(InProcessClientAPIExecutor): + def __init__( + self, + task_script_path: str, + task_script_args: str = "", + task_wait_time: Optional[float] = None, + result_pull_interval: float = 0.5, + log_pull_interval: Optional[float] = None, + params_transfer_type: TransferType = TransferType.FULL, + from_nvflare_converter_id: Optional[str] = None, + to_nvflare_converter_id: Optional[str] = None, + train_with_evaluation: bool = True, + train_task_name: str = "train", + evaluate_task_name: str = "evaluate", + submit_model_task_name: str = "submit_model", + params_exchange_format=ExchangeFormat.PYTORCH, + ): + """Wrapper around InProcessClientAPIExecutor for different params_exchange_format. Currently defaulting to `params_exchange_format=ExchangeFormat.PYTORCH`. + + Args: + """ + super(ScriptExecutor, self).__init__( + task_script_path=task_script_path, + task_script_args=task_script_args, + task_wait_time=task_wait_time, + result_pull_interval=result_pull_interval, + train_with_evaluation=train_with_evaluation, + train_task_name=train_task_name, + evaluate_task_name=evaluate_task_name, + submit_model_task_name=submit_model_task_name, + from_nvflare_converter_id=from_nvflare_converter_id, + to_nvflare_converter_id=to_nvflare_converter_id, + params_exchange_format=params_exchange_format, + params_transfer_type=params_transfer_type, + log_pull_interval=log_pull_interval, + ) + if params_exchange_format == ExchangeFormat.PYTORCH: + fobs.register(TensorDecomposer) + + if self._from_nvflare_converter is None: + self._from_nvflare_converter = NumpyToPTParamsConverter( + [AppConstants.TASK_TRAIN, AppConstants.TASK_VALIDATION] + ) + if self._to_nvflare_converter is None: + self._to_nvflare_converter = PTToNumpyParamsConverter( + [AppConstants.TASK_TRAIN, AppConstants.TASK_SUBMIT_MODEL] + ) + # TODO: support other params_exchange_format diff --git a/nvflare/fed_job.py b/nvflare/fed_job.py new file mode 100644 index 0000000000..4581805e06 --- /dev/null +++ b/nvflare/fed_job.py @@ -0,0 +1,274 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List + +import torch.nn as nn # TODO: How to handle pytorch dependency? + +from nvflare.apis.executor import Executor +from nvflare.apis.filter import Filter +from nvflare.apis.impl.controller import Controller +from nvflare.app_common.abstract.aggregator import Aggregator +from nvflare.app_common.abstract.learnable_persistor import LearnablePersistor +from nvflare.app_common.abstract.shareable_generator import ShareableGenerator +from nvflare.app_common.executors.script_executor import ScriptExecutor +from nvflare.app_common.widgets.external_configurator import ExternalConfigurator +from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector +from nvflare.app_common.widgets.metric_relay import MetricRelay +from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator +from nvflare.app_opt.pt import PTFileModelPersistor +from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator +from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver +from nvflare.fuel.utils.constants import Mode +from nvflare.fuel.utils.pipe.file_pipe import FilePipe +from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig +from nvflare.job_config.fed_job_config import FedJobConfig + + +class FilterType: + TASK_RESULT = "_TASK_RESULT_FILTER_TYPE_" + TASK_DATA = "_TASK_DATA_FILTER_TYPE_" + + +class FedApp: + def __init__(self): + """FedApp handles `ClientAppConfig` and `ServerAppConfig` and allows setting task result or task data filters. + + Args: + """ + self.app = None # Union[ClientAppConfig, ServerAppConfig] + self._used_ids = [] + + def get_app_config(self): + return self.app + + def add_task_result_filter(self, tasks: List[str], task_filter: Filter): + self.app.add_task_result_filter(tasks, task_filter) + + def add_task_data_filter(self, tasks: List[str], task_filter: Filter): + self.app.add_task_data_filter(tasks, task_filter) + + def add_component(self, component, id=None): + if id is None: + id = "component" + self.app.add_component(self._check_id(id), component) + + def set_persistor(self, model: nn.Module): # TODO: support other persistors + component = PTFileModelPersistor(model=model) + self.app.add_component("persistor", component) + + component = PTFileModelLocator(pt_persistor_id="persistor") + self.app.add_component("model_locator", component) + + def _check_id(self, id: str = "") -> str: + if id not in self._used_ids: + self._used_ids.append(id) + else: + cnt = 0 + _id = f"{id}_{cnt}" + while _id in self._used_ids: + cnt += 1 + id = f"{id}_{cnt}" + self._used_ids.append(id) + return id + + +class FedJob: + def __init__(self, name="fed_job", min_clients=1, mandatory_clients=None, key_metric="accuracy") -> None: + """FedJob allows users to generate job configurations in a Pythonic way. + The `to()` routine allows users to send different components to either the server or clients. + + Args: + job_name: the name of the NVFlare job + min_clients: the minimum number of clients for the job + mandatory_clients: mandatory clients to run the job (optional) + key_metric: Metric used to determine if the model is globally best. + if metrics are a `dict`, `key_metric` can select the metric used for global model selection. + Defaults to "accuracy". + """ + self.job_name = name + self.key_metric = key_metric + self.clients = [] + self.job: FedJobConfig = FedJobConfig( + job_name=self.job_name, min_clients=min_clients, mandatory_clients=mandatory_clients + ) + self._deploy_map = {} + self._deployed = False + self._gpus = {} + + def to( + self, obj: Any, target: str, tasks: List[str] = None, gpu: int = None, filter_type: FilterType = None, id=None + ): + """assign an `obj` to a target (server or clients). + The obj will be given a default `id` if non is provided based on its type. + + Returns: + + """ + if isinstance(obj, Controller): + if target not in self._deploy_map: + self._deploy_map[target] = ControllerApp(key_metric=self.key_metric) + self._deploy_map[target].add_controller(obj, id) + elif isinstance(obj, Executor): + if target not in self._deploy_map: + if isinstance(obj, ScriptExecutor): + external_scripts = [obj._task_script_path] + else: + external_scripts = None + self._deploy_map[target] = ExecutorApp(external_scripts=external_scripts) + self.clients.append(target) + if gpu is not None: + if target not in self._gpus: # GPU can only be selected once per client. + self._gpus[target] = str(gpu) + else: + print(f"{target} already set to use GPU {self._gpus[target]}. Ignoring gpu={gpu}.") + self._deploy_map[target].add_executor(obj, tasks=tasks) + else: # handle objects that are not Controller or Executor type + if target not in self._deploy_map: + raise ValueError( + f"{target} doesn't have a `Controller` or `Executor`. Deploy one first before adding components!" + ) + + if isinstance(obj, nn.Module): # if model, set a persistor + self._deploy_map[target].set_persistor(obj) + elif isinstance(obj, Filter): # handle filters + if filter_type == FilterType.TASK_RESULT: + self._deploy_map[target].add_task_result_filter(tasks, obj) + elif filter_type == FilterType.TASK_DATA: + self._deploy_map[target].add_task_data_filter(tasks, obj) + else: + raise ValueError( + f"Provided a filter for {target} without specifying valid `filter_type`. Select from `FilterType.TASK_RESULT` or `FilterType.TASK_DATA`." + ) + else: # handle other types + if id is None: # handle built-in types and set ids + if isinstance(obj, Aggregator): + id = "aggregator" + elif isinstance(obj, LearnablePersistor): + id = "persistor" + elif isinstance(obj, ShareableGenerator): + id = "shareable_generator" + self._deploy_map[target].add_component(obj, id) + + def _deploy(self, app: FedApp, target: str): + if not isinstance(app, FedApp): + raise ValueError(f"App needs to be of type `FedApp` but was type {type(app)}") + + client_server_config = app.get_app_config() + if isinstance(client_server_config, ClientAppConfig): + app_config = FedAppConfig(server_app=None, client_app=client_server_config) + app_name = f"app_{target}" + elif isinstance(client_server_config, ServerAppConfig): + app_config = FedAppConfig(server_app=client_server_config, client_app=None) + app_name = "app_server" + else: + raise ValueError( + f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}" + ) + + self.job.add_fed_app(app_name, app_config) + self.job.set_site_app(target, app_name) + + def _run_deploy(self): + if not self._deployed: + for target in self._deploy_map: + self._deploy(self._deploy_map[target], target) + + self._deployed = True + + def export_job(self, job_root): + self._run_deploy() + self.job.generate_job_config(job_root) + + def simulator_run(self, workspace, threads: int = None): + self._run_deploy() + + n_clients = len(self.clients) + if threads is None: + threads = n_clients + + self.job.simulator_run( + workspace, + clients=",".join(self.clients), + n_clients=n_clients, + threads=threads, + gpu=",".join([self._gpus[client] for client in self._gpus.keys()]), + ) + + +class ExecutorApp(FedApp): + def __init__(self, external_scripts: List = None): + """Wrapper around `ClientAppConfig`. + + Args: + external_scripts: List of external scripts that need to be deployed to the client. Defaults to None. + """ + super().__init__() + self.external_scripts = external_scripts + self._create_client_app() + + def add_executor(self, executor, tasks=None): + if tasks is None: + tasks = ["*"] # Add executor for any task by default + self.app.add_executor(tasks, executor) + + def _create_client_app(self): + self.app = ClientAppConfig() + + component = FilePipe( # TODO: support CellPipe, causes type error for passing secure_mode = "{SECURE_MODE}" + mode=Mode.PASSIVE, + root_path="{WORKSPACE}/{JOB_ID}/{SITE_NAME}", + ) + self.app.add_component("metrics_pipe", component) + + component = MetricRelay(pipe_id="metrics_pipe", event_type="fed.analytix_log_stats", read_interval=0.1) + self.app.add_component("metric_relay", component) + + component = ExternalConfigurator(component_ids=["metric_relay"]) + self.app.add_component("config_preparer", component) + + if self.external_scripts is not None: + for _script in self.external_scripts: + self.app.add_ext_script(_script) + + +class ControllerApp(FedApp): + """Wrapper around `ServerAppConfig`. + + Args: + """ + + def __init__(self, key_metric="accuracy"): + super().__init__() + self.key_metric = key_metric + self._create_server_app() + + def add_controller(self, controller, id=None): + if id is None: + id = "controller" + self.app.add_workflow(self._check_id(id), controller) + + def _create_server_app(self): + self.app: ServerAppConfig = ServerAppConfig() + + component = ValidationJsonGenerator() + self.app.add_component("json_generator", component) + + if self.key_metric: + component = IntimeModelSelector(key_metric=self.key_metric) + self.app.add_component("model_selector", component) + + # TODO: make different tracking receivers configurable + component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"]) + self.app.add_component("receiver", component)