-
Notifications
You must be signed in to change notification settings - Fork 6
/
stopping.py
601 lines (504 loc) · 21.4 KB
/
stopping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
"""
Module containing the classes related to the stopping alogirthm
In this module there are four Classes:
- FitState: this class contains the information of the fit
for a given point in history
- FitHistory: this class contains the information necessary
in order to reset the state of the fit to the point
in which the history was saved.
i.e., a list of FitStates
- Stopping: this class monitors the chi2 of the validation
and training sets and decides when to stop
- Positivity: Decides whether a given point fullfills the positivity conditions
- Validation: Controls the NNPDF cross-validation algorithm
Note:
There are situations in which the validation set is empty, in those cases
the training set is used as validation set.
This implies several changes in the behaviour of this class as the training chi2 will
now be monitored for stability.
In order to parse the set of loss functions coming from the backend::MetaModel,
the function `parse_losses` relies on the fact that they are all suffixed with `_loss`
the validation case, instead, is suffixed with `val_loss`. In the particular casse in
which both training and validation model correspond to the same backend::MetaModel only
the `_loss` suffix can be found. This is taken into account by the class `Stopping`
which will tell `Validation` that no validation set was found and that the training is to
be used instead.
"""
import logging
import numpy as np
log = logging.getLogger(__name__)
# Put a very big number here so that we for sure discard this run
# AND we have a clear marker that something went wrong, not just a bad fit
TERRIBLE_CHI2 = 1e10
INITIAL_CHI2 = 1e9
# Pass/veto keys
POS_OK = "POS_PASS"
POS_BAD = "POS_VETO"
THRESHOLD_POS = 1e-6
def parse_ndata(all_data):
"""
Parses the list of dictionaries received from ModelTrainer
into a dictionary containing only the name of the experiments
together with the number of points.
Returns
-------
`tr_ndata`
dictionary of {'exp' : ndata}
`vl_ndata`
dictionary of {'exp' : ndata}
`pos_set`: list of the names of the positivity sets
Note: if there is no validation (total number of val points == 0)
then vl_ndata will point to tr_ndata
"""
tr_ndata_dict = {}
vl_ndata_dict = {}
pos_set = []
for dictionary in all_data:
exp_name = dictionary["name"]
if dictionary.get("count_chi2"):
tr_ndata = dictionary["ndata"]
vl_ndata = dictionary["ndata_vl"]
if tr_ndata:
tr_ndata_dict[exp_name] = tr_ndata
if vl_ndata:
vl_ndata_dict[exp_name] = vl_ndata
if dictionary.get("positivity") and not dictionary.get("integrability"):
pos_set.append(exp_name)
if not vl_ndata_dict:
vl_ndata_dict = None
return tr_ndata_dict, vl_ndata_dict, pos_set
def parse_losses(history_object, data, suffix="loss"):
"""
Receives an object containing the chi2
Usually a history object, but it can come in the form of a dictionary.
It loops over the dictionary and uses the npoints_data dictionary to
normalize the chi2 and return backs a tuple (`total`, `tr_chi2`)
Parameters
----------
history_object: dict
A history object dictionary
data: dict
dictionary with the name of the experiments to be taken into account
and the number of datapoints of the experiments
suffix: str (default: ``loss``)
suffix of the loss layer, Keras default is _loss
Returns
-------
total_loss: float
Total value for the loss
dict_chi2: dict
dictionary of {'expname' : loss }
"""
try:
hobj = history_object.history
except AttributeError: # So it works whether we pass the out or the out.history
hobj = history_object
# In the general case epochs = 1.
# In case that we are doing more than 1 epoch, take the last result as it is the result
# the model is in at the moment
# This value is only used for printing output purposes so should not have any significance
dict_chi2 = {}
total_points = 0
total_loss = 0
for exp_name, npoints in data.items():
loss = np.array(hobj[exp_name + f"_{suffix}"])
dict_chi2[exp_name] = loss / npoints
total_points += npoints
total_loss += loss
# By taking the loss from the history object we would be saving the total loss
# including positivity sets and (if added/enabled) regularizsers
# instead we want to restrict ourselves to the loss coming from experiments
# total_loss = np.mean(hobj["loss"]) / total_points
total_loss /= total_points
dict_chi2["total"] = total_loss
return total_loss, dict_chi2
class FitState:
"""
Holds the state of the chi2 during the fit, for all replicas and one epoch
Note: the training chi2 is computed before the update of the weights
so it is the chi2 that informed the updated corresponding to this state.
The validation chi2 instead is computed after the update of the weights.
Parameters
----------
training_info: dict
all losses for the training model
validation_info: dict
all losses for the validation model
training_loss: float
total training loss, this can be given if per-exp``training_info``
is not available
"""
vl_ndata = None
tr_ndata = None
vl_suffix = None
def __init__(self, training_info, validation_info, training_loss=None):
if self.vl_ndata is None or self.tr_ndata is None or self.vl_suffix is None:
raise ValueError(
"FitState cannot be instantiated until vl_ndata, tr_ndata and vl_suffix are filled"
)
self._training = training_info
self.validation = validation_info
self._parsed = False
self._vl_chi2 = None # These are per replica
self._tr_chi2 = None # This is an overall training chi2
self._vl_dict = None
self._tr_dict = None
# This can be given if ``training_info`` is not given
self._training_loss = training_loss
@property
def vl_loss(self):
"""Return the total validation loss as it comes from the info dictionaries"""
return self.validation.get("loss")
@property
def tr_loss(self):
"""Return the total validation loss as it comes from the info dictionaries"""
if self._training is None:
return self._training_loss
return self._training.get("loss")
def _parse_chi2(self):
"""
Parses the chi2 from the losses according to the `tr_ndata` and
`vl_ndata` dictionaries of {dataset: n_points}
"""
if self._parsed:
return
if self._training is not None:
self._tr_chi2, self._tr_dict = parse_losses(self._training, self.tr_ndata)
if self.validation is not None:
self._vl_chi2, self._vl_dict = parse_losses(
self.validation, self.vl_ndata, suffix=self.vl_suffix
)
@property
def tr_chi2(self):
self._parse_chi2()
return self._tr_chi2
@property
def vl_chi2(self):
self._parse_chi2()
return self._vl_chi2
@property
def all_tr_chi2(self):
self._parse_chi2()
return self._tr_dict
@property
def all_vl_chi2(self):
self._parse_chi2()
return self._vl_dict
def all_tr_chi2_for_replica(self, i_replica):
"""Return the tr chi2 per dataset for a given replica"""
return {k: np.take(v, i_replica) for k, v in self.all_tr_chi2.items()}
def all_vl_chi2_for_replica(self, i_replica):
"""Return the vl chi2 per dataset for a given replica"""
return {k: np.take(v, i_replica) for k, v in self.all_vl_chi2.items()}
def total_partial_tr_chi2(self):
"""Return the tr chi2 summed over replicas per experiment"""
return {k: np.sum(v) for k, v in self.all_tr_chi2.items()}
def total_partial_vl_chi2(self):
"""Return the vl chi2 summed over replicas per experiment"""
return {k: np.sum(v) for k, v in self.all_vl_chi2.items()}
def total_tr_chi2(self):
"""Return the total tr chi2 summed over replicas"""
return np.sum(self.tr_chi2)
def total_vl_chi2(self):
"""Return the total vl chi2 summed over replicas"""
return np.sum(self.vl_chi2)
def __str__(self):
return f"chi2: tr={self.tr_chi2} vl={self.vl_chi2}"
class FitHistory:
"""
Keeps a list of FitState items holding the full chi2 history of the fit.
Parameters
----------
tr_ndata: dict
dictionary of {dataset: n_points} for the training data
vl_ndata: dict
dictionary of {dataset: n_points} for the validation data
"""
def __init__(self, tr_ndata, vl_ndata):
if vl_ndata is None:
vl_ndata = tr_ndata
vl_suffix = "loss"
else:
vl_suffix = "val_loss"
# All instances of FitState should use these
FitState.tr_ndata = tr_ndata
FitState.vl_ndata = vl_ndata
FitState.vl_suffix = vl_suffix
# Save a list of status for the entire fit
self._history = []
self.final_epoch = None
def get_state(self, epoch):
"""Get the FitState of the system for a given epoch"""
try:
return self._history[epoch]
except IndexError as e:
raise ValueError(
f"Tried to get obtain the state for epoch {epoch} when only {len(self._history)} epochs have been saved"
) from e
def register(self, epoch, fitstate):
"""Save the current fitstate and the associated epoch
and set the current epoch as the final one should the fit end now
"""
self.final_epoch = epoch
self._history.append(fitstate)
class Stopping:
"""
Driver of the stopping algorithm
Note, if the total number of points in the validation dictionary is None, it is assumed
the validation_model actually corresponds to the training model.
Parameters
----------
validation_model: n3fit.backends.MetaModel
the model with the validation mask applied
(and compiled with the validation data and covmat)
all_data_dicts: dict
list containg all dictionaries containing all information about
the experiments/validation/regularizers/etc to be parsed by Stopping
pdf_model: n3fit.backends.MetaModel
pdf_model being trained
threshold_positivity: float
maximum value allowed for the sum of all positivity losses
total_epochs: int
total number of epochs
stopping_patience: int
how many epochs to wait for the validation loss to improve
threshold_chi2: float
maximum value allowed for chi2
dont_stop: bool
dont care about early stopping
"""
def __init__(
self,
validation_model,
all_data_dicts,
pdf_model,
threshold_positivity=THRESHOLD_POS,
total_epochs=0,
stopping_patience=7000,
threshold_chi2=10.0,
dont_stop=False,
):
self._pdf_model = pdf_model
# Save the validation object
self._validation = validation_model
# Create the History object
tr_ndata, vl_ndata, pos_sets = parse_ndata(all_data_dicts)
self._history = FitHistory(tr_ndata, vl_ndata)
# And the positivity checker
self._positivity = Positivity(threshold_positivity, pos_sets)
# Initialize internal variables for the stopping
self._n_replicas = pdf_model.num_replicas
self._threshold_chi2 = threshold_chi2
self._stopping_degrees = np.zeros(self._n_replicas, dtype=int)
self._counts = np.zeros(self._n_replicas, dtype=int)
self._dont_stop = dont_stop
self._stop_now = False
self.stopping_patience = stopping_patience
self.total_epochs = total_epochs
self._stop_epochs = [total_epochs - 1] * self._n_replicas
self._best_epochs = [None] * self._n_replicas
self.positivity_statuses = [POS_BAD] * self._n_replicas
self._best_weights = [None] * self._n_replicas
self._best_val_chi2s = [INITIAL_CHI2] * self._n_replicas
@property
def vl_chi2(self):
"""Current validation chi2"""
validation_info = self._validation.compute_losses()
fitstate = FitState(None, validation_info)
return fitstate.vl_chi2
@property
def e_best_chi2(self):
"""Epoch of the best chi2, if there is no best epoch, return last"""
best_or_last_epochs = [
best if best is not None else last
for best, last in zip(self._best_epochs, self._stop_epochs)
]
return best_or_last_epochs
@property
def stop_epoch(self):
"""Epoch in which the fit is stopped"""
return -1 if self._history.final_epoch is None else self._history.final_epoch + 1
@property
def positivity_status(self):
"""Returns POS_PASS if positivity passes or veto if it doesn't
for each replica"""
return self.positivity_statuses
def evaluate_training(self, training_model):
"""Given the training model, evaluates the
model and parses the chi2 of the training datasets
Parameters
----------
training_model: n3fit.backends.MetaModel
an object implementing the evaluate function
Returns
-------
tr_chi2: float
chi2 of the given ``training_model``
"""
training_info = training_model.compute_losses()
fitstate = FitState(training_info, None)
return fitstate.tr_chi2
def monitor_chi2(self, training_info, epoch, print_stats=False):
"""
Function to be called at the end of every epoch.
Stores the total chi2 of the training set as well as the
total chi2 of the validation set.
If the training chi2 is below a certain threshold,
stores the state of the model which gave the minimum chi2
as well as the epoch in which occurred
If the epoch is a multiple of save_all_each then we also save the per-exp chi2
Returns True if the run seems ok and False if a NaN is found
Parameters
----------
training_info: dict
output of a .fit() call, dictionary of the total training loss
(summed over replicas and experiments)
epoch: int
index of the epoch
Returns
-------
pass_ok: bool
true/false according to the status of the run
"""
# Step 1. Check whether the fit has NaN'd and stop it if so
if np.isnan(training_loss := training_info["loss"]):
log.warning(" > NaN found, stopping activated")
self.make_stop()
return False
# Step 2. Compute the validation metrics
validation_info = self._validation.compute_losses()
# Step 3. Register the current point in (the) history
# and set the current final epoch as the current one
fitstate = FitState(None, validation_info, training_loss)
self._history.register(epoch, fitstate)
if print_stats:
self.print_current_stats(epoch, fitstate)
# Step 4. Check whether this is a better fit
# this means improving vl_chi2 and passing positivity
# Don't start counting until the chi2 of the validation goes below a certain threshold
# once we start counting, don't bother anymore
passes = self._counts | (fitstate.vl_chi2 < self._threshold_chi2)
passes &= fitstate.vl_loss < self._best_val_chi2s
# And the ones that pass positivity
passes &= self._positivity(fitstate)
self._stopping_degrees += self._counts
# Step 5. loop over the valid indices to check whether the vl improved
for i_replica in np.where(passes)[0]:
self._best_epochs[i_replica] = epoch
# By definition, if we have a ``best_epoch`` then positivity passed
self.positivity_statuses[i_replica] = POS_OK
self._best_val_chi2s[i_replica] = self._history.get_state(epoch).vl_loss[i_replica]
self._best_weights[i_replica] = self._pdf_model.get_replica_weights(i_replica)
self._stopping_degrees[i_replica] = 0
self._counts[i_replica] = 1
stop_replicas = self._counts & (self._stopping_degrees > self.stopping_patience)
for i_replica in np.where(stop_replicas)[0]:
self._stop_epochs[i_replica] = epoch
self._counts[i_replica] = 0
# By using the stopping degree we only stop when none of the replicas are improving anymore
if min(self._stopping_degrees) > self.stopping_patience:
self.make_stop()
return True
def make_stop(self):
"""Convenience method to set the stop_now flag
and reload the history to the point of the best model if any
"""
self._stop_now = True
self._restore_best_weights()
def _restore_best_weights(self):
for i_replica, weights in enumerate(self._best_weights):
if weights is not None:
self._pdf_model.set_replica_weights(weights, i_replica)
def print_current_stats(self, epoch, fitstate):
"""
Prints ``fitstate`` validation chi2 for every experiment
and the current total training loss as well as the validation loss
after the training step
"""
epoch_index = epoch + 1
vl_chi2 = fitstate.total_vl_chi2()
total_str = f"Epoch {epoch_index}/{self.total_epochs}: loss: {fitstate.tr_loss:.7f}"
total_str += f"\nValidation loss after training step: {vl_chi2:.7f}."
# The partial chi2 makes no sense for more than one replica at once:
if self._n_replicas == 1:
total_str += "\nValidation chi2s: "
partial_vl_chi2 = fitstate.total_partial_vl_chi2()
partials = []
for experiment, chi2 in partial_vl_chi2.items():
partials.append(f"{experiment}: {chi2:.3f}")
total_str += ", ".join(partials)
log.info(total_str)
def stop_here(self):
"""Returns the stopping status
If `dont_stop` is set returns always False (i.e., never stop)
"""
if self._dont_stop:
return False
else:
return self._stop_now
def chi2exps_json(self, i_replica=0, log_each=100):
"""
Returns and apt-for-json dictionary with the status of the fit every `log_each` epochs
It reports the total training loss and the validation loss broken down by experiment.
Parameters
----------
i_replica: int
which replica are we writing the log for
log_each: int
every how many epochs to print the log
Returns
-------
file_list: list(str)
a list of strings to be printed as `chi2exps.log`
"""
final_epoch = self._history.final_epoch
json_dict = {}
for epoch in range(log_each - 1, final_epoch + 1, log_each):
fitstate = self._history.get_state(epoch)
# Get the training and validation losses
tmp = {"training_loss": fitstate.tr_loss, "validation_loss": fitstate.vl_loss.tolist()}
# And the validation chi2 broken down by experiment
tmp["validation_chi2s"] = fitstate.all_vl_chi2_for_replica(i_replica)
json_dict[epoch + 1] = tmp
return json_dict
class Positivity:
"""
Controls the positivity requirements.
In order to check the positivity passes will check the history of the fitting
as the fitting included positivity sets.
If the sum of all positivity sets losses is above a certain value the model is
not accepted and the training continues.
Parameters
----------
threshold_positivity: float
maximum value allowed for the sum of all positivity losses
positivity_sets: list
list of positivity datasets
"""
def __init__(self, threshold, positivity_sets):
self.threshold = threshold
self.positivity_sets = positivity_sets
def check_positivity(self, history_object):
"""
This function receives a history objects and loops over the
positivity_sets to check the value of the positivity loss.
If the positivity loss is above the threshold, the positivity fails
otherwise, it passes.
It returns an array booleans which are True if positivity passed
story_object[key_loss] < self.threshold
Parameters
----------
history_object: dict
dictionary of entries in the form {'name': loss}, output of a MetaModel .fit()
"""
positivity_pass = True
for key in self.positivity_sets:
key_loss = f"{key}_loss"
positivity_pass &= history_object[key_loss] < self.threshold
return np.array(positivity_pass)
def __call__(self, fitstate):
"""
Checks whether a given FitState object
passes the positivity requirement
"""
return self.check_positivity(fitstate.validation)