# Write the splits in a Julia optimized format
* In addition, a negative split is sampled. This set consists of (user, item) pairs that the user did not watch.

In [None]:
using JLD2

import CSV
import DataFrames: DataFrame
import JupyterFormatter: enable_autoformat
import Memoize: @memoize
import ProgressMeter: @showprogress
import Random
import StatsBase: wsample

In [None]:
enable_autoformat();

In [None]:
Random.seed!(20231112);

# Save Splits

In [None]:
function split_save(file, values)
    # save in multiple files to allow multithreaded reading
    for (k, v) in values
        JLD2.save("$file.$k.jld2", Dict(k => v), compress = true)
    end
end;

In [None]:
function get_dataset(file)
    data = Dict(
        "source" => Int32[],
        "medium" => Int32[],
        "userid" => Int32[],
        "mediaid" => Int32[],
        "status" => Int32[],
        "rating" => Float32[],
        "forward_order" => Int32[],
        "backward_order" => Int32[],
        "updated_at" => Float32[],
        "created_at" => Float32[],
        "started_at" => Float32[],
        "finished_at" => Float32[],
        "progress" => Float32[],
        "repeat_count" => Int32[],
        "priority" => Float32[],
        "sentiment" => Int32[],
        "sentiment_score" => Float32[],
        "owned" => Float32[],
    )
    type_parser =
        (_, name) -> String(name) in keys(data) ? eltype(data[String(name)]) : Float32
    chunks = []
    try
        # load large files chunks to reduce memory usage        
        chunks = CSV.Chunks(file, types = type_parser, ntasks = 1024)
    catch
        # small files crash CSV.Chunks, so fallback to read all 
        chunks = [CSV.File(file, types = type_parser)]
    end
    @showprogress for f in chunks
        df = DataFrame(f)
        for k in keys(data)
            append!(data[k], df[:, k])
        end
        df = nothing
        GC.gc()
    end

    # rename columns
    data["itemid"] = data["mediaid"]
    data["update_order"] = data["backward_order"]
    delete!(data, "mediaid")
    delete!(data, "forward_order")
    delete!(data, "backward_order")
    data
end;

In [None]:
function get_data_path(x)
    "../../data/$x"
end

@memoize function num_users()
    df = DataFrame(CSV.File(get_data_path("processed_data/username_to_uid.csv")))
    length(df.uid)
end

@memoize function num_items(medium)
    df = DataFrame(CSV.File(get_data_path("processed_data/$(medium)_to_uid.csv")))
    length(df.uid)
end;

In [None]:
ALL_MEDIUMS = ["manga", "anime"]
ALL_SPLITS = ["training", "validation", "test"];

In [None]:
for medium in ALL_MEDIUMS
    for split in ALL_SPLITS
        stem = get_data_path("splits/$split")
        split_save("$stem.$medium", get_dataset("$stem.user_$(medium)_list.csv");)
        GC.gc()
    end
end;

# Save Negative Splits

In [None]:
get_test(medium, col) = JLD2.load("../../data/splits/test.$medium.$col.jld2", col);

In [None]:
function get_user_to_items(medium)
    user_to_items = Dict{Int32,Set{Int32}}()
    for split in ALL_SPLITS
        users = get_test(medium, "userid")
        items = get_test(medium, "itemid")
        @showprogress for i = 1:length(users)
            u = users[i]
            if u ∉ keys(user_to_items)
                user_to_items[u] = Set()
            end
            push!(user_to_items[u], items[i])
        end
    end
    user_to_items
end;

In [None]:
function save_negative_dataset_csv(medium, samples_per_user)
    user_to_items = get_user_to_items(medium)
    valid_users = collect(Set(get_test(medium, "userid")))
    negative_users = Int32[]
    negative_items = Int32[]
    stem = "../../data/splits/negative.$medium"
    io = open("$stem.csv", "w")
    write(io, "username,itemid\n")
    @showprogress for user in valid_users
        weights = ones(num_items(medium))
        for i in user_to_items[user]
            weights[i+1] = 0
        end
        num_samples = min(samples_per_user, Int(sum(weights)))
        for item in wsample(0:num_items(medium)-1, weights, num_samples; replace = false)
            push!(negative_users, user)
            push!(negative_items, item)
            write(io, "$(user),$(item)\n")
        end
    end
    split_save("$stem", Dict("userid" => negative_users, "itemid" => negative_items))
end;

In [None]:
for medium in ALL_MEDIUMS
    save_negative_dataset_csv(medium, 10000)
end;