# Generate Splits
* There are two tasks we create datasets for -- random item prediction, and next item prediction
* Users are randomly sharded between the tasks and each task has its own training, validation and test split
* Each split is further partitioned by interaction type: whether the user rated the item, watched the item, 
  or put the item on their plan-to-watch list

In [None]:
import math
import os
import random

import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
source_dir = "../../data/processed_data"

In [None]:
outdir = "../../data/splits"
os.makedirs(outdir, exist_ok=True)

In [None]:
random.seed(20220128)

In [None]:
media = ""

# Sort items by timestamp

In [None]:
def shard_by_user(file, num_shards):
    try:
        outfiles = []
        for i in range(num_shards):
            outfiles.append(open(f"{file}.shard.{i}", "w"))
        with open(file, "r") as in_file:
            header = False
            for line in tqdm(in_file):
                if not header:
                    header = True
                    user_col = line.strip().split(",").index("username")
                    for f in outfiles:
                        f.write(line)
                    continue
                fields = line.strip().split(",")
                user = fields[user_col]
                outfiles[int(user) % num_shards].write(line)
    finally:
        for f in outfiles:
            f.close()


def temporal_sort(input_fn, output_fn):
    df = pd.read_csv(input_fn)
    df = df.sort_values(by=["username", "timestamp"]).reset_index(drop=True)
    df["unit"] = 1
    df["order"] = (
        df.groupby("username")["unit"].apply(lambda x: x.cumsum()[::-1]).values
    )
    df = df.drop(columns="unit")
    df.to_csv(output_fn, index=False)


def sharded_temporal_sort(input_fn, output_fn, num_shards=16):
    shard_by_user(input_fn, num_shards)
    for i in tqdm(range(num_shards)):
        temporal_sort(f"{input_fn}.shard.{i}", f"{output_fn}.shard.{i}")
        os.remove(f"{input_fn}.shard.{i}")
    with open(output_fn, "w") as outfile:
        for i in tqdm(range(num_shards)):
            fn = f"{output_fn}.shard.{i}"
            with open(fn, "r") as infile:
                header = False
                for line in infile:
                    if not header:
                        header = True
                        if i == 0:
                            outfile.write(line)
                        continue
                    outfile.write(line)
            os.remove(fn)

In [None]:
sharded_temporal_sort(
    os.path.join(source_dir, f"user_{media}_list.csv"),
    os.path.join(outdir, f"user_{media}_list_sorted.csv"),
)

# Determine which task the user is assigned to

In [None]:
ALL_TASKS = ["temporal", "temporal_causal"]


def get_assignment(tasks, worker):
    return tasks[worker % len(tasks)]


def get_split(userid):
    return get_assignment(["training", "validation", "test"], userid)


def get_task(userid):
    return get_assignment(ALL_TASKS, userid)

# Generate task splits!

In [None]:
def get_temporal_percentage(test_months):
    with open(os.path.join(source_dir, f"{media}_processing_encodings.csv"), "r") as in_file:

        def parse_line(field):
            line = in_file.readline()
            fields = line.split(",")
            assert len(fields) == 2
            assert fields[0] == field
            return int(fields[1])

        min_timestamp = parse_line("min_timestamp")
        max_timestamp = parse_line("max_timestamp")
    seconds_in_month = 2.628e6
    month = seconds_in_month / (max_timestamp - min_timestamp)
    return 1 - month * test_months

In [None]:
def generate_splits(input_fn, output_fn, training_params):
    tasks = ALL_TASKS
    assert len(training_params) == len(tasks)
    training = [
        open(os.path.join(outdir, f"{task}.training.{output_fn}"), "w")
        for task in tasks
    ]
    validation = [
        open(os.path.join(outdir, f"{task}.validation.{output_fn}"), "w")
        for task in tasks
    ]
    test = [
        open(os.path.join(outdir, f"{task}.test.{output_fn}"), "w") for task in tasks
    ]

    with open(os.path.join(outdir, input_fn), "r") as in_file:
        header = False
        for line in tqdm(in_file):
            fields = line.strip().split(",")
            if not header:
                header = True
                status_col = fields.index("status")
                timestamp_col = fields.index("timestamp")
                user_col = fields.index("username")
                order_col = fields.index("order")
                for i in range(len(tasks)):
                    training[i].write(line)
                    validation[i].write(line)
                    test[i].write(line)
                continue

            timestamp = float(fields[timestamp_col])
            userid = int(fields[user_col])
            order = int(fields[order_col])
            rand = random.random()

            task = get_task(userid)
            if task == "random":
                train = rand < training_params[tasks.index(task)]
            elif task == "temporal":
                train = timestamp < training_params[tasks.index(task)]
            elif task == "temporal_causal":
                train = (timestamp < training_params[tasks.index(task)][0]) or (
                    order > training_params[tasks.index(task)][1]
                )
            else:
                assert False

            if train:
                training[tasks.index(task)].write(line)
            else:
                split = get_split(userid)
                if split == "training":
                    training[tasks.index(task)].write(line)
                elif split == "validation":
                    validation[tasks.index(task)].write(line)
                elif split == "test":
                    test[tasks.index(task)].write(line)
                else:
                    assert False

        for i in range(len(tasks)):
            training[i].close()
            validation[i].close()
            test[i].close()

In [None]:
# the temporal holdout split will be (interactions within the last 1.5 months)
# the temporal_causal holdout split will be (interactions within the last 1.5 months) AND (one of the 5 most recent interactions)
generate_splits(
    f"user_{media}_list_sorted.csv",
    f"user_{media}_list.csv",
    [get_temporal_percentage(1.5), (get_temporal_percentage(1.5), 5)],
)

# Generate content splits!

In [None]:
def subset(input_fn, output_fn, condition):
    with open(input_fn, "r") as in_file, open(output_fn, "w") as out_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                out_file.write(line)
                continue
            if condition(line):
                out_file.write(line)

In [None]:
def generate_content_splits(fn, valid_itemids):
    with open(os.path.join(outdir, fn), "r") as f:
        first_line = f.readline()
    fields = first_line.strip().split(",")
    score_col = fields.index("score")
    status_col = fields.index("status")
    item_col = fields.index(f"{media}id")

    def invalid(x):
        return int(x.strip().split(",")[item_col]) not in valid_itemids

    def is_ptw(x):
        return x.strip().split(",")[status_col] == "1"

    def is_implicit(x):
        return float(x.strip().split(",")[score_col]) == 0

    subset(
        os.path.join(outdir, fn),
        os.path.join(outdir, f"invalid.{fn}"),
        lambda x: invalid(x),
    )
    subset(
        os.path.join(outdir, fn),
        os.path.join(outdir, f"ptw.{fn}"),
        lambda x: not invalid(x) and is_ptw(x),
    )
    subset(
        os.path.join(outdir, fn),
        os.path.join(outdir, f"implicit.{fn}"),
        lambda x: not invalid(x) and not is_ptw(x) and is_implicit(x),
    )
    subset(
        os.path.join(outdir, fn),
        os.path.join(outdir, f"explicit.{fn}"),
        lambda x: not invalid(x) and not is_ptw(x) and not is_implicit(x),
    )

In [None]:
def get_max_ids(input_fn):
    max_userid = -1
    max_itemid = -1
    with open(input_fn, "r") as in_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                user_col = line.strip().split(",").index("username")
                item_col = line.strip().split(",").index(f"{media}id")
                header = True
                continue
            fields = line.strip().split(",")
            userid = fields[user_col]
            itemid = fields[item_col]
            userid, itemid = int(userid), int(itemid)
            max_userid = max(max_userid, int(userid))
            max_itemid = max(max_itemid, int(itemid))
    return max_userid, max_itemid

In [None]:
def get_item_ids(input_fns):
    uids = set()
    for input_fn in input_fns:
        with open(input_fn, "r") as in_file:
            header = False
            for line in tqdm(in_file):
                if not header:
                    header = True
                    item_col = line.strip().split(",").index(f"{media}id")
                    timestamp_col = line.strip().split(",").index("timestamp")
                    continue
                fields = line.strip().split(",")
                if math.isclose(float(fields[timestamp_col]), -1):
                    # skip rows with corrupted timestamps
                    continue
                itemid = int(fields[item_col])
                uids.add(itemid)
    return uids

In [None]:
max_userid, max_itemid = get_max_ids(
    os.path.join(outdir, f"user_{media}_list_sorted.csv")
)
valid_itemids = get_item_ids(
    [
        os.path.join(outdir, f"{task}.training.user_{media}_list.csv")
        for task in ALL_TASKS
    ]
)

with open(os.path.join(source_dir, f"{media}_uid_encoding.csv"), "w") as out_file:
    out_file.write(f"max_userid,{max_userid}\n")
    out_file.write(f"max_itemid,{max_itemid}\n")
    for uid in valid_itemids:
        out_file.write(f"valid_itemid,{uid}\n")

In [None]:
for task in ALL_TASKS:
    for split in ["training", "validation", "test"]:
        generate_content_splits(f"{task}.{split}.user_{media}_list.csv", valid_itemids)