# Pretrains a tranformer encoder model on watch histories

In [None]:
name = "all/Transformer/small";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb");

In [None]:
using Flux
using Transformers
using Transformers.Basic
import CUDA
import Random
import StatsBase: mean, sample

# Tokenize training data

In [None]:
function encode_word(item, rating, timestamp, status)
    if timestamp == -1
        ts = 1
    else
        ts = Int32(round(timestamp / year_in_timestamp_units())) + 2
    end
    r = Int32(round(rating)) + 1
    word = (item, r, ts, status)
    convert.(Int32, word)
end

function get_training_data()
    function get_df(task, content)
        df = get_raw_split("training", task, content)
        if content == "implicit"
            df.rating .= 11
        end
        df
    end

    sentences = Dict{Int32,Vector{NTuple{4,Int32}}}()
    for task in ALL_TASKS
        df = cat(get_df(task, "explicit"), get_df(task, "implicit"))
        order = sortperm(df.timestamp)
        @showprogress for idx = 1:length(order)
            i = order[idx]
            if df.user[i] ∉ keys(sentences)
                sentences[df.user[i]] = NTuple{4,Int32}[]
            end
            word = encode_word(df.item[i], df.rating[i], df.timestamp[i], df.status[i])
            push!(sentences[df.user[i]], word)
        end
    end
    [sentences[k] for k in keys(sentences)]
end;

In [None]:
function pad_sentence(sentence, max_seq_length, cls_tokens, pad_tokens; rng)
    outputs = fill.(pad_tokens, max_seq_length)
    for i = 1:length(outputs)
        outputs[i][1] = cls_tokens[i]
    end
    seq_len = max_seq_length - 1
    if length(sentence) > seq_len
        # take a random contiguous subset
        idx = rand(rng, 1:length(sentence)-seq_len+1)
        sentence = sentence[idx:idx+seq_len-1]
    end
    for i = 1:length(sentence)
        for j = 1:length(outputs)
            outputs[j][1+i] = sentence[i][j]
        end
    end
    outputs
end;

In [None]:
function get_token_ids(sentences, max_seq_length, cls_tokens, pad_tokens; rng)
    padded_sentences = [
        pad_sentence(x, max_seq_length, cls_tokens, pad_tokens; rng = rng) for
        x in sentences
    ]
    Tuple(hcat([x[i] for x in padded_sentences]...) for i = 1:length(cls_tokens))
end;

In [None]:
function get_batch(
    sentences;
    max_seq_len,
    vocab_sizes,
    cls_tokens,
    pad_tokens,
    mask_tokens,
    rng,
    training,
)
    # dynamically pad to the largest sequence length
    seq_len = min(maximum(length.(sentences)) + 1, max_seq_len)

    # get tokenized sentences
    tokens = get_token_ids(sentences, seq_len, cls_tokens, pad_tokens; rng = rng)

    # don't attend to padding tokens
    attention_mask = reshape(
        convert.(Float32, tokens[1] .!= pad_tokens[1]),
        (1, seq_len, length(sentences)),
    )

    # apply BERT masking
    masked_token_positions = ([], [])
    labels = ([], [])
    for b = 1:length(sentences)
        seq_len = Int(sum(attention_mask[:, :, b]))
        for i = 2:seq_len
            mask_item = rand(rng) < 0.15
            mask_rating = rand(rng) < 0.15

            if mask_item
                push!(labels[1], (tokens[1][i, b], length(labels[1]) + 1))
                for j in [2, 4]
                    # when predicting masked items, dont use rating or status metadata
                    tokens[j][i, b] = mask_tokens[j]
                end
                r = training ? rand(rng) : 0.0
                if r < 0.8
                    tokens[1][i, b] = mask_tokens[1]
                elseif r < 0.9
                    tokens[1][i, b] = rand(1:vocab_sizes[1])
                end
                push!(masked_token_positions[1], (i, b))
            end

            if mask_rating && !mask_item && (tokens[2][i, b] != vocab_sizes[2])
                # only try to predict explicit ratings
                push!(labels[2], tokens[2][i, b])
                for j in [4]
                    # when predicting masked ratings, dont use status metadata
                    tokens[j][i, b] = mask_tokens[j]
                end
                r = training ? rand(rng) : 0.0
                if r < 0.8
                    tokens[2][i, b] = mask_tokens[2]
                elseif r < 0.9
                    tokens[2][i, b] = rand(1:vocab_sizes[2])
                end
                push!(masked_token_positions[2], (i, b))
            end
        end
    end
    processed_labels = (labels[1], convert.(Float32, collect(labels[2]')))

    tokens, attention_mask, masked_token_positions, processed_labels
end;

In [None]:
function device(batch)
    gpu.(batch[1]), gpu(batch[2]), gpu.(batch[3]), gpu.(batch[4])
end

function device_free!(batch)
    if !CUDA.functional()
        return
    end
    CUDA.unsafe_free!.(batch[1])
    CUDA.unsafe_free!(batch[2])
    CUDA.unsafe_free!(batch[4][2])
end;

# Create model

In [None]:
# A layer that adds a 1-D vector to the input
struct BiasLayer
    b::Any
end
BiasLayer(n::Integer; init = zeros) = BiasLayer(init(Float32, n))
(m::BiasLayer)(x) = x .+ m.b
Flux.@functor BiasLayer

In [None]:
function create_bert(config)
    bert = Bert(
        config["hidden_size"],
        config["num_attention_heads"],
        config["intermediate_size"],
        config["num_hidden_layers"];
        act = config["hidden_act"],
        pdrop = config["hidden_dropout_prob"],
        attn_pdrop = config["attention_probs_dropout_prob"],
    )

    item_emb = Embed(config["hidden_size"], config["vocab_sizes"][1])
    rating_emb = Embed(config["hidden_size"], config["vocab_sizes"][2])
    timestamp_emb = Embed(config["hidden_size"], config["vocab_sizes"][3])
    status_emb = Embed(config["hidden_size"], config["vocab_sizes"][4])

    position_emb = PositionEmbedding(
        config["hidden_size"],
        config["max_sequence_length"];
        trainable = true,
    )

    emb_post = Positionwise(
        LayerNorm(config["hidden_size"]),
        Dropout(config["hidden_dropout_prob"]),
    )

    emb = CompositeEmbedding(
        item = item_emb,
        rating = rating_emb,
        timestamp = timestamp_emb,
        status = status_emb,
        position = position_emb,
        postprocessor = emb_post,
    )

    item_cls = (
        transform = Chain(
            Dense(config["hidden_size"], config["hidden_size"], config["hidden_act"]),
            LayerNorm(config["hidden_size"]),
        ),
        output_bias = BiasLayer(config["vocab_sizes"][1]),
    )
    rating_cls = Dense(config["hidden_size"], 1)
    clf = (item = item_cls, rating = rating_cls)

    TransformerModel(emb, bert, clf)
end;

# Loss metrics and training utils

In [None]:
function masklm_losses(model, batch)
    tokens, attention_mask, masked_token_positions, masked_token_labels = batch
    X = model.embed(
        item = tokens[1],
        rating = tokens[2],
        timestamp = tokens[3],
        status = tokens[4],
        position = tokens[1],
    )
    X = model.transformers(X, attention_mask)

    if length(masked_token_labels[1]) > 0
        item_pred = logsoftmax(
            transpose(model.embed.embeddings.item.embedding) *
            model.classifier.item.transform(gather(X, masked_token_positions[1])) .+
            model.classifier.item.output_bias.b,
        )
        item_loss = -mean(gather(item_pred, masked_token_labels[1]))
    else
        item_loss = 0.0f0
    end

    if length(masked_token_labels[2]) > 0
        rating_pred = model.classifier.rating(gather(X, masked_token_positions[2]))
        rating_loss = mean((rating_pred - masked_token_labels[2]) .^ 2)
    else
        rating_loss = 0.0f0
    end

    item_loss, rating_loss
end;

In [None]:
function evaluate_metrics(model, sentences, training_config; rng)
    sumtotals = [0.0, 0.0]
    Random.shuffle!(rng, sentences)
    sentence_batches =
        collect(Iterators.partition(sentences, training_config["batch_size"]))
    @showprogress for sbatch in sentence_batches
        batch =
            get_batch(
                sbatch;
                max_seq_len = training_config["max_sequence_length"],
                vocab_sizes = training_config["vocab_sizes"],
                cls_tokens = training_config["cls_tokens"],
                pad_tokens = training_config["pad_tokens"],
                mask_tokens = training_config["mask_tokens"],
                rng = rng,
                training = false,
            ) |> device
        sumtotals .+= masklm_losses(model, batch)
        device_free!(batch)
    end
    totals = sumtotals ./ length(sentence_batches)
    Dict("Item Crossentropy Loss" => totals[1], "Rating MSE Loss" => totals[2])
end;

In [None]:
function train_epoch!(model, opt, sentences, training_config; rng)
    ps = Flux.params(model)
    Random.shuffle!(rng, sentences)
    sentence_batches =
        collect(Iterators.partition(sentences, training_config["batch_size"]))
    @showprogress for sbatch in sentence_batches
        batch =
            get_batch(
                sbatch;
                max_seq_len = training_config["max_sequence_length"],
                vocab_sizes = training_config["vocab_sizes"],
                cls_tokens = training_config["cls_tokens"],
                pad_tokens = training_config["pad_tokens"],
                mask_tokens = training_config["mask_tokens"],
                rng = rng,
                training = true,
            ) |> device
        grads = Flux.gradient(ps) do
            sum(masklm_losses(model, batch))
        end
        Flux.Optimise.update!(opt, ps, grads)
        device_free!(batch)
    end
end;

In [None]:
function checkpoint(
    model,
    opt,
    sentences,
    training_config,
    model_config,
    epoch;
    rng,
    outdir,
)
    @info "evaluating metrics"
    metrics = evaluate_metrics(model, sentences, training_config; rng = rng)
    write_params(
        Dict(
            "m" => cpu(model),
            "opt" => opt,
            "epoch" => epoch,
            "metrics" => metrics,
            "training_config" => training_config,
            "model_config" => model_config,
        ),
        "$name/checkpoints/$epoch",
    )
    @info "saving model after $epoch epochs with metrics $metrics"
end;

# Train Model

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    if CUDA.functional()
        Random.seed!(CUDA.default_rng(), rand(rng, UInt64))
        Random.seed!(CUDA.CURAND.default_rng(), rand(rng, UInt64))
    end
    rng
end

rng = set_rngs(20221221);

In [None]:
sentences = get_training_data()
Random.shuffle!(rng, sentences)
cutoff = Int(round(0.95 * length(sentences)))
training_sentences = sentences[1:cutoff]
validation_sentences = sentences[cutoff+1:end];

In [None]:
base_vocab_sizes =
    convert.(Int32, (num_items(), 12, Int(ceil(1 / year_in_timestamp_units())) + 2, 5))
training_config = Dict(
    "base_vocab_sizes" => base_vocab_sizes,
    "cls_tokens" => base_vocab_sizes .+ 1,
    "pad_tokens" => base_vocab_sizes .+ 2,
    "mask_tokens" => base_vocab_sizes .+ 3,
    "sep_tokens" => base_vocab_sizes .+ 4,
    "vocab_sizes" => base_vocab_sizes .+ 4,
    "batch_size" => 128,
    "max_sequence_length" => 512,
);

In [None]:
# check that vocab sizes are correct
@tprogress Threads.@threads for i = 1:length(sentences)
    for word in sentences[i]
        @assert all((word .>= 1) .&& (word .<= base_vocab_sizes)) word
    end
end

In [None]:
function create_model_config(layers, hidden_size, training_config)
    # follows the recipe in Section 5 of [Well-Read Students Learn Better: On the 
    # Importance of Pre-training Compact Models](https://arxiv.org/pdf/1908.08962.pdf)
    Dict(
        "attention_probs_dropout_prob" => 0.1,
        "hidden_act" => gelu,
        "num_hidden_layers" => layers,
        "hidden_size" => hidden_size,
        "max_sequence_length" => training_config["max_sequence_length"],
        "vocab_sizes" => training_config["vocab_sizes"],
        "num_attention_heads" => Int(hidden_size / 64),
        "hidden_dropout_prob" => 0.1,
        "intermediate_size" => hidden_size * 4,
    )
end;

In [None]:
function load_from_checkpoint(::Nothing)
    # todo schedule learning rate warmup and decay
    opt = ADAMW(1e-4, (0.9, 0.999), 1e-4 * 0.01)
    model_config = create_model_config(4, 512, training_config)
    ryouko = create_bert(model_config) |> gpu
    ryouko, opt, model_config, 0
end

function load_from_checkpoint(epoch::Integer)
    params = read_params("$name/checkpoints/$epoch")
    ryouko = params["m"] |> gpu
    opt = params["opt"]
    model_config = params["model_config"]
    ryouko, opt, model_config, epoch
end

ryouko, opt, model_config, start_epoch = load_from_checkpoint(nothing);

In [None]:
@info "Training model with $(sum(length, Flux.params(ryouko))) parameters"

In [None]:
for epoch = start_epoch:100
    checkpoint(ryouko, opt, validation_sentences, training_config, model_config, epoch; rng=rng, outdir=name)
    train_epoch!(ryouko, opt, training_sentences, training_config; rng=rng)
end