# Pretraining
* Trains a bag-of-words model on user data

In [None]:
task = ""
content = ""
medium = ""

In [None]:
name = "$medium/$task/BagOfWords/$content/v1";

In [None]:
import NBInclude: @nbinclude
@nbinclude("../Alpha.ipynb");

In [None]:
import HDF5
import JSON
import SparseArrays: AbstractSparseArray, sparse

# Data

In [None]:
function explicit_inputs(task::String, medium::String, residual_alphas::Vector{String})
    df = get_split("training", task, "explicit", medium; fields = [:user, :item, :rating])
    df = RatingsDataset(
        user = df.user,
        item = df.item,
        rating = df.rating .-
                 read_alpha(
            residual_alphas,
            "training",
            task,
            "explicit",
            medium,
            false,
        ).rating,
        medium = medium,
    )
    sparse(df)
end;

function implicit_inputs(task::String, medium::String)
    df = get_split("training", task, "implicit", medium; fields = [:user, :item, :rating])
    sparse(df)
end;

function get_epoch_inputs(task::String, residual_alphas::Vector{String})
    @assert length(residual_alphas) == length(ALL_MEDIUMS)
    inputs = []
    for i = 1:length(ALL_MEDIUMS)
        push!(inputs, explicit_inputs(task, ALL_MEDIUMS[i], residual_alphas[i:i]))
    end
    for x in ALL_MEDIUMS
        push!(inputs, implicit_inputs(task, x))
    end
    reduce(vcat, inputs)
end;

In [None]:
function get_residualization_alphas(content)
    if content == "explicit"
        return ["$medium/$task/ExplicitUserItemBiases"]
    else
        return String[]
    end
end

function get_epoch_labels(split, task, content, medium)
    Y = sparse(get_split(split, task, content, medium; fields = [:user, :item, :rating]))
    if content == "explicit"
        Z = sparse(
            read_alpha(
                get_residualization_alphas(content),
                split,
                task,
                content,
                medium,
                false,
            ),
        )
        Y -= Z
    end
    Y
end;

In [None]:
function get_epoch_weights(
    split::String,
    task::String,
    content::String,
    medium::String,
    user_weight_decay::Real,
    item_weight_decay::Real,
    temporal_weight_decay::Real,
)
    if split == "training"
        weights =
            powerdecay(get_counts(split, task, content, medium), user_weight_decay) .*
            powerdecay(
                get_counts(split, task, content, medium; by_item = true),
                item_weight_decay,
            ) .* powerlawdecay(
                (
                    1 .-
                    max.(
                        get_split(
                            split,
                            task,
                            content,
                            medium;
                            fields = [:timestamp],
                        ).timestamp,
                        0.0f0,
                    )
                ) ./ year_in_timestamp_units(),
                temporal_weight_decay,
            )
    else
        weights = powerdecay(
            get_counts(split, task, content, medium),
            weighting_scheme("inverse"),
        )
    end
    df = get_split(split, task, content, medium; fields = [:user, :item])
    df = RatingsDataset(user = df.user, item = df.item, rating = weights, medium = medium)
    sparse(df)
end;

# Disk I/O

In [None]:
function create_training_config(medium, content)
    Dict(
        # model
        "input_sizes" => num_items.(ALL_MEDIUMS),
        "output_size_index" => findfirst(x -> x == medium, ALL_MEDIUMS),
        "content" => content,
        # training
        "user_weight_decay" => -0.26133174,
        "item_weight_decay" => 0.2260387,
        "temporal_weight_decay" => 0.67891073,
        "mask_rate" => 0.25762135,
        # data
        "num_shards" => 8,
    )
end;

In [None]:
function setup_training(config, outdir)
    if !isdir(outdir)
        mkpath(outdir)
    end
    for x in readdir(outdir, join = true)
        if isfile(x)
            rm(x)
        end
    end
    fn = joinpath(outdir, "..", "config.json")
    open(fn * "~", "w") do f
        write(f, JSON.json(config))
    end
    mv(fn * "~", fn, force = true)
end;

In [None]:
function save_features(X, Y, W, users, epoch_size, filename)
    d = Dict{String,Any}()
    data = [X, Y, W]
    names = ["inputs", "labels", "weights"]
    for i = 1:length(names)
        record_sparse_array!(d, names[i], data[i])
    end
    d["users"] = users
    d["epoch_size"] = epoch_size
    counts = sum(W, dims = 1)
    d["valid_users"] = [x for x in 1:length(counts) if counts[x] > 0]
    HDF5.h5open(filename, "w") do file
        for (k, v) in d
            write(file, k, v)
        end
    end
end

function record_sparse_array!(d::Dict, name::String, x::AbstractSparseArray)
    i, j, v = SparseArrays.findnz(x)
    d[name*"_i"] = i
    d[name*"_j"] = j
    d[name*"_v"] = v
    d[name*"_size"] = [size(x)[1], size(x)[2]]
end;

# Run

In [None]:
config = create_training_config(medium, content);

In [None]:
function save_split(split)
    @info "loading $split data"
    outdir = get_data_path(joinpath("alphas", name, split))
    setup_training(config, outdir)
    X = get_epoch_inputs(task, ["$x/$task/ExplicitUserItemBiases" for x in ALL_MEDIUMS])
    if split == "inference"
        Y = get_epoch_labels("test", task, content, medium)
        W = 0
        for content in ["implicit", "negative"]
            W =
                get_epoch_weights(
                    "test",
                    task,
                    content,
                    medium,
                    config["user_weight_decay"],
                    config["item_weight_decay"],
                    config["temporal_weight_decay"],
                ) .+ W
        end
        num_shards = 1
    else
        Y = get_epoch_labels(split, task, content, medium)
        W = get_epoch_weights(
            split,
            task,
            content,
            medium,
            config["user_weight_decay"],
            config["item_weight_decay"],
            config["temporal_weight_decay"],
        )
        num_shards = config["num_shards"]
    end
    splits =
        collect(Iterators.partition(1:num_users(), div(num_users(), num_shards, RoundUp)))
    for i = 1:length(splits)
        save_features(
            X[:, splits[i]],
            Y[:, splits[i]],
            W[:, splits[i]],
            collect(splits[i]),
            sum(sum(W, dims = 1) .> 0),
            "$outdir/data.$i.h5",
        )
    end
end;

In [None]:
save_split.(["training", "validation", "test", "inference"]);

In [None]:
GC.gc()

In [None]:
for mode in ["pretrain", "finetune", "inference"]
    run(`python3 Pytorch.py --outdir $name --mode $mode`)
end

In [None]:
for split in ["training", "validation", "test"]
    outdir = get_data_path(joinpath("alphas", name, split))
    rm(outdir, recursive = true)
end

# Save

In [None]:
file = HDF5.h5open(get_data_path(joinpath("alphas", name, "predictions.h5")), "r")
predictions = read(file["predictions"])
users = read(file["users"])
close(file)

In [None]:
user_to_index = Dict()
for i = 1:length(users)
    user_to_index[users[i]] = i
end

In [None]:
# note that we only record predictions for test users
function model(users, items, predictions, user_to_index)
    ratings = zeros(Float32, length(users))
    @showprogress for i = 1:length(ratings)
        u = user_to_index[users[i]]
        ratings[i] = predictions[items[i], u]
    end
    ratings
end;

In [None]:
write_alpha(
    (users, items) -> model(users, items, predictions, user_to_index),
    medium,
    name,
    splits = ["test"];
    task = task,
    log = true,
    log_task = task,
    log_content = content,
    log_alphas = get_residualization_alphas(content),
    log_splits = ["test"],
)