# Write the splits in a Julia optimized format
* In addition, a negative split is sampled. This set consists of (user, item) pairs that the user did not watch.
* 10% of the negative split is randomly sampled, the other 90% is weighted by item popularity

In [None]:
using JLD2

import CSV
import DataFrames: DataFrame
import JupyterFormatter: enable_autoformat
import ProgressMeter: @showprogress
import StatsBase: sample, wsample

In [None]:
enable_autoformat();

## RatingsDataset

In [None]:
struct RatingsDataset
    user::Vector{Int32}
    item::Vector{Int32}
    rating::Vector{Float32}
    timestamp::Vector{Float32}
    user_timestamp::Vector{Float32}
    item_timestamp::Vector{Float32}
    status::Vector{Int32}
    completion::Vector{Float32}
    rewatch::Vector{Int32}
    source::Vector{Int32}
    order::Vector{Int32}
end;

In [None]:
function get_dataset(file)
    df = DataFrame(CSV.File(file))
    RatingsDataset(
        df.username .+ 1, # julia is 1 indexed
        df.animeid .+ 1, # julia is 1 indexed
        df.score,
        df.timestamp,
        df.user_rel_timestamp,
        df.item_rel_timestamp,
        df.status,
        df.completion,
        df.rewatch,
        df.source,
        df.order,
    )
end;

In [None]:
function get_data_path(x)
    "../../data/$x"
end

function num_users()
    open(get_data_path("processed_data/uid_encoding.csv")) do file
        text = read(file, String)
        lines = split(text, '\n')
        fields = split(lines[1], ',')
        @assert fields[1] == "max_userid"
        max_userid = parse(Int, fields[2]) + 1
        return max_userid
    end
end

function num_items()
    open(get_data_path("processed_data/uid_encoding.csv")) do file
        text = read(file, String)
        lines = split(text, '\n')
        fields = split(lines[2], ',')
        @assert fields[1] == "max_itemid"
        max_itemid = parse(Int, fields[2]) + 1
        return max_itemid
    end
end;

## Negative Splits

In [None]:
function get_user_item_pairs(split_idx)
    user_item_pairs = Set{Tuple{Int32,Int32}}()
    splits = ["training", "validation", "test"]
    contents = ["explicit", "implicit", "ptw"]
    for split in splits[1:split_idx]
        for content in contents
            file = "../../data/splits/$(content)_$(split).jld2"
            df = JLD2.load(file, "dataset")
            @showprogress for k = 1:length(df.user)
                push!(user_item_pairs, (df.user[k], df.item[k]))
            end
        end
    end
    user_item_pairs
end;

In [None]:
function get_popularity(split_idx)
    p = zeros(Int32, num_items())
    splits = ["training", "validation", "test"]
    contents = ["explicit", "implicit", "ptw"]
    for split in splits[1:split_idx]
        for content in contents
            file = "../../data/splits/$(content)_$(split).jld2"
            df = JLD2.load(file, "dataset")
            @showprogress for k = 1:length(df.user)
                p[df.item[k]] += 1
            end
        end
    end
    p
end;

In [None]:
function save_negative_dataset_csv(split_idx, file, samples)
    user_item_pairs = get_user_item_pairs(split_idx)
    p = get_popularity(split_idx)
    max_userid = num_users()
    io = open(file, "w")
    write(io, "username,animeid\n")
    nextuser = 1
    @showprogress for _ = 1:samples
        user = -1
        item = -1
        while (user, item) in user_item_pairs || user == -1 || item == -1
            user = nextuser
            nextuser = (nextuser + 1) % max_userid
            if rand() < 0.9
                item = wsample(1:length(p), p)
            else
                item = sample(1:length(p))
            end
        end
        user -= 1
        item -= 1
        push!(user_item_pairs, (user, item))
        write(io, "$(user),$(item)\n")
    end
    close(io)
end;

In [None]:
function get_negative_dataset(file)
    df = DataFrame(CSV.File(file))
    RatingsDataset(
        df.username .+ 1, # julia is 1 indexed
        df.animeid .+ 1, # julia is 1 indexed
        [],
        [],
        [],
        [],
        [],
        [],
        [],
        [],
        [],
    )
end;

## Save splits

In [None]:
for content in ["explicit", "implicit", "ptw"]
    for split in ["training", "validation", "test"]
        file = "../../data/splits/$(content)_$(split).jld2"
        @time dataset = get_dataset("../../data/splits/$(content)_$(split).csv")
        @time jldsave(file; dataset)
    end
end;

In [None]:
for split_idx = 1:3
    splits = ["training", "validation", "test"]
    stem = "../../data/splits/negative_$(splits[split_idx])"
    save_negative_dataset_csv(split_idx, "$stem.csv", num_users() * 100)
    dataset = get_negative_dataset("$stem.csv")
    jldsave("$stem.jld2"; dataset)
end;