In [None]:
import Optimisers
import Optimisers: Adam, OptimiserChain, WeightDecay
import ParameterSchedulers
import ParameterSchedulers: Sequence, Triangle, Shifted, Stateful

# Learning rate schedules

In [None]:
function schedule_learning_rate!(opt, lr_schedule, weight_decay)
    lr = Float32(ParameterSchedulers.next!(lr_schedule))
    Optimisers.adjust!(opt, eta = lr, gamma = lr * weight_decay)
end;

In [None]:
function LinearWarmupSchedule(lr, iters, warmup_perc)
    warmup_steps = Int(round(iters * warmup_perc))
    remaining_steps = iters - warmup_steps
    Stateful(
        Sequence(
            Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * warmup_steps) => warmup_steps,
            Shifted(
                Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * remaining_steps),
                remaining_steps,
            ) => remaining_steps,
        ),
    )
end

function get_lr_schedule(config; num_epochs = nothing, peak_learning_rate = nothing)
    if isnothing(peak_learning_rate)
        lr = Float32(config["peak_learning_rate"])
    else
        lr = peak_learning_rate
    end
    if isnothing(num_epochs)
        num_epochs = config["num_epochs"]
    end
    max_batches = Int(round(num_epochs * config["iters_per_epoch"] / config["batch_size"]))
    LinearWarmupSchedule(lr, max_batches, 0.06)
end;

In [None]:
function current(iter::ParameterSchedulers.Stateful)
    return iter.schedule(iter.state)
end;