In [None]:
import JupyterFormatter
JupyterFormatter.enable_autoformat();

In [None]:
import NBInclude: @nbinclude
@nbinclude("Data.ipynb");

In [None]:
const MAX_OUPUT_ITEMS = 5
const MAX_OUPUT_DAYS = 14

In [None]:
function get_update_map(dataset::String)
    dfs = load_datasets(dataset)
    update_maps = Vector{Dict{Int32,Float64}}(undef, length(dfs))
    @showprogress Threads.@threads for i = 1:length(dfs)
        update_map = Dict{Int32,Float64}()
        df = dfs[i]
        for (u, t) in zip(df.userid, df.updated_at)
            if u ∉ keys(update_map)
                update_map[u] = 0
            end
            update_map[u] = max(update_map[u], t)
        end
        update_maps[i] = update_map
    end
    updates = merge(update_maps...)
    userids = read_csv(get_data_path("processed_data/$dataset.userid_map.csv"))
    userids[!, :userid] = parse.(Int32, userids.userid)
    userids[!, :updated_at] = [get(updates, i, 0) for i in userids.userid]
    userids
end;

In [None]:
function get_valid_users()
    s = get_update_map("streaming")
    t = get_update_map("test")
    df = DataFrames.rightjoin(s, t, on = [:source, :username], renamecols = "_s" => "_t")
    df = filter(x -> x.updated_at_t .> coalesce(x.updated_at_s, 0), df)
    df[!, :userid] = 1:DataFrames.nrow(df)
    CSV.write(get_data_path("processed_data/causal.userid_map.csv"), df)
    df
end;

In [None]:
function filter_users(dataset, users)
    userids = Dict{Int32,Int32}
    source_col = Dict("streaming" => users.userid_s, "test" => users.userid_t)
    userid_map = Dict{Int32,Int32}(
        k => v for (k, v) in zip(source_col[dataset], users.userid) if !ismissing(k)
    )
    dfs = load_datasets(dataset)
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i].userid .= dfs[i].userid .|> x -> get(userid_map, x, 0)
        dfs[i] = subset(dfs[i], dfs[i].userid .!= 0)
    end
    dfs
end;

In [None]:
function filter_causal!(dfs, users)
    streaming_map = Dict{Int32,Float64}(
        k => v for (k, v) in zip(users.userid, users.updated_at_s) if !ismissing(v)
    )
    @showprogress Threads.@threads for i = 1:length(dfs)
        df = dfs[i]
        mask = [t > get(streaming_map, u, 0) for (t, u) in zip(df.updated_at, df.userid)]
        dfs[i] = subset(df, mask)
    end
end;

In [None]:
function filter_duplicates!(input_dfs, output_dfs)
    userids = union([Set(df.userid) for df in vcat(input_dfs, output_dfs)]...)
    seen_items = Dict(u => Set() for u in userids)
    @showprogress for df in input_dfs
        for (s, u, i) in zip(df.status, df.userid, df.itemid)
            if s != STATUS_MAP["planned"]
                push!(seen_items[u], i)
            end
        end
    end
    @showprogress Threads.@threads for i = 1:length(output_dfs)
        df = output_dfs[i]
        mask = [a ∉ seen_items[u] for (u, a) in zip(df.userid, df.itemid)]
        output_dfs[i] = subset(df, mask)
    end
end;

In [None]:
function filter_recent!(dfs, ts_cutoff)
    update_maps = [Dict{Int32,Float64}() for _ = 1:length(dfs)]
    @showprogress Threads.@threads for i = 1:length(dfs)
        df = dfs[i]
        update_map = update_maps[i]
        for (u, t) in zip(df.userid, df.updated_at)
            if u ∉ keys(update_map)
                update_map[u] = Inf
            end
            update_map[u] = min(update_map[u], t)
        end
        update_maps[i] = update_map
    end
    valid_userids = Set()
    for map in update_maps
        for (u, t) in map
            if t >= ts_cutoff
                push!(valid_userids, u)
            end
        end
    end
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i] = subset(dfs[i], dfs[i].userid .∈ (valid_userids,))
    end
end;

In [None]:
seed_rng!("Preprocess/ImportLists/CausalData");

In [None]:
users = get_valid_users();

In [None]:
train_dfs = Vector{RatingsDataset}()
test_input_dfs = filter_users("streaming", users);
test_output_dfs = filter_users("test", users);

In [None]:
filter_causal!(test_output_dfs, users);

In [None]:
filter_duplicates!(test_input_dfs, test_output_dfs);

In [None]:
_, test_output_dfs = create_splits!(test_output_dfs, 0, MAX_OUPUT_ITEMS, false);

In [None]:
ts_cutoff =
    get_max_valid_ts("streaming") - (24 * 60 * 60 * MAX_OUPUT_DAYS) / (MAX_TS - MIN_TS)
filter_recent!(test_output_dfs, ts_cutoff);

In [None]:
filter_input!(test_input_dfs, test_output_dfs);

In [None]:
relabel_userids!(train_dfs, test_input_dfs, test_output_dfs, "causal");

In [None]:
save_dataset(train_dfs, test_input_dfs, test_output_dfs, "causal");