Skip to content

Commit

Permalink
Merge eb886f4 into 4650ec0
Browse files Browse the repository at this point in the history
  • Loading branch information
senwu committed Jan 9, 2020
2 parents 4650ec0 + eb886f4 commit 8149fd2
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Unreleased_

Added
^^^^^
* `@senwu`_: Add `checkpoint_all` to controll whether to save all checkpoints.
* `@senwu`_: Support `CosineAnnealingLR`, `CyclicLR`, `OneCycleLR`, `ReduceLROnPlateau`
lr scheduler.
* `@senwu`_: Support more unit tests.
Expand Down
1 change: 1 addition & 0 deletions src/emmental/emmental-default-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,6 @@ logging_config:
model/train/all/loss: min # metric_name: mode, where mode in [min, max]
checkpoint_task_metrics: # task_metric_name: mode
checkpoint_runway: 0 # checkpointing runway (no checkpointing before k unit)
checkpoint_all: False # checkpointing all checkpoints
clear_intermediate_checkpoints: True # whether to clear intermediate checkpoints
clear_all_checkpoints: False # whether to clear all checkpoints
16 changes: 15 additions & 1 deletion src/emmental/logging/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import os
from shutil import copyfile
from typing import Any, Dict, Set, Union
from typing import Any, Dict, List, Set, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler
Expand Down Expand Up @@ -84,6 +84,13 @@ def __init__(self) -> None:
f"{self.checkpoint_unit}."
)

self.checkpoint_all = Meta.config["logging_config"]["checkpointer_config"][
"checkpoint_all"
]
logger.info(f"Checkpointing all checkpoints: {self.checkpoint_all}.")

self.checkpoint_paths: List[str] = []

# Set up checkpoint clear
self.clear_intermediate_checkpoints = Meta.config["logging_config"][
"checkpointer_config"
Expand Down Expand Up @@ -135,6 +142,13 @@ def checkpoint(
f"at {checkpoint_path}."
)

if self.checkpoint_all is False:
for path in self.checkpoint_paths:
if os.path.exists(path):
os.remove(path)

self.checkpoint_paths.append(checkpoint_path)

if not set(self.checkpoint_all_metrics.keys()).isdisjoint(
set(metric_dict.keys())
):
Expand Down
8 changes: 8 additions & 0 deletions src/emmental/utils/parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,13 @@ def parse_args(parser: Optional[ArgumentParser] = None) -> ArgumentParser:
help="Checkpointing runway (no checkpointing before k checkpointing unit)",
)

logging_config.add_argument(
"--checkpoint_all",
type=str2bool,
default=False,
help="Whether to checkpoint all checkpoints",
)

logging_config.add_argument(
"--clear_intermediate_checkpoints",
type=str2bool,
Expand Down Expand Up @@ -960,6 +967,7 @@ def parse_args_to_config(args: Namespace) -> Dict[str, Any]:
"checkpoint_metric": args.checkpoint_metric,
"checkpoint_task_metrics": args.checkpoint_task_metrics,
"checkpoint_runway": args.checkpoint_runway,
"checkpoint_all": args.checkpoint_all,
"clear_intermediate_checkpoints": args.clear_intermediate_checkpoints,
"clear_all_checkpoints": args.clear_all_checkpoints,
},
Expand Down
1 change: 1 addition & 0 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_e2e(caplog):
"checkpoint_metric": {"model/all/train/loss": "min"},
"checkpoint_task_metrics": None,
"checkpoint_runway": 0,
"checkpoint_all": False,
"clear_intermediate_checkpoints": True,
"clear_all_checkpoints": False,
},
Expand Down
4 changes: 3 additions & 1 deletion tests/utils/test_parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_parse_args(caplog):
caplog.set_level(logging.INFO)

parser = parse_args()
args = parser.parse_args(["--seed", "0"])
args = parser.parse_args(["--seed", "0", "--checkpoint_all", "True"])
assert args.seed == 0

config = parse_args_to_config(args)
Expand Down Expand Up @@ -134,6 +134,7 @@ def test_parse_args(caplog):
"checkpoint_metric": {"model/train/all/loss": "min"},
"checkpoint_task_metrics": None,
"checkpoint_runway": 0,
"checkpoint_all": True,
"clear_intermediate_checkpoints": True,
"clear_all_checkpoints": False,
},
Expand Down Expand Up @@ -291,6 +292,7 @@ def test_checkpoint_metric(caplog):
"checkpoint_metric": {"model/valid/all/accuracy": "max"},
"checkpoint_task_metrics": None,
"checkpoint_runway": 0,
"checkpoint_all": False,
"clear_intermediate_checkpoints": True,
"clear_all_checkpoints": False,
},
Expand Down

0 comments on commit 8149fd2

Please sign in to comment.