# Pretrains a tranformer encoder model on watch histories

In [None]:
name = "all/Transformer/test";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb");

In [None]:
import CUDA
import Flux
import Flux: Chain, Dense, Dropout, LayerNorm, cpu, gelu, gpu, logsoftmax
import Flux.NNlib: gather
import Functors: @functor
import Optimisers
import Optimisers: OptimiserChain, Adam, WeightDecay
import ParameterSchedulers
import ParameterSchedulers: Sequence, Triangle, Shifted, Stateful
import Random
import StatsBase: mean, sample
import Transformers: Bert, Positionwise, TransformerModel
import Transformers.Basic: CompositeEmbedding, Embed, PositionEmbedding

# Tokenize training data

In [None]:
function encode_word(item, rating, timestamp, status, completion, user)
    if timestamp == -1
        ts = 1
    else
        date = timestamp_to_date(timestamp)
        year = Dates.value(Dates.Year(date)) - 2004
        season = div(Dates.value(Dates.Month(date)) - 1, 4) + 1
        ts = 1 + year * 4 + season
    end
    r = Int32(round(rating)) + 1
    c = Int32(round(10 * completion)) + 1
    word = (item, r, ts, status, c, user)
    convert.(Int32, word)
end

function get_training_data(task, include_ptw; show_progress_bar = false)
    function get_df(task, content)
        df = get_raw_split("training", task, content)
        if content != "explicit"
            df.rating .= 11
        end
        df
    end

    contents = ["explicit", "implicit"]
    if include_ptw
        push!(contents, "ptw")
    end
    sentences = Dict{Int32,Vector{NTuple{6,Int32}}}()
    df = reduce(cat, [get_df(task, content) for content in contents])
    order = sortperm(df.timestamp)
    p = ProgressMeter.Progress(length(order); enabled = show_progress_bar, showspeed = true)
    for idx = 1:length(order)
        i = order[idx]
        if df.user[i] ∉ keys(sentences)
            sentences[df.user[i]] = NTuple{6,Int32}[]
        end
        word = encode_word(
            df.item[i],
            df.rating[i],
            df.timestamp[i],
            df.status[i],
            df.completion[i],
            df.user[i],
        )
        push!(sentences[df.user[i]], word)
        ProgressMeter.next!(p)
    end
    ProgressMeter.finish!(p)
    [sentences[k] for k in keys(sentences)]
end;

function get_training_data(include_ptw)
    sentences = Vector{Vector{Vector{NTuple{6,Int32}}}}(undef, length(ALL_TASKS))
    # only print progress bars for one thread
    show_progress_thread = Threads.Atomic{Int}(0)
    Threads.@threads for i = 1:length(ALL_TASKS)
        Threads.atomic_cas!(show_progress_thread, 0, Threads.threadid())
        sentences[i] = get_training_data(
            ALL_TASKS[i],
            include_ptw;
            show_progress_bar = show_progress_thread[] == Threads.threadid(),
        )
    end
    vcat(sentences...)
end;

In [None]:
function pad_sentence(sentence, max_seq_length, cls_tokens, pad_tokens; rng)
    outputs = fill.(pad_tokens, max_seq_length)
    for i = 1:length(outputs)
        outputs[i][1] = cls_tokens[i]
    end
    seq_len = max_seq_length - 1
    if length(sentence) > seq_len
        # take a random contiguous subset
        idx = rand(rng, 1:length(sentence)-seq_len+1)
        sentence = sentence[idx:idx+seq_len-1]
    end
    for i = 1:length(sentence)
        for j = 1:length(outputs)
            outputs[j][1+i] = sentence[i][j]
        end
    end
    outputs
end;

In [None]:
function get_token_ids(sentences, max_seq_length, cls_tokens, pad_tokens; rng)
    padded_sentences = [
        pad_sentence(x, max_seq_length, cls_tokens, pad_tokens; rng = rng) for
        x in sentences
    ]
    Tuple(hcat([x[i] for x in padded_sentences]...) for i = 1:length(cls_tokens))
end;

In [None]:
function validate_tokenization(sentences, vocab_sizes)
    sharded_vocab_values =
        [[Set() for _ = 1:length(vocab_sizes)] for t = 1:Threads.nthreads()]
    @tprogress Threads.@threads for i = 1:length(sentences)
        for word in sentences[i]
            @assert all((word .>= 1) .&& (word .<= vocab_sizes)) word
            for j = 1:length(word)
                push!(sharded_vocab_values[Threads.threadid()][j], word[j])
            end
        end
    end

    vocab_values = [Set() for _ = 1:length(vocab_sizes)]
    @showprogress for t = 1:Threads.nthreads()
        for i = 1:length(vocab_sizes)
            union!(vocab_values[i], sharded_vocab_values[t][i])
        end
    end

    coverage = [length(vocab_values[i]) / vocab_sizes[i] for i = 1:length(vocab_sizes)]
    @info "Vocab values $(vocab_sizes)"
    @info "Minimum observed vocab values $(minimum.(vocab_values))"
    @info "Maximum observed vocab values $(maximum.(vocab_values))"
    @info "Coverage $coverage"
end;

In [None]:
function prune(sentences, invalid_word_fn)
    pruned_sentences = []
    @showprogress for i = 1:length(sentences)
        sentence = Vector{eltype(sentences[i])}()
        for word in sentences[i]
            if !invalid_word_fn(word)
                push!(sentence, word)
            end
        end
        if length(sentence) > 0
            push!(pruned_sentences, sentence)
        end
    end
    pruned_sentences
end;

# Create minibatches

In [None]:
function get_batch(
    sentences;
    max_seq_len,
    vocab_sizes,
    cls_tokens,
    pad_tokens,
    mask_tokens,
    rng,
    training,
)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)) + 1, max_seq_len)

    # get tokenized sentences
    tokens = get_token_ids(sentences, seq_len, cls_tokens, pad_tokens; rng = rng)

    # don't attend to padding tokens
    attention_mask = reshape(
        convert.(Float32, tokens[1] .!= pad_tokens[1]),
        (1, seq_len, length(sentences)),
    )

    # apply BERT masking
    masked_token_positions = (Tuple{Int32,Int32}[], Tuple{Int32,Int32}[])
    labels = (Tuple{Int32,Int32}[], Float32[])
    userids = (Int32[], Int32[])
    for b::Int32 = 1:length(sentences)
        seq_len = Int(sum(attention_mask[:, :, b]))
        for i::Int32 = 1:seq_len
            mask_item = rand(rng) < 0.15
            mask_rating = rand(rng) < 0.15

            if mask_item && (tokens[1][i, b] .<= vocab_sizes[1])
                push!(labels[1], (tokens[1][i, b], Int32(length(labels[1]) + 1)))
                push!(userids[1], tokens[6][i, b])
                for j in [2, 4, 5]
                    # when predicting masked items, dont use rating, status or completion metadata
                    tokens[j][i, b] = mask_tokens[j]
                end
                r = training ? rand(rng) : 0.0
                if r < 0.8
                    tokens[1][i, b] = mask_tokens[1]
                elseif r < 0.9
                    tokens[1][i, b] = rand(1:vocab_sizes[1])
                end
                push!(masked_token_positions[1], (i, b))
            end

            if mask_rating && !mask_item && (tokens[2][i, b] .< vocab_sizes[2])
                # only try to predict explicit ratings
                push!(labels[2], tokens[2][i, b])
                push!(userids[2], tokens[6][i, b])
                for j in [4, 5]
                    # when predicting masked ratings, dont use status or completion metadata
                    tokens[j][i, b] = mask_tokens[j]
                end
                r = training ? rand(rng) : 0.0
                if r < 0.8
                    tokens[2][i, b] = mask_tokens[2]
                elseif r < 0.9
                    tokens[2][i, b] = rand(1:vocab_sizes[2])
                end
                push!(masked_token_positions[2], (i, b))
            end
        end
    end
    processed_labels = (labels[1], convert.(Float32, collect(labels[2]')))
    processed_weights = (uids_to_weights(userids[1]), collect(uids_to_weights(userids[2])'))
    if training && !training_config["user_weighted_training"]
        processed_weights[1] .= 1
        processed_weights[2] .= 1
    end

    tokens, attention_mask, masked_token_positions, processed_labels, processed_weights
end;

In [None]:
function uids_to_weights(uids)
    uid_to_count = Dict(i => 0 for i in uids)
    for i in uids
        uid_to_count[i] += 1
    end
    weights = zeros(Float32, length(uids))
    for i = 1:length(uids)
        weights[i] = 1 / uid_to_count[uids[i]]
    end
    weights
end;

In [None]:
get_batch(sentences, training_config, rng, training) = get_batch(
    sentences;
    max_seq_len = training_config["max_sequence_length"],
    vocab_sizes = training_config["base_vocab_sizes"],
    cls_tokens = training_config["cls_tokens"],
    pad_tokens = training_config["pad_tokens"],
    mask_tokens = training_config["mask_tokens"],
    rng = rng,
    training = training,
);

In [None]:
function shuffle_training_data(
    rng,
    sentences,
    line_by_line,
    max_sequence_length,
    max_document_length,
    pad_tokens,
)
    order = Random.shuffle(rng, 1:length(sentences))
    if line_by_line
        return sentences[order]
    end
    max_sequence_length = max_sequence_length - 1 # leave room for CLS token 
    max_document_length = max_document_length - 1
    S = eltype(sentences)
    W = eltype(sentences[1])

    # concatenate all tokens
    tokens = Vector{W}()
    @showprogress for i in order
        sentence = sentences[i]
        if length(sentence) > max_document_length
            idx = rand(rng, 1:length(sentence)-max_document_length+1)
            sentence = sentence[idx:idx+max_document_length-1]
        end
        for token in sentence
            push!(tokens, token)
        end
        if i != order[end]
            push!(tokens, pad_tokens)
        end
    end

    # patition tokens into minibatches
    batched_sentences = Vector{S}()
    sentence = Vector{W}()
    @showprogress for token in tokens
        push!(sentence, token)
        if length(sentence) == max_sequence_length
            push!(batched_sentences, sentence)
            sentence = Vector{W}()
        end
    end
    if length(sentence) > 0
        push!(batched_sentences, sentence)
    end
    batched_sentences
end;

In [None]:
function device(batch)
    gpu.(batch[1]), gpu(batch[2]), gpu.(batch[3]), gpu.(batch[4]), gpu.(batch[5])
end

function device_free!(batch)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(batch[1])
    CUDA.unsafe_free!(batch[2])
    CUDA.unsafe_free!(batch[4][2])
    CUDA.unsafe_free!.(batch[5])
end;

# Create model

In [None]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = zeros) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
@functor BiasLayer;

In [None]:
(pe::PositionEmbedding)(x::AbstractArray{X}) where {X<:Integer} = pe(size(x, 1));

In [None]:
function create_bert(config)
    bert = Bert(
        config["hidden_size"],
        config["num_attention_heads"],
        config["intermediate_size"],
        config["num_hidden_layers"];
        act = config["hidden_act"],
        pdrop = config["hidden_dropout_prob"],
        attn_pdrop = config["attention_probs_dropout_prob"],
    )

    create_embedding(hidden_size, vocab_size) = Embed(Int(hidden_size), Int(vocab_size))
    item_emb = create_embedding(config["hidden_size"], config["vocab_sizes"][1])
    rating_emb = create_embedding(config["hidden_size"], config["vocab_sizes"][2])
    timestamp_emb = create_embedding(config["hidden_size"], config["vocab_sizes"][3])
    status_emb = create_embedding(config["hidden_size"], config["vocab_sizes"][4])
    completion_emb = create_embedding(config["hidden_size"], config["vocab_sizes"][5])

    position_emb = PositionEmbedding(
        config["hidden_size"],
        config["max_sequence_length"];
        trainable = true,
    )

    emb_post = Positionwise(
        LayerNorm(config["hidden_size"]),
        Dropout(config["hidden_dropout_prob"]),
    )

    emb = CompositeEmbedding(
        item = item_emb,
        rating = rating_emb,
        timestamp = timestamp_emb,
        status = status_emb,
        position = position_emb,
        postprocessor = emb_post,
    )

    item_cls = (
        transform = Chain(
            Dense(config["hidden_size"], config["hidden_size"], config["hidden_act"]),
            LayerNorm(config["hidden_size"]),
        ),
        output_bias = BiasLayer(config["vocab_sizes"][1]),
    )
    rating_cls = Dense(config["hidden_size"], 1)
    clf = (item = item_cls, rating = rating_cls)

    TransformerModel(emb, bert, clf)
end;

# Loss metrics

In [None]:
function masklm_losses(model, batch)
    tokens, attention_mask, masked_token_positions, masked_token_labels, weights = batch
    X = model.embed(
        item = tokens[1],
        rating = tokens[2],
        timestamp = tokens[3],
        status = tokens[4],
        completion = tokens[5],
        position = tokens[1],
    )
    X = model.transformers(X, attention_mask)

    if length(masked_token_labels[1]) > 0
        item_pred = logsoftmax(
            transpose(model.embed.embeddings.item.embedding) *
            model.classifier.item.transform(gather(X, masked_token_positions[1])) .+
            model.classifier.item.output_bias.b,
        )
        item_loss =
            -sum(gather(item_pred, masked_token_labels[1]) .* weights[1]) / sum(weights[1])
    else
        item_loss = 0.0f0
    end

    if length(masked_token_labels[2]) > 0
        rating_pred = model.classifier.rating(gather(X, masked_token_positions[2]))
        rating_loss =
            sum((rating_pred - masked_token_labels[2]) .^ 2 .* weights[2]) / sum(weights[2])
    else
        rating_loss = 0.0f0
    end

    item_loss, rating_loss
end;

In [None]:
function evaluate_metrics(model, sentences, training_config; rng)
    sumtotals = [0.0, 0.0]
    Random.shuffle!(rng, sentences)
    sentence_batches =
        collect(Iterators.partition(sentences, training_config["batch_size"]))
    @showprogress for sbatch in sentence_batches
        batch = get_batch(sbatch, training_config, rng, false) |> device
        sumtotals .+= masklm_losses(model, batch)
        device_free!(batch)
    end
    totals = sumtotals ./ length(sentence_batches)
    Dict("Item Crossentropy Loss" => totals[1], "Rating MSE Loss" => totals[2])
end;

# Training

In [None]:
struct Trainer
    model::Any
    opt::Any
    lr_schedule::Any
    training_config::Any
    model_config::Any
    rng::Any
end;

In [None]:
function schedule_learning_rate!(opt, lr_schedule)
    lr = Float32(ParameterSchedulers.next!(lr_schedule))
    Optimisers.adjust!(opt, eta = lr, gamma = lr * 1f-2)
end;

In [None]:
tuplesum(a::Nothing, b::Nothing) = nothing
tuplesum(a::Nothing, b) = b
tuplesum(a, b::Nothing) = a
function tuplesum(a::NamedTuple, b::NamedTuple)
    fields = fieldnames(typeof(a))
    NamedTuple{fields}(tuplesum(a[k], b[k]) for k in fields)
end
tuplesum(a::Tuple, b::Tuple) = Tuple(tuplesum(a[k], b[k]) for k = 1:length(a))
tuplesum(a, b) = a .+ b;

In [None]:
function train_epoch!(t::Trainer, sentences)
    sentences = shuffle_training_data(
        t.rng,
        sentences,
        t.training_config["line_by_line"],
        t.training_config["max_sequence_length"],
        t.training_config["max_document_length"],
        t.training_config["pad_tokens"],
    )
    sentence_batches = collect(
        Iterators.partition(
            sentences,
            t.training_config["gradient_accumulation_batch_size"],
        ),
    )
    @showprogress for sbatch in sentence_batches
        minibatches = collect(Iterators.partition(sbatch, t.training_config["batch_size"]))
        schedule_learning_rate!(t.opt, t.lr_schedule)
        total_grads = nothing
        for minibatch in minibatches
            batch = get_batch(minibatch, t.training_config, rng, true) |> device
            grads = Flux.gradient(t.model) do m
                sum(masklm_losses(m, batch))
            end
            total_grads = tuplesum(total_grads, grads)
            batch |> device_free!
        end
        Flux.update!(t.opt, t.model, total_grads[1])
    end
end;

In [None]:
function checkpoint(t::Trainer, sentences, epoch, outdir)
    @info "evaluating metrics"
    metrics = evaluate_metrics(t.model, sentences, t.training_config; rng = t.rng)
    write_params(
        Dict(
            "m" => t.model |> cpu,
            "opt" => t.opt |> cpu,
            "lr_schedule" => t.lr_schedule,
            "epoch" => epoch,
            "metrics" => metrics,
            "training_config" => t.training_config,
            "model_config" => t.model_config,
            "rng" => t.rng,
        ),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    if CUDA.functional()
        Random.seed!(CUDA.default_rng(), rand(rng, UInt64))
        Random.seed!(CUDA.CURAND.default_rng(), rand(rng, UInt64))
    end
    rng
end;

In [None]:
function get_sentences(rng, training_config)
    sentences = get_training_data(training_config["include_ptw_impressions"])
    Random.shuffle!(rng, sentences)
    cutoff = Int(round(0.95 * length(sentences)))
    is_ptw(word) = word[4] == 1
    training = sentences[1:cutoff]
    validation = prune(sentences[cutoff+1:end], is_ptw)
    if training_config["line_by_line"]
        training_config["iters_per_epoch"] = length(training)
    else
        # this is an approximation
        total_tokens =
            sum(min.(length.(training), training_config["max_document_length"])) +
            length(training) - 1
        training_config["iters_per_epoch"] =
            Int(ceil(total_tokens / (training_config["max_sequence_length"] - 1)))
    end
    training, validation
end;

In [None]:
function create_training_config()
    base_vocab_sizes =
        convert.(
            Int32,
            (
                num_items(),
                12,
                (Dates.value(Dates.Year(Dates.today())) - 2004 + 1) * 4 + 1,
                5,
                11,
                num_users(),
            ),
        )
    Dict(
        "base_vocab_sizes" => base_vocab_sizes,
        "cls_tokens" => base_vocab_sizes .+ Int32(1),
        "pad_tokens" => base_vocab_sizes .+ Int32(2),
        "mask_tokens" => base_vocab_sizes .+ Int32(3),
        "sep_tokens" => base_vocab_sizes .+ Int32(4),
        "vocab_sizes" => base_vocab_sizes .+ Int32(4),
        "batch_size" => 16,
        "gradient_accumulation_batch_size" => 16,
        "max_sequence_length" => 512,
        "num_epochs" => 8,
        "include_ptw_impressions" => false,
        "line_by_line" => false,
        "max_document_length" => 512, # controls document subsampling when not in line by line mode
        "user_weighted_training" => false,
    )
end;

In [None]:
function LinearWarmupSchedule(lr, iters)
    warmup_steps = Int(round(iters * 0.06))
    remaining_steps = iters - warmup_steps
    Stateful(
        Sequence(
            Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * warmup_steps) => warmup_steps,
            Shifted(
                Triangle(λ0 = 0.0f0, λ1 = lr, period = 2 * remaining_steps),
                remaining_steps,
            ) => remaining_steps,
        ),
    )
end;

In [None]:
function create_model_config(layers, hidden_size, training_config)
    # follows the recipe in Section 5 of [Well-Read Students Learn Better: On the 
    # Importance of Pre-training Compact Models](https://arxiv.org/pdf/1908.08962.pdf)
    Dict(
        "attention_probs_dropout_prob" => 0.1,
        "hidden_act" => gelu,
        "num_hidden_layers" => layers,
        "hidden_size" => hidden_size,
        "max_sequence_length" => training_config["max_sequence_length"],
        "vocab_sizes" => training_config["vocab_sizes"],
        "num_attention_heads" => Int(hidden_size / 64),
        "hidden_dropout_prob" => 0.1,
        "intermediate_size" => hidden_size * 4,
    )
end;

In [None]:
function load_from_checkpoint(training_config, epoch::Integer, rng)
    params = read_params("$name/checkpoints/$epoch")
    model = params["m"] |> gpu
    @assert training_config == params["training_config"]
    model_config = params["model_config"]
    opt = params["opt"] |> gpu
    lr_schedule = params["lr_schedule"]
    rng = params["rng"]
    Trainer(model, opt, lr_schedule, training_config, model_config, rng), epoch + 1
end

function load_from_checkpoint(training_config, ::Nothing, rng)
    model_config = create_model_config(4, 512, training_config)
    model = create_bert(model_config) |> gpu
    opt = Optimisers.setup(
        OptimiserChain(Adam(1f-4, (0.9f0, 0.999f0)), WeightDecay(1f-6)),
        model,
    )
    max_batches = Int(
        round(
            training_config["num_epochs"] * training_config["iters_per_epoch"] /
            training_config["gradient_accumulation_batch_size"],
        ),
    )
    lr_schedule = LinearWarmupSchedule(1f-4, max_batches)
    Trainer(model, opt, lr_schedule, training_config, model_config, rng), 1
end;

# Actually Train Model!

In [None]:
rng = set_rngs(20221221)
training_config = create_training_config();

In [None]:
training_sentences, validation_sentences = get_sentences(rng, training_config)
validate_tokenization(
    [training_sentences; validation_sentences],
    training_config["base_vocab_sizes"],
)

In [None]:
trainer, starting_epoch = load_from_checkpoint(training_config, nothing, rng);

In [None]:
@info "Training model with $(sum(length, Flux.params(trainer.model))) total parameters"
@info "Embedding parameters: $(sum(length, Flux.params(trainer.model.embed)))"
@info "Transformer parameters: $(sum(length, Flux.params(trainer.model.transformers)))"
@info "Classifier parameters: $(sum(length, Flux.params(trainer.model.classifier)))"
@info "Training config $(trainer.training_config)"
@info "Model config $(trainer.model_config)"

In [None]:
for epoch = starting_epoch:100
    train_epoch!(trainer, training_sentences)    
    checkpoint(trainer, validation_sentences, epoch, name)        
end;