# Write the splits in a Julia optimized format
* In addition, a negative split is sampled. This set consists of (user, item) pairs that the user did not watch.

In [None]:
using JLD2

import CSV
import DataFrames: DataFrame
import JupyterFormatter: enable_autoformat
import Memoize: @memoize
import ProgressMeter: @showprogress
import Random
import StatsBase: wsample

In [None]:
enable_autoformat();

In [None]:
Random.seed!(20231112);

# Save Splits

In [1]:
@kwdef struct RatingsDataset
    source::Vector{Int32}
    medium::Vector{Int32}
    userid::Vector{Int32}
    mediaid::Vector{Int32}
    status::Vector{Int32}
    rating::Vector{Float32}
    update_order::Vector{Int32}    
    updated_at::Vector{Float32}
    created_at::Vector{Float32}
    started_at::Vector{Float32}
    finished_at::Vector{Float32}    
    progress::Vector{Float32}
    repeat_count::Vector{Int32}
    priority::Vector{Int32}
    sentiment::Vector{Int32}
    sentiment_score::Vector{Float32}
    owned::Vector{Float32}
end;

In [None]:
to_julia_index(x) = x .+ 1; # julia is 1 indexed

In [None]:
function get_dataset(file)
    uidcol = Symbol("$(media)id")
    intcols = Set([:source, :medium, :userid, :mediaid, :status, :update_order, :repeat_count, :priority, :sentiment])
    df = DataFrame(CSV.File(file, types = (_, name) -> name in intcols ? Int32 : Float32))
    RatingsDataset(
        source = df.source,
        medium = df.medium,
        userid = df.userid |> to_julia_index,
        mediaid= df.mediaid |> to_julia_index,
        status= df.status,
        rating = df.rating,
        update_order = df.update_order,
        updated_at = df.updated_at,
        created_at = df.created_at,
        started_at = df.started_at,
        finished_at = df.finished_at,
        progress= df.progress,
        repeat_count= df.repeat_count,
        priority= df.priority,
        sentiment= df.sentiment,
        sentiment_score= df.sentiment_score,
        owned= df.owned,
    )
end;

In [None]:
function get_data_path(x)
    "../../data/$x"
end

@memoize
function num_users()
    df = DataFrame(CSV.File(get_data_path("processed_data/username_to_uid.csv")))
    length(df.uid)
end

@memoize
function num_items(medium)
    df = DataFrame(CSV.File(get_data_path("processed_data/$(medium)_to_uid.csv")))
    length(df.uid)
end;

In [None]:
ALL_MEDIUMS = ["manga", "anime"]
ALL_SPLITS = ["training", "validation", "test"]

In [None]:
for medium in ALL_MEDIUMS
    for split in ALL_SPLITS
        stem = get_data_path("splits/$split")
        @time dataset = get_dataset("$stem.user_$(media)_list.csv")
        jldsave("$stem.$media.jld2"; dataset)
    end
end

# Save Negative Splits

In [None]:
function get_items_per_user(medium)
    items_per_user = Dict{Int32, Set{Int32}}()
    for split in ALL_SPLITS
        file = get_data_path("splits/$split.$medium.jld2")
        df = JLD2.load(file, "dataset")
        @showprogress for i = 1:length(df.user)
            u = df.user[i]
            if u ∉ user_item_pairs
                user_item_pairs[u] = []
            end
            push!(items_per_user[u], df.item[k])
        end
    end
    items_per_user
end;

In [None]:
function get_nontrivial_users(medium)
    file = "../../data/splits/test.$medium.jld2"
    df = JLD2.load(file, "dataset")
    Set(df.user)
end;

In [None]:
function save_negative_dataset_csv(file, medium, samples_per_user)
    items_per_user = get_items_per_user(medium)
    io = open(file, "w")
    write(io, "medium,username,mediaid\n")
    valid_users = collect(get_nontrivial_users(medium))
    @showprogress for user in valid_users        
        # TODO test mixed negative sampling
        # uniform sampling
        weights = zeros(1/num_items(media), num_items(media))
        for i in items_per_user[user]
            weights[i] = 0
        end
        
        num_samples = min(samples_per_user, length(valid_items))
        for item in wsample(1:num_items(media), weights, num_samples; replace=false)
            write(io, "$medium,$(user-1),$(item-1)\n")
        end
    end
    close(io)
end;

In [None]:
function get_negative_dataset(file)
    df = DataFrame(CSV.File(file))
    RatingsDataset(
        source = [],
        medium = df.medium,
        userid = df.userid |> to_julia_index,
        mediaid= df.mediaid |> to_julia_index,
        status= [],
        rating = [],
        update_order = [],
        updated_at = [],
        created_at = [],
        started_at = [],
        finished_at = [],
        progress= [],
        repeat_count= [],
        priority= [],
        sentiment= [],
        sentiment_score= [],
        owned= [],
    )    
end;

In [None]:
for split in ["test"]
    for task in ALL_TASKS
        stem = get_data_path("splits/negative.$media")
        save_negative_dataset_csv("$stem.csv", 1000)
        dataset = get_negative_dataset("$stem.csv")
        jldsave("$stem.jld2"; dataset)
    end
end;