In [None]:
name = "all/Transformer/v0";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Reference/Data.ipynb");

In [None]:
import Flux
import HDF5
import JSON
import Random
import StatsBase: mean, sample

# Featurization

In [None]:
featurize(sentence::Vector{wordtype}, config, rng, training::Bool) = featurize(;
    sentence = sentence,
    max_seq_len = config["max_sequence_length"],
    vocab_sizes = config["base_vocab_sizes"],
    pad_tokens = config["pad_tokens"],
    cls_tokens = config["cls_tokens"],
    mask_tokens = config["mask_tokens"],
    empty_tokens = config["empty_tokens"],
    explicit_baseline = Dict(k => rand(rng, v) for (k, v) in config["explicit_baseline"]),
    rng = rng,
    training = training,
);

In [None]:
function featurize(;
    sentence::Vector{wordtype},
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
    empty_tokens,
    explicit_baseline,
    rng,
    training,
)
    # pad to the largest sequence length
    seq_len = max_seq_len
    sentence =
        subset_sentence(sentence, seq_len; recent = false, keep_first = false, rng = rng)

    # get tokenized sentences
    tokens =
        vec.(
            get_token_ids(
                [sentence],
                seq_len,
                extract(vocab_sizes, :position),
                pad_tokens,
                cls_tokens,
            ),
        )

    # demean ratings
    if !isnothing(explicit_baseline)
        demean = (
            anime = (
                rating = Dict{Int32,Float32}(),
                count = Dict{Int32,Int32}(),
                weight = Dict{Int32,Float32}(),
            ),
            manga = (
                rating = Dict{Int32,Float32}(),
                count = Dict{Int32,Int32}(),
                weight = Dict{Int32,Float32}(),
            ),
        )
        demean_item_weights = (
            anime = explicit_baseline[:anime]["weights"],
            manga = explicit_baseline[:manga]["weights"],
        )
    end

    positions = (
        anime = (item = zeros(Int32, seq_len), rating = zeros(Int32, seq_len)),
        manga = (item = zeros(Int32, seq_len), rating = zeros(Int32, seq_len)),
    )
    weights = (
        anime = (item = zeros(Float32, seq_len), rating = zeros(Float32, seq_len)),
        manga = (item = zeros(Float32, seq_len), rating = zeros(Float32, seq_len)),
    )
    labels = (
        anime = (item = zeros(Float32, seq_len), rating = zeros(Float32, seq_len)),
        manga = (item = zeros(Float32, seq_len), rating = zeros(Float32, seq_len)),
    )
    userids = (
        anime = (item = zeros(Int32, seq_len), rating = zeros(Int32, seq_len)),
        manga = (item = zeros(Int32, seq_len), rating = zeros(Int32, seq_len)),
    )
    for i::Int32 = 1:seq_len
        # randomly mask 15% of non-trivial tokens 
        has_anime =
            (extract(tokens, :anime)[i] <= extract(vocab_sizes, :anime)) &&
            (extract(tokens, :status)[i] != get_status(:plan_to_watch))
        has_manga =
            (extract(tokens, :manga)[i] <= extract(vocab_sizes, :manga)) &&
            (extract(tokens, :status)[i] != get_status(:plan_to_watch))
        has_rating = extract(tokens, :rating)[i] < extract(vocab_sizes, :rating)
        if has_anime
            medium = :anime
        elseif has_manga
            medium = :manga
        end
        should_mask = rand(rng) < 0.15

        # prepare to demean ratings
        if !should_mask && has_rating
            u = extract(tokens, :user)[i]
            a = extract(tokens, medium)[i]
            if u ∉ keys(demean[medium][:rating])
                demean[medium][:rating][u] = 0
                demean[medium][:count][u] = 0
                demean[medium][:weight][u] = 0
            end
            weight =
                demean_item_weights[medium][a] * powerlawdecay(
                    1 .- cast_universal_timestamp(
                        extract(tokens, :timestamp)[i],
                        String(medium),
                    ),
                    explicit_baseline[medium]["λ"][5],
                )
            demean[medium][:rating][u] +=
                weight * (extract(tokens, :rating)[i] - explicit_baseline[medium]["a"][a])
            demean[medium][:count][u] += 1
            demean[medium][:weight][u] += weight
        end

        # record tokens before we mask them out
        if !(should_mask && (has_anime || has_manga || has_rating))
            continue
        end
        if has_anime || has_manga
            positions[medium][:item][i] = extract(tokens, medium)[i]
            labels[medium][:item][i] = 1
            weights[medium][:item][i] = 1
            userids[medium][:item][i] = extract(tokens, :user)[i]
        end
        if has_rating
            positions[medium][:rating][i] = extract(tokens, medium)[i]
            labels[medium][:rating][i] = extract(tokens, :rating)[i]
            weights[medium][:rating][i] = 1
            userids[medium][:rating][i] = extract(tokens, :user)[i]
        end

        # bert masking
        item_allowed_info = get_wordtype_index.([medium, :rating, :timestamp, :position])
        item_skip_info = get_wordtype_index.([:anime, :manga, :user])
        for j = 1:length(tokens)
            if j in item_allowed_info || j in item_skip_info
                continue
            end
            tokens[j][i] = mask_tokens[j]
        end
        for j in item_allowed_info
            if j in get_wordtype_index.([medium, :rating])
                cutoffs = (0.8, 0.9)
                r = training ? rand(rng) : 0.0
            elseif j == get_wordtype_index(:timestamp)
                cutoffs = (0.45, 0.9)
                r = training ? rand(rng) : 0.0
            elseif j == get_wordtype_index(:position)
                cutoffs = (0.45, 0.9)
                r = training ? rand(rng) : 0.7
            else
                @assert false
            end
            if r <= cutoffs[1]
                tokens[j][i] = mask_tokens[j]
            elseif r <= cutoffs[2]
                nothing
            else
                if eltype(vocab_sizes[j]) == Int32
                    tokens[j][i] = rand(rng, 1:vocab_sizes[j])
                elseif eltype(tokens[j]) == Float32
                    tokens[j][i] = rand(rng) * vocab_sizes[j]
                else
                    @assert false
                end
            end
        end
    end

    # demean ratings
    for medium in [:anime, :manga]
        demean_explicit_ratings!(
            tokens = tokens,
            medium = medium,
            demean = demean[medium],
            explicit_baseline = explicit_baseline[medium],
            vocab_sizes = vocab_sizes,
            cls_tokens = cls_tokens,
            empty_tokens = empty_tokens,
            positions = positions[medium],
            labels = labels[medium],
            userids = userids[medium],
        )
    end

    if !training
        for x in [:anime, :manga]
            for y in [:item, :rating]
                weight_by_user!(weights[x][y], userids[x][y])
            end
        end
    end

    tokens, positions, labels, weights
end;

In [None]:
function demean_explicit_ratings!(;
    tokens,
    medium,
    demean,
    explicit_baseline,
    vocab_sizes,
    cls_tokens,
    empty_tokens,
    positions,
    labels,
    userids,
)
    user_to_baseline = Dict{Int32,Float32}()
    μ_user = mean(explicit_baseline["u"])
    μ_item = mean(explicit_baseline["a"])
    for u in keys(demean[:rating])
        user_weight = powerdecay(demean[:count][u], log(explicit_baseline["λ"][3]))
        user_to_baseline[u] =
            (demean[:rating][u] * user_weight + μ_user * explicit_baseline["λ"][1]) /
            (demean[:weight][u] * user_weight + explicit_baseline["λ"][1])
    end
    get_user_bias(u) = u in keys(user_to_baseline) ? user_to_baseline[u] : μ_user
    get_item_bias(a) =
        a in keys(explicit_baseline["a"]) ? explicit_baseline["a"][a] : μ_item

    for i::Int32 = 1:size(extract(tokens, medium))[1]
        if extract(tokens, medium)[i] == extract(empty_tokens, medium)
            continue
        end
        has_explicit_rating = (extract(tokens, :rating)[i] .< extract(vocab_sizes, :rating))
        if has_explicit_rating
            extract(tokens, :rating)[i] -=
                get_user_bias(extract(tokens, :user)[i]) +
                get_item_bias(extract(tokens, medium)[i])
        end
    end
    for i = 1:length(labels[:rating])
        if userids[:rating][i] != 0
            labels[:rating][i] -=
                get_user_bias(userids[:rating][i]) + get_item_bias(positions[:rating][i])
        end
    end
end;

In [None]:
function weight_by_user!(weights, userids)
    uid_to_count = Dict(i => 0 for i in userids)
    for i in userids
        uid_to_count[i] += 1
    end
    for i = 1:length(userids)
        if weights[i] != 0
            weights[i] /= uid_to_count[userids[i]]
        end
    end
    weights
end;

# Data colleciton

In [None]:
function shuffle_training_data(rng, sentences, max_sequence_length, max_document_length)
    order = Random.shuffle(rng, 1:length(sentences))
    S = eltype(sentences)
    W = eltype(sentences[1])

    # concatenate all tokens
    tokens = Vector{W}()
    for i in order
        sentence = subset_sentence(
            sentences[i],
            max_document_length;
            recent = false,
            keep_first = false,
            rng = rng,
        )
        for token in sentence
            push!(tokens, token)
        end
    end

    # patition tokens into minibatches
    batched_sentences = Vector{S}()
    sentence = Vector{W}()
    for token in tokens
        push!(sentence, token)
        if length(sentence) == max_sequence_length
            push!(batched_sentences, sentence)
            sentence = Vector{W}()
        end
    end
    if length(sentence) > 0
        push!(batched_sentences, sentence)
    end
    batched_sentences
end;

In [None]:
function get_training_data(media, include_ptw, cls_tokens, empty_tokens)
    n_tasks = length(ALL_TASKS)
    sentences = Vector{Vector{Vector{wordtype}}}(undef, n_tasks)
    for i = 1:length(sentences)
        data = get_training_data(ALL_TASKS[i], media, include_ptw, cls_tokens, empty_tokens)
        sentences[i] = collect(values(data))
    end
    vcat(sentences...)
end;

In [None]:
function get_sentences(rng, training_config)
    sentences = get_training_data(
        training_config["media"],
        training_config["include_ptw_impressions"],
        training_config["cls_tokens"],
        training_config["empty_tokens"],
    )
    Random.shuffle!(rng, sentences)
    cutoff = Int(round(0.99 * length(sentences)))
    training = sentences[1:cutoff]
    validation = sentences[cutoff+1:end]
    training, validation
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    rng
end;

In [None]:
function create_training_config()
    media = ["anime", "manga"]
    base_vocab_sizes = (
        Int32(num_items("anime")),
        Int32(num_items("manga")),
        Float32(11),
        Float32(1),
        Int32(5),
        Float32(1),
        Int32(maximum(num_users(x) for x in media)),
        Int32(512), # todo increase
    )
    d = Dict(
        # tokenization
        "base_vocab_sizes" => base_vocab_sizes,
        "cls_tokens" => base_vocab_sizes .+ Int32(1),
        "pad_tokens" => base_vocab_sizes .+ Int32(2),
        "mask_tokens" => base_vocab_sizes .+ Int32(3),
        "empty_tokens" => base_vocab_sizes .+ Int32(4),
        "vocab_sizes" => base_vocab_sizes .+ Int32(4),
        # data
        "max_document_length" => Inf,
        "include_ptw_impressions" => true,
        "explicit_baseline" => Dict(
            Symbol(x) =>
                [read_params("$x/$t/ExplicitUserItemBiases") for t in ALL_TASKS] for
            x in ["anime", "manga"]
        ),
        "media" => media,
        "chunk_size" => 2^16,
        # model
        "max_sequence_length" => extract(base_vocab_sizes, :position),
    )
    @assert d["max_document_length"] >= d["max_sequence_length"]
    for (k, v) in d["explicit_baseline"]
        for i = 1:length(ALL_TASKS)
            v[i]["weights"] = powerdecay(
                get_counts(
                    "training",
                    "all",
                    "explicit",
                    String(k),
                    by_item = true,
                    per_rating = false,
                ),
                log(v[i]["λ"][4]),
            )
            @assert length(v[i]) == 4
            @assert length(v[i]["λ"]) == 5
        end
    end
    d
end;

In [None]:
function set_epoch_size!(training_config, training_sentences, validation_sentences)
    num_tokens =
        sum(min.(length.(training_sentences), training_config["max_document_length"]))
    @info "Number of training sentences: $(length(training_sentences))"
    @info "Number of training tokens: $(num_tokens)"
    training_config["tokens_per_epoch"] = Int(num_tokens)
    training_config["num_training_sentences"] = length(training_sentences)
    training_config["num_validation_sentences"] = length(validation_sentences)
end;

In [None]:
function setup_training(config, outdir)
    if !isdir(outdir)
        mkdir(outdir)
    end
    for x in readdir(outdir, join = true)
        if isfile(x)
            rm(x)
        end
    end
    fn = joinpath(outdir, "config.json")
    open(fn * "~", "w") do f
        write(f, JSON.json(config))
    end
    mv(fn * "~", fn)
end;

# Disk I/O

In [None]:
function save_features(sentences, config, rng, training, outfile)
    if training
        sentences = shuffle_training_data(
            rng,
            sentences,
            config["max_sequence_length"],
            config["max_document_length"],
        )
    end

    features = []
    for x in sentences
        push!(features, featurize(x, config, rng, training))
    end

    d = Dict{String,AbstractArray}()
    collate = Flux.batch

    embed_names = [
        "anime",
        "manga",
        "rating",
        "timestamp",
        "status",
        "completion",
        "user",
        "position",
    ]
    for (i, name) in Iterators.enumerate(embed_names)
        d[name] = collate([x[1][i] for x in features])
    end
    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            d["positions_$(medium)_$(task)"] =
                collate([x[2][Symbol(medium)][Symbol(task)] for x in features])
        end
    end
    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            d["labels_$(medium)_$(task)"] =
                collate([x[3][Symbol(medium)][Symbol(task)] for x in features])
        end
    end
    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            d["weights_$(medium)_$(task)"] =
                collate([x[4][Symbol(medium)][Symbol(task)] for x in features])
        end
    end

    HDF5.h5open(outfile, "w") do file
        for (k, v) in d
            write(file, k, v)
        end
    end
end;

In [None]:
function spawn_feature_workers(sentences, workers, config, rng, training, outdir)
    # writes data to "$outdir/data.$worker.h5" in a hot loop
    # whenever that file disappears, we populate it with a new batch
    # we stop when the file "$outdir/finished" appears
    chunk_size = config["chunk_size"]
    finished = joinpath(outdir, "finished")
    stem = training ? "training" : "validation"
    rngs = [Random.Xoshiro(rand(rng, UInt64)) for _ = 1:workers]
    @sync for (i, batch) in Iterators.enumerate(
        Iterators.partition(sentences, div(length(sentences), workers, RoundUp)),
    )
        Threads.@spawn begin
            rng = rngs[i]
            while true
                Random.shuffle!(rng, batch)
                for (j, chunk) in
                    Iterators.enumerate(Iterators.partition(batch, chunk_size))
                    GC.gc()
                    fn = joinpath(outdir, "$stem.$i.h5")
                    while isfile(fn) && !isfile(finished)
                        sleep(1)
                    end
                    if isfile(finished)
                        break
                    end
                    save_features(chunk, config, rng, training, fn)
                    open("$fn.complete", "w") do f
                        write(f, "$j")
                    end
                end
            end
        end
    end
end;

# State

In [None]:
config_checkpoint = nothing
config_epoch = nothing
reset_lr_schedule = true
rng = set_rngs(20221221)
config = create_training_config()
num_workers = 8;

In [None]:
@info "loading data"
training_sentences, validation_sentences = get_sentences(rng, config)
set_epoch_size!(config, training_sentences, validation_sentences);

In [None]:
outdir = get_data_path(joinpath("alphas", name, "training"));
config["num_workers"] = num_workers
setup_training(config, outdir);

In [None]:
Threads.@spawn spawn_feature_workers(
    training_sentences,
    num_workers,
    config,
    rng,
    true,
    outdir,
);

In [None]:
Threads.@spawn spawn_feature_workers(
    validation_sentences,
    num_workers,
    config,
    rng,
    false,
    outdir,
);

In [None]:
source_nb = "$(pwd())/PretrainPytorch.ipynb"
dest_nb = "$(outdir)/PretrainPytorch.ipynb"
cmd = `papermill $source_nb $dest_nb --no-progress-bar -p name $name`
run(cmd)