Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fobs auto register #2567

Merged
merged 23 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ class ConfigVarName:
# server: wait this long since job schedule time before starting to check dead/disconnected clients
DEAD_CLIENT_CHECK_LEAD_TIME = "dead_client_check_lead_time"

# customized nvflare decomposers module name
DECOMPOSER_MODULE = "nvflare_decomposers"


class SystemVarName:
"""
Expand Down
4 changes: 4 additions & 0 deletions nvflare/client/flare_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from nvflare.fuel.utils.pipe.cell_pipe import CellPipe
from nvflare.fuel.utils.pipe.pipe import Message, Mode, Pipe
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler
from nvflare.private.fed.utils.fed_utils import register_ext_decomposers


class FlareAgentException(Exception):
Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
metric_channel_name: str = PipeChannelName.METRIC,
close_pipe: bool = True,
close_metric_pipe: bool = True,
decomposer_module: str = None,
yhwen marked this conversation as resolved.
Show resolved Hide resolved
):
"""Constructor of Flare Agent.

Expand Down Expand Up @@ -102,6 +104,8 @@ def __init__(
"""
flare_decomposers.register()
common_decomposers.register()
if decomposer_module:
register_ext_decomposers(decomposer_module)

self.logger = logging.getLogger(self.__class__.__name__)
self.pipe = pipe
Expand Down
4 changes: 4 additions & 0 deletions nvflare/client/ipc/ipc_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from nvflare.fuel.f3.cellnet.utils import make_reply
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.fed.utils.fed_utils import register_ext_decomposers

_SSL_ROOT_CERT = "rootCA.pem"
_SHORT_SLEEP_TIME = 0.2
Expand All @@ -43,6 +44,7 @@ def __init__(
flare_site_connection_timeout=60.0,
flare_site_heartbeat_timeout=None,
resend_result_interval=2.0,
decomposer_module=None,
):
"""Constructor of Flare Agent. The agent is responsible for communicating with the Flare Client Job cell (CJ)
to get task and to submit task result.
Expand Down Expand Up @@ -110,6 +112,8 @@ def __init__(
cb=self._msg_received,
)
numpy_decomposers.register()
if decomposer_module:
register_ext_decomposers(decomposer_module)

def start(self):
"""Start the agent. This method must be called to enable CJ/Agent communication.
Expand Down
44 changes: 38 additions & 6 deletions nvflare/fuel/utils/fobs/fobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import inspect
import logging
import os
import sys
from enum import Enum
from os.path import dirname, join
from typing import Any, BinaryIO, Dict, Type, TypeVar, Union
Expand Down Expand Up @@ -42,6 +43,7 @@

FOBS_TYPE = "__fobs_type__"
FOBS_DATA = "__fobs_data__"
FOBS_DECOMPOSER_DIR = "nvflare_decomposers"
yhwen marked this conversation as resolved.
Show resolved Hide resolved
MAX_CONTENT_LEN = 128
MSGPACK_TYPES = (None, bool, int, float, str, bytes, bytearray, memoryview, list, dict)
T = TypeVar("T")
Expand Down Expand Up @@ -189,12 +191,42 @@ def register_folder(folder: str, package: str):
for module in os.listdir(folder):
if module != "__init__.py" and module[-3:] == ".py":
decomposers = package + "." + module[:-3]
imported = importlib.import_module(decomposers, __package__)
for _, cls_obj in inspect.getmembers(imported, inspect.isclass):
spec = inspect.getfullargspec(cls_obj.__init__)
# classes who are abstract or take extra args in __init__ can't be auto-registered
if issubclass(cls_obj, Decomposer) and not inspect.isabstract(cls_obj) and len(spec.args) == 1:
register(cls_obj)
try:
imported = importlib.import_module(decomposers, __package__)
for _, cls_obj in inspect.getmembers(imported, inspect.isclass):
spec = inspect.getfullargspec(cls_obj.__init__)
# classes who are abstract or take extra args in __init__ can't be auto-registered
if issubclass(cls_obj, Decomposer) and not inspect.isabstract(cls_obj) and len(spec.args) == 1:
register(cls_obj)
except (ModuleNotFoundError, RuntimeError) as e:
log.debug(
f"Try to import module {decomposers}, but failed: {secure_format_exception(e)}. "
f"Can't use name in config to refer to classes in module: {decomposers}."
)
pass


def register_custom_folder(folder: str):
if os.path.isdir(folder) and folder not in sys.path:
sys.path.append(folder)

for root, dirs, files in os.walk(folder):
for filename in files:
if filename.endswith(".py"):
module = filename[:-3]
sub_folder = os.path.relpath(root, folder).strip(".")
if sub_folder:
module = sub_folder + "." + module

imported = importlib.import_module(module)
for _, cls_obj in inspect.getmembers(imported, inspect.isclass):
if issubclass(cls_obj, Decomposer) and not inspect.isabstract(cls_obj):
spec = inspect.getfullargspec(cls_obj.__init__)
if len(spec.args) == 1:
register(cls_obj)
else:
# Can't handle argument in constructor
log.warning(f"Invalid driver, __init__ with extra arguments: {module}")


def _register_decomposers():
Expand Down
24 changes: 21 additions & 3 deletions nvflare/private/fed/app/client/client_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,32 @@
import time

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLContextKey, JobConstants, SiteType, WorkspaceConstants
from nvflare.apis.fl_constant import (
ConfigVarName,
FLContextKey,
JobConstants,
SiteType,
SystemConfigs,
WorkspaceConstants,
)
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import AppFolderConstants
from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger, create_privacy_manager
from nvflare.private.fed.app.utils import component_security_check, version_check
from nvflare.private.fed.client.admin import FedAdminAgent
from nvflare.private.fed.client.client_engine import ClientEngine
from nvflare.private.fed.client.client_status import ClientStatus
from nvflare.private.fed.client.fed_client import FederatedClient
from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize, security_init
from nvflare.private.fed.utils.fed_utils import (
add_logfile_handler,
fobs_initialize,
register_ext_decomposers,
security_init,
)
from nvflare.private.privacy_manager import PrivacyService
from nvflare.security.logging import secure_format_exception

Expand Down Expand Up @@ -66,7 +79,7 @@ def main(args):

try:
os.chdir(args.workspace)
fobs_initialize()
fobs_initialize(workspace)

conf = FLClientStarterConfiger(
workspace=workspace,
Expand All @@ -75,6 +88,11 @@ def main(args):
)
conf.configure()

decomposer_module = ConfigService.get_str_var(
yhwen marked this conversation as resolved.
Show resolved Hide resolved
name=ConfigVarName.DECOMPOSER_MODULE, conf=SystemConfigs.RESOURCES_CONF
)
register_ext_decomposers(decomposer_module)

log_file = workspace.get_log_file_path()
add_logfile_handler(log_file)

Expand Down
4 changes: 2 additions & 2 deletions nvflare/private/fed/app/client/sub_worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,11 @@ def stop(self):
def main(args):
workspace = Workspace(args.workspace, args.client_name)
app_custom_folder = workspace.get_client_custom_dir()
if os.path.isdir(app_custom_folder):
if os.path.isdir(app_custom_folder) and app_custom_folder not in sys.path:
sys.path.append(app_custom_folder)
configure_logging(workspace)

fobs_initialize()
fobs_initialize(workspace=workspace, job_id=args.job_id)
yhwen marked this conversation as resolved.
Show resolved Hide resolved

SecurityContentService.initialize(content_folder=workspace.get_startup_kit_dir())

Expand Down
11 changes: 9 additions & 2 deletions nvflare/private/fed/app/client/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import sys
import threading

from nvflare.apis.fl_constant import FLContextKey, JobConstants
from nvflare.apis.fl_constant import ConfigVarName, FLContextKey, JobConstants, SystemConfigs
from nvflare.apis.overseer_spec import SP
from nvflare.apis.workspace import Workspace
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.sec.security_content_service import SecurityContentService
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import EngineConstant
from nvflare.private.fed.app.fl_conf import FLClientStarterConfiger
from nvflare.private.fed.app.utils import monitor_parent_process
Expand All @@ -36,6 +37,7 @@
add_logfile_handler,
create_stats_pool_files_for_job,
fobs_initialize,
register_ext_decomposers,
set_stats_pool_config_for_job,
)
from nvflare.security.logging import secure_format_exception
Expand Down Expand Up @@ -69,7 +71,7 @@ def main(args):
if os.path.exists(restart_file):
os.remove(restart_file)

fobs_initialize()
fobs_initialize(workspace=workspace, job_id=args.job_id)
# Initialize audit service since the job execution will need it!
audit_file_name = workspace.get_audit_file_path()
AuditService.initialize(audit_file_name)
Expand All @@ -94,6 +96,11 @@ def main(args):
)
conf.configure()

decomposer_module = ConfigService.get_str_var(
name=ConfigVarName.DECOMPOSER_MODULE, conf=SystemConfigs.RESOURCES_CONF
)
register_ext_decomposers(decomposer_module)

log_file = workspace.get_app_log_file_path(args.job_id)
add_logfile_handler(log_file)
logger = logging.getLogger("worker_process")
Expand Down
4 changes: 2 additions & 2 deletions nvflare/private/fed/app/fl_conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, workspace: Workspace, args, kv_list=None):
kv_list: key value pair list
"""
site_custom_folder = workspace.get_site_custom_dir()
if os.path.isdir(site_custom_folder):
if os.path.isdir(site_custom_folder) and site_custom_folder not in sys.path:
sys.path.append(site_custom_folder)

self.args = args
Expand Down Expand Up @@ -219,7 +219,7 @@ def __init__(self, workspace: Workspace, args, kv_list=None):
kv_list: key value pair list
"""
site_custom_folder = workspace.get_site_custom_dir()
if os.path.isdir(site_custom_folder):
if os.path.isdir(site_custom_folder) and site_custom_folder not in sys.path:
sys.path.append(site_custom_folder)

self.args = args
Expand Down
11 changes: 9 additions & 2 deletions nvflare/private/fed/app/server/runner_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import sys
import threading

from nvflare.apis.fl_constant import JobConstants
from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SystemConfigs
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.sec.audit import AuditService
from nvflare.fuel.sec.security_content_service import SecurityContentService
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import AppFolderConstants
from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger
from nvflare.private.fed.app.utils import monitor_parent_process
Expand All @@ -36,6 +37,7 @@
add_logfile_handler,
create_stats_pool_files_for_job,
fobs_initialize,
register_ext_decomposers,
set_stats_pool_config_for_job,
)
from nvflare.security.logging import secure_format_exception, secure_log_traceback
Expand Down Expand Up @@ -64,7 +66,7 @@ def main(args):

try:
os.chdir(args.workspace)
fobs_initialize()
fobs_initialize(workspace=workspace, job_id=args.job_id)

SecurityContentService.initialize(content_folder=workspace.get_startup_kit_dir())

Expand Down Expand Up @@ -97,6 +99,11 @@ def main(args):
deployer = conf.deployer
secure_train = conf.cmd_vars.get("secure_train", False)

decomposer_module = ConfigService.get_str_var(
name=ConfigVarName.DECOMPOSER_MODULE, conf=SystemConfigs.RESOURCES_CONF
)
register_ext_decomposers(decomposer_module)

try:
# create the FL server
server_config, server = deployer.create_fl_server(args, secure_train=secure_train)
Expand Down
18 changes: 15 additions & 3 deletions nvflare/private/fed/app/server/server_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,22 @@
import sys
import time

from nvflare.apis.fl_constant import JobConstants, SiteType, WorkspaceConstants
from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SiteType, SystemConfigs, WorkspaceConstants
from nvflare.apis.workspace import Workspace
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm
from nvflare.fuel.utils.argument_utils import parse_vars
from nvflare.fuel.utils.config_service import ConfigService
from nvflare.private.defs import AppFolderConstants
from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger, create_privacy_manager
from nvflare.private.fed.app.utils import create_admin_server, version_check
from nvflare.private.fed.server.server_status import ServerStatus
from nvflare.private.fed.utils.fed_utils import add_logfile_handler, fobs_initialize, security_init
from nvflare.private.fed.utils.fed_utils import (
add_logfile_handler,
fobs_initialize,
register_ext_decomposers,
security_init,
)
from nvflare.private.privacy_manager import PrivacyService
from nvflare.security.logging import secure_format_exception

Expand Down Expand Up @@ -64,13 +70,14 @@ def main(args):
try:
os.chdir(args.workspace)

fobs_initialize()
fobs_initialize(workspace)

conf = FLServerStarterConfiger(
workspace=workspace,
args=args,
kv_list=args.set,
)

log_level = os.environ.get("FL_LOG_LEVEL", "")
numeric_level = getattr(logging, log_level.upper(), None)
if isinstance(numeric_level, int):
Expand All @@ -82,6 +89,11 @@ def main(args):
logger.critical("loglevel critical enabled")
conf.configure()

decomposer_module = ConfigService.get_str_var(
yhwen marked this conversation as resolved.
Show resolved Hide resolved
name=ConfigVarName.DECOMPOSER_MODULE, conf=SystemConfigs.RESOURCES_CONF
)
register_ext_decomposers(decomposer_module)

log_file = workspace.get_log_file_path()
add_logfile_handler(log_file)

Expand Down