# Finetuning
* Finetunes a transformer model to predict recent watches

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Data.ipynb");

In [None]:
import H5Zblosc
import HDF5
import JSON
import MLUtils
import NNlib: sigmoid
import Random
import SparseArrays: AbstractSparseArray, sparse, spzeros
import StatsBase: mean, sample

In [None]:
medium = "";
mode = "";

In [None]:
const version = "v1"
const pretrain_name = "all/Transformer/$version"
const name = "$medium/Transformer/$version"
set_logging_outdir(name);

# Data

In [None]:
function get_size(vocab_sizes, medium)
    if medium == "manga"
        return vocab_sizes[1]
    elseif medium == "anime"
        return vocab_sizes[2]
    end
end

function SparseArrays.sparse(x::RatingsDataset, medium::String, vocab_sizes::Vector)
    SparseArrays.sparse(
        x.itemid .+ 1,
        x.userid .+ 1,
        x.metric,
        get_size(vocab_sizes, medium),
        num_users(),
    )
end;

In [None]:
function get_training_split(metric, medium, fields)
    fields = vcat(fields, [:updated_at, :update_order])
    _, df = training_test_split(get_split("training", metric, medium, fields))
    df = @set df.updated_at = []
    df = @set df.update_order = []
    df
end

function get_labels(split, metric, medium, vocab_sizes)
    if split == "training"
        df = get_training_split(metric, medium, [:userid, :itemid, :metric])
    elseif split == "test"
        df = get_split("test", metric, medium, [:userid, :itemid, :metric])
    else
        @assert false
    end
    sparse(df, medium, vocab_sizes)
end

function get_weights(split, metric, medium, vocab_sizes)
    if split == "training"
        df = get_training_split(metric, medium, [:userid, :itemid, :metric])
    elseif split == "test"
        df = get_split("test", metric, medium, [:userid, :itemid, :metric])
    else
        @assert false
    end
    df = @set df.metric = powerdecay(get_counts(df.userid), -1.0f0)
    sparse(df, medium, vocab_sizes)
end

function get_users(split, medium)
    if split == "training"
        _, df = training_test_split(
            get_raw_split(
                "training",
                medium,
                [:userid, :updated_at, :update_order],
                nothing,
            ),
        )
    else
        df = get_raw_split(split, medium, [:userid], nothing)
    end
    collect(Set(df.userid))
end;

In [None]:
function tokenize(sentences, labels, weights, medium, userid, config)
    if userid in keys(sentences)
        sentence = copy(sentences[userid])
    else
        sentence = Vector{wordtype}()
        push!(sentence, replace(config[:cls_tokens], :userid, userid))
    end
    tokenize(;
        sentence = sentence,
        labels = map(x -> x[:, userid+1], labels),
        weights = map(x -> x[:, userid+1], weights),
        medium = medium,
        userid = userid,
        max_seq_len = config[:max_sequence_length],
        vocab_sizes = config[:vocab_sizes],
        pad_tokens = config[:pad_tokens],
        cls_tokens = config[:cls_tokens],
        mask_tokens = config[:mask_tokens],
    )
end;

In [None]:
function tokenize(;
    sentence::Vector{wordtype},
    labels,
    weights,
    medium,
    userid,
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
)
    # get inputs
    sentence =
        subset_sentence(sentence, min(length(sentence), max_seq_len - 1); recent = true)
    masked_word = mask_tokens
    masked_word = replace(masked_word, :updated_at, 1)
    masked_word = replace(masked_word, :position, length(sentence) - 1)
    masked_word = replace(masked_word, :userid, userid)
    push!(sentence, masked_word)
    tokens = get_token_ids(sentence, max_seq_len, pad_tokens, false)

    # get outputs
    positions = [length(sentence) - 1]
    tokenized_labels = Dict(
        x => Dict(y => spzeros(Float32, get_size(vocab_sizes, x)) for y in ALL_METRICS)
        for x in ALL_MEDIUMS
    )
    tokenized_weights = Dict(
        x => Dict(y => spzeros(Float32, get_size(vocab_sizes, x)) for y in ALL_METRICS)
        for x in ALL_MEDIUMS
    )
    for i = 1:length(ALL_METRICS)
        tokenized_labels[medium][ALL_METRICS[i]] .= labels[i]
        tokenized_weights[medium][ALL_METRICS[i]] .= weights[i]
    end
    tokens, positions, tokenized_labels, tokenized_weights
end;

# Epochs

In [None]:
function record_sparse_array!(d::Dict, name::String, x::AbstractSparseArray)
    i, j, v = SparseArrays.findnz(x)
    d[name*"_i"] = i .- 1
    d[name*"_j"] = j .- 1
    d[name*"_v"] = v
    d[name*"_size"] = collect(size(x))
end;

In [None]:
function save_tokens(sentences, labels, weights, users, config, filename)
    tokens = Any[nothing for _ = 1:length(users)]
    Threads.@threads for i = 1:length(users)
        tokens[i] = tokenize(sentences, labels, weights, medium, users[i], config)
    end

    d = Dict{String,AbstractArray}()
    collate = MLUtils.batch
    for (i, name) in Iterators.enumerate(config.vocab_names)
        d[name] = collate([x[1][i] for x in tokens])
    end
    d["positions"] = collate([x[2] for x in tokens])
    for medium in ALL_MEDIUMS
        for metric in ALL_METRICS
            record_sparse_array!(
                d,
                "labels_$(medium)_$(metric)",
                collate([x[3][medium][metric] for x in tokens]),
            )
            record_sparse_array!(
                d,
                "weights_$(medium)_$(metric)",
                collate([x[4][medium][metric] for x in tokens]),
            )
        end
    end
    HDF5.h5open(filename, "w") do f
        for (k, v) in d
            f[k, blosc = 3] = v
        end
    end
end;

In [None]:
function save_epoch(sentences, labels, weights, users, config, epoch, outdir, split)
    outdir = joinpath(outdir, split, "$epoch")
    mkpath(outdir)
    Random.shuffle!(users)
    num_sentences = 0
    expected_num_sentences = getfield(config, Symbol("$(split)_epoch_size"))
    @showprogress enabled = epoch == 0 for (i, batch) in collect(
        Iterators.enumerate(Iterators.partition(users, config.batch_size)),
    )
        num_sentences += length(batch)
        save_tokens(
            sentences,
            labels,
            weights,
            batch,
            config,
            joinpath(outdir, "$(i-1).h5"),
        )
    end
    @assert num_sentences == expected_num_sentences
end;

# Configs

In [None]:
function create_training_config(pretrain_name, medium)
    file = joinpath(get_data_path("alphas/$pretrain_name/0"), "config.json")
    open(file) do f
        d = JSON.parse(f)
        d["mode"] = "finetune"
        d["medium"] = medium
        for split in ["training", "validation"]
            d["$(split)_epoch_size"] = nothing
            d["$(split)_epoch_tokens"] = nothing
        end
        return NamedTuple(Symbol.(keys(d)) .=> values(d))
    end
end;

In [None]:
function set_epoch_size(config, users)
    for (name, u) in zip(["training", "validation"], users)
        @info "Number of $name sentences: $(length(u))"
    end
    merge(
        config,
        (training_epoch_size = length(users[1]), validation_epoch_size = length(users[2])),
    )
end;

In [None]:
function setup_training(config, outdir)
    mkpath(outdir)
    fn = joinpath(outdir, "config.json")
    open(fn, "w") do f
        write(f, JSON.json(config))
    end
    for split in ["training", "validation"]
        fn = joinpath(outdir, split)
        mkpath(fn)
        for x in readdir(fn, join = true)
            rm(x, recursive = true)
        end
    end
end;

In [None]:
function save_epochs(num_epochs, pretrain_name)
    Random.seed!(20221221)
    config = create_training_config(pretrain_name, medium)
    @info "loading data"
    users = get_users.(["training", "test"], (medium,))
    sentences = [
        get_training_data(
            config[:cls_tokens],
            config[:max_sequence_length],
            nothing,
            vcat(users...),
            holdout,
        ) for holdout in [true, false]
    ]
    config = set_epoch_size(config, users)
    labels = [
        [get_labels(s, m, medium, config.vocab_sizes) for m in ALL_METRICS] for
        s in ["training", "test"]
    ]
    weights = [
        [get_weights(s, m, medium, config.vocab_sizes) for m in ALL_METRICS] for
        s in ["training", "test"]
    ]
    outdir = get_data_path(joinpath("alphas", name, "0"))
    setup_training(config, outdir)

    for epoch = 0:num_epochs-1
        for (s, l, w, u, t) in
            zip(sentences, labels, weights, users, ["training", "validation"])
            save_epoch(s, l, w, u, config, epoch, outdir, t)
        end
    end
end;

In [None]:
function copy_epochs(num_epochs, num_source_epochs)
    source_epoch = 0
    for s in ["training", "validation"]
        getpath(epoch) = get_data_path(joinpath("alphas", name, "0", s, "$epoch"))
        for i = num_source_epochs:num_epochs-1
            src = getpath(source_epoch)
            source_epoch = (source_epoch + 1) % num_source_epochs
            dst = getpath(i)
            mkdir(dst)
            for basename in readdir(src)
                cp("$src/$basename", "$dst/$basename")
            end
        end
    end
end;

# Saving

In [None]:
function model(users, items, cache)
    p = zeros(Float32, length(users))
    @showprogress for i = 1:length(p)
        p[i] = cache[users[i]][items[i]+1]
    end
    p
end

function get_cache(metric::String, embeddings, user_to_index, seen)
    cache = Dict()
    @showprogress for (user, index) in user_to_index
        e = embeddings[:, index]
        if metric in ["watch", "plantowatch"]
            e = exp.(e) # the model saves log-softmax values
            e[seen[:, user+1].nzind] .= 0 # zero out watched items
            e = e ./ sum(e)
        elseif metric == "drop"
            e = sigmoid.(e)
        end
        cache[user] = e
    end
    cache
end;

function get_cache(metric::String, medium::String)
    if metric in ["watch", "plantowatch"]
        df = get_raw_split("training", medium, [:userid, :itemid], nothing)
        users = Set(get_users("test", medium))
        df = filter(df, df.userid .∈ (users,))
        df = @set df.metric = ones(Float32, length(df.userid))
        seen = sparse(df, medium)
    else
        seen = nothing
    end

    cache = Dict()
    shard = 0
    while true
        fn = get_data_path(joinpath("alphas", name, "embeddings.$shard.h5"))
        shard += 1
        if !isfile(fn)
            break
        end
        file = HDF5.h5open(fn, "r")
        users = read(file["users"])
        user_to_index = Dict()
        for (i, u) in Iterators.enumerate(users)
            user_to_index[u] = i
        end
        preds = read(file["$(medium)_$(metric)"])
        cache = merge(cache, get_cache(metric, preds, user_to_index, seen))
        close(file)
    end
    cache
end;

In [None]:
function save_alphas()
    for metric in ALL_METRICS
        cache = get_cache(metric, medium)
        write_alpha(
            (users, items) -> model(users, items, cache),
            medium,
            "$name/$metric",
            ["test", "negative"],
        )
    end
end;

In [None]:
function log_alphas()
    for metric in ALL_METRICS
        for split in ["test"]
            val = compute_loss(metric, medium, ["$name/$metric"], split)
            @info "$metric $split loss = $val"
        end
    end
end;

In [None]:
function cleanup()
    fn = get_data_path(joinpath("alphas", name))
    rm(joinpath(fn, "0"); recursive = true)
    shard = 0
    while true
        fn = get_data_path(joinpath("alphas", name, "embeddings.$shard.h5"))
        if !isfile(fn)
            break
        end
        rm(fn)
        shard += 1
    end
end;

# Train model

In [None]:
if mode == "dataset"
    save_epochs(4, pretrain_name)
    copy_epochs(16, 4)
elseif mode == "train"
    run(`python3 Pytorch.py --outdir $name --initialize $pretrain_name`)
    save_alphas()
    log_alphas()
    cleanup()
else
    @assert false
end