# Prune Media Lists
* Drop any duplicate rows
* Drop any users with 3 or fewer item interactions
* Drop any unparseable rows

In [None]:
import gc
import os

import numpy as np
import pandas as pd
from tqdm import tqdm

In [None]:
outdir = "../../data/raw_data"

In [None]:
HEADER_FIELDS = []

In [None]:
def process(media, remove_line, error_file):
    source = os.path.join(outdir, f"user_{media}_list.csv")
    dest = os.path.join(outdir, f"user_{media}_list.csv~")
    error_file = os.path.join(outdir, error_file)
    with open(source, "r") as in_file, open(dest, "w") as out_file, open(
        error_file, "w"
    ) as err_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                global HEADER_FIELDS
                HEADER_FIELDS = line.strip().split(",")
                out_file.write(line)
                continue
            try:
                if remove_line(media, line):
                    err_file.write(line)
                else:
                    out_file.write(line)
            except Exception as e:
                print(line)
                raise e
    os.rename(dest, source)

In [None]:
def remove_unmatched_titles(media, line):
    fields = line.strip().split(",")
    col = HEADER_FIELDS.index("mediaid")
    return int(fields[col]) not in valid_titles

In [None]:
def remove_duplicates(media, line, partition):
    fields = line.strip().split(",")
    user_col = HEADER_FIELDS.index("userid")
    item_col = HEADER_FIELDS.index("mediaid")
    if fields[user_col] not in user_to_uid:
        user_to_uid[fields[user_col]] = np.int32(len(user_to_uid))
    user = user_to_uid[fields[user_col]]
    if user % partition[1] != partition[0]:
        return False
    item = np.int32(fields[item_col])
    key = (user, item)
    if key not in seen_items:
        seen_items.add(key)
        return False
    return True

In [None]:
def count_users(media, line):
    fields = line.strip().split(",")
    user_col = HEADER_FIELDS.index("userid")
    user = fields[user_col]
    if user not in user_counts:
        user_counts[user] = 0
        user_to_uid[user] = len(user_to_uid)
    user_counts[user] += 1
    return False

In [None]:
def remove_sparse_users(media, line, N=4):
    fields = line.strip().split(",")
    user_col = HEADER_FIELDS.index("userid")
    user = fields[user_col]
    return user_counts[user] < N

In [None]:
ALL_MEDIUMS = ["manga", "anime"]

In [None]:
for media in ALL_MEDIUMS:
    valid_titles = set(pd.read_csv(f"{outdir}/{media}.csv")[f'{media}_id'])
    process(media, remove_unmatched_titles, f"prune.{media}.unmatched.csv")

In [None]:
user_to_uid = {}
num_partitions = 2 # shard data to reduce memory pressure
for media in ALL_MEDIUMS:
    for partition in range(num_partitions):
        seen_items = set()        
        gc.collect()
        remove_duplicates_fn = lambda media, line: remove_duplicates(
            media, line, (partition, num_partitions)
        )
        process(media, remove_duplicates_fn, f"prune.{media}.duplicates.{partition}.csv")

In [None]:
user_counts = {}
user_to_uid = {}
for media in ALL_MEDIUMS:
    process(media, count_users, f"prune.{media}.empty.csv")
for media in ALL_MEDIUMS:
    process(media, remove_sparse_users, f"prune.{media}.sparse.csv")