# Finetuning
* Finetunes a transformer model to predict recent watches

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Data.ipynb");

In [None]:
import HDF5
import JSON
import MLUtils
import NNlib: sigmoid
import Random
import SparseArrays: AbstractSparseArray, sparse, spzeros
import StatsBase: mean, sample

In [None]:
medium = "";

In [None]:
const version = "v1"
const pretrain_name = "all/Transformer/$version"
const name = "$medium/Transformer/$version"
set_logging_outdir(name);

# Data

In [None]:
function get_labels(metric, medium)
    df = cat(
        get_split("validation", metric, medium, [:userid, :itemid, :metric]),
        get_split("test", metric, medium, [:userid, :itemid, :metric]),
    )
    sparse(df, medium)
end

function get_weights(metric, medium)
    df = cat(
        get_split("validation", metric, medium, [:userid, :itemid]),
        get_split("test", metric, medium, [:userid, :itemid]),
    )
    df = @set df.metric = powerdecay(get_counts(df.userid), -1.0f0)
    sparse(df, medium)
end

get_users(split, medium) =
    collect(Set(get_raw_split(split, medium, [:userid], nothing).userid));

In [None]:
function tokenize(sentences, labels, weights, medium, userid, config)
    if userid in keys(sentences)
        sentence = copy(sentences[userid])
    else
        sentence = Vector{wordtype}()
        push!(sentence, replace(config[:cls_tokens], :userid, userid))
    end
    tokenize(;
        sentence = sentence,
        labels = map(x -> x[:, userid+1], labels),
        weights = map(x -> x[:, userid+1], weights),
        medium = medium,
        userid = userid,
        max_seq_len = config[:max_sequence_length],
        vocab_sizes = config[:base_vocab_sizes],
        pad_tokens = config[:pad_tokens],
        cls_tokens = config[:cls_tokens],
        mask_tokens = config[:mask_tokens],
    )
end;

In [None]:
function tokenize(;
    sentence::Vector{wordtype},
    labels,
    weights,
    medium,
    userid,
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
)
    # get inputs
    sentence =
        subset_sentence(sentence, min(length(sentence), max_seq_len - 1); recent = true)
    # masked_word = replace(mask_tokens, :updated_at, 1) # TODO
    masked_word = mask_tokens
    masked_word = replace(masked_word, :position, length(sentence) - 1)
    masked_word = replace(masked_word, :userid, userid)
    push!(sentence, masked_word)
    tokens = get_token_ids(sentence, max_seq_len, pad_tokens, cls_tokens)

    # get outputs
    positions = [length(sentence) - 1]
    tokenized_labels = Dict(
        x => Dict(y => spzeros(Float32, num_items(x)) for y in ALL_METRICS) for
        x in ALL_MEDIUMS
    )
    tokenized_weights = Dict(
        x => Dict(y => spzeros(Float32, num_items(x)) for y in ALL_METRICS) for
        x in ALL_MEDIUMS
    )
    for i = 1:length(ALL_METRICS)
        tokenized_labels[medium][ALL_METRICS[i]] .= labels[i]
        tokenized_weights[medium][ALL_METRICS[i]] .= weights[i]
    end
    tokens, positions, tokenized_labels, tokenized_weights
end;

# Epochs

In [None]:
function hdf5_writer(c::Channel)
    while true
        (d, fn) = take!(c)
        HDF5.h5open(fn, "w") do f
            for (k, v) in d
                f[k, deflate = 3] = v
            end
        end
    end
end;

In [None]:
function record_sparse_array!(d::Dict, name::String, x::AbstractSparseArray)
    i, j, v = SparseArrays.findnz(x)
    d[name*"_i"] = i .- 1
    d[name*"_j"] = j .- 1
    d[name*"_v"] = v
    d[name*"_size"] = collect(size(x))
end;

In [None]:
function save_tokens(sentences, labels, weights, users, config, filename, writer)
    tokens = Any[nothing for _ = 1:length(users)]
    Threads.@threads for i = 1:length(users)
        tokens[i] = tokenize(sentences, labels, weights, medium, users[i], config)
    end

    d = Dict{String,AbstractArray}()
    collate = MLUtils.batch
    embed_names = ["itemid", "rating", "updated_at", "status", "position", "userid"]
    for (i, name) in Iterators.enumerate(embed_names)
        d[name] = collate([x[1][i] for x in tokens])
    end
    d["positions"] = collate([x[2] for x in tokens])
    for medium in ALL_MEDIUMS
        for metric in ALL_METRICS
            record_sparse_array!(
                d,
                "labels_$(medium)_$(metric)",
                collate([x[3][medium][metric] for x in tokens]),
            )
            record_sparse_array!(
                d,
                "weights_$(medium)_$(metric)",
                collate([x[4][medium][metric] for x in tokens]),
            )
        end
    end
    put!(writer, (d, filename))
end;

In [None]:
function save_epoch(sentences, labels, weights, users, config, epoch, outdir, split, writer)
    outdir = joinpath(outdir, split, "$epoch")
    mkpath(outdir)
    Random.shuffle!(users)
    num_sentences = 0
    expected_num_sentences = getfield(config, Symbol("$(split)_epoch_size"))
    @showprogress enabled = epoch == 0 for (i, batch) in collect(
        Iterators.enumerate(Iterators.partition(users, config.batch_size)),
    )
        num_sentences += length(batch)
        save_tokens(
            sentences,
            labels,
            weights,
            batch,
            config,
            joinpath(outdir, "$(i-1).h5"),
            writer,
        )
    end
    @assert num_sentences == expected_num_sentences
end;

# Configs

In [None]:
function create_training_config(pretrain_name, medium)
    file = joinpath(get_data_path("alphas/$pretrain_name"), "config.json")
    open(file) do f
        d = JSON.parse(f)
        d["mode"] = "finetune"
        d["medium"] = medium
        for split in ["training", "validation"]
            d["$(split)_epoch_size"] = nothing
            d["$(split)_epoch_tokens"] = nothing
        end
        return NamedTuple(Symbol.(keys(d)) .=> values(d))
    end
end;

In [None]:
function set_epoch_size(config, training_users, validation_users)
    @info "Number of training sentences: $(length(training_users))"
    @info "Number of validation sentences: $(length(validation_users))"
    config = @set config.training_epoch_size = length(training_users)
    config = @set config.validation_epoch_size = length(validation_users)
end;

In [None]:
function setup_training(config, outdir)
    mkpath(outdir)
    fn = joinpath(outdir, "config.json")
    open(fn, "w") do f
        write(f, JSON.json(config))
    end
    for split in ["training", "validation"]
        fn = joinpath(outdir, split)
        mkpath(fn)
        for x in readdir(fn, join = true)
            rm(x, recursive = true)
        end
    end
end;

In [None]:
function save_epochs(num_epochs, pretrain_name)
    Random.seed!(20221221)
    config = create_training_config(pretrain_name, medium)
    @info "loading data"
    training_users, validation_users = get_users.(["validation", "test"], (medium,))
    sentences =
        get_training_data(config[:cls_tokens], vcat(training_users, validation_users))
    config = set_epoch_size(config, training_users, validation_users)
    labels = get_labels.(ALL_METRICS, (medium,))
    weights = get_weights.(ALL_METRICS, (medium,))
    outdir = get_data_path(joinpath("alphas", name))
    setup_training(config, outdir)

    writer_workers = 2
    writer = Channel(2 * writer_workers)
    for _ = 1:writer_workers
        Threads.@spawn hdf5_writer(writer)
    end
    Threads.@spawn begin
        for epoch = 0:num_epochs-1
            for (s, t) in
                zip([validation_users, training_users], ["validation", "training"])
                save_epoch(sentences, labels, weights, s, config, epoch, outdir, t, writer)
            end
        end
    end
end;

# Saving

In [None]:
function model(users, items, user_cache, user_to_index)
    p = zeros(Float32, length(users))
    @showprogress for i = 1:length(p)
        u = user_to_index[users[i]]
        p[i] = user_cache[u][items[i]+1]
    end
    p
end

function get_cache(metric, embeddings, user_to_index, weights, biases)
    cache = Dict()
    E = weights[metric] * embeddings .+ biases[metric]
    @showprogress for u in values(user_to_index)
        e = E[:, u]
        if metric in ["watch", "plantowatch"]
            e = softmax(e)
        elseif metric == "drop"
            e = sigmoid(e)
        end
        cache[u] = e
    end
    cache
end;

In [None]:
function save_alphas()
    file = HDF5.h5open(get_data_path(joinpath("alphas", name, "embeddings.h5")), "r")
    embeddings = read(file["embedding"])
    users = read(file["users"])
    weights = Dict(x => read(file["$(medium)_$(x)_weight"])' for x in ALL_METRICS)
    biases = Dict(x => read(file["$(medium)_$(x)_bias"]) for x in ALL_METRICS)
    close(file)

    user_to_index = Dict()
    for (i, u) in Iterators.enumerate(users)
        user_to_index[u] = i
    end

    for metric in ALL_METRICS
        cache = get_cache(metric, embeddings, user_to_index, weights, biases)
        write_alpha(
            (users, items) -> model(users, items, cache, user_to_index),
            medium,
            "$name/$metric",
            ["test", "negative"],
        )
    end
end;

In [None]:
function log_alphas()
    for metric in ALL_METRICS
        for split in ["test"]
            val = compute_loss(metric, medium, ["$name/$metric"], split)
            @info "$metric $split loss = $val"
        end
    end
end;

# Train model

In [None]:
save_epochs(16, pretrain_name);
# wait until the first epoch is finished
while !isdir(get_data_path(joinpath("alphas", name, "validation", "1")))
    sleep(1)
end

In [None]:
run(`python3 Pytorch.py --outdir $name --initialize $pretrain_name`)

In [None]:
save_alphas()

In [None]:
log_alphas()