# Pretrains a tranformer encoder model on watch histories

In [None]:
name = "all/Transformer/v19";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Reference/CUDA.ipynb")
@nbinclude("Reference/Include.ipynb");

In [None]:
import Flux
import Flux: cpu, gpu, LayerNorm, logsoftmax
import Random
import StatsBase: mean, sample

# Structs

In [None]:
struct Trainer
    model::Any
    opt::Any
    weightdecay::Any
    lr_schedule::Any
    training_config::Any
    model_config::Any
    rng::Any
end;

# Tokenize training data

In [None]:
function get_training_data(media, include_ptw, cls_tokens, empty_tokens)
    n_tasks = length(ALL_TASKS)
    sentences = Vector{Vector{Vector{wordtype}}}(undef, n_tasks)
    for i = 1:length(sentences)
        data = get_training_data(ALL_TASKS[i], media, include_ptw, cls_tokens, empty_tokens)
        sentences[i] = [data[k] for k in keys(data)]
    end
    vcat(sentences...)
end;

# Create minibatches

In [None]:
get_batch(sentences, training::Bool, t::Trainer) = get_batch(;
    sentences = sentences,
    max_seq_len = t.training_config["max_sequence_length"],
    vocab_sizes = t.training_config["base_vocab_sizes"],
    pad_tokens = t.training_config["pad_tokens"],
    cls_tokens = t.training_config["cls_tokens"],
    mask_tokens = t.training_config["mask_tokens"],
    empty_tokens = t.training_config["empty_tokens"],
    user_weighted_training = t.training_config["user_weighted_training"],
    explicit_baseline = Dict(
        k => rand(t.rng, v) for (k, v) in t.training_config["explicit_baseline"]
    ),
    rng = t.rng,
    training = training,
);

In [None]:
function get_batch(;
    sentences,
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
    empty_tokens,
    user_weighted_training,
    explicit_baseline,
    rng,
    training,
)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)), max_seq_len)
    sentences = [
        subset_sentence(s, seq_len; recent = false, keep_first = !training, rng = rng)
        for s in sentences
    ]

    # get tokenized sentences
    tokens = get_token_ids(
        sentences,
        seq_len,
        extract(vocab_sizes, :position),
        pad_tokens,
        cls_tokens,
    )

    # don't attend across sequences
    attention_mask = zeros(Bool, (seq_len, seq_len, length(sentences)))
    Threads.@threads for i = 1:seq_len
        for j = 1:seq_len
            for k = 1:length(sentences)
                if (extract(tokens, :user)[i, k] == extract(tokens, :user)[j, k]) &&
                   (extract(tokens, :anime)[i, k] != extract(pad_tokens, :anime)) &&
                   (extract(tokens, :anime)[j, k] != extract(pad_tokens, :anime))
                    attention_mask[i, j, k] = 1
                end
            end
        end
    end

    # demean ratings
    if !isnothing(explicit_baseline)
        demean = (
            anime = (
                rating = Dict{Int32,Float32}(),
                count = Dict{Int32,Int32}(),
                weight = Dict{Int32,Float32}(),
            ),
            manga = (
                rating = Dict{Int32,Float32}(),
                count = Dict{Int32,Int32}(),
                weight = Dict{Int32,Float32}(),
            ),
        )
        demean_item_weights = (
            anime = powerdecay(
                get_counts(
                    "training",
                    "all",
                    "explicit",
                    "anime";
                    by_item = true,
                    per_rating = false,
                ),
                log(explicit_baseline[:anime]["λ"][4]),
            ),
            manga = powerdecay(
                get_counts(
                    "training",
                    "all",
                    "explicit",
                    "manga";
                    by_item = true,
                    per_rating = false,
                ),
                log(explicit_baseline[:manga]["λ"][4]),
            ),
        )
    end

    batch_positions = (
        anime = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[]),
        manga = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[]),
    )
    item_positions = (
        anime = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[]),
        manga = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[]),
    )
    labels = (
        anime = (item = Int32[], rating = Float32[]),
        manga = (item = Int32[], rating = Float32[]),
    )
    userids = (
        anime = (item = Int32[], rating = Int32[]),
        manga = (item = Int32[], rating = Int32[]),
    )
    for b::Int32 = 1:length(sentences)
        for i::Int32 = 1:seq_len
            # randomly mask 15% of non-trivial tokens 
            has_anime =
                (extract(tokens, :anime)[i, b] <= extract(vocab_sizes, :anime)) &&
                (extract(tokens, :status)[i, b] != get_status(:plan_to_watch))
            has_manga =
                (extract(tokens, :manga)[i, b] <= extract(vocab_sizes, :manga)) &&
                (extract(tokens, :status)[i, b] != get_status(:plan_to_watch))
            has_rating = extract(tokens, :rating)[i, b] < extract(vocab_sizes, :rating)
            if has_anime
                medium = :anime
            elseif has_manga
                medium = :manga
            end
            should_mask = rand(rng) < 0.15

            # prepare to demean ratings
            if !should_mask && has_rating
                u = extract(tokens, :user)[i, b]
                a = extract(tokens, medium)[i, b]
                if u ∉ keys(demean[medium][:rating])
                    demean[medium][:rating][u] = 0
                    demean[medium][:count][u] = 0
                    demean[medium][:weight][u] = 0
                end
                weight =
                    demean_item_weights[medium][a] * powerlawdecay(
                        1 .- cast_universal_timestamp(
                            extract(tokens, :timestamp)[i, b],
                            String(medium),
                        ),
                        explicit_baseline[medium]["λ"][5],
                    )
                demean[medium][:rating][u] +=
                    weight *
                    (extract(tokens, :rating)[i, b] - explicit_baseline[medium]["a"][a])
                demean[medium][:count][u] += 1
                demean[medium][:weight][u] += weight
            end

            # record tokens before we mask them out
            if !(should_mask && (has_anime || has_manga || has_rating))
                continue
            end
            if has_anime || has_manga
                push!(batch_positions[medium][:item], (i, b))
                push!(
                    item_positions[medium][:item],
                    (
                        extract(tokens, medium)[i, b],
                        Int32(length(item_positions[medium][:item]) + 1),
                    ),
                )
                push!(userids[medium][:item], extract(tokens, :user)[i, b])
            end
            if has_rating
                push!(batch_positions[medium][:rating], (i, b))
                push!(
                    item_positions[medium][:rating],
                    (
                        extract(tokens, medium)[i, b],
                        Int32(length(item_positions[medium][:rating]) + 1),
                    ),
                )
                push!(labels[medium][:rating], extract(tokens, :rating)[i, b])
                push!(labels[medium][:item], extract(tokens, medium)[i, b])
                push!(userids[medium][:rating], extract(tokens, :user)[i, b])
            end

            # bert masking
            item_allowed_info =
                get_wordtype_index.([medium, :rating, :timestamp, :position])
            item_skip_info = get_wordtype_index.([:anime, :manga, :user])
            for j = 1:length(tokens)
                if j in item_allowed_info || j in item_skip_info
                    continue
                end
                tokens[j][i, b] = mask_tokens[j]
            end
            for j in item_allowed_info
                if j in get_wordtype_index.([medium, :rating])
                    cutoffs = (0.8, 0.9)
                    r = training ? rand(rng) : 0.0
                elseif j == get_wordtype_index(:timestamp)
                    cutoffs = (0.45, 0.9)
                    r = training ? rand(rng) : 0.0
                elseif j == get_wordtype_index(:position)
                    cutoffs = (0.45, 0.9)
                    r = training ? rand(rng) : 0.7
                else
                    @assert false
                end
                if r <= cutoffs[1]
                    tokens[j][i, b] = mask_tokens[j]
                elseif r <= cutoffs[2]
                    nothing
                else
                    if eltype(vocab_sizes[j]) == Int32
                        tokens[j][i, b] = rand(rng, 1:vocab_sizes[j])
                    elseif eltype(tokens[j]) == Float32
                        tokens[j][i, b] = rand(rng) * vocab_sizes[j]
                    else
                        @assert false
                    end
                end
            end
        end
    end

    # demean ratings
    for medium in [:anime, :manga]
        demean_explicit_ratings!(
            tokens = tokens,
            medium = medium,
            demean = demean[medium],
            explicit_baseline = explicit_baseline[medium],
            vocab_sizes = vocab_sizes,
            cls_tokens = cls_tokens,
            empty_tokens = empty_tokens,
            labels = labels[medium],
            userids = userids[medium],
        )
    end

    # get weights
    processed_weights = (
        anime = map(x -> uids_to_weights(x), userids[:anime]),
        manga = map(x -> uids_to_weights(x), userids[:manga]),
    )
    if training && !user_weighted_training
        processed_weights[:anime][:item] .= 1
        processed_weights[:anime][:rating] .= 1
        processed_weights[:manga][:item] .= 1
        processed_weights[:manga][:rating] .= 1
    end

    tokens, attention_mask, batch_positions, item_positions, labels, processed_weights
end;

In [None]:
function demean_explicit_ratings!(;
    tokens,
    medium,
    demean,
    explicit_baseline,
    vocab_sizes,
    cls_tokens,
    empty_tokens,
    labels,
    userids,
)
    user_to_baseline = Dict{Int32,Float32}()
    μ_user = mean(explicit_baseline["u"])
    μ_item = mean(explicit_baseline["a"])
    for u in keys(demean[:rating])
        user_weight = powerdecay(demean[:count][u], log(explicit_baseline["λ"][3]))
        user_to_baseline[u] =
            (demean[:rating][u] * user_weight + μ_user * explicit_baseline["λ"][1]) /
            (demean[:weight][u] * user_weight + explicit_baseline["λ"][1])
    end
    get_user_bias(u) = u in keys(user_to_baseline) ? user_to_baseline[u] : μ_user
    get_item_bias(a) =
        a in keys(explicit_baseline["a"]) ? explicit_baseline["a"][a] : μ_item

    for b::Int32 = 1:size(extract(tokens, medium))[2]
        for i::Int32 = 1:size(extract(tokens, medium))[1]
            if extract(tokens, medium)[i, b] == extract(empty_tokens, medium)
                continue
            end
            has_explicit_rating =
                (extract(tokens, :rating)[i, b] .< extract(vocab_sizes, :rating))
            if has_explicit_rating
                extract(tokens, :rating)[i, b] -=
                    get_user_bias(extract(tokens, :user)[i, b]) +
                    get_item_bias(extract(tokens, medium)[i, b])
            end
        end
    end
    for i = 1:length(labels[:rating])
        labels[:rating][i] -=
            get_user_bias(userids[:rating][i]) + get_item_bias(labels[:item][i])
    end
end;

In [None]:
function uids_to_weights(uids)
    uid_to_count = Dict(i => 0 for i in uids)
    for i in uids
        uid_to_count[i] += 1
    end
    weights = zeros(Float32, length(uids))
    for i = 1:length(uids)
        weights[i] = 1 / uid_to_count[uids[i]]
    end
    weights
end;

# Manipulate minibatches

In [None]:
function shuffle_training_data(rng, sentences, max_sequence_length, max_document_length)
    order = Random.shuffle(rng, 1:length(sentences))
    max_sequence_length = max_sequence_length
    max_document_length = max_document_length
    S = eltype(sentences)
    W = eltype(sentences[1])

    # concatenate all tokens
    tokens = Vector{W}()
    for i in order
        sentence = subset_sentence(
            sentences[i],
            max_document_length;
            recent = false,
            keep_first = true,
            rng = rng,
        )
        for token in sentence
            push!(tokens, token)
        end
    end

    # patition tokens into minibatches
    batched_sentences = Vector{S}()
    sentence = Vector{W}()
    for token in tokens
        push!(sentence, token)
        if length(sentence) == max_sequence_length
            push!(batched_sentences, sentence)
            sentence = Vector{W}()
        end
    end
    if length(sentence) > 0
        push!(batched_sentences, sentence)
    end
    batched_sentences
end;

In [None]:
function device(x::NamedTuple)
    fields = fieldnames(typeof(x))
    NamedTuple{fields}(gpu(x[k]) for k in fields)
end

function device(batch)
    gpu.(batch[1]),
    gpu(batch[2]),
    device(batch[3]),
    device(batch[4]),
    device(batch[5]),
    device(batch[6])
end

CUDA.unsafe_free!(::Nothing) = nothing
device_free!(x) = CUDA.unsafe_free!(x)
function device_free!(x::NamedTuple)
    fields = fieldnames(typeof(x))
    for f in fields
        device_free!(x[f])
    end
end
function device_free!(batch::Tuple)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(batch[1])
    CUDA.unsafe_free!(batch[2])
    device_free!(batch[3])
    device_free!(batch[4])
    device_free!(batch[5])
    device_free!(batch[6])
end;

# Create model

In [None]:
function create_bert(config)
    bert = Bert(
        hidden_size = config["hidden_size"],
        num_attention_heads = config["num_attention_heads"],
        intermediate_size = config["intermediate_size"],
        num_layers = config["num_hidden_layers"];
        activation_fn = config["hidden_act"],
        dropout = config["dropout"],
        attention_dropout = config["attention_dropout"],
    )

    anime_emb = DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :anime))
    manga_emb = DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :manga))
    rating_emb = ContinuousEmbed(config["hidden_size"])
    timestamp_emb = ContinuousEmbed(config["hidden_size"])
    status_emb =
        DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :status))
    completion_emb = ContinuousEmbed(config["hidden_size"])
    position_emb =
        DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :position))
    emb_post = Chain(LayerNorm(config["hidden_size"]), Dropout(config["dropout"]))
    emb = CompositeEmbedding(
        anime = anime_emb,
        manga = manga_emb,
        rating = rating_emb,
        timestamp = timestamp_emb,
        status = status_emb,
        completion = completion_emb,
        position = position_emb,
        postprocessor = emb_post,
    )

    clf = (
        anime = (
            item = Dense(config["hidden_size"], extract(config["vocab_sizes"], :anime)),
            rating = Dense(config["hidden_size"], extract(config["vocab_sizes"], :anime)),
        ),
        manga = (
            item = Dense(config["hidden_size"], extract(config["vocab_sizes"], :manga)),
            rating = Dense(config["hidden_size"], extract(config["vocab_sizes"], :manga)),
        ),
    )

    TransformerModel(emb, bert, clf)
end;

# Loss metrics

In [None]:
function masklm_losses(model, batch)
    tokens, attention_mask, batch_positions, item_positions, labels, weights = batch
    X = model.embed(
        anime = extract(tokens, :anime),
        manga = extract(tokens, :manga),
        rating = extract(tokens, :rating),
        timestamp = extract(tokens, :timestamp),
        status = extract(tokens, :status),
        completion = extract(tokens, :completion),
        position = extract(tokens, :position),
    )
    X = model.transformers(X, attention_mask)

    if length(item_positions[:anime][:item]) > 0
        anime_item_pred = logsoftmax(
            model.classifier.anime.item(gather(X, batch_positions[:anime][:item])),
        )
        anime_item_loss =
            -(
                weights[:anime][:item]' *
                gather(anime_item_pred, item_positions[:anime][:item])
            ) / sum(weights[:anime][:item])
    else
        anime_item_loss = 0.0f0
    end
    if length(item_positions[:anime][:rating]) > 0
        anime_rating_pred =
            model.classifier.anime.rating(gather(X, batch_positions[:anime][:rating]))
        anime_rating_loss =
            (
                weights[:anime][:rating]' *
                (
                    gather(anime_rating_pred, item_positions[:anime][:rating]) -
                    labels[:anime][:rating]
                ) .^ 2
            ) / sum(weights[:anime][:rating])
    else
        anime_rating_loss = 0.0f0
    end

    if length(item_positions[:manga][:item]) > 0
        manga_item_pred = logsoftmax(
            model.classifier.manga.item(gather(X, batch_positions[:manga][:item])),
        )
        manga_item_loss =
            -(
                weights[:manga][:item]' *
                gather(manga_item_pred, item_positions[:manga][:item])
            ) / sum(weights[:manga][:item])
    else
        manga_item_loss = 0.0f0
    end
    if length(item_positions[:manga][:rating]) > 0
        manga_rating_pred =
            model.classifier.manga.rating(gather(X, batch_positions[:manga][:rating]))
        manga_rating_loss =
            (
                weights[:manga][:rating]' *
                (
                    gather(manga_rating_pred, item_positions[:manga][:rating]) -
                    labels[:manga][:rating]
                ) .^ 2
            ) / sum(weights[:manga][:rating])
    else
        manga_rating_loss = 0.0f0
    end

    anime_item_loss, anime_rating_loss, manga_item_loss, manga_rating_loss
end;

In [None]:
function evaluate_metrics(sentences, t::Trainer)
    losses = zeros(Float32, 4)
    weights = zeros(Float32, 4)
    Random.shuffle!(t.rng, sentences)
    sentence_batches =
        collect(Iterators.partition(sentences, t.training_config["batch_size"]))
    @showprogress for sbatch in sentence_batches
        batch = get_batch(sbatch, false, t) |> device
        w =
            sum.([
                batch[6][:anime][:item],
                batch[6][:anime][:rating],
                batch[6][:manga][:item],
                batch[6][:manga][:rating],
            ])
        losses .+= masklm_losses(t.model, batch) .* w
        weights .+= w
        device_free!(batch)
    end
    names = [
        "Anime Crossentropy Loss",
        "Anime Rating Loss",
        "Manga Crossentropy Loss",
        "Manga Rating Loss",
    ]
    Dict(names .=> losses ./ weights)
end;

# Training

In [None]:
function train_epoch!(sentences, t::Trainer)
    Random.shuffle!(t.rng, sentences)
    losses = []
    sentence_chunks = 128
    @showprogress for sentence_batch in Iterators.partition(
        sentences,
        div(length(sentences), sentence_chunks),
    )
        sentence_batch = shuffle_training_data(
            t.rng,
            sentence_batch,
            t.training_config["max_sequence_length"],
            t.training_config["max_document_length"],
        )
        minibatches =
            collect(Iterators.partition(sentence_batch, t.training_config["batch_size"]))
        for minibatch in minibatches
            schedule_learning_rate!(
                t.opt,
                t.weightdecay,
                t.lr_schedule,
                t.training_config["weight_decay"],
            )
            batch = get_batch(minibatch, true, t) |> device
            tloss, grads = Flux.withgradient(t.model) do m
                sum(masklm_losses(m, batch))
            end
            push!(losses, tloss)
            batch |> device_free!
            Flux.update!(t.opt, t.model, grads[1])
            Flux.update!(t.weightdecay, t.model, grads[1])
        end
    end
    mean(losses)
end;

In [None]:
function checkpoint(sentences, t::Trainer, training_loss, epoch, name)
    @info "evaluating metrics"
    metrics = evaluate_metrics(sentences, t)
    metrics["Validation Loss"] = sum(values(metrics))
    metrics["Training Loss"] = training_loss
    write_params(
        Dict(
            "m" => t.model |> cpu,
            "opt" => t.opt |> cpu,
            "weightdecay" => t.weightdecay |> cpu,
            "lr_schedule" => t.lr_schedule,
            "epoch" => epoch,
            "metrics" => metrics,
            "training_config" => t.training_config,
            "model_config" => t.model_config,
            "rng" => t.rng,
        ),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    if CUDA.functional()
        Random.seed!(CUDA.default_rng(), rand(rng, UInt64))
        Random.seed!(CUDA.CURAND.default_rng(), rand(rng, UInt64))
    end
    rng
end;

In [None]:
function get_sentences(rng, training_config)
    sentences = get_training_data(
        training_config["media"],
        training_config["include_ptw_impressions"],
        training_config["cls_tokens"],
        training_config["empty_tokens"],
    )
    Random.shuffle!(rng, sentences)
    cutoff = Int(round(0.99 * length(sentences)))
    training = sentences[1:cutoff]
    validation = sentences[cutoff+1:end]
    training, validation
end;

function set_epoch_size!(training_config, training_sentences)
    num_tokens =
        sum(min.(length.(training_sentences), training_config["max_document_length"]))
    @info "Number of training sentences: $(length(training_sentences))"
    @info "Number of training tokens: $(num_tokens)"
    training_config["iters_per_epoch"] =
        Int(ceil(num_tokens / training_config["max_sequence_length"]))
end;

In [None]:
function create_training_config()
    media = ["anime", "manga"]
    base_vocab_sizes = (
        Int32(num_items("anime")),
        Int32(num_items("manga")),
        Float32(11),
        Float32(1),
        Int32(5),
        Float32(1),
        Int32(maximum(num_users(x) for x in media)),
        Int32(512), 
    )
    d = Dict(
        # tokenization
        "base_vocab_sizes" => base_vocab_sizes,
        "cls_tokens" => base_vocab_sizes .+ Int32(1),
        "pad_tokens" => base_vocab_sizes .+ Int32(2),
        "mask_tokens" => base_vocab_sizes .+ Int32(3),
        "empty_tokens" => base_vocab_sizes .+ Int32(4),
        "vocab_sizes" => base_vocab_sizes .+ Int32(4),
        # training
        "batch_size" => 16,
        "user_weighted_training" => false,
        "peak_learning_rate" => 3f-4,
        "weight_decay" => 1f-2,
        # data
        "max_document_length" => Inf,
        "include_ptw_impressions" => true,
        "explicit_baseline" => Dict(
            Symbol(x) =>
                [read_params("$x/$t/ExplicitUserItemBiases") for t in ALL_TASKS] for
            x in ["anime", "manga"]
        ),
        "media" => media,
        # model
        "num_layers" => 4, 
        "hidden_size" => 512, 
        "max_sequence_length" => extract(base_vocab_sizes, :position),
    )
    d["num_epochs"] = 64
    @assert d["max_document_length"] >= d["max_sequence_length"]
    # we embed the baseline we're residualizing against in the cls token's status field
    for (k, v) in d["explicit_baseline"]
        @assert length(v) <= extract(d["base_vocab_sizes"], :status)
        for i = 1:length(ALL_TASKS)
            v[i]["task"] = i
            @assert length(v[i]) == 4
            @assert length(v[i]["λ"]) == 5
        end
    end
    d
end;

In [None]:
function create_model_config(training_config)
    # follows the recipe in Section 5 of [Well-Read Students Learn Better: On the 
    # Importance of Pre-training Compact Models](https://arxiv.org/pdf/1908.08962.pdf)
    Dict(
        "attention_dropout" => 0.1,
        "hidden_act" => gelu,
        "num_hidden_layers" => training_config["num_layers"],
        "hidden_size" => training_config["hidden_size"],
        "max_sequence_length" => training_config["max_sequence_length"],
        "vocab_sizes" => training_config["vocab_sizes"],
        "num_attention_heads" => Int(training_config["hidden_size"] / 64),
        "dropout" => 0.1,
        "intermediate_size" => training_config["hidden_size"] * 4,
    )
end;

In [None]:
function load_from_checkpoint(
    training_config,
    outdir::String,
    epoch::Integer,
    reset_lr_schedule,
    rng,
)
    params = read_params("$outdir/$epoch")
    model = params["m"] |> gpu
    if training_config != params["training_config"]
        @info "training config differs from stored params"
        training_config = params["training_config"]
    end
    model_config = params["model_config"]
    opt = params["opt"] |> gpu
    weightdecay = params["weightdecay"] |> gpu
    if reset_lr_schedule
        lr_schedule = get_lr_schedule(config)
    else
        lr_schedule = params["lr_schedule"]
    end
    rng = params["rng"]
    Trainer(model, opt, weightdecay, lr_schedule, training_config, model_config, rng), epoch
end

function load_from_checkpoint(training_config, ::Nothing, ::Nothing, reset_lr_schedule, rng)
    model_config = create_model_config(training_config)
    model = create_bert(model_config) |> gpu
    lr = Float32(config["peak_learning_rate"])
    wd = Float32(config["weight_decay"])
    opt = Optimisers.setup(Adam(lr, (0.9f0, 0.999f0)), model)
    weightdecay = Optimisers.setup(PureWeightDecay(lr * wd), model)
    initialize_weight_decay!(weightdecay, model)
    lr_schedule = get_lr_schedule(config)
    Trainer(model, opt, weightdecay, lr_schedule, training_config, model_config, rng), 0
end;

# Actually Train Model!

In [None]:
config_checkpoint = nothing
config_epoch = nothing
reset_lr_schedule = true
config_rng = set_rngs(20221221)
config = create_training_config();

In [None]:
training_sentences, validation_sentences = get_sentences(config_rng, config);

In [None]:
set_epoch_size!(config, training_sentences)
trainer, starting_epoch = load_from_checkpoint(
    config,
    config_checkpoint,
    config_epoch,
    reset_lr_schedule,
    config_rng,
)
@info "Num epochs: $(config["num_epochs"])"
@info "Training model with $(sum(length, Flux.params(trainer.model))) total parameters"
@info "Embedding parameters: $(sum(length, Flux.params(trainer.model.embed)))"
@info "Transformer parameters: $(sum(length, Flux.params(trainer.model.transformers)))"
@info "Classifier parameters: $(sum(length, Flux.params(trainer.model.classifier)))"

In [None]:
training_loss = Inf
for i = 1:trainer.training_config["num_epochs"]
    GC.gc()
    training_loss = train_epoch!(training_sentences, trainer)
    checkpoint(validation_sentences, trainer, training_loss, starting_epoch + i, name)
end;