In [None]:
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import pandas as pd
from tbparse import SummaryReader
import os

In [None]:
from settings import TrainingSettings, DatasetSettings, EvaluationSettings, CNNSettings, CNNMambaSettings, CNNAttentionSettings, CRNNSettings, asdict

def get_columns_with_type(typ, get_only_model_settings=False) -> list[str]:
    out = []
    if not get_only_model_settings:
        classes = [TrainingSettings, DatasetSettings, EvaluationSettings]
        for cls in classes:
            settings = cls()
            dic = asdict(settings)
            for name, value in dic.items():
                if type(value) is typ:
                    out.append(name)
    classes = [CNNSettings, CNNAttentionSettings, CNNMambaSettings, CRNNSettings]
    for cls in classes:
        settings = cls(3, 84)
        dic = asdict(settings)
        for name, value in dic.items():
            if type(value) is typ:
                out.append(name)
    return list(set(out))

In [None]:
integer_values = list(set(get_columns_with_type(int) + ["early_stopping"]))
boolean_values = get_columns_with_type(bool)
float_values = get_columns_with_type(float)
string_values = get_columns_with_type(str) + ["dir_name"]
# string_values.remove("backbone")
integer_values.remove("dropout")

In [None]:
integer_values

# Convert hyperparameters and per step scalars

In [None]:
pl.config.Config.restore_defaults()

In [None]:
log_dir = "runs/CRNN/Vogl/Validation"
output_dir = "processed/CRNN/Vogl/Validation"

os.makedirs(output_dir, exist_ok=True)

logs = SummaryReader(log_dir, pivot=True, extra_columns={'dir_name'})

In [None]:
# convert types
params = logs.hparams
params = params.loc[:,~params.columns.duplicated()].copy()
params[params == "None"] = None
int_cols = list(set(integer_values) & set(params.columns.values))
params[int_cols] = params[int_cols].astype(pd.Int64Dtype(), errors="ignore")
bool_cols = list(set(boolean_values) & set(params.columns.values))
params[bool_cols] = params[bool_cols].astype(bool)
float_cols = list(set(float_values) & set(params.columns.values))
params[float_cols] = params[float_cols].astype(np.float64)
string_cols = list(set(string_values) & set(params.columns.values))
params[string_cols] = params[string_cols].astype(pd.StringDtype())
params = pd.DataFrame(params)

In [None]:
scalars = logs.scalars
# broken_columns = ["F-Score/Avg/Test/RBMA", "F-Score/Sum/Test/RBMA", "Loss/Test/RBMA"]
# for col in broken_columns:
#     scalars = scalars[scalars[col].apply(lambda x: isinstance(x, float))].copy(deep=True)
# scalars[broken_columns] = scalars[broken_columns].astype(np.float64)
scalars.drop_duplicates(subset=["dir_name", "step"], keep=False, inplace=True)

In [None]:
logs.tensors

In [None]:
params = pl.from_pandas(params, nan_to_null=True)

In [None]:
scores = pl.from_pandas(scalars, nan_to_null=True)

In [None]:
hparams = params.join(scores.select(pl.col("F-Score", "dir_name")), on='dir_name', how='inner')
hparams = hparams.select(pl.all().exclude("dir_name"), pl.col("dir_name").str.split("/").list.first())

In [None]:
# convert time
hparams = hparams.with_columns(pl.col("dir_name").str.split("_").list.slice(0, 2).list.join("_").str.to_datetime("%b%d_%H-%M-%S", strict=False, ambiguous="earliest").alias("start_time"))
hparams = hparams.with_columns(pl.datetime(2025, pl.col("start_time").dt.month(), pl.col("start_time").dt.day(), pl.col("start_time").dt.hour(), pl.col("start_time").dt.minute(), pl.col("start_time").dt.second(), time_unit="ns").alias("start_time"))

In [None]:
scalars

In [None]:
hparams.write_parquet(f"{output_dir}/hparams.parquet")
scores.write_parquet(f"{output_dir}/scores.parquet")

# Process plots

In [None]:
logs_no_pivot = SummaryReader(log_dir, pivot=False, extra_columns={'dir_name'})

In [None]:
# Convert raw tensor data
tensors = logs_no_pivot.tensors
columns = tensors["tag"].unique().astype(str)
columns = columns[~np.char.endswith(columns, "Best_Thresholds")]
tensor_shapes = np.array(tensors.loc[tensors["tag"] == columns[0]].iloc[0]["value"].shape)
tensors["value"] = tensors["value"].apply(lambda x: x.flatten())

In [None]:
tensors = pl.from_pandas(tensors, nan_to_null=True, schema_overrides={"tag": pl.String, "dir_name": pl.String}, include_index=True)
tensors = tensors.filter(pl.col("tag").is_in(columns))
num_rows = tensors.select(pl.len()).to_series()[0]
tensors = tensors.with_columns(pl.col("value").reshape(tuple([num_rows, *tensor_shapes])).alias("value"))

In [None]:
tensors = tensors.unique(subset=["dir_name", "step", "tag"], keep="none")
tensors = tensors.pivot(values=["value"], on=["tag"], index=["dir_name", "step"])
tensors = tensors.sort("dir_name", "step")

In [None]:
tensors.write_parquet("./processed/BA_fixed/tensors.parquet", compression="zstd", compression_level=22)

In [None]:
# Convert plots
images = logs_no_pivot.images
size = np.array(images.iloc[0]["value"].shape)
images["value"] = images["value"].apply(lambda x: x.flatten())

In [None]:
# dumb stuff for datasets that barely fit into memory
import gc

data = [data for data in images["value"].to_numpy() if type(data) is np.ndarray]
images = images.drop("value", axis=1)
images_buf = images.copy(deep=True)
del images
gc.collect()
data = np.array(data)
images = images_buf
del images_buf

In [None]:
# plots = pl.read_parquet("./processed/images.parquet")
plots = pl.from_pandas(images, nan_to_null=True, schema_overrides={"tag": pl.String, "dir_name": pl.String}, include_index=True)
plots = plots.with_columns(pl.Series(name="value", values=data))

In [None]:
num_rows = plots.select(pl.len()).to_series()[0]

In [None]:
plots = plots.with_columns(pl.col("value").reshape(tuple([num_rows, *size])).alias("value"))

In [None]:
plots.write_parquet("./processed/BA_fixed/plots.parquet", compression="zstd")

In [None]:
from PIL import Image
from io import BytesIO
# compress images for easier handling

plots = pl.scan_parquet("./processed/BA_fixed/plots.parquet")

def convert_np_to_png(array: np.ndarray):
    file_buf = BytesIO()
    img = Image.fromarray(array, mode="RGB")
    img.save(file_buf, format="PNG", optimize=True)
    png_bin = file_buf.getvalue()
    return png_bin

In [None]:
plots.select(pl.len()).collect()

In [None]:
test_img = plots.head(1).collect()[0, 3]
png_img = convert_np_to_png(np.array(test_img))
as_polars = pl.DataFrame({"img": [png_img]})
png_img = as_polars[0, "img"]
reversed_image = Image.open(BytesIO(png_img))
reversed_image.show()

test_img_map = plots.head(1).collect().with_columns(pl.col("value").map_elements(lambda x: convert_np_to_png(np.array(x, dtype=np.uint8)), return_dtype=pl.Binary))[0, 3]
reversed_image = Image.open(BytesIO(test_img_map))
reversed_image.show()

In [None]:
# very slow
plots.with_columns(pl.col("value").map_elements(lambda x: convert_np_to_png(np.array(x, dtype=np.uint8)), return_dtype=pl.Binary, strategy="threading")).sink_parquet("./processed/BA_fixed/plots_png.parquet", compression_level=22)


In [None]:
plots = pl.scan_parquet("./processed/plots_png.parquet")

In [None]:
plots = plots.unique(subset=["dir_name", "step", "tag"], keep="none").collect()
plots = plots.pivot(values=["value"], on=["tag"], index=["dir_name", "step"])
plots = plots.sort("dir_name", "step")

In [None]:
plots.write_parquet("./processed/BA_fixed/plots_png_pivot.parquet", compression_level=22)
del plots

# Inspect data

In [None]:
hparams = pl.scan_parquet("./processed/hparams.parquet")
scalars = pl.scan_parquet("./processed/scores.parquet")
plots = pl.scan_parquet("./processed/plots_png_pivot.parquet")

In [None]:
hparams.collect()

In [None]:
pl.Config.set_tbl_cols(100)
pl.Config.set_tbl_rows(100)


def get_model_settings(model_type: str) -> (pl.DataFrame, pl.DataFrame):
    global hparams
    model = hparams.filter(pl.col("model_settings").str.contains(model_type))
    non_null = model.select(pl.all().is_not_null().all()).row(0)
    model = model[:, non_null]
    different = model.select(pl.all().n_unique() > 1).row(0)
    diff = model[:, different].sort("F-Score", descending=True)
    iden = model.select(pl.all().n_unique() == 1).row(0)
    identical = model[:, iden].select(pl.all().exclude("dir_name", "F-Score")).limit(1)

    return diff, identical


def get_history(name: str) -> pl.DataFrame:
    global plots, scores
    data = scores.filter(pl.col("dir_name") == name)
    prs = plots.filter(pl.col("dir_name") == name)
    data = data.join(prs, on="step", how="inner")
    return data

In [None]:
unique, identical = get_model_settings("mamba")
print(identical)
unique

In [None]:
unique, identical = get_model_settings("attention")
print(identical)
unique

In [None]:
unique, identical = get_model_settings("cnn")
print(identical)
unique

In [None]:
unique, identical = get_model_settings("crnn")
print(identical)
unique

In [None]:
best = unique.select("dir_name").row(0)[0]
history = get_history(best)
tags = ["Validation/PR-Curve/" ,"Test/RBMA_full/PR-Curve/", "Test/MDB_full/PR-Curve/", "Validation/Threshold-Curve/", "Test/RBMA_full/Threshold-Curve/", "Test/MDB_full/Threshold-Curve/"]
curves = []
for tag in tags:
    curves.append(history.select(pl.col(tag)).filter(pl.all().is_not_null()).row(-1)[0])
# Show Curves in a grid
fig, axs = plt.subplots(2, 3, figsize=(15, 10))
for i, ax in enumerate(axs.flat):
    ax.imshow(curves[i])
    ax.axis("off")
    # reduce spacing
    ax.margins(0)
    ax.axis("tight")


In [None]:
def get_model_dict(unique: pl.DataFrame, identical: pl.DataFrame) -> dict:
    dics = []
    for row in unique.to_dicts():
        dics.append(row | identical.to_dicts()[0])
    return dics

In [None]:
get_model_dict(unique, identical)[0]

In [None]:
# Get Value
params.loc[params["dir_name"].str.contains(best)][["full_length_test"]]

# Import data from optuna

In [None]:
import polars as pl
import numpy as np
import optuna
import ipywidgets as widgets
import polars.selectors as cs
from glob import glob

_ = pl.Config.restore_defaults()

In [None]:
hparams = pl.scan_parquet("processed/BA_fixed/hparams.parquet")
scalars = pl.scan_parquet("processed/BA_fixed/scores.parquet")
tensors = pl.scan_parquet("processed/BA_fixed/tensors.parquet")

In [None]:
storage: optuna.storages.BaseStorage = None

db_files = glob("./optuna/*.db")
conn_strings = ["sqlite:///" + db for db in db_files]

databases = ["postgresql://BachelorarbeitSync:BachelorarbeitSyncPlsDontHackMe@192.168.2.206:5432"]
databases = conn_strings + databases

db_dropdown = widgets.Dropdown(options=databases, description="Database: ")

def set_storage(connection_string: str):
    global storage
    storage = optuna.storages.RDBStorage(
        url=connection_string,
        engine_kwargs={"pool_pre_ping": True, "pool_recycle": 3600, "pool_timeout": 3600},
        heartbeat_interval=60,
        grace_period=3600,
    )

widgets.interact(set_storage, connection_string=db_dropdown)

study_dropdown = widgets.Dropdown(options=[], description="Study: ")

def update_options(*args):
    studies = storage.get_all_studies()
    study_dropdown.options = [stdy.study_name for stdy in studies]
db_dropdown.observe(update_options, "value")

def select_study(selected_study):
    global study
    global storage

    study = optuna.load_study(storage=storage, study_name=selected_study)

_ = widgets.interact(select_study, selected_study=study_dropdown)
study: optuna.Study = study

In [None]:
study_data = study.trials_dataframe()
study_data = pl.from_pandas(study_data).lazy()
study_data = study_data.filter(pl.col("state") == "COMPLETE").sort(cs.contains("value"), descending=True)
num_trials = study_data.select(pl.len()).collect()[0, 0]
print(num_trials)

optuna_param_list = study_data.select(cs.contains("params_").name.map(lambda s: s.removeprefix("params_"))).collect_schema()

study_data = (
    study_data
    .select(cs.contains("params_").name.map(lambda s: s.removeprefix("params_")), ~cs.contains("params_"))
    .select(~(cs.contains("user") | cs.contains("system_attrs")), cs.contains("user").name.map(lambda x: x.removeprefix("user_attrs_")))
    .select(~cs.contains("f_score"), cs.contains("f_score").name.prefix("optuna_"))
    .drop("state", "number")
    .with_columns(pl.lit(study.study_name).alias("study_name"))
)
study_data = study_data.with_columns(pl.col(list(set(integer_values) & set(study_data.collect_schema().names()))).cast(pl.Int64))
logs_param_list = hparams.collect_schema()
matched_params = [key for key, dtype in optuna_param_list.items() if key in logs_param_list.keys()]

# matched_params.remove("expansion_factor")
# matched_params.remove("hidden_units")
print(logs_param_list)
matched_params

In [None]:
hparams.collect()

In [None]:
study_data.collect()

In [None]:
from itertools import combinations

res = reversed([list(combinations(matched_params, r)) for r in range(1, len(matched_params) + 1)])
res = [list(sublist) for g in res for sublist in g]

full_matched_params = None

for comb in res:
    joined = hparams.with_columns(cs.by_dtype(pl.String).exclude("dir_name").str.to_lowercase()).join(study_data, on=comb, how="inner", nulls_equal=True)
    unique_matches = joined.unique(subset="datetime_start").select(pl.len()).collect()[0, 0]
    if unique_matches == num_trials:
        full_matched_params = comb
        print(full_matched_params)
        break

# assert full_matched_params is not None, "No matching subset was found. Are logs for all the experiments present?"

# full_matched_params = comb
mismatched_params = [param for param in matched_params if param not in full_matched_params]

joined = hparams.with_columns(cs.by_dtype(pl.String).exclude("dir_name").str.to_lowercase()).join(study_data, on=full_matched_params, how="cross", nulls_equal=True).collect()
joined = joined[[s.name for s in joined if not (s.null_count() == joined.height)]]
joined = joined.filter((abs((pl.col("start_time") - pl.col("datetime_start"))) < pl.duration(minutes=30)) & (pl.col("start_time") > pl.col("datetime_start")))
useful_cols = [col + "_right" for col in mismatched_params if col not in joined.columns]
if len(useful_cols) > 0:
    joined = joined.select(cs.exclude(useful_cols), pl.col(useful_cols).name.map(lambda s: s.removesuffix("_right")))
joined = joined.select(~cs.contains("_right")) # parameters that are mismatched are most likely due to a param being generated but not assigned
joined = joined[[s.name for s in joined if not (s.null_count() == joined.height)]]
joined

# sub = joined.select("dir_name", "datetime_start", cs.contains(*mismatched_params))
# sorted_mis = sorted(sub.columns)
# sorted_mis.remove("datetime_start")
# sorted_mis.remove("dir_name")
# duplicated_dates = sub.group_by("datetime_start").agg(pl.len().alias("count")).filter(pl.col("count") > 1).select("datetime_start").to_series().to_list()
# sub.select("dir_name", "datetime_start", *sorted_mis).sort("datetime_start").filter(pl.col("datetime_start").is_in(duplicated_dates))

In [None]:
# joined = hparams.with_columns(cs.by_dtype(pl.String).exclude("dir_name").str.to_lowercase()).with_columns(pl.duration(minutes=2).alias("tolerance")).join_where(study_data, abs((pl.col("start_time") - pl.col("datetime_start"))) < pl.col("tolerance"))
study_data.filter(~(pl.col("datetime_start").is_in(joined.select(pl.col("datetime_start")).to_series()))).collect()

In [None]:
import pathlib

if not pathlib.Path("./processed/optuna.parquet").exists():
    joined.write_parquet("./processed/optuna.parquet")
else:
    optuna_collection = pl.read_parquet("./processed/optuna.parquet")
    optuna_cols = set(optuna_collection.columns)
    optuna_schema = optuna_collection.schema
    joined_cols = set(joined.columns)
    joined_schema = joined.schema
    missing_cols = joined_cols - optuna_cols
    missing_schema = optuna_cols - joined_cols
    print(missing_cols)
    print(missing_schema)
    for col in missing_cols:
        optuna_collection = optuna_collection.with_columns(pl.lit(None).alias(col).cast(joined.schema[col]))
    for col in missing_schema:
        joined = joined.with_columns(pl.lit(None).alias(col).cast(optuna_collection.schema[col]))
    joined = joined.select(pl.col(optuna_collection.columns))
    out = pl.concat([optuna_collection, joined], how="vertical_relaxed")
    out.write_parquet("./processed/optuna.parquet")

    # optuna_collection.write_parquet("./processed/optuna.parquet")



In [None]:
optuna_collection = pl.read_parquet("./processed/optuna.parquet")
optuna_collection = optuna_collection[[s.name for s in optuna_collection if not (s.null_count() == optuna_collection.height)]]
optuna_collection.unique(subset="dir_name", keep="any").write_parquet("./processed/optuna.parquet", compression="zstd")

In [None]:
hparams.filter((pl.col("dataset_version") == "M") | (pl.col("dataset_version") == "L")).sort("F-Score", descending=True).collect()

In [None]:
scalar_cols = set(scalars.collect_schema().names()) - {"dir_name", "step"}
plots_cols = set(plots.collect_schema().names()) - {"dir_name", "step"}
# data points that get lost in the join are most likely due to ^C exit
joined_plots = scalars.join(plots, on=["dir_name", "step"], how="inner").sort(["dir_name", "step"])
joined_plots.sink_parquet("./processed/plots_with_scalars.parquet", compression="zstd", compression_level=22)



# View Data

In [None]:
from functools import lru_cache

import polars as pl
import numpy as np
import optuna
import ipywidgets as widgets
import polars.selectors as cs
import seaborn as sns
import matplotlib.pyplot as plt
from PIL import Image
import io

In [None]:
pl.Config.set_tbl_hide_column_data_types(True)
pl.Config.set_tbl_hide_dataframe_shape(True)
pl.Config.set_tbl_cols(17)
pl.Config.set_tbl_rows(20)

def drop_columns_that_are_all_null(_df: pl.DataFrame) -> pl.DataFrame:
    return _df[[s.name for s in _df if not (s.null_count() == _df.height)]]

assert integer_values
def sort_columns(_df: pl.DataFrame) -> pl.DataFrame:
    parameters = set(integer_values) | set(boolean_values) | set(string_values) | set(float_values)
    columns = set(_df.columns)
    scores = sorted(list(set(_df.select(cs.contains("core")).columns) - parameters))
    losses = sorted(list(set(_df.select(cs.contains("oss")).columns) - parameters))
    parameters = sorted(list(parameters & columns))
    sorted_columns: list[str] = ["dir_name", "F-Score", "flops", "params", *scores, *losses, *parameters, *columns]

    sorted_columns = [col for col, _ in dict([item[::-1] for item in enumerate(sorted_columns)]).items()] # deduplication using dict as an ordered set
    sorted_columns = [col for col in sorted_columns if col in columns]
    return _df.select(pl.col(*sorted_columns))

In [None]:
hparams = pl.scan_parquet("./processed/optuna.parquet")
plots = pl.scan_parquet("./processed/plots_png_pivot.parquet")
scalars = pl.scan_parquet("./processed/scores.parquet")

In [None]:
model_settings: pl.DataFrame = None
# select model type
model_types = hparams.select(pl.col("model_settings").unique()).collect().to_series().to_list()
model_selector = widgets.Dropdown(options=model_types, description="Model Type: ")

def set_model_settings(model_type: str):
    global model_settings
    best_values = scalars.select(
        pl.col("dir_name"),
        cs.contains("core").max().over("dir_name"),
        *[pl.col("step").get(pl.col(_col).arg_max()).over("dir_name").alias(_col + "_step") for _col in scalars.select(cs.contains("core")).collect_schema().names()],
        cs.contains("oss").min().over("dir_name"),
        *[pl.col("step").get(pl.col(_col).arg_min()).over("dir_name").alias(_col + "_step") for _col in scalars.select(cs.contains("oss")).collect_schema().names()]
    ).unique("dir_name")
    _df = hparams.filter(pl.col("model_settings") == model_type).collect()
    _df = _df.lazy().join(best_values, on=pl.col("dir_name")).sort("F-Score", descending=True).collect()
    model_settings = sort_columns(drop_columns_that_are_all_null(_df))


# set_model_settings("crnn")


_ = widgets.interact(set_model_settings, model_type=model_selector)

In [None]:
model_settings.sort("dir_name")

In [None]:
selectable_range = 20


run_options = model_settings.head(selectable_range).select("dir_name").to_series().to_list()
run_selector = widgets.Dropdown(options=run_options, value=run_options[0])

run = run_options[0]

@lru_cache
def get_run(name: str):
    settings = model_settings.filter(pl.col("dir_name") == name)
    settings = drop_columns_that_are_all_null(settings)
    model_params = sorted(set([param for typ in [int, float, str, bool] for param in get_columns_with_type(typ, True)]) & set(settings.columns))
    rows = [
        settings.select(pl.col(*settings.columns[:4], "dataset_version")),
        settings.select(cs.contains("/") & cs.contains("core") & ~cs.contains("_step")),
        settings.select(cs.contains("/") & cs.contains("core") & cs.contains("_step")),
        settings.select(pl.col(model_params)),
    ]

    values = scalars.lazy().filter(pl.col("dir_name") == name).sort("step")
    losses = values.select(pl.col("step"), cs.by_dtype(pl.Float64) & cs.contains("oss")).collect()
    losses = (
        drop_columns_that_are_all_null(losses)
            .lazy()
            .unpivot(cs.contains("oss"), index="step", variable_name="tag")
            .with_columns(pl.col("tag").str.split("/").list.to_struct(n_field_strategy="max_width", fields=["score", "split", "tag"]))
            .unnest("tag")
            .drop("score")
            .with_columns(pl.col("split", "tag").fill_null(pl.col("split")))
            .collect()
    )
    f_scores = values.select("step", cs.by_dtype(pl.Float64) & cs.contains("core")).collect()

    f_scores = (
        drop_columns_that_are_all_null(f_scores)
            .lazy()
            .unpivot(cs.contains("core"), index="step", variable_name="tag")
            .with_columns(pl.col("tag").str.split("/").list.to_struct(n_field_strategy="max_width", fields=["score", "type", "split", "tag"]))
            .unnest("tag")
            .drop("score")
            .with_columns(pl.col("type", "split", "tag").fill_null(pl.col("split")))
            .collect()
    )
    return rows, losses, f_scores

def plot_run(name: str):
    global run
    run = name
    step_selector.max = scalars.filter(pl.col("dir_name") == run).select(pl.col("step").max()).collect().to_series()[0]

    rows, losses, f_scores = get_run(name)
    for row in rows:
        print(row)

    fig, ax = plt.subplots(figsize=(15, 6), ncols=2)
    sns.lineplot(data=losses, x="step", y="value", hue="tag", ax=ax[0], style="split")
    ax[0].set_title(f"Losses for {name}")
    ax[0].set_xlabel("Step")
    ax[0].set_ylabel("Loss")
    ax[0].set_yscale('log')

    sns.lineplot(data=f_scores, x="step", y="value", hue="tag", ax=ax[1], style="type")
    ax[1].set_title(f"F-Scores for {name}")
    ax[1].set_xlabel("Step")
    ax[1].set_ylabel("F-Score")
    ax[1].set_ylim(0, 1)
    ax[1].set_yticks(np.arange(0, 1.1, 0.1))

    plt.show()


# plot_run(run)


def update_run_options(*args):
    global run_selector
    run_selector.options = model_settings.head(selectable_range).select("dir_name").to_series().to_list()
model_selector.observe(update_run_options, "value")

@lru_cache
def get_plots(run, step):
    global plots
    step_plots = (
        plots.lazy()
            .filter(pl.col("dir_name") == run)
            .sort("step")
            .select(pl.all().fill_null(strategy="forward"))
            .filter(pl.col("step") == step)
            .select(cs.by_dtype(pl.Binary))
            .collect()
    )
    return drop_columns_that_are_all_null(step_plots)

def plot_step(step: int):
    global plots
    step_plots = get_plots(run_selector.value, step)
    pr_plots = step_plots.select(cs.contains("PR-Curve"))
    threshold_plots = step_plots.select(cs.contains("Threshold"))
    to_plot = [pr_plots, threshold_plots]
    n_tags = pr_plots.shape[1]
    if n_tags == 0:
        return
    _fig, _axs = plt.subplots(figsize=(10, 12), ncols=2, nrows=n_tags)
    if n_tags == 1:
        _axs = [_axs]
    for _i, ax_row in enumerate(_axs):
        for j, _ax in enumerate(ax_row):
            png = to_plot[j][0, _i]
            png_file = io.BytesIO(png)
            img = np.asarray(Image.open(png_file, formats=["PNG"]))
            _ax.imshow(img)
            _ax.axis("off")
            # reduce spacing
            _ax.margins(0)
            # _ax.axis("tight")

    plt.show()

# plot_step(0)

max_step = scalars.filter(pl.col("dir_name") == run).select(pl.col("step").max()).collect().to_series()[0]
step_selector = widgets.IntSlider(min=0, max=max_step)

def update_max_steps(*args):
    step_selector.max = scalars.filter(pl.col("dir_name") == run).select(pl.col("step").max()).collect().to_series()[0]
model_selector.observe(update_max_steps, "value")
run_selector.observe(update_max_steps, "value")

def plot_all(name: str, step: int):
    print(model_settings.head(selectable_range).select(pl.col(*model_settings.columns[:4], "dataset_version")))
    # print(model_settings.head(selectable_range).select(cs.contains("Test")))
    plot_run(name)
    plot_step(step)

# plot_all(run , 0)

run_plot = widgets.interactive(plot_all, name=run_selector, step=step_selector)
output = run_plot.children[-1]
output.layout.height = '2400px'

run_plot


# Evaluate hyperparameter validation

In [1]:
import polars as pl
import polars.selectors as cs
import matplotlib.pyplot as plt
import seaborn as sns
from settings import TrainingSettings, DatasetSettings, EvaluationSettings, CNNSettings, CNNMambaSettings, CNNAttentionSettings, CRNNSettings, asdict, Config

In [13]:
def get_columns_with_type(typ, get_only_model_settings=False) -> list[str]:
    out = []
    if not get_only_model_settings:
        classes = [TrainingSettings, DatasetSettings, EvaluationSettings]
        for cls in classes:
            settings = cls()
            dic = asdict(settings)
            for name, value in dic.items():
                if type(value) is typ:
                    out.append(name)
    classes = [CNNSettings, CNNAttentionSettings, CNNMambaSettings, CRNNSettings]
    for cls in classes:
        settings = cls(3, 84)
        dic = asdict(settings)
        for name, value in dic.items():
            if type(value) is typ:
                out.append(name)
    return list(set(out))

def drop_columns_that_are_all_null(_df: pl.DataFrame) -> pl.DataFrame:
    return _df[[s.name for s in _df if not (s.null_count() == _df.height)]]

def get_settings(param_row: pl.DataFrame):
    settings = drop_columns_that_are_all_null(param_row)
    model_params = sorted(set([param for typ in [int, float, str, bool] for param in get_columns_with_type(typ, False)]) & set(settings.columns))
    _dict = settings.select(pl.col(model_params)).to_dicts()[0]
    if _dict["model_settings"] == "crnn" and "hidden_units" in _dict.keys():
        _dict.pop("hidden_units")
    if _dict["model_settings"] in ["crnn", "mamba", "mamba_fast"] and "use_relative_pos" in _dict.keys():
        _dict.pop("use_relative_pos")
    if _dict["model_settings"].startswith("mamba") and "expansion_factor" in _dict.keys():
        _dict.pop("expansion_factor")
    config = Config.from_flat_dict(_dict)
    reversed_settings = {
        **asdict(config.training),
        **asdict(config.evaluation),
        **asdict(config.dataset),
    }
    if config.model is not None:
        reversed_settings.update(asdict(config.model))
    for key, item in _dict.items():
        if key in ["activation", "mapping", "splits", "test_sets"]:
            assert str(item) == str(reversed_settings[key])
            # print(f"{item} == {reversed_settings[key]}: Please check manually if {key} is equal")
            continue
        assert item == reversed_settings[key], f"Key {key} is mismatched: {item}({type(item)}) != {reversed_settings[key]}({type(reversed_settings[key])})"
    return _dict

In [3]:
output_dir = "processed/CRNN/Params validation"

In [4]:
hparams = pl.read_parquet(f"{output_dir}/hparams.parquet")
scores = pl.read_parquet(f"{output_dir}/scores.parquet").filter(pl.col("F-Score").is_null()).drop("F-Score")

In [12]:
hparams

activation,batch_size,beats,beta_1,beta_2,causal,center,channel_multiplication,classifier_dim,dataset_version,decoupled_weight_decay,detect_tolerance,down_sample_factor,dropout,early_stopping,ema,epochs,epsilon,eval_set,fft_size,flux,frame_length,frame_overlap,full_length_test,hop_size,ignore_beats,label_lead_in,label_lead_out,learning_rate,mapping,mel_max,mel_min,min_save_score,min_test_score,model_settings,n_mels,normalize,num_channels,num_conv_layers,num_rnn_layers,num_workers,onset_cooldown,pad_annotations,pad_mode,pad_value,peak_max_range,peak_mean_range,positive_weight,power,pr_points,rnn_units,sample_rate,scheduler,seed,segment_type,splits,test_batch_size,test_sets,time_shift,train_set,weight_decay,F-Score,dir_name,start_time
str,i64,bool,f64,f64,bool,bool,i64,i64,str,bool,f64,i64,f64,str,bool,i64,f64,str,i64,bool,f64,f64,bool,i64,bool,f64,f64,f64,str,f64,f64,f64,f64,str,i64,bool,i64,i64,i64,i64,f64,bool,str,f64,i64,i64,f64,i64,i64,i64,i64,bool,f64,str,str,i64,str,f64,str,f64,f64,str,datetime[ns]
"""ReLU""",37,false,0.9,0.999,true,true,3,31,"""S""",true,0.025,4,0.35,"""5""",false,30,1.0000e-8,"""A2MD""",2048,true,2.3,1.1,true,441,true,0.25,0.1,0.003421,"""Three class standard""",20000.0,20.0,0.7,0.6,"""crnn""",84,false,16,2,3,16,0.021,true,"""constant""",0.25,2,2,1.0,1,100,61,44100,false,0.0,"""frame""","""[0.75, 0.25, 0.0]""",1,"""('RBMA', 'MDB')""",0.035,"""a2md_train""",6.0639e-7,0.68648,"""Jun30_16-53-32_marclie-desktop""",2025-06-30 16:53:32
"""ReLU""",7,false,0.9,0.999,true,true,1,32,"""S""",true,0.025,1,0.3,"""5""",false,30,1.0000e-8,"""A2MD""",2048,true,3.9,1.1,true,441,true,0.25,0.1,0.00362,"""Three class standard""",20000.0,20.0,0.7,0.6,"""crnn""",96,false,32,0,2,16,0.021,true,"""constant""",0.25,2,2,1.0,1,100,23,44100,false,1.0,"""frame""","""[0.75, 0.25, 0.0]""",1,"""('RBMA', 'MDB')""",0.035,"""a2md_train""",1.1095e-15,0.655962,"""Jun30_16-58-38_marclie-desktop""",2025-06-30 16:58:38
"""ELU""",6,false,0.9,0.999,true,true,1,33,"""S""",true,0.025,4,0.35,"""5""",false,30,1.0000e-8,"""A2MD""",2048,true,3.5,1.2,true,441,true,0.25,0.1,0.000779,"""Three class standard""",20000.0,20.0,0.7,0.6,"""crnn""",84,false,16,0,4,16,0.021,true,"""constant""",0.25,2,2,1.0,1,100,70,44100,false,2.0,"""frame""","""[0.75, 0.25, 0.0]""",1,"""('RBMA', 'MDB')""",0.035,"""a2md_train""",2.0121e-7,0.72142,"""Jun30_17-02-41_marclie-desktop""",2025-06-30 17:02:41
"""ELU""",12,false,0.9,0.999,true,true,1,45,"""S""",true,0.025,4,0.5,"""5""",false,30,1.0000e-8,"""A2MD""",2048,true,4.1,1.2,true,441,true,0.25,0.1,0.004734,"""Three class standard""",20000.0,20.0,0.7,0.6,"""crnn""",84,false,16,1,3,16,0.021,true,"""constant""",0.25,2,2,1.0,1,100,55,44100,false,3.0,"""frame""","""[0.75, 0.25, 0.0]""",1,"""('RBMA', 'MDB')""",0.035,"""a2md_train""",8.0411e-8,0.724782,"""Jun30_17-05-45_marclie-desktop""",2025-06-30 17:05:45
"""ELU""",8,false,0.9,0.999,true,true,1,19,"""S""",true,0.025,3,0.3,"""5""",false,30,1.0000e-8,"""A2MD""",2048,true,3.5,1.3,true,441,true,0.25,0.1,0.000469,"""Three class standard""",20000.0,20.0,0.7,0.6,"""crnn""",84,false,24,1,1,16,0.021,true,"""constant""",0.25,2,2,1.0,1,100,100,44100,false,4.0,"""frame""","""[0.75, 0.25, 0.0]""",1,"""('RBMA', 'MDB')""",0.035,"""a2md_train""",2.9620e-8,0.716706,"""Jun30_17-08-24_marclie-desktop""",2025-06-30 17:08:24
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""ELU""",8,false,0.9,0.999,true,true,1,46,"""S""",true,0.025,4,0.45,"""5""",false,30,1.0000e-8,"""A2MD""",2048,true,4.9,1.5,true,441,true,0.25,0.1,0.004061,"""Three class standard""",20000.0,20.0,0.7,0.6,"""crnn""",84,false,16,1,3,16,0.021,true,"""constant""",0.25,2,2,1.0,1,100,109,44100,false,45.0,"""frame""","""[0.75, 0.25, 0.0]""",1,"""('RBMA', 'MDB')""",0.035,"""a2md_train""",5.9861e-9,0.666329,"""Jun30_19-58-26_marclie-desktop""",2025-06-30 19:58:26
"""SELU""",4,false,0.9,0.999,true,true,2,16,"""S""",true,0.025,4,0.0,"""5""",false,30,1.0000e-8,"""A2MD""",2048,true,7.0,1.4,true,441,true,0.25,0.1,0.000677,"""Three class standard""",20000.0,20.0,0.7,0.6,"""crnn""",64,false,24,0,5,16,0.021,true,"""constant""",0.25,2,2,1.0,1,100,114,44100,false,46.0,"""frame""","""[0.75, 0.25, 0.0]""",1,"""('RBMA', 'MDB')""",0.035,"""a2md_train""",3.3012e-12,0.729093,"""Jun30_20-02-54_marclie-desktop""",2025-06-30 20:02:54
"""ELU""",5,false,0.9,0.999,true,true,1,38,"""S""",true,0.025,4,0.3,"""5""",false,30,1.0000e-8,"""A2MD""",2048,true,2.2,1.0,true,441,true,0.25,0.1,0.000397,"""Three class standard""",20000.0,20.0,0.7,0.6,"""crnn""",128,false,24,0,3,16,0.021,true,"""constant""",0.25,2,2,1.0,1,100,94,44100,false,47.0,"""frame""","""[0.75, 0.25, 0.0]""",1,"""('RBMA', 'MDB')""",0.035,"""a2md_train""",1.1468e-8,0.701014,"""Jun30_20-07-04_marclie-desktop""",2025-06-30 20:07:04
"""ELU""",9,false,0.9,0.999,true,true,1,35,"""S""",true,0.025,3,0.2,"""5""",false,30,1.0000e-8,"""A2MD""",2048,true,4.2,1.7,true,441,true,0.25,0.1,0.000533,"""Three class standard""",20000.0,20.0,0.7,0.6,"""crnn""",84,false,32,0,2,16,0.021,true,"""constant""",0.25,2,2,1.0,1,100,64,44100,false,48.0,"""frame""","""[0.75, 0.25, 0.0]""",1,"""('RBMA', 'MDB')""",0.035,"""a2md_train""",9.5536e-10,0.736903,"""Jun30_20-10-49_marclie-desktop""",2025-06-30 20:10:49


In [14]:
group_index = hparams.sort("dir_name").group_by("learning_rate", maintain_order=True).agg("dir_name").drop("learning_rate").with_row_index("group").explode("dir_name")
params = hparams.join(group_index, on="dir_name")
grouped = scores.with_columns(pl.col("dir_name").replace_strict(group_index[:, "dir_name"], group_index[:, "group"], return_dtype=pl.Int16).alias("group"))
aggregated = grouped.group_by("dir_name", "group").agg(pl.col("step").max(), cs.contains("core").max()).group_by("group").agg(pl.exclude("dir_name").mean(), pl.exclude("dir_name").std().name.suffix("_std"))
top_group = aggregated.sort("F-Score/Sum/Validation", descending=True)[:3, "group"].to_list()
top_group

[3, 8, 9]

In [16]:
grouped.filter(pl.col("group").is_in(top_group))
selected_params = params.filter(pl.col("group").is_in(top_group)).drop("dir_name", "start_time", "F-Score", "seed").unique()
final_params = [get_settings(selected_params.filter(pl.col("group") == group)) for group in top_group]
print(final_params)

[{'activation': 'ELU', 'batch_size': 12, 'beats': False, 'beta_1': 0.9, 'beta_2': 0.999, 'causal': True, 'center': True, 'channel_multiplication': 1, 'classifier_dim': 45, 'dataset_version': 'S', 'decoupled_weight_decay': True, 'detect_tolerance': 0.025, 'down_sample_factor': 4, 'dropout': 0.5, 'early_stopping': '5', 'ema': False, 'epochs': 30, 'epsilon': 1e-08, 'eval_set': 'A2MD', 'fft_size': 2048, 'flux': True, 'frame_length': 4.1, 'frame_overlap': 1.2, 'full_length_test': True, 'hop_size': 441, 'ignore_beats': True, 'label_lead_in': 0.25, 'label_lead_out': 0.1, 'learning_rate': 0.004733562104434663, 'mapping': 'Three class standard', 'mel_max': 20000.0, 'mel_min': 20.0, 'min_save_score': 0.7, 'min_test_score': 0.6, 'model_settings': 'crnn', 'n_mels': 84, 'normalize': False, 'num_channels': 16, 'num_conv_layers': 1, 'num_rnn_layers': 3, 'num_workers': 16, 'onset_cooldown': 0.021, 'pad_annotations': True, 'pad_mode': 'constant', 'pad_value': 0.25, 'peak_max_range': 2, 'peak_mean_range

# Get Settings of runs of interest

In [None]:
from settings import Config
import polars as pl
from copy import deepcopy

In [None]:
hparams = pl.scan_parquet("./processed/optuna.parquet")

In [None]:
runs_of_interest = {
    "Attention best": "Feb14_00-01-28_marclie-desktop",
    "Attention faster": "Feb15_15-25-56_marclie-desktop",
    "Attention no conv": "Feb16_11-25-37_marclie-desktop",
    "CRNN best": "Feb04_02-45-50_seppel-liemarce",
    "CRNN small": "Feb01_17-55-46_seppel-liemarce",
    "CRNN no conv": "Feb12_22-38-38_seppel-liemarce",
    "Mamba best": "Feb25_20-22-44_seppel-liemarce",
    "Mamba fast": "Feb22_15-27-14_marclie-desktop",
    "Mamba no conv": "Feb26_17-41-56_seppel-liemarce",
}

In [None]:
_ = pl.Config.restore_defaults()

def drop_columns_that_are_all_null(_df: pl.DataFrame) -> pl.DataFrame:
    return _df[[s.name for s in _df if not (s.null_count() == _df.height)]]

def get_settings(dir_name: str):
    global hparams
    settings = hparams.filter(pl.col("dir_name") == dir_name)
    settings = drop_columns_that_are_all_null(settings.collect())
    model_params = sorted(set([param for typ in [int, float, str, bool] for param in get_columns_with_type(typ, False)]) & set(settings.columns))
    _dict = settings.select(pl.col(model_params)).to_dicts()[0]
    activation_map = {
        "relu": "ReLU",
        "selu": "SELU",
        "silu": "SiLU",
        "elu": "ELU",
    }
    _dict["activation"] = activation_map[_dict["activation"]]
    if _dict["model_settings"] == "crnn" and "hidden_units" in _dict.keys():
        _dict.pop("hidden_units")
    if _dict["model_settings"] in ["crnn", "mamba", "mamba_fast"] and "use_relative_pos" in _dict.keys():
        _dict.pop("use_relative_pos")
    if _dict["model_settings"].startswith("mamba") and "expansion_factor" in _dict.keys():
        _dict.pop("expansion_factor")
    config = Config.from_flat_dict(_dict)
    reversed_settings = {
        **asdict(config.training),
        **asdict(config.evaluation),
        **asdict(config.dataset),
        **asdict(config.model),
    }
    for key, item in _dict.items():
        if key in ["activation", "mapping", "splits", "test_sets"]:
            assert str(item) == str(reversed_settings[key])
            # print(f"{item} == {reversed_settings[key]}: Please check manually if {key} is equal")
            continue
        assert item == reversed_settings[key], f"Key {key} is mismatched: {item}({type(item)}) != {reversed_settings[key]}({type(reversed_settings[key])})"
    return _dict


In [None]:
settings_of_interest = {name: get_settings(run) for name, run in runs_of_interest.items()}

In [None]:
final_experiment_params = deepcopy(settings_of_interest)
for name, settings in final_experiment_params.items():
    settings["pr_points"] = 1000
    settings["min_save_score"] = 0.0
    settings["min_test_score"] = 0.0
    settings["dataset_version"] = "M"
    settings["epochs"] = 40
    settings["test_sets"] = "('RBMA', 'MDB')"
    settings["eval_set"] = "A2MD"
    settings["scheduler"] = False
    settings["time_shift"] = 0.015
    settings["pad_value"] = 0.5
    settings["beats"] = False
    settings["causal"] = True
    settings["early_stopping"] = None
    settings["fft_size"] = 1024
    settings["ema"] = True
    settings.pop("num_workers")

identical_params = {**final_experiment_params["Attention best"]}
different_params = []
for name, settings in final_experiment_params.items():
    for key, value in settings.items():
        if key in identical_params:
            if not value == identical_params[key]:
                identical_params.pop(key)
                different_params.append(key)

print(identical_params)
print(different_params)

for param in different_params:
    print(f"-------{param}---------")
    for name, settings in final_experiment_params.items():
        if param in settings:
            print(name, settings[param])

print("\n\n\n")
final_experiment_params

# Test saved checkpoints

In [1]:
import torch
from glob import glob
import os
from model import EnsembleModel
from settings import Config, asdict, TrainingSettings, DatasetSettings, EvaluationSettings
from dataclasses import asdict as dataclass_asdict
from model.cnn import CNN
from model.cnnA import CNNAttention
from model.cnnM2 import CNNMambaFast
from model.CRNN import CRNN, CRNN_Vogl
from dataset.RBMA13 import RBMA13
from dataset.MDB_Drums import MDBDrums
from dataset.A2MD import A2MD
from dataset import get_dataloader
from main import evaluate

In [2]:
folder = "models/L 5 fold Mamba"
files = glob(os.path.join(folder, "*.pt"))

In [3]:
def get_model(file: str):
    checkpoint = torch.load(file, map_location="cpu")
    training_settings = TrainingSettings.from_flat_dict(checkpoint["training_settings"])
    dataset_settings = DatasetSettings.from_flat_dict(checkpoint["dataset_settings"])
    n_classes = dataset_settings.annotation_settings.n_classes
    n_mels = dataset_settings.audio_settings.n_mels
    match training_settings.model_settings:
        case "cnn":
            model_settings = training_settings.get_model_settings_class().from_flat_dict(checkpoint["model_settings"])
            model = CNN(**dataclass_asdict(model_settings), n_classes=n_classes, n_mels=n_mels)
        case "cnn_attention":
            model_settings = training_settings.get_model_settings_class().from_flat_dict(checkpoint["model_settings"])
            model = CNNAttention(**dataclass_asdict(model_settings), n_classes=n_classes, n_mels=n_mels)
        case "mamba_fast":
            model_settings = training_settings.get_model_settings_class().from_flat_dict(checkpoint["model_settings"])
            model = CNNMambaFast(**dataclass_asdict(model_settings), n_classes=n_classes, n_mels=n_mels)
        case "crnn":
            model_settings = training_settings.get_model_settings_class().from_flat_dict(checkpoint["model_settings"])
            model = CRNN(**dataclass_asdict(model_settings), n_classes=n_classes, n_mels=n_mels)
        case "vogl":
            model = CRNN_Vogl(n_classes=n_classes, n_mels=n_mels, causal=True)
        case _:
            raise ValueError(f"Unknown model type: {training_settings.model_settings}")
    model.load_state_dict(checkpoint["model"])
    return model, dataset_settings


In [4]:
models = [get_model(file) for file in files]
dataset_settings = models[0][1]
models = [model.to("cpu") for model, _ in models]
ensemble = EnsembleModel(models)

In [5]:
dataset_settings.k_folds = None
dataset_settings.dataset_version = "L"

rbma = RBMA13("data/rbma_13", dataset_settings, segment=False, splits=None, is_train=False, use_dataloader=True)
mdb = MDBDrums("data/MDB Drums", dataset_settings, segment=False, split=None, is_train=False, use_dataloader=True)
a2md = A2MD("data/a2md_public", dataset_settings, segment=False, split=None, is_train=False, use_dataloader=True)

  drums = np.loadtxt(os.path.join(root, file), delimiter="\t", )
  drums = np.loadtxt(os.path.join(root, file), delimiter="\t", )


In [6]:
batch_size = 1

rbma_loader = get_dataloader(rbma, batch_size=batch_size, num_workers=16, is_train=False)
mdb_loader = get_dataloader(mdb, batch_size=batch_size, num_workers=16, is_train=False)
a2md_loader = get_dataloader(a2md, batch_size=3, num_workers=16, is_train=False)

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
ensemble = ensemble.to(device)

In [8]:
evaluate(
    0,
    ensemble,
    a2md_loader,
    torch.nn.BCEWithLogitsLoss(reduction="none"),
    device,
    EvaluationSettings(pr_points=1000, detect_tolerance=0.05, ),
    None,
)

Evaluation:   0%|          | 0/521 [00:00<?, ?mini-batch/s]W0805 21:46:50.472000 28448 torch/_inductor/utils.py:1137] [0/0] Not enough SMs to use max_autotune_gemm mode
Evaluation: 100%|██████████| 521/521 [01:57<00:00,  4.44mini-batch/s]
100%|██████████| 1000/1000 [01:21<00:00, 12.28it/s]s]]

(0.08453381948187347,
 0.8771509528160095,
 0.8763518333435059,
 tensor([0.0860, 0.1730, 0.1170]))

In [12]:
evaluate(
    0,
    models[0].to(device),
    rbma_loader,
    torch.nn.BCEWithLogitsLoss(reduction="none"),
    device,
    EvaluationSettings(pr_points=1000, detect_tolerance=0.05, ),
    None,
    # thresholds=[0.0900, 0.1500, 0.1000]
) # vogl 0.6689 first 0.6757 ensemble
# mamba 0.6326 first 0.6753 ensemble

Evaluation: 100%|██████████| 28/28 [00:22<00:00,  1.24mini-batch/s]
100%|██████████| 1000/1000 [00:00<00:00, 16036.02it/s]

(0.09868937724136881,
 0.6326425671577454,
 0.6455109715461731,
 tensor([0.1180, 0.1790, 0.0840]))

In [13]:
evaluate(
    0,
    models[0].to(device),
    mdb_loader,
    torch.nn.BCEWithLogitsLoss(reduction="none"),
    device,
    EvaluationSettings(pr_points=1000, detect_tolerance=0.05, ),
    None,
    # thresholds=[0.0900, 0.1500, 0.1000]
) # vogl 0.7329 ensemble 0.7154 first
# mamba 0.7230 ensemble 0.6988 first

Evaluation: 100%|██████████| 23/23 [00:09<00:00,  2.47mini-batch/s]
100%|██████████| 1000/1000 [00:00<00:00, 8624.38it/s]]

(0.10995919467962306,
 0.6988956332206726,
 0.7007308602333069,
 tensor([0.1090, 0.1150, 0.1440]))