# Pretraining
* Trains a BERT-style transformer model on watch histories
* A watch history is a sorted sequence of interations, where each interaction consists of a user $u$, an item $i$, a timestamp $t$, and optional metadata
* The model predicts 1) the probability that $u$ will watch $i$ at $t$, and 2) the rating that $u$ will give to $i$ at $t$

In [None]:
name = "all/Transformer/v2";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Data.ipynb");

In [None]:
import HDF5
import JSON
import MLUtils
import Random
import StatsBase: mean, sample
import ThreadPinning

In [None]:
ThreadPinning.pinthreads(:cores)

# Featurization

In [None]:
featurize(sentence::Vector{wordtype}, config, rng, training::Bool) = featurize(;
    sentence = sentence,
    max_seq_len = config["max_sequence_length"],
    vocab_sizes = config["base_vocab_sizes"],
    pad_tokens = config["pad_tokens"],
    cls_tokens = config["cls_tokens"],
    mask_tokens = config["mask_tokens"],
    empty_tokens = config["empty_tokens"],
    causal = config["causal"],
    rng = rng,
    training = training,
);

In [None]:
function featurize(;
    sentence::Vector{wordtype},
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
    empty_tokens,
    causal::Bool,
    rng,
    training,
)
    # pad to the largest sequence length
    seq_len = max_seq_len
    sentence = subset_sentence(sentence, seq_len; recent = false, rng = rng)

    # get tokenized sentences
    tokens =
        vec.(
            get_token_ids(
                [sentence],
                seq_len,
                extract(vocab_sizes, :position),
                pad_tokens,
                cls_tokens,
            ),
        )

    positions = (
        anime = (item = zeros(Int32, seq_len), rating = zeros(Int32, seq_len)),
        manga = (item = zeros(Int32, seq_len), rating = zeros(Int32, seq_len)),
    )
    weights = (
        anime = (item = zeros(Float32, seq_len), rating = zeros(Float32, seq_len)),
        manga = (item = zeros(Float32, seq_len), rating = zeros(Float32, seq_len)),
    )
    labels = (
        anime = (item = zeros(Float32, seq_len), rating = zeros(Float32, seq_len)),
        manga = (item = zeros(Float32, seq_len), rating = zeros(Float32, seq_len)),
    )
    userids = (
        anime = (item = zeros(Int32, seq_len), rating = zeros(Int32, seq_len)),
        manga = (item = zeros(Int32, seq_len), rating = zeros(Int32, seq_len)),
    )
    for i::Int32 = 1:seq_len
        # randomly mask 15% of non-trivial tokens 
        has_anime =
            (extract(tokens, :anime)[i] <= extract(vocab_sizes, :anime)) &&
            (extract(tokens, :status)[i] != get_status(:plan_to_watch))
        has_manga =
            (extract(tokens, :manga)[i] <= extract(vocab_sizes, :manga)) &&
            (extract(tokens, :status)[i] != get_status(:plan_to_watch))
        has_rating = extract(tokens, :rating)[i] < extract(vocab_sizes, :rating)
        if has_anime
            medium = :anime
        elseif has_manga
            medium = :manga
        end

        if causal
            if i == 1 || (extract(tokens, :user)[i] != extract(tokens, :user)[i-1])
                continue
            end
            if has_anime || has_manga
                positions[medium][:item][i-1] = extract(tokens, medium)[i]
                labels[medium][:item][i-1] = 1
                weights[medium][:item][i-1] = 1
                userids[medium][:item][i-1] = extract(tokens, :user)[i]
            end
            if has_rating
                positions[medium][:rating][i-1] = extract(tokens, medium)[i]
                labels[medium][:rating][i-1] = extract(tokens, :rating)[i]
                weights[medium][:rating][i-1] = 1
                userids[medium][:rating][i-1] = extract(tokens, :user)[i]
            end
            continue
        end

        should_mask = rand(rng) < 0.15

        # record tokens before we mask them out
        if !(should_mask && (has_anime || has_manga || has_rating))
            continue
        end
        if has_anime || has_manga
            positions[medium][:item][i] = extract(tokens, medium)[i]
            labels[medium][:item][i] = 1
            weights[medium][:item][i] = 1
            userids[medium][:item][i] = extract(tokens, :user)[i]
        end
        if has_rating
            positions[medium][:rating][i] = extract(tokens, medium)[i]
            labels[medium][:rating][i] = extract(tokens, :rating)[i]
            weights[medium][:rating][i] = 1
            userids[medium][:rating][i] = extract(tokens, :user)[i]
        end

        # bert masking
        item_allowed_info = get_wordtype_index.([medium, :rating, :timestamp, :position])
        item_skip_info = get_wordtype_index.([:anime, :manga, :user])
        for j = 1:length(tokens)
            if j in item_allowed_info || j in item_skip_info
                continue
            end
            tokens[j][i] = mask_tokens[j]
        end
        for j in item_allowed_info
            if j in get_wordtype_index.([medium, :rating])
                cutoffs = (0.8, 0.9)
                r = training ? rand(rng) : 0.0
            elseif j == get_wordtype_index(:timestamp)
                cutoffs = (0.45, 0.9)
                r = training ? rand(rng) : 0.0
            elseif j == get_wordtype_index(:position)
                cutoffs = (0.45, 0.9)
                r = training ? rand(rng) : 0.7
            else
                @assert false
            end
            if r <= cutoffs[1]
                tokens[j][i] = mask_tokens[j]
            elseif r <= cutoffs[2]
                nothing
            else
                if eltype(vocab_sizes[j]) == Int32
                    tokens[j][i] = rand(rng, 1:vocab_sizes[j])
                elseif eltype(tokens[j]) == Float32
                    tokens[j][i] = rand(rng) * vocab_sizes[j]
                else
                    @assert false
                end
            end
        end
    end

    if !training
        for x in [:anime, :manga]
            for y in [:item, :rating]
                weight_by_user!(weights[x][y], userids[x][y])
            end
        end
    end

    tokens, positions, labels, weights
end;

In [None]:
function weight_by_user!(weights, userids)
    uid_to_count = Dict(i => 0 for i in userids)
    for i in userids
        uid_to_count[i] += 1
    end
    for i = 1:length(userids)
        if weights[i] != 0
            weights[i] /= uid_to_count[userids[i]]
        end
    end
    weights
end;

# Data collection

In [None]:
function shuffle_training_data(rng, sentences, max_sequence_length, max_document_length)
    order = Random.shuffle(rng, 1:length(sentences))
    S = eltype(sentences)
    W = eltype(sentences[1])

    # patition tokens into minibatches
    batched_sentences = Vector{S}()
    sentence = Vector{W}()
    for i in order
        subset =
            subset_sentence(sentences[i], max_document_length; recent = false, rng = rng)
        for token in subset
            push!(sentence, token)
            if length(sentence) == max_sequence_length
                push!(batched_sentences, sentence)
                sentence = Vector{W}()
            end
        end
    end
    if length(sentence) > 0
        push!(batched_sentences, sentence)
    end
    batched_sentences
end;

In [None]:
function get_training_data(media, include_ptw, cls_tokens, empty_tokens, causal::Bool)
    n_tasks = length(ALL_TASKS)
    sentences = Vector{Vector{Vector{wordtype}}}(undef, n_tasks)
    for i = 1:length(sentences)
        data = get_training_data(
            ALL_TASKS[i],
            media,
            include_ptw,
            cls_tokens,
            empty_tokens,
            causal,
        )
        sentences[i] = collect(values(data))
    end
    vcat(sentences...)
end;

In [None]:
function get_sentences(rng, training_config)
    sentences = get_training_data(
        training_config["media"],
        training_config["include_ptw_impressions"],
        training_config["cls_tokens"],
        training_config["empty_tokens"],
        training_config["causal"],
    )
    Random.shuffle!(rng, sentences)
    cutoff = Int(round(0.99 * length(sentences)))
    training = sentences[1:cutoff]
    validation = sentences[cutoff+1:end]
    training, validation
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    rng
end;

In [None]:
function create_training_config()
    media = ["anime", "manga"]
    base_vocab_sizes = (
        Int32(num_items("anime")),
        Int32(num_items("manga")),
        Float32(11),
        Float32(1),
        Int32(5),
        Float32(1),
        Int32(num_users()),
        Int32(1024),
    )
    d = Dict(
        # tokenization
        "base_vocab_sizes" => base_vocab_sizes,
        "cls_tokens" => base_vocab_sizes .+ Int32(1),
        "pad_tokens" => base_vocab_sizes .+ Int32(2),
        "mask_tokens" => base_vocab_sizes .+ Int32(3),
        "empty_tokens" => base_vocab_sizes .+ Int32(4),
        "vocab_sizes" => base_vocab_sizes .+ Int32(4),
        # data
        "max_document_length" => Inf,
        "include_ptw_impressions" => true,
        "media" => media,
        "chunk_size" => 2^14,
        "num_training_shards" => 24,
        "num_validation_shards" => 8,
        "num_dataloader_workers" => 1,
        # model
        "max_sequence_length" => extract(base_vocab_sizes, :position),
        "mode" => "pretrain",
        "causal" => true,
    )
    @assert d["max_document_length"] >= d["max_sequence_length"]
    d
end;

In [None]:
function set_epoch_size!(config, training_sentences, validation_sentences)
    num_tokens = sum(min.(length.(training_sentences), config["max_document_length"]))
    @info "Number of training tokens: $(num_tokens)"
    @info "Number of training sentences: $(length(training_sentences))"
    @info "Number of validation sentences: $(length(validation_sentences))"
    config["training_epoch_size"] = div(num_tokens, config["max_sequence_length"], RoundUp)
    config["validation_epoch_size"] = length(validation_sentences)
end;

In [None]:
function get_temp_path(x)
    get_data_path(x)
end;

In [None]:
function setup_training(config, outdir)
    if !isdir(outdir)
        mkpath(outdir)
    end
    for x in readdir(outdir, join = true)
        if isfile(x)
            rm(x)
        end
    end
    fn = joinpath(outdir, "..", "config.json")
    open(fn * "~", "w") do f
        write(f, JSON.json(config))
    end
    mv(fn * "~", fn, force = true)
end;

# Disk I/O

In [None]:
function save_features(sentences, config, rng, training, filename)
    if training
        sentences = shuffle_training_data(
            rng,
            sentences,
            config["max_sequence_length"],
            config["max_document_length"],
        )
    end
    features = []
    for x in sentences
        push!(features, featurize(x, config, rng, training))
    end

    d = Dict{String,AbstractArray}()
    d["causal"] = [config["causal"]]
    collate = MLUtils.batch
    embed_names = [
        "anime",
        "manga",
        "rating",
        "timestamp",
        "status",
        "completion",
        "user",
        "position",
    ]
    for (i, name) in Iterators.enumerate(embed_names)
        d[name] = collate([x[1][i] for x in features])
    end
    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            d["positions_$(medium)_$(task)"] =
                collate([x[2][Symbol(medium)][Symbol(task)] for x in features])
        end
    end
    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            d["labels_$(medium)_$(task)"] =
                collate([x[3][Symbol(medium)][Symbol(task)] for x in features])
        end
    end
    for medium in ["anime", "manga"]
        for task in ["item", "rating"]
            d["weights_$(medium)_$(task)"] =
                collate([x[4][Symbol(medium)][Symbol(task)] for x in features])
        end
    end
    HDF5.h5open(filename, "w") do file
        for (k, v) in d
            write(file, k, v)
        end
    end
end;

In [None]:
function advance!(filename)
    # check to see if we should write the next shard
    outdir = dirname(filename)
    files = readdir(outdir)
    suffix = basename(filename) * ".read"
    files = [x for x in files if occursin(suffix, basename(x))]
    if length(files) == 0
        return false
    end
    workers = Set(split(x, ".")[end-1:end] for x in files)
    @assert length(workers) == 1
    world_size, num_workers = parse.(Int, first(workers))
    advance = length(files) == world_size * num_workers
    if advance
        rm("$filename.complete")
        rm(filename)
        for x in files
            rm(joinpath(outdir, x))
        end
    end
    advance
end;

In [None]:
function spawn_feature_workers(sentences, config, rng, training, outdir)
    # writes data to "$outdir/data.$worker.h5" in a hot loop
    # whenever that file disappears, we populate it with a new batch
    # we stop when the file "$outdir/finished" appears
    chunk_size = config["chunk_size"]
    workers = training ? config["num_training_shards"] : config["num_validation_shards"]
    stem = training ? "training" : "validation"
    rngs = [Random.Xoshiro(rand(rng, UInt64)) for _ = 1:workers]
    for (i, batch) in Iterators.enumerate(
        Iterators.partition(sentences, div(length(sentences), workers, RoundUp)),
    )
        Threads.@spawn begin
            rng = rngs[i]
            while true
                Random.shuffle!(rng, batch)
                for (j, chunk) in
                    Iterators.enumerate(Iterators.partition(batch, chunk_size))
                    filename = joinpath(outdir, "$stem.$i.h5")
                    save_features(chunk, config, rng, training, filename)
                    open("$filename.complete", "w") do f
                        write(f, "$j")
                    end
                    if i == 1
                        GC.gc()
                    end
                    while isdir(outdir) && !advance!(filename)
                        sleep(1)
                    end
                    if !isdir(outdir)
                        break
                    end
                end
            end
        end
    end
end;

# State

In [None]:
config_checkpoint = nothing
config_epoch = nothing
rng = set_rngs(20221221)
config = create_training_config();

In [None]:
@info "loading data"
training_sentences, validation_sentences = get_sentences(rng, config)
set_epoch_size!(config, training_sentences, validation_sentences);

In [None]:
outdir = get_temp_path(joinpath("alphas", name, "training"))
setup_training(config, outdir);

In [None]:
spawn_feature_workers(training_sentences, config, rng, true, outdir);

In [None]:
spawn_feature_workers(validation_sentences, config, rng, false, outdir);

In [None]:
run(`python3 Pytorch.py --outdir $name --epochs 64`)

In [None]:
rm(outdir, recursive = true)