# Write the splits in a Julia optimized format
* In addition, a negative split is sampled. This set consists of (user, item) pairs that the user did not watch.
* 10% of the negative split is randomly sampled, the other 90% is weighted by item popularity

In [None]:
using JLD2

import CSV
import DataFrames: DataFrame
import JupyterFormatter: enable_autoformat
import ProgressMeter: @showprogress
import Random
import StatsBase: sample, wsample

In [None]:
enable_autoformat();

In [None]:
Random.seed!(20230406);

In [None]:
media = ""

In [None]:
ALL_TASKS = ["temporal", "temporal_causal"]

# Save Splits

In [None]:
struct RatingsDataset
    user::Vector{Int32}
    item::Vector{Int32}
    rating::Vector{Float32}
    timestamp::Vector{Float32}
    status::Vector{Int32}
    completion::Vector{Float32}
    source::Vector{Int32}
    medium::String
end;

In [None]:
function to_julia_index(x)
    x .+ 1 # julia is 1 indexed
end;

In [None]:
function get_dataset(file)
    uidcol = Symbol("$(media)id")
    intcols = Set([:username, uidcol, :status, :source])
    df = DataFrame(CSV.File(file, types = (i, name) -> name in intcols ? Int32 : Float32))
    RatingsDataset(
        df.username |> to_julia_index,
        df[:, uidcol] |> to_julia_index,
        df.score,
        df.timestamp,
        df.status,
        df.completion,
        df.source,
        media,
    )
end;

In [None]:
function get_data_path(x)
    "../../data/$x"
end

function num_users()
    open(get_data_path("processed_data/$(media)_uid_encoding.csv")) do file
        text = read(file, String)
        lines = split(text, '\n')
        fields = split(lines[1], ',')
        @assert fields[1] == "max_userid"
        max_userid = parse(Int, fields[2]) + 1
        return max_userid
    end
end

function num_items()
    open(get_data_path("processed_data/$(media)_uid_encoding.csv")) do file
        text = read(file, String)
        lines = split(text, '\n')
        fields = split(lines[2], ',')
        @assert fields[1] == "max_itemid"
        max_itemid = parse(Int, fields[2]) + 1
        return max_itemid
    end
end;

In [None]:
for content in ["explicit", "implicit", "ptw"]
    for task in ALL_TASKS
        for split in ["training", "validation", "test"]
            stem = "../../data/splits/$content.$task.$split"
            @time dataset = get_dataset("$stem.user_$(media)_list.csv")
            @time jldsave("$stem.$media.jld2"; dataset)
        end
    end
end;

# Save Negative Splits

In [None]:
function get_user_item_pairs(split, task)
    user_item_pairs = Set{Tuple{Int32,Int32}}()
    splits = ["training", split]
    contents = ["explicit", "implicit", "ptw"]

    for split in splits
        for content in contents
            file = "../../data/splits/$content.$task.$split.$media.jld2"
            df = JLD2.load(file, "dataset")
            @showprogress for k = 1:length(df.user)
                push!(user_item_pairs, (df.user[k] - 1, df.item[k] - 1))
            end
        end
    end
    user_item_pairs
end;

In [None]:
function get_users(split, task)
    users = Set()
    contents = ["explicit", "implicit", "ptw"]
    for content in contents
        file = "../../data/splits/$content.$task.$split.$media.jld2"
        df = JLD2.load(file, "dataset")
        users = users ∪ Set(df.user)
    end
    users
end;

In [None]:
function save_negative_dataset_csv(split, task, file, samples_per_user)
    user_item_pairs = get_user_item_pairs(split, task)
    io = open(file, "w")
    write(io, "username,$(media)id\n")
    valid_users = collect(get_users(split, task)) .- 1
    user_idx = 1
    M = num_items()

    @showprogress for _ = 1:(samples_per_user*length(valid_users))
        user = -1
        item = -1
        while (user, item) in user_item_pairs || user == -1 || item == -1
            user = valid_users[user_idx]
            user_idx = (user_idx % length(valid_users)) + 1
            item = sample(0:M-1)
        end
        push!(user_item_pairs, (user, item))
        write(io, "$(user),$(item)\n")
    end
    close(io)
end;

In [None]:
function get_negative_dataset(file)
    df = DataFrame(CSV.File(file))
    RatingsDataset(
        df.username |> to_julia_index,
        df[:, Symbol("$(media)id")] |> to_julia_index,
        [],
        [],
        [],
        [],
        [],
        media,
    )
end;

In [None]:
for split in ["validation", "test"]
    if split == "validation"
        samples_per_user = 100
    elseif split == "test"
        samples_per_user = 1000
    else
        @assert false
    end
    for task in ALL_TASKS
        stem = "../../data/splits/negative.$task.$split.$media"
        save_negative_dataset_csv(split, task, "$stem.csv", samples_per_user)
        dataset = get_negative_dataset("$stem.csv")
        jldsave("$stem.jld2"; dataset)
    end
end;