-
Notifications
You must be signed in to change notification settings - Fork 23
/
utils.py
183 lines (150 loc) · 5.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from pytorch_lightning.callbacks import EarlyStopping
from optuna.integration.pytorch_lightning import _check_pytorch_lightning_availability
from pathlib import Path
import numpy as np
import torch
import math
import torch
import optuna
from .logger import logger
def agg_dict(outputs):
keys = outputs[0].keys()
return {
k: torch.stack([x[k] for x in outputs if k in x])
.mean()
.cpu()
.item()
for k in keys
}
def agg_logs(outputs):
"""
Aggregate a list of dicts into a single (may have sub dicts but all array are aggregated)
outputs = [
{'val_loss': 0.7206,
'log': {'val_loss': 0.7206, 'val_loss_p': 0.7206,}},
{'val_loss': 0.7047,
'log': {'val_loss': 0.7047, 'val_loss_p': 0.7047}},
]
-> {'agg_val_loss': 0.7126500010490417, 'log': {'agg_val_loss': 0.7126500010490417, 'agg_val_loss_p': 0.7126500010490417, 'agg_val_loss_kl': 2.6101499770447845e-06, 'agg_val_loss_mse': 0.17669999599456787}}
"""
if isinstance(outputs, dict):
outputs = [outputs]
aggs = {}
if len(outputs) > 0:
for j in outputs[0].keys():
if isinstance(outputs[0][j], dict):
# Take mean of sub dicts
keys = outputs[0][j].keys()
aggs[j] = {
'agg_'+k: torch.stack([x[j][k] for x in outputs if k in x[j]])
.mean()
.cpu()
.item()
for k in keys
}
else:
# Take mean of numbers
aggs['agg_'+j] = (
torch.stack([x[j] for x in outputs if j in x]).mean().cpu().item()
)
return aggs
def round_values(d):
"""round values in dict to 2sf."""
def _round(v):
if isinstance(v, float):
return float(f"{v:.2g}")
elif isinstance(v, dict):
return round_values(v)
else:
return v
return {k: _round(v) for k, v in d.items()}
def init_random_seed(seed):
# https://pytorch.org/docs/stable/notes/randomness.html
np.random.seed(seed)
torch.random.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
class PyTorchLightningPruningCallback(EarlyStopping):
"""Optuna PyTorch Lightning callback to prune unpromising trials.
Example:
Add a pruning callback which observes validation accuracy.
.. code::
trainer.pytorch_lightning.Trainer(
early_stop_callback=PyTorchLightningPruningCallback(trial, monitor='avg_val_acc'))
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
monitor:
An evaluation metric for pruning, e.g., ``val_loss`` or
``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
``pytorch_lightning.LightningModule.training_step`` or
``pytorch_lightning.LightningModule.validation_end`` and the names thus depend on
how this dictionary is formatted.
"""
def __init__(self, trial, monitor, **kwargs):
# type: (optuna.trial.Trial, str) -> None
super().__init__(monitor, **kwargs)
_check_pytorch_lightning_availability()
self._trial = trial
self._monitor = monitor
def on_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
logs = trainer.callback_metrics or {}
current_score = logs.get(self._monitor)
if current_score is None:
return
self._trial.report(current_score, step=epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(epoch)
raise optuna.exceptions.TrialPruned(message)
class ObjectDict(dict):
"""
easy way to represent (hyper)parameters.
https://stackoverflow.com/a/50613966/221742
"""
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def __getstate__(self):
return self
def __setstate__(self, state):
self.update(state)
def copy(self, **extra_params):
return ObjectDict(**self, **extra_params)
@property
def __dict__(self):
return dict(self)
def hparams_power(hparams):
"""Some value we want to go up in powers of 2
So any hyper param that ends in power will be used this way.
"""
hparams_old = hparams.copy()
for k in hparams_old.keys():
if k.endswith("_power"):
k_new = k.replace("_power", "")
hparams[k_new] = int(2 ** hparams[k])
logger.debug("hparams %s", hparams)
return hparams
def log_prob_sigma(value, loc, log_scale):
"""A slightly more stable (not confirmed yet) log prob taking in log_var instead of scale.
modified from https://github.com/pytorch/pytorch/blob/2431eac7c011afe42d4c22b8b3f46dedae65e7c0/torch/distributions/normal.py#L65
"""
var = torch.exp(log_scale * 2)
return (
-((value - loc) ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi))
)
def kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post):
"""
Analytical KLD for two gaussians, taking in log_variance instead of scale ( given variance=scale**2) for more stable gradients
For version using scale see https://github.com/pytorch/pytorch/blob/master/torch/distributions/kl.py#L398
"""
var_ratio_log = log_var_post - log_var_prior
kl_div = (
(var_ratio_log.exp() + (post_mu - prior_mu) ** 2) / log_var_prior.exp()
- 1.0
- var_ratio_log
)
kl_div = 0.5 * kl_div
logger.warning('seems to be an error in kl_loss_var, dont use it')
return kl_div