# Pretrains a tranformer encoder model on watch histories

In [None]:
name = "all/Transformer/v4";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Reference/CUDA.ipynb")
@nbinclude("Reference/Include.ipynb");

In [None]:
import Flux
import Flux: cpu, gpu, LayerNorm, logsoftmax
import Optimisers
import Optimisers: Adam, OptimiserChain, WeightDecay
import ParameterSchedulers
import ParameterSchedulers: Sequence, Triangle, Shifted, Stateful
import Random
import StatsBase: mean, sample

# Structs

In [None]:
struct Trainer
    model::Any
    opt::Any
    lr_schedule::Any
    training_config::Any
    model_config::Any
    rng::Any
end;

# Tokenize training data

In [None]:
function get_training_data(include_ptw, cls_tokens)
    sentences = Vector{Vector{Vector{word_type}}}(undef, length(ALL_TASKS))
    # only print progress bars for one thread
    show_progress_thread = Threads.Atomic{Int}(0)
    Threads.@threads for i = 1:length(ALL_TASKS)
        Threads.atomic_cas!(show_progress_thread, 0, Threads.threadid())
        data = get_training_data(
            ALL_TASKS[i],
            include_ptw,
            cls_tokens;
            show_progress_bar = show_progress_thread[] == Threads.threadid(),
        )
        sentences[i] = [data[k] for k in keys(data)]
    end
    vcat(sentences...)
end;

In [None]:
function prune(sentences, invalid_word_fn)
    pruned_sentences = []
    @showprogress for i = 1:length(sentences)
        sentence = Vector{eltype(sentences[i])}()
        for word in sentences[i]
            if !invalid_word_fn(word)
                push!(sentence, word)
            end
        end
        if length(sentence) > 0
            push!(pruned_sentences, sentence)
        end
    end
    pruned_sentences
end;

# Create minibatches

In [None]:
function get_batch(
    sentences;
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
    user_weighted_training,
    rng,
    training,
)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)), max_seq_len)

    # get tokenized sentences
    tokens =
        get_token_ids(sentences, seq_len, vocab_sizes[7], pad_tokens, cls_tokens; rng = rng)

    # don't attend across sequences
    attention_mask = zeros(Bool, (seq_len, seq_len, length(sentences)))
    Threads.@threads for i = 1:seq_len
        for j = 1:seq_len
            for k = 1:length(sentences)
                if (tokens[6][i, k] == tokens[6][j, k]) &&
                   (tokens[1][i, k] != pad_tokens[1]) &&
                   (tokens[1][j, k] != pad_tokens[1])
                    attention_mask[i, j, k] = 1
                end
            end
        end
    end

    # apply BERT masking
    batch_positions = (Tuple{Int32,Int32}[], Tuple{Int32,Int32}[])
    item_positions = (Tuple{Int32,Int32}[], Tuple{Int32,Int32}[])
    labels = (nothing, Float32[])
    userids = (Int32[], Int32[])
    for b::Int32 = 1:length(sentences)
        for i::Int32 = 1:seq_len
            # demean explicit ratings
            if tokens[2][i, b] .< vocab_sizes[2]
                explicit_rating = true
            else
                explicit_rating = false
            end

            # randomly mask 15% of tokens
            if (rand(rng) < 0.15) && (tokens[1][i, b] .<= vocab_sizes[1])
                push!(batch_positions[1], (i, b))
                push!(
                    item_positions[1],
                    (tokens[1][i, b], Int32(length(item_positions[1]) + 1)),
                )
                push!(userids[1], tokens[6][i, b])
                if explicit_rating
                    push!(batch_positions[2], (i, b))
                    push!(
                        item_positions[2],
                        (tokens[1][i, b], Int32(length(item_positions[2]) + 1)),
                    )
                    push!(labels[2], tokens[2][i, b])
                    push!(userids[2], tokens[6][i, b])
                end

                item_allowed_info = [1, 2, 3, 7]
                for j = 1:length(tokens)
                    if j in item_allowed_info
                        continue
                    end
                    tokens[j][i, b] = mask_tokens[j]
                end
                for j in item_allowed_info
                    if j in [1, 2]
                        cutoffs = (0.8, 0.9)
                        r = training ? rand(rng) : 0.0
                    elseif j == 3
                        cutoffs = (0.45, 0.9)
                        r = training ? rand(rng) : 0.0
                    elseif j == 7
                        cutoffs = (0.45, 0.9)
                        r = training ? rand(rng) : 0.7
                    else
                        @assert false
                    end
                    if r <= cutoffs[1]
                        tokens[j][i, b] = mask_tokens[j]
                    elseif r <= cutoffs[2]
                        nothing
                    else
                        if eltype(vocab_sizes[j]) == Int32
                            tokens[j][i, b] = rand(rng, 1:vocab_sizes[j])
                        elseif eltype(tokens[j]) == Float32
                            tokens[j][i, b] = rand(rng) * vocab_sizes[j]
                        else
                            @assert false
                        end
                    end
                end
            end
        end
    end
    processed_weights = map(x -> uids_to_weights(x), userids)
    if training && !user_weighted_training
        processed_weights[1] .= 1
        processed_weights[2] .= 1
    end

    tokens, attention_mask, batch_positions, item_positions, labels, processed_weights
end;

In [None]:
function uids_to_weights(uids)
    uid_to_count = Dict(i => 0 for i in uids)
    for i in uids
        uid_to_count[i] += 1
    end
    weights = zeros(Float32, length(uids))
    for i = 1:length(uids)
        weights[i] = 1 / uid_to_count[uids[i]]
    end
    weights
end;

In [None]:
get_batch(sentences, training::Bool, t::Trainer) = get_batch(
    sentences;
    max_seq_len = t.training_config["max_sequence_length"],
    vocab_sizes = t.training_config["base_vocab_sizes"],
    pad_tokens = t.training_config["pad_tokens"],
    cls_tokens = t.training_config["cls_tokens"],
    mask_tokens = t.training_config["mask_tokens"],
    user_weighted_training = t.training_config["user_weighted_training"],
    rng = t.rng,
    training = training,
);

In [None]:
function shuffle_training_data(
    rng,
    sentences,
    line_by_line,
    max_sequence_length,
    max_document_length,
)
    order = Random.shuffle(rng, 1:length(sentences))
    if line_by_line
        return sentences[order]
    end
    max_sequence_length = max_sequence_length
    max_document_length = max_document_length
    S = eltype(sentences)
    W = eltype(sentences[1])

    # concatenate all tokens
    tokens = Vector{W}()
    @showprogress for i in order
        sentence =
            subset_sentence(sentences[i], max_document_length; recent = false, rng = rng)
        for token in sentence
            push!(tokens, token)
        end
    end

    # patition tokens into minibatches
    batched_sentences = Vector{S}()
    sentence = Vector{W}()
    @showprogress for token in tokens
        push!(sentence, token)
        if length(sentence) == max_sequence_length
            push!(batched_sentences, sentence)
            sentence = Vector{W}()
        end
    end
    if length(sentence) > 0
        push!(batched_sentences, sentence)
    end
    batched_sentences
end;

In [None]:
function device(batch)
    (
        gpu(batch[1][1]),
        gpu(batch[1][2]),
        gpu(batch[1][3]),
        gpu(batch[1][4]),
        gpu(batch[1][5]),
        nothing,
        gpu(batch[1][7]),
    ),
    gpu(batch[2]),
    gpu.(batch[3]),
    gpu.(batch[4]),
    gpu.(batch[5]),
    gpu.(batch[6])
end

CUDA.unsafe_free!(::Nothing) = nothing

function device_free!(batch)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(batch[1])
    CUDA.unsafe_free!(batch[2])
    CUDA.unsafe_free!.(batch[3])
    CUDA.unsafe_free!.(batch[4])
    CUDA.unsafe_free!(batch[5][2])
    CUDA.unsafe_free!.(batch[6])
end;

# Create model

In [None]:
function create_bert(config)
    bert = Bert(
        hidden_size = config["hidden_size"],
        num_attention_heads = config["num_attention_heads"],
        intermediate_size = config["intermediate_size"],
        num_layers = config["num_hidden_layers"];
        activation_fn = config["hidden_act"],
        dropout = config["dropout"],
        attention_dropout = config["attention_dropout"],
    )

    item_emb = DiscreteEmbed(config["hidden_size"], config["vocab_sizes"][1])
    rating_emb = ContinuousEmbed(config["hidden_size"])
    timestamp_emb = ContinuousEmbed(config["hidden_size"])
    status_emb = DiscreteEmbed(config["hidden_size"], config["vocab_sizes"][4])
    completion_emb = ContinuousEmbed(config["hidden_size"])
    position_emb = DiscreteEmbed(config["hidden_size"], config["vocab_sizes"][7])
    emb_post = Chain(LayerNorm(config["hidden_size"]), Dropout(config["dropout"]))
    emb = CompositeEmbedding(
        item = item_emb,
        rating = rating_emb,
        timestamp = timestamp_emb,
        status = status_emb,
        completion = completion_emb,
        position = position_emb,
        postprocessor = emb_post,
    )

    item_cls = (
        transform = Chain(
            Dense(config["hidden_size"], config["hidden_size"], config["hidden_act"]),
            LayerNorm(config["hidden_size"]),
        ),
        output_bias = BiasLayer(config["vocab_sizes"][1]),
    )
    rating_cls = (
        transform = Chain(
            Dense(config["hidden_size"], config["hidden_size"], config["hidden_act"]),
            LayerNorm(config["hidden_size"]),
            Dense(config["hidden_size"], config["vocab_sizes"][1]),
        ),
    )
    clf = (item = item_cls, rating = rating_cls)

    TransformerModel(emb, bert, clf)
end;

# Loss metrics

In [None]:
function masklm_losses(model, batch)
    tokens, attention_mask, batch_positions, item_positions, labels, weights = batch
    X = model.embed(
        item = tokens[1],
        rating = tokens[2],
        timestamp = tokens[3],
        status = tokens[4],
        completion = tokens[5],
        position = tokens[7],
    )
    X = model.transformers(X, attention_mask)

    if length(item_positions[1]) > 0
        item_pred = logsoftmax(
            transpose(model.embed.embeddings.item.embedding) *
            model.classifier.item.transform(gather(X, batch_positions[1])) .+
            model.classifier.item.output_bias.b,
        )
        item_loss =
            -sum(gather(item_pred, item_positions[1]) .* weights[1]) / sum(weights[1])
    else
        item_loss = 0.0f0
    end

    if length(item_positions[2]) > 0
        rating_pred = model.classifier.rating.transform(gather(X, batch_positions[2]))
        rating_loss =
            sum((gather(rating_pred, item_positions[2]) - labels[2]) .^ 2 .* weights[2]) /
            sum(weights[2])
    else
        rating_loss = 0.0f0
    end
    item_loss, rating_loss
end;

In [None]:
function evaluate_metrics(sentences, t::Trainer)
    sumtotals = [0.0, 0.0]
    weights = [0.0, 0.0]
    Random.shuffle!(t.rng, sentences)
    sentence_batches =
        collect(Iterators.partition(sentences, t.training_config["minibatch_size"]))
    @showprogress for sbatch in sentence_batches
        batch = get_batch(sbatch, false, t) |> device
        w = sum.(batch[6])
        weights .+= w
        sumtotals .+= masklm_losses(t.model, batch) .* w
        device_free!(batch)
    end
    totals = sumtotals ./ weights
    Dict("Item Crossentropy Loss" => totals[1], "Rating MSE Loss" => totals[2])
end;

# Training

In [None]:
function schedule_learning_rate!(opt, lr_schedule)
    lr = Float32(ParameterSchedulers.next!(lr_schedule))
    Optimisers.adjust!(opt, eta = lr, gamma = lr * 1f-2)
end;

In [None]:
function train_epoch!(sentences, t::Trainer)
    sentences = shuffle_training_data(
        t.rng,
        sentences,
        t.training_config["line_by_line"],
        t.training_config["max_sequence_length"],
        t.training_config["max_document_length"],
    )
    sentence_batches =
        collect(Iterators.partition(sentences, t.training_config["batch_size"]))
    losses = []
    @showprogress for sbatch in sentence_batches
        minibatches =
            collect(Iterators.partition(sbatch, t.training_config["minibatch_size"]))
        schedule_learning_rate!(t.opt, t.lr_schedule)
        total_grads = nothing
        for minibatch in minibatches
            batch = get_batch(minibatch, true, t) |> device
            loss, grads = Flux.withgradient(t.model) do m
                sum(masklm_losses(m, batch))
            end
            total_grads = tuplesum(total_grads, grads[1])
            push!(losses, loss)
            batch |> device_free!
        end
        total_grads = tupledivide(total_grads, length(minibatches))
        Flux.update!(t.opt, t.model, total_grads)
    end
    mean(losses)
end;

In [None]:
function checkpoint(sentences, t::Trainer, training_loss, epoch, name)
    @info "evaluating metrics"
    metrics = evaluate_metrics(sentences, t)
    metrics["Training Loss"] = training_loss
    write_params(
        Dict(
            "m" => t.model |> cpu,
            "opt" => t.opt |> cpu,
            "lr_schedule" => t.lr_schedule,
            "epoch" => epoch,
            "metrics" => metrics,
            "training_config" => t.training_config,
            "model_config" => t.model_config,
            "rng" => t.rng,
        ),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    if CUDA.functional()
        Random.seed!(CUDA.default_rng(), rand(rng, UInt64))
        Random.seed!(CUDA.CURAND.default_rng(), rand(rng, UInt64))
    end
    rng
end;

In [None]:
function get_sentences(rng, training_config)
    sentences = get_training_data(
        training_config["include_ptw_impressions"],
        training_config["cls_tokens"],
    )
    Random.shuffle!(rng, sentences)
    cutoff = Int(round(0.99 * length(sentences)))
    training = sentences[1:cutoff]
    validation = prune(sentences[cutoff+1:end], is_ptw)
    if training_config["line_by_line"]
        training_config["iters_per_epoch"] = length(training)
    else
        total_tokens = sum(min.(length.(training), training_config["max_document_length"]))
        training_config["iters_per_epoch"] =
            Int(ceil(total_tokens / training_config["max_sequence_length"]))
    end
    training, validation
end;

In [None]:
function create_training_config()
    base_vocab_sizes = (
        Int32(num_items()),
        Float32(11),
        Float32(1),
        Int32(5),
        Float32(1),
        Int32(num_users()),
        Int32(512), # todo increase
    )
    d = Dict(
        # tokenization
        "base_vocab_sizes" => base_vocab_sizes,
        "cls_tokens" => base_vocab_sizes .+ Int32(1),
        "pad_tokens" => base_vocab_sizes .+ Int32(2),
        "mask_tokens" => base_vocab_sizes .+ Int32(3),
        "sep_tokens" => base_vocab_sizes .+ Int32(4),
        "vocab_sizes" => base_vocab_sizes .+ Int32(4),
        # training
        "minibatch_size" => 16, # TODO
        "batch_size" => 16, # TODO
        "num_epochs" => 8, # TODO
        "user_weighted_training" => false,
        "peak_learning_rate" => 1f-3,
        # data
        "line_by_line" => false,
        "max_document_length" => 512 * 2, # TODO
        "include_ptw_impressions" => false,
        # model
        "num_layers" => 4, # TODO
        "hidden_size" => 512, # TODO
        "max_sequence_length" => 512,
    )
    @assert d["max_document_length"] >= d["max_sequence_length"]
    @assert d["batch_size"] >= d["minibatch_size"]
    @assert (d["batch_size"] % d["minibatch_size"]) == 0
    d
end;

In [None]:
function LinearWarmupSchedule(lr, iters, warmup_perc)
    warmup_steps = Int(round(iters * warmup_perc))
    remaining_steps = iters - warmup_steps
    Stateful(
        Sequence(
            Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * warmup_steps) => warmup_steps,
            Shifted(
                Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * remaining_steps),
                remaining_steps,
            ) => remaining_steps,
        ),
    )
end

function get_lr_schedule(config)
    lr = Float32(config["peak_learning_rate"])
    max_batches =
        Int(round(config["num_epochs"] * config["iters_per_epoch"] / config["batch_size"]))
    LinearWarmupSchedule(lr, max_batches, 0.06)
end;

In [None]:
function create_model_config(training_config)
    # follows the recipe in Section 5 of [Well-Read Students Learn Better: On the 
    # Importance of Pre-training Compact Models](https://arxiv.org/pdf/1908.08962.pdf)
    Dict(
        "attention_dropout" => 0.1,
        "hidden_act" => gelu,
        "num_hidden_layers" => training_config["num_layers"],
        "hidden_size" => training_config["hidden_size"],
        "max_sequence_length" => training_config["max_sequence_length"],
        "vocab_sizes" => training_config["vocab_sizes"],
        "num_attention_heads" => Int(training_config["hidden_size"] / 64),
        "dropout" => 0.1,
        "intermediate_size" => training_config["hidden_size"] * 4,
    )
end;

In [None]:
function load_from_checkpoint(
    training_config,
    outdir::String,
    epoch::Integer,
    reset_lr_schedule,
    rng,
)
    params = read_params("$outdir/$epoch")
    model = params["m"] |> gpu
    if training_config != params["training_config"]
        @info "training config differs from stored params"
        training_config = params["training_config"]
    end
    model_config = params["model_config"]
    opt = params["opt"] |> gpu
    if reset_lr_schedule
        lr_schedule = get_lr_schedule(config)
    else
        lr_schedule = params["lr_schedule"]
    end
    rng = params["rng"]
    Trainer(model, opt, lr_schedule, training_config, model_config, rng), epoch
end

function load_from_checkpoint(training_config, ::Nothing, ::Nothing, reset_lr_schedule, rng)
    model_config = create_model_config(training_config)
    model = create_bert(model_config) |> gpu
    lr = Float32(config["peak_learning_rate"])
    opt = Optimisers.setup(
        OptimiserChain(Adam(lr, (0.9f0, 0.999f0)), WeightDecay(lr * 1f-2)),
        model,
    )
    lr_schedule = get_lr_schedule(config)
    Trainer(model, opt, lr_schedule, training_config, model_config, rng), 0
end;

# Actually Train Model!

In [None]:
config_checkpoint = nothing
config_epoch = nothing
reset_lr_schedule = true
config_rng = set_rngs(20221221)
config = create_training_config();

In [None]:
training_sentences, validation_sentences = get_sentences(config_rng, config);

In [None]:
trainer, starting_epoch = load_from_checkpoint(
    config,
    config_checkpoint,
    config_epoch,
    reset_lr_schedule,
    config_rng,
);

In [None]:
@info "Training model with $(sum(length, Flux.params(trainer.model))) total parameters"
@info "Embedding parameters: $(sum(length, Flux.params(trainer.model.embed)))"
@info "Transformer parameters: $(sum(length, Flux.params(trainer.model.transformers)))"
@info "Classifier parameters: $(sum(length, Flux.params(trainer.model.classifier)))"

In [None]:
checkpoint(validation_sentences, trainer, Inf, starting_epoch, name)

In [None]:
for i = 1:trainer.training_config["num_epochs"]
    training_loss = train_epoch!(training_sentences, trainer)
    checkpoint(
        validation_sentences,
        trainer,
        training_loss,
        starting_epoch + i,
        name,
    )
end;