Skip to content

Commit

Permalink
Job CLI bug fix: make sure we use the existing meta.conf from job_tem…
Browse files Browse the repository at this point in the history
…plate (#1980)

* bug fix. make sure we use the existing meta.conf from job_template first before we use the default one.

* change job_template_dir to job_templates_dir

* remove un-used code

* remove un-used code

* rename job_template_dir to job_templates_dir
  • Loading branch information
chesterxgchen committed Sep 8, 2023
1 parent f72cc22 commit 23ee751
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 53 deletions.
4 changes: 2 additions & 2 deletions nvflare/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def def_config_parser(sub_cmd):
"-pw", "--poc_workspace_dir", type=str, nargs="?", default=None, help="POC workspace location"
)
config_parser.add_argument(
"-jt", "--job_template_dir", type=str, nargs="?", default=None, help="job template location"
"-jt", "--job_templates_dir", type=str, nargs="?", default=None, help="job templates location"
)
config_parser.add_argument("-debug", "--debug", action="store_true", help="debug is on")
return {cmd: config_parser}
Expand All @@ -118,7 +118,7 @@ def handle_config_cmd(args):

nvflare_config = create_startup_kit_config(nvflare_config, args.startup_kit_dir)
nvflare_config = create_poc_workspace_config(nvflare_config, args.poc_workspace_dir)
nvflare_config = create_job_template_config(nvflare_config, args.job_template_dir)
nvflare_config = create_job_template_config(nvflare_config, args.job_templates_dir)

save_config(nvflare_config, config_file_path)

Expand Down
5 changes: 0 additions & 5 deletions nvflare/tool/job/config/config_exchange.conf

This file was deleted.

56 changes: 25 additions & 31 deletions nvflare/tool/job/job_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
JOB_TEMPLATE_CONF,
)
from nvflare.utils.cli_utils import (
find_job_template_location,
find_job_templates_location,
get_curr_dir,
get_hidden_nvflare_dir,
get_startup_kit_dir,
Expand All @@ -72,12 +72,12 @@ def find_filename_basename(f: str):
return basename


def build_job_template_indices(job_template_dir: str) -> ConfigTree:
def build_job_template_indices(job_templates_dir: str) -> ConfigTree:
conf = CF.parse_string("{ templates = {} }")
config_file_base_names = CONFIG_FILE_BASE_NAME_WO_EXTS
template_conf = conf.get("templates")
keys = JOB_INFO_KEYS
for root, dirs, files in os.walk(job_template_dir):
for root, dirs, files in os.walk(job_templates_dir):
config_files = [f for f in files if find_filename_basename(f) in config_file_base_names]
if len(config_files) > 0:
info_conf = get_template_info_config(root)
Expand All @@ -104,8 +104,8 @@ def get_template_info_config(template_dir):
def create_job(cmd_args):
try:
prepare_job_folder(cmd_args)
job_template_dir = find_job_template_location()
template_index_conf = build_job_template_indices(job_template_dir)
job_templates_dir = find_job_templates_location()
template_index_conf = build_job_template_indices(job_templates_dir)
job_folder = cmd_args.job_folder
config_dir = get_config_dir(job_folder)

Expand All @@ -120,9 +120,9 @@ def create_job(cmd_args):

target_template_name = cmd_args.template
check_template_exists(target_template_name, template_index_conf)
src = os.path.join(job_template_dir, target_template_name)
src = os.path.join(job_templates_dir, target_template_name)
copy_tree(src=src, dst=config_dir)
prepare_meta_config(cmd_args)
prepare_meta_config(cmd_args, src)
remove_extra_file(config_dir)
variable_values = prepare_job_config(cmd_args)
display_template_variables(job_folder, variable_values)
Expand Down Expand Up @@ -214,13 +214,13 @@ def display_template_variables(job_folder, variable_values):

def list_templates(cmd_args):
try:
job_template_dir = find_job_template_location(cmd_args.job_template_dir)
job_template_dir = os.path.abspath(job_template_dir)
template_index_conf = build_job_template_indices(job_template_dir)
job_templates_dir = find_job_templates_location(cmd_args.job_templates_dir)
job_templates_dir = os.path.abspath(job_templates_dir)
template_index_conf = build_job_template_indices(job_templates_dir)
display_available_templates(template_index_conf)

if job_template_dir:
update_job_template_dir(job_template_dir)
if job_templates_dir:
update_job_templates_dir(job_templates_dir)

except ValueError as e:
print(f"\nUnable to handle command: {CMD_LIST_TEMPLATES} due to: {e} \n")
Expand All @@ -231,11 +231,11 @@ def list_templates(cmd_args):
sub_cmd_parser.print_help()


def update_job_template_dir(job_template_dir: str):
def update_job_templates_dir(job_templates_dir: str):
hidden_nvflare_dir = get_hidden_nvflare_dir()
file_path = os.path.join(hidden_nvflare_dir, CONFIG_CONF)
config = CF.parse_file(file_path)
config.put(f"{JOB_TEMPLATE}.path", job_template_dir)
config.put(f"{JOB_TEMPLATE}.path", job_templates_dir)
save_config(config, file_path)


Expand Down Expand Up @@ -394,7 +394,7 @@ def define_list_templates_parser(job_subparser):
show_jobs_parser = job_subparser.add_parser("list_templates", help="show available job templates")
show_jobs_parser.add_argument(
"-d",
"--job_template_dir",
"--job_templates_dir",
type=str,
nargs="?",
default=None,
Expand Down Expand Up @@ -529,16 +529,8 @@ def save_merged_configs(merged_conf, tmp_job_dir):
save_config(root_index.value, dst_path)


def prepare_model_exchange_config(job_folder: str, force: bool):
dst_path = dst_config_path(job_folder, "config_exchange.conf")
if os.path.isfile(dst_path) and not force:
return
def prepare_meta_config(cmd_args, target_template_dir):

dst_config = load_src_config_template("config_exchange.conf")
save_config(dst_config, dst_path)


def prepare_meta_config(cmd_args):
job_folder = cmd_args.job_folder
job_folder = job_folder[:-1] if job_folder.endswith("/") else job_folder

Expand All @@ -551,15 +543,17 @@ def prepare_meta_config(cmd_args):
dst_path = meta_path
break

src_meta_path = os.path.join(target_template_dir, "meta.conf")
if not os.path.isfile(src_meta_path):
dst_config = load_default_config_template("meta.conf")
else:
dst_config = CF.parse_file(src_meta_path)

# Use existing meta.conf if user already defined it.
if not dst_path:
dst_config = load_src_config_template("meta.conf")
if not dst_path or (dst_path and cmd_args.force):
dst_config.put("name", app_name)
dst_path = os.path.join(job_folder, "meta.conf")
else:
dst_config = CF.from_dict(ConfigFactory.load_config(dst_path).to_dict())

save_config(dst_config, dst_path)
save_config(dst_config, dst_path)

# clean up
config_dir = get_config_dir(job_folder)
Expand All @@ -569,7 +563,7 @@ def prepare_meta_config(cmd_args):
os.remove(meta_path)


def load_src_config_template(config_file_name: str):
def load_default_config_template(config_file_name: str):
file_dir = os.path.dirname(__file__)
# src config here is always pyhocon
config_template = CF.parse_file(os.path.join(file_dir, f"config/{config_file_name}"))
Expand Down
30 changes: 15 additions & 15 deletions nvflare/utils/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,22 +126,22 @@ def create_poc_workspace_config(nvflare_config: ConfigTree, poc_workspace_dir: O
return conf.with_fallback(nvflare_config)


def create_job_template_config(nvflare_config: ConfigTree, job_template_dir: Optional[str] = None) -> ConfigTree:
def create_job_template_config(nvflare_config: ConfigTree, job_templates_dir: Optional[str] = None) -> ConfigTree:
"""
Args:
job_template_dir: specified job template directory
job_templates_dir: specified job template directory
nvflare_config (ConfigTree): The existing nvflare configuration.
Returns:
ConfigTree: The merged configuration tree.
"""
if job_template_dir is None:
if job_templates_dir is None:
return nvflare_config

job_template_dir = os.path.abspath(job_template_dir)
job_templates_dir = os.path.abspath(job_templates_dir)
conf_str = f"""
job_template {{
path = {job_template_dir}
path = {job_templates_dir}
}}
"""
conf: ConfigTree = CF.parse_string(conf_str)
Expand Down Expand Up @@ -178,31 +178,31 @@ def check_startup_dir(startup_kit_dir):
)


def find_job_template_location(job_template_dir: Optional[str] = None):
def check_job_template_dir(job_temp_dir: str):
def find_job_templates_location(job_templates_dir: Optional[str] = None):
def check_job_templates_dir(job_temp_dir: str):
if job_temp_dir:
if not os.path.isdir(job_temp_dir):
raise ValueError(f"Invalid job template directory {job_temp_dir}")

if job_template_dir is None:
if job_templates_dir is None:
nvflare_home = os.environ.get("NVFLARE_HOME", None)
if nvflare_home:
job_template_dir = os.path.join(nvflare_home, JOB_TEMPLATES)
job_templates_dir = os.path.join(nvflare_home, JOB_TEMPLATES)

if job_template_dir is None:
if job_templates_dir is None:
nvflare_config = load_hidden_config()
job_template_dir = nvflare_config.get_string("job_template.path", None) if nvflare_config else None
job_templates_dir = nvflare_config.get_string("job_template.path", None) if nvflare_config else None

if job_template_dir:
check_job_template_dir(job_template_dir)
if job_templates_dir:
check_job_templates_dir(job_templates_dir)

if not job_template_dir:
if not job_templates_dir:
raise ValueError(
"Required job_template directory is not specified. "
"Please check ~/.nvflare/config.conf or set env variable NVFLARE_HOME "
)

return job_template_dir
return job_templates_dir


def get_curr_dir():
Expand Down

0 comments on commit 23ee751

Please sign in to comment.