Skip to content

Commit

Permalink
fixup finetune problem
Browse files Browse the repository at this point in the history
Summary: support finetune from the other model with different number of classes, and simplify calling way (#325)

close #325

close #325
  • Loading branch information
L1aoXingyu committed Nov 6, 2020
1 parent f496193 commit 7e9a477
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 41 deletions.
15 changes: 7 additions & 8 deletions fastreid/engine/defaults.py
Expand Up @@ -44,11 +44,6 @@ def default_argument_parser():
"""
parser = argparse.ArgumentParser(description="fastreid Training")
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--finetune",
action="store_true",
help="whether to attempt to finetune from the trained model",
)
parser.add_argument(
"--resume",
action="store_true",
Expand Down Expand Up @@ -244,8 +239,13 @@ def __init__(self, cfg):

def resume_or_load(self, resume=True):
"""
If `resume==True`, and last checkpoint exists, resume from it.
Otherwise, load a model specified by the config.
If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
a `last_checkpoint` file), resume from the file. Resuming means loading all
available states (eg. optimizer and scheduler) and update iteration counter
from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
Otherwise, this is considered as an independent training. The method will load model
weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
from iteration 0.
Args:
resume (bool): whether to do resume or not
"""
Expand Down Expand Up @@ -468,7 +468,6 @@ def auto_scale_hyperparams(cfg, data_loader):
because some hyper-param, such as MAX_ITER, means training epochs rather than iters,
so we need to convert specific hyper-param to training iterations.
"""

cfg = cfg.clone()
frozen = cfg.is_frozen()
cfg.defrost()
Expand Down
139 changes: 107 additions & 32 deletions fastreid/utils/checkpoint.py
@@ -1,12 +1,12 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import collections
import copy
import logging
import os
from collections import defaultdict
from typing import Any
from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable

import numpy as np
import torch
Expand All @@ -17,6 +17,23 @@
from fastreid.utils.file_io import PathManager


class _IncompatibleKeys(
NamedTuple(
# pyre-fixme[10]: Name `IncompatibleKeys` is used but not defined.
"IncompatibleKeys",
[
("missing_keys", List[str]),
("unexpected_keys", List[str]),
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
("incorrect_shapes", List[Tuple]),
],
)
):
pass


class Checkpointer(object):
"""
A checkpointer that can save/load model as well as extra checkpointable
Expand Down Expand Up @@ -50,7 +67,9 @@ def __init__(
self.save_dir = save_dir
self.save_to_disk = save_to_disk

def save(self, name: str, **kwargs: dict):
self.path_manager = PathManager

def save(self, name: str, **kwargs: Dict[str, str]):
"""
Dump model and checkpointables to a file.
Args:
Expand All @@ -74,13 +93,15 @@ def save(self, name: str, **kwargs: dict):
torch.save(data, f)
self.tag_last_checkpoint(basename)

def load(self, path: str):
def load(self, path: str, checkpointables: Optional[List[str]] = None) -> object:
"""
Load from the given checkpoint. When path points to network file, this
function has to be called on all ranks.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
checkpointables (list): List of checkpointable names to load. If not
specified (None), will load all the possible checkpointables.
Returns:
dict:
extra data loaded from the checkpoint that has not been
Expand All @@ -89,21 +110,25 @@ def load(self, path: str):
"""
if not path:
# no checkpoint provided
self.logger.info(
"No checkpoint found. Training model from scratch"
)
self.logger.info("No checkpoint found. Training model from scratch")
return {}
self.logger.info("Loading checkpoint from {}".format(path))
if not os.path.isfile(path):
path = PathManager.get_local_path(path)
path = self.path_manager.get_local_path(path)
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)

checkpoint = self._load_file(path)
self._load_model(checkpoint)
for key, obj in self.checkpointables.items():
if key in checkpoint:
incompatible = self._load_model(checkpoint)
if (
incompatible is not None
): # handle some existing subclasses that returns None
self._log_incompatible_keys(incompatible)

for key in self.checkpointables if checkpointables is None else checkpointables:
if key in checkpoint: # pyre-ignore
self.logger.info("Loading {} from {}".format(key, path))
obj.load_state_dict(checkpoint.pop(key))
obj = self.checkpointables[key]
obj.load_state_dict(checkpoint.pop(key)) # pyre-ignore

# return any further checkpoint data
return checkpoint
Expand Down Expand Up @@ -158,7 +183,9 @@ def resume_or_load(self, path: str, *, resume: bool = True):
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
return self.load(path)
else:
return self.load(path, checkpointables=[])

def tag_last_checkpoint(self, last_filename_basename: str):
"""
Expand Down Expand Up @@ -199,26 +226,40 @@ def _load_model(self, checkpoint: Any):

# work around https://github.com/pytorch/pytorch/issues/24139
model_state_dict = self.model.state_dict()
incorrect_shapes = []
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
self.logger.warning(
"'{}' has shape {} in the checkpoint but {} in the "
"model! Skipped.".format(
k, shape_checkpoint, shape_model
)
)
incorrect_shapes.append((k, shape_checkpoint, shape_model))
checkpoint_state_dict.pop(k)

incompatible = self.model.load_state_dict(
checkpoint_state_dict, strict=False
incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
return _IncompatibleKeys(
missing_keys=incompatible.missing_keys,
unexpected_keys=incompatible.unexpected_keys,
incorrect_shapes=incorrect_shapes,
)

def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None:
"""
Log information about the incompatible keys returned by ``_load_model``.
"""
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes:
self.logger.warning(
"Skip loading parameter '{}' to the model due to incompatible "
"shapes: {} in the checkpoint but {} in the "
"model! You might want to double check if this is expected.".format(
k, shape_checkpoint, shape_model
)
)
if incompatible.missing_keys:
self.logger.info(
get_missing_parameters_message(incompatible.missing_keys)
missing_keys = _filter_reused_missing_keys(
self.model, incompatible.missing_keys
)
if missing_keys:
self.logger.info(get_missing_parameters_message(missing_keys))
if incompatible.unexpected_keys:
self.logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
Expand Down Expand Up @@ -297,7 +338,27 @@ def save(self, name: str, **kwargs: Any):
self.checkpointer.save(name, **kwargs)


def get_missing_parameters_message(keys: list):
def _filter_reused_missing_keys(model: nn.Module, keys: List[str]) -> List[str]:
"""
Filter "missing keys" to not include keys that have been loaded with another name.
"""
keyset = set(keys)
param_to_names = defaultdict(set) # param -> names that points to it
for module_prefix, module in _named_modules_with_dup(model):
for name, param in list(module.named_parameters(recurse=False)) + list(
module.named_buffers(recurse=False) # pyre-ignore
):
full_name = (module_prefix + "." if module_prefix else "") + name
param_to_names[param].add(full_name)
for names in param_to_names.values():
# if one name appears missing but its alias exists, then this
# name is not considered missing
if any(n in keyset for n in names) and not all(n in keyset for n in names):
[keyset.remove(n) for n in names if n in keyset]
return list(keyset)


def get_missing_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Expand All @@ -307,14 +368,14 @@ def get_missing_parameters_message(keys: list):
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "Some model parameters are not in the checkpoint:\n"
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
)
return msg


def get_unexpected_parameters_message(keys: list):
def get_unexpected_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Expand All @@ -324,15 +385,14 @@ def get_unexpected_parameters_message(keys: list):
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint contains parameters not used by the model:\n"
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "magenta")
for k, v in groups.items()
" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
)
return msg


def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
"""
Strip the prefix in metadata, if any.
Args:
Expand All @@ -349,7 +409,7 @@ def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):

# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata
metadata = state_dict._metadata # pyre-ignore
except AttributeError:
pass
else:
Expand All @@ -365,7 +425,7 @@ def _strip_prefix_if_present(state_dict: collections.OrderedDict, prefix: str):
metadata[newkey] = metadata.pop(key)


def _group_checkpoint_keys(keys: list):
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Expand All @@ -386,7 +446,7 @@ def _group_checkpoint_keys(keys: list):
return groups


def _group_to_str(group: list):
def _group_to_str(group: List[str]) -> str:
"""
Format a group of parameter name suffixes into a loggable string.
Args:
Expand All @@ -401,3 +461,18 @@ def _group_to_str(group: list):
return "." + group[0]

return ".{" + ", ".join(group) + "}"


def _named_modules_with_dup(
model: nn.Module, prefix: str = ""
) -> Iterable[Tuple[str, nn.Module]]:
"""
The same as `model.named_modules()`, except that it includes
duplicated modules that have more than one name.
"""
yield prefix, model
for name, module in model._modules.items(): # pyre-ignore
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _named_modules_with_dup(module, submodule_prefix)
1 change: 0 additions & 1 deletion tools/train_net.py
Expand Up @@ -40,7 +40,6 @@ def main(args):
return res

trainer = DefaultTrainer(cfg)
if args.finetune: Checkpointer(trainer.model).load(cfg.MODEL.WEIGHTS) # load trained model to funetune

trainer.resume_or_load(resume=args.resume)
return trainer.train()
Expand Down

0 comments on commit 7e9a477

Please sign in to comment.