# Utilities for Optimization

> This module handles all aspects of the world model, including state representation, environment dynamics, and prediction.

In [None]:
#| default_exp optimizer.utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
from functools import partial
from torch.optim import Optimizer

In [None]:
#| export
# Source - https://stackoverflow.com/a
# Posted by isle_of_gods, modified by community. See post 'Timeline' for change history
# Retrieved 2025-11-15, License - CC BY-SA 4.0

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = float('inf')

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False


In [None]:
#| export
class EarlyStopping(object): # pylint: disable=R0902
    """
    Gives a criterion to stop training when a given metric is not
    improving anymore
    Args:
        mode (str): One of `min`, `max`. In `min` mode, training will
            be stopped when the quantity monitored has stopped
            decreasing; in `max` mode it will be stopped when the
            quantity monitored has stopped increasing. Default: 'min'.
        patience (int): Number of epochs with no improvement after
            which training is stopped. For example, if
            `patience = 2`, then we will ignore the first 2 epochs
            with no improvement, and will only stop learning after the
            3rd epoch if the loss still hasn't improved then.
            Default: 10.
        threshold (float): Threshold for measuring the new optimum,
            to only focus on significant changes. Default: 1e-4.
        threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
            dynamic_threshold = best * ( 1 + threshold ) in 'max'
            mode or best * ( 1 - threshold ) in `min` mode.
            In `abs` mode, dynamic_threshold = best + threshold in
            `max` mode or best - threshold in `min` mode. Default: 'rel'.

    """

    def __init__(self, mode='min', patience=10, threshold=1e-4, threshold_mode='rel'):
        self.patience = patience
        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.best = None
        self.num_bad_epochs = None
        self.mode_worse = None  # the worse value for the chosen mode
        self.is_better = None
        self.last_epoch = -1
        self._init_is_better(mode=mode, threshold=threshold,
                             threshold_mode=threshold_mode)
        self._reset()


In [None]:
#| export
@patch
def _reset(self: EarlyStopping):
        """Resets num_bad_epochs counter and cooldown counter."""
        self.best = self.mode_worse
        self.num_bad_epochs = 0

@patch
def step(self: EarlyStopping, metrics, epoch=None):
    """ Updates early stopping state """
    current = metrics
    if epoch is None:
        epoch = self.last_epoch = self.last_epoch + 1
    self.last_epoch = epoch

    if self.is_better(current, self.best):
        self.best = current
        self.num_bad_epochs = 0
    else:
        self.num_bad_epochs += 1

@property
@patch
def stop(self: EarlyStopping):
    """ Should we stop learning? """
    return self.num_bad_epochs > self.patience


@patch
def _cmp(self: EarlyStopping, mode, threshold_mode, threshold, a, best): # pylint: disable=R0913, R0201
    if mode == 'min' and threshold_mode == 'rel':
        rel_epsilon = 1. - threshold
        return a < best * rel_epsilon

    elif mode == 'min' and threshold_mode == 'abs':
        return a < best - threshold

    elif mode == 'max' and threshold_mode == 'rel':
        rel_epsilon = threshold + 1.
        return a > best * rel_epsilon

    return a > best + threshold

@patch
def _init_is_better(self: EarlyStopping, mode, threshold, threshold_mode):
    if mode not in {'min', 'max'}:
        raise ValueError('mode ' + mode + ' is unknown!')
    if threshold_mode not in {'rel', 'abs'}:
        raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')

    if mode == 'min':
        self.mode_worse = float('inf')
    else:  # mode == 'max':
        self.mode_worse = (-float('inf'))

    self.is_better = partial(self._cmp, mode, threshold_mode, threshold)

@patch
def state_dict(self: EarlyStopping):
    """ Returns early stopping state """
    return {key: value for key, value in self.__dict__.items() if key != 'is_better'}

@patch
def load_state_dict(self: EarlyStopping, state_dict):
    """ Loads early stopping state """
    self.__dict__.update(state_dict)
    self._init_is_better(mode=self.mode, threshold=self.threshold,
                            threshold_mode=self.threshold_mode)



In [None]:
#| export
class ReduceLROnPlateau(object): # pylint: disable=R0902
    """Reduce learning rate when a metric has stopped improving.
    Models often benefit from reducing the learning rate by a factor
    of 2-10 once learning stagnates. This scheduler reads a metrics
    quantity and if no improvement is seen for a 'patience' number
    of epochs, the learning rate is reduced.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        mode (str): One of `min`, `max`. In `min` mode, lr will
            be reduced when the quantity monitored has stopped
            decreasing; in `max` mode it will be reduced when the
            quantity monitored has stopped increasing. Default: 'min'.
        factor (float): Factor by which the learning rate will be
            reduced. new_lr = lr * factor. Default: 0.1.
        patience (int): Number of epochs with no improvement after
            which learning rate will be reduced. For example, if
            `patience = 2`, then we will ignore the first 2 epochs
            with no improvement, and will only decrease the LR after the
            3rd epoch if the loss still hasn't improved then.
            Default: 10.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.
        threshold (float): Threshold for measuring the new optimum,
            to only focus on significant changes. Default: 1e-4.
        threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
            dynamic_threshold = best * ( 1 + threshold ) in 'max'
            mode or best * ( 1 - threshold ) in `min` mode.
            In `abs` mode, dynamic_threshold = best + threshold in
            `max` mode or best - threshold in `min` mode. Default: 'rel'.
        cooldown (int): Number of epochs to wait before resuming
            normal operation after lr has been reduced. Default: 0.
        min_lr (float or list): A scalar or a list of scalars. A
            lower bound on the learning rate of all param groups
            or each group respectively. Default: 0.
        eps (float): Minimal decay applied to lr. If the difference
            between new and old lr is smaller than eps, the update is
            ignored. Default: 1e-8.

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
        >>> for epoch in range(10):
        >>>     train(...)
        >>>     val_loss = validate(...)
        >>>     # Note that step should be called after validate()
        >>>     scheduler.step(val_loss)
    """

    def __init__(self, optimizer, mode='min', factor=0.1, patience=10, # pylint: disable=R0913
                 verbose=False, threshold=1e-4, threshold_mode='rel',
                 cooldown=0, min_lr=0, eps=1e-8):

        if factor >= 1.0:
            raise ValueError('Factor should be < 1.0.')
        self.factor = factor

        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        if isinstance(min_lr, (list, tuple)):
            if len(min_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} min_lrs, got {}".format(
                    len(optimizer.param_groups), len(min_lr)))
            self.min_lrs = list(min_lr)
        else:
            self.min_lrs = [min_lr] * len(optimizer.param_groups)

        self.patience = patience
        self.verbose = verbose
        self.cooldown = cooldown
        self.cooldown_counter = 0
        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.best = None
        self.num_bad_epochs = None
        self.mode_worse = None  # the worse value for the chosen mode
        self.is_better = None
        self.eps = eps
        self.last_epoch = -1
        self._init_is_better(mode=mode, threshold=threshold,
                             threshold_mode=threshold_mode)
        self._reset()


In [None]:
#| export
@patch
def _reset(self: ReduceLROnPlateau):
    """Resets num_bad_epochs counter and cooldown counter."""
    self.best = self.mode_worse
    self.cooldown_counter = 0
    self.num_bad_epochs = 0

@patch
def step(self: ReduceLROnPlateau, metrics, epoch=None):
    """ Updates scheduler state """
    current = metrics
    if epoch is None:
        epoch = self.last_epoch = self.last_epoch + 1
    self.last_epoch = epoch

    if self.is_better(current, self.best):
        self.best = current
        self.num_bad_epochs = 0
    else:
        self.num_bad_epochs += 1

    if self.in_cooldown:
        self.cooldown_counter -= 1
        self.num_bad_epochs = 0  # ignore any bad epochs in cooldown

    if self.num_bad_epochs > self.patience:
        self._reduce_lr(epoch)
        self.cooldown_counter = self.cooldown
        self.num_bad_epochs = 0

@patch
def _reduce_lr(self: ReduceLROnPlateau, epoch):
    for i, param_group in enumerate(self.optimizer.param_groups):
        old_lr = float(param_group['lr'])
        new_lr = max(old_lr * self.factor, self.min_lrs[i])
        if old_lr - new_lr > self.eps:
            param_group['lr'] = new_lr
            if self.verbose:
                print('Epoch {:5d}: reducing learning rate'
                        ' of group {} to {:.4e}.'.format(epoch, i, new_lr))

@property
@patch
def in_cooldown(self: ReduceLROnPlateau):
    """ Are we on CD? """
    return self.cooldown_counter > 0

@patch
def _cmp(self: ReduceLROnPlateau, mode, threshold_mode, threshold, a, best): # pylint: disable=R0913,R0201
    if mode == 'min' and threshold_mode == 'rel':
        rel_epsilon = 1. - threshold
        return a < best * rel_epsilon

    elif mode == 'min' and threshold_mode == 'abs':
        return a < best - threshold

    elif mode == 'max' and threshold_mode == 'rel':
        rel_epsilon = threshold + 1.
        return a > best * rel_epsilon

    return a > best + threshold

@patch
def _init_is_better(self: ReduceLROnPlateau, mode, threshold, threshold_mode):
    if mode not in {'min', 'max'}:
        raise ValueError('mode ' + mode + ' is unknown!')
    if threshold_mode not in {'rel', 'abs'}:
        raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')

    if mode == 'min':
        self.mode_worse = float('inf')
    else:  # mode == 'max':
        self.mode_worse = (-float('inf'))

    self.is_better = partial(self._cmp, mode, threshold_mode, threshold)

@patch
def state_dict(self: ReduceLROnPlateau):
    """ Returns scheduler state """
    return {key: value for key, value in self.__dict__.items()
            if key not in {'optimizer', 'is_better'}}

@patch
def load_state_dict(self: ReduceLROnPlateau, state_dict):
    """ Loads scheduler state """
    self.__dict__.update(state_dict)
    self._init_is_better(mode=self.mode, threshold=self.threshold,
                            threshold_mode=self.threshold_mode)


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()