This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
transformer_module.py
366 lines (311 loc) · 14.7 KB
/
transformer_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
import logging
import os
from os import PathLike
from typing import TYPE_CHECKING, Optional, Dict, Union, List, Any, TypeVar, Type
import re
import warnings
import torch
import torch.distributed as dist
from allennlp.common.util import is_distributed, is_global_primary
from allennlp.nn.util import StateDictType, read_state_dict, load_state_dict_distributed
if TYPE_CHECKING:
from transformers.configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
_T = TypeVar("_T", bound="TransformerModule")
class TransformerModule(torch.nn.Module):
"""
Base class to help with generalized loading of pretrained weights.
Subclasses should override `_from_config()` if you want to instantiate them with
`from_pretrained_module()`.
"""
_pretrained_mapping: Dict[str, str] = {}
"""
An optional mapping for each class that determines any differences in the module
names between the class modules and the HuggingFace model's modules.
Keys correspond to HuggingFace submodule names, values correspond to submodules names of this module.
"""
_pretrained_relevant_module: Optional[Union[str, List[str]]] = None
"""
An optional string or list of strings which contains the expected name of the module
in the HuggingFace pretrained model. It can be a list to account for different names in different
models. The search is carried out in the order of the list.
"""
_pretrained_ignore: Optional[List[str]] = None
"""
An optional list of regular expressions that define which weights to ignore from a pretrained state_dict.
"""
_pretrained_allow_missing: Optional[List[str]] = None
"""
An optional list of regular expressions that specifies which weights are allowed to be missing
from a pretrained state dictionary.
"""
@classmethod
def _get_mapping(
cls,
mapping: Optional[Dict[str, str]] = None,
):
"""
Returns the mapping to be used, based on the optional `mapping` overrides
and the default module-level mapping.
"""
combined_mapping = {}
combined_mapping.update(cls._pretrained_mapping)
if mapping is not None:
combined_mapping.update(mapping)
return combined_mapping
def _get_mapped_state_dict(
self,
state_dict: StateDictType,
mapping: Optional[Dict[str, str]] = None,
) -> StateDictType:
"""
Recursively map keys in a HuggingFace `state_dict` to the corresponding keys
for this module and all submodules.
"""
return _get_mapped_state_dict(self, state_dict, mapping=mapping)
@classmethod
def _get_relevant_submodule_state(
cls,
state_dict: StateDictType,
relevant_module: Optional[Union[str, List[str]]] = None,
) -> StateDictType:
"""
Returns the relevant part of the `state_dict`.
"""
relevant_modules: Optional[List[str]] = None
if relevant_module:
relevant_modules = (
[relevant_module] if isinstance(relevant_module, str) else relevant_module
)
elif isinstance(cls._pretrained_relevant_module, str):
relevant_modules = [cls._pretrained_relevant_module]
elif isinstance(cls._pretrained_relevant_module, list):
relevant_modules = cls._pretrained_relevant_module
if relevant_modules:
found = False
for module_name in relevant_modules:
relevant_keys = set(
[key for key in state_dict.keys() if key.startswith(module_name + ".")]
)
if relevant_keys:
# Only keep elements of state dict that correspond to the relevant module.
state_dict = {
key.replace(module_name + ".", "", 1): value
for key, value in state_dict.items()
if key in relevant_keys
}
found = True
break
if not found:
warnings.warn(
f"{relevant_modules} was not found at top level of state_dict!", UserWarning
)
return state_dict
@classmethod
def _get_pretrained_state_dict(
cls,
model_name: str,
weights_path: Optional[Union[str, PathLike]] = None,
relevant_module: Optional[Union[str, List[str]]] = None,
ignore: Optional[List[str]] = None,
) -> StateDictType:
"""
Get a HuggingFace pretrained `state_dict` corresponding to this module.
"""
if weights_path is None:
from transformers.file_utils import WEIGHTS_NAME
# First see if we can find the weights locally.
if os.path.isdir(model_name):
local_weights_path = os.path.join(model_name, WEIGHTS_NAME)
if os.path.isfile(local_weights_path):
logger.info("Found weights at local path %s", local_weights_path)
weights_path = local_weights_path
# If we haven't found locally, we assume model ID corresponds to a model
# on the HuggingFace Hub.
if weights_path is None:
from allennlp.common.file_utils import cached_path
weights_path = cached_path(f"hf://{model_name}/{WEIGHTS_NAME}")
# Now load the state dict.
logger.info("Reading state dict from %s", weights_path)
state_dict = read_state_dict(
weights_path,
ignore=ignore if ignore is not None else cls._pretrained_ignore,
strict=False,
)
# Keep just the relevant_module, remove everything else.
state_dict = cls._get_relevant_submodule_state(state_dict, relevant_module=relevant_module)
return state_dict
@classmethod
def _from_config(cls: Type[_T], config: "PretrainedConfig", **kwargs) -> _T:
"""
Instantiate this module from a HuggingFace config. Subclasses should override
this method if you want to be able to instantiate them with `from_pretrained_module()`.
"""
raise NotImplementedError
def _post_load_pretrained_state_dict_hook(
self, missing_keys: List[str], unexpected_keys: List[str]
) -> None:
"""
Subclasses can override this method to modify `missing_keys` or `unexpected_keys` after
loading a pretrained state dictionary.
"""
pass
@classmethod
def from_pretrained_module(
cls: Type[_T],
model_name: str,
*,
load_weights: bool = True,
weights_path: Optional[Union[str, PathLike]] = None,
auto_config_kwargs: Optional[Dict[str, Any]] = None,
mapping: Optional[Dict[str, str]] = None,
relevant_module: Optional[Union[str, List[str]]] = None,
ignore: Optional[List[str]] = None,
allow_missing: Optional[List[str]] = None,
strict: bool = True,
**kwargs,
) -> _T:
"""
Initialize this module from a corresponding model on HuggingFace.
!!! Note
This method is only available for subclasses that implement `_from_config()`.
Otherwise a `NotImplementedError` will be raised.
# Parameters
model_name : `str`
The model identifier or path.
load_weights : `bool`, optional (default = `True`)
Whether to download and load the pretrained weights. If `False`, the
weights are left uninitialized.
weights_path : `Optional[Union[str, PathLike]]`, optional (default = `None`)
When `load_weights` is `True`, this can be set to override the weights file.
Otherwise the default weights from the pretrained model are used.
auto_config_kwargs : `Optional[Dict[str, Any]]`, optional (default = `None`)
Optional key-word arguments to pass to `transformers.AutoConfig.from_pretrained()`
to load the pretrained model's configuration file.
mapping : `Optional[Dict[str, str]]`, optional (default = `None`)
Optional mapping that determines any differences in the submodule names
between this module and the pretrained model from HuggingFace.
If not given, the class's default is used: `cls._pretrained_mapping`.
relevant_module : `Optional[str]`, optional (default = `None`)
An optional submodule of the HuggingFace module to initialize weights from.
This is only relevant when `load_weights` is `True`.
If not given, the class's default is used: `cls._pretrained_relevant_module`.
ignore : `Optional[List[str]]`, optional (default = `None`)
An optional list of regular expressions that define which weights to ignore
from a pretrained state_dict.
This is only relevant when `load_weights` is `True`.
If not specified, the class's default is used: `cls._pretrained_ignore`.
allow_missing: `Optional[List[str]]`, optional (default = `None`)
An optional list of regular expressions that specifies which weights are allowed to be missing
from the pretrained state dictionary.
This is only relevant when `load_weights` is `True`.
If not specified, the class's default is used: `cls._pretrained_allow_missing`.
strict : `bool`, optional (default = `True`)
Whether to load the `state_dict` in "strict" model. This only applies
when `load_weights` is `True`.
**kwargs : `Any`
Key word arguments to pass to `cls.from_config()` when instantiating the module.
""" # noqa: E501
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_name, **(auto_config_kwargs or {}))
model = cls._from_config(config, **kwargs)
if load_weights:
state_dict: Optional[StateDictType] = None
if is_global_primary():
# Load the pretrained HuggingFace state_dict.
pretrained_state_dict = cls._get_pretrained_state_dict(
model_name,
weights_path=weights_path,
relevant_module=relevant_module,
ignore=ignore,
)
# Now map keys from the HuggingFace state_dict to the corresponding keys from
# this class. This is called recursively on each submodule of the current module.
state_dict = model._get_mapped_state_dict(pretrained_state_dict, mapping=mapping)
missing_keys: List[str]
unexpected_keys: List[str]
error_msgs: List[str] = []
if not is_distributed():
assert state_dict is not None
logger.info("Loading state_dict into module")
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
else:
# We're in distributed training. `state_dict` is `None` for all process groups
# except the global primary.
# Syncronize here since non-primary process groups will have to wait for the primary
# to load the state_dict into memory.
dist.barrier()
# Now load the state dict into the model.
logger.info("Loading state_dict into module (MEMORY_EFFICIENT strategy)")
missing_keys, unexpected_keys = load_state_dict_distributed(
model, state_dict, strict=False
)
# Run post load hook.
model._post_load_pretrained_state_dict_hook(missing_keys, unexpected_keys)
# Exclude any keys in `missing_keys` that match with the `allow_missing`
# regular expressions.
if allow_missing is None:
allow_missing = cls._pretrained_allow_missing
if allow_missing:
missing_keys = [
k for k in missing_keys if not any(re.match(p, k) for p in allow_missing)
]
if missing_keys:
error_msgs.append(
"Missing key(s) in state_dict: {}".format(
", ".join(f'"{k}"' for k in missing_keys)
)
)
if unexpected_keys:
error_msgs.append(
"Unexpected key(s) in state_dict: {}".format(
", ".join(f'"{k}"' for k in unexpected_keys)
)
)
if error_msgs and strict:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
cls.__name__, "\n\t".join(error_msgs)
)
)
# If there were error messages but we're not loading in 'strict' mode,
# we just issue warnings from the logger.
for msg in error_msgs:
logger.warning(msg)
return model
def _get_mapped_state_dict(
module: torch.nn.Module,
state_dict: StateDictType,
mapping: Optional[Dict[str, str]] = None,
) -> StateDictType:
# First fix all top-level keys according to `combined_mapping`.
combined_mapping = module._get_mapping(mapping) if isinstance(module, TransformerModule) else {}
for hf_key, cls_key in sorted(
# Sort by most specific key first.
combined_mapping.items(),
key=lambda x: x[0].count("."),
reverse=True,
):
relevant_keys = set(
[key for key in state_dict.keys() if (key == hf_key or key.startswith(hf_key + "."))]
)
for key in relevant_keys:
new_key = key.replace(hf_key, cls_key, 1)
# We have to be careful not to overwrite an entry that we might have updated
# on a previous iteration of this loop due to having a more specific key.
if new_key not in state_dict:
state_dict[new_key] = state_dict.pop(key)
# Now loop through the submodules, calling this function on each submodule.
for name, submodule in module.named_children():
# Pull-out the part of the state_dict corresponding to just this submodule.
relevant_keys = set([key for key in state_dict.keys() if key.startswith(name + ".")])
module_state_dict = {
key.replace(name + ".", "", 1): state_dict.pop(key) for key in relevant_keys
}
# Recursively call this function from the submodule to map this part
# of the state_dict.
module_state_dict = _get_mapped_state_dict(submodule, module_state_dict)
# And then update the full state_dict.
for key, value in module_state_dict.items():
state_dict[name + "." + key] = value
return state_dict