# Prune Media Lists
* Drop any duplicate rows
* Drop any users with 3 or fewer item interactions
* Drop any unparseable rows

In [None]:
import gc
import os
from functools import cache

import numpy as np
import pandas as pd
import yaml
from tqdm import tqdm

In [None]:
part = 0

In [None]:
outdir = "../../data/raw_data"

In [None]:
HEADER_FIELDS = []
ALL_MEDIUMS = ["manga", "anime"]

In [None]:
@cache
def get_col_id(name):
    return HEADER_FIELDS.index(name)

In [None]:
def process(media, remove_line, error_file):
    source = os.path.join(outdir, f"user_{media}_list.{part}.csv")
    dest = os.path.join(outdir, f"user_{media}_list.{part}.csv~")
    error_file = os.path.join(outdir, error_file)
    with open(source, "r") as in_file, open(dest, "w") as out_file, open(
        error_file, "w"
    ) as err_file:
        header = False
        for line in tqdm(in_file):
            if not header:
                header = True
                global HEADER_FIELDS
                fields = line.strip().split(",")
                if HEADER_FIELDS:
                    assert HEADER_FIELDS == fields
                HEADER_FIELDS = fields
                out_file.write(line)
                continue
            try:
                if remove_line(media, line):
                    err_file.write(line)
                else:
                    out_file.write(line)
            except Exception as e:
                print(line)
                raise e
    os.rename(dest, source)

In [None]:
def remove_unmatched_titles(media, line, valid_titles):
    fields = line.strip().split(",")
    col = get_col_id("mediaid")
    return int(fields[col]) not in valid_titles

In [None]:
def remove_duplicates(media, line, user_to_uid, seen_items):
    fields = line.strip().split(",")
    user_col = get_col_id("userid")
    item_col = get_col_id("mediaid")
    if fields[user_col] not in user_to_uid:
        user_to_uid[fields[user_col]] = np.int32(len(user_to_uid))
    user = user_to_uid[fields[user_col]]
    item = np.int32(fields[item_col])
    key = (user, item)
    if key not in seen_items:
        seen_items.add(key)
        return False
    return True

In [None]:
def count_col(media, line, counts, key_to_uid, name):
    fields = line.strip().split(",")
    col = get_col_id(name)
    key = fields[col]
    if key not in counts:
        counts[key] = 0
        key_to_uid[key] = len(key_to_uid)
    counts[key] += 1
    return False


def remove_sparse_col(media, line, counts, key_to_uid, name, N):
    fields = line.strip().split(",")
    col = get_col_id(name)
    key = fields[col]
    return counts[key] < N

In [None]:
def get_settings():
    d = {}
    for s in ["default_settings", "private_settings"]:
        with open(f"../../environment/{s}.yml", "r") as f:
            d |= yaml.safe_load(f)
    return d

In [None]:
def prune_media():
    for media in ALL_MEDIUMS:
        valid_titles = set(pd.read_csv(f"{outdir}/{media}.csv")[f"{media}_id"])
        fn = lambda media, line: remove_unmatched_titles(media, line, valid_titles)
        process(media, fn, f"prune.{media}.unmatched.{part}.csv")

In [None]:
def prune_duplicates():
    user_to_uid = {}
    for media in ALL_MEDIUMS:
        seen_items = set()
        fn = lambda media, line: remove_duplicates(media, line, user_to_uid, seen_items)
        process(media, fn, f"prune.{media}.duplicates.{part}.csv")

In [None]:
def prune_sparse():
    for x, y in zip(
        ["userid", "mediaid"], ["min_user_interactions", "min_item_interactions"]
    ):
        N = get_settings()[y]
        if x == "userid":
            counts = {}
            key_to_uid = {}
            count_fn = lambda media, line: count_col(media, line, counts, key_to_uid, x)
            remove_fn = lambda media, line: remove_sparse_col(
                media, line, counts, key_to_uid, x, N
            )
            for media in ALL_MEDIUMS:
                process(media, count_fn, f"prune.{media}.empty.{part}.csv")
            for media in ALL_MEDIUMS:
                process(media, remove_fn, f"prune.{media}.sparse.{part}.csv")
        elif x == "mediaid":
            for media in ALL_MEDIUMS:
                counts = {}
                key_to_uid = {}
                count_fn = lambda media, line: count_col(
                    media, line, counts, key_to_uid, x
                )
                remove_fn = lambda media, line: remove_sparse_col(
                    media, line, counts, key_to_uid, x, N
                )
                process(media, count_fn, f"prune.{media}.empty.{part}.csv")
                process(media, remove_fn, f"prune.{media}.sparse.{part}.csv")
        else:
            assert False

In [None]:
prune_media()

In [None]:
prune_duplicates()

In [None]:
prune_sparse()