-
Notifications
You must be signed in to change notification settings - Fork 86
/
utils.py
executable file
·105 lines (87 loc) · 3.67 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Utils for parsing PBA augmentation schedules."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import ast
import collections
import tensorflow as tf
import json
PbtUpdate = collections.namedtuple('PbtUpdate', [
'target_trial_name', 'clone_trial_name', 'target_trial_epochs',
'clone_trial_epochs', 'old_config', 'new_config'
])
def parse_log(file_path, epochs):
"""Parses augmentation policy schedule from log file.
Args:
file_path: Path to policy generated by running search.py.
epochs: The number of epochs search was run for.
Returns:
A list containing the parsed policy of the form: [start epoch, start_epoch_clone, policy], where each element is a tuple of (num_epochs, policy list).
"""
raw_policy_file = open(file_path, "r").readlines()
raw_policy = []
for line in raw_policy_file:
try:
raw_policy_line = json.loads(line)
except:
raw_policy_line = ast.literal_eval(line)
raw_policy.append(raw_policy_line)
# Depreciated use case has policy as list instead of dict config.
for r in raw_policy:
for i in [4, 5]:
if isinstance(r[i], list):
r[i] = {"hp_policy": r[i]}
raw_policy = [PbtUpdate(*r) for r in raw_policy]
policy = []
# Sometimes files have extra lines in the beginning.
to_truncate = None
for i in range(len(raw_policy) - 1):
if raw_policy[i][0] != raw_policy[i + 1][1]:
to_truncate = i
if to_truncate is not None:
raw_policy = raw_policy[to_truncate + 1:]
# Initial policy for trial_to_clone_epochs.
policy.append([raw_policy[0][3], raw_policy[0][4]["hp_policy"]])
current = raw_policy[0][3]
for i in range(len(raw_policy) - 1):
# End at next line's trial epoch, start from this clone epoch.
this_iter = raw_policy[i + 1][3] - raw_policy[i][3]
assert this_iter >= 0, (i, raw_policy[i + 1][3], raw_policy[i][3])
assert raw_policy[i][0] == raw_policy[i + 1][1], (i, raw_policy[i][0],
raw_policy[i + 1][1])
policy.append([this_iter, raw_policy[i][5]["hp_policy"]])
current += this_iter
# Last cloned trial policy is run for (end - clone iter of last logged line)
policy.append([epochs - raw_policy[-1][3], raw_policy[-1][5]["hp_policy"]])
current += epochs - raw_policy[-1][3]
assert epochs == sum([p[0] for p in policy])
return policy
def parse_log_schedule(file_path, epochs, multiplier=1):
"""Parses policy schedule from log file.
Args:
file_path: Path to policy generated by running search.py.
epochs: The number of epochs search was run for.
multiplier: Multiplier on number of epochs for each policy in the schedule..
Returns:
List of length epochs, where index i contains the policy to use at epoch i.
"""
policy = parse_log(file_path, epochs)
schedule = []
count = 0
for num_iters, pol in policy:
tf.logging.debug("iters {} by multiplier {} result: {}".format(
num_iters, multiplier, num_iters * multiplier))
for _ in range(int(num_iters * multiplier)):
schedule.append(pol)
count += 1
if int(epochs * multiplier) - count > 0:
tf.logging.info("len: {}, remaining: {}".format(
count, epochs * multiplier))
for _ in range(int(epochs * multiplier) - count):
schedule.append(policy[-1][1])
tf.logging.info("final len {}".format(len(schedule)))
return schedule
if __name__ == "__main__":
schedule = parse_log('schedules/rsvhn_16_wrn.txt', 160)
for s in schedule:
print(s)