/
average.py
169 lines (136 loc) · 5.73 KB
/
average.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
# is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""
Average parameters from multiple model checkpoints. Checkpoints can be either
specified manually or automatically chosen according to one of several
strategies. The default strategy of simply selecting the top-scoring N points
works well in practice.
"""
import argparse
import itertools
import os
from typing import Dict, Iterable, Tuple, List
import mxnet as mx
import sockeye.constants as C
import sockeye.utils
import sockeye.arguments
from sockeye.log import setup_main_logger, log_sockeye_version
from sockeye.utils import check_condition
logger = setup_main_logger(__name__, console=True, file_logging=False)
def average(param_paths: Iterable[str]) -> Dict[str, mx.nd.NDArray]:
"""
Averages parameters from a list of .params file paths.
:param param_paths: List of paths to parameter files.
:return: Averaged parameter dictionary.
"""
all_arg_params = []
all_aux_params = []
for path in param_paths:
logger.info("Loading parameters from '%s'", path)
arg_params, aux_params = sockeye.utils.load_params(path)
all_arg_params.append(arg_params)
all_aux_params.append(aux_params)
logger.info("%d models loaded", len(all_arg_params))
check_condition(all(all_arg_params[0].keys() == p.keys() for p in all_arg_params),
"arg_param names do not match across models")
check_condition(all(all_aux_params[0].keys() == p.keys() for p in all_aux_params),
"aux_param names do not match across models")
avg_params = {}
# average arg_params
for k in all_arg_params[0]:
arrays = [p[k] for p in all_arg_params]
avg_params["arg:" + k] = sockeye.utils.average_arrays(arrays)
# average aux_params
for k in all_aux_params[0]:
arrays = [p[k] for p in all_aux_params]
avg_params["aux:" + k] = sockeye.utils.average_arrays(arrays)
return avg_params
def find_checkpoints(model_path: str, size=4, strategy="best", maximize=False, metric: str = C.PERPLEXITY) \
-> Iterable[str]:
"""
Finds N best points from .metrics file according to strategy
:param metric: Metric according to which checkpoints are selected. Corresponds to columns in model/metrics file.
:param model_path: Path to model.
:param size: Number of checkpoints to combine.
:param strategy: Combination strategy.
:param maximize: Whether the value of the metric should be maximized.
:return: List of paths corresponding to chosen checkpoints.
"""
metrics_path = os.path.join(model_path, C.METRICS_NAME)
points = sockeye.utils.read_metrics_points(metrics_path, model_path, metric=metric)
if strategy == "best":
# N best scoring points
top_n = _strategy_best(points, size, maximize)
elif strategy == "last":
# N sequential points ending with overall best
top_n = _strategy_last(points, size, maximize)
elif strategy == "lifespan":
# Track lifespan of every "new best" point
# Points dominated by a previous better point have lifespan 0
top_n = _strategy_lifespan(points, size, maximize)
else:
raise RuntimeError("Unknown strategy, options: best last lifespan")
# Assemble paths for params files corresponding to chosen checkpoints
# Last element in point is always the checkpoint id
params_paths = [
os.path.join(model_path, C.PARAMS_NAME % point[-1]) for point in top_n
]
# Report
logger.info("Found: " + ", ".join(str(point) for point in top_n))
return params_paths
def _strategy_best(points, size, maximize):
top_n = sorted(points, reverse=maximize)[:size]
return top_n
def _strategy_last(points, size, maximize):
best = max if maximize else min
after_top = points.index(best(points)) + 1
top_n = points[max(0, after_top - size):after_top]
return top_n
def _strategy_lifespan(points, size, maximize):
top_n = []
cur_best = points[0]
cur_lifespan = 0
for point in points[1:]:
better = point > cur_best if maximize else point < cur_best
if better:
top_n.append(list(itertools.chain([cur_lifespan], cur_best)))
cur_best = point
cur_lifespan = 0
else:
top_n.append(list(itertools.chain([0], point)))
cur_lifespan += 1
top_n.append(list(itertools.chain([cur_lifespan], cur_best)))
# Sort by lifespan, then by val
top_n = sorted(
top_n,
key=lambda point: [point[0], point[1] if maximize else -point[1]],
reverse=True)[:size]
return top_n
def main():
"""
Commandline interface to average parameters.
"""
log_sockeye_version(logger)
params = argparse.ArgumentParser(description="Averages parameters from multiple models.")
sockeye.arguments.add_average_args(params)
args = params.parse_args()
if len(args.inputs) > 1:
avg_params = average(args.inputs)
else:
param_paths = find_checkpoints(args.inputs[0], args.n, args.strategy,
args.max, args.metric)
avg_params = average(param_paths)
mx.nd.save(args.output, avg_params)
logger.info("Averaged parameters written to '%s'", args.output)
if __name__ == "__main__":
main()