# Learning rate decay scheduling

This notebook is not specific to Distiller.

When fine-tuning or training a model, you may want to try different LR-decay policies.  This notebook shows how the different policies work.

In [None]:
from ipywidgets import widgets, interact
import matplotlib.pyplot as plt
import torch
from torch.optim.lr_scheduler import *

In [None]:
import torchvision
model = torchvision.models.alexnet(pretrained=True).cuda()

In [None]:
@interact(first_epoch=(0,100), last_epoch=(1,100), step_size=(1, 30, 1), gamma=(0, 1, 0.05), lr='0.001', T_max=(1,10),
         enable_steplr=True, 
         enable_explr=True,
         enable_cosinelr=False,
         enable_multisteplr=True)
def draw_schedules(first_epoch=0, last_epoch=50, step_size=3, gamma=0.9, lr=0.001, T_max=1, 
                   enable_steplr=True,
                   enable_explr=True,
                   enable_cosinelr=False,
                   enable_multisteplr=True):
    lr = float(lr)
    optimizer = torch.optim.SGD(model.parameters(), lr, momentum=0.9, weight_decay=0.0001)

    schedulers = {}
    if enable_explr:
        schedulers['ExponentialLR'] = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)
    if enable_steplr:
        schedulers['StepLR'] = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma)
    if enable_cosinelr:
        schedulers['CosineAnnealingLR'] = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max)
    if enable_multisteplr:
        schedulers['MultiStepLR'] = MultiStepLR(optimizer, milestones=[30,80], gamma=gamma)
    
    epochs = []
    lr_values = {}
    for name in schedulers.keys():
        lr_values[name] = []

    for epoch in range(first_epoch, last_epoch):
        epochs.append(epoch)
        for name, scheduler in schedulers.items():
            scheduler.step(epoch)
            lr = scheduler.get_lr()
            lr_values[name].append(lr)    

    for name in schedulers.keys():
        plt.plot(epochs, lr_values[name])
    plt.ylabel('LR')
    plt.xlabel('epoch')
    plt.title('Learning Rate Schedulers')
    plt.show()


# References

 1. <div id="pytorch1"></div> **http://pytorch.org/docs/master/_modules/torch/optim/lr_scheduler.html**
 2. <div id="pytorch1"></div> **http://pytorch.org/docs/master/optim.html#how-to-adjust-learning-rate** 