In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from pathlib import Path
import numpy as np
from scipy.interpolate import UnivariateSpline
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# Extra functions for lr_find
class ParameterModule(nn.Module):
    "Register a lone parameter 'p' in a module"
    def __init__(self, p:nn.Parameter):
        super().__init__()
        self.val = p
    def forward(self, x): return x

def children_and_parameters(m:nn.Module):
    "Return the children of `m` and its direct parameters not registered in modules."
    children = list(m.children())
    children_p = sum([[id(p) for p in c.parameters()] for c in m.children()],[])
    for p in m.parameters():
        if id(p) not in children_p: children.append(ParameterModule(p))
    return children

flatten_model = lambda m: sum(map(flatten_model,children_and_parameters(m)),[]) if len(list(m.children())) else [m]

# lr_range
def lr_range(model, lr):
    """
    Build differential learning rate from lr
    Arguments:
        lr :- float or slice
        num_layer :- number of layers with requires_grad=True
    Returns:
        Depending upon lr
    """
    if not isinstance(lr, slice): 
        return lr
    
    num_layer = [nn.Sequential(*flatten_model(model))]
    if lr.start: 
        mult = lr.stop / lr.start
        step = mult**(1/(num_layer-1))
        res = np.array([lr.start*(step**i) for i in range(num_layer)])
    else:
        res = [lr.stop/10]*(num_layer-1) + [lr.stop]
    
    return np.array(res)

In [3]:
# lr_find
def lr_find(model:nn.Module, data, start_lr=1e-7, end_lr=10, num_it=100, stop_div:bool=True, wd:float=None):
    """
    Arguments:
        model ->
        start_lr -> lr at which cyclic lr should start
        end_lr -> lr at which cyclic lr shoudl end
        num_it -> number of batches you want to run
        stop_div -> if loss diverges, stop
        wd ->
        data -> your train data_loader (of class torch.utils.data.DataLoader)
        
        If I have 100 images and I make a batch of 10 images, so num_batches=100/10 = 10
    """
    start_lr = lr_range(model, start_lr)
    start_lr = np.array(start_lr) if isinstance(start_lr, (tuple, list)) else start_lr
    
    end_lr = lr_range(end_lr)
    end_lr = np.array(end_lr) if isinstance(end_lr, (tuple, list)) else end_lr
    
    cb = LRFinder(model, start_lr, end_lr, num_it, stop_div)
    epochs = int(np.ceil(num_it/len(data)))

In [19]:
def annealing_no(start, end, pct:float):
    "No annealing, always return `start`."
    return start
def annealing_linear(start, end, pct:float):
    "Linearly anneal from `start` to `end` as pct goes from 0.0 to 1.0."
    return start + pct * (end-start)
def annealing_exp(start, end, pct:float):
    "Exponentially anneal from `start` to `end` as pct goes from 0.0 to 1.0."
    return start * (end/start) ** pct
def annealing_cos(start, end, pct:float):
    "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
    cos_out = np.cos(np.pi * pct) + 1
    return end + (start-end)/2 * cos_out
def do_annealing_poly(start, end, pct:float, degree):
    return end + (start-end) * (1-pct)**degree

class Stepper():
    "Used to \"step\" from start, end ('vals') over 'n_iter' iterations on a schedule"
    def __init__(self, vals, n_iter:int, func=None):
        self.start, self.end = (vals[0], vals[1]) if isinstance(vals, tuple) else (vals,0)
        self.n_iter = max(1, n_iter)
        if func is None:
            self.func = annealing_linear if isinstance(vals, tuple) else annealing_no
        else:
            self.func = func
        self.n = 0
    
    def step(self):
        "Return next value along annealed schedule"
        self.n += 1
        return self.func(self.start, self.end, self.n/self.n_iter)
        
    @property
    def is_done(self)->bool:
        "Return 'True' if schedule completed"
        return self.n >= self.n_iter
        
# def save_model_dict(path, name, model:nn.Module, optimizer, with_opt:bool=True):
#     """
#     Arguments:-
#         path -> Where you want to save the files
#                 from pathlib import Path
#                 path = Path('/home/kushaj/Desktop')
#         name -> Name of the file
#         with_opt -> if True then optimizer state is also saved
#     """
#     path = path/f'{name}.pth' 
#     if not with_opt:
#         state = model.state_dict()
#     else:
#         state = {'model': model.state_dict(),
#                  'opt': optimizer.state_dict()}
#     torch.save(state, path)
        
def LRFinder(model, 
             data, 
             loss_fn, 
             opt, 
             wd, 
             start_lr:float=1e-7, 
             end_lr:float=10, 
             num_it:int=100, 
             stop_div:bool=True):
    sched = Stepper((start_lr, end_lr), num_it, annealing_exp)
    # save model_dict
    model_state = model.state_dict()
    opt_state = opt.state_dict()
    
    opt.lr = sched.start
    stop = False
    best_loss = 0.
    flag = False
    losses = []
    iteration = 0
    lrs = []
    moms = []
    
    while True:
        for dat in data:
            # Batch begin
            lrs.append(opt.lr)
            moms.append(opt.mom)
            
            inputs, labels = dat
            opt.zero_grad()

            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            # For adamW
            for group in opt.param_groups():
                for param in group['params']:
                    param.data = param.data.add(-wd * group['lr'], param.data)
                    
            opt.step()
            
            losses.append(loss.item())
            opt.lr = sched.step()
            if iteration == 0 or loss < best_loss:
                best_loss = loss
            iteration += 1
            
            if sched.is_done or (stop_div and (loss > 4*best_loss or torch.isnan(smooth_loss))):
                flag = True
                break
        
        if flag:
            break
    
    # Load state back
    model.load_state_dict(model_state)
    opt.load_state_dict(opt_state)
    
    print('LR Finder is complete')
    
    return losses, lrs

def plot_lr_finder(losses, 
                   lrs, 
                   skip_start:int=10, 
                   skip_end:int=5, 
                   suggestion:bool=False, 
                   return_fig:bool=None, 
                   smoothen_by_spline:bool=True):
    lrs = lrs[skip_start:-skip_end] if skip_end > 0 else lrs[skip_start:]
    losses = losses[skip_start:-skip_end] if skip_end > 0 else losses[skip_start:]
    if smoothen_by_spline:
        xs = np.arange(len(losses))
        spl = UnivariateSpline(xs, losses)
        losses = spl(xs)
        
    fig, ax = plt.subplots(1, 1)
    ax.plot(lrs, losses)
    ax.set_ylabel("Loss")
    ax.set_xlabel("Learning Rate")
    ax.set_xscale('log')
    ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))
    if suggestion:
        try:
            mg = (np.gradient(np.array(losses))).argmin()
        except:
            print("Failed to compute the gradients, there might not be enough points.")
            return
        print("Min numerical gradient: {lrs[mg]:.2E}")
        ax.plot(lrs[mg], losses[mg], markersize=10, marker='o', color='red')
    
    if return_fig is not None:
        return fig

In [35]:
path = '../Data/cifar10/train/'
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.ImageFolder(path, transform=transform)

data = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True, num_workers=8)

In [17]:
scipy.interpolate

AttributeError: module 'scipy' has no attribute 'interpolate'