# Write the splits in a Julia optimized format

In [1]:
using CSV
using DataFrames
using JLD2
using JupyterFormatter
using ProgressMeter

In [2]:
enable_autoformat();

In [3]:
struct RatingsDataset
    user::Vector{Int32}
    item::Vector{Int32}
    rating::Vector{Float32}
end;

In [4]:
function get_dataset(file)
    df = DataFrame(CSV.File(file))
    df.username .+= 1 # julia is 1 indexed
    df.anime_id .+= 1
    df.my_score = float(df.my_score)
    RatingsDataset(df.username, df.anime_id, df.my_score)
end

get_split(split) = get_dataset("../../data/splits/$(split).csv");

## Load splits

In [5]:
@time training = get_split("training");
@time validation = get_split("validation");
@time test = get_split("test");

 21.050536 seconds (9.57 M allocations: 13.413 GiB, 3.83% gc time, 0.58% compilation time)
  1.179054 seconds (4.82 k allocations: 685.104 MiB, 69.92% gc time)
  0.195462 seconds (2.47 k allocations: 684.776 MiB)


## Load implicit splits

In [6]:
@time implicit = get_dataset("../../data/processed_data/user_implicit_lists.csv")

  6.642972 seconds (2.24 k allocations: 14.748 GiB, 7.73% gc time)


RatingsDataset(Int32[851625, 851625, 851625, 851625, 851625, 851625, 851625, 851625, 851625, 851625  …  369188, 369188, 369188, 369188, 369188, 369188, 369188, 369188, 369188, 369188], Int32[11528, 498, 805, 41, 12807, 7670, 101, 3963, 116, 6464  …  14194, 3, 6082, 774, 14254, 7169, 15584, 9017, 14423, 369], Float32[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])

In [7]:
function generate_index(split)
    indices = Set()
    @showprogress for i = 1:length(split.rating)
        push!(indices, (split.user[i], split.item[i]))
    end
    indices
end
out_of_sample_indices = union(generate_index(validation), generate_index(test));

insample_mask = fill(true, length(implicit.rating))
@time Threads.@threads for i = 1:length(implicit.rating)
    insample_mask[i] = (implicit.user[i], implicit.item[i]) ∉ out_of_sample_indices
end

implicit_training = RatingsDataset(
    implicit.user[insample_mask],
    implicit.item[insample_mask],
    implicit.rating[insample_mask],
);

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:08[39m
[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:14[39m


 23.285416 seconds (2.10 G allocations: 38.211 GiB, 48.46% gc time, 0.14% compilation time)


In [8]:
file = "../../data/splits/splits.jld2";
@time jldsave(file; training, validation, test, implicit, implicit_training);

 65.107023 seconds (13.34 M allocations: 738.510 MiB, 3.79% gc time, 9.67% compilation time)
