# Generate Splits
* We split the dataset into by interaction type: whether the user rated the item, watched the item, 
  or put the item on their plan-to-watch list
* Each of the above splits is further separated into a training, validation, and test split
* The training split consists of all data for half the users,
  and all data except the most recent month for the other half
* The validation and test splits are a random partition of the remaining data
* Any items that are not present in the training set are removed from the validation and test sets

In [None]:
import math
import os
import random

import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
source_dir = "../../data/processed_data"

In [None]:
outdir = "../../data/splits"
os.makedirs(outdir, exist_ok=True)

In [None]:
random.seed(20220128)

In [None]:
def shard_by_user(file, num_shards):
    try:
        outfiles = []
        for i in range(num_shards):
            outfiles.append(open(f"{file}.shard.{i}", "w"))
        with open(file, "r") as in_file:
            header = False
            for line in tqdm(in_file):
                if not header:
                    header = True
                    user_col = line.strip().split(",").index("username")
                    for f in outfiles:
                        f.write(line)
                    continue
                fields = line.strip().split(",")
                user = fields[user_col]
                outfiles[int(user) % num_shards].write(line)
    finally:
        for f in outfiles:
            f.close()

In [None]:
def temporal_sort(input_fn, output_fn):
    user_anime_lists = pd.read_csv(input_fn)
    user_anime_lists = user_anime_lists.sort_values(
        by=["username", "timestamp"]
    ).reset_index(drop=True)
    user_anime_lists.to_csv(output_fn, index=False)

In [None]:
def sharded_temporal_sort(input_fn, output_fn, num_shards=16):
    shard_by_user(input_fn, num_shards)
    for i in tqdm(range(num_shards)):
        temporal_sort(f"{input_fn}.shard.{i}", f"{output_fn}.shard.{i}")
        os.remove(f"{input_fn}.shard.{i}")
    with open(output_fn, "w") as outfile:
        for i in tqdm(range(num_shards)):
            fn = f"{output_fn}.shard.{i}"
            with open(fn, "r") as infile:
                header = False
                for line in infile:
                    if not header:
                        header = True
                        if i == 0:
                            outfile.write(line)
                        continue
                    outfile.write(line)
            os.remove(fn)

In [None]:
def combined_split(input_fn, train_fn, test_fn, training_param, split_type):
    with open(input_fn, "r") as in_file, open(train_fn, "w") as training, open(
        test_fn, "w"
    ) as test:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                status_col = line.strip().split(",").index("status")
                timestamp_col = line.strip().split(",").index("timestamp")
                user_col = line.strip().split(",").index("username")                
                training.write(line)
                test.write(line)
                continue

            fields = line.strip().split(",")
            timestamp = float(fields[timestamp_col])
            userid = int(fields[timestamp_col])
            rand = random.random()

            if split_type == "random":
                for_training = rand < training_param
            elif split_type == "temporal":
                for_training = (timestamp < training_param) or (userid % 2 == 0)
            else:
                assert False

            if for_training:
                training.write(line)
            else:
                test.write(line)

In [None]:
def random_split(input_fn, train_fn, test_fn, p_training):
    combined_split(input_fn, train_fn, test_fn, p_training, "random")


def temporal_split(input_fn, train_fn, test_fn, test_months):
    with open(os.path.join(source_dir, "processing_encodings.csv"), "r") as in_file:

        def parse_line(field):
            line = in_file.readline()
            fields = line.split(",")
            assert len(fields) == 2
            assert fields[0] == field
            return int(fields[1])

        min_timestamp = parse_line("min_timestamp")
        max_timestamp = parse_line("max_timestamp")
    seconds_in_month = 2.628e6
    month = seconds_in_month / (max_timestamp - min_timestamp)
    combined_split(input_fn, train_fn, test_fn, 1 - month * test_months, "temporal")

In [None]:
# scans every line of the input file. If the line satisfies the
# condition, then it is written to the output file
def subset(input_fn, output_fn, condition):
    with open(input_fn, "r") as in_file, open(output_fn, "w") as out_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                out_file.write(line)
                continue
            if condition(line):
                out_file.write(line)

## Construct training/validation/test splits

In [None]:
def generate_temporal_splits():
    outfn = "user_anime_list_sorted"
    sharded_temporal_sort(
        os.path.join(source_dir, "user_anime_list.csv"),
        os.path.join(outdir, f"{outfn}.csv"),
    )
    temporal_split(
        os.path.join(outdir, f"{outfn}.csv"),
        os.path.join(outdir, f"{outfn}_training.csv"),
        os.path.join(outdir, f"{outfn}_valtest.csv"),
        3.0,  # use the last 3 months as the validation/test sets
    )
    random_split(
        os.path.join(outdir, f"{outfn}_valtest.csv"),
        os.path.join(outdir, f"{outfn}_validation.csv"),
        os.path.join(outdir, f"{outfn}_test.csv"),
        0.5,
    )
    os.remove(os.path.join(outdir, f"{outfn}_valtest.csv"))

In [None]:
def generate_content_splits(split, valid_itemids):
    with open(os.path.join(outdir, f"user_anime_list_sorted_{split}.csv")) as f:
        first_line = f.readline()
    fields = first_line.strip().split(",")
    score_col = fields.index("score")
    status_col = fields.index("status")
    item_col = fields.index("animeid")

    def invalid(x):
        return int(x.strip().split(",")[item_col]) not in valid_itemids

    def is_ptw(x):
        return x.strip().split(",")[status_col] == "1"

    def is_implicit(x):
        return float(x.strip().split(",")[score_col]) == 0

    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"invalid_{split}.csv"),
        lambda x: invalid(x),
    )
    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"ptw_{split}.csv"),
        lambda x: not invalid(x) and is_ptw(x),
    )
    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"implicit_{split}.csv"),
        lambda x: not invalid(x) and not is_ptw(x) and is_implicit(x),
    )
    subset(
        os.path.join(outdir, f"user_anime_list_sorted_{split}.csv"),
        os.path.join(outdir, f"explicit_{split}.csv"),
        lambda x: not invalid(x) and not is_ptw(x) and not is_implicit(x),
    )

In [None]:
def get_max_ids(input_fn):
    max_userid = -1
    max_itemid = -1
    with open(input_fn, "r") as in_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                user_col = line.strip().split(",").index("username")
                item_col = line.strip().split(",").index("animeid")
                header = True
                continue
            fields = line.strip().split(",")
            userid = fields[user_col]
            itemid = fields[item_col]
            userid, itemid = int(userid), int(itemid)
            max_userid = max(max_userid, int(userid))
            max_itemid = max(max_itemid, int(itemid))
    return max_userid, max_itemid

In [None]:
def get_item_ids(input_fn):
    uids = set()
    with open(input_fn, "r") as in_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                item_col = line.strip().split(",").index("animeid")
                timestamp_col = line.strip().split(",").index("timestamp")
                continue
            fields = line.strip().split(",")
            if math.isclose(float(fields[timestamp_col]), -1):
                # skip rows with corrupted timestamps
                continue
            itemid = int(fields[item_col])
            uids.add(itemid)
    return uids

## Write splits 

In [None]:
generate_temporal_splits()

In [None]:
max_userid, max_itemid = get_max_ids(os.path.join(outdir, "user_anime_list_sorted.csv"))
valid_itemids = get_item_ids(
    os.path.join(outdir, "user_anime_list_sorted_training.csv")
)

with open(os.path.join(source_dir, "uid_encoding.csv"), "w") as out_file:
    out_file.write(f"max_userid,{max_userid}\n")
    out_file.write(f"max_itemid,{max_itemid}\n")
    for uid in valid_itemids:
        out_file.write(f"valid_itemid,{uid}\n")

In [None]:
for split in ["training", "validation", "test"]:
    generate_content_splits(split, valid_itemids)