# Generate Splits
* We split the dataset into explicit interactions, implicit interactions, and plan-to-watch interactions
* Each of the above splits is further separated into a training, validation, and test split
* If a user has over 20 interactions, then all but the most recent 10 will be in the training split
* If a user has fewer than 20 interactions, then half of their interactions will be in the training split
* The remaining interactions are randomly split betweeen the validation and test sets

In [None]:
import os
import random

import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
source_dir = "../../data/processed_data"

In [None]:
outdir = "../../data/splits"
os.makedirs(outdir, exist_ok=True)

In [None]:
random.seed(20220128)

In [None]:
def shard_by_user(file, num_shards):
    try:
        outfiles = []
        for i in range(num_shards):
            outfiles.append(open(f"{file}.shard.{i}", "w"))
        with open(file, "r") as in_file:
            header = False
            for line in tqdm(in_file):
                if not header:
                    header = True
                    for f in outfiles:
                        f.write(line)
                    continue
                fields = line.strip().split(",")
                user = fields[0]
                outfiles[hash(user) % num_shards].write(line)
    finally:
        for f in outfiles:
            f.close()

In [None]:
def temporal_sort(input_fn, output_fn):
    user_anime_lists = pd.read_csv(input_fn)
    user_anime_lists["rng"] = [random.random() for _ in range(len(user_anime_lists))]
    user_anime_lists = user_anime_lists.sort_values(
        by=["username", "timestamp", "rng"]
    ).reset_index(drop=True)
    user_anime_lists = user_anime_lists.drop("rng", axis=1)
    user_anime_lists["count"] = 1
    user_anime_lists["order"] = user_anime_lists.groupby("username")["count"].cumsum()
    user_anime_lists = user_anime_lists.drop("count", axis=1)
    user_anime_lists.to_csv(output_fn, index=False)

In [None]:
def sharded_temporal_sort(input_fn, output_fn, num_shards=16):
    shard_by_user(input_fn, num_shards)
    for i in tqdm(range(num_shards)):
        temporal_sort(f"{input_fn}.shard.{i}", f"{output_fn}.shard.{i}")
        os.remove(f"{input_fn}.shard.{i}")
    with open(output_fn, "w") as outfile:
        for i in tqdm(range(num_shards)):
            fn = f"{output_fn}.shard.{i}"
            with open(fn, "r") as infile:
                header = False
                for line in infile:
                    if not header:
                        header = True
                        if i == 0:
                            outfile.write(line)
                        continue
                    outfile.write(line)
            os.remove(fn)

In [None]:
def get_user_counts(input_fn):
    user_to_count = {}
    with open(input_fn, "r") as in_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                continue
            fields = line.strip().split(",")
            user = fields[0]
            if user not in user_to_count:
                user_to_count[user] = 0
            user_to_count[user] += 1
    return user_to_count

In [None]:
# splits the input file into a training set and a test set
def temporal_split(input_fn, train_fn, test_fn, test_samples_per_user, p_rampup):
    user_counts = get_user_counts(input_fn)
    seen_counts = {u: 0 for u in user_counts}
    with open(input_fn, "r") as in_file, open(train_fn, "w") as training, open(
        test_fn, "w"
    ) as test:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                training.write(line)
                test.write(line)
                continue
            username = line.strip().split(",")[0]
            if user_counts[username] * (1 - p_rampup) >= test_samples_per_user:
                to_training_split = (
                    seen_counts[username] + test_samples_per_user
                    < user_counts[username]
                )
            else:
                to_training_split = (
                    seen_counts[username] + random.random()
                    < user_counts[username] * p_rampup
                )
            if to_training_split:
                training.write(line)
            else:
                test.write(line)
            seen_counts[username] += 1

In [None]:
def random_split(input_fn, train_fn, test_fn, p_training):
    with open(input_fn, "r") as in_file, open(train_fn, "w") as training, open(
        test_fn, "w"
    ) as test:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                training.write(line)
                test.write(line)
                continue
            if random.random() < p_training:
                training.write(line)
            else:
                test.write(line)

In [None]:
# scans every line of the input file. If the line satisfies the
# condition, then it is removed from the input file and written
# to the output file
def subset(input_fn, output_fn, condition):
    with open(input_fn, "r") as in_file, open(output_fn, "w") as out_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                out_file.write(line)
                continue
            if condition(line):
                out_file.write(line)

## Construct training/validation/test splits

In [None]:
def generate_temporal_splits():
    outfn = "user_anime_list_sorted"
    sharded_temporal_sort(
        os.path.join(source_dir, "user_anime_list.csv"),
        os.path.join(outdir, f"{outfn}.csv"),
    )
    temporal_split(
        os.path.join(outdir, f"{outfn}.csv"),
        os.path.join(outdir, f"{outfn}_training.csv"),
        os.path.join(outdir, f"{outfn}_valtest.csv"),
        10,
        0.5,
    )
    random_split(
        os.path.join(outdir, f"{outfn}_valtest.csv"),
        os.path.join(outdir, f"{outfn}_validation.csv"),
        os.path.join(outdir, f"{outfn}_test.csv"),
        0.5,
    )
    os.remove(os.path.join(outdir, f"{outfn}_valtest.csv"))

In [None]:
def generate_content_splits(split):
    with open(os.path.join(outdir, f"user_anime_list_sorted_{split}.csv")) as f:
        first_line = f.readline()
    fields = first_line.strip().split(",")
    score_col = fields.index("score")
    status_col = fields.index("status")
    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"ptw_{split}.csv"),
        lambda x: x.strip().split(",")[status_col] == "1",
    )
    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"implicit_{split}.csv"),
        lambda x: float(x.strip().split(",")[score_col]) == 0
        and x.strip().split(",")[status_col] != "1",
    )
    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"explicit_{split}.csv"),
        lambda x: float(x.strip().split(",")[score_col]) != 0
        and x.strip().split(",")[status_col] != "1",
    )

In [None]:
def get_max_ids(input_fn):
    max_userid = -1
    max_itemid = -1
    with open(input_fn, "r") as in_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                continue
            fields = line.strip().split(",")
            userid = fields[0]
            itemid = fields[1]
            userid, itemid = int(userid), int(itemid)
            max_userid = max(max_userid, int(userid))
            max_itemid = max(max_itemid, int(itemid))
    return max_userid, max_itemid

## Write splits 

In [None]:
generate_temporal_splits()

In [None]:
for split in ["training", "validation", "test"]:
    generate_content_splits(split)

In [None]:
max_userid, max_itemid = get_max_ids(os.path.join(outdir, "user_anime_list_sorted.csv"))
with open(os.path.join(source_dir, "uid_encoding.csv"), "w") as out_file:
    out_file.write(f"max_userid,{max_userid}\n")
    out_file.write(f"max_itemid,{max_itemid}\n")