# Split the data into training/validation/test/negative sets
* The explicit rating data is split into a 90%/10% training/test ratio
* The training set is further split into a training set and a validation set, using the same 90/10 ratio
* The userids and itemids for the validation and test splits are clipped such that the maximum userid and the maximum itemid occur within the training set (this will simplify array indexing later on).
* A negative set is sampled uniformly at random. This set consists of (user, item) pairs that the user did not watch.

In [1]:
import os
import random

import pandas as pd
from tqdm import tqdm

In [2]:
source_dir = "../../data/processed_data"

In [3]:
outdir = "../../data/splits"
if not os.path.exists(outdir):
    os.mkdir(outdir)

In [4]:
random.seed(20220128)

In [5]:
# splits the input file into a training set and a test set
def split(input_fn, train_fn, test_fn, p_training=0.9):
    with open(input_fn, "r") as in_file, open(train_fn, "w") as training, open(
        test_fn, "w"
    ) as test:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                training.write(line)
                test.write(line)
                continue
            if random.random() < p_training:
                training.write(line)
            else:
                test.write(line)

In [6]:
# scans every line of the input file. If the line satisfies the
# condition, then it is removed from the input file and appended
# to the output file
def clip(input_fn, output_fn, condition):
    with open(input_fn, "r") as in_file, open(input_fn + "~", "w") as tmp_in_file, open(
        output_fn, "a"
    ) as out_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                tmp_in_file.write(line)
                continue
            if condition(line):
                tmp_in_file.write(line)
            else:
                out_file.write(line)
        os.rename(input_fn + "~", input_fn)

In [7]:
def get_max_ids(input_fns):
    max_userid = -1
    max_itemid = -1
    for input_fn in input_fns:
        with open(input_fn, "r") as in_file:
            header = False
            for line in tqdm(in_file):
                if not header:
                    header = True
                    continue
                userid, itemid, rating = line.split(",")
                max_userid = max(max_userid, int(userid))
                max_itemid = max(max_itemid, int(itemid))
    return max_userid, max_itemid

In [8]:
# returns a dict of user -> set of item ids
def get_user_item_pairs(input_fns):
    user_items = {}
    for input_fn in input_fns:
        with open(input_fn, "r") as in_file:
            header = False
            for line in tqdm(in_file):
                if not header:
                    header = True
                    continue
                userid, itemid, rating = line.split(",")
                userid, itemid = int(userid), int(itemid)
                if userid not in user_items:
                    user_items[userid] = set()
                user_items[userid].add(itemid)
    return user_items

In [9]:
def length(input_fn):
    lines = 0
    with open(input_fn, "r") as in_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                continue
            lines += 1
    return lines

In [10]:
# writes out a split of `sample_size` (user, item) pairs that
# are sampled uniformly at random from (user, item) pairs where
# the user has not seen the item
def sample_negative_set(
    user_item_pairs, max_userid, max_itemid, sample_size, output_fn
):
    with open(output_fn, "w") as out_file:
        out_file.write("user,item,rating\n")
        with tqdm(total=sample_size) as pbar:
            while sample_size > 0:
                user = random.randint(0, max_userid - 1)
                item = random.randint(0, max_itemid - 1)
                if user in user_item_pairs and item in user_item_pairs[user]:
                    continue
                if user not in user_item_pairs:
                    user_item_pairs[user] = set()
                user_item_pairs[user].add(item)
                out_file.write(f"{user},{item},0\n")
                sample_size -= 1
                pbar.update(1)

## Construct training/validation/test splits

In [11]:
def generate_splits(feature):
    assert feature in {"explicit", "implicit"}
    split(
        os.path.join(source_dir, f"user_{feature}_lists.csv"),
        os.path.join(outdir, f"{feature}_trainval.csv"),
        os.path.join(outdir, f"{feature}_test.csv"),
    )
    split(
        os.path.join(outdir, f"{feature}_trainval.csv"),
        os.path.join(outdir, f"{feature}_training.csv"),
        os.path.join(outdir, f"{feature}_validation.csv"),
    )
    os.remove(os.path.join(outdir, f"{feature}_trainval.csv"))

    max_training_userid, max_training_itemid = get_max_ids(
        [os.path.join(outdir, f"{feature}_training.csv")]
    )

    def filter_maxids(x):
        userid, itemid, rating = x.split(",")
        return (int(userid) <= max_training_userid) and (
            int(itemid) <= max_training_itemid
        )

    clip(
        os.path.join(outdir, f"{feature}_validation.csv"),
        os.path.join(outdir, f"{feature}_training.csv"),
        filter_maxids,
    )

    clip(
        os.path.join(outdir, f"{feature}_test.csv"),
        os.path.join(outdir, f"{feature}_training.csv"),
        filter_maxids,
    )

In [12]:
generate_splits("explicit")
generate_splits("implicit")

168626003it [01:15, 2238288.04it/s]
151763170it [01:09, 2173577.91it/s]
136583943it [02:20, 973885.12it/s] 
15179228it [00:13, 1111835.63it/s]
16862834it [00:15, 1057356.98it/s]
61471672it [00:27, 2260328.50it/s]
55322163it [00:25, 2182538.43it/s]
49787458it [00:52, 954731.22it/s] 
5534706it [00:05, 1058907.45it/s]
6149510it [00:05, 1057751.87it/s]


## Construct the negative test split

In [13]:
user_items = get_user_item_pairs(
    [
        os.path.join(source_dir, f"user_{feature}_lists.csv")
        for feature in ["explicit", "implicit"]
    ]
)
max_userid, max_itemid = get_max_ids(
    [
        os.path.join(outdir, f"{feature}_training.csv")
        for feature in ["explicit", "implicit"]
    ]
)
num_negative_samples = sum(
    length(os.path.join(outdir, f"{feature}_test.csv"))
    for feature in ["explicit", "implicit"]
)
sample_negative_set(
    user_items,
    max_userid,
    max_itemid,
    num_negative_samples,
    os.path.join(outdir, "negative.csv"),
)

168626003it [04:29, 626164.80it/s] 
61471672it [03:06, 330120.87it/s]
136583943it [02:21, 964650.75it/s]
49787458it [00:51, 965288.95it/s]
16862834it [00:04, 3826675.19it/s]
6149510it [00:01, 3825311.77it/s]
100%|████████████████████████████| 23012342/23012342 [07:02<00:00, 54411.42it/s]
