# Pretrains a tranformer encoder model on watch histories

In [None]:
medium = ""

In [None]:
name = "$medium/all/Transformer/v3";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Reference/CUDA.ipynb")
@nbinclude("Reference/Include.ipynb");

In [None]:
import Flux
import Flux: cpu, gpu, LayerNorm, logsoftmax
import Random
import StatsBase: mean, sample

# Structs

In [None]:
struct Trainer
    model::Any
    opt::Any
    lr_schedule::Any
    training_config::Any
    model_config::Any
    rng::Any
end;

# Tokenize training data

In [None]:
function get_training_data(include_ptw, cls_tokens)
    sentences = Vector{Vector{Vector{wordtype}}}(undef, length(ALL_TASKS))
    for i = 1:length(ALL_TASKS)
        data = get_training_data(ALL_TASKS[i], medium, include_ptw, cls_tokens)
        sentences[i] = [data[k] for k in keys(data)]
    end
    vcat(sentences...)
end;

In [None]:
function prune(sentences, invalid_word_fn)
    pruned_sentences = []
    @showprogress for i = 1:length(sentences)
        sentence = Vector{eltype(sentences[i])}()
        for word in sentences[i]
            if !invalid_word_fn(word)
                push!(sentence, word)
            end
        end
        if length(sentence) > 0
            push!(pruned_sentences, sentence)
        end
    end
    pruned_sentences
end;

# Create minibatches

In [None]:
get_batch(sentences, training::Bool, t::Trainer) = get_batch(
    sentences;
    max_seq_len = t.training_config["max_sequence_length"],
    vocab_sizes = t.training_config["base_vocab_sizes"],
    pad_tokens = t.training_config["pad_tokens"],
    cls_tokens = t.training_config["cls_tokens"],
    mask_tokens = t.training_config["mask_tokens"],
    user_weighted_training = t.training_config["user_weighted_training"],
    explicit_baseline = t.training_config["explicit_baseline"],
    rng = t.rng,
    training = training,
);

In [None]:
# TODO break up into smaller functions
# TODO support causal training
function get_batch(
    sentences;
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
    user_weighted_training,
    explicit_baseline,
    rng,
    training,
)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)), max_seq_len)

    # get tokenized sentences
    tokens = get_token_ids(
        sentences,
        seq_len,
        extract(vocab_sizes, :position),
        pad_tokens,
        cls_tokens;
        rng = rng,
    )

    # don't attend across sequences
    attention_mask = zeros(Bool, (seq_len, seq_len, length(sentences)))
    Threads.@threads for i = 1:seq_len
        for j = 1:seq_len
            for k = 1:length(sentences)
                if (extract(tokens, :user)[i, k] == extract(tokens, :user)[j, k]) &&
                   (extract(tokens, :item)[i, k] != extract(pad_tokens, :item)) &&
                   (extract(tokens, :item)[j, k] != extract(pad_tokens, :item))
                    attention_mask[i, j, k] = 1
                end
            end
        end
    end

    # demean ratings
    if !isnothing(explicit_baseline)
        demean = (
            rating = Dict{Int32,Float32}(),
            count = Dict{Int32,Int32}(),
            weight = Dict{Int32,Float32}(),
        )
        demean_item_weights = powerdecay(
            get_counts(
                "training",
                "all",
                "explicit",
                medium;
                by_item = true,
                per_rating = false,
            ),
            log(explicit_baseline["λ"][4]),
        )
    end

    batch_positions = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[])
    item_positions = (item = Tuple{Int32,Int32}[], rating = Tuple{Int32,Int32}[])
    labels = (item = Int32[], rating = Float32[])
    userids = (item = Int32[], rating = Int32[])
    for b::Int32 = 1:length(sentences)
        for i::Int32 = 1:seq_len
            # randomly mask 15% of non-trivial tokens 
            has_implicit_rating =
                (extract(tokens, :item)[i, b] .<= extract(vocab_sizes, :item))
            has_explicit_rating =
                (extract(tokens, :rating)[i, b] .< extract(vocab_sizes, :rating))
            should_mask = rand(rng) < 0.15

            # pre-demean ratings
            if !should_mask && !isnothing(explicit_baseline) && has_explicit_rating
                u = extract(tokens, :user)[i, b]
                a = extract(tokens, :item)[i, b]
                if u ∉ keys(demean[:rating])
                    demean[:rating][u] = 0
                    demean[:count][u] = 0
                    demean[:weight][u] = 0
                end
                weight =
                    demean_item_weights[a] * powerlawdecay(
                        1 .- extract(tokens, :timestamp)[i, b],
                        explicit_baseline["λ"][5],
                    )
                demean[:rating][u] +=
                    weight * (extract(tokens, :rating)[i, b] - explicit_baseline["a"][a])
                demean[:count][u] += 1
                demean[:weight][u] += weight
            end

            # record tokens before we mask them out
            if !(should_mask && (has_implicit_rating || has_explicit_rating))
                continue
            end
            if has_implicit_rating
                push!(batch_positions[:item], (i, b))
                push!(
                    item_positions[:item],
                    (
                        extract(tokens, :item)[i, b],
                        Int32(length(item_positions[:item]) + 1),
                    ),
                )
                push!(userids[:item], extract(tokens, :user)[i, b])
            end
            if has_explicit_rating
                push!(batch_positions[:rating], (i, b))
                push!(
                    item_positions[:rating],
                    (
                        extract(tokens, :item)[i, b],
                        Int32(length(item_positions[:rating]) + 1),
                    ),
                )
                push!(labels[:rating], extract(tokens, :rating)[i, b])
                push!(labels[:item], extract(tokens, :item)[i, b])
                push!(userids[:rating], extract(tokens, :user)[i, b])
            end

            # bert masking
            item_allowed_info = get_wordtype_index.([:item, :rating, :timestamp, :position])
            for j = 1:length(tokens)
                if j in item_allowed_info
                    continue
                end
                tokens[j][i, b] = mask_tokens[j]
            end
            for j in item_allowed_info
                if j in get_wordtype_index.([:item, :rating])
                    cutoffs = (0.8, 0.9)
                    r = training ? rand(rng) : 0.0
                elseif j == get_wordtype_index(:timestamp)
                    cutoffs = (0.45, 0.9)
                    r = training ? rand(rng) : 0.0
                elseif j == get_wordtype_index(:position)
                    cutoffs = (0.45, 0.9)
                    r = training ? rand(rng) : 0.7
                else
                    @assert false
                end
                if r <= cutoffs[1]
                    tokens[j][i, b] = mask_tokens[j]
                elseif r <= cutoffs[2]
                    nothing
                else
                    if eltype(vocab_sizes[j]) == Int32
                        tokens[j][i, b] = rand(rng, 1:vocab_sizes[j])
                    elseif eltype(tokens[j]) == Float32
                        tokens[j][i, b] = rand(rng) * vocab_sizes[j]
                    else
                        @assert false
                    end
                end
            end
        end
    end

    # demean ratings
    demean_explicit_ratings!(
        tokens = tokens,
        demean = demean,
        explicit_baseline = explicit_baseline,
        vocab_sizes = vocab_sizes,
        labels = labels,
        userids = userids,
    )

    # get weights
    processed_weights = map(x -> uids_to_weights(x), userids)
    if training && !user_weighted_training
        processed_weights[:item] .= 1
        processed_weights[:rating] .= 1
    end

    tokens, attention_mask, batch_positions, item_positions, labels, processed_weights
end;

In [None]:
function demean_explicit_ratings!(;
    tokens,
    demean,
    explicit_baseline,
    vocab_sizes,
    labels,
    userids,
)
    user_to_baseline = Dict{Int32,Float32}()
    μ_user = mean(explicit_baseline["u"])
    μ_item = mean(explicit_baseline["a"])
    for u in keys(demean[:rating])
        user_weight = powerdecay(demean[:count][u], log(explicit_baseline["λ"][3]))
        user_to_baseline[u] =
            (demean[:rating][u] * user_weight + μ_user * explicit_baseline["λ"][1]) /
            (demean[:weight][u] * user_weight + explicit_baseline["λ"][1])
    end
    get_user_bias(u) = u in keys(user_to_baseline) ? user_to_baseline[u] : μ_user
    get_item_bias(a) = a in keys(explicit_baseline["a"]) ? explicit_baseline["a"][a] : μ_item
    
    for b::Int32 = 1:size(extract(tokens, :item))[2]
        for i::Int32 = 1:size(extract(tokens, :item))[1]
            has_explicit_rating =
                (extract(tokens, :rating)[i, b] .< extract(vocab_sizes, :rating))
            if has_explicit_rating
                extract(tokens, :rating)[i, b] -=
                    get_user_bias(extract(tokens, :user)[i, b]) +
                    get_item_bias(extract(tokens, :item)[i, b])
            end
        end
    end
    for i = 1:length(labels[:rating])
        labels[:rating][i] -=
            get_user_bias(userids[:rating][i]) + get_item_bias(labels[:item][i])
    end
end;

In [None]:
function uids_to_weights(uids)
    uid_to_count = Dict(i => 0 for i in uids)
    for i in uids
        uid_to_count[i] += 1
    end
    weights = zeros(Float32, length(uids))
    for i = 1:length(uids)
        weights[i] = 1 / uid_to_count[uids[i]]
    end
    weights
end;

# Manipulate minibatches

In [None]:
function shuffle_training_data(
    rng,
    sentences,
    line_by_line,
    max_sequence_length,
    max_document_length,
)
    order = Random.shuffle(rng, 1:length(sentences))
    if line_by_line
        return sentences[order]
    end
    max_sequence_length = max_sequence_length
    max_document_length = max_document_length
    S = eltype(sentences)
    W = eltype(sentences[1])

    # concatenate all tokens
    tokens = Vector{W}()
    @showprogress for i in order
        sentence =
            subset_sentence(sentences[i], max_document_length; recent = false, rng = rng)
        for token in sentence
            push!(tokens, token)
        end
    end

    # patition tokens into minibatches
    batched_sentences = Vector{S}()
    sentence = Vector{W}()
    @showprogress for token in tokens
        push!(sentence, token)
        if length(sentence) == max_sequence_length
            push!(batched_sentences, sentence)
            sentence = Vector{W}()
        end
    end
    if length(sentence) > 0
        push!(batched_sentences, sentence)
    end
    batched_sentences
end;

In [None]:
function device(x::NamedTuple)
    fields = fieldnames(typeof(x))
    NamedTuple{fields}(gpu(x[k]) for k in fields)
end

function device(batch)
    gpu.(batch[1]),
    gpu(batch[2]),
    device(batch[3]),
    device(batch[4]),
    device(batch[5]),
    device(batch[6])
end

CUDA.unsafe_free!(::Nothing) = nothing
function device_free!(x::NamedTuple)
    fields = fieldnames(typeof(x))
    for f in fields
        CUDA.unsafe_free!(x[f])
    end
end
function device_free!(batch)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(batch[1])
    CUDA.unsafe_free!(batch[2])
    device_free!(batch[3])
    device_free!(batch[4])
    device_free!(batch[5])
    device_free!(batch[6])
end;

# Create model

In [None]:
function create_bert(config)
    bert = Bert(
        hidden_size = config["hidden_size"],
        num_attention_heads = config["num_attention_heads"],
        intermediate_size = config["intermediate_size"],
        num_layers = config["num_hidden_layers"];
        activation_fn = config["hidden_act"],
        dropout = config["dropout"],
        attention_dropout = config["attention_dropout"],
    )

    item_emb = DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :item))
    rating_emb = ContinuousEmbed(config["hidden_size"])
    timestamp_emb = ContinuousEmbed(config["hidden_size"])
    status_emb =
        DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :status))
    completion_emb = ContinuousEmbed(config["hidden_size"])
    position_emb =
        DiscreteEmbed(config["hidden_size"], extract(config["vocab_sizes"], :position))
    emb_post = Chain(LayerNorm(config["hidden_size"]), Dropout(config["dropout"]))
    emb = CompositeEmbedding(
        item = item_emb,
        rating = rating_emb,
        timestamp = timestamp_emb,
        status = status_emb,
        completion = completion_emb,
        position = position_emb,
        postprocessor = emb_post,
    )

    item_cls = (
        transform = Chain(
            Dense(config["hidden_size"], config["hidden_size"], config["hidden_act"]),
            LayerNorm(config["hidden_size"]),
        ),
        output_bias = BiasLayer(extract(config["vocab_sizes"], :item)),
    )
    rating_cls = (
        transform = Chain(
            Dense(config["hidden_size"], config["hidden_size"], config["hidden_act"]),
            LayerNorm(config["hidden_size"]),
            Dense(config["hidden_size"], extract(config["vocab_sizes"], :item)),
        ),
    )
    clf = (item = item_cls, rating = rating_cls)

    TransformerModel(emb, bert, clf)
end;

# Loss metrics

In [None]:
function masklm_losses(model, batch)
    tokens, attention_mask, batch_positions, item_positions, labels, weights = batch
    X = model.embed(
        item = extract(tokens, :item),
        rating = extract(tokens, :rating),
        timestamp = extract(tokens, :timestamp),
        status = extract(tokens, :status),
        completion = extract(tokens, :completion),
        position = extract(tokens, :position),
    )
    X = model.transformers(X, attention_mask)

    if length(item_positions[:item]) > 0
        item_pred = logsoftmax(
            transpose(model.embed.embeddings.item.embedding) *
            model.classifier.item.transform(gather(X, batch_positions[:item])) .+
            model.classifier.item.output_bias.b,
        )
        item_loss =
            -sum(gather(item_pred, item_positions[:item]) .* weights[:item]) /
            sum(weights[:item])
    else
        item_loss = 0.0f0
    end

    if length(item_positions[:rating]) > 0
        rating_pred = model.classifier.rating.transform(gather(X, batch_positions[:rating]))
        rating_loss =
            sum(
                (gather(rating_pred, item_positions[:rating]) - labels[:rating]) .^ 2 .*
                weights[:rating],
            ) / sum(weights[:rating])
    else
        rating_loss = 0.0f0
    end
    item_loss, rating_loss
end;

In [None]:
function evaluate_metrics(sentences, t::Trainer)
    sumtotals = [0.0, 0.0]
    weights = [0.0, 0.0]
    Random.shuffle!(t.rng, sentences)
    sentence_batches =
        collect(Iterators.partition(sentences, t.training_config["minibatch_size"]))
    @showprogress for sbatch in sentence_batches
        batch = get_batch(sbatch, false, t) |> device
        w = sum.([batch[6][:item], batch[6][:rating]])
        weights .+= w
        sumtotals .+= masklm_losses(t.model, batch) .* w
        device_free!(batch)
    end
    totals = sumtotals ./ weights
    Dict("Item Crossentropy Loss" => totals[1], "Rating MSE Loss" => totals[2])
end;

# Training

In [None]:
function train_epoch!(sentences, t::Trainer)
    sentences = shuffle_training_data(
        t.rng,
        sentences,
        t.training_config["line_by_line"],
        t.training_config["max_sequence_length"],
        t.training_config["max_document_length"],
    )
    sentence_batches =
        collect(Iterators.partition(sentences, t.training_config["batch_size"]))
    losses = []
    @showprogress for sbatch in sentence_batches
        minibatches =
            collect(Iterators.partition(sbatch, t.training_config["minibatch_size"]))
        schedule_learning_rate!(t.opt, t.lr_schedule, t.training_config["weight_decay"])
        total_grads = nothing
        for minibatch in minibatches
            batch = get_batch(minibatch, true, t) |> device
            loss, grads = Flux.withgradient(t.model) do m
                sum(masklm_losses(m, batch))
            end
            total_grads = tuplesum(total_grads, grads[1])
            push!(losses, loss)
            batch |> device_free!
        end
        total_grads = tupledivide(total_grads, length(minibatches))
        Flux.update!(t.opt, t.model, total_grads)
    end
    mean(losses)
end;

In [None]:
function checkpoint(sentences, t::Trainer, training_loss, epoch, name)
    @info "evaluating metrics"
    metrics = evaluate_metrics(sentences, t)
    metrics["Training Loss"] = training_loss
    write_params(
        Dict(
            "m" => t.model |> cpu,
            "opt" => t.opt |> cpu,
            "lr_schedule" => t.lr_schedule,
            "epoch" => epoch,
            "metrics" => metrics,
            "training_config" => t.training_config,
            "model_config" => t.model_config,
            "rng" => t.rng,
        ),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    if CUDA.functional()
        Random.seed!(CUDA.default_rng(), rand(rng, UInt64))
        Random.seed!(CUDA.CURAND.default_rng(), rand(rng, UInt64))
    end
    rng
end;

In [None]:
function get_sentences(rng, training_config)
    sentences = get_training_data(
        training_config["include_ptw_impressions"],
        training_config["cls_tokens"],
    )
    Random.shuffle!(rng, sentences)
    cutoff = Int(round(0.99 * length(sentences)))
    training = sentences[1:cutoff]
    validation = prune(sentences[cutoff+1:end], is_ptw)
    training, validation
end;

function set_epoch_size!(training_config, training_sentences)
    if training_config["line_by_line"]
        training_config["iters_per_epoch"] = length(training_sentences)
    else
        total_tokens =
            sum(min.(length.(training_sentences), training_config["max_document_length"]))
        training_config["iters_per_epoch"] =
            Int(ceil(total_tokens / training_config["max_sequence_length"]))
    end
end;

In [None]:
function create_training_config()
    base_vocab_sizes = (
        Int32(num_items(medium)),
        Float32(11),
        Float32(1),
        Int32(5),
        Float32(1),
        Int32(num_users(medium)),
        Int32(512), # todo increase
    )
    d = Dict(
        # tokenization
        "base_vocab_sizes" => base_vocab_sizes,
        "cls_tokens" => base_vocab_sizes .+ Int32(1),
        "pad_tokens" => base_vocab_sizes .+ Int32(2),
        "mask_tokens" => base_vocab_sizes .+ Int32(3),
        "sep_tokens" => base_vocab_sizes .+ Int32(4),
        "vocab_sizes" => base_vocab_sizes .+ Int32(4),
        # training
        "minibatch_size" => 16, # TODO
        "batch_size" => 16, # TODO
        "num_epochs" => 16, # TODO
        "user_weighted_training" => false,
        "peak_learning_rate" => 1f-3,
        "weight_decay" => 1f-2,
        # data
        "line_by_line" => false,
        "max_document_length" => 512 * 2, # TODO
        "include_ptw_impressions" => false,
        "explicit_baseline" =>
            read_params("$medium/$(ALL_TASKS[1])/ExplicitUserItemBiases"),
        # model
        "num_layers" => 4, # TODO
        "hidden_size" => 512, # TODO
        "max_sequence_length" => 512,
    )
    @assert d["max_document_length"] >= d["max_sequence_length"]
    @assert d["batch_size"] >= d["minibatch_size"]
    @assert (d["batch_size"] % d["minibatch_size"]) == 0
    @assert length(d["explicit_baseline"]) == 3
    @assert length(d["explicit_baseline"]["λ"]) == 5
    d
end;

In [None]:
function create_model_config(training_config)
    # follows the recipe in Section 5 of [Well-Read Students Learn Better: On the 
    # Importance of Pre-training Compact Models](https://arxiv.org/pdf/1908.08962.pdf)
    Dict(
        "attention_dropout" => 0.1,
        "hidden_act" => gelu,
        "num_hidden_layers" => training_config["num_layers"],
        "hidden_size" => training_config["hidden_size"],
        "max_sequence_length" => training_config["max_sequence_length"],
        "vocab_sizes" => training_config["vocab_sizes"],
        "num_attention_heads" => Int(training_config["hidden_size"] / 64),
        "dropout" => 0.1,
        "intermediate_size" => training_config["hidden_size"] * 4,
    )
end;

In [None]:
function load_from_checkpoint(
    training_config,
    outdir::String,
    epoch::Integer,
    reset_lr_schedule,
    rng,
)
    params = read_params("$outdir/$epoch")
    model = params["m"] |> gpu
    if training_config != params["training_config"]
        @info "training config differs from stored params"
        training_config = params["training_config"]
    end
    model_config = params["model_config"]
    opt = params["opt"] |> gpu
    if reset_lr_schedule
        lr_schedule = get_lr_schedule(config)
    else
        lr_schedule = params["lr_schedule"]
    end
    rng = params["rng"]
    Trainer(model, opt, lr_schedule, training_config, model_config, rng), epoch
end

function load_from_checkpoint(training_config, ::Nothing, ::Nothing, reset_lr_schedule, rng)
    model_config = create_model_config(training_config)
    model = create_bert(model_config) |> gpu
    lr = Float32(config["peak_learning_rate"])
    wd = Float32(config["weight_decay"])
    opt = Optimisers.setup(
        OptimiserChain(Adam(lr, (0.9f0, 0.999f0)), WeightDecay(lr * wd)),
        model,
    )
    lr_schedule = get_lr_schedule(config)
    Trainer(model, opt, lr_schedule, training_config, model_config, rng), 0
end;

# Actually Train Model!

In [None]:
config_checkpoint = nothing
config_epoch = nothing
reset_lr_schedule = true
config_rng = set_rngs(20221221)
config = create_training_config();

In [None]:
training_sentences, validation_sentences = get_sentences(config_rng, config);

In [None]:
set_epoch_size!(config, training_sentences);

In [None]:
trainer, starting_epoch = load_from_checkpoint(
    config,
    config_checkpoint,
    config_epoch,
    reset_lr_schedule,
    config_rng,
);

In [None]:
@info "Training model with $(sum(length, Flux.params(trainer.model))) total parameters"
@info "Embedding parameters: $(sum(length, Flux.params(trainer.model.embed)))"
@info "Transformer parameters: $(sum(length, Flux.params(trainer.model.transformers)))"
@info "Classifier parameters: $(sum(length, Flux.params(trainer.model.classifier)))"

In [None]:
checkpoint(validation_sentences, trainer, Inf, starting_epoch, name)

In [None]:
for i = 1:trainer.training_config["num_epochs"]
    training_loss = train_epoch!(training_sentences, trainer)
    checkpoint(validation_sentences, trainer, training_loss, starting_epoch + i, name)
end;