# Finetuning
* Finetunes a transformer model to predict recent watches
* A recent watch is defined as an interaction that occurred within the past $D$ days and is one of the $N$ most recent interactions for that user

In [None]:
medium = ""
task = ""
name = "Transformer/v1"

In [None]:
pretrain_name = "all/$name"
name = "$medium/$task/$name";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Data.ipynb");

In [None]:
import HDF5
import JSON
import MLUtils
import Random
import SparseArrays: AbstractSparseArray, sparse, spzeros
import StatsBase: mean, sample
import ThreadPinning

In [None]:
ThreadPinning.pinthreads(:cores)

# Featurization

In [None]:
function featurize(sentences, labels, weights, medium, user, config, training::Bool)
    if user in keys(sentences)
        sentence = copy(sentences[user])
    else
        sentence = Vector{wordtype}()
        push!(sentence, replace(config["cls_tokens"], :user, user))
    end
    featurize(;
        sentence = sentence,
        labels = map(x -> x[:, user], labels),
        weights = map(x -> x[:, user], weights),
        medium = medium,
        user = user,
        max_seq_len = config["max_sequence_length"],
        vocab_sizes = config["base_vocab_sizes"],
        pad_tokens = config["pad_tokens"],
        cls_tokens = config["cls_tokens"],
        mask_tokens = config["mask_tokens"],
        empty_tokens = config["empty_tokens"],
        causal = config["causal"],
        training = training,
    )
end;

In [None]:
function featurize(;
    sentence::Vector{wordtype},
    labels,
    weights,
    medium,
    user,
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
    empty_tokens,
    causal,
    training,
)
    sentence = subset_sentence(
        sentence,
        min(length(sentence), max_seq_len - 1);
        recent = true,
        rng = nothing,
    )

    if causal
        sentence[end] = replace(sentence[end], :timestamp, 1)
    else
        # add mask token  
        if task == "temporal_causal"
            masked_word = replace(mask_tokens, :timestamp, 1)
            masked_word = replace(masked_word, :position, length(sentence))
        else
            @assert false
        end
        masked_word = replace(masked_word, :user, user)
        push!(sentence, masked_word)
    end

    # get tokenized sentences
    tokens =
        vec.(
            get_token_ids(
                [sentence],
                max_seq_len,
                extract(vocab_sizes, :position),
                pad_tokens,
                cls_tokens,
            ),
        )
    positions = [length(sentence)]

    featurized_labels = Dict(
        x => (
            item = spzeros(Float32, num_items(x)),
            rating = spzeros(Float32, num_items(x)),
        ) for x in ["anime", "manga"]
    )
    featurized_labels[medium][:item] .= labels[1]
    featurized_labels[medium][:rating] .= labels[2]

    featurized_weights = Dict(
        x => (
            item = spzeros(Float32, num_items(x)),
            rating = spzeros(Float32, num_items(x)),
        ) for x in ["anime", "manga"]
    )
    featurized_weights[medium][:item] .= weights[1]
    featurized_weights[medium][:rating] .= weights[2]

    tokens, positions, featurized_labels, featurized_weights
end;

# Data colleciton

In [None]:
function get_labels(task, content, medium)
    df = cat(
        get_split("validation", task, content, medium),
        get_split("test", task, content, medium),
    )
    sparse(df.item, df.user, df.rating, num_items(medium), num_users())
end

function get_labels(task)
    [get_labels(task, content, medium) for content in ["implicit", "explicit"]]
end;

In [None]:
function get_weights(task, content, medium)
    df = cat(
        get_split("validation", task, content, medium),
        get_split("test", task, content, medium),
    )
    w = vcat(
        powerdecay(
            get_counts("validation", task, content, medium),
            weighting_scheme("inverse"),
        ),
        powerdecay(get_counts("test", task, content, medium), weighting_scheme("inverse")),
    )

    sparse(df.item, df.user, w, num_items(medium), num_users())
end

function get_weights(task)
    [get_weights(task, content, medium) for content in ["implicit", "explicit"]]
end;

In [None]:
function get_users(task, medium)
    training = collect(Set(get_split("validation", task, "implicit", medium).user))
    test = collect(
        union(
            [
                Set(get_split("test", task, x, medium; fields = [:user]).user) for
                x in ALL_CONTENTS
            ]...,
        ),
    )
    training, test
end

function get_users(task)
    get_users(task, medium)
end;

In [None]:
function get_sentences(config, task, users)
    get_training_data(
        task,
        config["media"],
        config["include_ptw_impressions"],
        config["cls_tokens"],
        config["empty_tokens"],
        config["causal"];
        chunks = 1,
        users = users,
    )
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    rng
end;

In [None]:
function create_training_config(pretrain_name)
    file = joinpath(get_data_path("alphas/$pretrain_name"), "config.json")
    open(file) do f
        d = JSON.parse(f)
        d["mode"] = "finetune"
        return d
    end
end;

In [None]:
function set_epoch_size!(config, training_users, validation_users)
    num_tokens = length(training_users) * config["max_sequence_length"]
    @info "Number of training tokens: $(num_tokens)"
    @info "Number of training sentences: $(length(training_users))"
    @info "Number of validation sentences: $(length(validation_users))"
    config["training_epoch_size"] = length(training_users)
    config["validation_epoch_size"] = length(validation_users)
end;

In [None]:
function setup_training(config, outdir)
    if !isdir(outdir)
        mkdir(outdir)
    end
    for x in readdir(outdir, join = true)
        if isfile(x)
            rm(x)
        end
    end
    fn = joinpath(outdir, "..", "config.json")
    open(fn, "w") do f
        write(f, JSON.json(config))
    end
end;

# Disk I/O

In [None]:
function save_features(sentences, labels, weights, users, config, training, filename)
    features = []
    for x in users
        push!(features, featurize(sentences, labels, weights, medium, x, config, training))
    end

    d = Dict{String,AbstractArray}()
    d["causal"] = [config["causal"]]
    collate = MLUtils.batch
    embed_names = [
        "anime",
        "manga",
        "rating",
        "timestamp",
        "status",
        "completion",
        "user",
        "position",
    ]
    for (i, name) in Iterators.enumerate(embed_names)
        d[name] = collate([x[1][i] for x in features])
    end
    d["positions"] = collate([x[2] for x in features])

    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            record_sparse_array!(
                d,
                "labels_$(medium)_$(task)",
                collate([x[3][medium][Symbol(task)] for x in features]),
                extract(config["vocab_sizes"], Symbol(medium)),
            )
        end
    end
    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            record_sparse_array!(
                d,
                "weights_$(medium)_$(task)",
                collate([x[4][medium][Symbol(task)] for x in features]),
                extract(config["vocab_sizes"], Symbol(medium)),
            )
        end
    end

    HDF5.h5open(filename, "w") do file
        for (k, v) in d
            write(file, k, v)
        end
    end
end

function record_sparse_array!(d::Dict, name::String, x::AbstractSparseArray, vocab_size)
    i, j, v = SparseArrays.findnz(x)
    d[name*"_i"] = i
    d[name*"_j"] = j
    d[name*"_v"] = v
    d[name*"_size"] = [vocab_size, size(x)[2]]
end;

In [None]:
function advance!(filename)
    # check to see if we should write the next shard
    outdir = dirname(filename)
    files = readdir(outdir)
    suffix = basename(filename) * ".read"
    files = [x for x in files if occursin(suffix, basename(x))]
    if length(files) == 0
        return false
    end
    workers = Set(split(x, ".")[end-1:end] for x in files)
    @assert length(workers) == 1
    world_size, num_workers = parse.(Int, first(workers))
    advance = length(files) == world_size * num_workers
    if advance
        rm("$filename.complete")
        rm(filename)
        for x in files
            rm(joinpath(outdir, x))
        end
    end
    advance
end;

In [None]:
function spawn_feature_workers(
    sentences,
    labels,
    weights,
    users,
    config,
    rng,
    training,
    outdir,
)
    # writes data to "$outdir/data.$worker.h5" in a hot loop
    # whenever that file disappears, we populate it with a new batch
    chunk_size = config["chunk_size"]
    workers = training ? config["num_training_shards"] : config["num_validation_shards"]
    stem = training ? "training" : "validation"
    rngs = [Random.Xoshiro(rand(rng, UInt64)) for _ = 1:workers]
    for (i, batch) in Iterators.enumerate(
        Iterators.partition(users, div(length(users), workers, RoundUp)),
    )
        Threads.@spawn begin
            rng = rngs[i]
            while true
                Random.shuffle!(rng, batch)
                for (j, chunk) in
                    Iterators.enumerate(Iterators.partition(batch, chunk_size))
                    filename = joinpath(outdir, "$stem.$i.h5")
                    save_features(
                        sentences,
                        labels,
                        weights,
                        chunk,
                        config,
                        training,
                        filename,
                    )
                    open("$filename.complete", "w") do f
                        write(f, "$j")
                    end
                    GC.gc()
                    while isdir(outdir) && !advance!(filename)
                        sleep(1)
                    end
                    if !isdir(outdir)
                        break
                    end
                end
            end
        end
    end
end;

# Train model

In [None]:
config_checkpoint = nothing
config_epoch = nothing
reset_lr_schedule = true
rng = set_rngs(20221221)
config = create_training_config(pretrain_name);

In [None]:
@info "loading data"
training_users, test_users = get_users(task)
sentences = get_sentences(config, task, Set(vcat(training_users, test_users)))
labels = get_labels(task)
weights = get_weights(task)
set_epoch_size!(config, training_users, test_users);

In [None]:
outdir = get_data_path(joinpath("alphas", name, "training"))
setup_training(config, outdir);

In [None]:
HDF5.h5open(joinpath(outdir, "users.h5"), "w") do file
    write(file, "training", training_users)
    write(file, "test", test_users)
end

In [None]:
spawn_feature_workers(sentences, labels, weights, training_users, config, rng, true, outdir);

In [None]:
spawn_feature_workers(sentences, labels, weights, test_users, config, rng, false, outdir);

In [None]:
run(`python3 Pytorch.py --outdir $name --initialize $pretrain_name --epochs 8`)

# Save predictions

In [None]:
file = HDF5.h5open(get_data_path(joinpath("alphas", name, "embeddings.h5")), "r")
embeddings = read(file["embedding"])
users = read(file["users"])
item_weight = read(file["$(medium)_item_weight"])'
item_bias = read(file["$(medium)_item_bias"])
rating_weight = read(file["$(medium)_rating_weight"])'
rating_bias = read(file["$(medium)_rating_bias"])
close(file)

user_to_index = Dict()
for (i, u) in Iterators.enumerate(users)
    user_to_index[u] = i
end;

In [None]:
# note that we only record predictions for test users
function model(users, items, user_cache)
    ratings = zeros(Float32, length(users))
    @showprogress for i = 1:length(ratings)
        if users[i] ∉ keys(user_to_index)
            continue
        end
        u = user_to_index[users[i]]
        ratings[i] = user_cache[u][items[i]]
    end
    ratings
end;

In [None]:
item_cache = Dict()
@showprogress for u in values(user_to_index)
    e = item_weight * embeddings[:, u] + item_bias
    item_cache[u] = softmax(e)
end
write_alpha(
    (users, items) -> model(users, items, item_cache),
    medium,
    joinpath(name, "implicit");
    task = task,
    log = true,
    log_task = task,
    log_content = "implicit",
    log_alphas = String[],
    log_splits = ["test"],
)
item_cache = nothing;

In [None]:
rating_cache = Dict()
@showprogress for u in values(user_to_index)
    e = rating_weight * embeddings[:, u] + rating_bias
    rating_cache[u] = e
end
user_cache = Dict()
write_alpha(
    (users, items) -> model(users, items, rating_cache),
    medium,
    joinpath(name, "explicit");
    task = task,
    log = true,
    log_task = task,
    log_content = "explicit",
    log_alphas = String[],
    log_splits = ["test"],
)
rating_cache = nothing;

In [None]:
rm(outdir, recursive = true)