# Transfer Learning Experiments

Created by: Jacob A Rose  
Created On: Wednesday Oct 6th, 2021  

Based on Notebook located at: https://jarvislabs.ai/blogs/transfer-learning

## Imports & Definitions

In [None]:
import torch
import matplotlib.pyplot as plt
from collections import OrderedDict


max_lr = 0.3
base_lr = 0.1
optim_lr = 0.5


num_epochs = 40

model = torch.nn.Sequential(OrderedDict({"head":torch.nn.Linear(2, 1),
                                         "backbone":torch.nn.Linear(100,2)}))
                            
optimizer = torch.optim.SGD([{"params":model.backbone.parameters(), "lr":optim_lr*0.1, "weight_decay": 0.01},
                             {"params":model.head.parameters(), "lr":optim_lr, "weight_decay": 0.01}])

In [None]:
# start_factor = 0.1
# end_factor = 1.0
# total_iters = 5 #3 # epochs
# last_epoch=-1
# factor_range = end_factor - start_factor
# warmup_lambda = lambda epoch: (1.0+factor_range/(start_factor*epoch + factor_range*(last_epoch-1)+0.00001))



mult_init = 0.05

lr_max = lr
lr_init = lr*mult_init
T_max = 5
T_init = 0

warmup_lambda = lambda epoch: lr_init + ((lr_max-lr_init) / (warmup_epochs)


plt.plot(range(20), list(map(warmup_lambda, range(20))))



In [None]:
from typing import Optional

import torch
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset

from pytorch_lightning import LightningDataModule, LightningModule


class RandomDictDataset(Dataset):
    def __init__(self, size: int, length: int):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        a = self.data[index]
        b = a + 2
        return {"a": a, "b": b}

    def __len__(self):
        return self.len


class RandomDataset(Dataset):
    def __init__(self, size: int, length: int):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class RandomIterableDataset(IterableDataset):
    def __init__(self, size: int, count: int):
        self.count = count
        self.size = size

    def __iter__(self):
        for _ in range(self.count):
            yield torch.randn(self.size)


class RandomIterableDatasetWithLen(IterableDataset):
    def __init__(self, size: int, count: int):
        self.count = count
        self.size = size

    def __iter__(self):
        for _ in range(len(self)):
            yield torch.randn(self.size)

    def __len__(self):
        return self.count


class BoringModel(LightningModule):
    def __init__(self):
        """Testing PL Module.
        Use as follows:
        - subclass
        - modify the behavior for what you want
        class TestModel(BaseTestModel):
            def training_step(...):
                # do your own thing
        or:
        model = BaseTestModel()
        model.training_epoch_end = None
        """
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))

    def step(self, x):
        x = self(x)
        out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
        return out

    def training_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x["x"] for x in outputs]).mean()

    def test_step(self, batch, batch_idx):
        output = self(batch)
        loss = self.loss(batch, output)
        return {"y": loss}

    def test_epoch_end(self, outputs) -> None:
        torch.stack([x["y"] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def test_dataloader(self):
        return DataLoader(RandomDataset(32, 64))

    def predict_dataloader(self):
        return DataLoader(RandomDataset(32, 64))


class BoringDataModule(LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.non_picklable = None
        self.checkpoint_state: Optional[str] = None

    def prepare_data(self):
        self.random_full = RandomDataset(32, 64 * 4)

    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            self.random_train = Subset(self.random_full, indices=range(64))

        if stage in ("fit", "validate") or stage is None:
            self.random_val = Subset(self.random_full, indices=range(64, 64 * 2))

        if stage == "test" or stage is None:
            self.random_test = Subset(self.random_full, indices=range(64 * 2, 64 * 3))

        if stage == "predict" or stage is None:
            self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))

    def train_dataloader(self):
        return DataLoader(self.random_train)

    def val_dataloader(self):
        return DataLoader(self.random_val)

    def test_dataloader(self):
        return DataLoader(self.random_test)

    def predict_dataloader(self):
        return DataLoader(self.random_predict)


In [None]:
!ls '/media/data/conda/jrose3/envs/sequoia/lib/python3.8/site-packages/pytorch_lightning/tests'

In [None]:
import torch
from pytorch_lightning import Trainer

class WarmupLRScheduler(torch.optim.lr_scheduler._LRScheduler):
    """
    Warmup learning rate until `total_steps`

    Args:
        optimizer (Optimizer): wrapped optimizer.
        configs (DictConfig): configuration set.
    """
    def __init__(
            self,
            optimizer
    ) -> None:
        warmup_steps = 6
        peak_lr = 0.1
        self.init_lr = 0.001
#         warmup_steps = 2

        if warmup_steps != 0:
            warmup_rate = peak_lr - self.init_lr
            self.warmup_rate = warmup_rate / warmup_steps
        else:
            self.warmup_rate = 0
            
        print(f"self.warmup_rate={self.warmup_rate}")
        self.update_steps = 1
        self.lr = self.init_lr
        self.warmup_steps = warmup_steps
        super().__init__(optimizer)

    def set_lr(self, optimizer, lr):
        for pg in optimizer.param_groups:
            pg["lr"] = lr

    def step(self, val_loss = None):
        print(self.lr)
        if self.update_steps < self.warmup_steps:
            lr = self.init_lr + self.warmup_rate * self.update_steps
            self.set_lr(self.optimizer, lr)
            self.lr = lr
            print(f"new_lr={self.lr}")
        else:
            print(f"No update: self.update_steps:{self.update_steps}, self.warmup_steps:{self.warmup_steps}")
        self.update_steps += 1
        return self.lr

class TestModel(BoringModel):

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        return [optimizer], [{"scheduler":WarmupLRScheduler(optimizer), "interval": "step"}]

model = TestModel()

trainer = Trainer(
    max_epochs=1,
    limit_train_batches=5,
)
trainer.fit(model)


## lr_scheduler.py

In [None]:
import types
import math
from torch._six import inf
from functools import wraps
import warnings
import weakref
from collections import Counter
from bisect import bisect_right

from torch.optim.optimizer import Optimizer


EPOCH_DEPRECATION_WARNING = (
    "The epoch parameter in `scheduler.step()` was not necessary and is being "
    "deprecated where possible. Please use `scheduler.step()` to step the "
    "scheduler. During the deprecation, if epoch is different from None, the "
    "closed form is used instead of the new chainable form, where available. "
    "Please open an issue if you are unable to replicate your use case: "
    "https://github.com/pytorch/pytorch/issues/new/choose."
)

class _LRScheduler(object):

    def __init__(self, optimizer, last_epoch=-1, verbose=False):

        # Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        # Initialize epoch and base learning rates
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
        self.last_epoch = last_epoch

        # Following https://github.com/pytorch/pytorch/issues/20124
        # We would like to ensure that `lr_scheduler.step()` is called after
        # `optimizer.step()`
        def with_counter(method):
            if getattr(method, '_with_counter', False):
                # `optimizer.step()` has already been replaced, return.
                return method

            # Keep a weak reference to the optimizer instance to prevent
            # cyclic references.
            instance_ref = weakref.ref(method.__self__)
            # Get the unbound method for the same purpose.
            func = method.__func__
            cls = instance_ref().__class__
            del method

            @wraps(func)
            def wrapper(*args, **kwargs):
                instance = instance_ref()
                instance._step_count += 1
                wrapped = func.__get__(instance, cls)
                return wrapped(*args, **kwargs)

            # Note that the returned function here is no longer a bound method,
            # so attributes like `__func__` and `__self__` no longer exist.
            wrapper._with_counter = True
            return wrapper

        self.optimizer.step = with_counter(self.optimizer.step)
        self.optimizer._step_count = 0
        self._step_count = 0
        self.verbose = verbose

        self.step()

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def get_last_lr(self):
        """ Return last computed learning rate by current scheduler.
        """
        return self._last_lr

    def get_lr(self):
        # Compute learning rate using chainable form of the scheduler
        raise NotImplementedError

    def print_lr(self, is_verbose, group, lr, epoch=None):
        """Display the current learning rate.
        """
        if is_verbose:
            if epoch is None:
                print('Adjusting learning rate'
                      ' of group {} to {:.4e}.'.format(group, lr))
            else:
                print('Epoch {:5d}: adjusting learning rate'
                      ' of group {} to {:.4e}.'.format(epoch, group, lr))


    def step(self, epoch=None):
        # Raise a warning if old pattern is detected
        # https://github.com/pytorch/pytorch/issues/20124
        if self._step_count == 1:
            if not hasattr(self.optimizer.step, "_with_counter"):
                warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
                              "initialization. Please, make sure to call `optimizer.step()` before "
                              "`lr_scheduler.step()`. See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

            # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
            elif self.optimizer._step_count < 1:
                warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
                              "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
                              "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
                              "will result in PyTorch skipping the first value of the learning rate schedule. "
                              "See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
        self._step_count += 1

        class _enable_get_lr_call:

            def __init__(self, o):
                self.o = o

            def __enter__(self):
                self.o._get_lr_called_within_step = True
                return self

            def __exit__(self, type, value, traceback):
                self.o._get_lr_called_within_step = False

        with _enable_get_lr_call(self):
            if epoch is None:
                self.last_epoch += 1
                values = self.get_lr()
            else:
                warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
                self.last_epoch = epoch
                if hasattr(self, "_get_closed_form_lr"):
                    values = self._get_closed_form_lr()
                else:
                    values = self.get_lr()

        for i, data in enumerate(zip(self.optimizer.param_groups, values)):
            param_group, lr = data
            param_group['lr'] = lr
            self.print_lr(self.verbose, i, lr, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]


class LambdaLR(_LRScheduler):
    """Sets the learning rate of each parameter group to the initial lr
    times a given function. When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        lr_lambda (function or list): A function which computes a multiplicative
            factor given an integer parameter epoch, or a list of such
            functions, one for each group in optimizer.param_groups.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer has two groups.
        >>> lambda1 = lambda epoch: epoch // 30
        >>> lambda2 = lambda epoch: 0.95 ** epoch
        >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
        self.optimizer = optimizer

        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
        else:
            if len(lr_lambda) != len(optimizer.param_groups):
                raise ValueError("Expected {} lr_lambdas, but got {}".format(
                    len(optimizer.param_groups), len(lr_lambda)))
            self.lr_lambdas = list(lr_lambda)
        super(LambdaLR, self).__init__(optimizer, last_epoch, verbose)

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        The learning rate lambda functions will only be saved if they are callable objects
        and not if they are functions or lambdas.

        When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
        """

        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
        state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)

        for idx, fn in enumerate(self.lr_lambdas):
            if not isinstance(fn, types.FunctionType):
                state_dict['lr_lambdas'][idx] = fn.__dict__.copy()

        return state_dict

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """

        lr_lambdas = state_dict.pop('lr_lambdas')
        self.__dict__.update(state_dict)
        # Restore state_dict keys in order to prevent side effects
        # https://github.com/pytorch/pytorch/issues/32756
        state_dict['lr_lambdas'] = lr_lambdas

        for idx, fn in enumerate(lr_lambdas):
            if fn is not None:
                self.lr_lambdas[idx].__dict__.update(fn)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.")

        return [base_lr * lmbda(self.last_epoch)
                for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]


class MultiplicativeLR(_LRScheduler):
    """Multiply the learning rate of each parameter group by the factor given
    in the specified function. When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        lr_lambda (function or list): A function which computes a multiplicative
            factor given an integer parameter epoch, or a list of such
            functions, one for each group in optimizer.param_groups.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> lmbda = lambda epoch: 0.95
        >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
        self.optimizer = optimizer

        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
        else:
            if len(lr_lambda) != len(optimizer.param_groups):
                raise ValueError("Expected {} lr_lambdas, but got {}".format(
                    len(optimizer.param_groups), len(lr_lambda)))
            self.lr_lambdas = list(lr_lambda)
        super(MultiplicativeLR, self).__init__(optimizer, last_epoch, verbose)

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        The learning rate lambda functions will only be saved if they are callable objects
        and not if they are functions or lambdas.
        """
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
        state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)

        for idx, fn in enumerate(self.lr_lambdas):
            if not isinstance(fn, types.FunctionType):
                state_dict['lr_lambdas'][idx] = fn.__dict__.copy()

        return state_dict

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        lr_lambdas = state_dict.pop('lr_lambdas')
        self.__dict__.update(state_dict)
        # Restore state_dict keys in order to prevent side effects
        # https://github.com/pytorch/pytorch/issues/32756
        state_dict['lr_lambdas'] = lr_lambdas

        for idx, fn in enumerate(lr_lambdas):
            if fn is not None:
                self.lr_lambdas[idx].__dict__.update(fn)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch > 0:
            return [group['lr'] * lmbda(self.last_epoch)
                    for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
        else:
            return [group['lr'] for group in self.optimizer.param_groups]


class StepLR(_LRScheduler):
    """Decays the learning rate of each parameter group by gamma every
    step_size epochs. Notice that such decay can happen simultaneously with
    other changes to the learning rate from outside this scheduler. When
    last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        step_size (int): Period of learning rate decay.
        gamma (float): Multiplicative factor of learning rate decay.
            Default: 0.1.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.05     if epoch < 30
        >>> # lr = 0.005    if 30 <= epoch < 60
        >>> # lr = 0.0005   if 60 <= epoch < 90
        >>> # ...
        >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
        self.step_size = step_size
        self.gamma = gamma
        super(StepLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * self.gamma
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
                for base_lr in self.base_lrs]


class MultiStepLR(_LRScheduler):
    """Decays the learning rate of each parameter group by gamma once the
    number of epoch reaches one of the milestones. Notice that such decay can
    happen simultaneously with other changes to the learning rate from outside
    this scheduler. When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        milestones (list): List of epoch indices. Must be increasing.
        gamma (float): Multiplicative factor of learning rate decay.
            Default: 0.1.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.05     if epoch < 30
        >>> # lr = 0.005    if 30 <= epoch < 80
        >>> # lr = 0.0005   if epoch >= 80
        >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
        self.milestones = Counter(milestones)
        self.gamma = gamma
        super(MultiStepLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch not in self.milestones:
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        milestones = list(sorted(self.milestones.elements()))
        return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
                for base_lr in self.base_lrs]


class ConstantLR(_LRScheduler):
    """Decays the learning rate of each parameter group by a small constant factor until the
    number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
    happen simultaneously with other changes to the learning rate from outside this scheduler.
    When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
        total_iters (int): The number of steps that the scheduler decays the learning rate.
            Default: 5.
        last_epoch (int): The index of the last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.025   if epoch == 0
        >>> # lr = 0.025   if epoch == 1
        >>> # lr = 0.025   if epoch == 2
        >>> # lr = 0.025   if epoch == 3
        >>> # lr = 0.05    if epoch >= 4
        >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False):
        if factor > 1.0 or factor < 0:
            raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')

        self.factor = factor
        self.total_iters = total_iters
        super(ConstantLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] * self.factor for group in self.optimizer.param_groups]

        if (self.last_epoch > self.total_iters or
                (self.last_epoch != self.total_iters)):
            return [group['lr'] for group in self.optimizer.param_groups]

        if (self.last_epoch == self.total_iters):
            return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
                for base_lr in self.base_lrs]


class LinearLR(_LRScheduler):
    """Decays the learning rate of each parameter group by linearly changing small
    multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
    Notice that such decay can happen simultaneously with other changes to the learning rate
    from outside this scheduler. When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        start_factor (float): The number we multiply learning rate in the first epoch.
            The multiplication factor changes towards end_factor in the following epochs.
            Default: 1./3.
        end_factor (float): The number we multiply learning rate at the end of linear changing
            process. Default: 1.0.
        total_iters (int): The number of iterations that multiplicative factor reaches to 1.
            Default: 5.
        last_epoch (int): The index of the last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.025    if epoch == 0
        >>> # lr = 0.03125  if epoch == 1
        >>> # lr = 0.0375   if epoch == 2
        >>> # lr = 0.04375  if epoch == 3
        >>> # lr = 0.005    if epoch >= 4
        >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
                 verbose=False):
        if start_factor > 1.0 or start_factor < 0:
            raise ValueError('Starting multiplicative factor expected to be between 0 and 1.')

        if end_factor > 1.0 or end_factor < 0:
            raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')

        self.start_factor = start_factor
        self.end_factor = end_factor
        self.total_iters = total_iters
        super(LinearLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]

        if (self.last_epoch > self.total_iters):
            return [group['lr'] for group in self.optimizer.param_groups]

        return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
                (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * (self.start_factor +
                (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
                for base_lr in self.base_lrs]


class ExponentialLR(_LRScheduler):
    """Decays the learning rate of each parameter group by gamma every epoch.
    When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        gamma (float): Multiplicative factor of learning rate decay.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.
    """

    def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
        super(ExponentialLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * self.gamma
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * self.gamma ** self.last_epoch
                for base_lr in self.base_lrs]


class SequentialLR(_LRScheduler):
    """Receives the list of schedulers that is expected to be called sequentially during
    optimization process and milestone points that provides exact intervals to reflect
    which scheduler is supposed to be called at a given epoch.

    Args:
        schedulers (list): List of chained schedulers.
        milestones (list): List of integers that reflects milestone points.

    Example:
        >>> # Assuming optimizer uses lr = 1. for all groups
        >>> # lr = 0.1     if epoch == 0
        >>> # lr = 0.1     if epoch == 1
        >>> # lr = 0.9     if epoch == 2
        >>> # lr = 0.81    if epoch == 3
        >>> # lr = 0.729   if epoch == 4
        >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
        >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
        >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
        for scheduler_idx in range(1, len(schedulers)):
            if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
                raise ValueError(
                    "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
                    "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
                )
        if (len(milestones) != len(schedulers) - 1):
            raise ValueError(
                "Sequential Schedulers expects number of schedulers provided to be one more "
                "than the number of milestone points, but got number of schedulers {} and the "
                "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
            )
        self._schedulers = schedulers
        self._milestones = milestones
        self.last_epoch = last_epoch + 1

    def step(self):
        self.last_epoch += 1
        idx = bisect_right(self._milestones, self.last_epoch)
        if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
            self._schedulers[idx].step(0)
        else:
            self._schedulers[idx].step()

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        The wrapped scheduler states will also be saved.
        """
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
        state_dict['_schedulers'] = [None] * len(self._schedulers)

        for idx, s in enumerate(self._schedulers):
            state_dict['_schedulers'][idx] = s.state_dict()

        return state_dict

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        _schedulers = state_dict.pop('_schedulers')
        self.__dict__.update(state_dict)
        # Restore state_dict keys in order to prevent side effects
        # https://github.com/pytorch/pytorch/issues/32756
        state_dict['_schedulers'] = _schedulers

        for idx, s in enumerate(_schedulers):
            self._schedulers[idx].load_state_dict(s)


class CosineAnnealingLR(_LRScheduler):
    r"""Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr and
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:

    .. math::
        \begin{aligned}
            \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
            + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
            & T_{cur} \neq (2k+1)T_{max}; \\
            \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
            \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
            & T_{cur} = (2k+1)T_{max}.
        \end{aligned}

    When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
    is defined recursively, the learning rate can be simultaneously modified
    outside this scheduler by other operators. If the learning rate is set
    solely by this scheduler, the learning rate at each step becomes:

    .. math::
        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
        \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)

    It has been proposed in
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
    implements the cosine annealing part of SGDR, and not the restarts.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        T_max (int): Maximum number of iterations.
        eta_min (float): Minimum learning rate. Default: 0.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
        https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
        self.T_max = T_max
        self.eta_min = eta_min
        super(CosineAnnealingLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] for group in self.optimizer.param_groups]
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return [group['lr'] + (base_lr - self.eta_min) *
                    (1 - math.cos(math.pi / self.T_max)) / 2
                    for base_lr, group in
                    zip(self.base_lrs, self.optimizer.param_groups)]
        return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
                (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
                (group['lr'] - self.eta_min) + self.eta_min
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [self.eta_min + (base_lr - self.eta_min) *
                (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
                for base_lr in self.base_lrs]


class ChainedScheduler(_LRScheduler):
    """Chains list of learning rate schedulers. It takes a list of chainable learning
    rate schedulers and performs consecutive step() functions belong to them by just
    one call.

    Args:
        schedulers (list): List of chained schedulers.

    Example:
        >>> # Assuming optimizer uses lr = 1. for all groups
        >>> # lr = 0.09     if epoch == 0
        >>> # lr = 0.081    if epoch == 1
        >>> # lr = 0.729    if epoch == 2
        >>> # lr = 0.6561   if epoch == 3
        >>> # lr = 0.59049  if epoch >= 4
        >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
        >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
        >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, schedulers):
        for scheduler_idx in range(1, len(schedulers)):
            if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
                raise ValueError(
                    "ChainedScheduler expects all schedulers to belong to the same optimizer, but "
                    "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
                )
        self._schedulers = list(schedulers)

    def step(self):
        for scheduler in self._schedulers:
            scheduler.step()

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        The wrapped scheduler states will also be saved.
        """
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
        state_dict['_schedulers'] = [None] * len(self._schedulers)

        for idx, s in enumerate(self._schedulers):
            state_dict['_schedulers'][idx] = s.state_dict()

        return state_dict

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        _schedulers = state_dict.pop('_schedulers')
        self.__dict__.update(state_dict)
        # Restore state_dict keys in order to prevent side effects
        # https://github.com/pytorch/pytorch/issues/32756
        state_dict['_schedulers'] = _schedulers

        for idx, s in enumerate(_schedulers):
            self._schedulers[idx].load_state_dict(s)


class ReduceLROnPlateau(object):
    """Reduce learning rate when a metric has stopped improving.
    Models often benefit from reducing the learning rate by a factor
    of 2-10 once learning stagnates. This scheduler reads a metrics
    quantity and if no improvement is seen for a 'patience' number
    of epochs, the learning rate is reduced.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        mode (str): One of `min`, `max`. In `min` mode, lr will
            be reduced when the quantity monitored has stopped
            decreasing; in `max` mode it will be reduced when the
            quantity monitored has stopped increasing. Default: 'min'.
        factor (float): Factor by which the learning rate will be
            reduced. new_lr = lr * factor. Default: 0.1.
        patience (int): Number of epochs with no improvement after
            which learning rate will be reduced. For example, if
            `patience = 2`, then we will ignore the first 2 epochs
            with no improvement, and will only decrease the LR after the
            3rd epoch if the loss still hasn't improved then.
            Default: 10.
        threshold (float): Threshold for measuring the new optimum,
            to only focus on significant changes. Default: 1e-4.
        threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
            dynamic_threshold = best * ( 1 + threshold ) in 'max'
            mode or best * ( 1 - threshold ) in `min` mode.
            In `abs` mode, dynamic_threshold = best + threshold in
            `max` mode or best - threshold in `min` mode. Default: 'rel'.
        cooldown (int): Number of epochs to wait before resuming
            normal operation after lr has been reduced. Default: 0.
        min_lr (float or list): A scalar or a list of scalars. A
            lower bound on the learning rate of all param groups
            or each group respectively. Default: 0.
        eps (float): Minimal decay applied to lr. If the difference
            between new and old lr is smaller than eps, the update is
            ignored. Default: 1e-8.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
        >>> for epoch in range(10):
        >>>     train(...)
        >>>     val_loss = validate(...)
        >>>     # Note that step should be called after validate()
        >>>     scheduler.step(val_loss)
    """

    def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
                 threshold=1e-4, threshold_mode='rel', cooldown=0,
                 min_lr=0, eps=1e-8, verbose=False):

        if factor >= 1.0:
            raise ValueError('Factor should be < 1.0.')
        self.factor = factor

        # Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        if isinstance(min_lr, list) or isinstance(min_lr, tuple):
            if len(min_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} min_lrs, got {}".format(
                    len(optimizer.param_groups), len(min_lr)))
            self.min_lrs = list(min_lr)
        else:
            self.min_lrs = [min_lr] * len(optimizer.param_groups)

        self.patience = patience
        self.verbose = verbose
        self.cooldown = cooldown
        self.cooldown_counter = 0
        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.best = None
        self.num_bad_epochs = None
        self.mode_worse = None  # the worse value for the chosen mode
        self.eps = eps
        self.last_epoch = 0
        self._init_is_better(mode=mode, threshold=threshold,
                             threshold_mode=threshold_mode)
        self._reset()

    def _reset(self):
        """Resets num_bad_epochs counter and cooldown counter."""
        self.best = self.mode_worse
        self.cooldown_counter = 0
        self.num_bad_epochs = 0

    def step(self, metrics, epoch=None):
        # convert `metrics` to float, in case it's a zero-dim Tensor
        current = float(metrics)
        if epoch is None:
            epoch = self.last_epoch + 1
        else:
            warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
        self.last_epoch = epoch

        if self.is_better(current, self.best):
            self.best = current
            self.num_bad_epochs = 0
        else:
            self.num_bad_epochs += 1

        if self.in_cooldown:
            self.cooldown_counter -= 1
            self.num_bad_epochs = 0  # ignore any bad epochs in cooldown

        if self.num_bad_epochs > self.patience:
            self._reduce_lr(epoch)
            self.cooldown_counter = self.cooldown
            self.num_bad_epochs = 0

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

    def _reduce_lr(self, epoch):
        for i, param_group in enumerate(self.optimizer.param_groups):
            old_lr = float(param_group['lr'])
            new_lr = max(old_lr * self.factor, self.min_lrs[i])
            if old_lr - new_lr > self.eps:
                param_group['lr'] = new_lr
                if self.verbose:
                    print('Epoch {:5d}: reducing learning rate'
                          ' of group {} to {:.4e}.'.format(epoch, i, new_lr))

    @property
    def in_cooldown(self):
        return self.cooldown_counter > 0

    def is_better(self, a, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            rel_epsilon = 1. - self.threshold
            return a < best * rel_epsilon

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return a < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            rel_epsilon = self.threshold + 1.
            return a > best * rel_epsilon

        else:  # mode == 'max' and epsilon_mode == 'abs':
            return a > best + self.threshold

    def _init_is_better(self, mode, threshold, threshold_mode):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if threshold_mode not in {'rel', 'abs'}:
            raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')

        if mode == 'min':
            self.mode_worse = inf
        else:  # mode == 'max':
            self.mode_worse = -inf

        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode

    def state_dict(self):
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)
        self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)


class CyclicLR(_LRScheduler):
    r"""Sets the learning rate of each parameter group according to
    cyclical learning rate policy (CLR). The policy cycles the learning
    rate between two boundaries with a constant frequency, as detailed in
    the paper `Cyclical Learning Rates for Training Neural Networks`_.
    The distance between the two boundaries can be scaled on a per-iteration
    or per-cycle basis.

    Cyclical learning rate policy changes the learning rate after every batch.
    `step` should be called after a batch has been used for training.

    This class has three built-in policies, as put forth in the paper:

    * "triangular": A basic triangular cycle without amplitude scaling.
    * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
    * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
      at each cycle iteration.

    This implementation was adapted from the github repo: `bckenstler/CLR`_

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        base_lr (float or list): Initial learning rate which is the
            lower boundary in the cycle for each parameter group.
        max_lr (float or list): Upper learning rate boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (max_lr - base_lr).
            The lr at any cycle is the sum of base_lr
            and some scaling of the amplitude; therefore
            max_lr may not actually be reached depending on
            scaling function.
        step_size_up (int): Number of training iterations in the
            increasing half of a cycle. Default: 2000
        step_size_down (int): Number of training iterations in the
            decreasing half of a cycle. If step_size_down is None,
            it is set to step_size_up. Default: None
        mode (str): One of {triangular, triangular2, exp_range}.
            Values correspond to policies detailed above.
            If scale_fn is not None, this argument is ignored.
            Default: 'triangular'
        gamma (float): Constant in 'exp_range' scaling function:
            gamma**(cycle iterations)
            Default: 1.0
        scale_fn (function): Custom scaling policy defined by a single
            argument lambda function, where
            0 <= scale_fn(x) <= 1 for all x >= 0.
            If specified, then 'mode' is ignored.
            Default: None
        scale_mode (str): {'cycle', 'iterations'}.
            Defines whether scale_fn is evaluated on
            cycle number or cycle iterations (training
            iterations since start of cycle).
            Default: 'cycle'
        cycle_momentum (bool): If ``True``, momentum is cycled inversely
            to learning rate between 'base_momentum' and 'max_momentum'.
            Default: True
        base_momentum (float or list): Lower momentum boundaries in the cycle
            for each parameter group. Note that momentum is cycled inversely
            to learning rate; at the peak of a cycle, momentum is
            'base_momentum' and learning rate is 'max_lr'.
            Default: 0.8
        max_momentum (float or list): Upper momentum boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (max_momentum - base_momentum).
            The momentum at any cycle is the difference of max_momentum
            and some scaling of the amplitude; therefore
            base_momentum may not actually be reached depending on
            scaling function. Note that momentum is cycled inversely
            to learning rate; at the start of a cycle, momentum is 'max_momentum'
            and learning rate is 'base_lr'
            Default: 0.9
        last_epoch (int): The index of the last batch. This parameter is used when
            resuming a training job. Since `step()` should be invoked after each
            batch instead of after each epoch, this number represents the total
            number of *batches* computed, not the total number of epochs computed.
            When last_epoch=-1, the schedule is started from the beginning.
            Default: -1
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()


    .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
    .. _bckenstler/CLR: https://github.com/bckenstler/CLR
    """

    def __init__(self,
                 optimizer,
                 base_lr,
                 max_lr,
                 step_size_up=2000,
                 step_size_down=None,
                 mode='triangular',
                 gamma=1.,
                 scale_fn=None,
                 scale_mode='cycle',
                 cycle_momentum=True,
                 base_momentum=0.8,
                 max_momentum=0.9,
                 last_epoch=-1,
                 verbose=False):

        # Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        base_lrs = self._format_param('base_lr', optimizer, base_lr)
        if last_epoch == -1:
            for lr, group in zip(base_lrs, optimizer.param_groups):
                group['lr'] = lr

        self.max_lrs = self._format_param('max_lr', optimizer, max_lr)

        step_size_up = float(step_size_up)
        step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
        self.total_size = step_size_up + step_size_down
        self.step_ratio = step_size_up / self.total_size

        if mode not in ['triangular', 'triangular2', 'exp_range'] \
                and scale_fn is None:
            raise ValueError('mode is invalid and scale_fn is None')

        self.mode = mode
        self.gamma = gamma

        if scale_fn is None:
            if self.mode == 'triangular':
                self.scale_fn = self._triangular_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'triangular2':
                self.scale_fn = self._triangular2_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'exp_range':
                self.scale_fn = self._exp_range_scale_fn
                self.scale_mode = 'iterations'
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode

        self.cycle_momentum = cycle_momentum
        if cycle_momentum:
            if 'momentum' not in optimizer.defaults:
                raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')

            base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
            if last_epoch == -1:
                for momentum, group in zip(base_momentums, optimizer.param_groups):
                    group['momentum'] = momentum
            self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
            self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)

        super(CyclicLR, self).__init__(optimizer, last_epoch, verbose)
        self.base_lrs = base_lrs

    def _format_param(self, name, optimizer, param):
        """Return correctly formatted lr/momentum for each param group."""
        if isinstance(param, (list, tuple)):
            if len(param) != len(optimizer.param_groups):
                raise ValueError("expected {} values for {}, got {}".format(
                    len(optimizer.param_groups), name, len(param)))
            return param
        else:
            return [param] * len(optimizer.param_groups)

    def _triangular_scale_fn(self, x):
        return 1.

    def _triangular2_scale_fn(self, x):
        return 1 / (2. ** (x - 1))

    def _exp_range_scale_fn(self, x):
        return self.gamma**(x)

    def get_lr(self):
        """Calculates the learning rate at batch index. This function treats
        `self.last_epoch` as the last batch index.

        If `self.cycle_momentum` is ``True``, this function has a side effect of
        updating the optimizer's momentum.
        """

        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        cycle = math.floor(1 + self.last_epoch / self.total_size)
        x = 1. + self.last_epoch / self.total_size - cycle
        if x <= self.step_ratio:
            scale_factor = x / self.step_ratio
        else:
            scale_factor = (x - 1) / (self.step_ratio - 1)

        lrs = []
        for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
            base_height = (max_lr - base_lr) * scale_factor
            if self.scale_mode == 'cycle':
                lr = base_lr + base_height * self.scale_fn(cycle)
            else:
                lr = base_lr + base_height * self.scale_fn(self.last_epoch)
            lrs.append(lr)

        if self.cycle_momentum:
            momentums = []
            for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
                base_height = (max_momentum - base_momentum) * scale_factor
                if self.scale_mode == 'cycle':
                    momentum = max_momentum - base_height * self.scale_fn(cycle)
                else:
                    momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
                momentums.append(momentum)
            for param_group, momentum in zip(self.optimizer.param_groups, momentums):
                param_group['momentum'] = momentum

        return lrs


class CosineAnnealingWarmRestarts(_LRScheduler):
    r"""Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
    is the number of epochs since the last restart and :math:`T_{i}` is the number
    of epochs between two warm restarts in SGDR:

    .. math::
        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
        \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)

    When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
    When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.

    It has been proposed in
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        T_0 (int): Number of iterations for the first restart.
        T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
        eta_min (float, optional): Minimum learning rate. Default: 0.
        last_epoch (int, optional): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
        https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
        self.T_0 = T_0
        self.T_i = T_0
        self.T_mult = T_mult
        self.eta_min = eta_min
        self.T_cur = last_epoch
        super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
                for base_lr in self.base_lrs]

    def step(self, epoch=None):
        """Step could be called after every batch update

        Example:
            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
            >>> iters = len(dataloader)
            >>> for epoch in range(20):
            >>>     for i, sample in enumerate(dataloader):
            >>>         inputs, labels = sample['inputs'], sample['labels']
            >>>         optimizer.zero_grad()
            >>>         outputs = net(inputs)
            >>>         loss = criterion(outputs, labels)
            >>>         loss.backward()
            >>>         optimizer.step()
            >>>         scheduler.step(epoch + i / iters)

        This function can be called in an interleaved way.

        Example:
            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
            >>> for epoch in range(20):
            >>>     scheduler.step()
            >>> scheduler.step(26)
            >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
        """

        if epoch is None and self.last_epoch < 0:
            epoch = 0

        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.T_cur = self.T_cur - self.T_i
                self.T_i = self.T_i * self.T_mult
        else:
            if epoch < 0:
                raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
        self.last_epoch = math.floor(epoch)

        class _enable_get_lr_call:

            def __init__(self, o):
                self.o = o

            def __enter__(self):
                self.o._get_lr_called_within_step = True
                return self

            def __exit__(self, type, value, traceback):
                self.o._get_lr_called_within_step = False
                return self

        with _enable_get_lr_call(self):
            for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
                param_group, lr = data
                param_group['lr'] = lr
                self.print_lr(self.verbose, i, lr, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]


class OneCycleLR(_LRScheduler):
    r"""Sets the learning rate of each parameter group according to the
    1cycle learning rate policy. The 1cycle policy anneals the learning
    rate from an initial learning rate to some maximum learning rate and then
    from that maximum learning rate to some minimum learning rate much lower
    than the initial learning rate.
    This policy was initially described in the paper `Super-Convergence:
    Very Fast Training of Neural Networks Using Large Learning Rates`_.

    The 1cycle learning rate policy changes the learning rate after every batch.
    `step` should be called after a batch has been used for training.

    This scheduler is not chainable.

    Note also that the total number of steps in the cycle can be determined in one
    of two ways (listed in order of precedence):

    #. A value for total_steps is explicitly provided.
    #. A number of epochs (epochs) and a number of steps per epoch
       (steps_per_epoch) are provided.
       In this case, the number of total steps is inferred by
       total_steps = epochs * steps_per_epoch

    You must either provide a value for total_steps or provide a value for both
    epochs and steps_per_epoch.

    The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
    claims that "unpublished work has shown even better results by using only two phases". To
    mimic the behaviour of the original paper instead, set ``three_phase=True``.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        max_lr (float or list): Upper learning rate boundaries in the cycle
            for each parameter group.
        total_steps (int): The total number of steps in the cycle. Note that
            if a value is not provided here, then it must be inferred by providing
            a value for epochs and steps_per_epoch.
            Default: None
        epochs (int): The number of epochs to train for. This is used along
            with steps_per_epoch in order to infer the total number of steps in the cycle
            if a value for total_steps is not provided.
            Default: None
        steps_per_epoch (int): The number of steps per epoch to train for. This is
            used along with epochs in order to infer the total number of steps in the
            cycle if a value for total_steps is not provided.
            Default: None
        pct_start (float): The percentage of the cycle (in number of steps) spent
            increasing the learning rate.
            Default: 0.3
        anneal_strategy (str): {'cos', 'linear'}
            Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
            linear annealing.
            Default: 'cos'
        cycle_momentum (bool): If ``True``, momentum is cycled inversely
            to learning rate between 'base_momentum' and 'max_momentum'.
            Default: True
        base_momentum (float or list): Lower momentum boundaries in the cycle
            for each parameter group. Note that momentum is cycled inversely
            to learning rate; at the peak of a cycle, momentum is
            'base_momentum' and learning rate is 'max_lr'.
            Default: 0.85
        max_momentum (float or list): Upper momentum boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (max_momentum - base_momentum).
            Note that momentum is cycled inversely
            to learning rate; at the start of a cycle, momentum is 'max_momentum'
            and learning rate is 'base_lr'
            Default: 0.95
        div_factor (float): Determines the initial learning rate via
            initial_lr = max_lr/div_factor
            Default: 25
        final_div_factor (float): Determines the minimum learning rate via
            min_lr = initial_lr/final_div_factor
            Default: 1e4
        three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
            learning rate according to 'final_div_factor' instead of modifying the second
            phase (the first two phases will be symmetrical about the step indicated by
            'pct_start').
        last_epoch (int): The index of the last batch. This parameter is used when
            resuming a training job. Since `step()` should be invoked after each
            batch instead of after each epoch, this number represents the total
            number of *batches* computed, not the total number of epochs computed.
            When last_epoch=-1, the schedule is started from the beginning.
            Default: -1
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()


    .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
        https://arxiv.org/abs/1708.07120
    """
    def __init__(self,
                 optimizer,
                 max_lr,
                 total_steps=None,
                 epochs=None,
                 steps_per_epoch=None,
                 pct_start=0.3,
                 anneal_strategy='cos',
                 cycle_momentum=True,
                 base_momentum=0.85,
                 max_momentum=0.95,
                 div_factor=25.,
                 final_div_factor=1e4,
                 three_phase=False,
                 last_epoch=-1,
                 verbose=False):

        # Validate optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        # Validate total_steps
        if total_steps is None and epochs is None and steps_per_epoch is None:
            raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
        elif total_steps is not None:
            if total_steps <= 0 or not isinstance(total_steps, int):
                raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps))
            self.total_steps = total_steps
        else:
            if epochs <= 0 or not isinstance(epochs, int):
                raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
            if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
                raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch))
            self.total_steps = epochs * steps_per_epoch

        if three_phase:
            self._schedule_phases = [
                {
                    'end_step': float(pct_start * self.total_steps) - 1,
                    'start_lr': 'initial_lr',
                    'end_lr': 'max_lr',
                    'start_momentum': 'max_momentum',
                    'end_momentum': 'base_momentum',
                },
                {
                    'end_step': float(2 * pct_start * self.total_steps) - 2,
                    'start_lr': 'max_lr',
                    'end_lr': 'initial_lr',
                    'start_momentum': 'base_momentum',
                    'end_momentum': 'max_momentum',
                },
                {
                    'end_step': self.total_steps - 1,
                    'start_lr': 'initial_lr',
                    'end_lr': 'min_lr',
                    'start_momentum': 'max_momentum',
                    'end_momentum': 'max_momentum',
                },
            ]
        else:
            self._schedule_phases = [
                {
                    'end_step': float(pct_start * self.total_steps) - 1,
                    'start_lr': 'initial_lr',
                    'end_lr': 'max_lr',
                    'start_momentum': 'max_momentum',
                    'end_momentum': 'base_momentum',
                },
                {
                    'end_step': self.total_steps - 1,
                    'start_lr': 'max_lr',
                    'end_lr': 'min_lr',
                    'start_momentum': 'base_momentum',
                    'end_momentum': 'max_momentum',
                },
            ]

        # Validate pct_start
        if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
            raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))

        # Validate anneal_strategy
        if anneal_strategy not in ['cos', 'linear']:
            raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
        elif anneal_strategy == 'cos':
            self.anneal_func = self._annealing_cos
        elif anneal_strategy == 'linear':
            self.anneal_func = self._annealing_linear

        # Initialize learning rate variables
        max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
        if last_epoch == -1:
            for idx, group in enumerate(self.optimizer.param_groups):
                group['initial_lr'] = max_lrs[idx] / div_factor
                group['max_lr'] = max_lrs[idx]
                group['min_lr'] = group['initial_lr'] / final_div_factor

        # Initialize momentum variables
        self.cycle_momentum = cycle_momentum
        if self.cycle_momentum:
            if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
                raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
            self.use_beta1 = 'betas' in self.optimizer.defaults
            max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
            base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
            if last_epoch == -1:
                for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
                    if self.use_beta1:
                        _, beta2 = group['betas']
                        group['betas'] = (m_momentum, beta2)
                    else:
                        group['momentum'] = m_momentum
                    group['max_momentum'] = m_momentum
                    group['base_momentum'] = b_momentum

        super(OneCycleLR, self).__init__(optimizer, last_epoch, verbose)

    def _format_param(self, name, optimizer, param):
        """Return correctly formatted lr/momentum for each param group."""
        if isinstance(param, (list, tuple)):
            if len(param) != len(optimizer.param_groups):
                raise ValueError("expected {} values for {}, got {}".format(
                    len(optimizer.param_groups), name, len(param)))
            return param
        else:
            return [param] * len(optimizer.param_groups)

    def _annealing_cos(self, start, end, pct):
        "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        cos_out = math.cos(math.pi * pct) + 1
        return end + (start - end) / 2.0 * cos_out

    def _annealing_linear(self, start, end, pct):
        "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        return (end - start) * pct + start

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        lrs = []
        step_num = self.last_epoch

        if step_num > self.total_steps:
            raise ValueError("Tried to step {} times. The specified number of total steps is {}"
                             .format(step_num + 1, self.total_steps))

        for group in self.optimizer.param_groups:
            start_step = 0
            for i, phase in enumerate(self._schedule_phases):
                end_step = phase['end_step']
                if step_num <= end_step or i == len(self._schedule_phases) - 1:
                    pct = (step_num - start_step) / (end_step - start_step)
                    computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
                    if self.cycle_momentum:
                        computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
                    break
                start_step = phase['end_step']

            lrs.append(computed_lr)
            if self.cycle_momentum:
                if self.use_beta1:
                    _, beta2 = group['betas']
                    group['betas'] = (computed_momentum, beta2)
                else:
                    group['momentum'] = computed_momentum

        return lrs

## WarmupLRScheduler

## lr_scheduler.py

In [None]:
import types
import math
from torch._six import inf
from functools import wraps
import warnings
import weakref
from collections import Counter
from bisect import bisect_right

from torch.optim.optimizer import Optimizer


EPOCH_DEPRECATION_WARNING = (
    "The epoch parameter in `scheduler.step()` was not necessary and is being "
    "deprecated where possible. Please use `scheduler.step()` to step the "
    "scheduler. During the deprecation, if epoch is different from None, the "
    "closed form is used instead of the new chainable form, where available. "
    "Please open an issue if you are unable to replicate your use case: "
    "https://github.com/pytorch/pytorch/issues/new/choose."
)

class _LRScheduler(object):

    def __init__(self, optimizer, last_epoch=-1, verbose=False):

        # Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        # Initialize epoch and base learning rates
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
        self.last_epoch = last_epoch

        # Following https://github.com/pytorch/pytorch/issues/20124
        # We would like to ensure that `lr_scheduler.step()` is called after
        # `optimizer.step()`
        def with_counter(method):
            if getattr(method, '_with_counter', False):
                # `optimizer.step()` has already been replaced, return.
                return method

            # Keep a weak reference to the optimizer instance to prevent
            # cyclic references.
            instance_ref = weakref.ref(method.__self__)
            # Get the unbound method for the same purpose.
            func = method.__func__
            cls = instance_ref().__class__
            del method

            @wraps(func)
            def wrapper(*args, **kwargs):
                instance = instance_ref()
                instance._step_count += 1
                wrapped = func.__get__(instance, cls)
                return wrapped(*args, **kwargs)

            # Note that the returned function here is no longer a bound method,
            # so attributes like `__func__` and `__self__` no longer exist.
            wrapper._with_counter = True
            return wrapper

        self.optimizer.step = with_counter(self.optimizer.step)
        self.optimizer._step_count = 0
        self._step_count = 0
        self.verbose = verbose

        self.step()

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def get_last_lr(self):
        """ Return last computed learning rate by current scheduler.
        """
        return self._last_lr

    def get_lr(self):
        # Compute learning rate using chainable form of the scheduler
        raise NotImplementedError

    def print_lr(self, is_verbose, group, lr, epoch=None):
        """Display the current learning rate.
        """
        if is_verbose:
            if epoch is None:
                print('Adjusting learning rate'
                      ' of group {} to {:.4e}.'.format(group, lr))
            else:
                print('Epoch {:5d}: adjusting learning rate'
                      ' of group {} to {:.4e}.'.format(epoch, group, lr))


    def step(self, epoch=None):
        # Raise a warning if old pattern is detected
        # https://github.com/pytorch/pytorch/issues/20124
        if self._step_count == 1:
            if not hasattr(self.optimizer.step, "_with_counter"):
                warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
                              "initialization. Please, make sure to call `optimizer.step()` before "
                              "`lr_scheduler.step()`. See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

            # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
            elif self.optimizer._step_count < 1:
                warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
                              "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
                              "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
                              "will result in PyTorch skipping the first value of the learning rate schedule. "
                              "See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
        self._step_count += 1

        class _enable_get_lr_call:

            def __init__(self, o):
                self.o = o

            def __enter__(self):
                self.o._get_lr_called_within_step = True
                return self

            def __exit__(self, type, value, traceback):
                self.o._get_lr_called_within_step = False

        with _enable_get_lr_call(self):
            if epoch is None:
                self.last_epoch += 1
                values = self.get_lr()
            else:
                warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
                self.last_epoch = epoch
                if hasattr(self, "_get_closed_form_lr"):
                    values = self._get_closed_form_lr()
                else:
                    values = self.get_lr()

        for i, data in enumerate(zip(self.optimizer.param_groups, values)):
            param_group, lr = data
            param_group['lr'] = lr
            self.print_lr(self.verbose, i, lr, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]


class LambdaLR(_LRScheduler):
    """Sets the learning rate of each parameter group to the initial lr
    times a given function. When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        lr_lambda (function or list): A function which computes a multiplicative
            factor given an integer parameter epoch, or a list of such
            functions, one for each group in optimizer.param_groups.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer has two groups.
        >>> lambda1 = lambda epoch: epoch // 30
        >>> lambda2 = lambda epoch: 0.95 ** epoch
        >>> scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
        self.optimizer = optimizer

        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
        else:
            if len(lr_lambda) != len(optimizer.param_groups):
                raise ValueError("Expected {} lr_lambdas, but got {}".format(
                    len(optimizer.param_groups), len(lr_lambda)))
            self.lr_lambdas = list(lr_lambda)
        super(LambdaLR, self).__init__(optimizer, last_epoch, verbose)

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        The learning rate lambda functions will only be saved if they are callable objects
        and not if they are functions or lambdas.

        When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.
        """

        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
        state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)

        for idx, fn in enumerate(self.lr_lambdas):
            if not isinstance(fn, types.FunctionType):
                state_dict['lr_lambdas'][idx] = fn.__dict__.copy()

        return state_dict

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        When saving or loading the scheduler, please make sure to also save or load the state of the optimizer.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """

        lr_lambdas = state_dict.pop('lr_lambdas')
        self.__dict__.update(state_dict)
        # Restore state_dict keys in order to prevent side effects
        # https://github.com/pytorch/pytorch/issues/32756
        state_dict['lr_lambdas'] = lr_lambdas

        for idx, fn in enumerate(lr_lambdas):
            if fn is not None:
                self.lr_lambdas[idx].__dict__.update(fn)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.")

        return [base_lr * lmbda(self.last_epoch)
                for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]


class MultiplicativeLR(_LRScheduler):
    """Multiply the learning rate of each parameter group by the factor given
    in the specified function. When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        lr_lambda (function or list): A function which computes a multiplicative
            factor given an integer parameter epoch, or a list of such
            functions, one for each group in optimizer.param_groups.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> lmbda = lambda epoch: 0.95
        >>> scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, lr_lambda, last_epoch=-1, verbose=False):
        self.optimizer = optimizer

        if not isinstance(lr_lambda, list) and not isinstance(lr_lambda, tuple):
            self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
        else:
            if len(lr_lambda) != len(optimizer.param_groups):
                raise ValueError("Expected {} lr_lambdas, but got {}".format(
                    len(optimizer.param_groups), len(lr_lambda)))
            self.lr_lambdas = list(lr_lambda)
        super(MultiplicativeLR, self).__init__(optimizer, last_epoch, verbose)

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        The learning rate lambda functions will only be saved if they are callable objects
        and not if they are functions or lambdas.
        """
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', 'lr_lambdas')}
        state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)

        for idx, fn in enumerate(self.lr_lambdas):
            if not isinstance(fn, types.FunctionType):
                state_dict['lr_lambdas'][idx] = fn.__dict__.copy()

        return state_dict

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        lr_lambdas = state_dict.pop('lr_lambdas')
        self.__dict__.update(state_dict)
        # Restore state_dict keys in order to prevent side effects
        # https://github.com/pytorch/pytorch/issues/32756
        state_dict['lr_lambdas'] = lr_lambdas

        for idx, fn in enumerate(lr_lambdas):
            if fn is not None:
                self.lr_lambdas[idx].__dict__.update(fn)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch > 0:
            return [group['lr'] * lmbda(self.last_epoch)
                    for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)]
        else:
            return [group['lr'] for group in self.optimizer.param_groups]


class StepLR(_LRScheduler):
    """Decays the learning rate of each parameter group by gamma every
    step_size epochs. Notice that such decay can happen simultaneously with
    other changes to the learning rate from outside this scheduler. When
    last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        step_size (int): Period of learning rate decay.
        gamma (float): Multiplicative factor of learning rate decay.
            Default: 0.1.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.05     if epoch < 30
        >>> # lr = 0.005    if 30 <= epoch < 60
        >>> # lr = 0.0005   if 60 <= epoch < 90
        >>> # ...
        >>> scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False):
        self.step_size = step_size
        self.gamma = gamma
        super(StepLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if (self.last_epoch == 0) or (self.last_epoch % self.step_size != 0):
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * self.gamma
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * self.gamma ** (self.last_epoch // self.step_size)
                for base_lr in self.base_lrs]


class MultiStepLR(_LRScheduler):
    """Decays the learning rate of each parameter group by gamma once the
    number of epoch reaches one of the milestones. Notice that such decay can
    happen simultaneously with other changes to the learning rate from outside
    this scheduler. When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        milestones (list): List of epoch indices. Must be increasing.
        gamma (float): Multiplicative factor of learning rate decay.
            Default: 0.1.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.05     if epoch < 30
        >>> # lr = 0.005    if 30 <= epoch < 80
        >>> # lr = 0.0005   if epoch >= 80
        >>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False):
        self.milestones = Counter(milestones)
        self.gamma = gamma
        super(MultiStepLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch not in self.milestones:
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * self.gamma ** self.milestones[self.last_epoch]
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        milestones = list(sorted(self.milestones.elements()))
        return [base_lr * self.gamma ** bisect_right(milestones, self.last_epoch)
                for base_lr in self.base_lrs]


class ConstantLR(_LRScheduler):
    """Decays the learning rate of each parameter group by a small constant factor until the
    number of epoch reaches a pre-defined milestone: total_iters. Notice that such decay can
    happen simultaneously with other changes to the learning rate from outside this scheduler.
    When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        factor (float): The number we multiply learning rate until the milestone. Default: 1./3.
        total_iters (int): The number of steps that the scheduler decays the learning rate.
            Default: 5.
        last_epoch (int): The index of the last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.025   if epoch == 0
        >>> # lr = 0.025   if epoch == 1
        >>> # lr = 0.025   if epoch == 2
        >>> # lr = 0.025   if epoch == 3
        >>> # lr = 0.05    if epoch >= 4
        >>> scheduler = ConstantLR(self.opt, factor=0.5, total_iters=4)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, factor=1.0 / 3, total_iters=5, last_epoch=-1, verbose=False):
        if factor > 1.0 or factor < 0:
            raise ValueError('Constant multiplicative factor expected to be between 0 and 1.')

        self.factor = factor
        self.total_iters = total_iters
        super(ConstantLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] * self.factor for group in self.optimizer.param_groups]

        if (self.last_epoch > self.total_iters or
                (self.last_epoch != self.total_iters)):
            return [group['lr'] for group in self.optimizer.param_groups]

        if (self.last_epoch == self.total_iters):
            return [group['lr'] * (1.0 / self.factor) for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * (self.factor + (self.last_epoch >= self.total_iters) * (1 - self.factor))
                for base_lr in self.base_lrs]


class LinearLR(_LRScheduler):
    """Decays the learning rate of each parameter group by linearly changing small
    multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters.
    Notice that such decay can happen simultaneously with other changes to the learning rate
    from outside this scheduler. When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        start_factor (float): The number we multiply learning rate in the first epoch.
            The multiplication factor changes towards end_factor in the following epochs.
            Default: 1./3.
        end_factor (float): The number we multiply learning rate at the end of linear changing
            process. Default: 1.0.
        total_iters (int): The number of iterations that multiplicative factor reaches to 1.
            Default: 5.
        last_epoch (int): The index of the last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> # Assuming optimizer uses lr = 0.05 for all groups
        >>> # lr = 0.025    if epoch == 0
        >>> # lr = 0.03125  if epoch == 1
        >>> # lr = 0.0375   if epoch == 2
        >>> # lr = 0.04375  if epoch == 3
        >>> # lr = 0.005    if epoch >= 4
        >>> scheduler = LinearLR(self.opt, start_factor=0.5, total_iters=4)
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, start_factor=1.0 / 3, end_factor=1.0, total_iters=5, last_epoch=-1,
                 verbose=False):
        if start_factor > 1.0 or start_factor < 0:
            raise ValueError('Starting multiplicative factor expected to be between 0 and 1.')

        if end_factor > 1.0 or end_factor < 0:
            raise ValueError('Ending multiplicative factor expected to be between 0 and 1.')

        self.start_factor = start_factor
        self.end_factor = end_factor
        self.total_iters = total_iters
        super(LinearLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] * self.start_factor for group in self.optimizer.param_groups]

        if (self.last_epoch > self.total_iters):
            return [group['lr'] for group in self.optimizer.param_groups]

        return [group['lr'] * (1. + (self.end_factor - self.start_factor) /
                (self.total_iters * self.start_factor + (self.last_epoch - 1) * (self.end_factor - self.start_factor)))
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * (self.start_factor +
                (self.end_factor - self.start_factor) * min(self.total_iters, self.last_epoch) / self.total_iters)
                for base_lr in self.base_lrs]


class ExponentialLR(_LRScheduler):
    """Decays the learning rate of each parameter group by gamma every epoch.
    When last_epoch=-1, sets initial lr as lr.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        gamma (float): Multiplicative factor of learning rate decay.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.
    """

    def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
        super(ExponentialLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] for group in self.optimizer.param_groups]
        return [group['lr'] * self.gamma
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [base_lr * self.gamma ** self.last_epoch
                for base_lr in self.base_lrs]


class SequentialLR(_LRScheduler):
    """Receives the list of schedulers that is expected to be called sequentially during
    optimization process and milestone points that provides exact intervals to reflect
    which scheduler is supposed to be called at a given epoch.

    Args:
        schedulers (list): List of chained schedulers.
        milestones (list): List of integers that reflects milestone points.

    Example:
        >>> # Assuming optimizer uses lr = 1. for all groups
        >>> # lr = 0.1     if epoch == 0
        >>> # lr = 0.1     if epoch == 1
        >>> # lr = 0.9     if epoch == 2
        >>> # lr = 0.81    if epoch == 3
        >>> # lr = 0.729   if epoch == 4
        >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
        >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
        >>> scheduler = SequentialLR(self.opt, schedulers=[scheduler1, scheduler2], milestones=[2])
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, optimizer, schedulers, milestones, last_epoch=-1, verbose=False):
        for scheduler_idx in range(1, len(schedulers)):
            if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
                raise ValueError(
                    "Sequential Schedulers expects all schedulers to belong to the same optimizer, but "
                    "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
                )
        if (len(milestones) != len(schedulers) - 1):
            raise ValueError(
                "Sequential Schedulers expects number of schedulers provided to be one more "
                "than the number of milestone points, but got number of schedulers {} and the "
                "number of milestones to be equal to {}".format(len(schedulers), len(milestones))
            )
        self._schedulers = schedulers
        self._milestones = milestones
        self.last_epoch = last_epoch + 1

    def step(self):
        self.last_epoch += 1
        idx = bisect_right(self._milestones, self.last_epoch)
        if idx > 0 and self._milestones[idx - 1] == self.last_epoch:
            self._schedulers[idx].step(0)
        else:
            self._schedulers[idx].step()

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        The wrapped scheduler states will also be saved.
        """
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
        state_dict['_schedulers'] = [None] * len(self._schedulers)

        for idx, s in enumerate(self._schedulers):
            state_dict['_schedulers'][idx] = s.state_dict()

        return state_dict

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        _schedulers = state_dict.pop('_schedulers')
        self.__dict__.update(state_dict)
        # Restore state_dict keys in order to prevent side effects
        # https://github.com/pytorch/pytorch/issues/32756
        state_dict['_schedulers'] = _schedulers

        for idx, s in enumerate(_schedulers):
            self._schedulers[idx].load_state_dict(s)


class CosineAnnealingLR(_LRScheduler):
    r"""Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr and
    :math:`T_{cur}` is the number of epochs since the last restart in SGDR:

    .. math::
        \begin{aligned}
            \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
            + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
            & T_{cur} \neq (2k+1)T_{max}; \\
            \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
            \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
            & T_{cur} = (2k+1)T_{max}.
        \end{aligned}

    When last_epoch=-1, sets initial lr as lr. Notice that because the schedule
    is defined recursively, the learning rate can be simultaneously modified
    outside this scheduler by other operators. If the learning rate is set
    solely by this scheduler, the learning rate at each step becomes:

    .. math::
        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
        \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)

    It has been proposed in
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
    implements the cosine annealing part of SGDR, and not the restarts.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        T_max (int): Maximum number of iterations.
        eta_min (float): Minimum learning rate. Default: 0.
        last_epoch (int): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
        https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False):
        self.T_max = T_max
        self.eta_min = eta_min
        super(CosineAnnealingLR, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        if self.last_epoch == 0:
            return [group['lr'] for group in self.optimizer.param_groups]
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return [group['lr'] + (base_lr - self.eta_min) *
                    (1 - math.cos(math.pi / self.T_max)) / 2
                    for base_lr, group in
                    zip(self.base_lrs, self.optimizer.param_groups)]
        return [(1 + math.cos(math.pi * self.last_epoch / self.T_max)) /
                (1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) *
                (group['lr'] - self.eta_min) + self.eta_min
                for group in self.optimizer.param_groups]

    def _get_closed_form_lr(self):
        return [self.eta_min + (base_lr - self.eta_min) *
                (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
                for base_lr in self.base_lrs]


class ChainedScheduler(_LRScheduler):
    """Chains list of learning rate schedulers. It takes a list of chainable learning
    rate schedulers and performs consecutive step() functions belong to them by just
    one call.

    Args:
        schedulers (list): List of chained schedulers.

    Example:
        >>> # Assuming optimizer uses lr = 1. for all groups
        >>> # lr = 0.09     if epoch == 0
        >>> # lr = 0.081    if epoch == 1
        >>> # lr = 0.729    if epoch == 2
        >>> # lr = 0.6561   if epoch == 3
        >>> # lr = 0.59049  if epoch >= 4
        >>> scheduler1 = ConstantLR(self.opt, factor=0.1, total_iters=2)
        >>> scheduler2 = ExponentialLR(self.opt, gamma=0.9)
        >>> scheduler = ChainedScheduler([scheduler1, scheduler2])
        >>> for epoch in range(100):
        >>>     train(...)
        >>>     validate(...)
        >>>     scheduler.step()
    """

    def __init__(self, schedulers):
        for scheduler_idx in range(1, len(schedulers)):
            if (schedulers[scheduler_idx].optimizer != schedulers[0].optimizer):
                raise ValueError(
                    "ChainedScheduler expects all schedulers to belong to the same optimizer, but "
                    "got schedulers at index {} and {} to be different".format(0, scheduler_idx)
                )
        self._schedulers = list(schedulers)

    def step(self):
        for scheduler in self._schedulers:
            scheduler.step()

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        The wrapped scheduler states will also be saved.
        """
        state_dict = {key: value for key, value in self.__dict__.items() if key not in ('optimizer', '_schedulers')}
        state_dict['_schedulers'] = [None] * len(self._schedulers)

        for idx, s in enumerate(self._schedulers):
            state_dict['_schedulers'][idx] = s.state_dict()

        return state_dict

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        _schedulers = state_dict.pop('_schedulers')
        self.__dict__.update(state_dict)
        # Restore state_dict keys in order to prevent side effects
        # https://github.com/pytorch/pytorch/issues/32756
        state_dict['_schedulers'] = _schedulers

        for idx, s in enumerate(_schedulers):
            self._schedulers[idx].load_state_dict(s)


class ReduceLROnPlateau(object):
    """Reduce learning rate when a metric has stopped improving.
    Models often benefit from reducing the learning rate by a factor
    of 2-10 once learning stagnates. This scheduler reads a metrics
    quantity and if no improvement is seen for a 'patience' number
    of epochs, the learning rate is reduced.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        mode (str): One of `min`, `max`. In `min` mode, lr will
            be reduced when the quantity monitored has stopped
            decreasing; in `max` mode it will be reduced when the
            quantity monitored has stopped increasing. Default: 'min'.
        factor (float): Factor by which the learning rate will be
            reduced. new_lr = lr * factor. Default: 0.1.
        patience (int): Number of epochs with no improvement after
            which learning rate will be reduced. For example, if
            `patience = 2`, then we will ignore the first 2 epochs
            with no improvement, and will only decrease the LR after the
            3rd epoch if the loss still hasn't improved then.
            Default: 10.
        threshold (float): Threshold for measuring the new optimum,
            to only focus on significant changes. Default: 1e-4.
        threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
            dynamic_threshold = best * ( 1 + threshold ) in 'max'
            mode or best * ( 1 - threshold ) in `min` mode.
            In `abs` mode, dynamic_threshold = best + threshold in
            `max` mode or best - threshold in `min` mode. Default: 'rel'.
        cooldown (int): Number of epochs to wait before resuming
            normal operation after lr has been reduced. Default: 0.
        min_lr (float or list): A scalar or a list of scalars. A
            lower bound on the learning rate of all param groups
            or each group respectively. Default: 0.
        eps (float): Minimal decay applied to lr. If the difference
            between new and old lr is smaller than eps, the update is
            ignored. Default: 1e-8.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
        >>> for epoch in range(10):
        >>>     train(...)
        >>>     val_loss = validate(...)
        >>>     # Note that step should be called after validate()
        >>>     scheduler.step(val_loss)
    """

    def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
                 threshold=1e-4, threshold_mode='rel', cooldown=0,
                 min_lr=0, eps=1e-8, verbose=False):

        if factor >= 1.0:
            raise ValueError('Factor should be < 1.0.')
        self.factor = factor

        # Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        if isinstance(min_lr, list) or isinstance(min_lr, tuple):
            if len(min_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} min_lrs, got {}".format(
                    len(optimizer.param_groups), len(min_lr)))
            self.min_lrs = list(min_lr)
        else:
            self.min_lrs = [min_lr] * len(optimizer.param_groups)

        self.patience = patience
        self.verbose = verbose
        self.cooldown = cooldown
        self.cooldown_counter = 0
        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.best = None
        self.num_bad_epochs = None
        self.mode_worse = None  # the worse value for the chosen mode
        self.eps = eps
        self.last_epoch = 0
        self._init_is_better(mode=mode, threshold=threshold,
                             threshold_mode=threshold_mode)
        self._reset()

    def _reset(self):
        """Resets num_bad_epochs counter and cooldown counter."""
        self.best = self.mode_worse
        self.cooldown_counter = 0
        self.num_bad_epochs = 0

    def step(self, metrics, epoch=None):
        # convert `metrics` to float, in case it's a zero-dim Tensor
        current = float(metrics)
        if epoch is None:
            epoch = self.last_epoch + 1
        else:
            warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
        self.last_epoch = epoch

        if self.is_better(current, self.best):
            self.best = current
            self.num_bad_epochs = 0
        else:
            self.num_bad_epochs += 1

        if self.in_cooldown:
            self.cooldown_counter -= 1
            self.num_bad_epochs = 0  # ignore any bad epochs in cooldown

        if self.num_bad_epochs > self.patience:
            self._reduce_lr(epoch)
            self.cooldown_counter = self.cooldown
            self.num_bad_epochs = 0

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

    def _reduce_lr(self, epoch):
        for i, param_group in enumerate(self.optimizer.param_groups):
            old_lr = float(param_group['lr'])
            new_lr = max(old_lr * self.factor, self.min_lrs[i])
            if old_lr - new_lr > self.eps:
                param_group['lr'] = new_lr
                if self.verbose:
                    print('Epoch {:5d}: reducing learning rate'
                          ' of group {} to {:.4e}.'.format(epoch, i, new_lr))

    @property
    def in_cooldown(self):
        return self.cooldown_counter > 0

    def is_better(self, a, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            rel_epsilon = 1. - self.threshold
            return a < best * rel_epsilon

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return a < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            rel_epsilon = self.threshold + 1.
            return a > best * rel_epsilon

        else:  # mode == 'max' and epsilon_mode == 'abs':
            return a > best + self.threshold

    def _init_is_better(self, mode, threshold, threshold_mode):
        if mode not in {'min', 'max'}:
            raise ValueError('mode ' + mode + ' is unknown!')
        if threshold_mode not in {'rel', 'abs'}:
            raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')

        if mode == 'min':
            self.mode_worse = inf
        else:  # mode == 'max':
            self.mode_worse = -inf

        self.mode = mode
        self.threshold = threshold
        self.threshold_mode = threshold_mode

    def state_dict(self):
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict):
        self.__dict__.update(state_dict)
        self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)


class CyclicLR(_LRScheduler):
    r"""Sets the learning rate of each parameter group according to
    cyclical learning rate policy (CLR). The policy cycles the learning
    rate between two boundaries with a constant frequency, as detailed in
    the paper `Cyclical Learning Rates for Training Neural Networks`_.
    The distance between the two boundaries can be scaled on a per-iteration
    or per-cycle basis.

    Cyclical learning rate policy changes the learning rate after every batch.
    `step` should be called after a batch has been used for training.

    This class has three built-in policies, as put forth in the paper:

    * "triangular": A basic triangular cycle without amplitude scaling.
    * "triangular2": A basic triangular cycle that scales initial amplitude by half each cycle.
    * "exp_range": A cycle that scales initial amplitude by :math:`\text{gamma}^{\text{cycle iterations}}`
      at each cycle iteration.

    This implementation was adapted from the github repo: `bckenstler/CLR`_

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        base_lr (float or list): Initial learning rate which is the
            lower boundary in the cycle for each parameter group.
        max_lr (float or list): Upper learning rate boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (max_lr - base_lr).
            The lr at any cycle is the sum of base_lr
            and some scaling of the amplitude; therefore
            max_lr may not actually be reached depending on
            scaling function.
        step_size_up (int): Number of training iterations in the
            increasing half of a cycle. Default: 2000
        step_size_down (int): Number of training iterations in the
            decreasing half of a cycle. If step_size_down is None,
            it is set to step_size_up. Default: None
        mode (str): One of {triangular, triangular2, exp_range}.
            Values correspond to policies detailed above.
            If scale_fn is not None, this argument is ignored.
            Default: 'triangular'
        gamma (float): Constant in 'exp_range' scaling function:
            gamma**(cycle iterations)
            Default: 1.0
        scale_fn (function): Custom scaling policy defined by a single
            argument lambda function, where
            0 <= scale_fn(x) <= 1 for all x >= 0.
            If specified, then 'mode' is ignored.
            Default: None
        scale_mode (str): {'cycle', 'iterations'}.
            Defines whether scale_fn is evaluated on
            cycle number or cycle iterations (training
            iterations since start of cycle).
            Default: 'cycle'
        cycle_momentum (bool): If ``True``, momentum is cycled inversely
            to learning rate between 'base_momentum' and 'max_momentum'.
            Default: True
        base_momentum (float or list): Lower momentum boundaries in the cycle
            for each parameter group. Note that momentum is cycled inversely
            to learning rate; at the peak of a cycle, momentum is
            'base_momentum' and learning rate is 'max_lr'.
            Default: 0.8
        max_momentum (float or list): Upper momentum boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (max_momentum - base_momentum).
            The momentum at any cycle is the difference of max_momentum
            and some scaling of the amplitude; therefore
            base_momentum may not actually be reached depending on
            scaling function. Note that momentum is cycled inversely
            to learning rate; at the start of a cycle, momentum is 'max_momentum'
            and learning rate is 'base_lr'
            Default: 0.9
        last_epoch (int): The index of the last batch. This parameter is used when
            resuming a training job. Since `step()` should be invoked after each
            batch instead of after each epoch, this number represents the total
            number of *batches* computed, not the total number of epochs computed.
            When last_epoch=-1, the schedule is started from the beginning.
            Default: -1
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1)
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()


    .. _Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
    .. _bckenstler/CLR: https://github.com/bckenstler/CLR
    """

    def __init__(self,
                 optimizer,
                 base_lr,
                 max_lr,
                 step_size_up=2000,
                 step_size_down=None,
                 mode='triangular',
                 gamma=1.,
                 scale_fn=None,
                 scale_mode='cycle',
                 cycle_momentum=True,
                 base_momentum=0.8,
                 max_momentum=0.9,
                 last_epoch=-1,
                 verbose=False):

        # Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        base_lrs = self._format_param('base_lr', optimizer, base_lr)
        if last_epoch == -1:
            for lr, group in zip(base_lrs, optimizer.param_groups):
                group['lr'] = lr

        self.max_lrs = self._format_param('max_lr', optimizer, max_lr)

        step_size_up = float(step_size_up)
        step_size_down = float(step_size_down) if step_size_down is not None else step_size_up
        self.total_size = step_size_up + step_size_down
        self.step_ratio = step_size_up / self.total_size

        if mode not in ['triangular', 'triangular2', 'exp_range'] \
                and scale_fn is None:
            raise ValueError('mode is invalid and scale_fn is None')

        self.mode = mode
        self.gamma = gamma

        if scale_fn is None:
            if self.mode == 'triangular':
                self.scale_fn = self._triangular_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'triangular2':
                self.scale_fn = self._triangular2_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'exp_range':
                self.scale_fn = self._exp_range_scale_fn
                self.scale_mode = 'iterations'
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode

        self.cycle_momentum = cycle_momentum
        if cycle_momentum:
            if 'momentum' not in optimizer.defaults:
                raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')

            base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
            if last_epoch == -1:
                for momentum, group in zip(base_momentums, optimizer.param_groups):
                    group['momentum'] = momentum
            self.base_momentums = [group['momentum'] for group in optimizer.param_groups]
            self.max_momentums = self._format_param('max_momentum', optimizer, max_momentum)

        super(CyclicLR, self).__init__(optimizer, last_epoch, verbose)
        self.base_lrs = base_lrs

    def _format_param(self, name, optimizer, param):
        """Return correctly formatted lr/momentum for each param group."""
        if isinstance(param, (list, tuple)):
            if len(param) != len(optimizer.param_groups):
                raise ValueError("expected {} values for {}, got {}".format(
                    len(optimizer.param_groups), name, len(param)))
            return param
        else:
            return [param] * len(optimizer.param_groups)

    def _triangular_scale_fn(self, x):
        return 1.

    def _triangular2_scale_fn(self, x):
        return 1 / (2. ** (x - 1))

    def _exp_range_scale_fn(self, x):
        return self.gamma**(x)

    def get_lr(self):
        """Calculates the learning rate at batch index. This function treats
        `self.last_epoch` as the last batch index.

        If `self.cycle_momentum` is ``True``, this function has a side effect of
        updating the optimizer's momentum.
        """

        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        cycle = math.floor(1 + self.last_epoch / self.total_size)
        x = 1. + self.last_epoch / self.total_size - cycle
        if x <= self.step_ratio:
            scale_factor = x / self.step_ratio
        else:
            scale_factor = (x - 1) / (self.step_ratio - 1)

        lrs = []
        for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
            base_height = (max_lr - base_lr) * scale_factor
            if self.scale_mode == 'cycle':
                lr = base_lr + base_height * self.scale_fn(cycle)
            else:
                lr = base_lr + base_height * self.scale_fn(self.last_epoch)
            lrs.append(lr)

        if self.cycle_momentum:
            momentums = []
            for base_momentum, max_momentum in zip(self.base_momentums, self.max_momentums):
                base_height = (max_momentum - base_momentum) * scale_factor
                if self.scale_mode == 'cycle':
                    momentum = max_momentum - base_height * self.scale_fn(cycle)
                else:
                    momentum = max_momentum - base_height * self.scale_fn(self.last_epoch)
                momentums.append(momentum)
            for param_group, momentum in zip(self.optimizer.param_groups, momentums):
                param_group['momentum'] = momentum

        return lrs


class CosineAnnealingWarmRestarts(_LRScheduler):
    r"""Set the learning rate of each parameter group using a cosine annealing
    schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
    is the number of epochs since the last restart and :math:`T_{i}` is the number
    of epochs between two warm restarts in SGDR:

    .. math::
        \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
        \cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)

    When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
    When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.

    It has been proposed in
    `SGDR: Stochastic Gradient Descent with Warm Restarts`_.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        T_0 (int): Number of iterations for the first restart.
        T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
        eta_min (float, optional): Minimum learning rate. Default: 0.
        last_epoch (int, optional): The index of last epoch. Default: -1.
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
        https://arxiv.org/abs/1608.03983
    """

    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False):
        if T_0 <= 0 or not isinstance(T_0, int):
            raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
        if T_mult < 1 or not isinstance(T_mult, int):
            raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
        self.T_0 = T_0
        self.T_i = T_0
        self.T_mult = T_mult
        self.eta_min = eta_min
        self.T_cur = last_epoch
        super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        return [self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2
                for base_lr in self.base_lrs]

    def step(self, epoch=None):
        """Step could be called after every batch update

        Example:
            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
            >>> iters = len(dataloader)
            >>> for epoch in range(20):
            >>>     for i, sample in enumerate(dataloader):
            >>>         inputs, labels = sample['inputs'], sample['labels']
            >>>         optimizer.zero_grad()
            >>>         outputs = net(inputs)
            >>>         loss = criterion(outputs, labels)
            >>>         loss.backward()
            >>>         optimizer.step()
            >>>         scheduler.step(epoch + i / iters)

        This function can be called in an interleaved way.

        Example:
            >>> scheduler = CosineAnnealingWarmRestarts(optimizer, T_0, T_mult)
            >>> for epoch in range(20):
            >>>     scheduler.step()
            >>> scheduler.step(26)
            >>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
        """

        if epoch is None and self.last_epoch < 0:
            epoch = 0

        if epoch is None:
            epoch = self.last_epoch + 1
            self.T_cur = self.T_cur + 1
            if self.T_cur >= self.T_i:
                self.T_cur = self.T_cur - self.T_i
                self.T_i = self.T_i * self.T_mult
        else:
            if epoch < 0:
                raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
            if epoch >= self.T_0:
                if self.T_mult == 1:
                    self.T_cur = epoch % self.T_0
                else:
                    n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult))
                    self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (self.T_mult - 1)
                    self.T_i = self.T_0 * self.T_mult ** (n)
            else:
                self.T_i = self.T_0
                self.T_cur = epoch
        self.last_epoch = math.floor(epoch)

        class _enable_get_lr_call:

            def __init__(self, o):
                self.o = o

            def __enter__(self):
                self.o._get_lr_called_within_step = True
                return self

            def __exit__(self, type, value, traceback):
                self.o._get_lr_called_within_step = False
                return self

        with _enable_get_lr_call(self):
            for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
                param_group, lr = data
                param_group['lr'] = lr
                self.print_lr(self.verbose, i, lr, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]


class OneCycleLR(_LRScheduler):
    r"""Sets the learning rate of each parameter group according to the
    1cycle learning rate policy. The 1cycle policy anneals the learning
    rate from an initial learning rate to some maximum learning rate and then
    from that maximum learning rate to some minimum learning rate much lower
    than the initial learning rate.
    This policy was initially described in the paper `Super-Convergence:
    Very Fast Training of Neural Networks Using Large Learning Rates`_.

    The 1cycle learning rate policy changes the learning rate after every batch.
    `step` should be called after a batch has been used for training.

    This scheduler is not chainable.

    Note also that the total number of steps in the cycle can be determined in one
    of two ways (listed in order of precedence):

    #. A value for total_steps is explicitly provided.
    #. A number of epochs (epochs) and a number of steps per epoch
       (steps_per_epoch) are provided.
       In this case, the number of total steps is inferred by
       total_steps = epochs * steps_per_epoch

    You must either provide a value for total_steps or provide a value for both
    epochs and steps_per_epoch.

    The default behaviour of this scheduler follows the fastai implementation of 1cycle, which
    claims that "unpublished work has shown even better results by using only two phases". To
    mimic the behaviour of the original paper instead, set ``three_phase=True``.

    Args:
        optimizer (Optimizer): Wrapped optimizer.
        max_lr (float or list): Upper learning rate boundaries in the cycle
            for each parameter group.
        total_steps (int): The total number of steps in the cycle. Note that
            if a value is not provided here, then it must be inferred by providing
            a value for epochs and steps_per_epoch.
            Default: None
        epochs (int): The number of epochs to train for. This is used along
            with steps_per_epoch in order to infer the total number of steps in the cycle
            if a value for total_steps is not provided.
            Default: None
        steps_per_epoch (int): The number of steps per epoch to train for. This is
            used along with epochs in order to infer the total number of steps in the
            cycle if a value for total_steps is not provided.
            Default: None
        pct_start (float): The percentage of the cycle (in number of steps) spent
            increasing the learning rate.
            Default: 0.3
        anneal_strategy (str): {'cos', 'linear'}
            Specifies the annealing strategy: "cos" for cosine annealing, "linear" for
            linear annealing.
            Default: 'cos'
        cycle_momentum (bool): If ``True``, momentum is cycled inversely
            to learning rate between 'base_momentum' and 'max_momentum'.
            Default: True
        base_momentum (float or list): Lower momentum boundaries in the cycle
            for each parameter group. Note that momentum is cycled inversely
            to learning rate; at the peak of a cycle, momentum is
            'base_momentum' and learning rate is 'max_lr'.
            Default: 0.85
        max_momentum (float or list): Upper momentum boundaries in the cycle
            for each parameter group. Functionally,
            it defines the cycle amplitude (max_momentum - base_momentum).
            Note that momentum is cycled inversely
            to learning rate; at the start of a cycle, momentum is 'max_momentum'
            and learning rate is 'base_lr'
            Default: 0.95
        div_factor (float): Determines the initial learning rate via
            initial_lr = max_lr/div_factor
            Default: 25
        final_div_factor (float): Determines the minimum learning rate via
            min_lr = initial_lr/final_div_factor
            Default: 1e4
        three_phase (bool): If ``True``, use a third phase of the schedule to annihilate the
            learning rate according to 'final_div_factor' instead of modifying the second
            phase (the first two phases will be symmetrical about the step indicated by
            'pct_start').
        last_epoch (int): The index of the last batch. This parameter is used when
            resuming a training job. Since `step()` should be invoked after each
            batch instead of after each epoch, this number represents the total
            number of *batches* computed, not the total number of epochs computed.
            When last_epoch=-1, the schedule is started from the beginning.
            Default: -1
        verbose (bool): If ``True``, prints a message to stdout for
            each update. Default: ``False``.

    Example:
        >>> data_loader = torch.utils.data.DataLoader(...)
        >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
        >>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
        >>> for epoch in range(10):
        >>>     for batch in data_loader:
        >>>         train_batch(...)
        >>>         scheduler.step()


    .. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
        https://arxiv.org/abs/1708.07120
    """
    def __init__(self,
                 optimizer,
                 max_lr,
                 total_steps=None,
                 epochs=None,
                 steps_per_epoch=None,
                 pct_start=0.3,
                 anneal_strategy='cos',
                 cycle_momentum=True,
                 base_momentum=0.85,
                 max_momentum=0.95,
                 div_factor=25.,
                 final_div_factor=1e4,
                 three_phase=False,
                 last_epoch=-1,
                 verbose=False):

        # Validate optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        # Validate total_steps
        if total_steps is None and epochs is None and steps_per_epoch is None:
            raise ValueError("You must define either total_steps OR (epochs AND steps_per_epoch)")
        elif total_steps is not None:
            if total_steps <= 0 or not isinstance(total_steps, int):
                raise ValueError("Expected positive integer total_steps, but got {}".format(total_steps))
            self.total_steps = total_steps
        else:
            if epochs <= 0 or not isinstance(epochs, int):
                raise ValueError("Expected positive integer epochs, but got {}".format(epochs))
            if steps_per_epoch <= 0 or not isinstance(steps_per_epoch, int):
                raise ValueError("Expected positive integer steps_per_epoch, but got {}".format(steps_per_epoch))
            self.total_steps = epochs * steps_per_epoch

        if three_phase:
            self._schedule_phases = [
                {
                    'end_step': float(pct_start * self.total_steps) - 1,
                    'start_lr': 'initial_lr',
                    'end_lr': 'max_lr',
                    'start_momentum': 'max_momentum',
                    'end_momentum': 'base_momentum',
                },
                {
                    'end_step': float(2 * pct_start * self.total_steps) - 2,
                    'start_lr': 'max_lr',
                    'end_lr': 'initial_lr',
                    'start_momentum': 'base_momentum',
                    'end_momentum': 'max_momentum',
                },
                {
                    'end_step': self.total_steps - 1,
                    'start_lr': 'initial_lr',
                    'end_lr': 'min_lr',
                    'start_momentum': 'max_momentum',
                    'end_momentum': 'max_momentum',
                },
            ]
        else:
            self._schedule_phases = [
                {
                    'end_step': float(pct_start * self.total_steps) - 1,
                    'start_lr': 'initial_lr',
                    'end_lr': 'max_lr',
                    'start_momentum': 'max_momentum',
                    'end_momentum': 'base_momentum',
                },
                {
                    'end_step': self.total_steps - 1,
                    'start_lr': 'max_lr',
                    'end_lr': 'min_lr',
                    'start_momentum': 'base_momentum',
                    'end_momentum': 'max_momentum',
                },
            ]

        # Validate pct_start
        if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
            raise ValueError("Expected float between 0 and 1 pct_start, but got {}".format(pct_start))

        # Validate anneal_strategy
        if anneal_strategy not in ['cos', 'linear']:
            raise ValueError("anneal_strategy must by one of 'cos' or 'linear', instead got {}".format(anneal_strategy))
        elif anneal_strategy == 'cos':
            self.anneal_func = self._annealing_cos
        elif anneal_strategy == 'linear':
            self.anneal_func = self._annealing_linear

        # Initialize learning rate variables
        max_lrs = self._format_param('max_lr', self.optimizer, max_lr)
        if last_epoch == -1:
            for idx, group in enumerate(self.optimizer.param_groups):
                group['initial_lr'] = max_lrs[idx] / div_factor
                group['max_lr'] = max_lrs[idx]
                group['min_lr'] = group['initial_lr'] / final_div_factor

        # Initialize momentum variables
        self.cycle_momentum = cycle_momentum
        if self.cycle_momentum:
            if 'momentum' not in self.optimizer.defaults and 'betas' not in self.optimizer.defaults:
                raise ValueError('optimizer must support momentum with `cycle_momentum` option enabled')
            self.use_beta1 = 'betas' in self.optimizer.defaults
            max_momentums = self._format_param('max_momentum', optimizer, max_momentum)
            base_momentums = self._format_param('base_momentum', optimizer, base_momentum)
            if last_epoch == -1:
                for m_momentum, b_momentum, group in zip(max_momentums, base_momentums, optimizer.param_groups):
                    if self.use_beta1:
                        _, beta2 = group['betas']
                        group['betas'] = (m_momentum, beta2)
                    else:
                        group['momentum'] = m_momentum
                    group['max_momentum'] = m_momentum
                    group['base_momentum'] = b_momentum

        super(OneCycleLR, self).__init__(optimizer, last_epoch, verbose)

    def _format_param(self, name, optimizer, param):
        """Return correctly formatted lr/momentum for each param group."""
        if isinstance(param, (list, tuple)):
            if len(param) != len(optimizer.param_groups):
                raise ValueError("expected {} values for {}, got {}".format(
                    len(optimizer.param_groups), name, len(param)))
            return param
        else:
            return [param] * len(optimizer.param_groups)

    def _annealing_cos(self, start, end, pct):
        "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        cos_out = math.cos(math.pi * pct) + 1
        return end + (start - end) / 2.0 * cos_out

    def _annealing_linear(self, start, end, pct):
        "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
        return (end - start) * pct + start

    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.", UserWarning)

        lrs = []
        step_num = self.last_epoch

        if step_num > self.total_steps:
            raise ValueError("Tried to step {} times. The specified number of total steps is {}"
                             .format(step_num + 1, self.total_steps))

        for group in self.optimizer.param_groups:
            start_step = 0
            for i, phase in enumerate(self._schedule_phases):
                end_step = phase['end_step']
                if step_num <= end_step or i == len(self._schedule_phases) - 1:
                    pct = (step_num - start_step) / (end_step - start_step)
                    computed_lr = self.anneal_func(group[phase['start_lr']], group[phase['end_lr']], pct)
                    if self.cycle_momentum:
                        computed_momentum = self.anneal_func(group[phase['start_momentum']], group[phase['end_momentum']], pct)
                    break
                start_step = phase['end_step']

            lrs.append(computed_lr)
            if self.cycle_momentum:
                if self.use_beta1:
                    _, beta2 = group['betas']
                    group['betas'] = (computed_momentum, beta2)
                else:
                    group['momentum'] = computed_momentum

        return lrs

## WarmupLRScheduler

In [None]:
import math


def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False):
    """Linear warmup for warmup_steps, optionally with cosine annealing or linear decay to 0 at total_steps."""
    assert not (linear and cosine)

    def fn(step):
        if step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))

        if not (cosine or linear):
            # no decay
            return 1.0

        progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        if cosine:
            # cosine decay
            return 0.5 * (1.0 + math.cos(math.pi * progress))

        # linear decay
        return 1.0 - progress

    return fn



class DelayedCosineAnnealingWarmRestarts(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts):
    def __init__(self,
                 optimizer, 
                 T_0,
                 T_mult=2,
                 eta_min=0,
                 start_epoch=0):
        self.start_epoch = start_epoch
        super().__init__(optimizer=optimizer, 
                         T_0=T_0,
                         T_mult=T_mult,
                         eta_min=eta_min)
        
        
    def step(self, epoch=0):
        if epoch >= self.start_epoch:
            super().step(epoch=epoch-self.start_epoch)


# lr_sched = linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False)
# lr_sched = linear_warmup_decay(warmup_steps, total_steps, cosine=False, linear=True)
# plt.plot(range(40), list(map(lr_sched, range(40))))





class TestModel(BoringModel):
    def __init__(self):
            """Testing PL Module.
            Use as follows:
            - subclass
            - modify the behavior for what you want
            class TestModel(BaseTestModel):
                def training_step(...):
                    # do your own thing
            or:
            model = BaseTestModel()
            model.training_epoch_end = None
            """
            super().__init__()
            self.backbone = torch.nn.Linear(32, 32)
            self.head = torch.nn.Linear(32, 2)
            self.layer = torch.nn.Sequential(OrderedDict({"backbone":self.backbone,
                                                          "head":self.head}))
            
            self.lr_hist = {"backbone":[],
                            "head":[]}


    def training_step_end(self, training_step_outputs):
#         print(dir(self.trainer))
        lr_groups = [g['lr'] for g in self.trainer.optimizers[0].param_groups]
        self.lr_hist["backbone"].append(lr_groups) #optimizer.param_groups[0]["lr"])
        self.lr_hist["head"].append(lr_groups) #optimizer.param_groups[1]["lr"])

        return training_step_outputs

    def configure_optimizers(self):
#         optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        optim_lr = 0.1
        optimizer = torch.optim.SGD([{"params":self.backbone.parameters(), "lr":optim_lr*0.1, "weight_decay": 0.01},
                                     {"params":self.head.parameters(), "lr":optim_lr, "weight_decay": 0.01}])
        warmup_steps = 5
        total_steps = 20
        lr_lambda = linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False)
        warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

        return [optimizer], [{"scheduler":warmup_scheduler, "interval": "epoch"}]
        
        
        
#         min_lr = optim_lr*0.1*0.1
#         decay_scheduler = DelayedCosineAnnealingWarmRestarts(optimizer, 
#                                                              T_0=1,
#                                                              T_mult=2,
#                                                              eta_min=min_lr,
#                                                              start_epoch=warmup_steps)
#         return [optimizer], [{"scheduler":warmup_scheduler, "interval": "epoch"},
#                              {"scheduler":decay_scheduler, "interval": "epoch"}]

    

In [None]:
from omegaconf import OmegaConf

In [None]:
OmegaConf.load

In [None]:

def configure_schedulers(optimizer,
                         config):
    
    if config.scheduler_type == "linear_warmup_cosine_decay":
        warmup_steps = config.get("warmup_steps", 5)
        total_steps = config.get("total_steps", 20)
        lr_lambda = linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False)
        warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        return [{"scheduler":warmup_scheduler, "interval": "epoch"}]
    elif config.scheduler_type == "linear_warmup_cosine_decay_w_warm_restarts":
        warmup_steps = config.get("warmup_steps", 5)
        total_steps = config.get("total_steps", 20)
        lr_lambda = linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False)
        warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

        min_lr = config.get("min_lr", )
        decay_scheduler = DelayedCosineAnnealingWarmRestarts(optimizer,
                                                             T_0=config.get("T_0",1),
                                                             T_mult=config.get("T_mult", 2),
                                                             eta_min=config.get("eta_min",0),
                                                             start_epoch=config.get("start_epoch", warmup_steps))
        return [optimizer], [{"scheduler":warmup_scheduler, "interval": "epoch"},
                             {"scheduler":decay_scheduler, "interval": "epoch"}]
    else:
        raise ConfigurationError(f"Misconfigured Scheduler config:{config}")


In [None]:
# max_lr = 0.3
# base_lr = 0.1
num_epochs = 40

model = TestModel()

trainer = Trainer(
    max_epochs=num_epochs,
    limit_train_batches=1,
)
trainer.fit(model)

In [None]:
import seaborn as sns
# sns.set_style("talk")
sns.set_theme(context="talk")

fig, ax = plt.subplots(1,1,figsize=(16,12))
ax.plot(t, model.lr_hist["backbone"], label="backbone", alpha=0.5)
ax.plot(t, model.lr_hist["head"], label="head", alpha=0.5)
ax.set_xlabel("epochs")
ax.set_ylabel("lr")
plt.suptitle("LR Warmup -> CosineAnnealing w/ WarmRestarts")

plt.savefig("LR Warmup-CosineAnnealing w WarmRestarts.png")

In [None]:
# lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20, eta_min=0.1*optim_lr*0.1)

# lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
#                                                                       T_0=5,
#                                                                       T_mult=2,
#                                                                       eta_min=0.1*optim_lr*0.1)

warmup_lambda = 

# lr_scheduler_1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])



# lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR(
#     optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3
# )


t = list(range(num_epochs))
lr_hist = {"backbone":[],
           "head":[]}

for i in t:
#     if i <= lr_scheduler_1.T_max:
#     if i <= lr_scheduler_1.T_0:
    lr_scheduler_1.step()
#     else:
#         lr_scheduler_2.step()
    lr_hist["backbone"].append(optimizer.param_groups[0]["lr"])
    lr_hist["head"].append(optimizer.param_groups[1]["lr"])
#     print(f"last_lr={last_lr}")
    
    
fig, ax = plt.subplots(1,1,figsize=(12,12))
ax.plot(t, lr_hist["backbone"], label="backbone", alpha=0.5)
ax.plot(t, lr_hist["head"], label="head", alpha=0.5)
# plt.xaxis.

In [None]:
for group in optimizer.param_groups:
    group.setdefault('initial_lr', group['lr'])
    print({k:v for k,v in group.items() if k!="params"})
base_lrs = [group['initial_lr'] for group in optimizer.param_groups]

print(base_lrs)

In [None]:
for g in optimizer.param_groups:
    for k,v in g.items():
        if k != "params":
            print(k,f"{v:.3f}")

In [None]:
g.setdefault

In [None]:
optimizer.param_groups[1].keys()

In [None]:
from rich import print as pp
import pandas as pd
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torchvision
import matplotlib.pyplot as plt
import time
import os
from pathlib import Path

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import logging

logger = logging.Logger(__name__)
logger.setLevel('INFO')

from tqdm.auto import tqdm, trange


%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import pytorch_lightning as pl
import timm
import glob
import hydra
from collections import OrderedDict
from typing import *

from lightning_hydra_classifiers.models.transfer import *
from rich import print as pp
from lightning_hydra_classifiers.utils.model_utils import count_parameters, collect_results
from lightning_hydra_classifiers.utils.metric_utils import get_per_class_metrics, get_scalar_metrics
from lightning_hydra_classifiers.models.backbones.backbone import build_model
import pytorch_lightning as pl

pl.seed_everything(42)

from lightning_hydra_classifiers.scripts.multitask.train import MultiTaskDataModule, LitMultiTaskModule, ImagePredictionLogger, train_task,  CIFAR10DataModule, run_multitask_test, load_data_and_model, load_data, resolve_config, configure_callbacks, configure_loggers, configure_trainer
from lightning_hydra_classifiers.data.datasets.common import toPIL
from lightning_hydra_classifiers.utils.etl_utils import ETL
from omegaconf import OmegaConf

from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning_hydra_classifiers.scripts.pretrain import lr_tuner

from lightning_hydra_classifiers.scripts.multitask.train import configure_callbacks, configure_loggers#, configure_trainer

## Developing Early Stopping multi-stage subclass for finetuning

### trying out built in lightning tests for fine tuning

In [None]:
# os.chdir(Path(pl.__file__).parent)
# os.getcwd()

In [None]:
# # %%writefile first_test.py

# # See the License for the specific language governing permissions and
# # limitations under the License.
# from collections import OrderedDict
# import matplotlib.pyplot as plt

# import pytest
# import torch
# from torch import nn
# from torch.optim import Optimizer, SGD
# from torch.utils.data import DataLoader

# from pytorch_lightning import LightningModule, seed_everything, Trainer
# from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint
# from pytorch_lightning.callbacks.base import Callback
# from tests.helpers import BoringModel, RandomDataset

# from typing import *

# # class RandomDataset(torch.utils.data.Dataset):
# #     def __init__(self, num_samples=2000, shape=(3,64,64)):
# #         self.num_samples = num_samples
# #         self.shape = shape
# #         self.data = torch.randn(num_samples, *shape)

# #     def __getitem__(self, index):
# #         return self.data[index]

# #     def __len__(self):
# #         return self.num_samples

# # class RandomTupleSupervisedDataset(RandomDataset):
    
# #     def __init__(self, num_classes=1000, num_samples=2000, shape=(3,64,64)):
# #         super().__init__(num_samples, shape)
# #         self.num_classes = num_classes
        
# #         self.targets = torch.randperm(num_classes)[:num_samples]
        
# #     def __getitem__(self, index):
# #         return self.data[index], self.targets[index]
        

# # dataset = RandomTupleSupervisedDataset(1000, 200, (3,128,128))
# # dataset
# # dataset.data.shape

# # class TestBackboneFinetuningCallback(BackboneFinetuning):
# #     def on_train_epoch_start(self, trainer, pl_module):
# #         super().on_train_epoch_start(trainer, pl_module)
# #         epoch = trainer.current_epoch
# #         if self.unfreeze_backbone_at_epoch <= epoch:
# #             optimizer = trainer.optimizers[0]
# #             current_lr = optimizer.param_groups[0]["lr"]
# #             backbone_lr = self.previous_backbone_lr
# #             if epoch < 6:
# #                 assert backbone_lr <= current_lr
# #             else:
# #                 assert backbone_lr == current_lr

# ###############################################################

# # class BackboneFinetuningCallback(pl.callbacks.Callback):

# #         def __init__(self,
    

# print("current dir:", os.getcwd())

# import os
# os.path.abspath

In [None]:
# %%writefile finetuning_callback_test.py

# import numpy as np
# import os
# import pytest
# from typing import *
# import pytorch_lightning as pl
# import torch
# from lightning_hydra_classifiers.models.transfer import *
# from torch.utils.data import DataLoader

# from torch import nn
# # from pytorch_lightning import LightningModule, seed_everything, Trainer
# import logging
# import json
# logging.basicConfig(level=logging.DEBUG)
# logger = logging.Logger(__name__)
# logger.setLevel('INFO')
# pylog = logging.getLogger()


# BN_TYPE = (torch.nn.modules.batchnorm._BatchNorm,)

# def is_bn(layer: nn.Module) -> bool:
#     """ Return True if layer's type is one of the batch norms."""
#     return isinstance(layer, BN_TYPE)

# def grad_check(tensor: torch.Tensor) -> bool:
#     """ Returns True if tensor.requires_grad==True, else False."""
#     return tensor.requires_grad == True


# # os.chdir("/media/data/jacob/GitHub/lightning-hydra-classifiers")#/tests")

# class RandomDataset(torch.utils.data.Dataset):
#     def __init__(self, num_samples=2000, shape=(3,64,64)):
#         self.num_samples = num_samples
#         self.shape = shape
#         self.data = torch.randn(num_samples, *shape)

#     def __getitem__(self, index):
#         return self.data[index]

#     def __len__(self):
#         return self.num_samples

# class RandomTupleSupervisedDataset(RandomDataset):
    
#     def __init__(self, num_classes=1000, num_samples=2000, shape=(3,64,64)):
#         super().__init__(num_samples, shape)
#         self.num_classes = num_classes
        
#         self.targets = torch.randperm(num_classes)[:num_samples]
        
#     def __getitem__(self, index):
#         return self.data[index], self.targets[index]

##############################################

    
    
# class FinetuningLightningCallback(pl.callbacks.Callback):
    
# # class FinetuningLightningPlugin:
#     mode_dict = {"min": torch.lt, "max": torch.gt}
#     order_dict = {"min": "<", "max": ">"}
    
    
#     def __init__(self,
#                  monitor: str="val_loss",
#                  mode: str="min",
#                  patience: int=4):
        
# #         if pl_module.hparams.finetuning_strategy == "finetuning_unfreeze_layers_on_plateau":
#         self.monitor = monitor
#         self.mode = mode
#         self.patience = patience
# #         self.best_metric = 0
#         self.milestone_index = 0
        
# #         self.min_delta *= 1 if self.monitor_op == torch.gt else -1
#         torch_inf = torch.tensor(np.Inf)
#         self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
    
#         self.milestone_logs = []
        
#     def on_fit_start(self,
#                      trainer,
#                      pl_module):
#         self.milestones = pl_module.finetuning_milestones
#         print(f"Setting milestones: {pl_module.finetuning_milestones}")
#         self._finished = False
    
    
#     def finetuning_pretrained_strategy(self,
#                                        trainer: "pl.Trainer",
#                                        pl_module):
#         """
        
        
#         """
#         epoch = trainer.current_epoch
#         logs = trainer.callback_metrics
#         current = logs.get(self.monitor)
        
#         if self.mode == "min":
#             new_best = current < self.best_score
#         elif self.mode == "max":
#             new_best = current > self.best_score
        
#         if self._finished:
#             return
        
#         if new_best:
#             self.best_score = current
#             self.wait_epochs = 0
#             print(f"New best score: {self.monitor}={self.best_score}.")
#         elif self.wait_epochs >= self.patience:
            
#             next_to_unfreeze = self.milestones[self.milestone_index]
#             print(f"Patience of {self.patience} surpassed at epoch: {epoch} unfreezing down to: {next_to_unfreeze}")
            
#             pl_module.unfreeze_backbone_top_layers(unfreeze_down_to=next_to_unfreeze)
#             self.wait_epochs = 0
#             self.milestone_index += 1
#             self.milestone_logs.append({"epoch":epoch,
#                                         "unfreeze_at_layer":next_to_unfreeze,
#                                         "trainable_params":pl_module.get_trainable_parameters(count_params=True),
#                                         "nontrainable_params":pl_module.get_nontrainable_parameters(count_params=True)})
#             if self.milestone_index >= len(self.milestones):
#                 self._finished = True
#         else:
#             self.wait_epochs += 1
    
#     @property
#     def monitor_op(self) -> Callable:
#         return self.mode_dict[self.mode]
    
#     def on_epoch_end(self, trainer, pl_module):
#         """Called when the epoch ends."""

#         self.finetuning_pretrained_strategy(trainer=trainer, pl_module=pl_module)
#         try:
#             pl_module.log("nontrainable_params", pl_module.get_nontrainable_parameters(count_params=True))
#             pl_module.log("trainable_params", pl_module.get_trainable_parameters(count_params=True))
# #             pl_module.logger.summary["milestones"] = self.milestone_logs[-1]
#         except Exception as e:
#             print(e)
#             print(f"logging to wandb didnt work bro")



########################
########################


from lightning_hydra_classifiers.callbacks.finetuning_callbacks import FinetuningLightningCallback

class TestLightningClassifier(LightningClassifier):

    def __init__(self,
                 backbone_name='resnet50',
                 pretrained: Union[bool, str]=True,
                 num_classes: int=1000,
                 finetuning_strategy: str="feature_extractor",
                 seed: int=None,
                 **kwargs):

        super().__init__(backbone_name=backbone_name,
                         pretrained=pretrained,
                         num_classes=num_classes,
                         pool_type="avgdrop",
                         head_type="linear",
                         hidden_size=None, lr=0.01, backbone_lr_mult=0.1,
                         weight_decay=0.01,
                         finetuning_strategy=finetuning_strategy,
                         seed=42,
                        **kwargs)
        self._verbose=True
        
        
    
        
    def training_step(self, batch, batch_idx):
        self.log("train_loss",1)
        return {"loss": torch.ones(1, requires_grad=True)}
    
    def validation_step(self, batch, batch_idx):
        self.log("val_loss",1)
        return {"loss": torch.ones(1, requires_grad=True)}
    
    
#         output = super().training_step(batch, batch_idx)
#         self._verbose=False
#         return output

    def training_step_end(self, outputs):
        super().training_step_end(outputs)

    def print(self, *args):
        if self._verbose:
            print(*args)

    def train_dataloader(self):
        return DataLoader(RandomTupleSupervisedDataset(num_classes=1000, num_samples=50, shape=(3,64,64)), batch_size=2)

    def val_dataloader(self):
        return DataLoader(RandomTupleSupervisedDataset(num_classes=1000, num_samples=50, shape=(3,64,64)), batch_size=2)


def save_log(log, fp):
    with open(fp, "w") as fp:
        json.dump(log, fp, indent=4, sort_keys=False)

        
        
        
###############################
###############################


# @pytest.mark.parametrize("finetuning_strategy",
#                         [("feature_extractor",)
#                          "feature_extractor_+_bn.eval()",
#                          "feature_extractor_+_except_bn"])

# @pytest.mark.parametrize("finetuning_strategy, expected_layer_counts",
#     [
#         ("feature_extractor",
#             {"is_training":{'True': 53, 'False': 0, 'Total': 53}, 
#              "requires_grad":{'True': 0, 'False': 53, 'Total': 53}}
#         ),
#         ("feature_extractor_+_bn.eval()",
#             {"is_training":{'True': 0, 'False': 53, 'Total': 53}, 
#              "requires_grad":{'True': 0, 'False': 53, 'Total': 53}}
#         ),
#         ("feature_extractor_+_except_bn",
#             {"is_training":{'True': 53, 'False': 0, 'Total': 53}, 
#              "requires_grad":{'True': 53, 'False': 0, 'Total': 53}}
#         )
#     ]
#                         )
# @pytest.mark.parametrize()
def test_finetuning_callback(tmpdir):#, finetuning_strategy: str, expected_layer_counts: Dict[str,Dict[str,int]]):#, expectations: Dict[str,Any]):
    """Test finetuning strategy works as expected."""

    pl.seed_everything(42)
    
    callbacks = [FinetuningLightningCallback(monitor="val_loss",
                                             mode="min",
                                             patience=4)]

    model = TestLightningClassifier(finetuning_strategy="finetuning_unfreeze_layers_on_plateau")
#     callback = TestBackboneFinetuningCallback(unfreeze_backbone_at_epoch=3, verbose=False)

    trainer = pl.Trainer(limit_train_batches=2,
                         limit_val_batches=2,
                         default_root_dir="/home/jrose3",
                         log_every_n_steps=1,
                         callbacks=callbacks,
                         max_epochs=25)
    trainer.fit(model)
    
    
#     pylog.info(f"strategy: {finetuning_strategy}")
    model._verbose = True
    layer_counts = model.count_trainable_batchnorm_layers()



    from rich import print as pp
    print("milestone_logs:")
    pp(callbacks[0].milestone_logs)

#     pylog.info(f"strategy: {finetuning_strategy}")
#     pylog.info(f"Expected layer counts: {expected_layer_counts}")
    pylog.info(f"count trainable batchnorm layers`: {model.count_trainable_batchnorm_layers()}")
    pylog.info(f"count trainable layers: {model.get_trainable_parameters(count_layers=True)}")
    pylog.info(f"count nontrainable layers: {model.get_nontrainable_parameters(count_layers=True)}")
    pylog.info(f"count trainable params: {model.get_trainable_parameters(count_params=True)}")
    pylog.info(f"count nontrainable params: {model.get_nontrainable_parameters(count_params=True)}")

    
    
    
#     assert expected_layer_counts["is_training"]["True"] == layer_counts[0]["True"]
#     assert expected_layer_counts["is_training"]["False"] == layer_counts[0]["False"]

#     assert expected_layer_counts["requires_grad"]["True"] == layer_counts[1]["True"]
#     assert expected_layer_counts["requires_grad"]["False"] == layer_counts[1]["False"]

In [None]:
# model = TestLightningClassifier(finetuning_strategy="feature_extractor_+_except_bn")

In [None]:
# class TestLightningClassifier(LightningClassifier):

#     def __init__(self,
#                  backbone_name='resnet50',
#                  pretrained: Union[bool, str]=True,
#                  num_classes: int=1000,
#                  finetuning_strategy: str="feature_extractor",
#                  seed: int=None,
#                  **kwargs):

#         super().__init__(backbone_name=backbone_name,
#                          pretrained=pretrained,
#                          num_classes=num_classes,
#                          pool_type="avgdrop",
#                          head_type="linear",
#                          hidden_size=None, lr=0.01, backbone_lr_mult=0.1,
#                          weight_decay=0.01,
#                          finetuning_strategy=finetuning_strategy,
#                          seed=42,
#                         **kwargs)
#         self._verbose=True

#     def training_step(self, batch, batch_idx):
#         output = super().training_step(batch, batch_idx)
# #             self.print(f"During: self.training_step")
# #             self.count_trainable_batchnorm_layers()
#         self._verbose=False
#         return output

#     def training_step_end(self, outputs):
#         output = super().training_step_end(outputs)

#     def print(self, *args):
#         if self._verbose:
#             print(*args)

# #         def configure_optimizers(self):
# #             optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
# #             lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.7)
# #             return [optimizer], [lr_scheduler]

#     def train_dataloader(self):
#         return DataLoader(RandomTupleSupervisedDataset(num_classes=1000, num_samples=50, shape=(3,64,64)), batch_size=2)

#     def val_dataloader(self):
#         return DataLoader(RandomTupleSupervisedDataset(num_classes=1000, num_samples=50, shape=(3,64,64)), batch_size=2)




In [None]:

# # finetuning_strategy = "feature_extractor"
# fixtures = ["feature_extractor",
#             "feature_extractor_+_bn.eval()",
#             "feature_extractor_+_except_bn"]

# for finetuning_strategy in fixtures:
#     print(f"strategy: {finetuning_strategy}")

#     model = TestLightningClassifier(finetuning_strategy=finetuning_strategy)
#     training_batch_stats, params_require_grads = model.count_trainable_batchnorm_layers()

In [None]:
# from rich.console import Console
# from rich.markdown import Markdown
# console = Console()
# markdown = Markdown("$$\delta \pi = 3.14159265358979323$$")
# console.print(markdown)
# from IPython.display import display, Math, Latex
# display(Math(r"$$mean=E[x^k]," + "\n" +"Var = Var[x^k]"))
# from IPython.display import display, Math, Latex
# display(Math(r"$$\delta \pi = 3.14159265358979323$$"))

In [None]:
# print("count trainable layers: ", model.count_trainable_layers())

### Next

In [None]:

# class ThresholdBasedFinetuning(pl.callbacks.EarlyStopping):
    
#     mode_dict = {"min": torch.lt, "max": torch.gt}

#     order_dict = {"min": "<", "max": ">"}
    
#     def __init__(self,
#                  monitor: Optional[str] = None,
#                  min_delta: float = 0.0,
#                  patience: int = 3,
#                  verbose: bool = False,
#                  mode: str = "min",
#                  strict: bool = True,
#                  check_finite: bool = True,
#                  stopping_threshold: Optional[float] = None,
#                  divergence_threshold: Optional[float] = None,
#                  check_on_train_epoch_end: Optional[bool] = None):
#         super().__init__()
        
        
        
#     def on_save_checkpoint(
#         self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
#     ) -> Dict[str, Any]:
#         return {
#             "wait_count": self.wait_count,
#             "stopped_epoch": self.stopped_epoch,
#             "best_score": self.best_score,
#             "patience": self.patience,
#         }

#     def on_load_checkpoint(
#         self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
#     ) -> None:
#         self.wait_count = callback_state["wait_count"]
#         self.stopped_epoch = callback_state["stopped_epoch"]
#         self.best_score = callback_state["best_score"]
#         self.patience = callback_state["patience"]

        
#     def _should_skip_check(self, trainer: "pl.Trainer") -> bool:
#         from pytorch_lightning.trainer.states import TrainerFn

#         return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking

#     def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
#         if self._should_skip_check(trainer):
#             return
#         self._run_early_stopping_check(trainer)

        
#     def _run_early_stopping_check(self, trainer: "pl.Trainer") -> None:
#         """Checks whether the early stopping condition is met and if so tells the trainer to stop the training."""
#         logs = trainer.callback_metrics

#         if trainer.fast_dev_run or not self._validate_condition_metric(  # disable early_stopping with fast_dev_run
#             logs
#         ):  # short circuit if metric not present
#             return

#         current = logs.get(self.monitor)
#         should_stop, reason = self._evaluate_stopping_criteria(current)

#         # stop every ddp process if any world process decides to stop
#         should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
#         trainer.should_stop = trainer.should_stop or should_stop
#         if should_stop:
#             self.stopped_epoch = trainer.current_epoch
#         if reason and self.verbose:
#             self._log_info(trainer, reason)

#     def _evaluate_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, Optional[str]]:
#         should_stop = False
#         reason = None
#         if self.check_finite and not torch.isfinite(current):
#             should_stop = True
#             reason = (
#                 f"Monitored metric {self.monitor} = {current} is not finite."
#                 f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
#             )
#         elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
#             should_stop = True
#             reason = (
#                 "Stopping threshold reached:"
#                 f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
#                 " Signaling Trainer to stop."
#             )
#         elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
#             should_stop = True
#             reason = (
#                 "Divergence threshold reached:"
#                 f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
#                 " Signaling Trainer to stop."
#             )
#         elif self.monitor_op(current - self.min_delta, self.best_score.to(current.device)):
#             should_stop = False
#             reason = self._improvement_message(current)
#             self.best_score = current
#             self.wait_count = 0
#         else:
#             self.wait_count += 1
#             if self.wait_count >= self.patience:
#                 should_stop = True
#                 reason = (
#                     f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} records."
#                     f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
#                 )

#         return should_stop, reason

In [None]:
from typing import *
class TestLightningClassifier(LightningClassifier):
    
    def __init__(self,
                 backbone_name='resnet50',
                 pretrained: Union[bool, str]=True,
                 num_classes: int=1000,
                 pool_size: int=1,
                 pool_type: str='avg',
                 head_type: str='linear',
                 hidden_size: Optional[int]=512,
                 lr: float=2e-03,
                 backbone_lr_mult: bool=0.1,
                 weight_decay: float=0.01,
                 finetuning_strategy: str="finetuning_unfreeze_layers_on_plateau", #"feature_extractor",
                 seed: int=None,
                 **kwargs):
        
        super().__init__(backbone_name=backbone_name,
                         pretrained=pretrained,
                         num_classes=num_classes,
                         pool_size=pool_size,
                         pool_type=pool_type,
                         head_type=head_type,
                         hidden_size=hidden_size,
                         lr=lr,
                         backbone_lr_mult=backbone_lr_mult,
                         weight_decay=weight_decay,
                         finetuning_strategy=finetuning_strategy,
                         seed=seed,
                        **kwargs)
        self._verbose=True
    
    def print(self, *args):
        if self._verbose:
            print(*args)
    
    
    def training_step(self, batch, batch_idx):
#         self.print(f"Before: self.training_step")
#         self.count_trainable_batchnorm_layers()
        output = super().training_step(batch, batch_idx)
        self.print(f"During: self.training_step")
        self.count_trainable_batchnorm_layers()
#         if self.eval_bn:
#             if not self.freeze_bn:
#                 self.unfreeze(self.model,
#                               filter_pattern="bn")

#         self.print(f"After: self.training_step")
#         self.count_trainable_batchnorm_layers()
        self._verbose=False
        return output
    
    def training_step_end(self, outputs): #batch, batch_idx):
#         self.print(f"Before: self.training_step_end")
#         self.count_trainable_batchnorm_layers()
        output = super().training_step_end(outputs) #batch, batch_idx)
#         self.print(f"After: self.training_step_end")
#         self.count_trainable_batchnorm_layers()
        
#         return output
    
#     def training_epoch_end(self, outputs):
#         self.print(f"Before: self.training_epoch_end")
        self.count_trainable_batchnorm_layers()
#         super().training_epoch_end(outputs)
#         self.print(f"After: self.training_epoch_end")
#         self.count_trainable_batchnorm_layers()
#         self._verbose=False
        
    
    
    

In [None]:
# source: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial5/Inception_ResNet_DenseNet.html
from lightning_hydra_classifiers.callbacks.finetuning_callbacks import FinetuningLightningCallback


def test_model_freeze_strategy(config, datamodule, **kwargs):
    """
    Inputs:
        model_name - Name of the model you want to run. Is used to look up the class in "model_dict"
        save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory.
    """
    
    config.model.finetuning_strategy = "finetuning_unfreeze_layers_on_plateau"
    
    group=f'{config.model.backbone.backbone_name}_{config.data.experiment.experiment_name}'#_task_{task_id}'
    config.logger.wandb.group = group
    config.callbacks.log_per_class_metrics_to_wandb.class_names = datamodule.classes

    callbacks = configure_callbacks(config)
    
    callbacks.append(FinetuningLightningCallback(monitor="val_loss",
                                                 mode="min",
                                                 patience=5))
    logger = configure_loggers(config)
    

    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=config.checkpoint_dir, #config.experiment_dir,
                         gpus=1,
                         max_epochs=config.trainer.max_epochs,
                         callbacks=callbacks,
#                                     ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),
#                                     LearningRateMonitor("epoch")],
                         logger=logger,
                         resume_from_checkpoint=config.trainer.resume_from_checkpoint,
                         progress_bar_refresh_rate=1)
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = config.trainer.resume_from_checkpoint #config.checkpoint_dir
    if os.path.isfile(str(pretrained_filename)):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = TestLightningClassifier.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        pl.seed_everything(config.model.seed)
        model = TestLightningClassifier(**config.model, **kwargs)
        model.label_encoder = datamodule.label_encoder
        


        if config.trainer.auto_lr_find:

            lr_tune_output = lr_tuner.run_lr_tuner(trainer=trainer,
                                                   model=model,
                                                   datamodule=datamodule,
                                                   config=config,
                                                   results_dir=config.lr_tuner_dir,
                                                   group="bn_eval_trials")
        
        
        
        trainer.fit(model, datamodule=datamodule)
        model = TestLightningClassifier.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training
        print(f"Best checkpoint saved to: {trainer.checkpoint_callback.best_model_path}")

    # Test best model on validation and test set
    val_result = trainer.test(model, test_dataloaders=datamodule.val_dataloader(), verbose=False)
    test_result = trainer.test(model, test_dataloaders=datamodule.test_dataloader(), verbose=False)
    
    try:
        result = {"test_acc": test_result[0]["test_acc"], "val_acc": val_result[0]["test_acc"]}
    except Exception as e:
        print(e)
        result = {"test_acc": test_result, "val_acc": val_result}
        
    result["ckpt_path"] = trainer.checkpoint_callback.best_model_path

    return model, result

In [None]:
from lightning_hydra_classifiers.scripts.multitask.train import MultiTaskDataModule, LitMultiTaskModule, ImagePredictionLogger, train_task,  CIFAR10DataModule, run_multitask_test, load_data_and_model, load_data, resolve_config, configure_callbacks, configure_loggers, configure_trainer



# from lightning_hydra_classifiers.scripts.finetune_demo import *
# # from lightning_hydra_classifiers.data.datasets.common import toPIL
# from lightning_hydra_classifiers.utils.etl_utils import ETL
# from omegaconf import OmegaConf
# import os

def get_config_and_load_data(overrides = None,
                             task_id: int = 1,
                             pool_type='avgdrop',
                             finetuning_strategy="feature_extractor_+_bn.eval()",
                             lr=2e-03,
                             dropout_p: float=0.3,
                             max_epochs: int=5):
    overrides = overrides or []    
    config = ETL.load_hydra_config(config_name = "finetune_config",
                                   config_path = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs",
                                   overrides=overrides)
    OmegaConf.set_struct(config, False)
    

    datamodule = load_data(config,
                           task_id=task_id)

    model_config = OmegaConf.create(dict(
                                    backbone={"backbone_name":config.model.backbone.backbone_name},
                                    backbone_name=config.model.backbone.backbone_name,
                                    pretrained=True,
                                    num_classes=datamodule.num_classes,
                                    pool_type=pool_type,
                                    head_type='linear',
                                    hidden_size=None,
                                    dropout_p=dropout_p,
                                    lr=2e-03,
                                    backbone_lr_mult=0.1,
                                    finetuning_strategy=finetuning_strategy,
                                    weight_decay=0.01,
                                    seed=98))
    config.model = model_config
#     config.trainer.max_epochs = max_epochs
#     config.trainer.auto_lr_find = False
#     config.experiment_name = f"{config.model.finetuning_strategy}-PNAS-{datamodule.num_classes}_classes-res_{config.data.image_size}-bsz_{config.data.batch_size}-{config.model.backbone_name}-pretrained_{config.model.pretrained}-pool_{config.model.pool_type}"
    
#     config.root_dir = os.path.join(os.getcwd(), "bn_unit_test_logs", config.model.pool_type)
#     config.lr_tuner_dir = os.path.join(config.results_dir, f"task_{task_id}", "lr_tuner")
    
    config = OmegaConf.create(OmegaConf.to_container(config, resolve=True))
    
#     os.makedirs(config.results_dir, exist_ok=True)
#     os.makedirs(config.checkpoint_dir, exist_ok=True)
#     os.makedirs(config.lr_tuner_dir, exist_ok=True)
    return config, datamodule


# model = LightningClassifier(**config.model)
# model = TestLightningClassifier(**config.model)
# model.label_encoder = datamodule.label_encoder

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:


# config = OmegaConf.load("/media/data/jacob/GitHub/lightning-hydra-classifiers/configs/finetune_config.yaml")

In [None]:
import wandb
wandb.save()

In [None]:
# OmegaConf.to_container(config, resolve=True)

# from lightning_hydra_classifiers.scripts.finetune_demo import *
# # from lightning_hydra_classifiers.data.datasets.common import toPIL
# from lightning_hydra_classifiers.utils.etl_utils import ETL
from omegaconf import OmegaConf
import os

from lightning_hydra_classifiers.scripts.finetune_demo import *

config = get_config(overrides=["task_id=1"])

config, datamodule = get_config_and_load_data(config)#overrides = None, task_id = 1, finetuning_strategy="finetuning_unfreeze_layers_on_plateau")

datamodule.train_dataset[3].metadata

In [None]:
ckpt_path = "/media/data_cifs/projects/prj_fossils/users/jacob/experiments/July2021-Nov2021/experiment_logs/Transfer_Experiments/finetune_trials/Extant-to-PNAS-512-transfer_benchmark-resnet50-Extant-PNAS_to_PNAS-92_classes-res_512-bsz_32-pretrained_imagenet-pool_avgdrop/replicate_1/checkpoints/epoch=09-val_loss=1.097-val_acc=0.621.ckpt"

In [None]:
import time

time.time()

In [None]:
# from lightning_hydra_classifiers.scripts.finetune_demo import *

# kwargs = {"backbone_name":"resnet50",
#           "num_classes":91}

model = LightningClassifier.init_pretrained_backbone_w_new_classifier(ckpt_path=ckpt_path,
                                                                      new_num_classes=19,
                                                                      **config.model)

In [None]:
plt.hist(model.model.head.classifier.weight.detach().numpy())

In [None]:
torch.save

In [None]:
model.model.load_state_dict()

In [None]:
import seaborn as sns



def plot_class_counts(df,
                      ax=None,
                      figsize=(25,10),
                      alpha=0.8,
                      ticklabel_rotation=40,
                      title: str=None):
    if ax is None:
        fig, ax = plt.subplots(1,1, figsize=figsize)
    else:
        fig = plt.gcf()
    sns.barplot(df.index, df.values, alpha=alpha, ax=ax)

    ax.set_xticklabels(ax.get_xticklabels(), 
                       rotation = ticklabel_rotation,
                       ha="right", fontsize="xx-small")
    if isinstance(title, str):
        ax.set_title(title)
    return fig, ax

In [None]:
y_col = "family"

for subset in ["train", "val", "test"]:

    df = getattr(datamodule, f"{subset}_dataset").samples_df.value_counts(y_col)

    fig, ax = plot_class_counts(df,
                                ax=None,
                                figsize=(25,10),
                                alpha=0.8,
                                ticklabel_rotation=40,
                                title=subset)

In [None]:
%%time


valid_strategies = ("finetuning_unfreeze_layers_on_plateau",)
# pool_types = ("avg", "avgdrop")#, "avgmax", "max", "avgmaxdrop")

# finetuning_strategy="feature_extractor"
# finetuning_strategy="feature_extractor_+_bn.eval()"

# pool_type='avgdrop'
# pool_type='avgmaxdrop'
pool_type="avg"
dropout_p = 0.3





all_results = {}

for strategy in valid_strategies:

    print(f"BEGINNING STRATEGY: {strategy}")
    overrides = ['model/backbone=resnet50',
                 "data=extant_to_pnas",
                 "trainer.max_epochs=1",
                 "trainer.auto_lr_find=true",
                 "trainer.precision=16",
                 "trainer.gpus=[0]",
                 "trainer.resume_from_checkpoint=null",
                 "data.batch_size=16",
                 "logger.wandb.project=finetuning_on_plateau"]

    config, datamodule = get_config_and_load_data(overrides = overrides,
                                                  task_id=1,
                                                  pool_type=pool_type,
                                                  finetuning_strategy=strategy, #"feature_extractor_+_bn.eval()",
                                                  lr=2e-03,
                                                  dropout_p=dropout_p)#,
#                                                   max_epochs=config.trainer.max_epochs)
    ckpt_paths = os.listdir(os.path.join(config.checkpoint_dir))
    if len(ckpt_paths) and os.path.exists(ckpt_paths[-1]):
        print(f"Found {ckpt_paths[-1]}")
        config.resume_from_checkpoint = ckpt_paths[-1]


    model, results = test_model_freeze_strategy(config, datamodule)
    model.cpu()
    del model

    results['model_config'] = OmegaConf.to_container(config.model, resolve=True)
    results['data_config'] = OmegaConf.to_container(config.data, resolve=True)
    
    ETL.config2yaml(results, os.path.join(config.results_dir, "results.yaml"))
    print(f"[SAVED TRIAL RESULTS] Location: {os.path.join(config.results_dir, 'results.yaml')}")
    pp(results)
    
    all_results[strategy] = results

print(f"ALL FINISHED!!! RESULTS:")
pp(all_results)


ETL.config2yaml(all_results, os.path.join(config.root_dir, "results.yaml"))

In [None]:
test_logs/avg/finetuning_unfreeze_layers_on_plateau-PNAS-19_classes-res_512-bsz_16-resnet50-p

In [None]:

ckpt_path = '/media/data/jacob/GitHub/lightning-hydra-classifiers/notebooks/bn_unit_test_logs/avg/finetuning_unfreeze_layers_on_plateau-PNAS-19_classes-res_512-bsz_16-resnet50-pretrained_True-pool_avg/replicate_1/results/checkpoints/epoch=00-val_loss=1.102-val_acc=0.475.ckpt'

## scratch

In [None]:
from lightning_hydra_classifiers.utils.common_utils import *

In [None]:
# subset_counts_df.T.plot(kind='bar', figsize=(16,9), multiple='stack')
import seaborn as sns

sns.set_theme(style="ticks", color_codes=True)
# sns.displot(data=subset_counts_df.T, kind='bar', figsize=(16,9), multiple='stack')

from IPython.display import display


available_palettes = ['Accent', 'Accent_r', 'Blues', 'Blues_r', 'BrBG', 'BrBG_r', 'BuGn', 'BuGn_r', 'BuPu', 'BuPu_r', 'CMRmap', 'CMRmap_r', 'Dark2', 'Dark2_r', 'GnBu', 'GnBu_r', 'Greens', 'Greens_r', 'Greys', 'Greys_r', 'OrRd', 'OrRd_r', 'Oranges', 'Oranges_r', 'PRGn', 'PRGn_r', 'Paired', 'Paired_r', 'Pastel1', 'Pastel1_r', 'Pastel2', 'Pastel2_r', 'PiYG', 'PiYG_r', 'PuBu', 'PuBuGn', 'PuBuGn_r', 'PuBu_r', 'PuOr', 'PuOr_r', 'PuRd', 'PuRd_r', 'Purples', 'Purples_r', 'RdBu', 'RdBu_r', 'RdGy', 'RdGy_r', 'RdPu', 'RdPu_r', 'RdYlBu', 'RdYlBu_r', 'RdYlGn', 'RdYlGn_r', 'Reds', 'Reds_r', 'Set1', 'Set1_r', 'Set2', 'Set2_r', 'Set3', 'Set3_r', 'Spectral', 'Spectral_r', 'Wistia', 'Wistia_r', 'YlGn', 'YlGnBu', 'YlGnBu_r', 'YlGn_r', 'YlOrBr', 'YlOrBr_r', 'YlOrRd', 'YlOrRd_r', 'afmhot', 'afmhot_r', 'autumn', 'autumn_r', 'binary', 'binary_r', 'bone', 'bone_r', 'brg', 'brg_r', 'bwr', 'bwr_r', 'cividis', 'cividis_r', 'cool', 'cool_r', 'coolwarm', 'coolwarm_r', 'copper', 'copper_r', 'crest', 'crest_r', 'cubehelix', 'cubehelix_r', 'flag', 'flag_r', 'flare', 'flare_r', 'gist_earth', 'gist_earth_r', 'gist_gray', 'gist_gray_r', 'gist_heat', 'gist_heat_r', 'gist_ncar', 'gist_ncar_r', 'gist_rainbow', 'gist_rainbow_r', 'gist_stern', 'gist_stern_r', 'gist_yarg', 'gist_yarg_r', 'gnuplot', 'gnuplot2', 'gnuplot2_r', 'gnuplot_r', 'gray', 'gray_r', 'hot', 'hot_r', 'hsv', 'hsv_r', 'icefire', 'icefire_r', 'inferno', 'inferno_r', 'magma', 'magma_r', 'mako', 'mako_r', 'nipy_spectral', 'nipy_spectral_r', 'ocean', 'ocean_r', 'pink', 'pink_r', 'plasma', 'plasma_r', 'prism', 'prism_r', 'rainbow', 'rainbow_r', 'rocket', 'rocket_r', 'seismic', 'seismic_r', 'spring', 'spring_r', 'summer', 'summer_r', 'tab10', 'tab10_r', 'tab20', 'tab20_r', 'tab20b', 'tab20b_r', 'tab20c', 'tab20c_r', 'terrain', 'terrain_r', 'twilight', 'twilight_r', 'twilight_shifted', 'twilight_shifted_r', 'viridis', 'viridis_r', 'vlag', 'vlag_r', 'winter', 'winter_r']

for p in available_palettes:
    print(p)
    display(sns.color_palette(p))

datamodule.current_task

data_splits=datamodule.current_task

train_key = ["train"]
sort_by = compute_class_counts(data_splits[train_key[0]].targets,
                               sort_by="count")

subset_counts_df = {}
for subset, values in data_splits.items():
    print(subset)
    values = compute_class_counts(data_splits[subset].targets,
                                  sort_by=sort_by)
    subset_counts_df[subset] = values

subset_counts_df = pd.DataFrame.from_dict(subset_counts_df)

In [None]:
df = subset_counts_df.T.stack().reset_index().rename(columns={0:"counts"})#.set_index("subset")

# df.index.name = "subset"
# df.columns.name = "target"


df

subset_counts_df.T#.unstack("target")

df_like = next(iter(df.groupby("subset")))[1]

totals = df_like[["target", "counts"]].assign(counts=df_like["counts"]*0).set_index("target")

# totals = df_like.set_index("subset").assign(count=df["count"]*0)
previous_totals = totals.copy()

previous_totals.counts + data_subset.set_index("target").counts

totals#.set_index("target")
previous_totals
# data_subset.set_index("target")

#     totals["count"] = previous_totals["count"] + data_subset["count"]

# totals["count"]
# previous_totals["count"] + 
data_subset["count"]

colors

df

df.stack("subset")

# gb = df.reset_index().groupby("subset").unstack()

df = df.reset_index().set_index("target")#, "target"))
# gb = df.unstack(("subset","target"))
df.plot(y="counts", hue="subset", kind="bar", stacked=True)

# gb.columns

gb#.set_index(keys=("subset", "target"))

help(gb.plot)

final_sum = df.groupby("target").agg(sum).sort_values("counts")

final_sum

data = totals.counts / final_sum

totals

# data
final_sum

In [None]:
colors = sns.color_palette("Set2")
i=0

for subset_name, data_subset in df.groupby("subset"):
    print(subset_name)
    
    totals.counts = previous_totals.counts + data_subset.set_index("target").counts
    previous_totals.counts = totals.counts
    
    data = totals.counts / final_sum
#     display(totals)
    bar = sns.barplot(data=totals.reset_index(), y="counts",x="target", label=subset_name, color=colors[i], alpha=0.3)#, kind='bar', palette="tab10_r")
#     bar = sns.barplot(data=totals, y="count",x="target", hue="subset", kind='bar', palette="tab10_r")
    i+=1
plt.legend()

# fig, ax = plt.subplots(1, 1, figsize=(16,9))
# bar = sns.catplot(data=df, x="target", y="count", hue="subset", kind='bar', ax=ax, palette="tab10_r")#, multiple='stack')
# bar = sns.catplot(data=df, x="target", y="count", hue="subset", kind='bar', figsize=(16,9), palette="tab10_r")

# for c in colors:
# data_bar_totals = pd.DataFrame.
# dir(pd.DataFrame)


for subset_name, data_subset in df.groupby("subset"):
    print(subset_name, data_subset)
    bar = sns.barplot(data=data_subset, y="count",x="target", hue="subset", kind='bar', palette="tab10_r")





# cat = sns.catplot(data=df, y="count",x="target", hue="subset", kind='bar', palette="tab10_r", height=5, aspect=3, multiple='stack')
# ax = plt.gca()
# ax.set_xticklabels(ax.get_xticklabels(), fontsize=14, rotation=30, ha="right");





sns.catplot(data=df, multiple='stack',  kind='bar')#, figsize=(16,9), multiple='stack')

# a = df.set_index(df.columns.tolist())
a = df.set_index(["subset", "target"])

a.index

# df.T.index.name


df.index#columns

df.T.reset_index()

df.melt(id_vars=["subset", "target"])

penguins = sns.load_dataset("penguins")
sns.displot(penguins, x="flipper_length_mm")

penguins

# # pd.DataFrame({"labels": [() for subset, values in data_splits.items()})

# data_splits_cat = []

# for subset, values in data_splits.items():
#     print(subset)
#     data_splits_cat.extend([(subset, v) for v in values])
    
    
# data_splits_cat = pd.DataFrame.from_records(data_splits_cat, columns=["subset","target"])

# data_splits_cat



sns.set_palette("Set2")

fig, ax = plot_split_distributions(data_splits=datamodule.current_task,
                                   use_one_axis=True,
                                   hist_kwargs={"alpha":0.4,
                                                "multiple":"fill"})
plt.legend()

display(ax[0])

ax[0].legend()

dir(datamodule)

In [None]:
# overrides = ['model/backbone=efficientnet_b3',"data=extant_to_fossil", "trainer.max_epochs=2", "data.batch_size=16", "trainer.precision=16"]
# overrides = ['model/backbone=resnet50',"data=extant_to_pnas", "trainer.max_epochs=20", "data.batch_size=32", "trainer.precision=16", "trainer.gpus=[7]"]
# config = ETL.load_hydra_config(config_name = "config",
#                               config_path = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs",
#                               overrides=overrides)

# valid_strategies : Tuple[str] = ("feature_extractor",
#                              "feature_extractor_+_bn.eval()",
#                              "feature_extractor_+_except_bn")
            
######################################

## Evaluate multiple strategies

In [None]:
%%time


valid_strategies = ("feature_extractor",
                    "feature_extractor_+_bn.eval()",
                    "feature_extractor_+_except_bn")
pool_types = ("avg", "avgdrop", "avgmax", "max", "avgmaxdrop")

# finetuning_strategy="feature_extractor"
# finetuning_strategy="feature_extractor_+_bn.eval()"

# pool_type='avgdrop'
# pool_type='avgmaxdrop'
pool_type="avg"
dropout_p = 0.3





all_results = {}

for strategy in valid_strategies:

    print(f"BEGINNING STRATEGY: {strategy}")
    overrides = ['model/backbone=resnet50',
                 "data=extant_to_pnas",
                 "trainer.max_epochs=10",
                 "trainer.auto_lr_find=true",
                 "trainer.precision=16",
                 "trainer.gpus=[7]",
                 "trainer.resume_from_checkpoint=null",
                 "logger.wandb.project=bn_global_pool_trials"]
    if strategy in ["feature_extractor_+_except_bn"]:
        overrides.append("data.batch_size=16")
    else:
        overrides.append("data.batch_size=32")


    config, datamodule = get_config_and_load_data(overrides = overrides,
                                                  task_id=1,
                                                  pool_type=pool_type,
                                                  finetuning_strategy=strategy, #"feature_extractor_+_bn.eval()",
                                                  lr=2e-03,
                                                  dropout_p=dropout_p)#,
#                                                   max_epochs=config.trainer.max_epochs)
    ckpt_paths = os.listdir(os.path.join(config.checkpoint_dir))
    if len(ckpt_paths) and os.path.exists(ckpt_paths[-1]):
        print(f"Found {ckpt_paths[-1]}")
        config.resume_from_checkpoint = ckpt_paths[-1]


    model, results = test_model_freeze_strategy(config, datamodule)
    model.cpu()
    del model

    results['model_config'] = OmegaConf.to_container(config.model, resolve=True)
    results['data_config'] = OmegaConf.to_container(config.data, resolve=True)
    
    ETL.config2yaml(results, os.path.join(config.results_dir, "results.yaml"))
    print(f"[SAVED TRIAL RESULTS] Location: {os.path.join(config.results_dir, 'results.yaml')}")
    pp(results)
    
    all_results[strategy] = results

print(f"ALL FINISHED!!! RESULTS:")
pp(all_results)


ETL.config2yaml(all_results, os.path.join(config.root_dir, "results.yaml"))

### plot strategy results

In [None]:
strategies_x = []
val_accs = []
test_accs = []
for k,v in all_results.items():
    strategies_x.append(k)
    try:
        val_accs.append(v['val_acc'])
        test_accs.append(v['test_acc'])
    except:
        val_accs.append(v['val'])
        test_accs.append(v['test'])
        
        
        
import matplotlib.pyplot as plt
import seaborn as sns


results_df = pd.DataFrame({"strategy":strategies_x,
                           "val_acc":val_accs,
                           "test_acc":test_accs})
results_df

barWidth = 0.25

r1 = np.arange(len(strategies_x))
r2 = [x + barWidth for x in r1]
r3 = [x + barWidth for x in r2]

plt.figure(figsize=(10,10))

plt.bar(r1, val_accs, color='#7f6d5f', width=barWidth, edgecolor='white', label='val')
plt.bar(r2, test_accs, color='#557f2d', width=barWidth, edgecolor='white', label='test')

 
# Add xticks on the middle of the group bars
plt.xlabel('finetuning strategy', fontweight='bold')
plt.ylabel("Macro avg acc")
plt.xticks([r + barWidth for r in range(len(strategies_x))], strategies_x, rotation=10)
plt.suptitle(f"{config.model.backbone_name} with global_pool={config.model.pool_type}." + "\n Classifier head trained on PNAS for <=10 epochs")
plt.tight_layout(rect=[0.0, 0.05, 1.0, 0.95])
plt.savefig(os.path.join(config.root_dir, "final_results_plot.png"))
print(f"Final results saved to", os.path.join(config.root_dir, "final_results_plot.png"))
# sns.barplot(data=results_df, x='strategy', y='val_acc')

In [None]:
Next: Add specific tests for different freeze strategies

## Training script development
(10-17-21)

In [None]:
# source: https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial5/Inception_ResNet_DenseNet.html

def train_model(model_name, save_name=None, **kwargs):
    """
    Inputs:
        model_name - Name of the model you want to run. Is used to look up the class in "model_dict"
        save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory.
    """
    if save_name is None:
        save_name = model_name

    # Create a PyTorch Lightning trainer with the generation callback
    trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, save_name),                          # Where to save models
                         gpus=1 if str(device)=="cuda:0" else 0,                                             # We run on a single GPU (if possible)
                         max_epochs=180,                                                                     # How many epochs to train for if no patience is set
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"),  # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer
                                    LearningRateMonitor("epoch")],                                           # Log learning rate every epoch
                         progress_bar_refresh_rate=1)                                                        # In case your notebook crashes due to the progress bar, consider increasing the refresh rate
    trainer.logger._log_graph = True         # If True, we plot the computation graph in tensorboard
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
    if os.path.isfile(pretrained_filename):
        print(f"Found pretrained model at {pretrained_filename}, loading...")
        model = CIFARModule.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameters
    else:
        pl.seed_everything(42) # To be reproducable
        model = CIFARModule(model_name=model_name, **kwargs)
        trainer.fit(model, train_loader, val_loader)
        model = CIFARModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    # Test best model on validation and test set
    val_result = trainer.test(model, test_dataloaders=val_loader, verbose=False)
    test_result = trainer.test(model, test_dataloaders=test_loader, verbose=False)
    result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}

    return model, result

In [None]:
# class LightningClassifier(BaseLightningModule):
#     def __init__(self,
#                  backbone_name='gluon_seresnext50_32x4d',
#                  pretrained: Union[bool, str]=True,
#                  num_classes: int=1000,
#                  pool_size: int=1,
#                  pool_type: str='avg',
#                  head_type: str='linear',
#                  hidden_size: Optional[int]=512,
#                  lr: float=2e-03,
#                  weight_decay: float=0.01,
#                  seed: int=None):
#         super().__init__(seed=seed)
#         self.save_hyperparameters()
        
#         self.model = build_model(backbone_name=backbone_name,
#                                       pretrained=pretrained,
#                                       num_classes=num_classes,
#                                       pool_size=pool_size,
#                                       pool_type=pool_type,
#                                       head_type=head_type,
#                                       hidden_size=hidden_size)
    
#         self.criterion = nn.CrossEntropyLoss()
#         self.metrics = self.init_metrics(stage='all')
    
#     def forward(self,x):
#         return self.model(x)
    
    
#     def get_lr(self, group: str=None):
#         if group is None:
#             return self.hparams.lr
#         if group == "backbone":
#             return self.hparams.lr * 0.1
#         if group == "head":
#             return self.hparams.lr
    
#     def configure_optimizers(self):
#         print(f"self.hparams={self.hparams}")
#         self.optimizer = torch.optim.AdamW([{"params":self.model.backbone.parameters(), "lr":self.get_lr("backbone"), "weight_decay": self.hparams.weight_decay},
#                                             {"params":self.model.head.parameters(), "lr":self.get_lr("head"), "weight_decay": self.hparams.weight_decay}])
# #         self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.config.t_max, eta_min=self.config.min_lr)

#         return {'optimizer': self.optimizer}

In [None]:
# class BnFreeze(Callback):
# source: https://github.com/fastai/fastai/blob/master/fastai/callback/training.py#L55
#     run_after=TrainEvalCallback
#     "Freeze moving average statistics in all non-trainable batchnorm layers."
#     def before_train(self):
#         set_bn_eval(self.model)
        
        
        

# Experiments

For each architecture, 3 different experimental training methods will be evaluated

Training methods:  
1. Feature Extractor  
2. Feature Extractor + set batchnorm to eval()  
3. Freeze backbone except for all batchnorm layers  



Baseline:  
   * Architecture 1. simple pretrained backbone -> `avg_pool` -> linear_classifier  
    
Next Comparisons:  
* Architecture 2. simple pretrained backbone -> `max_pool` -> linear_classifier
* Architecture 3. simple pretrained backbone -> `avgmax_pool` -> linear_classifier  

Later:  
* Architecture 4. simple pretrained backbone -> `avgdrop_pool` -> linear_classifier  
* Architecture 5. simple pretrained backbone -> `maxdrop_pool` -> linear_classifier  
* Architecture 6. simple pretrained backbone -> `avgmaxdrop_pool` -> linear_classifier  

In [None]:
# model = build_model(backbone_name='gluon_seresnext50_32x4d',
#                     pretrained=True,
#                     num_classes=19,
#                     pool_size=1,
#                     pool_type='avgmax',
#                     head_type='linear',
#                     hidden_size=None)
# from torchinfo import summary


#left off here 11 pm

# model = LightningClassifier(backbone_name='gluon_seresnext50_32x4d',
#                             pretrained=True,
#                             num_classes=19,
#                             pool_size=1,
#                             pool_type='avgmax',
#                             head_type='linear',
#                             hidden_size=None,
#                             lr=2e-03,
#                             weight_decay=0.01,
#                             seed=98)

# pp(list(model.get_batchnorm_modules()))
# # pp(list(model.get_conv_modules()))
# pp(list(model.get_linear_modules()))
# # pp(list(model.get_conv_modules()))
# pp(list(model.get_named_modules()))

# model.freeze_backbone(freeze_bn=False)
# model.freeze_backbone(freeze_bn=True)
# summary(model.model)

# print(f"trainable: {len(list(model.get_trainable_parameters()))}")
# print(f"non-trainable: {len(list(model.get_nontrainable_parameters()))}")

# bn = {n:p.requires_grad for n, p in model.get_named_parameters("bn")}

### Mini test: Wrap all model hooks & display + verify for each parameter group the proper status of module.training & weight_tensors.requires_grad

### 1. feature extractor

In [None]:
model.freeze_backbone(freeze_bn=True) #False)
# model.set_bn_eval()
# summary(model.model)
model.count_trainable_batchnorm_layers()
count_parameters(model.model, verbose=False);

### 2. Feature extractor + BN set to Eval()

In [None]:
model.freeze_backbone(freeze_bn=True)
model.set_bn_eval(model)
# summary(model.model)
model.count_trainable_batchnorm_layers()
count_parameters(model.model, verbose=False);

### 2. Freeze backbone except for BN layers

In [None]:
# model.unfreeze(model.model)
model.freeze_backbone(freeze_bn=False)

model.unfreeze(model.model,
               filter_pattern="bn")
# summary(model.model)
model.count_trainable_batchnorm_layers()
count_parameters(model.model, verbose=False);

## Load data, model, trainer, callbacks, logger

In [None]:
from lightning_hydra_classifiers.scripts.multitask.train import MultiTaskDataModule, LitMultiTaskModule, ImagePredictionLogger, train_task,  CIFAR10DataModule, run_multitask_test, load_data_and_model, load_data, resolve_config, configure_callbacks, configure_loggers, configure_trainer
from lightning_hydra_classifiers.data.datasets.common import toPIL
from lightning_hydra_classifiers.utils.etl_utils import ETL
from omegaconf import OmegaConf


# overrides = ['model/backbone=efficientnet_b3',"data=extant_to_fossil", "trainer.max_epochs=2", "data.batch_size=16", "trainer.precision=16"]
overrides = ['model/backbone=resnet50',"data=extant_to_pnas", "trainer.max_epochs=20", "data.batch_size=32", "trainer.precision=16", "trainer.gpus=[7]"]
config = ETL.load_hydra_config(config_name = "config",
                              config_path = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs",
                              overrides=overrides)

task_id = 1
pp(config.data)
datamodule = load_data(config,
                       task_id=task_id)


        
# model_config = OmegaConf.create(dict(
#                                 backbone_name=config.model.backbone.backbone_name, #'gluon_seresnext50_32x4d',
#                                 pretrained=True,
#                                 num_classes=datamodule.num_classes,
#                                 pool_type='avg',
#                                 head_type='linear',
#                                 hidden_size=None,
#                                 lr=2e-03,
#                                 weight_decay=0.01,
#                                 seed=98))

# model_config = OmegaConf.create(dict(
#                                 backbone_name=config.model.backbone.backbone_name, #'gluon_seresnext50_32x4d',
#                                 pretrained=True,
#                                 num_classes=datamodule.num_classes,
#                                 pool_type='max',
#                                 head_type='linear',
#                                 hidden_size=None,
#                                 lr=2e-03,
#                                 weight_decay=0.01,
#                                 seed=98))

model_config = OmegaConf.create(dict(
                                backbone_name=config.model.backbone.backbone_name, #'gluon_seresnext50_32x4d',
                                pretrained=True,
                                num_classes=datamodule.num_classes,
                                pool_type='avgdrop',
                                head_type='linear',
                                hidden_size=None,
                                lr=2e-03,
                                backbone_lr_mult=0.1,
                                finetuning_strategy="feature_extractor",
                                weight_decay=0.01,
                                seed=98))



config.model = model_config

algorithm_name = "feature_extractor"
config.experiment_name = f"{algorithm_name}-PNAS-{datamodule.num_classes}_classes-res_{config.data.image_size}-bsz_{config.data.batch_size}-{config.model.backbone_name}-pretrained_{config.model.pretrained}-pool_{config.model.pool_type}"


experiment_dir = config.experiment_dir
results_dir = config.results_dir
results_dir

model = LightningClassifier(**config.model)
model.label_encoder = datamodule.label_encoder


group = f'{config.model.backbone_name}__PNAS__experiment_0__feature_extractor'
config.logger.wandb.group = group
config.callbacks.log_per_class_metrics_to_wandb.class_names = datamodule.classes


callbacks = configure_callbacks(config)
logger = configure_loggers(config)

trainer: pl.Trainer = configure_trainer(config, callbacks=callbacks, logger=logger)

## Run lr_tune->fit->test

In [None]:
from lightning_hydra_classifiers.scripts.pretrain import lr_tuner


lr_tuner_results_dir = os.path.join(results_dir, f"task_{task_id}", "lr_tuner")
lr_tune_output = lr_tuner.run_lr_tuner(trainer=trainer,
                                       model=model,
                                       datamodule=datamodule,
                                       config=config,
                                       results_dir=lr_tuner_results_dir,
                                       group=group)

## model.fit

hist = trainer.fit(model, datamodule=datamodule)

## model.test

test_result = trainer.test()

# Other

In [None]:
import os
import pandas as pd 
import numpy as np
import logging
from sklearn.metrics import classification_report, f1_score

In [None]:
classification_report


# def predict_step(batch, batch_idx=None):
#     out = self.step(batch, batch_idx)
#     if hasattr(batch, "metadata"):
#         if "path" in batch.metadata:
#             out = [*out, batch.metadata["path"]]
#     return out

# self=model
# model.predict_step = predict_step
# test_results = trainer.predict(dataloaders=datamodule.test_dataloader(), return_predictions=True)
# results = collect_results(prediction_results)
prediction_results = test_results
results = collect_results(prediction_results)
len(prediction_results[0])
# len(results)

def tensors2np(t: Union[torch.Tensor, list]) -> np.ndarray:
    if isinstance(t, torch.Tensor):
        t = t.cpu().numpy()
    elif isinstance(t, list):
        t = list(map(tensors2np, t))
    if isinstance(t, np.ndarray):
        return t
    else:
        raise TypeError(f"type(t)={type(t)} is invalid for function tensors2np" + '\n' + 'tensors2np(t: Union[torch.Tensor, list]) -> np.ndarray:')
        
rows = []
for result in list(prediction_results):
    
    y_logit.append(result[0])
    y_true.append(result[1])
    y_pred.append(result[2])
    paths.extend(result[3])
    
y_logit = torch.cat(y_logit).cpu().numpy()
y_true = torch.cat(y_true).cpu().numpy()
y_pred = torch.cat(y_pred).cpu().numpy()
# paths = torch.cat(paths).cpu().numpy()

# [(r[0].device, r[0].shape, r[1].shape, r[2].shape) for r in test_results]

print(y_logit.shape, y_true.shape, y_pred.shape, len(paths))

In [None]:
label_encoder = model.label_encoder




In [None]:
from torch.functional import F



# https://www.google.com/url?sa=t&rct=j&q=&esrc=s&source=web&cd=&cad=rja&uact=8&ved=2ahUKEwji_bKG6cPzAhXJT98KHU0eCsg4HhAWegQIAhAB&url=https%3A%2F%2Fclear.ml%2Fdocs%2Flatest%2Fdocs%2Fguides%2Freporting%2Fexplicit_reporting%2F&usg=AOvVaw3tvUYT7fU3QHIwunDpE800
labels = model.label_encoder.classes

test_predictions_filepath = os.path.join(results_dir, f"task_{task_id}", "test_predictions.csv")


class ImageInterpretation:
    
    def __init__(self, model, datamodule, trainer, y_col: str='family'):
        self.model = model
        self.dm = datamodule
        self.trainer = trainer
        self.y_col = y_col

        
    @property
    def decoder(self):
        return self.dm.label_encoder.idx2class
        
    def decode_label(y: int): #, labels: Union[Dict[int, str], List[str]]=None):
        try:
            return self.decoder[y]
        except:
            return y
        
    def log_image_predictions(self,
                              results_path: str=None,
                              sort_by_losses: bool=True,
                              ascending: bool=True) -> pd.DataFrame:
        """
        Save table of model predictions as csv
        
        |losses	|y_true	|y_pred	|paths 	|per-class logits|
        |---	|---	|---	|---	| ---	 |---	 |
        |   	|   	|   	|   	|   	 |   	 |
        |   	|   	|   	|   	|   	 |   	 |
        
        
        """

        pred_results = trainer.predict(dataloaders=datamodule.test_dataloader(), return_predictions=True)

        results = collect_results(pred_results)

        labels = list(self.decoder.values())
        columns = ["xEnt_loss", f"{self.y_col}_true", f"{self.y_col}_pred", "paths", *[f"{l}_logit" for l in labels]]

        y_logits = torch.from_numpy(results[0].astype("float32"))
        y_true = torch.from_numpy(results[1])
        xEnt_loss = F.cross_entropy(y_logits, y_true, reduction="none")

        losses = xEnt_loss
        y_true = results[1]
        y_pred = results[2]
        paths = results[3]
        per_class_y_logits = np.hsplit(results[0], results[0].shape[1])

        num_results = len(results[0])
        rows = []
        for i in range(num_results):
            rows.append({k:v for k, v in zip(columns,
                                             [losses,
                                              self.decode_label(y=y_true[i]),
                                              self.decode_label(y=y_pred[i]),
                                              paths[i],
                                              *(y[i].item() for y in per_class_y_logits)]
                                            )
                        })

        data_df = pd.DataFrame.from_records(rows)
        
        if sort_by_losses:
            data_df = data_df.sort_values("xEnt_loss", ascending=ascending)
        
        ETL.df2csv(data_df, results_spath)
        
        return data_df

        
        

#         data = {"xEnt_loss":xEnt_loss,
#                 f"{self.y_col}_true":results[1],
#                 f"{self.y_col}_pred":results[2],
#                 "path":results[3],
#                 "y_logits":np.hsplit(results[0], results[0].shape[1])}    
    
    
    
#         rows = []
#         for i in range(num_results):
#             rows.append({k:v for k, v in zip(columns,
#                                              [data["xEnt_loss"][i], 
#                                               self.decode_label(y=data[f"{self.y_col}_true"][i]),
#                                               self.decode_label(y=data[f"{self.y_col}_pred"][i]),
#                                               data["path"][i],
#                                               *(y[i].item() for y in data["y_logits"])]
#                                             )
#                         })

#         data_df = pd.DataFrame.from_records(rows)
#         ETL.df2csv(data_df, test_predictions_filepath)



In [None]:
data_df.sort_values("Anacardiaceae_logit")

In [None]:
data_df
test_predictions_filepath

## generate report

In [None]:
from lightning_hydra_classifiers.utils.report_utils.pandas_embed_images import df_embed_paths2imgs



df_embed_paths2imgs(df: pd.DataFrame,
                        file_path: str, 
                        path_col: str="path",
                        display: bool=False

In [None]:
len(paths)

y_logit#.shape

In [None]:
import seaborn as sns

sns.heatmap(y_logit)

In [None]:
# logger = logging.getLogger(__name__)

def generate_report(y_pred,
                    y_true, 
                    labels=None,
                    results_dir: str=None):
    """Create a performance report for the current experiment and 
    consolidate the information to a general report of all runs 
    Parameters
    ----------
    opt : sklearn.model_selection.Object
        A hyperparameter 
    X_test: numpy array or pandas Dataframe 
        Input test data
    y_test: numpy array or pandas Dataframe 
        Target test data
    """


    logger.info("Generating Evaluation Report:")
    

    res = classification_report(y_true, y_pred, labels=labels, output_dict=True)
    res = pd.DataFrame(res)

    logger.info("Test report:")
    logger.info('\n \t'+ res.to_string().replace('\n', '\n\t'))
    
    f1 = f1_score(y_true, y_pred, labels=labels, average='macro')
    
    steps= [*pipeline.named_steps]

    cv_mean ,cv_std = opt.best_score_,opt.cv_results_['std_test_score'][opt.best_index_]

    tmp= pd.DataFrame({"Scaling":[steps[0]],
                        "Model":[steps[1]],
                        "params":[opt.best_params_],
                        'CV Mean':[cv_mean],
                        'CV Std':[cv_std],
                        'Test dataset':f1,
                        })

    if os.path.exists(path+"/results.csv"):
        current_csv =pd.read_csv(path+"/results.csv")
        pd.concat([current_csv, tmp], 
                   ignore_index=True
                 ).to_csv(path+"/results.csv",
                          index=False)    
    else:
        tmp.to_csv(path+"/results.csv",
                   index=False)

In [None]:


model = LightningClassifier(backbone_name='gluon_seresnext50_32x4d',
                            pretrained=True,
                            num_classes=19,
                            pool_size=1,
                            pool_type='avgmax',
                            head_type='linear',
                            hidden_size=None,
                            lr=2e-03,
                            weight_decay=0.01,
                            seed=98)

In [None]:
# import bqplot.pyplot as plt
# from dataclasses import dataclass
# @dataclass
# class LRTunerConfig:
    
#     min_lr: float = 1e-08
#     max_lr: float = 1.0
#     num_training: int = 50
#     mode: str = 'exponential'
#     early_stop_threshold: float = 4.0

# cfg = OmegaConf.structured(LRTunerConfig())

# lr_tuner = trainer.tuner.lr_find(model,
#                                  data,
#                                  **cfg)
# lr_tuner_results = lr_tuner.results
# best_lr = lr_tuner.suggestion()

# suggestion = {"lr": best_lr,
#               "loss":lr_tuner_results['loss'][lr_tuner._optimal_idx]}

# plt.figure()
# fig = lr_tuner.plot(suggest=True)
# lr_tuner_results_dir = os.path.join(results_dir, f"task_{task_id}", "lr_tuner")

# plot_fname = 'lr_tuner_results_loss-vs-lr.png'
# plot_path = Path(lr_tuner_results_dir) / plot_fname
# plt.title(f"Suggested lr={best_lr:.4e} |\n| Searched {lr_tuner.num_training} lr values $\in$ [{lr_tuner.lr_min},{lr_tuner.lr_max}] |\n| bsz = {config.data.batch_size}", style={"fontsize":'small'})
# fig.save_png(filename=str(plot_path))

# fig = plt.figure()
# plt.plot(x=lr_tuner.results['lr'],
#          y=lr_tuner.results['loss'],
#         figure=fig)

## Display available global pool types

### Aside: Verify task_0 and task_1 label maps all agree

In [None]:
task_0_labels = data.label_encoder

data.setup(stage='fit', task_id=1)

task_1_labels = data.label_encoder

task_0_labels
task_1_labels

task_0_labels
print(f"label|task_0_idx|task_1_idx")
for label, idx in task_1_labels.class2idx.items():
    print(f"{label}|{task_0_labels.class2idx[label]}|{idx}")
    
    assert task_0_labels.class2idx[label] == idx
    
print(f"Success, all labels in task_1 have identical integer mappings to their corresponding values in task_0")

In [None]:
class2idx = data.label_encoder.class2idx
family_counts = df.value_counts("family").to_dict()

df = df.assign(class_idx=df.family.apply(lambda x: class2idx[x]),
               score = df.family.apply(lambda x: family_counts[x]))

df#.clear_intent()

df.groupby('class_idx').mean()

# df.exported.keys()

df.exported['Distribution']

df.compute_metadata()

df

df.clear_intent()

In [None]:
backbone_name='gluon_seresnext50_32x4d'
pretrained=True
num_classes=1000

head_type='linear'
hidden_size=0

pool_types = ["avg", "max", "avgmax"]

models = OrderedDict({})

for pool_type in pool_types:
    models[pool_type] = build_model(backbone_name=backbone_name,
                                    pretrained=pretrained,
                                    num_classes=num_classes,
                                    pool_size=1,
                                    pool_type=pool_type,
                                    head_type=head_type,
                                    hidden_size=hidden_size)
print(f"backbone={backbone_name}|pretrained={pretrained}|num_classes={num_classes}|head_type={head_type}|hidden_size={hidden_size}")
for pool_type, model in models.items():
    print(f"pool_type={pool_type}")
    pp({k: v.shape for k,v in model.head.named_parameters()})
#     pp(list(dict(model.head.named_parameters()).keys()))

## Display available head types (TBD)

In [None]:
backbone_name='gluon_seresnext50_32x4d'
pretrained=True
num_classes=1000

head_types=['linear', 'custom']
hidden_size=0

pool_types = ["avg", "max", "avgmax"]

models = OrderedDict({})

for pool_type in pool_types:
    models[pool_type] = build_model(backbone_name=backbone_name,
                                    pretrained=pretrained,
                                    num_classes=num_classes,
                                    pool_size=1,
                                    pool_type=pool_type,
                                    head_type=head_type,
                                    hidden_size=hidden_size)
print(f"backbone={backbone_name}|pretrained={pretrained}|num_classes={num_classes}|head_type={head_type}|hidden_size={hidden_size}")
for pool_type, model in models.items():
    print(f"pool_type={pool_type}")
    pp({k: v.shape for k,v in model.head.named_parameters()})
#     pp(list(dict(model.head.named_parameters()).keys()))

## Creating fit method

In [None]:
#Borrowed from fastai2 library

bn_types = (torch.nn.modules.batchnorm.BatchNorm1d,torch.nn.modules.batchnorm.BatchNorm2d,torch.nn.modules.batchnorm.BatchNorm3d)
 
def set_bn_eval(m:nn.Module)->None:
    "Set bn layers in eval mode for all recursive children of `m`."
    for l in m.children():
        if isinstance(l, bn_types) and not next(l.parameters()).requires_grad:
            l.eval()
        set_bn_eval(l)

In [None]:
def fit(epochs,model,train_dl,valid_dl,loss_fn,opt,device=None,bn_eval=False):
    if device is None:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    mb = master_bar(range(epochs))
    mb.write(['epoch','train_loss','valid_loss','trn_acc','val_acc'],table=True)
    model.to(device)

    for i in mb:    
        trn_loss,val_loss = 0.0,0.0
        trn_acc,val_acc = 0,0
        trn_n,val_n = len(train_dl.dataset),len(valid_dl.dataset)
        model.train()
        if bn_eval:set_bn_eval(model)
        for xb,yb in progress_bar(train_dl,parent=mb):
            xb,yb = xb.to(device), yb.to(device)
            out = model(xb)
            opt.zero_grad()
            loss = loss_fn(out,yb)
            _,pred = torch.max(out.data, 1)
            trn_acc += (pred == yb).sum().item()
            trn_loss += loss.item()
            loss.backward()
            opt.step()
        trn_loss /= mb.child.total
        trn_acc /= trn_n

        model.eval()
        with torch.no_grad():
            for xb,yb in progress_bar(valid_dl,parent=mb):
                xb,yb = xb.to(device), yb.to(device)
                out = model(xb)
                loss = loss_fn(out,yb)
                val_loss += loss.item()
                _,pred = torch.max(out.data, 1)
                val_acc += (pred == yb).sum().item()
        val_loss /= mb.child.total
        val_acc /= val_n

        mb.write([i,f'{trn_loss:.6f}',f'{val_loss:.6f}',f'{trn_acc:.6f}',f'{val_acc:.6f}'],table=True)        
        
    

## Training

In [None]:
loss_fn = F.cross_entropy

In [None]:
def freeze(model,bn_freeze=True):
    for name,param in model.named_parameters():
        if bn_freeze:
            param.requires_grad = False
        elif name.find('bn') == -1:
            param.requires_grad = False
            
def unfreeze(model):
    for param in model.parameters():
        param.requires_grad = True

def get_model(lrs=[1e-3,1e-3],bn_freeze=True):
    model = MyResNet()
    freeze(model.body,bn_freeze=bn_freeze)
    opt = optim.Adam([{'params': model.body.parameters(), 'lr':lrs[0]},
                {'params': model.head.parameters(), 'lr': lrs[1]}])
    return model,opt

def update_lr(lr,opt):
    opt.param_groups[0]['lr'] = lr/100
    opt.param_groups[1]['lr'] = lr
    

### Freeze the complete resnet body

lr = 1e-3
model,opt = get_model(lrs=[lr,lr],bn_freeze=True)
fit(2,model,trn_dl,valid_dl,loss_fn,opt)

### Freeze the complete resnet body and place BN layers in eval mode.

lr = 1e-3
model,opt = get_model(lrs=[lr,lr],bn_freeze=True)
fit(2,model,trn_dl,valid_dl,loss_fn,opt,bn_eval=True)

### Freeze the complete resnet body and place BN layers in eval mode and train the body at a lesser learning rate for the second epoch.

In [None]:
lr = 1e-3
model,opt = get_model(lrs=[lr,lr],bn_freeze=True)
fit(1,model,trn_dl,valid_dl,loss_fn,opt,bn_eval=True)

In [None]:
update_lr(lr/2,opt)
unfreeze(model)
fit(1,model,trn_dl,valid_dl,loss_fn,opt)

### Freeze the resnet body except fot BN layers

In [None]:
model,opt = get_model(lrs=[1e-3,1e-3],bn_freeze=False)
fit(2,model,trn_dl,valid_dl,loss_fn,opt)

### Freeze the resnet body except fot BN layers and try smaller leraning rate for the resnet body

In [None]:
model,opt = get_model(lrs=[lr,lr],bn_freeze=False)
fit(1,model,trn_dl,valid_dl,loss_fn,opt)
update_lr(lr/2,opt)
unfreeze(model)
fit(1,model,trn_dl,valid_dl,loss_fn,opt)

### Try adjusting the model

In [None]:
class AdaptiveConcatPooling(nn.Module):
    def forward(self,x):
        avg_pool = F.adaptive_avg_pool2d(x,1)
        max_pool = F.adaptive_max_pool2d(x,1)
        return torch.cat([avg_pool,max_pool],dim=1)

In [None]:
class MyResNet(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet34(pretrained=True)
        self.body = nn.Sequential(*list(resnet.children())[:-2])
        self.head = nn.Sequential(AdaptiveConcatPooling(),Flatten(),nn.Linear(512*2,2))
    
    def forward(self,x):
        x = self.body(x)
        return self.head(x)

In [None]:
lr = 1e-3
model,opt = get_model(lrs=[lr,lr],bn_freeze=False)
fit(1,model,trn_dl,valid_dl,loss_fn,opt)
update_lr(lr/2,opt)
unfreeze(model)
fit(1,model,trn_dl,valid_dl,loss_fn,opt)

### Increasing the complexity of the model

In [None]:
class MyResNet(nn.Module):
    def __init__(self):
        super().__init__()
        nf = 512*2
        resnet = models.resnet34(pretrained=True)
        self.body = nn.Sequential(*list(resnet.children())[:-2])
        self.head = nn.Sequential(AdaptiveConcatPooling(),Flatten(),nn.BatchNorm1d(nf),nn.Dropout(p=0.25),
                      nn.Linear(nf,nf//2,bias=False),nn.ReLU(inplace=True),nn.BatchNorm1d(nf//2),nn.Dropout(p=0.75),
                      nn.Linear(nf//2,2,bias=False))
    
    def forward(self,x):
        x = self.body(x)
        return self.head(x)

In [None]:
lr = 1e-3
model,opt = get_model(lrs=[lr,lr],bn_freeze=False)
fit(1,model,trn_dl,valid_dl,loss_fn,opt)
update_lr(lr/2,opt)
unfreeze(model)
fit(1,model,trn_dl,valid_dl,loss_fn,opt)

## scratch

In [None]:


                

# ## WIP: display_layer_status
#     @classmethod
#     def display_layer_status(cls,
#                              model: nn.Module,
#                              max_depth: int=3):
#         """
#         Return a formatted display of model's layers alongside relevant training status info.
#         """
#         modules = []
        
# #         for name, module in model.named_modules():
#         for name, module in model.named_children():
# #             print(name, 'max_depth:', max_depth)
#             if name=="": continue
# #             if (max_depth>0) and (len(list(module.named_modules())) > 0):
#             if (max_depth>0) and (len(list(module.named_children())) > 0):
#                 modules.extend(cls.display_layer_status(module, max_depth=max_depth-1))
#                 continue
#             module_out = {"name":name,
#                           "training":module.training,
#                           "type":type(module),
#                           "params":[]}
#             for param_name, param in module.named_parameters():
#                 module_out["params"].append({
#                     "name":param_name,
#                     "type":type(param),
#                     "requires_grad":param.requires_grad,
#                     "shape":param.shape
#                 })
#             print(name)
#             pp(module_out)
#             modules.append(module_out)
#         return modules


In [None]:
# for unfreeze_down_to in reversed(range(0,-8,-1)):
# model.freeze_backbone(freeze_bn=False)
# model.freeze_backbone(freeze_bn=True)
# summary(model.model)
# count_trainable_batchnorm_layers(model)

# for unfreeze_down_to in range(0,-9,-1):
#     print(unfreeze_down_to)
#     print(f"Unfreezing backbone down to layer: {unfreeze_down_to}")
#     model.unfreeze_backbone_top_layers(unfreeze_down_to=unfreeze_down_to)
# #     summary(model.model)
#     model.count_trainable_batchnorm_layers()
# #     count_trainable_batchnorm_layers(model)

#     print(f"trainable parameters: {len(list(model.get_trainable_parameters()))}")
#     print(f"non-trainable parameters: {len(list(model.get_nontrainable_parameters()))}")

# unfreeze_down_to = -2

# model.unfreeze_backbone_top_layers(unfreeze_down_to=unfreeze_down_to)
# summary(model.model)
# count_trainable_batchnorm_layers(model)

# print(f"trainable parameters: {len(list(model.get_trainable_parameters()))}")
# print(f"non-trainable parameters: {len(list(model.get_nontrainable_parameters()))}")

# unfreeze_down_to = -3

# model.unfreeze_backbone_top_layers(unfreeze_down_to=unfreeze_down_to)
# summary(model.model)
# print(f"trainable: {len(list(model.get_trainable_parameters()))}")
# print(f"non-trainable: {len(list(model.get_nontrainable_parameters()))}")

# unfreeze_down_to = -4

# model.unfreeze_backbone_top_layers(unfreeze_down_to=unfreeze_down_to)
# summary(model.model)
# print(f"trainable: {len(list(model.get_trainable_parameters()))}")
# print(f"non-trainable: {len(list(model.get_nontrainable_parameters()))}")

In [None]:


#     model = timm.create_model(model_name=backbone_name, num_classes=1000, pretrained=pretrained)
#     if isinstance(pretrained, str) and pretrained != "imagenet":
#         model = load_model_checkpoint(model, ckpt_path=pretrained)
# #         ckpt_pth = glob.glob(hydra.utils.to_absolute_path(pretrained))
# #         model = model.load_state_dict(torch.load(ckpt_pth[0]))
        
#     body = nn.Sequential(*list(model.children())[:-2])

    
#     feature_size = model.fc.in_features
    
#     head = OrderedDict()
#     global_pool, feature_size = build_global_pool(pool_type=pool_type,
#                                                   pool_size=pool_size,
#                                                   feature_size=feature_size)
#     head["global_pool"] = global_pool
#     head["flatten"] = Flatten()
    
#     classifier_input_feature_size = feature_size*(pool_size*2)        
#     if head_type=='linear':
#         head["classifier"] = nn.Linear(classifier_input_feature_size, num_classes)
#     elif head_type=='custom':
#         head["classifier"] = nn.Sequential(nn.Linear(classifier_input_feature_size, hidden_size),
#                                 nn.RReLU(lower=0.125, upper=0.3333333333333333, inplace=False),
#                                 nn.BatchNorm1d(hidden_size),
#                                 nn.Linear(hidden_size, num_classes))
        
#     head = nn.Sequential(head)


#     model = nn.Sequential(OrderedDict({
#         "body":body,
#         "head":head
#     }))
#     return model

# def build_model(backbone_name='gluon_seresnext50_32x4d',
#                 pretrained: Union[bool, str]=True,
#                 num_classes: int=1000,
#                 pool_size: int=1,
#                 pool_type: str='avg',
#                 head_type: str='linear',
#                 hidden_size: Optional[int]=512):
    
#     try:
#         model = build_timm_custom(backbone_name=backbone_name,
#                                   pretrained=pretrained,
#                                   num_classes=num_classes,
#                                   pool_size=pool_size,
#                                   pool_type=pool_type,
#                                   head_type=hidden_size,
#                                   hidden_size=hidden_size)

#     except:
#         print

        
        
        