In [3]:
import pandas as pd
from sklearn.model_selection import train_test_split

from circleguard import KeylessCircleguard, ReplayDir
from slider import Library

from utils.replay_processing import filter_replay, validate_replay

In [None]:
RANDOM_SEED = 12
BATCH_SIZE = 1024

In [None]:
INDEX_PATH = "../data/indices/base_index.csv"
BEATMAP_PATH = "../data/beatmaps"
REPLAY_PATH = "../data/replays/osr"
DB_PATH = "../data/replays/osr/.circleguard.db"
SLIDER_PATH = "../data/beatmaps"

In [None]:
# index and beatmaps

index_df = pd.read_csv(INDEX_PATH, low_memory = False)
beatmap_library = Library(BEATMAP_PATH)

print(f"Num. replays in index: {len(index_df)}")
print(f"Num. of beatmaps in library: {len(beatmap_library.ids)}")
index_df.head()

In [4]:
def get_clean_mod_str(replay):
    mod_str = str(replay.mods)
    res =   "EZ" * ("EZ" in mod_str) + \
            "HD" * ("HD" in mod_str) + \
            "DT" * ("DT" in mod_str or "NC" in mod_str) + \
            "HR" * ("HR" in mod_str)
    if res == "":
        res = "NM"
    return res


def build_valid_index_df(replay_dir, beatmap_library):

    valid_idxs = []

    for replay_idx, replay in enumerate(replay_dir):

        beatmap = beatmap_library.lookup_by_md5(replay.beatmap_hash)

        if filter_replay(replay, beatmap) == 0 and validate_replay(replay, beatmap) == 0:
            valid_idxs.append([ replay_idx, get_clean_mod_str(replay), str(replay.path) ])

    valid_index_df = pd.DataFrame(valid_idxs, columns = ["replay_idx", "mods", "path"])

    return valid_index_df

In [None]:
# START_BATCH = 0

# for batch_idx in range( START_BATCH, math.floor(len(index_df) / BATCH_SIZE) + 1 ):

#     cg = KeylessCircleguard( 
#         db_path = DB_PATH,
#         slider_dir = SLIDER_PATH,
#         cache = True
#     )
    
#     replay_dir = ReplayDir(REPLAY_PATH)
#     cg.load_info(replay_dir)

#     start_idx = batch_idx * BATCH_SIZE
#     end_idx = min( (batch_idx + 1) * BATCH_SIZE, len(replay_dir))

#     replay_dir.replays = replay_dir.replays[start_idx : end_idx]

#     try:
#         cg.load(replay_dir) # expensive
#         valid_replay_df = build_valid_index_df(replay_dir, beatmap_library)
#     except Exception as e:
#         print(f"Bad batch: {batch_idx}\n{e}")
#         continue

#     valid_replay_df["replay_idx"] = valid_replay_df["replay_idx"].apply(lambda idx: idx + start_idx)
#     valid_replay_df.to_csv(f"valid_index_df_{batch_idx}.csv", index = False)

In [None]:
# valid_index_df = pd.read_csv("data/valid_index.csv")

# X_train, X_test, _, _ = train_test_split(valid_index_df, valid_index_df["mods"], test_size = 0.1, random_state = RANDOM_SEED, stratify = valid_index_df["mods"])
# X_train, X_val, _, _  = train_test_split(X_train, X_train["mods"], test_size = 1/9, random_state = RANDOM_SEED, stratify = X_train["mods"])

# X_train.to_csv("train_index.csv", index = False)
# X_val.to_csv("val_index.csv", index = False)
# X_test.to_csv("test_index.csv", index = False)