In [None]:
medium = "anime"
task = "temporal"
name = "Transformer/v0"

In [None]:
pretrain_name = "all/$name"
name = "$medium/$task/$name";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Data.ipynb");

In [None]:
import Flux
import HDF5
import JSON
import Random
import SparseArrays: AbstractSparseArray, sparse, spzeros
import StatsBase: mean, sample

# Featurization

In [None]:
function featurize(sentences, labels, weights, medium, user, config, training::Bool)
    featurize(;
        sentence = user in keys(sentences) ? sentences[user] : eltype(values(sentences))(),
        labels = map(x -> x[:, user], labels),
        weights = map(x -> x[:, user], weights),
        medium = medium,
        user = user,
        max_seq_len = config["max_sequence_length"],
        vocab_sizes = config["base_vocab_sizes"],
        pad_tokens = config["pad_tokens"],
        cls_tokens = config["cls_tokens"],
        mask_tokens = config["mask_tokens"],
        empty_tokens = config["empty_tokens"],
        training = training,
    )
end;

In [None]:
function featurize(;
    sentence::Vector{wordtype},
    labels,
    weights,
    medium,
    user,
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
    empty_tokens,
    training,
)
    sentence = subset_sentence(
        sentence,
        min(length(sentence), max_seq_len - 1);
        recent = true,
        keep_first = false,
        rng = nothing,
    )

    # add masking token    
    if task == "temporal"
        masked_word = replace(mask_tokens, :timestamp, 1)
    elseif task == "temporal_causal"
        masked_word = replace(mask_tokens, :timestamp, 1)
        masked_word = replace(masked_word, :position, length(s))
    else
        @assert false
    end
    masked_word = replace(masked_word, :user, user)
    push!(sentence, masked_word)
    masked_pos = length(sentence)
    seq_len = max_seq_len

    # get tokenized sentences
    tokens =
        vec.(
            get_token_ids(
                [sentence],
                seq_len,
                extract(vocab_sizes, :position),
                pad_tokens,
                cls_tokens,
            ),
        )
    positions = [masked_pos]

    featurized_labels = Dict(
        x => (
            item = spzeros(Float32, num_items(x)),
            rating = spzeros(Float32, num_items(x)),
        ) for x in ["anime", "manga"]
    )
    featurized_labels[medium][:item] .= labels[1]
    featurized_labels[medium][:rating] .= labels[2]

    featurized_weights = Dict(
        x => (
            item = spzeros(Float32, num_items(x)),
            rating = spzeros(Float32, num_items(x)),
        ) for x in ["anime", "manga"]
    )
    featurized_weights[medium][:item] .= weights[1]
    featurized_weights[medium][:rating] .= weights[2]

    tokens, positions, featurized_labels, featurized_weights
end;

# Data colleciton

In [None]:
function get_labels(task, content, medium)
    df = cat(
        get_split("validation", task, content, medium),
        get_split("test", task, content, medium),
    )
    if content == "explicit"
        baseline = read_params("$medium/$task/ExplicitUserItemBiases")
        for i = 1:length(df.rating)
            df.rating[i] -= baseline["u"][df.user[i]] + baseline["a"][df.item[i]]
        end
    end
    sparse(df.item, df.user, df.rating, num_items(medium), num_users(medium))
end

function get_labels(task)
    [get_labels(task, content, medium) for content in ["implicit", "explicit"]]
end;

In [None]:
function get_weights(task, content, medium)
    df = cat(
        get_split("validation", task, content, medium),
        get_split("test", task, content, medium),
    )
    w = vcat(
        powerdecay(
            get_counts("validation", task, content, medium),
            weighting_scheme("inverse"),
        ),
        powerdecay(get_counts("test", task, content, medium), weighting_scheme("inverse")),
    )

    sparse(df.item, df.user, w, num_items(medium), num_users(medium))
end

function get_weights(task)
    [get_weights(task, content, medium) for content in ["implicit", "explicit"]]
end;

In [None]:
function get_users(task, content, medium)
    training = collect(Set(get_split("validation", task, content, medium).user))
    test = collect(Set(get_split("test", task, content, medium).user))
    training, test
end

function get_users(task)
    get_users(task, "implicit", medium)
end;

In [None]:
function get_sentences(config, task)
    explicit_baseline = Dict(
        medium => read_params("$medium/$task/ExplicitUserItemBiases") for
        medium in ["anime", "manga"]
    )
    get_training_data(
        task,
        config["media"],
        config["include_ptw_impressions"],
        config["cls_tokens"],
        config["empty_tokens"],
        explicit_baseline = explicit_baseline,
    )
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    rng
end;

In [None]:
function create_training_config(pretrain_name)
    file = joinpath(get_data_path("alphas/$pretrain_name"), "config.json")
    open(file) do f
        d = JSON.parse(f)
        d["mode"] = "finetune"
        return d
    end
end;

In [None]:
function set_epoch_size!(config, training_users, validation_users)
    num_tokens = length(training_users) * config["max_sequence_length"]
    @info "Number of training tokens: $(num_tokens)"
    @info "Number of training sentences: $(length(training_users))"
    @info "Number of validation sentences: $(length(validation_users))"
    config["training_epoch_size"] = length(training_users)
    config["validation_epoch_size"] = length(validation_users)
end;

In [None]:
function setup_training(config, outdir)
    if !isdir(outdir)
        mkdir(outdir)
    end
    for x in readdir(outdir, join = true)
        if isfile(x)
            rm(x)
        end
    end
    fn = joinpath(outdir, "..", "config.json")
    open(fn * "~", "w") do f
        write(f, JSON.json(config))
    end
    mv(fn * "~", fn, force = true)
end;

# Disk I/O

In [None]:
function save_features(sentences, labels, weights, users, config, training, filename)
    features = []
    for x in users
        push!(features, featurize(sentences, labels, weights, medium, x, config, training))
    end

    d = Dict{String,AbstractArray}()
    collate = Flux.batch
    embed_names = [
        "anime",
        "manga",
        "rating",
        "timestamp",
        "status",
        "completion",
        "user",
        "position",
    ]
    for (i, name) in Iterators.enumerate(embed_names)
        d[name] = collate([x[1][i] for x in features])
    end
    d["positions"] = collate([x[2] for x in features])

    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            record_sparse_array!(
                d,
                "labels_$(medium)_$(task)",
                collate([x[3][medium][Symbol(task)] for x in features]),
                extract(config["vocab_sizes"], Symbol(medium)),
            )
        end
    end
    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            record_sparse_array!(
                d,
                "weights_$(medium)_$(task)",
                collate([x[4][medium][Symbol(task)] for x in features]),
                extract(config["vocab_sizes"], Symbol(medium)),
            )
        end
    end

    HDF5.h5open(filename, "w") do file
        for (k, v) in d
            write(file, k, v)
        end
    end
end

function record_sparse_array!(d::Dict, name::String, x::AbstractSparseArray, vocab_size)
    i, j, v = SparseArrays.findnz(x)
    d[name*"_i"] = i
    d[name*"_j"] = j
    d[name*"_v"] = v
    d[name*"_size"] = [vocab_size, size(x)[2]]
end;

In [None]:
function spawn_feature_workers(
    sentences,
    labels,
    weights,
    users,
    config,
    rng,
    training,
    outdir,
)
    # writes data to "$outdir/data.$worker.h5" in a hot loop
    # whenever that file disappears, we populate it with a new batch
    chunk_size = config["chunk_size"]
    workers = config["num_workers"]
    stem = training ? "training" : "validation"
    rngs = [Random.Xoshiro(rand(rng, UInt64)) for _ = 1:workers]
    @sync for (i, batch) in Iterators.enumerate(
        Iterators.partition(users, div(length(users), workers, RoundUp)),
    )
        Threads.@spawn begin
            rng = rngs[i]
            while true
                Random.shuffle!(rng, batch)
                for (j, chunk) in
                    Iterators.enumerate(Iterators.partition(batch, chunk_size))
                    GC.gc()
                    filename = joinpath(outdir, "$stem.$i.h5")
                    while isfile(filename) && isdir(outdir)
                        sleep(1)
                    end
                    if !isdir(outdir)
                        break
                    end
                    save_features(
                        sentences,
                        labels,
                        weights,
                        chunk,
                        config,
                        training,
                        filename,
                    )
                    open("$filename.complete", "w") do f
                        write(f, "$j")
                    end
                end
            end
        end
    end
end;

# State

In [None]:
config_checkpoint = nothing
config_epoch = nothing
reset_lr_schedule = true
rng = set_rngs(20221221)
config = create_training_config(pretrain_name);

In [None]:
@info "loading data"
sentences = get_sentences(config, task)
labels = get_labels(task)
weights = get_weights(task)
training_users, test_users = get_users(task)
set_epoch_size!(config, training_users, test_users);

In [None]:
outdir = get_data_path(joinpath("alphas", name, "training"))
config["num_workers"] = 4
setup_training(config, outdir);

In [None]:
Threads.@spawn spawn_feature_workers(
    sentences,
    labels,
    weights,
    training_users,
    config,
    rng,
    true,
    outdir,
);

In [None]:
Threads.@spawn spawn_feature_workers(
    sentences,
    labels,
    weights,
    test_users,
    config,
    rng,
    false,
    outdir,
);

In [None]:
# run(`python3 Pytorch.py --outdir $name --model_checkpoint $pretrain_name`)

In [None]:
# rm(outdir, recursive=true)