In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf

from circleguard import Circleguard, KeylessCircleguard, ReplayDir, User
from slider import Library

from utils.replay_processing import get_embeddings, get_judgments, get_beatmap_context

plt.style.use("ggplot")

print(tf.test.is_built_with_cuda())
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

True
Num GPUs Available:  1


In [19]:
EMBEDDING_DIM = 16
JUDGMENT_DIM = 4
BEATMAP_CONTEXT_DIM = 8
NOTES_PER_EXAMPLE = 8

INDEX_PATH = "../data/indices/base_index.csv"
BEATMAP_PATH = "../data/beatmaps"
REPLAY_PATH = "../data/live"
DB_PATH = "../data/replays/osr/.circleguard.db"
SLIDER_PATH = "../data/beatmaps"

CONTEXT_LABELS = [
    "ar",
    "circle_radius",
    "hp",
    "hitwindow_300",
    "hitwindow_100",
    "hitwindow_50",
    "hd",
    "dt",
]

EMBEDDING_LABELS = [
    "x_pos",
    "y_pos",
    "in_x_offset",
    "in_y_offset",
    "in_dist",
    "in_timedelta",
    "out_x_offset",
    "out_y_offset",
    "out_dist",
    "out_timedelta",
    "angle",
    "is_slider",
    "slider_duration",
    "slider_length",
    "slider_num_ticks",
    "slider_num_beats"
]

NULL_LABELS = [
    "p_300_null",
    "p_100_null",
    "p_50_null",
    "p_miss_null"
]

PROB_LABELS = [
    "p_300",
    "p_100",
    "p_50",
    "p_miss"
]

TRUE_LABELS = [
    "300",
    "100",
    "50",
    "miss"
]

In [3]:
index_df = pd.read_csv(INDEX_PATH, low_memory = False)
beatmap_library = Library(BEATMAP_PATH)

print(f"Num. replays in index: {len(index_df)}")
print(f"Num. of beatmaps in library: {len(beatmap_library.ids)}")
index_df.head()

Num. replays in index: 381570
Num. of beatmaps in library: 55087


Unnamed: 0.1,Unnamed: 0,replayHash,beatmapHash,summary,date,playerName,modsReadable,mods,performance-IsFail,performance-Accuracy,...,beatmap-Circles,beatmap-Sliders,beatmap-Spinners,beatmapPlay-BPMMin,beatmapPlay-BPMMax,beatmapPlay-HP,beatmapPlay-OD,beatmapPlay-AR,beatmapPlay-CS,osrReplayUrl
0,0,857623324645f59599fc4e9c1c7e1130,db4cdde15984869de346686ceb6bc1a5,[7.44 ⭐] My Angel Shori | Imperial Circus Dead...,2022-11-11T08:03:56.656728,My Angel Shori,,0,False,0.954368,...,1438,442,1,134.0,235.0,5.0,9.2,9.7,4.5,https://dl.issou.best/ordr/replays/cfc43370ca0...
1,1,b1675645282756f18e0c3f142f290d94,04129d1b26d6bf3f35cb00a04f4a8c88,[6.55 ⭐] AlexTheProtoTTV | katagiri - Buta Mus...,2022-11-11T04:37:52.235160,AlexTheProtoTTV,,0,False,0.857687,...,832,392,1,210.0,210.0,4.5,9.2,9.8,3.8,https://dl.issou.best/ordr/replays/e35f6f3e3b1...
2,2,6170131dcd9ad32d4c0287422b256423,8e3e4e77e6498a994bfa7505e735155d,[6.37 ⭐] Asio_ | Qrispy Joybox feat. mao - Col...,2022-11-11T08:01:02.823288,Asio_,HDDT,72,False,0.916667,...,189,118,1,228.0,228.0,5.2,10.083333,10.333333,3.8,https://dl.issou.best/ordr/replays/f214715d45d...
3,3,d3fc501dadc1fabd60f847432cc3de84,27f9c5f8496bfdab8623e3be6cf73f25,[6.3 ⭐] sotarks fan123 | DragonForce - Valley ...,2022-11-11T07:46:29.042768,sotarks fan123,,0,True,0.948815,...,1486,660,2,71.0,200.0,6.0,8.5,9.2,4.0,https://dl.issou.best/ordr/replays/cf3f6778f86...
4,4,8e62a939e4c97fe22cbd90e90d338a53,1e8f966c7a8f992cb2f3f5bab7d55925,[6.34 ⭐] XxSzymenxx | DragonForce - Through th...,2022-03-11T10:40:06.741672,XxSzymenxx,,0,True,0.731128,...,1534,587,5,170.0,200.0,6.2,9.0,9.5,4.0,https://dl.issou.best/ordr/replays/b90ed42a04b...


In [26]:
def _replay_to_np_features(replay, beatmap_library):
    
    beatmap_context = get_beatmap_context(replay, beatmap_library)
    beatmap_context = np.reshape(beatmap_context, (BEATMAP_CONTEXT_DIM,))
    
    embs = get_embeddings(replay, beatmap_library)
    replay_len = len(embs)
    
    if replay_len <= NOTES_PER_EXAMPLE:
        return None
    
    res = np.zeros((replay_len - NOTES_PER_EXAMPLE, NOTES_PER_EXAMPLE * EMBEDDING_DIM + BEATMAP_CONTEXT_DIM))
    
    for note_idx in range(0, replay_len - NOTES_PER_EXAMPLE):
        
        es = embs[replay_len - note_idx - NOTES_PER_EXAMPLE : replay_len - note_idx , :]
        es = np.reshape( es, (NOTES_PER_EXAMPLE * EMBEDDING_DIM,) )
    
        res[note_idx, : ] =  np.concatenate( (beatmap_context, es) )
        
    return res


def _np_embs_to_notes_df(np_features, judgs, model):
    
    if np_features is None:
        return None
    
    probs = model.predict_on_batch(np_features)
    
    res_df = pd.DataFrame(
        np.concatenate(
            (np_features[:, :BEATMAP_CONTEXT_DIM], np_features[:, -EMBEDDING_DIM:], probs, judgs), 
            axis = 1
        ), columns = CONTEXT_LABELS + EMBEDDING_LABELS + PROB_LABELS + TRUE_LABELS
    )[::-1]
    
    return res_df.reset_index(drop = True)


def get_notes_df( replay, beatmap_library, model, cg ):
    
    np_features =  _replay_to_np_features( replay, beatmap_library)
    judgs = get_judgments(replay, cg)[NOTES_PER_EXAMPLE:]
    
    return _np_embs_to_notes_df( np_features, judgs, model )


In [5]:
def _likelihood_lambda(row):
    
    res = 1
    
    if row["300"]:
        res =  row["p_300"]
    elif row["100"]:
        res =  row["p_300"] + row["p_100"]
    elif row["50"]:
        res =  row["p_300"] + row["p_100"] + row["50"]
    
    return -np.log(res)

    
def compute_likelihood(notes_df, top_k = 256 ):
    likelihoods = notes_df.apply(_likelihood_lambda, axis = 1)
    
    return np.sum ( sorted(likelihoods, reverse = True)[:top_k] ) 

In [6]:
cg = Circleguard(
    API_KEY, 
    db_path = DB_PATH, 
    slider_dir = SLIDER_PATH,
    cache = False
)

# cg = KeylessCircleguard(
#     db_path = DB_PATH,
#     slider_dir = SLIDER_PATH
# )

replay_dir = ReplayDir(REPLAY_PATH)
cg.load(replay_dir)

In [7]:
MODEL_PATH = "../models/naive/naive_final.keras"
model = tf.keras.models.load_model(MODEL_PATH)

In [28]:
res = []

for idx, replay in enumerate(replay_dir):
    
    try:
        n_df = get_notes_df(replay, beatmap_library, model, cg) 
    except KeyError as e:
        print(f"cant find beatmap for one of {replay.username}'s replays")
        continue
    
    likelihood = compute_likelihood(n_df)
    
    username = replay.username
    beatmap = replay.beatmap(beatmap_library)
    beatmap_name =  beatmap.display_name
    mod_str = str(replay.mods)
    stars = 0
    
    count_300 = replay.count_300
    count_100 = replay.count_100
    count_50 = replay.count_50
    count_miss = replay.count_miss
    
    try:
        stars = beatmap.stars()
    except:
        continue
    
    res.append([idx, username, beatmap_name, mod_str, stars, count_300, count_100, count_50, count_miss, likelihood])

cant find beatmap for one of Accolibed's replays
cant find beatmap for one of chocomint's replays
cant find beatmap for one of sytho's replays


In [29]:
pd.DataFrame(res, columns = ["replay_idx", "username", "beatmap_name", "mod_str", "nm_stars", 
                             "count_300", "count_100", "count_50", "count_miss",
                             "perplexity"]).to_csv("top20_naive.csv", index = False)