# Pretrain Dataset
* Saves pretraining data to disk

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb")
@nbinclude("Data.ipynb");

In [None]:
const name = "all/Transformer/maskv4"
set_logging_outdir(name);

In [None]:
import H5Zblosc
import HDF5
import JSON
import MLUtils
import Random
import StatsBase: mean, sample

In [None]:
# the dataset is too big to load into memory, so we process it in parts
partition = 0
num_partitions = 1
num_epochs = 1
mode = "map"

# Data

In [None]:
function get_sentences(cls_tokens, partition)
    sortedvals(x::Dict) = [x[k] for k in sort(collect(keys(x)))]
    sentences = sortedvals(get_training_data(cls_tokens, partition, nothing))
    Random.shuffle!(sentences)
    cutoff = Int(round(0.99 * length(sentences))) # TODO switch to 0.999 for prod
    Dict("training" => sentences[1:cutoff], "validation" => sentences[cutoff+1:end])
end;

In [None]:
tokenize(sentence::Vector{wordtype}, config, training) = tokenize(;
    sentence = sentence,
    max_seq_len = config[:max_sequence_length],
    vocab_sizes = config[:base_vocab_sizes],
    pad_tokens = config[:pad_tokens],
    cls_tokens = config[:cls_tokens],
    mask_tokens = config[:mask_tokens],
    training = training,
);

In [None]:
function tokenize(;
    sentence::Vector{wordtype},
    max_seq_len,
    vocab_sizes,
    pad_tokens,
    cls_tokens,
    mask_tokens,
    training,
)
    tokens = get_token_ids(sentence, max_seq_len, pad_tokens, cls_tokens)
    positions = Dict(
        x => Dict(y => zeros(Int32, max_seq_len) for y in ALL_METRICS) for x in ALL_MEDIUMS
    )
    weights = Dict(
        x => Dict(y => zeros(Float32, max_seq_len) for y in ALL_METRICS) for
        x in ALL_MEDIUMS
    )
    labels = Dict(
        x => Dict(y => zeros(Float32, max_seq_len) for y in ALL_METRICS) for
        x in ALL_MEDIUMS
    )
    userids = Dict(
        x => Dict(y => zeros(Int32, max_seq_len) for y in ALL_METRICS) for x in ALL_MEDIUMS
    )
    for i::Int32 = 1:max_seq_len
        is_manga = extract(tokens, :mangaid)[i] != extract(cls_tokens, :mangaid)
        is_anime = extract(tokens, :animeid)[i] != extract(cls_tokens, :animeid)
        has_rating =
            0 < extract(tokens, :rating)[i] <= vocab_sizes[get_wordtype_index(:rating)]
        has_watch =
            get_status(:plan_to_watch) <
            extract(tokens, :status)[i] <=
            vocab_sizes[get_wordtype_index(:status)]
        has_plantowatch = extract(tokens, :status)[i] == get_status(:plan_to_watch)
        has_drop =
            get_status(:none) <
            extract(tokens, :status)[i] <=
            vocab_sizes[get_wordtype_index(:status)]
        if !(has_rating || has_watch || has_plantowatch || has_drop)
            continue
        end
        @assert xor(is_manga, is_anime)
        if is_manga
            medium = "manga"
            mediaid = extract(tokens, :mangaid)[i]
        elseif is_anime
            medium = "anime"
            mediaid = extract(tokens, :animeid)[i]
        else
            @assert false
        end

        # make predictions for 15% of tokens
        if rand() < 0.15
            if has_rating
                positions[medium]["rating"][i] = mediaid
                labels[medium]["rating"][i] = extract(tokens, :rating)[i]
                weights[medium]["rating"][i] = 1
                userids[medium]["rating"][i] = extract(tokens, :userid)[i]
            end
            if has_watch
                positions[medium]["watch"][i] = mediaid
                labels[medium]["watch"][i] = 1
                weights[medium]["watch"][i] = 1
                userids[medium]["watch"][i] = extract(tokens, :userid)[i]
            end
            if has_plantowatch
                positions[medium]["plantowatch"][i] = mediaid
                labels[medium]["plantowatch"][i] = 1
                weights[medium]["plantowatch"][i] = 1
                userids[medium]["plantowatch"][i] = extract(tokens, :userid)[i]
            end
            if has_drop
                positions[medium]["drop"][i] = mediaid
                labels[medium]["drop"][i] =
                    extract(tokens, :status)[i] <= get_status(:dropped)
                weights[medium]["drop"][i] = 1
                userids[medium]["drop"][i] = extract(tokens, :userid)[i]
            end
            keep_fields = get_wordtype_index.([:updated_at, :position, :userid])
            for j = 1:length(tokens)
                if j in keep_fields
                    continue
                end
                tokens[j][i] = mask_tokens[j]
            end
        end
    end

    if !training
        for x in ALL_MEDIUMS
            for y in ALL_METRICS
                weight_by_user!(weights[x][y], userids[x][y])
            end
        end
    end

    tokens, positions, labels, weights
end;

function weight_by_user!(weights, userids)
    uid_to_count = Dict(i => 0 for i in userids)
    for i in userids
        uid_to_count[i] += 1
    end
    for i = 1:length(userids)
        if weights[i] != 0
            weights[i] /= uid_to_count[userids[i]]
        end
    end
end;

# Epochs

In [None]:
function hdf5_writer(c::Channel)
    while true
        (d, fn) = take!(c)
        HDF5.h5open(fn, "w") do f
            for (k, v) in d
                f[k, blosc = 1] = v
            end
        end
    end
end;

In [None]:
function save_tokens(sentences, config, filename, writer, training)
    tokens = Any[nothing for _ = 1:length(sentences)]
    Threads.@threads for i = 1:length(sentences)
        tokens[i] = tokenize(sentences[i], config, training)
    end
    d = Dict{String,AbstractArray}()
    collate = MLUtils.batch
    for (i, name) in Iterators.enumerate(config.vocab_names)
        d[name] = collate([x[1][i] for x in tokens])
    end
    for medium in ALL_MEDIUMS
        for metric in ALL_METRICS
            stem = "$(medium)_$(metric)"
            d["positions_$stem"] = collate([x[2][medium][metric] for x in tokens])
            d["labels_$stem"] = collate([x[3][medium][metric] for x in tokens])
            d["weights_$stem"] = collate([x[4][medium][metric] for x in tokens])
        end
    end
    put!(writer, (d, filename))
end;

In [None]:
function get_batch(sentences, state, batch_size, max_sequence_length, max_document_length)
    # Constructs an iterator that chunks the data into batches 
    # of `batch_size` sentences with `max_sequence_length` words
    W = wordtype
    batch = Vector{Vector{W}}()
    sentence = Vector{W}()
    for i = state.index:length(sentences)
        if state.resuming
            s = state.sentence
            state = @set state.resuming = false
        else
            s = subset_sentence(sentences[i], max_document_length; recent = false)
        end
        for j = 1:length(s)
            push!(sentence, s[j])
            if length(sentence) == max_sequence_length
                push!(batch, sentence)
                sentence = Vector{W}()
                if length(batch) == batch_size
                    return batch,
                    (
                        resuming = true,
                        finished = false,
                        index = i,
                        sentence = s[j+1:end],
                        batch = state.batch + 1,
                    )
                end
            end
        end
    end
    push!(batch, sentence)
    batch, (finished = true, batch = state.batch + 1)
end;

In [None]:
function save_epoch(sentences, config, epoch, outdir, split, writer)
    outdir = joinpath(outdir, split, "$epoch")
    mkpath(outdir)
    Random.shuffle!(sentences)
    state = (resuming = false, finished = false, index = 1, batch = 0)
    num_tokens = 0
    num_sentences = 0
    expected_num_tokens = getfield(config, Symbol("$(split)_epoch_tokens"))
    expected_num_sentences = getfield(config, Symbol("$(split)_epoch_size"))

    p = ProgressMeter.Progress(expected_num_tokens)
    while !state.finished
        batch, state = get_batch(
            sentences,
            state,
            config.batch_size,
            config.max_sequence_length,
            config.max_document_length,
        )
        save_tokens(
            batch,
            config,
            joinpath(outdir, "$(state.batch-1).h5"),
            writer,
            split == "training",
        )
        num_tokens += sum(length.(batch))
        num_sentences += length(batch)
        ProgressMeter.update!(p, num_tokens)
    end
    ProgressMeter.finish!(p)
    @assert (num_tokens == expected_num_tokens) && (num_sentences == expected_num_sentences)
end;

# Configs

In [None]:
function create_training_config()
    max_sequence_length = 1024
    base_vocab_sizes = (
        num_items("manga") - 1, # mangaid
        num_items("anime") - 1, # animeid
        Float32(10), # rating
        Float32(1), # updated_at
        Int32(get_status(:rewatching)), # status
        Int32(3), # source
        Float32(1), # created_at
        Float32(1), # started_at
        Float32(1), # finished_at
        Float32(1), # progress
        Float32(1), # 1 - 1 / (repeat_count + 1)
        Float32(1), # 1 - 1 / (priority + 1)
        Int32(3), # sentiment
        Float32(1), # sentiment_score
        Int32(max_sequence_length - 1), # position
        Int32(num_users() - 1), # userid
    )
    d = (
        # tokenization
        base_vocab_sizes = convert(wordtype, base_vocab_sizes),
        cls_tokens = convert(wordtype, base_vocab_sizes .+ 1),
        pad_tokens = convert(wordtype, base_vocab_sizes .+ 2),
        mask_tokens = convert(wordtype, base_vocab_sizes .+ 3),
        vocab_sizes = convert(wordtype, base_vocab_sizes .+ 4),
        vocab_types = (
            "int",
            "int",
            "float",
            "float",
            "int",
            "int",
            "float",
            "float",
            "float",
            "float",
            "float",
            "float",
            "int",
            "float",
            "int",
            "none",
        ),
        vocab_names = [
            "mangaid",
            "animeid",
            "rating",
            "updated_at",
            "status",
            "source",
            "created_at",
            "started_at",
            "finished_at",
            "progress",
            "repeat_count",
            "priority",
            "sentiment",
            "sentiment_score",
            "position",
            "userid",
        ],
        media_sizes = Dict(m => num_items(m) for m in ALL_MEDIUMS),
        # data
        max_document_length = Inf, # TODO experiment with subsampling
        batch_size = 2^11,
        training_epoch_size = -1,
        training_epoch_tokens = -1,
        validation_epoch_size = -1,
        validation_epoch_tokens = -1,
        # model
        max_sequence_length = max_sequence_length,
        mode = "pretrain",
    )
    @assert d[:max_document_length] >= d[:max_sequence_length]
    @assert length(d[:vocab_sizes]) == length(d[:vocab_types])
    @assert length(d[:vocab_sizes]) == length(d[:vocab_names])
    d
end;

In [None]:
function set_epoch_size(config, sentences)
    for t in ["training", "validation"]
        s = sentences[t]
        num_tokens = Int64(sum(min.(length.(s), config[:max_document_length])))
        @info "Number of $t tokens: $(num_tokens)"
        @info "Number of $t sentences: $(length(s))"
        config = @set config[Symbol("$(t)_epoch_tokens")] = num_tokens
        config = @set config[Symbol("$(t)_epoch_size")] =
            div(num_tokens, config[:max_sequence_length], RoundUp)
    end
    config
end;

In [None]:
function setup_training(config, outdir)
    mkpath(outdir)
    fn = joinpath(outdir, "config.json")
    open(fn, "w") do f
        write(f, JSON.json(config))
    end
    for split in ["training", "validation"]
        fn = joinpath(outdir, split)
        mkpath(fn)
        for x in readdir(fn, join = true)
            rm(x, recursive = true)
        end
    end
end;

In [None]:
@memoize function get_writer()
    writer = Channel(8)
    Threads.@spawn hdf5_writer(writer)
    writer
end;

In [None]:
function save_epochs(partition::Int, num_epochs::Int)
    Random.seed!(partition)
    config = create_training_config()
    sentences =
        JLD2.load(get_data_path(joinpath("alphas", name, "sentences.$partition.jld2")))
    config = set_epoch_size(config, sentences)
    outdir = get_data_path(joinpath("alphas", name, "$partition"))
    setup_training(config, outdir)
    writer = get_writer()
    for epoch = 0:num_epochs-1
        for t in ["validation", "training"]
            save_epoch(sentences[t], config, epoch, outdir, t, writer)
        end
    end
end;

In [None]:
if mode == "map"
    const config = create_training_config()
    JLD2.save(
        get_data_path(joinpath("alphas", name, "sentences.$partition.jld2")),
        get_sentences(config[:cls_tokens], (partition, num_partitions));
        compress = true,
    )
elseif mode == "reduce"
    @time save_epochs(partition, num_epochs)
    sleep(60) # wait for final epoch to finish writing
else
    @assert false
end