In [1]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from circleguard import KeylessCircleguard, ReplayDir
from circleguard.judgment import JudgmentType
from slider import Library
from slider.beatmap import Circle, Slider

plt.style.use("ggplot")
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)

## Parse Notes

In [3]:
def compute_hitobject_angle(prev_obj, curr_obj, next_obj):
    
    EPSLION = 1e-06
    res = np.pi

    prev_vector = [ prev_obj.position.x - curr_obj.position.x, prev_obj.position.y - curr_obj.position.y ]
    next_vector = [ next_obj.position.x - curr_obj.position.x, next_obj.position.y - curr_obj.position.y ]

    prev_norm = (prev_vector[0] ** 2 + prev_vector[1] ** 2) ** 0.5
    next_norm = (next_vector[0] ** 2 + next_vector[1] ** 2) ** 0.5

    if prev_norm > EPSLION and next_norm > EPSLION:
        cos = (prev_vector[0] * next_vector[0] + prev_vector[1] * next_vector[1]) / (prev_norm * next_norm)
        res = np.clip( np.arccos(cos), -1, 1 )  

    return res


    
def get_hitobject_embedding(prev_obj, curr_obj, next_obj):
     

    x_position = curr_obj.position.x
    y_position = curr_obj.position.y

    in_x_offset = 0.0
    in_y_offset = 0.0
    in_distance = 0.0
    in_timedelta = 5000.0 # in ms

    out_x_offset = 0.0
    out_y_offset = 0.0
    out_distance = 0.0
    out_timedelta = 5000.0

    angle = np.pi

    is_slider = 0.0
    slider_duration = 0.0
    slider_length = 0.0
    slider_num_ticks = 0.0
    slider_num_beats = 0.0


    if prev_obj:
        in_x_offset = curr_obj.position.x - prev_obj.position.x
        in_y_offset = curr_obj.position.y - prev_obj.position.y
        in_distance = ( in_x_offset ** 2 + in_y_offset ** 2) ** 0.5
        in_timedelta = (curr_obj.time - prev_obj.time).microseconds / 1000

    if next_obj:
        out_x_offset = next_obj.position.x - curr_obj.position.x
        out_y_offset = next_obj.position.y - curr_obj.position.y
        out_distance = ( out_x_offset ** 2 + out_y_offset ** 2) ** 0.5
        out_timedelta = (next_obj.time - curr_obj.time).microseconds / 1000

    if prev_obj and next_obj:
        angle = compute_hitobject_angle(prev_obj, curr_obj, next_obj)

    if type(curr_obj) == Slider:
        is_slider = 1.0
        slider_duration = (curr_obj.end_time - curr_obj.time).microseconds / 1000
        slider_length = curr_obj.length
        slider_num_ticks = curr_obj.ticks
        slider_num_beats = curr_obj.num_beats

    return np.array([
        x_position, y_position,
        in_x_offset, in_y_offset, in_distance, in_timedelta,
        out_x_offset, out_y_offset, out_distance, out_timedelta,
        angle,
        is_slider, slider_duration, slider_length, slider_num_ticks, slider_num_beats
    ])



def filter_out_spinners(objects):
    return [o for o in objects if type(o) in (Slider, Circle)]



def get_embeddings(replay, max_length = 2048):

    beatmap = beatmap_library.lookup_by_md5(replay.beatmap_hash)
    hitobjects = filter_out_spinners( beatmap.hit_objects() )

    res_len = min(len(hitobjects), max_length)
    res = np.zeros((res_len, 16))

    res[0, :] = get_hitobject_embedding(None, hitobjects[0], hitobjects[1])

    for emb_idx in range(1, res_len - 1):
        res[emb_idx, :] = get_hitobject_embedding( hitobjects[emb_idx - 1], hitobjects[emb_idx], hitobjects[emb_idx + 1] )

    res[res_len - 1 , :] = get_hitobject_embedding(hitobjects[res_len - 2], hitobjects[res_len - 1], None)

    return res

## Parse Judgments

In [4]:
def sort_judgments(judgments):
    return sorted(judgments, key = lambda j: j.hitobject.t)
    
def encode_judgment(judgment):
     t = judgment.type
     return np.array([
          1 if t == JudgmentType.Hit300 else 0, 
          1 if t == JudgmentType.Hit100 else 0,
          1 if t == JudgmentType.Hit50 else 0,
          1 if t == JudgmentType.Miss else 0
     ])

def get_judgments(replay, max_length = 2048):
    
     judgments = sort_judgments( cg.judgments(replay) )
     res_len = min( len(judgments), max_length )
     res = np.zeros((res_len, 4))

     # could probably be written w/o for loop?
     for judgment_idx in range(res_len):
          res[judgment_idx, :] = encode_judgment(judgments[judgment_idx])

     return res

## Parse Beatmap Info

In [5]:
def get_beatmap_consts(replay):

    beatmap = beatmap_library.lookup_by_md5(replay.beatmap_hash)
    mod_str = str(replay.mods)

    return np.array([
        
        beatmap.ar(),
        beatmap.od(),
        beatmap.cs(),
        beatmap.stack_leniency,

        1.0 if "EZ" in mod_str else 0.0,
        1.0 if "HD" in mod_str else 0.0,
        1.0 if "DT" in mod_str  or "NC" in mod_str else 0.0,
        1.0 if "HR" in mod_str else 0.0,

    ])

## Validate Replays

In [6]:
def validate_replay_length(replay, beatmap_objects, replay_judgments):

    res = 0

    beatmap_objects_no_spinners = filter_out_spinners(beatmap_objects)
    replay_objects = [ j.hitobject for j in replay_judgments ]

    if len(beatmap_objects_no_spinners) != len(replay_judgments):
        #print(f"Length mismatch between beatmap objects {(len(beatmap_objects_no_spinners))} and replay judgments {(len(replay_judgments))} of replay at {replay.path}.")
        res |= 1

    if len(beatmap_objects_no_spinners) != len(replay_objects):
        #print(f"Length mismatch between beatmap objects {(len(replay_objects))} and replay objects {(len(beatmap_objects))} of replay at {replay.path}.")
        res |= 2
    
    return res



def validate_replay_hitcounts(replay, beatmap_objects, replay_judgments):

    judgment_encodings = [encode_judgment(j) for j in replay_judgments]
    judgment_hitcounts = np.sum(judgment_encodings, axis = 0)
    replay_hitcounts = np.array([ replay.count_300, replay.count_100, replay.count_50, replay.count_miss])
    hitcount_err_arr = replay_hitcounts - judgment_hitcounts

    num_spinners = len(beatmap_objects) - len(filter_out_spinners(beatmap_objects))

    if sum(hitcount_err_arr) != num_spinners:
        #print(f"Hitcount mismatch ({judgment_hitcounts} vs. {replay_hitcounts}, num_spinners = {num_spinners}) of replay at {replay.path}.")
        return 4
    
    return 0



def validate_replay_objects(replay, beatmap_objects, replay_judgments):

    res = 0

    EPSILON = 1e-06
    beatmap_objects_no_spinners = filter_out_spinners(beatmap_objects)
    replay_objects = [ j.hitobject for j in replay_judgments ]

    for idx, _ in enumerate(beatmap_objects_no_spinners):

        replay_obj = replay_objects[idx]
        beatmap_obj = beatmap_objects_no_spinners[idx]

        if abs( replay_obj.time - beatmap_obj.time.total_seconds() * 1000 ) >= EPSILON : 
            #print(f"Offset mismatch ({1.0 * replay_obj.time} vs. {beatmap_obj.time.total_seconds() * 1000}) at index {idx} of replay at {replay.path}.")
            res |= 8

        if "HR" not in str(replay.mods) and "EZ" not in str(replay.mods): 

            if replay_obj.x - beatmap_obj.position.x >= EPSILON:
                #print(f"Position mismatch (x={replay_obj.x} vs. x={beatmap_obj.position.x}) at index {idx} of replay at {replay.path}.")
                res |= 16
            
            if replay_obj.y - beatmap_obj.position.y >= EPSILON:
                #print(f"Position mismatch (y={replay_obj.y} vs. y={beatmap_obj.position.y}) at index {idx} of replay at {replay.path}.")
                res |= 32
    
    return res
            



def validate_replay(replay, beatmap):

    res = 0

    beatmap_objects = beatmap.hit_objects()
    replay_judgments = sort_judgments( cg.judgments(replay) )

    res |= validate_replay_length(replay, beatmap_objects, replay_judgments)
    res |= validate_replay_hitcounts(replay, beatmap_objects, replay_judgments)
    res |= validate_replay_objects(replay, beatmap_objects, replay_judgments)
            
    return res


## Filter Replays

In [7]:
def filter_replay(replay, beatmap):

    res = 0
    
    mod_str = str(replay.mods)
    if "V2" in mod_str or "FL" in mod_str or "HT" in mod_str:
        res |= 1
    
    if replay.count_miss > 128:
        res |= 2
    
    if beatmap.max_combo < 128 or beatmap.max_combo > 8192:
        res |= 4
    
    if beatmap.bpm_min() < 32 or beatmap.bpm_max() > 512:
        res |= 8
    
    return res
    

## Write to File

In [8]:
def get_clean_mod_str(replay):
    mod_str = str(replay.mods)
    res =   "HD" * ("HD" in mod_str) + \
            "DT" * ("DT" in mod_str or "NC" in mod_str) + \
            "HR" * ("HR" in mod_str)
    if res == "":
        res = "NM"
    return res


def build_valid_index_df(replay_dir):

    valid_idxs = []

    for replay_idx, replay in enumerate(replay_dir):

        beatmap = beatmap_library.lookup_by_md5(replay.beatmap_hash)

        if filter_replay(replay, beatmap) == 0 and validate_replay(replay, beatmap) == 0:
            valid_idxs.append([ replay_idx, get_clean_mod_str(replay) ])

    valid_index_df = pd.DataFrame(valid_idxs, columns = ["replay_idx", "mods"])

    return valid_index_df


def write_to_file(replays, file_suffix, max_length = 2048):

    embedding_np = np.zeros( (len(replays), max_length, 16) )
    judgment_np = np.zeros( (len(replays), max_length, 4) )
    beatmap_const_np = np.zeros( (len(replays), 8) )

    for replay_idx, replay in enumerate(replays):

        embeddings = get_embeddings(replay, max_length)
        judgments = get_judgments(replay, max_length)
        beatmap_consts = get_beatmap_consts(replay)

        num_pad = max_length - len(embeddings)
        if num_pad > 0:
            embeddings = np.vstack(( embeddings, np.zeros((num_pad, 16)) ))
            judgments = np.vstack(( judgments, np.zeros((num_pad, 4)) ))

        embedding_np[replay_idx, :, :] = embeddings
        judgment_np[replay_idx, :, :] = judgments
        beatmap_const_np[replay_idx, :] = beatmap_consts
    
    with open(f"embeddings_{file_suffix}.npy", 'wb') as f:
        np.save(f, embedding_np)
    
    with open(f"judgments_{file_suffix}.npy", 'wb') as f:
        np.save(f, judgment_np)

    with open(f"beatmaps_consts_{file_suffix}.npy", 'wb') as f:
        np.save(f, beatmap_const_np)
    
    return embedding_np, judgment_np, beatmap_const_np
        

## Driver Code

In [9]:
INDEX_PATH = "data/index.csv"
BEATMAP_PATH = "data/beatmaps"
REPLAY_PATH = "data/replays/osr"
DB_PATH = "data/replays/osr/.circleguard.db"
SLIDER_PATH = "data/beatmaps"

In [10]:
RANDOM_SEED = 12
BATCH_SIZE = 16384

In [11]:
# index and beatmaps

index_df = pd.read_csv(INDEX_PATH, low_memory = False)
beatmap_library = Library(BEATMAP_PATH)

print(f"Num. replays in index: {len(index_df)}")
print(f"Num. of beatmaps in library: {len(beatmap_library.ids)}")
index_df.head()

Num. replays in index: 381570
Num. of beatmaps in library: 55079


Unnamed: 0.1,Unnamed: 0,replayHash,beatmapHash,summary,date,playerName,modsReadable,mods,performance-IsFail,performance-Accuracy,performance-Score,performance-300s,performance-100s,performance-50s,performance-Misses,performance-Geki,performance-Katu,performance-MaxCombo,performance-IsFC,beatmap-Artist,beatmap-Title,beatmap-Version,beatmap-Id,beatmap-SetId,beatmap-BPMMin,beatmap-BPMMax,beatmap-HP,beatmap-OD,beatmap-AR,beatmap-CS,beatmap-MaxCombo,beatmap-HitObjects,beatmap-Circles,beatmap-Sliders,beatmap-Spinners,beatmapPlay-BPMMin,beatmapPlay-BPMMax,beatmapPlay-HP,beatmapPlay-OD,beatmapPlay-AR,beatmapPlay-CS,osrReplayUrl
0,0,857623324645f59599fc4e9c1c7e1130,db4cdde15984869de346686ceb6bc1a5,[7.44 ⭐] My Angel Shori | Imperial Circus Dead...,2022-11-11T08:03:56.656728,My Angel Shori,,0,False,0.954368,28363840,1771,72,1,37,422,50,950,False,Imperial Circus Dead Decadence,Hyakki Yakou -Pandemonic Night Parade-,Youkai,3638479.0,1776682.0,134,235,5.0,9.2,9.7,4.5,2515,1881,1438,442,1,134.0,235.0,5.0,9.2,9.7,4.5,https://dl.issou.best/ordr/replays/cfc43370ca0...
1,1,b1675645282756f18e0c3f142f290d94,04129d1b26d6bf3f35cb00a04f4a8c88,[6.55 ⭐] AlexTheProtoTTV | katagiri - Buta Mus...,2022-11-11T04:37:52.235160,AlexTheProtoTTV,,0,False,0.857687,2363246,991,174,10,50,274,103,194,False,katagiri,Buta Musou,AR 9.8,3777132.0,1647421.0,210,210,4.5,9.2,9.8,3.8,1660,1225,832,392,1,210.0,210.0,4.5,9.2,9.8,3.8,https://dl.issou.best/ordr/replays/e35f6f3e3b1...
2,2,6170131dcd9ad32d4c0287422b256423,8e3e4e77e6498a994bfa7505e735155d,[6.37 ⭐] Asio_ | Qrispy Joybox feat. mao - Col...,2022-11-11T08:01:02.823288,Asio_,HDDT,72,False,0.916667,777289,272,31,0,5,57,20,136,False,Qrispy Joybox feat. mao,Colorful Minutes,Beautiful Time,2584022.0,1242911.0,152,152,5.2,8.5,9.0,3.8,438,308,189,118,1,228.0,228.0,5.2,10.083333,10.333333,3.8,https://dl.issou.best/ordr/replays/f214715d45d...
3,3,d3fc501dadc1fabd60f847432cc3de84,27f9c5f8496bfdab8623e3be6cf73f25,[6.3 ⭐] sotarks fan123 | DragonForce - Valley ...,2022-11-11T07:46:29.042768,sotarks fan123,,0,True,0.948815,45264320,1111,68,1,15,182,42,1485,False,DragonForce,Valley of the Damned,Apocalypse,675734.0,67565.0,71,200,6.0,8.5,9.2,4.0,3089,2148,1486,660,2,71.0,200.0,6.0,8.5,9.2,4.0,https://dl.issou.best/ordr/replays/cf3f6778f86...
4,4,8e62a939e4c97fe22cbd90e90d338a53,1e8f966c7a8f992cb2f3f5bab7d55925,[6.34 ⭐] XxSzymenxx | DragonForce - Through th...,2022-03-11T10:40:06.741672,XxSzymenxx,,0,True,0.731128,486150,254,85,30,24,29,19,99,False,DragonForce,Through the Fire and Flames,Myth,1001682.0,382400.0,170,200,6.2,9.0,9.5,4.0,3220,2126,1534,587,5,170.0,200.0,6.2,9.0,9.5,4.0,https://dl.issou.best/ordr/replays/b90ed42a04b...


In [17]:
for batch_idx in range( math.floor(len(index_df) / BATCH_SIZE) + 1 ):

    cg = KeylessCircleguard( 
        db_path = DB_PATH,
        slider_dir = SLIDER_PATH,
        cache = True
    )
    
    replay_dir = ReplayDir(REPLAY_PATH)
    cg.load_info(replay_dir)

    start_idx = batch_idx * BATCH_SIZE
    end_idx = min( (batch_idx + 1) * BATCH_SIZE, len(replay_dir))

    replay_dir.replays = replay_dir.replays[start_idx : end_idx] 
    cg.load(replay_dir) # expensive

    valid_replay_df = build_valid_index_df(replay_dir)
    X_train, X_test, _, _ = train_test_split(valid_replay_df, valid_replay_df["mods"], test_size = 0.1, random_state = RANDOM_SEED, stratify = valid_replay_df["mods"])
    X_train, X_val, _, _  = train_test_split(X_train, X_train["mods"], test_size = 1/9, random_state = RANDOM_SEED, stratify = X_train["mods"])

    train_replays = [replay_dir[idx] for idx in X_train["replay_idx"].values]
    val_replays = [replay_dir[idx] for idx in X_val["replay_idx"].values]
    test_replays = [replay_dir[idx] for idx in X_test["replay_idx"].values]

    write_to_file(train_replays, f"{batch_idx}_train")
    write_to_file(val_replays, f"{batch_idx}_val")
    write_to_file(test_replays, f"{batch_idx}_test")