Skip to content

Commit

Permalink
1. Replaced Pickle with FOBS (#746)
Browse files Browse the repository at this point in the history
2. Added decomposers for all classes to be serialized
3. Changed all examples to use FOBS
4. Break all recursion in the serialized objects.
  • Loading branch information
nvidianz committed Aug 8, 2022
1 parent 695c076 commit 6cde16f
Show file tree
Hide file tree
Showing 39 changed files with 1,089 additions and 95 deletions.
2 changes: 1 addition & 1 deletion examples/hello-cyclic/config/config_fed_server.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"id": "persistor",
"path": "tf2_model_persistor.TF2ModelPersistor",
"args": {
"save_name": "tf2weights.pickle"
"save_name": "tf2weights.fobs"
}
},
{
Expand Down
6 changes: 3 additions & 3 deletions examples/hello-cyclic/custom/tf2_model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
import pickle
import json

import tensorflow as tf
Expand All @@ -22,6 +21,7 @@
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.model import ModelLearnable
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.fuel.utils import fobs
from tf2_net import Net
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.abstract.model import make_model_learnable
Expand Down Expand Up @@ -85,7 +85,7 @@ def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
if os.path.exists(self._pkl_save_path):
self.logger.info(f"Loading server weights")
with open(self._pkl_save_path, "rb") as f:
model_learnable = pickle.load(f)
model_learnable = fobs.load(f)
else:
self.logger.info(f"Initializing server model")
network = Net()
Expand All @@ -111,4 +111,4 @@ def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext):
model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()}
self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}")
with open(self._pkl_save_path, "wb") as f:
pickle.dump(model_learnable, f)
fobs.dump(model_learnable, f)
2 changes: 1 addition & 1 deletion examples/hello-tf2/config/config_fed_server.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"id": "persistor",
"path": "tf2_model_persistor.TF2ModelPersistor",
"args": {
"save_name": "tf2weights.pickle"
"save_name": "tf2weights.fobs"
}
},
{
Expand Down
6 changes: 3 additions & 3 deletions examples/hello-tf2/custom/tf2_model_persistor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
import pickle
import json

import tensorflow as tf
Expand All @@ -22,6 +21,7 @@
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.abstract.model import ModelLearnable
from nvflare.app_common.abstract.model_persistor import ModelPersistor
from nvflare.fuel.utils import fobs
from tf2_net import Net
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.abstract.model import make_model_learnable
Expand Down Expand Up @@ -85,7 +85,7 @@ def load_model(self, fl_ctx: FLContext) -> ModelLearnable:
if os.path.exists(self._pkl_save_path):
self.logger.info(f"Loading server weights")
with open(self._pkl_save_path, "rb") as f:
model_learnable = pickle.load(f)
model_learnable = fobs.load(f)
else:
self.logger.info(f"Initializing server model")
network = Net()
Expand All @@ -111,4 +111,4 @@ def save_model(self, model_learnable: ModelLearnable, fl_ctx: FLContext):
model_learnable_info = {k: str(type(v)) for k, v in model_learnable.items()}
self.logger.info(f"Saving aggregated server weights: \n {model_learnable_info}")
with open(self._pkl_save_path, "wb") as f:
pickle.dump(model_learnable, f)
fobs.dump(model_learnable, f)
8 changes: 4 additions & 4 deletions nvflare/apis/dxo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.

import copy
import pickle
from typing import List

from nvflare.apis.shareable import ReservedHeaderKey, Shareable
from nvflare.fuel.utils import fobs


class DataKind(object):
Expand Down Expand Up @@ -114,7 +114,7 @@ def to_bytes(self) -> bytes:
object serialized in bytes.
"""
return pickle.dumps(self)
return fobs.dumps(self)

def validate(self) -> str:
if self.data is None:
Expand Down Expand Up @@ -166,10 +166,10 @@ def from_bytes(data: bytes) -> DXO:
data: a bytes object
Returns:
an object loaded by pickle from data
an object loaded by FOBS from data
"""
x = pickle.loads(data)
x = fobs.loads(data)
if isinstance(x, DXO):
return x
else:
Expand Down
1 change: 1 addition & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class ReservedKey(object):
EVENT_SCOPE = "__event_scope__"
RUN_ABORT_SIGNAL = "__run_abort_signal__"
SHAREABLE = "__shareable__"
SHARED_FL_CONTEXT = "__shared_fl_context__"
ARGS = "__args__"
WORKSPACE_OBJECT = "__workspace_object__"
RANK_NUMBER = "__rank_number__"
Expand Down
12 changes: 6 additions & 6 deletions nvflare/apis/impl/job_def_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import datetime
import os
import pathlib
import pickle
import shutil
import tempfile
import time
Expand All @@ -29,6 +28,7 @@
from nvflare.apis.server_engine_spec import ServerEngineSpec
from nvflare.apis.storage import StorageException, StorageSpec
from nvflare.fuel.hci.zip_utils import unzip_all_from_bytes, zip_directory_to_bytes
from nvflare.fuel.utils import fobs


class _JobFilter(ABC):
Expand Down Expand Up @@ -108,7 +108,7 @@ def create(self, meta: dict, uploaded_content: bytes, fl_ctx: FLContext) -> Dict
# write it to the store
stored_data = {JobDataKey.JOB_DATA.value: uploaded_content, JobDataKey.WORKSPACE_DATA.value: None}
store = self._get_job_store(fl_ctx)
store.create_object(self.job_uri(jid), pickle.dumps(stored_data), meta, overwrite_existing=True)
store.create_object(self.job_uri(jid), fobs.dumps(stored_data), meta, overwrite_existing=True)
return meta

def delete(self, jid: str, fl_ctx: FLContext):
Expand Down Expand Up @@ -185,12 +185,12 @@ def get_content(self, jid: str, fl_ctx: FLContext) -> Optional[bytes]:
stored_data = store.get_data(self.job_uri(jid))
except StorageException:
return None
return pickle.loads(stored_data).get(JobDataKey.JOB_DATA.value)
return fobs.loads(stored_data).get(JobDataKey.JOB_DATA.value)

def get_job_data(self, jid: str, fl_ctx: FLContext) -> dict:
store = self._get_job_store(fl_ctx)
stored_data = store.get_data(self.job_uri(jid))
return pickle.loads(stored_data)
return fobs.loads(stored_data)

def set_status(self, jid: str, status: RunStatus, fl_ctx: FLContext):
meta = {JobMetaKey.STATUS.value: status.value}
Expand Down Expand Up @@ -259,6 +259,6 @@ def set_approval(
def save_workspace(self, jid: str, data: bytes, fl_ctx: FLContext):
store = self._get_job_store(fl_ctx)
stored_data = store.get_data(self.job_uri(jid))
job_data = pickle.loads(stored_data)
job_data = fobs.loads(stored_data)
job_data[JobDataKey.WORKSPACE_DATA.value] = data
store.update_data(self.job_uri(jid), pickle.dumps(job_data))
store.update_data(self.job_uri(jid), fobs.dumps(job_data))
9 changes: 4 additions & 5 deletions nvflare/apis/shareable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle

from ..fuel.utils import fobs
from .fl_constant import ReservedKey, ReturnCode


Expand Down Expand Up @@ -113,7 +112,7 @@ def to_bytes(self) -> bytes:
object serialized in bytes.
"""
return pickle.dumps(self)
return fobs.dumps(self)

@classmethod
def from_bytes(cls, data: bytes):
Expand All @@ -123,10 +122,10 @@ def from_bytes(cls, data: bytes):
data: a bytes object
Returns:
an object loaded by pickle from data
an object loaded by FOBS from data
"""
return pickle.loads(data)
return fobs.loads(data)


# some convenience functions
Expand Down
13 changes: 13 additions & 0 deletions nvflare/apis/utils/decomposers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
168 changes: 168 additions & 0 deletions nvflare/apis/utils/decomposers/flare_decomposers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decomposers for objects used by NVFlare itself
This module contains all the decomposers used to run NVFlare.
The decomposers are registered at server/client startup.
"""
import os
from argparse import Namespace
from typing import Any

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.apis.client import Client
from nvflare.apis.dxo import DXO
from nvflare.apis.fl_context import FLContext
from nvflare.apis.fl_snapshot import RunSnapshot
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.apis.workspace import Workspace
from nvflare.fuel.utils import fobs
from nvflare.fuel.utils.fobs.decomposer import Decomposer


class ShareableDecomposer(Decomposer):
@staticmethod
def supported_type():
return Shareable

def decompose(self, target: Shareable) -> Any:
return target.copy()

def recompose(self, data: Any) -> Shareable:
obj = Shareable()
for k, v in data.items():
obj[k] = v
return obj


class ContextDecomposer(Decomposer):
@staticmethod
def supported_type():
return FLContext

def decompose(self, target: FLContext) -> Any:
return [target.model, target.props]

def recompose(self, data: Any) -> FLContext:
obj = FLContext()
obj.model = data[0]
obj.props = data[1]
return obj


class DxoDecomposer(Decomposer):
@staticmethod
def supported_type():
return DXO

def decompose(self, target: DXO) -> Any:
return [target.data_kind, target.data, target.meta]

def recompose(self, data: Any) -> DXO:
return DXO(data[0], data[1], data[2])


class ClientDecomposer(Decomposer):
@staticmethod
def supported_type():
return Client

def decompose(self, target: Client) -> Any:
return [target.name, target.token, target.last_connect_time, target.props]

def recompose(self, data: Any) -> Client:
client = Client(data[0], data[1])
client.last_connect_time = data[2]
client.props = data[3]
return client


class RunSnapshotDecomposer(Decomposer):
@staticmethod
def supported_type():
return RunSnapshot

def decompose(self, target: RunSnapshot) -> Any:
return [target.component_states, target.completed, target.job_id]

def recompose(self, data: Any) -> RunSnapshot:
snapshot = RunSnapshot(data[2])
snapshot.component_states = data[0]
snapshot.completed = data[1]
return snapshot


class WorkspaceDecomposer(Decomposer):
@staticmethod
def supported_type():
return Workspace

def decompose(self, target: Workspace) -> Any:
return [target.root_dir, target.name, target.config_folder]

def recompose(self, data: Any) -> Workspace:
return Workspace(data[0], data[1], data[2])


class SignalDecomposer(Decomposer):
@staticmethod
def supported_type():
return Signal

def decompose(self, target: Signal) -> Any:
return [target.value, target.trigger_time, target.triggered]

def recompose(self, data: Any) -> Signal:
signal = Signal()
signal.value = data[0]
signal.trigger_time = data[1]
signal.triggered = data[2]
return signal


class AnalyticsDataTypeDecomposer(Decomposer):
@staticmethod
def supported_type():
return AnalyticsDataType

def decompose(self, target: AnalyticsDataType) -> Any:
return target.name

def recompose(self, data: Any) -> AnalyticsDataType:
return AnalyticsDataType[data]


class NamespaceDecomposer(Decomposer):
@staticmethod
def supported_type():
return Namespace

def decompose(self, target: Namespace) -> Any:
return vars(target)

def recompose(self, data: Any) -> Namespace:
return Namespace(**data)


def register():
if register.registered:
return

fobs.register_folder(os.path.dirname(__file__), __package__)
register.registered = True


register.registered = False
Loading

0 comments on commit 6cde16f

Please sign in to comment.