# Finetuning
* Finetunes a transformer model to predict recent watches
* A recent watch is defined as an interaction that occurred within the past $D$ days and is one of the $N$ most recent interactions for that user

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Data.ipynb");

In [None]:
import HDF5
import JSON
import MLUtils
import NNlib: sigmoid
import Random
import SparseArrays: AbstractSparseArray, sparse, spzeros
import StatsBase: mean, sample
import ThreadPinning

In [None]:
medium = "";

In [None]:
version = "v2layer"
pretrain_name = "all/Transformer/$version"
name = "$medium/Transformer/$version"
set_logging_outdir(name);

In [None]:
ThreadPinning.pinthreads(:cores)

# Configuration

In [None]:
function get_labels(metric, medium)
    df = cat(
        get_split("validation", metric, medium, [:userid, :itemid, :metric]),
        get_split("test", metric, medium, [:userid, :itemid, :metric]),
    )
    sparse(df, medium)
end

function get_weights(metric, medium)
    df = cat(
        get_split("validation", metric, medium, [:userid, :itemid]),
        get_split("test", metric, medium, [:userid, :itemid]),
    )
    df = @set df.metric = powerdecay(get_counts(df.userid), -1.0f0)
    sparse(df, medium)
end

get_users(split, medium) =
    collect(Set(get_raw_split(split, medium, [:userid], nothing).userid));

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    rng
end;

In [None]:
function create_training_config(pretrain_name, medium)
    file = joinpath(get_data_path("alphas/$pretrain_name"), "config.json")
    open(file) do f
        d = JSON.parse(f)
        d["mode"] = "finetune"
        d["medium"] = medium
        return d
    end
end;

In [None]:
function set_epoch_size!(config, training_users, validation_users)
    @info "Number of training sentences: $(length(training_users))"
    @info "Number of validation sentences: $(length(validation_users))"
    config["training_epoch_size"] = length(training_users)
    config["validation_epoch_size"] = length(validation_users)
end;

In [None]:
function setup_training(config, outdir)
    if !isdir(outdir)
        mkpath(outdir)
    end
    for x in readdir(outdir, join = true)
        if isfile(x)
            rm(x)
        end
    end
    fn = joinpath(outdir, "..", "config.json")
    open(fn, "w") do f
        write(f, JSON.json(config))
    end
end;

# Disk I/O

In [None]:
function featurize(sentences, labels, weights, medium, userid, config)
    if userid in keys(sentences)
        sentence = copy(sentences[userid])
    else
        sentence = Vector{wordtype}()
        push!(sentence, replace(config["cls_tokens"], :userid, userid))
    end
    featurize(;
        sentence = sentence,
        labels = map(x -> x[:, userid+1], labels),
        weights = map(x -> x[:, userid+1], weights),
        medium = medium,
        userid = userid,
        max_seq_len = config["max_sequence_length"],
        vocab_sizes = config["base_vocab_sizes"],
        pad_tokens = config["pad_tokens"],
        cls_tokens = config["cls_tokens"],
        mask_tokens = config["mask_tokens"],
    )
end;

In [None]:
function featurize(;
    sentence::Vector{wordtype},
    labels,
    weights,
    medium,
    userid,
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
)
    # get inputs
    sentence = subset_sentence(
        sentence,
        min(length(sentence), max_seq_len - 1);
        recent = true,
        rng = nothing,
    )
    # masked_word = replace(mask_tokens, :updated_at, 1) # TODO
    masked_word = config["mask_tokens"]
    masked_word = replace(masked_word, :position, length(sentence) - 1)
    masked_word = replace(masked_word, :userid, userid)
    push!(sentence, masked_word)
    tokens = get_token_ids(sentence, max_seq_len, pad_tokens, cls_tokens)

    # get outputs
    positions = [length(sentence) - 1]
    featurized_labels = Dict(
        x => Dict(y => spzeros(Float32, num_items(x)) for y in ALL_METRICS) for
        x in ALL_MEDIUMS
    )
    featurized_weights = Dict(
        x => Dict(y => spzeros(Float32, num_items(x)) for y in ALL_METRICS) for
        x in ALL_MEDIUMS
    )
    for i = 1:length(ALL_METRICS)
        featurized_labels[medium][ALL_METRICS[i]] .= labels[i]
        featurized_weights[medium][ALL_METRICS[i]] .= weights[i]
    end
    tokens, positions, featurized_labels, featurized_weights
end;

In [None]:
function save_features(sentences, labels, weights, users, config, filename)
    features = []
    for x in users
        push!(features, featurize(sentences, labels, weights, medium, x, config))
    end

    d = Dict{String,AbstractArray}()
    collate = MLUtils.batch
    embed_names = ["itemid", "rating", "updated_at", "status", "position", "userid"]
    for (i, name) in Iterators.enumerate(embed_names)
        d[name] = collate([x[1][i] for x in features])
    end
    d["positions"] = collate([x[2] for x in features])
    for medium in ALL_MEDIUMS
        for metric in ALL_METRICS
            record_sparse_array!(
                d,
                "labels_$(medium)_$(metric)",
                collate([x[3][medium][metric] for x in features]),
            )
            record_sparse_array!(
                d,
                "weights_$(medium)_$(metric)",
                collate([x[4][medium][metric] for x in features]),
            )
        end
    end

    HDF5.h5open(filename, "w") do file
        for (k, v) in d
            write(file, k, v)
        end
    end
end

function record_sparse_array!(d::Dict, name::String, x::AbstractSparseArray)
    i, j, v = SparseArrays.findnz(x)
    d[name*"_i"] = i .- 1
    d[name*"_j"] = j .- 1
    d[name*"_v"] = v
    d[name*"_size"] = collect(size(x))
end;

In [None]:
function advance!(filename)
    # check to see if we should write the next shard
    outdir = dirname(filename)
    files = readdir(outdir)
    suffix = basename(filename) * ".read"
    files = [x for x in files if occursin(suffix, basename(x))]
    if length(files) == 0
        return false
    end
    world_sizes = Set(split(x, ".")[end] for x in files)
    @assert length(world_sizes) == 1
    world_size = parse(Int, first(world_sizes))
    advance = length(files) == world_size
    if advance
        rm("$filename.complete")
        rm(filename)
        for x in files
            rm(joinpath(outdir, x))
        end
    end
    advance
end;

In [None]:
function spawn_feature_workers(
    sentences,
    labels,
    weights,
    users,
    config,
    rng,
    training,
    outdir,
)
    # writes data to "$outdir/training/$split.$worker.h5" in a loop
    chunk_size = config["chunk_size"]
    workers = training ? config["num_training_shards"] : config["num_validation_shards"]
    stem = training ? "training" : "validation"
    rngs = [Random.Xoshiro(rand(rng, UInt64)) for _ = 1:workers]
    for (i, batch) in Iterators.enumerate(
        Iterators.partition(users, div(length(users), workers, RoundUp)),
    )
        Threads.@spawn begin
            rng = rngs[i]
            while true
                Random.shuffle!(rng, batch)
                for (j, chunk) in
                    Iterators.enumerate(Iterators.partition(batch, chunk_size))
                    filename = joinpath(outdir, "$stem.$i.h5")
                    save_features(sentences, labels, weights, chunk, config, filename)
                    open("$filename.complete", "w") do f
                        write(f, "$j")
                    end
                    if i == 1
                        GC.gc()
                    end
                    while isdir(outdir) && !advance!(filename)
                        sleep(1)
                    end
                    if !isdir(outdir)
                        return
                    end
                end
            end
        end
    end
end;

# Train model

In [None]:
config_checkpoint = nothing
config_epoch = nothing
rng = set_rngs(20221221)
config = create_training_config(pretrain_name, medium);

In [None]:
training_users, test_users = get_users.(["validation", "test"], (medium,))
sentences = get_training_data(config["cls_tokens"], 1, vcat(training_users, test_users));

In [None]:
labels = get_labels.(ALL_METRICS, (medium,));

In [None]:
weights = get_weights.(ALL_METRICS, (medium,));

In [None]:
set_epoch_size!(config, training_users, test_users);

In [None]:
outdir = get_data_path(joinpath("alphas", name, "training"))
setup_training(config, outdir);

In [None]:
HDF5.h5open(joinpath(outdir, "users.h5"), "w") do file
    write(file, "training", training_users)
    write(file, "test", test_users)
end

In [None]:
spawn_feature_workers(sentences, labels, weights, training_users, config, rng, true, outdir);

In [None]:
spawn_feature_workers(sentences, labels, weights, test_users, config, rng, false, outdir);

In [None]:
# wait for workers to begin writing
while sum(endswith.(readdir(outdir), (".complete",))) <
      config["num_training_shards"] + config["num_validation_shards"]
    sleep(1)
end

In [None]:
run(`python3 Pytorch.py --outdir $name --initialize $pretrain_name --epochs 8`)

# Save predictions

In [None]:
file = HDF5.h5open(get_data_path(joinpath("alphas", name, "embeddings.h5")), "r")
embeddings = read(file["embedding"])
users = read(file["users"])
weights = Dict(x => read(file["$(medium)_$(x)_weight"])' for x in ALL_METRICS)
biases = Dict(x => read(file["$(medium)_$(x)_bias"]) for x in ALL_METRICS)
close(file)

user_to_index = Dict()
for (i, u) in Iterators.enumerate(users)
    user_to_index[u] = i
end;

In [None]:
function model(users, items, user_cache)
    p = zeros(Float32, length(users))
    @showprogress for i = 1:length(p)
        u = user_to_index[users[i]]
        p[i] = user_cache[u][items[i]+1]
    end
    p
end

function get_cache(metric, user_to_index, weights, biases)
    cache = Dict()
    E = weights[metric] * embeddings .+ biases[metric]
    @showprogress for u in values(user_to_index)
        e = E[:, u]
        if metric in ["watch", "plantowatch"]
            e = softmax(e)
        elseif metric == "drop"
            e = sigmoid(e)
        end
        cache[u] = e
    end
    cache
end;

In [None]:
@showprogress for metric in ALL_METRICS
    cache = get_cache(metric, user_to_index, weights, biases)
    write_alpha(
        (users, items) -> model(users, items, cache),
        medium,
        "$name/$metric",
        ["test", "negative"],
    )
end;

In [None]:
for metric in ALL_METRICS
    for split in ["test"]
        val = compute_loss(metric, medium, ["$name/$metric"], split)
        @info "$metric $split loss = $val"
    end
end