In [None]:
import NBInclude: @nbinclude
@nbinclude("./Data.ipynb");

In [None]:
import JSON

In [None]:
version = ""
dataset = ""
medium = ""
metric = "";

In [None]:
function sparse(x::RatingsDataset, c::Symbol, dataset::String, medium::String)
    SparseArrays.sparse(
        x.itemid,
        x.userid,
        getfield(x, c),
        num_items(medium),
        num_users(dataset),
    )
end

function get_dataset(
    dataset::String,
    split::String,
    medium::String,
    metric::String,
    epoch::String,
)
    @info "loading dataset $dataset $split $medium $metric"
    @assert dataset in ["training", "streaming"] dataset
    @assert split in ["train", "test_input"]
    df = get_split(
        dataset,
        split,
        medium,
        [:userid, :itemid, :rating, :status, :medium, :updated_at, :update_order],
    )
    if dataset == "streaming" && split == "train"
        input, output = input_output_split(
            df,
            get_timestamp(dataset, :max_ts) - get_timestamp(Dates.Day(7)),
            5,
            true,
        )
        input = as_metric(input, metric)
        output = as_metric(output, metric)
        df = Dict("input" => input, "output" => output)[epoch]
        resid = input
    else
        df = as_metric(df, metric)
        resid = df
    end
    if metric == "rating"
        baseline_dataset = dataset == "training" ? "training" : "streaming"
        params = read_params("baseline/v1/$baseline_dataset/$medium/rating")
        user_biases = get_user_biases(resid, params)
        item_biases = params["a"]
        β = params["β"]
        Threads.@threads for i = 1:length(df.metric)
            df.metric[i] -=
                (get(user_biases, df.userid[i], 0) + item_biases[df.itemid[i]]) * β
        end
    end
    df
end;

In [None]:
function get_epoch_inputs(dataset::String, version::String)
    @info "loading inputs $dataset"
    fn = get_data_path("alphas/bagofwords/$version/$dataset/inputs.h5")
    if !isfile(fn)
        # save inputs to disk
        mkpath(dirname(fn))
        X = vcat(
            [
                sparse(
                    get_dataset(dataset, "train", medium, metric, "input"),
                    :metric,
                    dataset,
                    medium,
                ) for metric in ["rating", "watch"] for medium in ALL_MEDIUMS
            ]...,
        )
        d = Dict{String,Any}()
        record_sparse_array!(d, "inputs", X)
        HDF5.h5open(fn, "w") do file
            for (k, v) in d
                file[k] = v
            end
        end
    end
    HDF5.h5open(fn, "r") do f
        g(x) = read(f[x])
        return SparseArrays.sparse(
            g("inputs_i"),
            g("inputs_j"),
            g("inputs_v"),
            g("inputs_size")...,
        )
    end
end;

In [None]:
function get_counts(df, col)
    data = getfield(df, col)
    counts = StatsBase.countmap(data)
    Int32[counts[x] for x in data]
end

function get_weights(df, λ_wu, λ_wa, λ_wt)
    users = get_counts(df, :userid)
    items = get_counts(df, :itemid)
    w = Vector{typeof(λ_wt)}(undef, length(users))
    Threads.@threads for i = 1:length(w)
        w[i] = (users[i]^λ_wu) * (items[i]^λ_wa) * (λ_wt^(1 - df.updated_at[i]))
    end
    w
end

function get_epoch_outputs(
    dataset::String,
    medium::String,
    metric::String,
    λ::Vector{Float32},
)
    @info "loading outputs $dataset $medium $metric"
    df = get_dataset(dataset, "train", medium, metric, "output")
    if dataset == "training"
        λ_wu, λ_wa, λ_wt = λ
        df.updated_at .= get_weights(df, λ_wu, λ_wa, sigmoid(λ_wt))
    else
        df.updated_at .= Float32[1.0 / x for x in get_counts(df, :userid)]
    end
    Y = sparse(df, :metric, dataset, medium)
    W = sparse(df, :updated_at, dataset, medium)
    Y, W
end;

In [None]:
function save_dataset(dataset, medium, metric, version)
    seed_rng!("Train/Baseline/Train/$dataset/$medium/$metric")
    name = "bagofwords/$version/$dataset/$medium/$metric"
    outdir = get_data_path(joinpath("alphas", name))
    if !isdir(outdir)
        mkpath(outdir)
    end
    config = Dict(
        "input_sizes" => num_items.(ALL_MEDIUMS),
        "output_index" => findfirst(x -> x == medium, ALL_MEDIUMS),
        "metric" => metric,
    )
    open("$outdir/config.json", "w") do f
        write(f, JSON.json(config))
    end
    X = get_epoch_inputs(dataset, version)
    logit(p) = log(p / (1 - p))
    λ_wt = logit(exp(log(0.5) / get_timestamp(Dates.Day(365))))
    Y, W = get_epoch_outputs(dataset, medium, metric, Float32[0, 0, λ_wt])

    userids = Random.shuffle([x for x in 1:size(X)[2] if sum(W[:, x]) != 0])
    test_frac = Dict("training" => 0.01, "streaming" => 0.1)[dataset]
    n_train = Int(round(length(userids) * (1 - test_frac)))
    train_userids = userids[1:n_train]
    test_userids = userids[n_train+1:end]

    save_features(
        X[:, train_userids],
        Y[:, train_userids],
        W[:, train_userids],
        train_userids,
        "$outdir/train.h5",
    )
    save_features(
        X[:, test_userids],
        Y[:, test_userids],
        W[:, test_userids],
        test_userids,
        "$outdir/test.h5",
    )
end;

In [None]:
save_dataset(dataset, medium, metric, version);