In [None]:
import JupyterFormatter
JupyterFormatter.enable_autoformat();

In [None]:
import NBInclude: @nbinclude
@nbinclude("Data.ipynb");

In [None]:
const CUTOFF_ITEMS = 5
const CUTOFF_DAYS = 14

In [None]:
function get_update_map(fn)
    dfs = load_datasets(fn)
    update_maps = [Dict{Int32,Float64}() for _ = 1:length(dfs)]
    @showprogress Threads.@threads for i = 1:length(dfs)
        df = dfs[i]
        update_map = update_maps[i]
        for (u, t) in zip(df.userid, df.updated_at)
            if u ∉ keys(update_map)
                update_map[u] = 0
            end
            update_map[u] = max(update_map[u], t)
        end
        update_maps[i] = update_map
    end
    merge(update_maps...)
end;

In [None]:
function get_valid_users()
    test_map = get_update_map("test_data")
    streaming_map = get_update_map("streaming_data")
    valid_users = Set{Int32}()
    @showprogress for x in keys(streaming_map)
        if get(test_map, x, 0) > streaming_map[x]
            push!(valid_users, x)
        end
    end
    valid_users
end;

In [None]:
function filter_users(fn, valid_userids)
    dfs = load_datasets(fn)
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i] = subset(dfs[i], dfs[i].userid .∈ (valid_userids,))
    end
    dfs
end;

In [None]:
function filter_time!(dfs)
    streaming_map = get_update_map("streaming_data")
    @showprogress Threads.@threads for i = 1:length(dfs)
        df = dfs[i]
        mask = [t > streaming_map[u] for (t, u) in zip(df.updated_at, df.userid)]
        dfs[i] = subset(df, mask)
    end
end;

In [None]:
function filter_duplicates!(train_dfs, test_dfs)
    seen_items = Dict()
    for df in train_dfs
        for u in Set(df.userid)
            seen_items[u] = Set()
        end
    end
    for df in train_dfs
        for (s, u, i) in zip(df.status, df.userid, df.itemid)
            if s != STATUS_MAP["planned"]
                push!(seen_items[u], i)
            end
        end
    end
    @showprogress Threads.@threads for i = 1:length(test_dfs)
        df = test_dfs[i]
        mask = [a ∉ seen_items[u] for (u, a) in zip(df.userid, df.itemid)]
        test_dfs[i] = subset(df, mask)
    end
end;

In [None]:
function filter_count!(dfs, cutoff_items)
    @showprogress Threads.@threads for t = 1:length(dfs)
        df = dfs[t]
        df = subset(df, sortperm(collect(zip(df.userid, df.updated_at, df.update_order))))
        userid = nothing
        num_val = [0 for _ in MEDIUM_MAP]
        val_mask = [false for _ = 1:length(df.userid)]
        for i in reverse(1:length(df.userid))
            if userid != df.userid[i]
                userid = df.userid[i]
                num_val .= 0
            end
            if (num_val[df.medium[i]+1] < cutoff_items)
                num_val[df.medium[i]+1] += 1
                val_mask[i] = true
            end
        end
        dfs[t] = subset(df, val_mask)
    end
end;

In [None]:
function filter_recent!(dfs, max_valid_ts, cutoff_days)
    update_maps = [Dict{Int32,Float64}() for _ = 1:length(dfs)]
    @showprogress Threads.@threads for i = 1:length(dfs)
        df = dfs[i]
        update_map = update_maps[i]
        for (u, t) in zip(df.userid, df.updated_at)
            if u ∉ keys(update_map)
                update_map[u] = Inf
            end
            update_map[u] = min(update_map[u], t)
        end
        update_maps[i] = update_map
    end
    ts_cutoff = max_valid_ts - (24 * 60 * 60 * cutoff_days) / (MAX_TS - MIN_TS)
    valid_userids = Set()
    for map in update_maps
        for (u, t) in map
            if t >= ts_cutoff
                push!(valid_userids, u)
            end
        end
    end
    @showprogress Threads.@threads for i = 1:length(dfs)
        dfs[i] = subset(dfs[i], dfs[i].userid .∈ (valid_userids,))
    end
end;

In [None]:
function filter_sparse!(train_dfs, test_dfs)
    valid_userids = union([Set(df.userid) for df in test_dfs]...)
    @showprogress Threads.@threads for i = 1:length(train_dfs)
        df = train_dfs[i]
        train_dfs[i] = subset(df, df.userid .∈ (valid_userids,))
    end
end;

In [None]:
valid_users = get_valid_users();
train_dfs = filter_users("streaming_data", valid_users);
test_dfs = filter_users("test_data", valid_users);

In [None]:
filter_time!(test_dfs);

In [None]:
relabel_userids!(train_dfs)
relabel_userids!(test_dfs);

In [None]:
filter_duplicates!(train_dfs, test_dfs)

In [None]:
filter_count!(test_dfs, CUTOFF_ITEMS)

In [None]:
filter_recent!(test_dfs, get_max_valid_ts("test"), CUTOFF_DAYS)

In [None]:
filter_sparse!(train_dfs, test_dfs)

In [None]:
save_splits(train_dfs, test_dfs, "causal");