In [None]:
import Flux
import NeuralAttentionlib
import Optimisers
import Optimisers: Adam
import ParameterSchedulers
import ParameterSchedulers: Sequence, Triangle, Shifted, Stateful
import Transformers

# Learning rate schedules

In [None]:
function schedule_learning_rate!(opt, weightdecay, lr_schedule, weightdecay_frac)
    lr = Float32(ParameterSchedulers.next!(lr_schedule))
    Optimisers.adjust!(opt, eta = lr)
    Optimisers.adjust!(weightdecay, gamma = lr * weightdecay_frac)
end;

In [None]:
function LinearWarmupSchedule(lr, iters, warmup_perc)
    # TODO cosine annealing
    warmup_steps = Int(round(iters * warmup_perc))
    remaining_steps = iters - warmup_steps
    Stateful(
        Sequence(
            Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * warmup_steps) => warmup_steps,
            Shifted(
                Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * remaining_steps),
                remaining_steps,
            ) => remaining_steps,
        ),
    )
end

function get_lr_schedule(config; num_epochs = nothing, peak_learning_rate = nothing)
    if isnothing(peak_learning_rate)
        lr = Float32(config["peak_learning_rate"])
    else
        lr = peak_learning_rate
    end
    if isnothing(num_epochs)
        num_epochs = config["num_epochs"]
    end
    max_batches = Int(ceil(num_epochs * config["iters_per_epoch"] / config["batch_size"]))
    LinearWarmupSchedule(lr, max_batches, 0.06)
end;

In [None]:
function current(iter::ParameterSchedulers.Stateful)
    return iter.schedule(iter.state)
end;

# Weight decay

In [None]:
# weight decay, but don't decay on embeddings, biases, or layer norms
# need to call initialize_weight_decay! before after setup
struct PureWeightDecay{T} <: Optimisers.AbstractRule
  gamma::T
end
PureWeightDecay() = PureWeightDecay(5f-4)
Optimisers.init(o::PureWeightDecay, x::AbstractArray) = nothing
function Optimisers.apply!(o::PureWeightDecay, state, x, dx)
  return state, o.gamma * x
end;

In [None]:
function initialize_weight_decay!(opt, m)
    if typeof(opt) <: Optimisers.Leaf || length(opt) == 0
        return
    end
    disabled = disable_weight_decay(m)
    for f in disabled
        Optimisers.freeze!(opt[f])
    end
    for f in fieldnames(typeof(m))
        if f ∉ disabled && f in fieldnames(typeof(opt))
            initialize_weight_decay!(opt[f], getfield(m, f))
        end
    end
end;

In [None]:
# layers with special weightdecay semantics
disable_weight_decay(x::Flux.Dense) = [:bias]
disable_weight_decay(x::Transformers.Layers.Dense) =  [:b]
disable_weight_decay(x::Transformers.Layers.LayerNorm) = [:α, :β]
disable_weight_decay(x::Flux.LayerNorm) = [:diag]

# as a safety check, whitelist layers
disable_weight_decay(x) = @assert false typeof(x)
disable_weight_decay(x::Function) = Symbol[]
disable_weight_decay(x::NamedTuple) = Symbol[]
disable_weight_decay(x::Tuple) = Symbol[]
disable_weight_decay(x::Flux.Chain) = Symbol[]
disable_weight_decay(x::Transformers.Transformer) = Symbol[]
disable_weight_decay(x::Transformers.Layers.PreNormTransformerBlock) = Symbol[]
disable_weight_decay(x::Transformers.Transformer) = Symbol[]
disable_weight_decay(x::Transformers.Layers.PreNormResidual) = Symbol[]
disable_weight_decay(x::Transformers.Layers.DropoutLayer) = Symbol[]
disable_weight_decay(x::Transformers.Layers.SelfAttention) = Symbol[]
disable_weight_decay(x::NeuralAttentionlib.MultiheadQKVAttenOp) = Symbol[]
disable_weight_decay(x::Transformers.Layers.NSplit) = Symbol[]
disable_weight_decay(x::Transformers.Layers.Chain) = Symbol[]
disable_weight_decay(x::Transformers.Dropout) = Symbol[];