Skip to content

Commit

Permalink
fix ckpt issue
Browse files Browse the repository at this point in the history
  • Loading branch information
senwu committed Dec 17, 2019
1 parent 053a55c commit a535dd8
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 6 deletions.
6 changes: 4 additions & 2 deletions src/emmental/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def update_config(
"""

if config != {}:
Meta.config = merge(Meta.config, config)
Meta.config = merge(Meta.config, config, specical_keys="checkpoint_metric")
logger.info("Updating Emmental config from user provided config.")

if path is not None:
Expand All @@ -166,7 +166,9 @@ def update_config(
potential_path = os.path.join(current_dir, filename)
if os.path.exists(potential_path):
with open(potential_path, "r") as f:
Meta.config = merge(Meta.config, yaml.load(f))
Meta.config = merge(
Meta.config, yaml.load(f), specical_keys="checkpoint_metric"
)
logger.info(f"Updating Emmental config from {potential_path}.")
break

Expand Down
15 changes: 12 additions & 3 deletions src/emmental/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,16 @@ def array_to_numpy(
return array


def merge(x: Dict[str, Any], y: Dict[str, Any]) -> Dict[str, Any]:
def merge(
x: Dict[str, Any], y: Dict[str, Any], specical_keys: Union[str, List[str]] = None
) -> Dict[str, Any]:
r"""Merge two nested dictionaries. Overwrite values in x with values in y.
Args:
x(dict): The original dict.
y(dict): The new dict.
specical_keys(str or list of str): The specical keys to replace
instead of merging, defaults to None.
Returns:
dict: The updated dic.
Expand All @@ -222,13 +226,18 @@ def merge(x: Dict[str, Any], y: Dict[str, Any]) -> Dict[str, Any]:
if y is None:
return x

if isinstance(specical_keys, str):
specical_keys = [specical_keys]

merged = {**x, **y}

xkeys = x.keys()

for key in xkeys:
if isinstance(x[key], dict) and key in y:
merged[key] = merge(x[key], y[key])
if specical_keys is not None and key in specical_keys and key in y:
merged[key] = y[key]
elif isinstance(x[key], dict) and key in y:
merged[key] = merge(x[key], y[key], specical_keys)

return merged

Expand Down
146 changes: 145 additions & 1 deletion tests/utils/test_parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,152 @@ def test_parse_args(caplog):
args = parser.parse_args([])
config1 = parse_args_to_config(args)
config2 = emmental.Meta.config
# import pdb; pdb.set_trace()

del config2["learner_config"]["global_evaluation_metric_dict"]
assert config1 == config2

shutil.rmtree(dirpath)


def test_checkpoint_metric(caplog):
"""Unit test of parsing checkpoint metric"""

caplog.set_level(logging.INFO)

# Test different checkpoint_metric
dirpath = "temp_parse_args"
Meta.reset()
emmental.init(
log_dir=dirpath,
config={
"logging_config": {
"checkpointer_config": {
"checkpoint_metric": {"model/valid/all/accuracy": "max"}
}
}
},
)

assert emmental.Meta.config == {
"meta_config": {"seed": None, "verbose": True, "log_path": "logs"},
"data_config": {"min_data_len": 0, "max_data_len": 0},
"model_config": {"model_path": None, "device": 0, "dataparallel": True},
"learner_config": {
"fp16": False,
"n_epochs": 1,
"train_split": ["train"],
"valid_split": ["valid"],
"test_split": ["test"],
"ignore_index": None,
"global_evaluation_metric_dict": None,
"optimizer_config": {
"optimizer": "adam",
"lr": 0.001,
"l2": 0.0,
"grad_clip": None,
"asgd_config": {"lambd": 0.0001, "alpha": 0.75, "t0": 1000000.0},
"adadelta_config": {"rho": 0.9, "eps": 1e-06},
"adagrad_config": {
"lr_decay": 0,
"initial_accumulator_value": 0,
"eps": 1e-10,
},
"adam_config": {"betas": (0.9, 0.999), "amsgrad": False, "eps": 1e-08},
"adamw_config": {"betas": (0.9, 0.999), "amsgrad": False, "eps": 1e-08},
"adamax_config": {"betas": (0.9, 0.999), "eps": 1e-08},
"lbfgs_config": {
"max_iter": 20,
"max_eval": None,
"tolerance_grad": 1e-07,
"tolerance_change": 1e-09,
"history_size": 100,
"line_search_fn": None,
},
"rms_prop_config": {
"alpha": 0.99,
"eps": 1e-08,
"momentum": 0,
"centered": False,
},
"r_prop_config": {"etas": (0.5, 1.2), "step_sizes": (1e-06, 50)},
"sgd_config": {"momentum": 0, "dampening": 0, "nesterov": False},
"sparse_adam_config": {"betas": (0.9, 0.999), "eps": 1e-08},
"bert_adam_config": {"betas": (0.9, 0.999), "eps": 1e-08},
},
"lr_scheduler_config": {
"lr_scheduler": None,
"lr_scheduler_step_unit": "batch",
"lr_scheduler_step_freq": 1,
"warmup_steps": None,
"warmup_unit": "batch",
"warmup_percentage": None,
"min_lr": 0.0,
"exponential_config": {"gamma": 0.9},
"plateau_config": {
"metric": "model/train/all/loss",
"mode": "min",
"factor": 0.1,
"patience": 10,
"threshold": 0.0001,
"threshold_mode": "rel",
"cooldown": 0,
"eps": 1e-08,
},
"step_config": {"step_size": 1, "gamma": 0.1, "last_epoch": -1},
"multi_step_config": {
"milestones": [1000],
"gamma": 0.1,
"last_epoch": -1,
},
"cyclic_config": {
"base_lr": 0.001,
"base_momentum": 0.8,
"cycle_momentum": True,
"gamma": 1.0,
"last_epoch": -1,
"max_lr": 0.1,
"max_momentum": 0.9,
"mode": "triangular",
"scale_fn": None,
"scale_mode": "cycle",
"step_size_down": None,
"step_size_up": 2000,
},
"one_cycle_config": {
"anneal_strategy": "cos",
"base_momentum": 0.85,
"cycle_momentum": True,
"div_factor": 25.0,
"final_div_factor": 10000.0,
"last_epoch": -1,
"max_lr": 0.1,
"max_momentum": 0.95,
"pct_start": 0.3,
},
"cosine_annealing_config": {"last_epoch": -1},
},
"task_scheduler_config": {
"task_scheduler": "round_robin",
"sequential_scheduler_config": {"fillup": False},
"round_robin_scheduler_config": {"fillup": False},
"mixed_scheduler_config": {"fillup": False},
},
},
"logging_config": {
"counter_unit": "epoch",
"evaluation_freq": 1,
"writer_config": {"writer": "tensorboard", "verbose": True},
"checkpointing": False,
"checkpointer_config": {
"checkpoint_path": None,
"checkpoint_freq": 1,
"checkpoint_metric": {"model/valid/all/accuracy": "max"},
"checkpoint_task_metrics": None,
"checkpoint_runway": 0,
"clear_intermediate_checkpoints": True,
"clear_all_checkpoints": False,
},
},
}

shutil.rmtree(dirpath)

0 comments on commit a535dd8

Please sign in to comment.