In [1]:
import matplotlib.pyplot as plt
import numpy as np
import polars as pl
import pandas as pd
from tbparse import SummaryReader
from settings import TrainingSettings, DatasetSettings, EvaluationSettings, CNNSettings, CNNMambaSettings, CNNAttentionSettings, asdict

In [2]:
def get_columns_with_type(typ) -> list[str]:
    out = []
    classes = [TrainingSettings, DatasetSettings, EvaluationSettings]
    for cls in classes:
        settings = cls()
        dic = asdict(settings)
        for name, value in dic.items():
            if type(value) is typ:
                out.append(name)
    classes = [CNNSettings, CNNAttentionSettings, CNNMambaSettings]
    for cls in classes:
        settings = cls(3, 84)
        dic = asdict(settings)
        for name, value in dic.items():
            if type(value) is typ:
                out.append(name)
    return list(set(out))

In [3]:
integer_values = list(set(get_columns_with_type(int) + ["early_stopping"]))
boolean_values = get_columns_with_type(bool)
float_values = get_columns_with_type(float)
string_values = get_columns_with_type(str) + ["dir_name", "mapping"]

In [4]:
string_values

['splits',
 'segment_type',
 'dataset_version',
 'activation',
 'train_set',
 'pad_mode',
 'model_settings',
 'dir_name',
 'mapping']

In [5]:
logs = SummaryReader("runs/", pivot=True, extra_columns={'dir_name'})
logs_no_pivot = SummaryReader("runs/", pivot=False, extra_columns={'dir_name'})

In [11]:
# convert types
params = logs.hparams
params[params == "None"] = None
params[integer_values] = params[integer_values].astype(pd.Int64Dtype())
params[boolean_values] = params[boolean_values].astype(bool)
params[float_values] = params[float_values].astype(np.float64)
params[string_values] = params[string_values].astype(pd.StringDtype())
params = pd.DataFrame(params)

params = pl.from_pandas(params, nan_to_null=True)
scores = pl.from_pandas(logs.scalars, nan_to_null=True)
hparams = params.join(scores.select(pl.col("F-Score", "dir_name")), on='dir_name', how='inner')
hparams = hparams.select(pl.all().exclude("dir_name"), pl.col("dir_name").str.split("/").list.first())

# Convert plots
images = logs_no_pivot.images
size = np.array(images.iloc[0]["value"].shape)
images["value"] = images["value"].apply(lambda x: x.flatten())
plots = pl.from_pandas(images, nan_to_null=True)
num_rows = plots.shape[0]
plots = plots.select(pl.col("dir_name"), pl.col("step"), pl.col("tag"),
                     pl.col("value").reshape(tuple([num_rows, *size]), pl.Array).alias("value"))
plots = plots.pivot(values=["value"], columns=["tag"], index=["step", "dir_name"])

In [12]:
pl.Config.set_tbl_cols(100)
pl.Config.set_tbl_rows(100)


def get_model_settings(model_type: str) -> (pl.DataFrame, pl.DataFrame):
    model = hparams.filter(pl.col("model_settings").str.contains(model_type))
    non_null = model.select(pl.all().is_not_null().all()).row(0)
    model = model[:, non_null]
    different = model.select(pl.all().n_unique() > 1).row(0)
    diff = model[:, different].sort("F-Score", descending=True)
    iden = model.select(pl.all().n_unique() == 1).row(0)
    identical = model[:, iden].select(pl.all().exclude("dir_name", "F-Score")).limit(1)

    return diff, identical


def get_history(name: str) -> pl.DataFrame:
    data = scores.filter(pl.col("dir_name") == name)
    prs = plots.filter(pl.col("dir_name") == name)
    data = data.join(prs, on="step", how="inner")
    return data

In [13]:
unique, identical = get_model_settings("mamba")
print(identical)
unique

shape: (1, 27)
┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
│ act ┆ bea ┆ cau ┆ cen ┆ dat ┆ det ┆ ema ┆ fft ┆ hop ┆ ign ┆ mel ┆ mel ┆ min ┆ n_c ┆ n_m ┆ nor ┆ pad ┆ pad ┆ pad ┆ pea ┆ pea ┆ sam ┆ sch ┆ spl ┆ tes ┆ tra ┆ use │
│ iva ┆ ts  ┆ sal ┆ ter ┆ ase ┆ ect ┆ --- ┆ _si ┆ _si ┆ ore ┆ _ma ┆ _mi ┆ _sa ┆ las ┆ els ┆ mal ┆ _an ┆ _mo ┆ _va ┆ k_m ┆ k_m ┆ ple ┆ edu ┆ its ┆ t_b ┆ in_ ┆ _re │
│ tio ┆ --- ┆ --- ┆ --- ┆ t_v ┆ _to ┆ boo ┆ ze  ┆ ze  ┆ _be ┆ x   ┆ n   ┆ ve_ ┆ ses ┆ --- ┆ ize ┆ not ┆ de  ┆ lue ┆ ax_ ┆ ean ┆ _ra ┆ ler ┆ --- ┆ atc ┆ set ┆ lat │
│ n   ┆ boo ┆ boo ┆ boo ┆ ers ┆ ler ┆ l   ┆ --- ┆ --- ┆ ats ┆ --- ┆ --- ┆ sco ┆ --- ┆ i64 ┆ --- ┆ ati ┆ --- ┆ --- ┆ ran ┆ _ra ┆ te  ┆ --- ┆ str ┆ h_s ┆ --- ┆ ive │
│ --- ┆ l   ┆ l   ┆ l   ┆ ion ┆ anc ┆     ┆ i64 ┆ i64 ┆ --- ┆ f64 ┆ f64 ┆ re  ┆ i64 ┆     ┆ boo ┆ ons ┆ str ┆ f64 ┆ ge  ┆ nge ┆ --- ┆ boo ┆     ┆ ize ┆ str ┆ _po │
│

batch_size,dropout,epochs,flux,learning_rate,mapping,min_test_score,model_settings,n_layers,num_channels,num_workers,onset_cooldown,time_shift,weight_decay,F-Score,dir_name
i64,f64,i64,bool,f64,str,f64,str,i64,i64,i64,f64,f64,f64,f64,str
64,0.1,30,False,0.0001,"""Three class standard""",0.48,"""mamba_fast""",5,32,16,0.021,0.015,1e-05,0.500146,"""Jun16_11-35-01_marclie-desktop"""
4,0.1,20,False,0.0001,"""THREE_CLASS_STANDARD""",0.54,"""mamba""",3,16,64,0.021,0.015,1e-05,0.455403,"""Jun10_17-58-15_seppel-liemarce"""
16,0.1,20,True,0.0001,"""Three class standard""",0.48,"""mamba_fast""",5,32,64,0.021,0.02,1e-05,0.434372,"""Jun19_13-16-07_seppel-liemarce"""
16,0.1,20,True,0.0001,"""Three class standard""",0.48,"""mamba_fast""",5,32,64,0.021,0.02,1e-05,0.422229,"""Jun26_10-28-36_seppel-liemarce"""
16,0.1,20,False,0.0001,"""Three class standard""",0.48,"""mamba_fast""",3,16,64,0.021,0.02,1e-05,0.139301,"""Jun25_07-53-16_seppel-liemarce"""
32,0.1,30,False,0.0001,"""Three class standard""",0.48,"""mamba_fast""",16,16,16,0.021,0.015,1e-05,0.125544,"""Jun21_18-48-11_marclie-desktop"""
24,0.1,20,True,0.0001,"""Three class standard""",0.48,"""mamba_fast""",5,16,64,0.021,0.025,1e-05,0.116948,"""Jun23_17-34-02_seppel-liemarce"""
24,0.1,20,True,0.0001,"""Three class standard""",0.48,"""mamba_fast""",10,16,64,0.021,0.025,1e-05,0.052223,"""Jun22_16-25-42_seppel-liemarce"""
32,0.3,30,False,5e-05,"""THREE_CLASS_STANDARD""",0.54,"""mamba""",3,16,64,0.02,0.015,0.0,0.04605,"""Jun08_13-22-15_seppel-liemarce"""
16,0.1,20,True,0.0001,"""Three class standard""",0.48,"""mamba_fast""",5,32,64,0.021,0.02,1e-05,0.045436,"""Jun25_14-54-59_seppel-liemarce"""


In [15]:
unique, identical = get_model_settings("attention")
print(identical)
unique

shape: (1, 25)
┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
│ act ┆ bea ┆ cau ┆ cen ┆ det ┆ dro ┆ ema ┆ fft ┆ hop ┆ ign ┆ mel ┆ mel ┆ min ┆ mod ┆ n_c ┆ n_m ┆ nor ┆ pad ┆ pad ┆ pea ┆ pea ┆ sam ┆ sch ┆ spl ┆ tra │
│ iva ┆ ts  ┆ sal ┆ ter ┆ ect ┆ pou ┆ --- ┆ _si ┆ _si ┆ ore ┆ _ma ┆ _mi ┆ _sa ┆ el_ ┆ las ┆ els ┆ mal ┆ _mo ┆ _va ┆ k_m ┆ k_m ┆ ple ┆ edu ┆ its ┆ in_ │
│ tio ┆ --- ┆ --- ┆ --- ┆ _to ┆ t   ┆ boo ┆ ze  ┆ ze  ┆ _be ┆ x   ┆ n   ┆ ve_ ┆ set ┆ ses ┆ --- ┆ ize ┆ de  ┆ lue ┆ ax_ ┆ ean ┆ _ra ┆ ler ┆ --- ┆ set │
│ n   ┆ boo ┆ boo ┆ boo ┆ ler ┆ --- ┆ l   ┆ --- ┆ --- ┆ ats ┆ --- ┆ --- ┆ sco ┆ tin ┆ --- ┆ i64 ┆ --- ┆ --- ┆ --- ┆ ran ┆ _ra ┆ te  ┆ --- ┆ str ┆ --- │
│ --- ┆ l   ┆ l   ┆ l   ┆ anc ┆ f64 ┆     ┆ i64 ┆ i64 ┆ --- ┆ f64 ┆ f64 ┆ re  ┆ gs  ┆ i64 ┆     ┆ boo ┆ str ┆ f64 ┆ ge  ┆ nge ┆ --- ┆ boo ┆     ┆ str │
│ str ┆     ┆     ┆     ┆ e   ┆     ┆     ┆     ┆     ┆ boo ┆     ┆     ┆

batch_size,context_size,dataset_version,epochs,flux,learning_rate,mapping,min_test_score,num_attention_blocks,num_channels,num_heads,num_workers,onset_cooldown,pad_annotations,test_batch_size,time_shift,use_relative_pos,F-Score,dir_name
i64,i64,str,i64,bool,f64,str,f64,i64,i64,i64,i64,f64,bool,i64,f64,bool,f64,str
16,200,"""M""",20,True,0.0001,"""Three class standard""",0.54,5,24,8,16,0.021,True,1,0.015,False,0.493905,"""Jun15_14-16-35_marclie-desktop"""
16,200,"""L""",20,True,0.0001,"""THREE_CLASS_STANDARD""",0.54,5,24,8,64,0.021,True,1,0.015,False,0.492189,"""Jun13_10-57-29_seppel-liemarce"""
16,200,"""L""",20,True,0.0001,"""THREE_CLASS_STANDARD""",0.54,5,32,8,64,0.021,True,1,0.015,False,0.482672,"""Jun14_09-33-24_seppel-liemarce"""
4,200,"""L""",20,False,0.0001,"""THREE_CLASS_STANDARD""",0.54,5,16,8,64,0.021,True,1,0.015,False,0.458652,"""Jun11_13-54-21_seppel-liemarce"""
16,200,"""M""",20,True,0.0001,"""Three class standard""",0.48,4,24,8,16,0.021,True,1,0.015,False,0.452022,"""Jun15_18-05-40_marclie-desktop"""
16,200,"""L""",20,True,0.0001,"""Three class standard""",0.48,5,24,8,64,0.021,True,1,0.02,False,0.393078,"""Jun17_15-04-11_seppel-liemarce"""
4,200,"""S""",20,False,0.0001,"""THREE_CLASS_STANDARD""",0.54,5,32,8,64,0.021,True,1,0.015,False,0.248427,"""Jun11_06-49-13_seppel-liemarce"""
16,200,"""M""",20,True,0.0001,"""Three class standard""",0.48,5,24,8,16,0.021,True,1,0.015,False,0.248362,"""Jun15_20-26-24_marclie-desktop"""
8,200,"""L""",20,True,0.0001,"""THREE_CLASS_STANDARD""",0.54,5,24,8,64,0.021,True,1,0.015,False,0.233853,"""Jun12_07-38-49_seppel-liemarce"""
64,50,"""M""",30,False,1e-05,"""THREE_CLASS_STANDARD""",0.54,2,16,4,16,0.02,False,4,0.015,True,0.230188,"""Jun06_19-48-30_marclie-desktop"""


In [16]:
unique, identical = get_model_settings("cnn")
print(identical)
unique

shape: (1, 24)
┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐
│ act ┆ bea ┆ cau ┆ cen ┆ det ┆ dro ┆ ema ┆ fft ┆ hop ┆ ign ┆ mel ┆ mel ┆ min ┆ n_c ┆ n_m ┆ nor ┆ pad ┆ pad ┆ pea ┆ pea ┆ sam ┆ sch ┆ spl ┆ tra │
│ iva ┆ ts  ┆ sal ┆ ter ┆ ect ┆ pou ┆ --- ┆ _si ┆ _si ┆ ore ┆ _ma ┆ _mi ┆ _sa ┆ las ┆ els ┆ mal ┆ _mo ┆ _va ┆ k_m ┆ k_m ┆ ple ┆ edu ┆ its ┆ in_ │
│ tio ┆ --- ┆ --- ┆ --- ┆ _to ┆ t   ┆ boo ┆ ze  ┆ ze  ┆ _be ┆ x   ┆ n   ┆ ve_ ┆ ses ┆ --- ┆ ize ┆ de  ┆ lue ┆ ax_ ┆ ean ┆ _ra ┆ ler ┆ --- ┆ set │
│ n   ┆ boo ┆ boo ┆ boo ┆ ler ┆ --- ┆ l   ┆ --- ┆ --- ┆ ats ┆ --- ┆ --- ┆ sco ┆ --- ┆ i64 ┆ --- ┆ --- ┆ --- ┆ ran ┆ _ra ┆ te  ┆ --- ┆ str ┆ --- │
│ --- ┆ l   ┆ l   ┆ l   ┆ anc ┆ f64 ┆     ┆ i64 ┆ i64 ┆ --- ┆ f64 ┆ f64 ┆ re  ┆ i64 ┆     ┆ boo ┆ str ┆ f64 ┆ ge  ┆ nge ┆ --- ┆ boo ┆     ┆ str │
│ str ┆     ┆     ┆     ┆ e   ┆     ┆     ┆     ┆     ┆ boo ┆     ┆     ┆ --- ┆     ┆     ┆ l   ┆     ┆     ┆

batch_size,dataset_version,epochs,flux,learning_rate,mapping,min_test_score,model_settings,num_channels,num_workers,onset_cooldown,pad_annotations,test_batch_size,time_shift,use_relative_pos,F-Score,dir_name
i64,str,i64,bool,f64,str,f64,str,i64,i64,f64,bool,i64,f64,bool,f64,str
16,"""M""",20,True,0.0001,"""Three class standard""",0.54,"""cnn_attention""",24,16,0.021,True,1,0.015,False,0.493905,"""Jun15_14-16-35_marclie-desktop"""
16,"""L""",20,True,0.0001,"""THREE_CLASS_STANDARD""",0.54,"""cnn_attention""",24,64,0.021,True,1,0.015,False,0.492189,"""Jun13_10-57-29_seppel-liemarce"""
512,"""L""",30,True,0.0001,"""THREE_CLASS_STANDARD""",0.54,"""cnn""",16,64,0.02,False,10,0.015,True,0.487235,"""Jun04_19-26-40_seppel-liemarce"""
16,"""L""",20,True,0.0001,"""THREE_CLASS_STANDARD""",0.54,"""cnn_attention""",32,64,0.021,True,1,0.015,False,0.482672,"""Jun14_09-33-24_seppel-liemarce"""
512,"""M""",30,True,0.0001,"""Three class standard""",0.48,"""cnn""",16,16,0.021,True,1,0.015,True,0.46571,"""Jun16_19-27-05_marclie-desktop"""
512,"""L""",20,True,0.0001,"""THREE_CLASS_STANDARD""",0.54,"""cnn""",16,64,0.02,True,1,0.015,True,0.462143,"""Jun10_09-14-54_seppel-liemarce"""
4,"""L""",20,False,0.0001,"""THREE_CLASS_STANDARD""",0.54,"""cnn_attention""",16,64,0.021,True,1,0.015,False,0.458652,"""Jun11_13-54-21_seppel-liemarce"""
16,"""M""",20,True,0.0001,"""Three class standard""",0.48,"""cnn_attention""",24,16,0.021,True,1,0.015,False,0.452022,"""Jun15_18-05-40_marclie-desktop"""
16,"""L""",20,True,0.0001,"""Three class standard""",0.48,"""cnn_attention""",24,64,0.021,True,1,0.02,False,0.393078,"""Jun17_15-04-11_seppel-liemarce"""
512,"""L""",20,True,0.0001,"""THREE_CLASS_STANDARD""",0.54,"""cnn""",16,64,0.02,True,1,0.025,True,0.267671,"""Jun09_16-46-27_seppel-liemarce"""


In [None]:
best = unique.select("dir_name").row(0)[0]
history = get_history(best)
pr_curve = history.select(pl.col("Validation/PR-Curve/")).row(-1)[0]

In [None]:
plt.imshow(pr_curve)
plt.axis("off")

[String,
 Int64,
 Boolean,
 Boolean,
 Boolean,
 Int64,
 Int64,
 Int64,
 Int64,
 String,
 Float64,
 Int64,
 Float64,
 Int64,
 Boolean,
 Int64,
 Int64,
 Int64,
 Int64,
 Boolean,
 Float64,
 Float64,
 Int64,
 Boolean,
 Float64,
 Float64,
 Float64,
 Float64,
 Float64,
 String,
 Float64,
 Float64,
 Float64,
 Float64,
 String,
 Int64,
 Int64,
 Int64,
 Boolean,
 Int64,
 Int64,
 Int64,
 Int64,
 Int64,
 Float64,
 Boolean,
 String,
 Float64,
 Int64,
 Int64,
 Float64,
 Int64,
 Int64,
 Boolean,
 String,
 String,
 Int64,
 Float64,
 String,
 Boolean,
 Float64,
 String]