-
Notifications
You must be signed in to change notification settings - Fork 1
/
transformer_callbacks.py
46 lines (37 loc) · 1.7 KB
/
transformer_callbacks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# transformer_callbacks.py
import tensorflow as tf
class TransformerCallbacks(tf.keras.callbacks.Callback):
"""
Custom callback to monitor the validation loss during training and save the best model.
Args:
config: Configuration object containing hyperparameters.
Attributes:
checkpoint_filepath: Filepath to save the best model.
patience: Number of epochs to wait for improvement in validation loss.
best_loss: Best validation loss observed during training.
Methods:
on_epoch_end: Called at the end of each epoch to monitor the validation loss.
"""
def __init__(self, config):
super(TransformerCallbacks, self).__init__()
self.checkpoint_filepath = config.checkpoint_filepath
self.patience = config.patience
self.best_loss = float('inf') # Initialize with a very large value for the first comparison
def on_epoch_end(self, epoch, logs={}):
"""
Callback function called at the end of each epoch to monitor the validation loss.
Args:
epoch: The current epoch number.
logs: Dictionary containing training and validation metrics.
"""
# Access the validation loss from the logs dictionary
val_loss = logs.get('val_loss')
if val_loss < self.best_loss:
self.best_loss = val_loss
self.model.save(self.checkpoint_filepath)
print('The best model has been saved at epoch #{}'.format(epoch))
elif self.patience:
self.patience -= 1
if self.patience == 0:
self.model.stop_training = True
print('Training stopped. No improvement after {} epochs.'.format(epoch))