# Pretraining
* Trains a bag-of-words model on user data

In [None]:
medium = ""
metric = ""
mode = "";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb");

In [None]:
import H5Zblosc
import HDF5
import JSON
import SparseArrays: AbstractSparseArray, sparse

In [None]:
version = "v1"
basepath = "$medium/BagOfWords/$version"
name = "$basepath/$metric"
set_logging_outdir(name);

# Data

In [None]:
@memoize function get_rating_beta(name)
    params = read_params(name, false)
    params["β"]
end;

In [None]:
function get_inputs(medium::String, metric::String, holdout::Bool)
    @info "loading $medium $metric inputs"
    if metric == "rating"
        alpha = "$medium/Baseline/rating"
        β = get_rating_beta(alpha)
        df = get_split(
            "training",
            metric,
            medium,
            [:userid, :itemid, :metric, :update_order, :updated_at],
            alpha,
        )
        df.metric .= df.metric - df.alpha .* β
    else
        df = get_split(
            "training",
            metric,
            medium,
            [:userid, :itemid, :metric, :update_order, :updated_at],
        )
    end
    GC.gc()
    if holdout
        df, _ = training_test_split(df)
    end
    sparse(df, medium)
end;

function get_epoch_inputs_unmemoized(holdout)
    inputs = [
        get_inputs(medium, metric, holdout) for metric in ["rating", "watch"] for
        medium in ALL_MEDIUMS
    ]
    @info "loaded inputs"
    vcat(inputs...)
end;

function get_epoch_inputs(holdout)
    fn = get_data_path("alphas/all/BagOfWords/$version/inputs.$holdout.h5")
    if !isfile(fn)
        mkpath(dirname(fn))
        X = get_epoch_inputs_unmemoized(holdout)
        d = Dict{String,Any}()
        record_sparse_array!(d, "inputs", X)
        HDF5.h5open(fn, "w") do file
            for (k, v) in d
                file[k] = v
            end
        end
        return X
    else
        d = Dict{String,Any}()
        HDF5.h5open(fn, "r") do f
            g(x) = read(f[x])
            return sparse(g("inputs_i"), g("inputs_j"), g("inputs_v"), g("inputs_size")...)
        end
    end
end

function record_sparse_array!(d::Dict, name::String, x::AbstractSparseArray)
    i, j, v = SparseArrays.findnz(x)
    d[name*"_i"] = i
    d[name*"_j"] = j
    d[name*"_v"] = v
    d[name*"_size"] = [size(x)[1], num_users()]
end;

In [None]:
function get_epoch_labels(split, metric, medium)
    @info "loading labels $split"
    if split in ["pretrain", "finetune"]
        tsplit = "training"
    elseif split == "test"
        tsplit = "test"
    else
        @assert false
    end
    if metric == "rating"
        alpha = "$medium/Baseline/rating"
        df = get_split(
            tsplit,
            metric,
            medium,
            [:userid, :itemid, :metric, :update_order, :updated_at],
            alpha,
        )
        df.metric .= df.metric - df.alpha .* get_rating_beta(alpha)
    else
        df = get_split(
            tsplit,
            metric,
            medium,
            [:userid, :itemid, :metric, :update_order, :updated_at],
        )
    end
    if split == "pretrain"
        df, _ = training_test_split(df)
    elseif split == "finetune"
        _, df = training_test_split(df)
    elseif split == "test"
        nothing
    else
        @assert false
    end
    sparse(df, medium)
end;

In [None]:
function get_epoch_weights(
    split::String,
    metric::String,
    medium::String,
    λ_wu::Real,
    λ_wa::Real,
    λ_wt::Real,
)
    @info "loading weights $split"
    GC.gc()
    if split == "pretrain"
        df = get_split(
            "training",
            metric,
            medium,
            [:userid, :itemid, :update_order, :updated_at],
        )
        df, _ = training_test_split(df)
        weights = df.updated_at
        @showprogress for i = 1:length(weights)
            weights[i] = λ_wt^((1 - df.updated_at[i]) / days_in_timestamp_units(365))
        end
        df = @set df.update_order = []
        df = @set df.updated_at = []
        for (c, λ) in zip([:userid, :itemid], [λ_wu, λ_wa])
            w = get_counts(getfield(df, c))
            @showprogress for i = 1:length(weights)
                weights[i] *= powerdecay(w[i], λ)
            end
        end
    elseif split == "finetune"
        df = get_split(
            "training",
            metric,
            medium,
            [:userid, :itemid, :update_order, :updated_at],
        )
        _, df = training_test_split(df)
        weights = powerdecay(get_counts(df.userid), -1.0f0)
    elseif split == "test"
        df = get_split(
            "test",
            metric,
            medium,
            [:userid, :itemid, :update_order, :updated_at],
        )
        weights = powerdecay(get_counts(df.userid), -1.0f0)
    else
        @assert false
    end
    df = @set df.metric = weights
    GC.gc()
    sparse(df, medium)
end;

# Disk I/O

In [None]:
function create_training_config(medium, metric)
    Dict(
        # model
        "input_sizes" => num_items.(ALL_MEDIUMS),
        "output_size_index" => findfirst(x -> x == medium, ALL_MEDIUMS),
        "metric" => metric,
        # training
        "user_weight_decay" => 0.0f0,
        "item_weight_decay" => 0.0f0,
        "temporal_weight_decay" => 0.5f0,
        "mask_rate" => 0.25,
        # data
        "num_shards" => 8,
    )
end;

In [None]:
function setup_split(config, outdir)
    if !isdir(outdir)
        mkpath(outdir)
    end
    for x in readdir(outdir, join = true)
        if isfile(x)
            rm(x)
        end
    end
end;

In [None]:
function save_features(X, Y, W, epoch_size, users, valid_users, filename)
    d = Dict{String,Any}()
    data = [X, Y, W]
    names = ["inputs", "labels", "weights"]
    for i = 1:length(names)
        record_sparse_array!(d, names[i], data[i])
    end
    d["epoch_size"] = epoch_size
    d["users"] = users
    d["valid_users"] = valid_users
    HDF5.h5open(filename, "w") do file
        for (k, v) in d
            file[k, blosc = 1] = v
        end
    end
end;

# Run

In [None]:
function save_split(split, config)
    @info "loading $split data"
    outdir = get_data_path(joinpath("alphas", name, split))
    setup_split(config, outdir)
    users = collect(0:num_users()-1)
    chunks = collect(
        Iterators.partition(1:num_users(), div(num_users(), config["num_shards"], RoundUp)),
    )
    if split == "inference"
        X = get_epoch_inputs(false)
        GC.gc()
        Y = sparse(RatingsDataset(), medium) # unused
        W = sparse(RatingsDataset(), medium) # unused
        valid_users = Set{Int32}()
        for s in ["test", "negative"]
            for m in ALL_METRICS
                df = get_raw_split(s, medium, [:userid], nothing)
                valid_users = union(valid_users, Set(df.userid))
            end
        end
        valid_users = sort(collect(valid_users))
        @showprogress for i = 1:length(chunks)
            save_features(
                X[:, chunks[i]],
                Y[:, chunks[i]],
                W[:, chunks[i]],
                length(valid_users),
                users[chunks[i]],
                valid_users,
                "$outdir/data.$i.h5",
            )
        end
    elseif split in ["pretrain", "finetune", "test"]
        X = get_epoch_inputs(split != "test")
        GC.gc()
        Y = get_epoch_labels(split, metric, medium)
        W = get_epoch_weights(
            split,
            metric,
            medium,
            config["user_weight_decay"],
            config["item_weight_decay"],
            config["temporal_weight_decay"],
        )
        valid_users = users[vec(sum(W, dims = 1) .> 0)]
        epoch_size = length(valid_users)
        config["epoch_size_$(split)"] = epoch_size
        @showprogress for i = 1:length(chunks)
            save_features(
                X[:, chunks[i]],
                Y[:, chunks[i]],
                W[:, chunks[i]],
                epoch_size,
                users[chunks[i]],
                valid_users,
                "$outdir/data.$i.h5",
            )
        end
    else
        @assert false
    end
    @info "done $split data"
end

function save_splits(mode)
    config_fn = get_data_path(joinpath("alphas", name, "config.json"))
    if get_settings()["mode"] == "research"
        if mode == "training_dataset"
            splits = ["pretrain"]
            config = create_training_config(medium, metric)
        elseif mode == "test_dataset"
            splits = ["finetune", "test", "inference"]
            config = JSON.parsefile(config_fn)
        else
            @assert false
        end
    else
        @assert false # TODO
    end
    for split in splits
        GC.gc()
        save_split(split, config)
    end
    open(config_fn, "w") do f
        write(f, JSON.json(config))
    end
end;

In [None]:
if mode != "train"
    save_splits(mode)
    exit()
end

In [None]:
if get_settings()["mode"] == "research"
    modes = ["pretrain", "finetune", "inference"]
else
    @assert false
end
for mode in modes
    run(`python3 Pytorch.py --outdir $name --mode $mode`)
end

In [None]:
for split in ["pretrain", "finetune", "test", "inference"]
    outdir = get_data_path(joinpath("alphas", name, split))
    rm(outdir, recursive = true)
end

# Save

In [None]:
file = HDF5.h5open(get_data_path(joinpath("alphas", name, "predictions.h5")), "r")
predictions = read(file["predictions"])
users = read(file["users"])
close(file)

In [None]:
user_to_index = Dict()
for i = 1:length(users)
    user_to_index[users[i]] = i
end

In [None]:
# zero out watched items
if metric in ["watch", "plantowatch"]
    df = get_raw_split("training", medium, [:userid, :itemid], nothing)
    users = Set(get_raw_split("test", medium, [:userid, :itemid], nothing).userid)
    df = filter(df, df.userid .∈ (users,))
    df = @set df.metric = ones(Float32, length(df.userid))
    seen = sparse(df, medium)
    for (u, index) in user_to_index
        predictions[seen[:, u+1].nzind, index] .= 0
        predictions[:, index] ./= sum(predictions[:, index])
    end
else
    seen = nothing
end

In [None]:
function model(users, items, predictions, user_to_index)
    p = zeros(Float32, length(users))
    @showprogress for i = 1:length(p)
        @assert users[i] in keys(user_to_index)
        u = user_to_index[users[i]]
        a = items[i] + 1
        p[i] = predictions[a, u]
    end
    p
end;

In [None]:
write_alpha(
    (users, items) -> model(users, items, predictions, user_to_index),
    medium,
    name,
    ["test", "negative"],
)

In [None]:
for split in ["test"]
    if metric == "rating"
        alphas = ["$medium/Baseline/rating", name]
    else
        alphas = [name]
    end
    val = compute_loss(metric, medium, alphas, split)
    @info "$split loss = $val"
end