This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
module.py
45 lines (37 loc) · 1.87 KB
/
module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import List, Optional, Tuple
import torch
from allennlp.nn.util import (
_check_incompatible_keys,
_IncompatibleKeys,
StateDictType,
load_state_dict_distributed,
)
class Module(torch.nn.Module):
"""
This is just `torch.nn.Module` with some extra functionality.
"""
def _post_load_state_dict(
self, missing_keys: List[str], unexpected_keys: List[str]
) -> Tuple[List[str], List[str]]:
"""
Subclasses can override this and potentially modify `missing_keys` or `unexpected_keys`.
"""
return missing_keys, unexpected_keys
def load_state_dict(self, state_dict: StateDictType, strict: bool = True) -> _IncompatibleKeys:
"""
Same as [`torch.nn.Module.load_state_dict()`]
(https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict)
except we also run the [`_post_load_state_dict`](#_post_load_state_dict) method before returning,
which can be implemented by subclasses to customize the behavior.
"""
missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False) # type: ignore[arg-type]
missing_keys, unexpected_keys = self._post_load_state_dict(missing_keys, unexpected_keys)
_check_incompatible_keys(self, missing_keys, unexpected_keys, strict)
return _IncompatibleKeys(missing_keys, unexpected_keys)
def load_state_dict_distributed(
self, state_dict: Optional[StateDictType], strict: bool = True
) -> _IncompatibleKeys:
missing_keys, unexpected_keys = load_state_dict_distributed(self, state_dict, strict=strict)
missing_keys, unexpected_keys = self._post_load_state_dict(missing_keys, unexpected_keys)
_check_incompatible_keys(self, missing_keys, unexpected_keys, strict)
return _IncompatibleKeys(missing_keys, unexpected_keys)