-
Notifications
You must be signed in to change notification settings - Fork 155
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a761a1b
commit 9e7d82d
Showing
20 changed files
with
1,718 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import csv | ||
import os | ||
|
||
from src.kmeans_assembler import KMeansAssembler | ||
from src.split_csv import distribute_header_file, split_csv | ||
|
||
from nvflare import FedJob, ScriptExecutor | ||
from nvflare.app_common.aggregators.collect_and_assemble_aggregator import CollectAndAssembleAggregator | ||
from nvflare.app_common.shareablegenerators.full_model_shareable_generator import FullModelShareableGenerator | ||
from nvflare.app_common.workflows.scatter_and_gather import ScatterAndGather | ||
from nvflare.app_opt.sklearn.joblib_model_param_persistor import JoblibModelParamPersistor | ||
from nvflare.client.config import ExchangeFormat | ||
|
||
|
||
def split_higgs(input_data_path, input_header_path, output_dir, site_num, sample_rate, site_name_prefix="site-"): | ||
input_file = input_data_path | ||
output_directory = output_dir | ||
num_parts = site_num | ||
site_name_prefix = site_name_prefix | ||
sample_rate = sample_rate | ||
split_csv(input_file, output_directory, num_parts, site_name_prefix, sample_rate) | ||
distribute_header_file(input_header_path, output_directory, num_parts, site_name_prefix) | ||
|
||
|
||
if __name__ == "__main__": | ||
n_clients = 3 | ||
num_rounds = 2 | ||
train_script = "src/kmeans_fl.py" | ||
data_input_dir = "/tmp/nvflare/higgs/data" | ||
data_output_dir = "/tmp/nvflare/higgs/split_data" | ||
|
||
# Download data | ||
os.makedirs(data_input_dir, exist_ok=True) | ||
higgs_zip_file = os.path.join(data_input_dir, "higgs.zip") | ||
if not os.path.exists(higgs_zip_file): | ||
os.system( | ||
f"curl -o {higgs_zip_file} https://archive.ics.uci.edu/static/public/280/higgs.zip" | ||
) # This might take a while. The file is 2.8 GB. | ||
os.system(f"unzip -d {data_input_dir} {higgs_zip_file}") | ||
os.system( | ||
f"gunzip -c {os.path.join(data_input_dir, 'HIGGS.csv.gz')} > {os.path.join(data_input_dir, 'higgs.csv')}" | ||
) | ||
|
||
# Generate the csv header file | ||
# Your list of data | ||
features = [ | ||
"label", | ||
"lepton_pt", | ||
"lepton_eta", | ||
"lepton_phi", | ||
"missing_energy_magnitude", | ||
"missing_energy_phi", | ||
"jet_1_pt", | ||
"jet_1_eta", | ||
"jet_1_phi", | ||
"jet_1_b_tag", | ||
"jet_2_pt", | ||
"jet_2_eta", | ||
"jet_2_phi", | ||
"jet_2_b_tag", | ||
"jet_3_pt", | ||
"jet_3_eta", | ||
"jet_3_phi", | ||
"jet_3_b_tag", | ||
"jet_4_pt", | ||
"jet_4_eta", | ||
"jet_4_phi", | ||
"jet_4_b_tag", | ||
"m_jj", | ||
"m_jjj", | ||
"m_lv", | ||
"m_jlv", | ||
"m_bb", | ||
"m_wbb", | ||
"m_wwbb", | ||
] | ||
|
||
# Specify the file path | ||
file_path = os.path.join(data_input_dir, "headers.csv") | ||
|
||
with open(file_path, "w", newline="") as file: | ||
csv_writer = csv.writer(file) | ||
csv_writer.writerow(features) | ||
|
||
print(f"features written to {file_path}") | ||
|
||
# Split the data | ||
split_higgs( | ||
input_data_path=os.path.join(data_input_dir, "higgs.csv"), | ||
input_header_path=os.path.join(data_input_dir, "headers.csv"), | ||
output_dir=data_output_dir, | ||
site_num=n_clients, | ||
sample_rate=0.3, | ||
) | ||
|
||
# Create the federated learning job | ||
job = FedJob(name="kmeans") | ||
|
||
controller = ScatterAndGather( | ||
min_clients=n_clients, | ||
num_rounds=num_rounds, | ||
wait_time_after_min_received=0, | ||
aggregator_id="aggregator", | ||
persistor_id="persistor", # TODO: Allow adding python objects rather than ids | ||
shareable_generator_id="shareable_generator", | ||
train_task_name="train", # Client will start training once received such task. | ||
train_timeout=0, | ||
) | ||
job.to(controller, "server") | ||
|
||
# For kmeans with sklean, we need a custom persistor | ||
# JoblibModelParamPersistor is a persistor which save/read the model to/from file with JobLib format. | ||
persistor = JoblibModelParamPersistor(initial_params={"n_clusters": 2}) | ||
# When assigning the persistor to the server, we need to specify the id that's expected | ||
# by the ScatterAndGather controller. | ||
job.to(persistor, "server") | ||
|
||
# Similarly, ScatterAndGather expects a "shareable_generator" which we need to assign to the server. | ||
job.to(FullModelShareableGenerator(), "server") | ||
|
||
# ScatterAndGather also expects an "aggregator" which we define here. | ||
# The actual aggregation function is defined by an "assembler" to specify how to handle the collected updates. | ||
aggregator = CollectAndAssembleAggregator( | ||
assembler_id="kmeans_assembler" | ||
) # TODO: Allow adding KMeansAssembler() directly | ||
job.to(aggregator, "server") | ||
|
||
# This is the assembler designed for k-Means algorithm. | ||
# As CollectAndAssembleAggregator expects an assembler_id, we need to specify it here. | ||
job.to(KMeansAssembler(), "server", id="kmeans_assembler") | ||
|
||
# Add clients | ||
for i in range(n_clients): | ||
executor = ScriptExecutor( | ||
task_script_path=train_script, | ||
task_script_args=f"--data_root_dir {data_output_dir}", | ||
params_exchange_format=ExchangeFormat.RAW, # kmeans requires raw values only rather than PyTorch Tensors (the default) | ||
) | ||
job.to(executor, f"site-{i+1}", gpu=0) # HIGGs data splitter assumes site names start from 1 | ||
|
||
job.export_job("/tmp/nvflare/jobs/job_config") | ||
job.simulator_run("/tmp/nvflare/jobs/workdir") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from src.lit_net import LitNet | ||
|
||
from nvflare import FedAvg, FedJob, ScriptExecutor | ||
|
||
if __name__ == "__main__": | ||
n_clients = 2 | ||
num_rounds = 2 | ||
train_script = "src/cifar10_lightning_fl.py" | ||
|
||
job = FedJob(name="cifar10_fedavg_lightning") | ||
|
||
# Define the controller workflow and send to server | ||
controller = FedAvg( | ||
min_clients=n_clients, | ||
num_rounds=num_rounds, | ||
) | ||
job.to(controller, "server") | ||
|
||
# Define the initial global model and send to server | ||
job.to(LitNet(), "server") | ||
|
||
# Add clients | ||
for i in range(n_clients): | ||
executor = ScriptExecutor( | ||
task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" | ||
) | ||
job.to(executor, f"site-{i}", gpu=0) | ||
|
||
job.export_job("/tmp/nvflare/jobs/job_config") | ||
job.simulator_run("/tmp/nvflare/jobs/workdir") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from src.net import Net | ||
|
||
from nvflare import FedAvg, FedJob, ScriptExecutor | ||
|
||
if __name__ == "__main__": | ||
n_clients = 2 | ||
num_rounds = 2 | ||
train_script = "src/cifar10_fl.py" | ||
|
||
job = FedJob(name="cifar10_fedavg") | ||
|
||
# Define the controller workflow and send to server | ||
controller = FedAvg( | ||
min_clients=n_clients, | ||
num_rounds=num_rounds, | ||
) | ||
job.to(controller, "server") | ||
|
||
# Define the initial global model and send to server | ||
job.to(Net(), "server") | ||
|
||
# Add clients | ||
for i in range(n_clients): | ||
executor = ScriptExecutor( | ||
task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" | ||
) | ||
job.to(executor, f"site-{i}", gpu=0) | ||
|
||
job.export_job("/tmp/nvflare/jobs/job_config") | ||
job.simulator_run("/tmp/nvflare/jobs/workdir") |
47 changes: 47 additions & 0 deletions
47
examples/hello-world/job_api/pt/client_api_pt_cyclic_cc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from src.net import Net | ||
|
||
from nvflare import FedJob, ScriptExecutor | ||
from nvflare.app_common.ccwf import CyclicClientController, CyclicServerController | ||
from nvflare.app_common.ccwf.comps.simple_model_shareable_generator import SimpleModelShareableGenerator | ||
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor | ||
|
||
if __name__ == "__main__": | ||
n_clients = 2 | ||
num_rounds = 3 | ||
train_script = "src/cifar10_fl.py" | ||
|
||
job = FedJob(name="cifar10_cyclic") | ||
|
||
controller = CyclicServerController(num_rounds=num_rounds, max_status_report_interval=300) | ||
job.to(controller, "server") | ||
|
||
for i in range(n_clients): | ||
executor = ScriptExecutor( | ||
task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" | ||
) | ||
job.to(executor, f"site-{i}", gpu=0) | ||
|
||
# Add client-side controller for cyclic workflow | ||
executor = CyclicClientController() | ||
job.to(executor, f"site-{i}", tasks=["cyclic_*"]) | ||
|
||
# In swarm learning, each client uses a model persistor and shareable_generator | ||
job.to(PTFileModelPersistor(model=Net()), f"site-{i}", id="persistor") | ||
job.to(SimpleModelShareableGenerator(), f"site-{i}", id="shareable_generator") | ||
|
||
job.export_job("/tmp/nvflare/jobs/job_config") | ||
job.simulator_run("/tmp/nvflare/jobs/workdir") |
46 changes: 46 additions & 0 deletions
46
examples/hello-world/job_api/pt/client_api_pt_dp_filter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from src.net import Net | ||
|
||
from nvflare import FedAvg, FedJob, FilterType, ScriptExecutor | ||
from nvflare.app_common.filters.percentile_privacy import PercentilePrivacy | ||
|
||
if __name__ == "__main__": | ||
n_clients = 2 | ||
num_rounds = 2 | ||
train_script = "src/cifar10_fl.py" | ||
|
||
job = FedJob(name="cifar10_fedavg_privacy") | ||
|
||
# Define the controller workflow and send to server | ||
controller = FedAvg( | ||
min_clients=n_clients, | ||
num_rounds=num_rounds, | ||
) | ||
job.to(controller, "server") | ||
|
||
# Define the initial global model and send to server | ||
job.to(Net(), "server") | ||
|
||
for i in range(n_clients): | ||
executor = ScriptExecutor(task_script_path=train_script, task_script_args="") | ||
job.to(executor, f"site-{i}", tasks=["train"], gpu=0) | ||
|
||
# add privacy filter. | ||
pp_filter = PercentilePrivacy(percentile=10, gamma=0.01) | ||
job.to(pp_filter, f"site-{i}", tasks=["train"], filter_type=FilterType.TASK_RESULT) | ||
|
||
job.export_job("/tmp/nvflare/jobs/job_config") | ||
job.simulator_run("/tmp/nvflare/jobs/workdir") |
Oops, something went wrong.