# Split the data into training/validation/test sets
* For each user, we hold out approximately 10 items for the validation set and 10 items for the test set
* If the user has fewer than 100 items, then we randomly assign their entries with an 80/10/10 training/validation/test split

In [1]:
import os
import random

from tqdm import tqdm

In [2]:
source_dir = "../../data/processed_data"

In [3]:
outdir = "../../data/splits"
if not os.path.exists(outdir):
    os.mkdir(outdir)

In [4]:
def get_input_file(file):
    return open(os.path.join(source_dir, file), "r")


def get_output_file(file):
    return open(os.path.join(outdir, file), "w")

In [5]:
import pandas as pd
df = pd.read_csv(os.path.join(source_dir, "user_anime_lists.csv"))

In [6]:
item_counts = df.groupby("username").size().to_dict()

In [7]:
def get_split_ratio(count):
    if count < 100:
        return (0.8, 0.1, 0.1)
    else:
        p = 10 / count
        return (1 - 2*p, p, p)

In [8]:
random.seed(20220128)

In [9]:
with get_input_file("user_anime_lists.csv") as in_file, get_output_file(
    "training.csv"
) as training, get_output_file("validation.csv") as validation, get_output_file(
    "test.csv"
) as test:
    header = False
    for line in tqdm(in_file):
        if not header:
            header = True
            training.write(line)
            validation.write(line)
            test.write(line)
            continue
        user = int(line.split(",")[0])
        p_training, p_validation, p_test = get_split_ratio(item_counts[user])
        sample = random.random()
        if sample < p_training:
            training.write(line)
        elif sample < p_training+p_validation:
            validation.write(line)
        else:
            test.write(line)

168626003it [03:45, 746724.13it/s]
