In [1]:
import math
import time
import pandas as pd
import tensorflow as tf

from circleguard import KeylessCircleguard, ReplayDir
from slider import Library

from utils.replay_processing import get_embeddings, get_judgments, get_beatmap_context

In [2]:
INDEX_PATH = "../data/indices/base_index.csv"
BEATMAP_PATH = "../data/beatmaps"
REPLAY_PATH = "../data/replays/osr"
DB_PATH = "../data/replays/osr/.circleguard.db"
SLIDER_PATH = "../data/beatmaps"

In [3]:
index_df = pd.read_csv(INDEX_PATH, low_memory = False)
beatmap_library = Library(BEATMAP_PATH)

print(f"Num. replays in index: {len(index_df)}")
print(f"Num. of beatmaps in library: {len(beatmap_library.ids)}")
index_df.head()

Num. replays in index: 381570
Num. of beatmaps in library: 55087


Unnamed: 0.1,Unnamed: 0,replayHash,beatmapHash,summary,date,playerName,modsReadable,mods,performance-IsFail,performance-Accuracy,...,beatmap-Circles,beatmap-Sliders,beatmap-Spinners,beatmapPlay-BPMMin,beatmapPlay-BPMMax,beatmapPlay-HP,beatmapPlay-OD,beatmapPlay-AR,beatmapPlay-CS,osrReplayUrl
0,0,857623324645f59599fc4e9c1c7e1130,db4cdde15984869de346686ceb6bc1a5,[7.44 ⭐] My Angel Shori | Imperial Circus Dead...,2022-11-11T08:03:56.656728,My Angel Shori,,0,False,0.954368,...,1438,442,1,134.0,235.0,5.0,9.2,9.7,4.5,https://dl.issou.best/ordr/replays/cfc43370ca0...
1,1,b1675645282756f18e0c3f142f290d94,04129d1b26d6bf3f35cb00a04f4a8c88,[6.55 ⭐] AlexTheProtoTTV | katagiri - Buta Mus...,2022-11-11T04:37:52.235160,AlexTheProtoTTV,,0,False,0.857687,...,832,392,1,210.0,210.0,4.5,9.2,9.8,3.8,https://dl.issou.best/ordr/replays/e35f6f3e3b1...
2,2,6170131dcd9ad32d4c0287422b256423,8e3e4e77e6498a994bfa7505e735155d,[6.37 ⭐] Asio_ | Qrispy Joybox feat. mao - Col...,2022-11-11T08:01:02.823288,Asio_,HDDT,72,False,0.916667,...,189,118,1,228.0,228.0,5.2,10.083333,10.333333,3.8,https://dl.issou.best/ordr/replays/f214715d45d...
3,3,d3fc501dadc1fabd60f847432cc3de84,27f9c5f8496bfdab8623e3be6cf73f25,[6.3 ⭐] sotarks fan123 | DragonForce - Valley ...,2022-11-11T07:46:29.042768,sotarks fan123,,0,True,0.948815,...,1486,660,2,71.0,200.0,6.0,8.5,9.2,4.0,https://dl.issou.best/ordr/replays/cf3f6778f86...
4,4,8e62a939e4c97fe22cbd90e90d338a53,1e8f966c7a8f992cb2f3f5bab7d55925,[6.34 ⭐] XxSzymenxx | DragonForce - Through th...,2022-03-11T10:40:06.741672,XxSzymenxx,,0,True,0.731128,...,1534,587,5,170.0,200.0,6.2,9.0,9.5,4.0,https://dl.issou.best/ordr/replays/b90ed42a04b...


In [4]:
tf_options = tf.io.TFRecordOptions(
    compression_type = "ZLIB",
)

def build_tf_example(replay, bm_library, cg):

    embeddings = get_embeddings(replay, bm_library)
    judgments = get_judgments(replay, cg)
    beatmap_context = get_beatmap_context(replay, bm_library)
    replay_length = len(embeddings)

    embeddings_feature = tf.train.Feature(
        float_list = tf.train.FloatList(
            value = embeddings.flatten()
        )
    )

    judgments_feature =  tf.train.Feature(
        int64_list = tf.train.Int64List(
            value = judgments.flatten()
        )
    )

    beatmap_context_feature = tf.train.Feature(
        float_list = tf.train.FloatList(value = beatmap_context)
    )

    replay_length_feature = tf.train.Feature(
        int64_list = tf.train.Int64List(value = [replay_length])
    )

    feature = {
        "embeddings" : embeddings_feature,
        "judgments" : judgments_feature,
        "beatmap_context" : beatmap_context_feature,
        "replay_length" : replay_length_feature
    }

    return tf.train.Example(features = tf.train.Features(feature = feature))

In [5]:
BATCH_SIZE = 2048

valid_index_df = pd.read_csv("../data/indices/test_index.csv")
replay_indices = valid_index_df["replay_idx"]

start_batch = 0
end_batch = math.floor(len(replay_indices) / BATCH_SIZE) + 1 

def write_batch(batch_idx):
    
    start_time = time.perf_counter()

    cg = KeylessCircleguard( 
        db_path = DB_PATH,
        slider_dir = SLIDER_PATH,
        cache = True
    )
    
    replay_dir = ReplayDir(REPLAY_PATH)
    cg.load_info(replay_dir)

    start_replay_idx = batch_idx * BATCH_SIZE
    end_replay_idx = min( (batch_idx + 1) * BATCH_SIZE, len(replay_indices))
    replay_idx_set = set( replay_indices[start_replay_idx : end_replay_idx] )
    
    replay_dir.replays = [ replay_dir.replays[idx] for idx in replay_idx_set ]

    try:
        cg.load(replay_dir) # expensive
    except Exception as e:
        print(f"Bad batch: {batch_idx}\n{e}")
        return

    writer = tf.io.TFRecordWriter(
        path = f"../data/records/test/test_{batch_idx}.tfrecord",
        options = tf_options,
    )

    try:
        for replay in replay_dir:
            mod_str = str(replay.mods)
            if "AT" not in mod_str and "RL" not in mod_str and "AP" not in mod_str and "V2" not in mod_str: # oops forgot about these in the validation
                tf_example = build_tf_example( replay, beatmap_library, cg )
                writer.write(tf_example.SerializeToString())
    except Exception as e:
        print(f"Bad batch: {batch_idx}\n{e}")
        writer.close()
        return

    writer.close()

    stop_time = time.perf_counter()
    print(f"Completed batch {batch_idx} in {stop_time - start_time:0.4f} seconds.")
    
    return

In [6]:
for batch_idx in range(start_batch, end_batch):
    write_batch(batch_idx)

Completed batch 0 in 492.1009 seconds.
Completed batch 1 in 506.4502 seconds.
Completed batch 2 in 535.1151 seconds.
Completed batch 3 in 556.6976 seconds.
Completed batch 4 in 536.1156 seconds.
Completed batch 5 in 557.0286 seconds.
Completed batch 6 in 527.2597 seconds.
Completed batch 7 in 501.6619 seconds.
Completed batch 8 in 518.7133 seconds.
Completed batch 9 in 521.6275 seconds.
Completed batch 10 in 536.6495 seconds.
Completed batch 11 in 568.5472 seconds.
Completed batch 12 in 399.2265 seconds.
