Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pythonic job creation #2483

Merged
merged 43 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
d94b23a
WIP: constructed the FedJob.
yhwen Mar 26, 2024
b89520e
WIP: server_app josn export.
yhwen Mar 27, 2024
8613ee5
generate the job app config.
yhwen Mar 28, 2024
1d5fc8f
fully functional pythonic job creation.
yhwen Mar 29, 2024
51eccea
Added simulator_run for pythonic API.
yhwen Mar 29, 2024
62ceea8
reformat.
yhwen Apr 1, 2024
c90ec58
Added filters support for pythonic job creation.
yhwen Apr 1, 2024
08f20af
handled the direct import case in fed_job.
yhwen Apr 1, 2024
67ecfc2
refactor.
yhwen Apr 2, 2024
f43ac26
Added the resource_spec set function for FedJob.
yhwen Apr 3, 2024
71d3c1c
refactored.
yhwen Apr 3, 2024
b709e30
Moved the ClientApp and ServerApp into fed_app.py.
yhwen Apr 3, 2024
39a7fe2
Refactored: removed the _FilterDef class.
yhwen Apr 4, 2024
c6722a5
refactored.
yhwen Apr 4, 2024
5311fc7
Rename job config classes (#3)
holgerroth Apr 5, 2024
9915974
Enable obj in the constructor as paramenter.
yhwen Apr 8, 2024
105e226
Added support for the launcher script.
yhwen Apr 8, 2024
cdf5cc5
refactored.
yhwen Apr 9, 2024
0a30979
reformat.
yhwen Apr 10, 2024
22febb8
Update the comment.
yhwen Apr 10, 2024
4f7e2dd
re-arrange the package location.
yhwen Apr 10, 2024
3c85cc9
Added add_ext_script() for BaseAppConfig.
yhwen Apr 10, 2024
196fc17
codestyle fix.
yhwen Apr 11, 2024
03c6eed
Removed the client-api-pt example.
yhwen Apr 11, 2024
7869697
removed no used import.
yhwen Apr 11, 2024
1152511
fixed the in_time_accumulate_weighted_aggregator_test.py
yhwen Apr 11, 2024
f13a9ec
Added Enum parameter support.
yhwen Apr 11, 2024
8ba8e21
Added docstring.
yhwen Apr 11, 2024
f7223bd
Merge branch 'main' into pythonic_job_creation
yhwen Apr 11, 2024
9e531c0
Added ability to handle parameters from base class.
yhwen Apr 12, 2024
4d68baa
Move the parameter data format conversion to the START_RUN event for …
yhwen Apr 12, 2024
81605ad
Added params_exchange_format for PTInProcessClientAPIExecutor.
yhwen Apr 12, 2024
94074ef
codestyle fix.
yhwen Apr 12, 2024
3858177
Fixed a custom code folder structure issue.
yhwen Apr 12, 2024
628d8d1
work for sub-folder custom files.
yhwen Apr 13, 2024
00bd538
backed to handle parameters from base classes.
yhwen Apr 14, 2024
14825a1
Support folder structure job config.
yhwen Apr 14, 2024
2849705
Added support for flat folder from '.XXX' import.
yhwen Apr 14, 2024
d5bd639
codestyle fix.
yhwen Apr 14, 2024
9c1218e
refactored and add docstring.
yhwen Apr 15, 2024
998aa9c
Merge branch 'main' into pythonic_job_creation
YuanTingHsieh Apr 15, 2024
28c1b6a
Merge branch 'main' into pythonic_job_creation
chesterxgchen Apr 17, 2024
0c4d2a7
Address some of the PR reviews.
yhwen Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/advanced/job_config/hello-pt/add_shareable_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
yhwen marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.filter import Filter
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable


class AddShareable(Filter):
def process(self, shareable: Shareable, fl_ctx: FLContext) -> Shareable:
print(f"{fl_ctx.get_identity_name()} ---- AddShareable Filter ----")
yhwen marked this conversation as resolved.
Show resolved Hide resolved

return shareable
200 changes: 200 additions & 0 deletions examples/advanced/job_config/hello-pt/cifar10trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
yhwen marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os.path

import torch
from pt_constants import PTConstants
from simple_network import SimpleNetwork
from torch import nn
from torch.optim import SGD
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.abstract.model import make_model_learnable, model_learnable_to_dxo
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_opt.pt.model_persistence_format_manager import PTModelPersistenceFormatManager


class Cifar10Trainer(Executor):
def __init__(
self,
data_path="~/data",
lr=0.01,
epochs=5,
train_task_name=AppConstants.TASK_TRAIN,
submit_model_task_name=AppConstants.TASK_SUBMIT_MODEL,
exclude_vars=None,
pre_train_task_name=AppConstants.TASK_GET_WEIGHTS,
):
"""Cifar10 Trainer handles train and submit_model tasks. During train_task, it trains a
simple network on CIFAR10 dataset. For submit_model task, it sends the locally trained model
(if present) to the server.

Args:
lr (float, optional): Learning rate. Defaults to 0.01
epochs (int, optional): Epochs. Defaults to 5
train_task_name (str, optional): Task name for train task. Defaults to "train".
submit_model_task_name (str, optional): Task name for submit model. Defaults to "submit_model".
exclude_vars (list): List of variables to exclude during model loading.
pre_train_task_name: Task name for pre train task, i.e., sending initial model weights.
"""
super().__init__()

self._lr = lr
self._epochs = epochs
self._train_task_name = train_task_name
self._pre_train_task_name = pre_train_task_name
self._submit_model_task_name = submit_model_task_name
self._exclude_vars = exclude_vars

# Training setup
self.model = SimpleNetwork()
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.loss = nn.CrossEntropyLoss()
self.optimizer = SGD(self.model.parameters(), lr=lr, momentum=0.9)

# Create Cifar10 dataset for training.
transforms = Compose(
[
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
self._train_dataset = CIFAR10(root=data_path, transform=transforms, download=True, train=True)
self._train_loader = DataLoader(self._train_dataset, batch_size=4, shuffle=True)
self._n_iterations = len(self._train_loader)

# Setup the persistence manager to save PT model.
# The default training configuration is used by persistence manager
# in case no initial model is found.
self._default_train_conf = {"train": {"model": type(self.model).__name__}}
self.persistence_manager = PTModelPersistenceFormatManager(
data=self.model.state_dict(), default_train_conf=self._default_train_conf
)

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
try:
if task_name == self._pre_train_task_name:
# Get the new state dict and send as weights
return self._get_model_weights()
elif task_name == self._train_task_name:
# Get model weights
try:
dxo = from_shareable(shareable)
except:
self.log_error(fl_ctx, "Unable to extract dxo from shareable.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Ensure data kind is weights.
if not dxo.data_kind == DataKind.WEIGHTS:
self.log_error(fl_ctx, f"data_kind expected WEIGHTS but got {dxo.data_kind} instead.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Convert weights to tensor. Run training
torch_weights = {k: torch.as_tensor(v) for k, v in dxo.data.items()}
self._local_train(fl_ctx, torch_weights, abort_signal)

# Check the abort_signal after training.
# local_train returns early if abort_signal is triggered.
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

# Save the local model after training.
self._save_local_model(fl_ctx)

# Get the new state dict and send as weights
return self._get_model_weights()
elif task_name == self._submit_model_task_name:
# Load local model
ml = self._load_local_model(fl_ctx)

# Get the model parameters and create dxo from it
dxo = model_learnable_to_dxo(ml)
return dxo.to_shareable()
else:
return make_reply(ReturnCode.TASK_UNKNOWN)
except Exception as e:
self.log_exception(fl_ctx, f"Exception in simple trainer: {e}.")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)

def _get_model_weights(self) -> Shareable:
# Get the new state dict and send as weights
weights = {k: v.cpu().numpy() for k, v in self.model.state_dict().items()}

outgoing_dxo = DXO(
data_kind=DataKind.WEIGHTS, data=weights, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: self._n_iterations}
)
return outgoing_dxo.to_shareable()

def _local_train(self, fl_ctx, weights, abort_signal):
# Set the model weights
self.model.load_state_dict(state_dict=weights)

# Basic training
self.model.train()
for epoch in range(self._epochs):
running_loss = 0.0
for i, batch in enumerate(self._train_loader):
if abort_signal.triggered:
# If abort_signal is triggered, we simply return.
# The outside function will check it again and decide steps to take.
return

images, labels = batch[0].to(self.device), batch[1].to(self.device)
self.optimizer.zero_grad()

predictions = self.model(images)
cost = self.loss(predictions, labels)
cost.backward()
self.optimizer.step()

running_loss += cost.cpu().detach().numpy() / images.size()[0]
if i % 3000 == 0:
self.log_info(
fl_ctx, f"Epoch: {epoch}/{self._epochs}, Iteration: {i}, " f"Loss: {running_loss/3000}"
)
running_loss = 0.0

def _save_local_model(self, fl_ctx: FLContext):
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM))
models_dir = os.path.join(run_dir, PTConstants.PTModelsDir)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
model_path = os.path.join(models_dir, PTConstants.PTLocalModelName)

ml = make_model_learnable(self.model.state_dict(), {})
self.persistence_manager.update(ml)
torch.save(self.persistence_manager.to_persistence_dict(), model_path)

def _load_local_model(self, fl_ctx: FLContext):
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_prop(ReservedKey.RUN_NUM))
models_dir = os.path.join(run_dir, PTConstants.PTModelsDir)
if not os.path.exists(models_dir):
return None
model_path = os.path.join(models_dir, PTConstants.PTLocalModelName)

self.persistence_manager = PTModelPersistenceFormatManager(
data=torch.load(model_path), default_train_conf=self._default_train_conf
)
ml = self.persistence_manager.to_model_learnable(exclude_vars=self._exclude_vars)
return ml
112 changes: 112 additions & 0 deletions examples/advanced/job_config/hello-pt/cifar10validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from simple_network import SimpleNetwork
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

from nvflare.apis.dxo import DXO, DataKind, from_shareable
from nvflare.apis.executor import Executor
from nvflare.apis.fl_constant import ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
from nvflare.app_common.app_constant import AppConstants


class Cifar10Validator(Executor):
def __init__(self, data_path="~/data", validate_task_name=AppConstants.TASK_VALIDATION):
super().__init__()

self._validate_task_name = validate_task_name

# Setup the model
self.model = SimpleNetwork()
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.model.to(self.device)

# Preparing the dataset for testing.
transforms = Compose(
[
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)
test_data = CIFAR10(root=data_path, train=False, transform=transforms)
self._test_loader = DataLoader(test_data, batch_size=4, shuffle=False)

def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable:
if task_name == self._validate_task_name:
model_owner = "?"
try:
try:
dxo = from_shareable(shareable)
except:
self.log_error(fl_ctx, "Error in extracting dxo from shareable.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Ensure data_kind is weights.
if not dxo.data_kind == DataKind.WEIGHTS:
self.log_exception(fl_ctx, f"DXO is of type {dxo.data_kind} but expected type WEIGHTS.")
return make_reply(ReturnCode.BAD_TASK_DATA)

# Extract weights and ensure they are tensor.
model_owner = shareable.get_header(AppConstants.MODEL_OWNER, "?")
weights = {k: torch.as_tensor(v, device=self.device) for k, v in dxo.data.items()}

# Get validation accuracy
val_accuracy = self._validate(weights, abort_signal)
if abort_signal.triggered:
return make_reply(ReturnCode.TASK_ABORTED)

self.log_info(
fl_ctx,
f"Accuracy when validating {model_owner}'s model on"
f" {fl_ctx.get_identity_name()}"
f"s data: {val_accuracy}",
)

dxo = DXO(data_kind=DataKind.METRICS, data={"val_acc": val_accuracy})
return dxo.to_shareable()
except:
self.log_exception(fl_ctx, f"Exception in validating model from {model_owner}")
return make_reply(ReturnCode.EXECUTION_EXCEPTION)
else:
return make_reply(ReturnCode.TASK_UNKNOWN)

def _validate(self, weights, abort_signal):
self.model.load_state_dict(weights)

self.model.eval()

correct = 0
total = 0
with torch.no_grad():
for i, (images, labels) in enumerate(self._test_loader):
if abort_signal.triggered:
return 0

images, labels = images.to(self.device), labels.to(self.device)
output = self.model(images)

_, pred_label = torch.max(output, 1)

correct += (pred_label == labels).sum().item()
total += images.size()[0]

metric = correct / float(total)

return metric
Loading
Loading