In [None]:
# var 1

def find_appropriate_lr(model:Learner, lr_diff:int = 15, loss_threshold:float = .05, adjust_value:float = 1, plot:bool = False) -> float:
    #Run the Learning Rate Finder
    model.lr_find()
    
    #Get loss values and their corresponding gradients, and get lr values
    losses = np.array(model.recorder.losses)
    assert(lr_diff < len(losses))
    loss_grad = np.gradient(losses)
    lrs = model.recorder.lrs
    
    #Search for index in gradients where loss is lowest before the loss spike
    #Initialize right and left idx using the lr_diff as a spacing unit
    #Set the local min lr as -1 to signify if threshold is too low
    r_idx = -1
    l_idx = r_idx - lr_diff
    while (l_idx >= -len(losses)) and (abs(loss_grad[r_idx] - loss_grad[l_idx]) > loss_threshold):
        local_min_lr = lrs[l_idx]
        r_idx -= 1
        l_idx -= 1

    lr_to_use = local_min_lr * adjust_value
    
    if plot:
        # plots the gradients of the losses in respect to the learning rate change
        plt.plot(loss_grad)
        plt.plot(len(losses)+l_idx, loss_grad[l_idx],markersize=10,marker='o',color='red')
        plt.ylabel("Loss")
        plt.xlabel("Index of LRs")
        plt.show()

        plt.plot(np.log10(lrs), losses)
        plt.ylabel("Loss")
        plt.xlabel("Log 10 Transform of Learning Rate")
        loss_coord = np.interp(np.log10(lr_to_use), np.log10(lrs), losses)
        plt.plot(np.log10(lr_to_use), loss_coord, markersize=10,marker='o',color='red')
        plt.show()
        
    return lr_to_use

In [None]:
# var 2

def find_lr(losses, lrs):
    import matplotlib.pyplot as plt

    losses_skipped = 5
    trailing_losses_skipped = 5
    losses = losses[losses_skipped:-trailing_losses_skipped]
    lrs = lrs[losses_skipped:-trailing_losses_skipped]

    n = len(losses)

    max_start = 0
    max_end = 0
    
    # finding the longest valley.
    lds = [1] * n

    for i in range(1, n):
        for j in range(0, i):
            if losses[i] < losses[j] and lds[i] < lds[j] + 1:
                lds[i] = lds[j] + 1
            if lds[max_end] < lds[i]:
                max_end = i
                max_start = max_end - lds[max_end]

    sections = (max_end - max_start) / 3
    final_index = max_start + int(sections) + int(sections/2) # pick something midway, or 2/3rd of the way to be more aggressive

    fig, ax = plt.subplots(1, 1)
    ax.plot(
        lrs,
        losses
    )
    ax.set_ylabel("Loss")
    ax.set_xlabel("Learning Rate")
    ax.set_xscale('log')
    ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))
    ax.plot(
        lrs[final_index],
        losses[final_index],
        markersize=10,
        marker='o',
        color='red'
    )

    plt.show()
    
    return lrs[final_index]

In [None]:
# var 3

learn.unfreeze()
learn.lr_find()
learn.recorder.plot(suggestion=True)

lr = learn.recorder.min_grad_lr
lr