# Pretraining
* Trains a BERT-style transformer model on watch histories

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Data.ipynb");

In [None]:
name = "all/Transformer/v1"
set_logging_outdir(name);

In [None]:
import HDF5
import JSON
import MLUtils
import Random
import StatsBase: mean, sample
import ThreadPinning

In [None]:
ThreadPinning.pinthreads(:cores)

# Featurization

In [None]:
featurize(sentence::Vector{wordtype}, config, rng) = featurize(;
    sentence = sentence,
    max_seq_len = config["max_sequence_length"],
    vocab_sizes = config["base_vocab_sizes"],
    pad_tokens = config["pad_tokens"],
    cls_tokens = config["cls_tokens"],
    mask_tokens = config["mask_tokens"],
    rng = rng,
);

In [None]:
function featurize(;
    sentence::Vector{wordtype},
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
    rng,
)
    sentence = subset_sentence(sentence, max_seq_len; recent = false, rng = rng)
    tokens = get_token_ids(sentence, max_seq_len, pad_tokens, cls_tokens)

    positions = Dict(
        x => Dict(y => zeros(Int32, max_seq_len) for y in ALL_METRICS) for x in ALL_MEDIUMS
    )
    weights = Dict(
        x => Dict(y => zeros(Float32, max_seq_len) for y in ALL_METRICS) for
        x in ALL_MEDIUMS
    )
    labels = Dict(
        x => Dict(y => zeros(Float32, max_seq_len) for y in ALL_METRICS) for
        x in ALL_MEDIUMS
    )
    userids = Dict(
        x => Dict(y => zeros(Int32, max_seq_len) for y in ALL_METRICS) for x in ALL_MEDIUMS
    )
    for i::Int32 = 1:max_seq_len
        is_manga = (extract(tokens, :itemid)[i] < num_items("manga"))
        is_anime =
            num_items("manga") <=
            extract(tokens, :itemid)[i] <
            num_items("manga") + num_items("anime")
        has_rating =
            0 < extract(tokens, :rating)[i] <= vocab_sizes[get_wordtype_index(:rating)]
        has_watch =
            get_status(:plan_to_watch) <
            extract(tokens, :status)[i] <=
            vocab_sizes[get_wordtype_index(:status)]
        has_plantowatch = extract(tokens, :status)[i] == get_status(:plan_to_watch)
        has_drop =
            get_status(:none) <
            extract(tokens, :status)[i] <=
            vocab_sizes[get_wordtype_index(:status)]
        if !(has_rating || has_watch || has_plantowatch || has_drop)
            continue
        end
        @assert xor(is_manga, is_anime)
        if is_manga
            medium = "manga"
            mediaid = extract(tokens, :itemid)[i]
        elseif is_anime
            medium = "anime"
            mediaid = extract(tokens, :itemid)[i] - num_items("manga")
        else
            @assert false
        end

        # mask and make predictions for 15% of tokens
        if rand(rng) > 0.15
            continue
        end
        if has_rating
            positions[medium]["rating"][i] = mediaid
            labels[medium]["rating"][i] = extract(tokens, :rating)[i]
            weights[medium]["rating"][i] = 1
            userids[medium]["rating"][i] = extract(tokens, :userid)[i]
        end
        if has_watch
            positions[medium]["watch"][i] = mediaid
            labels[medium]["watch"][i] = 1
            weights[medium]["watch"][i] = 1
            userids[medium]["watch"][i] = extract(tokens, :userid)[i]
        end
        if has_plantowatch
            positions[medium]["plantowatch"][i] = mediaid
            labels[medium]["plantowatch"][i] = 1
            weights[medium]["plantowatch"][i] = 1
            userids[medium]["plantowatch"][i] = extract(tokens, :userid)[i]
        end
        if has_drop
            positions[medium]["drop"][i] = mediaid
            labels[medium]["drop"][i] = extract(tokens, :status)[i] <= get_status(:dropped)
            weights[medium]["drop"][i] = 1
            userids[medium]["drop"][i] = extract(tokens, :userid)[i]
        end

        # do bert masking        
        r = rand(rng)
        if r <= 0.8
            for j = 1:length(tokens)
                if j ∉ get_wordtype_index.([:userid, :position])
                    tokens[j][i] = mask_tokens[j]
                end
            end
        elseif r <= 0.9
            for j = 1:length(tokens)
                if j ∉ get_wordtype_index.([:userid, :position])
                    if eltype(vocab_sizes[j]) == Int32
                        tokens[j][i] = rand(rng, 0:vocab_sizes[j]-1)
                    elseif eltype(tokens[j]) == Float32
                        tokens[j][i] = rand(rng) * vocab_sizes[j]
                    else
                        @assert false
                    end
                end
            end
        else
            nothing
        end
    end

    tokens, positions, labels, weights
end;

# Data collection

In [None]:
function shuffle_training_data(rng, sentences, max_sequence_length, max_document_length)
    order = Random.shuffle(rng, 1:length(sentences))
    S = eltype(sentences)
    W = eltype(sentences[1])

    # patition tokens into minibatches
    batched_sentences = Vector{S}()
    sentence = Vector{W}()
    for i in order
        subset =
            subset_sentence(sentences[i], max_document_length; recent = false, rng = rng)
        for token in subset
            push!(sentence, token)
            if length(sentence) == max_sequence_length
                push!(batched_sentences, sentence)
                sentence = Vector{W}()
            end
        end
    end
    if length(sentence) > 0
        push!(batched_sentences, sentence)
    end
    batched_sentences
end;

In [None]:
function get_sentences(rng, training_config)
    sentences = collect(values(get_training_data(training_config["cls_tokens"])))
    Random.shuffle!(rng, sentences)
    cutoff = Int(round(0.99 * length(sentences))) # TODO switch to 0.999 for prod
    training = sentences[1:cutoff]
    validation = sentences[cutoff+1:end]
    training, validation
end;

# Configuration

In [None]:
function set_rngs(seed)
    rng = Random.Xoshiro(seed)
    Random.seed!(rand(rng, UInt64))
    rng
end;

In [None]:
function create_training_config()
    # itemid, rating, updated_at, status, position, userid
    max_sequence_length = 1024
    base_vocab_sizes = (
        Int32(sum(num_items.(ALL_MEDIUMS)) - 1),
        Float32(10),
        Float32(1),
        Int32(get_status(:rewatching)),
        Int32(max_sequence_length - 1),
        Int32(num_users() - 1),
    )
    d = Dict(
        # tokenization
        "base_vocab_sizes" => base_vocab_sizes,
        "cls_tokens" => base_vocab_sizes .+ Int32(1),
        "pad_tokens" => base_vocab_sizes .+ Int32(2),
        "mask_tokens" => base_vocab_sizes .+ Int32(3),
        "vocab_sizes" => base_vocab_sizes .+ Int32(4),
        "vocab_types" => ("int", "float", "float", "int", "int", "none"),
        "media_sizes" => Dict(m => num_items(m) for m in ALL_MEDIUMS),
        # data
        "max_document_length" => Inf, # TODO experiment with subsampling
        "chunk_size" => 2^14,
        "num_training_shards" => 24,
        "num_validation_shards" => 8,
        # model
        "max_sequence_length" => max_sequence_length,
        "mode" => "pretrain",
    )
    @assert d["max_document_length"] >= d["max_sequence_length"]
    @assert length(d["vocab_sizes"]) == length(d["vocab_types"])
    d
end;

In [None]:
function set_epoch_size!(config, training_sentences, validation_sentences)
    num_training_tokens =
        Int64(sum(min.(length.(training_sentences), config["max_document_length"])))
    num_validation_tokens =
        Int64(sum(min.(length.(validation_sentences), config["max_document_length"])))
    @info "Number of training tokens: $(num_training_tokens)"
    @info "Number of training sentences: $(length(training_sentences))"
    @info "Number of validation tokens: $(num_validation_tokens)"
    @info "Number of validation sentences: $(length(validation_sentences))"
    config["training_epoch_size"] =
        div(num_training_tokens, config["max_sequence_length"], RoundUp)
    config["validation_epoch_size"] =
        div(num_validation_tokens, config["max_sequence_length"], RoundUp)
end;

In [None]:
function setup_training(config, outdir)
    if !isdir(outdir)
        mkpath(outdir)
    end
    for x in readdir(outdir, join = true)
        if isfile(x)
            rm(x)
        end
    end
    fn = joinpath(outdir, "..", "config.json")
    open(fn, "w") do f
        write(f, JSON.json(config))
    end
end;

# Disk I/O

In [None]:
function save_features(sentences, config, rng, filename)
    sentences = shuffle_training_data(
        rng,
        sentences,
        config["max_sequence_length"],
        config["max_document_length"],
    )
    features = []
    for x in sentences
        push!(features, featurize(x, config, rng))
    end

    d = Dict{String,AbstractArray}()
    collate = MLUtils.batch

    embed_names = ["itemid", "rating", "updated_at", "status", "position", "userid"]
    for (i, name) in Iterators.enumerate(embed_names)
        d[name] = collate([x[1][i] for x in features])
    end
    for medium in ALL_MEDIUMS
        for metric in ALL_METRICS
            stem = "$(medium)_$(metric)"
            d["positions_$stem"] = collate([x[2][medium][metric] for x in features])
            d["labels_$stem"] = collate([x[3][medium][metric] for x in features])
            d["weights_$stem"] = collate([x[4][medium][metric] for x in features])
        end
    end
    HDF5.h5open(filename, "w") do file
        for (k, v) in d
            write(file, k, v)
        end
    end
end;

In [None]:
function advance!(filename)
    # check to see if we should write the next shard
    outdir = dirname(filename)
    files = readdir(outdir)
    suffix = basename(filename) * ".read"
    files = [x for x in files if occursin(suffix, basename(x))]
    if length(files) == 0
        return false
    end
    world_sizes = Set(split(x, ".")[end] for x in files)
    @assert length(world_sizes) == 1
    world_size = parse(Int, first(world_sizes))
    advance = length(files) == world_size
    if advance
        rm("$filename.complete")
        rm(filename)
        for x in files
            rm(joinpath(outdir, x))
        end
    end
    advance
end;

In [None]:
function spawn_feature_workers(sentences, config, rng, training, outdir)
    # writes data to "$outdir/training/$split.$worker.h5" in a loop
    chunk_size = config["chunk_size"]
    workers = training ? config["num_training_shards"] : config["num_validation_shards"]
    stem = training ? "training" : "validation"
    rngs = [Random.Xoshiro(rand(rng, UInt64)) for _ = 1:workers]
    for (i, batch) in Iterators.enumerate(
        Iterators.partition(sentences, div(length(sentences), workers, RoundUp)),
    )
        Threads.@spawn begin
            rng = rngs[i]
            while true
                Random.shuffle!(rng, batch)
                for (j, chunk) in
                    Iterators.enumerate(Iterators.partition(batch, chunk_size))
                    filename = joinpath(outdir, "$stem.$i.h5")
                    save_features(chunk, config, rng, filename)
                    open("$filename.complete", "w") do f
                        write(f, "$j")
                    end
                    if i == 1
                        GC.gc()
                    end
                    while isdir(outdir) && !advance!(filename)
                        sleep(1)
                    end
                    if !isdir(outdir)
                        return
                    end
                end
            end
        end
    end
end;

# State

In [None]:
config_checkpoint = nothing
config_epoch = nothing
rng = set_rngs(20221221)
config = create_training_config();

In [None]:
@info "loading data"
training_sentences, validation_sentences = get_sentences(rng, config)
set_epoch_size!(config, training_sentences, validation_sentences);

In [None]:
outdir = get_data_path(joinpath("alphas", name, "training"))
setup_training(config, outdir);

In [None]:
spawn_feature_workers(training_sentences, config, rng, true, outdir);

In [None]:
spawn_feature_workers(validation_sentences, config, rng, false, outdir);

In [None]:
# wait for workers to begin writing
while sum(endswith.(readdir(outdir), (".complete",))) <
      config["num_training_shards"] + config["num_validation_shards"]
    sleep(1)
end