# Generate Splits
* Create training, validation and test splits

In [None]:
import math
import os

import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
source_dir = "../../data/processed_data"

In [None]:
outdir = "../../data/splits"
os.makedirs(outdir, exist_ok=True)

In [None]:
media = ""

# Sort items by timestamp

In [None]:
def shard_by_user(file, num_shards):
    try:
        outfiles = []
        for i in range(num_shards):
            outfiles.append(open(f"{file}.shard.{i}", "w"))
        with open(file) as in_file:
            header = False
            for line in tqdm(in_file):
                if not header:
                    header = True
                    user_col = line.strip().split(",").index("userid")
                    for f in outfiles:
                        f.write(line)
                    continue
                fields = line.strip().split(",")
                user = fields[user_col]
                outfiles[int(user) % num_shards].write(line)
    finally:
        for f in outfiles:
            f.close()


def temporal_sort(input_fn, output_fn):
    df = pd.read_csv(input_fn, dtype=str)
    for key, dtype in zip(["userid", "update_order", "updated_at"], [int, int, float]):
        df[key] = df[key].astype(dtype)
    df = df.sort_values(by=["userid", "update_order", "updated_at"]).reset_index(
        drop=True
    )
    df["unit"] = 1
    df["order"] = df.groupby("userid")["unit"].apply(lambda x: x.cumsum()[::-1]).values
    df.to_csv(output_fn, index=False)


def sharded_temporal_sort(input_fn, output_fn, num_shards=16):
    shard_by_user(input_fn, num_shards)
    for i in tqdm(range(num_shards)):
        temporal_sort(f"{input_fn}.shard.{i}", f"{output_fn}.shard.{i}")
        os.remove(f"{input_fn}.shard.{i}")
    with open(output_fn, "w") as outfile:
        for i in tqdm(range(num_shards)):
            fn = f"{output_fn}.shard.{i}"
            with open(fn) as infile:
                header = False
                for line in infile:
                    if not header:
                        header = True
                        if i == 0:
                            outfile.write(line)
                        continue
                    outfile.write(line)
            os.remove(fn)

In [None]:
sharded_temporal_sort(
    os.path.join(source_dir, f"user_{media}_list.csv"),
    os.path.join(outdir, f"user_{media}_list_sorted.csv"),
)

# Generate splits!

In [None]:
ALL_SPLITS = ["training", "validation", "test"]


def get_split(userid):
    return ALL_SPLITS[userid % len(ALL_SPLITS)]

In [None]:
def get_cutoff(days):
    def parse_line(file, field, format=int):
        line = file.readline()
        fields = line.strip().split(",")
        assert len(fields) == 2
        assert fields[0] == field
        return format(fields[1])

    with open(os.path.join(source_dir, "timestamps.csv")) as f:
        min_timestamp = parse_line(f, "min_timestamp")
        max_timestamp = parse_line(f, "max_timestamp")

    with open(os.path.join(source_dir, "knowledge_cutoff.csv")) as f:
        knowledge_cutoff = parse_line(f, "knowledge_cutoff", float)

    seconds_in_day = 24 * 60 * 60
    return knowledge_cutoff - days * seconds_in_day / (max_timestamp - min_timestamp)

In [None]:
def generate_splits(input_fn, output_fn, params):
    files = {x: open(os.path.join(outdir, f"{x}.{output_fn}"), "w") for x in ALL_SPLITS}
    with open(os.path.join(outdir, input_fn)) as f:
        header = False
        for line in tqdm(f):
            fields = line.strip().split(",")
            if not header:
                header = True
                timestamp_col = fields.index("updated_at")
                user_col = fields.index("userid")
                order_col = fields.index("order")
                for g in files.values():
                    g.write(line)
                continue

            userid = int(fields[user_col])
            timestamp = float(fields[timestamp_col])
            order = int(fields[order_col])
            cutoff = params[0]
            num_interactions = params[1]
            train = (timestamp < cutoff) or (order > num_interactions)

            if train:
                files["training"].write(line)
            else:
                files[get_split(userid)].write(line)

    for f in files.values():
        f.close()

In [None]:
# interactions will go in the test split if the user is in the test split
# AND the the interaction is one of the user's N most recent watches
# AND the interaction occured less that M days ago
num_days = 7
num_interactions = 5
generate_splits(
    f"user_{media}_list_sorted.csv",
    f"user_{media}_list.csv",
    (get_cutoff(num_days), num_interactions),
)