In [None]:
import Optimisers
import Optimisers: Adam, OptimiserChain, WeightDecay
import ParameterSchedulers
import ParameterSchedulers: Sequence, Triangle, Shifted, Stateful

# Learning rate schedules

In [None]:
function schedule_learning_rate!(opt, lr_schedule, weight_decay)
    lr = Float32(ParameterSchedulers.next!(lr_schedule))
    Optimisers.adjust!(opt, eta = lr, gamma = lr * weight_decay)
end;

In [None]:
function LinearWarmupSchedule(lr, iters, warmup_perc)
    warmup_steps = Int(round(iters * warmup_perc))
    remaining_steps = iters - warmup_steps
    Stateful(
        Sequence(
            Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * warmup_steps) => warmup_steps,
            Shifted(
                Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * remaining_steps),
                remaining_steps,
            ) => remaining_steps,
        ),
    )
end

function get_lr_schedule(config; num_epochs = nothing, peak_learning_rate = nothing)
    if isnothing(peak_learning_rate)
        lr = Float32(config["peak_learning_rate"])
    else
        lr = peak_learning_rate
    end
    if isnothing(num_epochs)
        num_epochs = config["num_epochs"]
    end
    max_batches = Int(round(num_epochs * config["iters_per_epoch"] / config["batch_size"]))
    LinearWarmupSchedule(lr, max_batches, 0.06)
end;

In [None]:
function current(iter::ParameterSchedulers.Stateful)
    return iter.schedule(iter.state)
end;

# Gradient accumulation

In [None]:
function tuplesum(a::NamedTuple, b::NamedTuple)
    fields = fieldnames(typeof(a))
    NamedTuple{fields}(tuplesum(a[k], b[k]) for k in fields)
end
tuplesum(a::Tuple, b::Tuple) = Tuple(tuplesum(a[k], b[k]) for k = 1:length(a))
tuplesum(a::Nothing, b) = b
tuplesum(a, b) = a + b;

function tupledivide(a::NamedTuple, d)
    fields = fieldnames(typeof(a))
    NamedTuple{fields}(tupledivide(a[k], d) for k in fields)
end
tupledivide(a::Tuple, d) = Tuple(tupledivide(a[k], d) for k = 1:length(a))
tupledivide(a::Nothing, d) = nothing
tupledivide(a, d) = a ./ d;

# Optimisers

In [None]:
# weight decay, but on on bias terms, embedding layers, layer norms, etc.
struct WeightDecayNobias{T} <: Optimisers.AbstractRule
  gamma::T
end
Optimisers.init(o::WeightDecayNobias, x::AbstractArray) = nothing
function Optimisers.apply!(o::WeightDecayNobias, state, x, dx)
    if should_weight_decay(o, x)
        dx′ = Optimisers.@lazy dx + o.gamma * x
    else
        dx′ = Optimisers.@lazy dx
    end
    return state, dx′
end
function should_weight_decay(o::WeightDecayNobias, x)
    nontrivial_dims = 0 
    has_odd_size = false
    has_power_of_two_size = false
    for d in size(x)
        if d > 1
            nontrivial_dims += 1
            if d % 2 != 0
                has_odd_size = true
            end
            if ispow2(d)
                has_power_of_two_size = true
            end
        end 
    end
    is_embedding_matrix = has_power_of_two_size && has_odd_size && (length(size(x)) == 2)
    return !is_embedding_matrix && (nontrivial_dims > 1)
end;