# Generate Splits
* We split the dataset into explicit interactions, implicit interactinos, and plan-to-watch interactions
* Each of the above splits is further separated into a training, validation, and test split
* The training/validation/test splits are temporally separated per user, and are in an approximately 90/5/5 ratio
* In addition, a negative split is sampled. This set consists of (user, item) pairs that the user did not watch.

In [None]:
import os
import random

import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
source_dir = "../../data/processed_data"

In [None]:
outdir= "../../data/splits"
os.makedirs(outdir, exist_ok=True)

In [None]:
random.seed(20220128)

In [None]:
def temporal_sort(input_fn, output_fn):
    user_anime_lists = pd.read_csv(input_fn)
    # fast pseudorandom shuffle
    rng = list(range(len(user_anime_lists)))
    c = 1
    p = 15485863 # a prime number
    n = len(rng)
    for i in tqdm(range(n)):
        rng[i] = c
        c = (c * p) % n
    user_anime_lists["rng"] = rng
    user_anime_lists = user_anime_lists.sort_values(by=["timestamp", "rng"])
    user_anime_lists = user_anime_lists.drop("rng", axis=1)
    user_anime_lists.to_csv(output_fn, index=False)

In [None]:
def get_user_counts(input_fn):
    user_to_count = {}
    with open(input_fn, "r") as in_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                continue
            fields = line.strip().split(",")
            user = fields[0]
            if user not in user_to_count:
                user_to_count[user] = 0
            user_to_count[user] += 1
    return user_to_count

In [None]:
# splits the input file into a training set and a test set
def temporal_split(input_fn, train_fn, test_fn, p_training):
    user_counts = get_user_counts(input_fn)
    seen_counts = {u: 0 for u in user_counts}
    with open(input_fn, "r") as in_file, open(train_fn, "w") as training, open(
        test_fn, "w"
    ) as test:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                training.write(line)
                test.write(line)
                continue
            username = line.strip().split(",")[0]
            if (
                seen_counts[username] + random.random()
                < user_counts[username] * p_training
            ):
                training.write(line)
            else:
                test.write(line)
            seen_counts[username] += 1

In [None]:
def random_split(input_fn, train_fn, test_fn, p_training):
    with open(input_fn, "r") as in_file, open(train_fn, "w") as training, open(
        test_fn, "w"
    ) as test:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                training.write(line)
                test.write(line)
                continue
            if random.random() < p_training:
                training.write(line)
            else:
                test.write(line)

In [None]:
# scans every line of the input file. If the line satisfies the
# condition, then it is removed from the input file and written
# to the output file
def subset(input_fn, output_fn, condition):
    with open(input_fn, "r") as in_file, open(output_fn, "w") as out_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                out_file.write(line)
                continue
            if condition(line):
                out_file.write(line)

## Construct training/validation/test splits

In [None]:
def generate_temporal_splits():
    feature = "user_anime_list_sorted"
    temporal_sort(
        os.path.join(source_dir, "user_anime_list.csv"),
        os.path.join(outdir, f"{feature}.csv"),
    )
    temporal_split(
        os.path.join(outdir, f"{feature}.csv"),
        os.path.join(outdir, f"{feature}_training.csv"),
        os.path.join(outdir, f"{feature}_valtest.csv"),
        0.9,
    )
    random_split(
        os.path.join(outdir, f"{feature}_valtest.csv"),
        os.path.join(outdir, f"{feature}_validation.csv"),
        os.path.join(outdir, f"{feature}_test.csv"),
        0.5,
    )
    os.remove(os.path.join(outdir, f"{feature}_valtest.csv"))

In [None]:
def generate_content_splits(split):
    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"ptw_{split}.csv"),
        lambda x: x.strip().split(",")[4] == "1",
    )
    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"implicit_{split}.csv"),
        lambda x: float(x.strip().split(",")[2]) == 0
        and x.strip().split(",")[4] != "1",
    )
    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"explicit_{split}.csv"),
        lambda x: float(x.strip().split(",")[2]) != 0
        and x.strip().split(",")[4] != "1",
    )

## Construct negative splits

In [None]:
# returns a dict of user -> set of item ids
def get_user_item_pairs(input_fns):
    user_items = {}
    for input_fn in input_fns:
        with open(input_fn, "r") as in_file:
            header = False
            for line in tqdm(in_file):
                if not header:
                    header = True
                    continue
                fields = line.strip().split(",")
                userid = fields[0]
                itemid = fields[1]
                userid, itemid = int(userid), int(itemid)
                if userid not in user_items:
                    user_items[userid] = set()
                user_items[userid].add(itemid)
    return user_items

In [None]:
def get_max_ids(input_fn):
    max_userid = -1
    max_itemid = -1
    with open(input_fn, "r") as in_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                continue
            fields = line.strip().split(",")
            userid = fields[0]
            itemid = fields[1]
            userid, itemid = int(userid), int(itemid)
            max_userid = max(max_userid, int(userid))
            max_itemid = max(max_itemid, int(itemid))
    return max_userid, max_itemid

In [None]:
# writes out a split of `sample_size` (user, item) pairs that
# are sampled uniformly at random from (user, item) pairs where
# the user has not seen the item
def sample_negative_set(
    user_item_pairs, max_userid, max_itemid, sample_size, output_fn
):
    with open(output_fn, "w") as out_file:
        out_file.write("user,item\n")
        with tqdm(total=sample_size) as pbar:
            while sample_size > 0:
                user = random.randint(0, max_userid)
                item = random.randint(0, max_itemid)
                if user in user_item_pairs and item in user_item_pairs[user]:
                    continue
                if user not in user_item_pairs:
                    user_item_pairs[user] = set()
                user_item_pairs[user].add(item)
                out_file.write(f"{user},{item}\n")
                sample_size -= 1
                pbar.update(1)

In [None]:
def generate_negative_splits():
    max_userid, max_itemid = get_max_ids(os.path.join(outdir, "user_anime_list_sorted.csv"))
    splits = ["training", "validation", "test"]
    num_samples = max_userid * 100    
    for i in range(len(splits)):
        user_item_pairs = get_user_item_pairs(
            [os.path.join(outdir, f"user_anime_list_sorted_{x}.csv") for x in splits[:i]]
        )
        sample_negative_set(user_item_pairs, max_userid, max_itemid, num_samples, os.path.join(outdir, f"negative_{splits[i]}.csv"))

## Write splits 

In [None]:
generate_temporal_splits()

In [None]:
for split in ["training", "validation", "test"]:
    generate_content_splits(split)

In [None]:
generate_negative_splits()

In [None]:
max_userid, max_itemid = get_max_ids(os.path.join(outdir, "user_anime_list_sorted.csv"))
with open(os.path.join(source_dir, "uid_encoding.csv"), 'w') as out_file:
    out_file.write(f"max_userid,{max_userid}\n")
    out_file.write(f"max_itemid,{max_itemid}\n")    