In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

import duckdb
import seaborn as sns
font_scale = 7
sns.set_theme(style='whitegrid', font_scale=font_scale, palette=sns.color_palette('Set2'),)
import sqlalchemy as sa
import polars as pl
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from great_tables import GT

from conf import conf
from dafm import plots

In [3]:
duckdb.sql("""
attach '../runs.sqlite';
use runs;
""")

In [4]:
engine = conf.get_engine()
session = conf.sa.orm.Session(engine)
session.begin()

<sqlalchemy.orm.session.SessionTransaction at 0x75cc2259a900>

# Run queries

### Datasets

In [5]:
dataset_cols = ''.join(map(str.strip, """
    id,
    dataset_name,
    dataset_name_latex_short,
    time_step_count,
    time_step_count_drop_first,
    observe_every_n_time_steps,
""".splitlines()))

##### Generative method comparison tuning

In [6]:
topk_hyperparameter_filename = 'topk_hyperparameters'
dataset_rows = duckdb.sql(f"""
    select {dataset_cols} from paper_lorenz96
    union
    select {dataset_cols} from paper_kuramoto_sivashinsky
--  union
--  select {dataset_cols} from paper_navier_stokes_dim_64
    union
    select {dataset_cols} from paper_navier_stokes_dim_256
""")
dataset_multiple = 3
assert len(dataset_rows) == dataset_multiple
sampling_time_step_counts = duckdb.sql("select * from sweep_sampling_time_step_count")

##### Classical comparison tuning

In [None]:
topk_hyperparameter_filename = 'topk_hyperparameters_classical_comparison'
dataset_rows = duckdb.sql(f"""
    select {dataset_cols} from paper_kuramoto_sivashinsky_classical_comparison where state_dimension = 512
    union
    select {dataset_cols} from paper_navier_stokes_classical_comparison where state_dimension = 3*64*64
""")
dataset_multiple = 2
assert len(dataset_rows) == dataset_multiple
sampling_time_step_counts = duckdb.sql("select 100")

### Models

In [7]:
model_cols = ''.join(map(str.strip, """
    id,
    model_name,
    hyperparameter1,
    hyperparameter1_name,
    hyperparameter1_name_latex,
    hyperparameter2,
    hyperparameter2_name,
    hyperparameter2_name_latex,
""".splitlines()))

##### Generative method comparison tuning

In [8]:
model_rows = duckdb.sql(rf"""
    select * from (
        select {model_cols} sampling_time_step_count
        from ensf
        where true
        and epsilon_alpha in (select * from sweep_ensf_epsilon_alpha)
        and epsilon_beta in (select * from sweep_ensf_epsilon_beta)
        union
        select {model_cols} sampling_time_step_count
        from enff_ot
        where true
        and sigma_min in (select * from sweep_enff_sigma_min)
        and lambda in (select * from sweep_enff_lambda)
        union
        select {model_cols} sampling_time_step_count
        from enff_f2p
        where true
        and sigma_min in (select * from sweep_enff_sigma_min)
        and lambda in (select * from sweep_enff_lambda)
    )
    where sampling_time_step_count in (select * from sampling_time_step_counts)
""")
model_multiple = sum(
    q.fetchall()[0][0]
    for q in (
        duckdb.sql("select count(*) from sweep_ensf_epsilon_alpha cross join sweep_ensf_epsilon_beta"),
        duckdb.sql("select count(*) from sweep_enff_sigma_min cross join sweep_enff_lambda"),
        duckdb.sql("select count(*) from sweep_enff_sigma_min cross join sweep_enff_lambda"),
    )
)
sampling_time_step_count_multiple = duckdb.sql('select count(*) from sampling_time_step_counts').fetchall()[0][0]
multiple = (
    dataset_multiple
    * duckdb.sql("select count(*) from rng_seed_train").fetchall()[0][0]
    * (
        model_multiple * sampling_time_step_count_multiple
    )
)

##### Classical comparison tuning

In [None]:
model_classical_rows = duckdb.sql(rf"""
    select {model_cols} null as sampling_time_step_count
    from bpf
    where true
    and inflation in (select * from sweep_classical_inflation)
    union
    select {model_cols} null as sampling_time_step_count
    from enkf_po
    where true
    and inflation in (select * from sweep_classical_inflation)
    and localization in (select * from sweep_classical_localization)
    union
    select {model_cols} null as sampling_time_step_count
    from ienkf_po
    where true
    and inflation in (select * from sweep_classical_inflation)
    union
    select {model_cols} null as sampling_time_step_count
    from esrf
    where true
    and inflation in (select * from sweep_classical_inflation)
    and localization in (select * from sweep_classical_localization)
    union
    select {model_cols} null as sampling_time_step_count
    from letkf
    where true
    and inflation in (select * from sweep_classical_inflation)
    and localization in (select * from sweep_classical_localization)
""")
model_classical_multiple = sum(
    q.fetchall()[0][0]
    for q in (
        duckdb.sql("select count(*) from sweep_classical_inflation"),
        duckdb.sql("select count(*) from sweep_classical_inflation cross join sweep_classical_localization"),
        duckdb.sql("select count(*) from sweep_classical_inflation"),
        duckdb.sql("select count(*) from sweep_classical_inflation cross join sweep_classical_localization"),
        duckdb.sql("select count(*) from sweep_classical_inflation cross join sweep_classical_localization"),
    )
)
sampling_time_step_count_classical_multiple = 1
multiple = (
    dataset_multiple
    * duckdb.sql("select count(*) from rng_seed_train").fetchall()[0][0]
    * (
        model_multiple * sampling_time_step_count_multiple
        + model_classical_multiple * sampling_time_step_count_classical_multiple
    )
)

### General

In [9]:
rows = duckdb.sql("""
    select alt_id, rng_seed, dataset_rows.*, all_model_rows.*
    from Conf
    join dataset_rows on Conf.Dataset = dataset_rows.id
    join (
        select * from model_rows
      --union
      --select * from model_classical_rows
    ) as all_model_rows on Conf.Model = all_model_rows.id
    where true
    and rng_seed in (select * from rng_seed_train)
""")
assert len(rows) == multiple, f'{len(rows) = } != {multiple}'

In [10]:
duckdb.sql("""
    select
        dataset_name,
        model_name,
        sampling_time_step_count,
        count(*) as num_rows,
    from rows
    group by
        dataset_name,
        model_name,
        sampling_time_step_count,
    having (
        true
      --model_name = 'EnSF' and num_rows = 10 * 8
      --or
      --model_name != 'EnSF' and num_rows <= 13 * 5
    )
    order by dataset_name, model_name desc, sampling_time_step_count
""")

┌─────────────────────┬────────────┬──────────────────────────┬──────────┐
│    dataset_name     │ model_name │ sampling_time_step_count │ num_rows │
│       varchar       │  varchar   │          int64           │  int64   │
├─────────────────────┼────────────┼──────────────────────────┼──────────┤
│ KuramotoSivashinsky │ EnSF       │                        5 │       80 │
│ KuramotoSivashinsky │ EnSF       │                       10 │       80 │
│ KuramotoSivashinsky │ EnSF       │                       20 │       80 │
│ KuramotoSivashinsky │ EnSF       │                       50 │       80 │
│ KuramotoSivashinsky │ EnSF       │                      100 │       80 │
│ KuramotoSivashinsky │ EnFF-OT    │                        5 │       65 │
│ KuramotoSivashinsky │ EnFF-OT    │                       10 │       65 │
│ KuramotoSivashinsky │ EnFF-OT    │                       20 │       65 │
│ KuramotoSivashinsky │ EnFF-OT    │                       50 │       65 │
│ KuramotoSivashinsky │ E

In [11]:
logged_metrics_file_paths = plots.get_logged_metrics_file_paths(rows)
duckdb.sql("""
set variable dataset_metrics_filepaths = (
    select list(path) from logged_metrics_file_paths where exists
)
""")

In [12]:
observation_steps_back = 50
logged_metrics = duckdb.sql(f"""
    select rows.*, logs.*,
    from (
        select split(filename, '/')[-2] as alt_id, step, time_s, crps, rmse,
        from read_csv(getvariable(dataset_metrics_filepaths), filename=true, union_by_name=true)
    ) as logs
    join rows on rows.alt_id = logs.alt_id
    where true
    and (logs.step - time_step_count_drop_first - 1) % observe_every_n_time_steps == 0 -- include only analysis time steps
    and step > time_step_count - observe_every_n_time_steps * {observation_steps_back}
""")
logged_metrics.show(max_width=125)

┌──────────┬────────────┬───────┬─────────────────────┬───┬──────────────────────┬────────────────────┬─────────────────────┐
│  alt_id  │  rng_seed  │  id   │    dataset_name     │ … │        time_s        │        crps        │        rmse         │
│ varchar  │   int64    │ int64 │       varchar       │   │        double        │       double       │       double        │
├──────────┼────────────┼───────┼─────────────────────┼───┼──────────────────────┼────────────────────┼─────────────────────┤
│ 5gip7gv0 │ 2376999025 │    50 │ KuramotoSivashinsky │ … │ 0.030874179999997864 │  7.089921293743939 │ 0.23455222081430618 │
│ 5gip7gv0 │ 2376999025 │    50 │ KuramotoSivashinsky │ … │ 0.029918566999995733 │  7.177826929114071 │ 0.23664704221334718 │
│ 5gip7gv0 │ 2376999025 │    50 │ KuramotoSivashinsky │ … │  0.03107601700000373 │ 7.3595981917605515 │  0.2423642313643367 │
│ 5gip7gv0 │ 2376999025 │    50 │ KuramotoSivashinsky │ … │  0.03136959099999359 │   6.93293557025534 │ 0.229011209084

In [13]:
duckdb.sql("""
select distinct dataset_name from logged_metrics
""")

┌─────────────────────┐
│    dataset_name     │
│       varchar       │
├─────────────────────┤
│ Lorenz96Bao2024EnSF │
│ NavierStokesDim256  │
│ KuramotoSivashinsky │
└─────────────────────┘

In [14]:
required_observation_step_count = duckdb.sql(f"""
    select
        dataset_name,
        max(observation_step_count) as required_observation_step_count,
    from (
        select
            dataset_name,
            count(*) as observation_step_count,
        from logged_metrics
        group by alt_id, dataset_name
    )
    group by dataset_name
""")
required_observation_step_count

┌─────────────────────┬─────────────────────────────────┐
│    dataset_name     │ required_observation_step_count │
│       varchar       │              int64              │
├─────────────────────┼─────────────────────────────────┤
│ KuramotoSivashinsky │                              50 │
│ NavierStokesDim256  │                              50 │
│ Lorenz96Bao2024EnSF │                              50 │
└─────────────────────┴─────────────────────────────────┘

In [15]:
failed_before_finish_cols = """
    alt_id,
    dataset_name,
    model_name,
"""
failed_before_finish = duckdb.sql(f"""
    select
        observation_steps_back.*,
    from (
        select
            {failed_before_finish_cols}
            count(*) as observation_step_count,
        from logged_metrics
        group by {failed_before_finish_cols}
    ) as observation_steps_back
    join required_observation_step_count
    on observation_steps_back.dataset_name = required_observation_step_count.dataset_name
    and observation_steps_back.observation_step_count < required_observation_step_count.required_observation_step_count
    order by observation_steps_back.dataset_name, model_name desc, observation_step_count
""")
failed_before_finish

┌──────────┬─────────────────────┬────────────┬────────────────────────┐
│  alt_id  │    dataset_name     │ model_name │ observation_step_count │
│ varchar  │       varchar       │  varchar   │         int64          │
├──────────┼─────────────────────┼────────────┼────────────────────────┤
│ r8u0bwft │ Lorenz96Bao2024EnSF │ EnSF       │                     10 │
│ nt92emwo │ Lorenz96Bao2024EnSF │ EnSF       │                     30 │
│ hjru879w │ Lorenz96Bao2024EnSF │ EnFF-OT    │                     10 │
│ zc0fdoha │ NavierStokesDim256  │ EnSF       │                      2 │
│ 18n5tjx1 │ NavierStokesDim256  │ EnSF       │                      4 │
│ zr9rskpr │ NavierStokesDim256  │ EnSF       │                      7 │
│ 0bp4bsfy │ NavierStokesDim256  │ EnSF       │                      7 │
│ jv51gocs │ NavierStokesDim256  │ EnSF       │                     11 │
│ ouxuxk85 │ NavierStokesDim256  │ EnSF       │                     17 │
│ 37ts0aqx │ NavierStokesDim256  │ EnSF       │    

In [16]:
logged_metrics = duckdb.sql("""
    select
        *
    from logged_metrics
    where alt_id not in (select alt_id from failed_before_finish)
""").pl()

FloatProgress(value=0.0, layout=Layout(width='auto'), style=ProgressStyle(bar_color='black'))

In [17]:
duckdb.sql("""
    select
        dataset_name,
        model_name,
        sampling_time_step_count,
        count(*) as 'Num. runs failed',
    from rows
    where alt_id not in (select alt_id from logged_metrics)
    group by dataset_name, model_name, sampling_time_step_count
    order by dataset_name, model_name, sampling_time_step_count
""")

┌─────────────────────┬────────────┬──────────────────────────┬──────────────────┐
│    dataset_name     │ model_name │ sampling_time_step_count │ Num. runs failed │
│       varchar       │  varchar   │          int64           │      int64       │
├─────────────────────┼────────────┼──────────────────────────┼──────────────────┤
│ KuramotoSivashinsky │ EnFF-F2P   │                        5 │               30 │
│ KuramotoSivashinsky │ EnFF-F2P   │                       10 │               10 │
│ KuramotoSivashinsky │ EnFF-OT    │                        5 │               28 │
│ KuramotoSivashinsky │ EnFF-OT    │                       10 │                2 │
│ KuramotoSivashinsky │ EnSF       │                        5 │               75 │
│ Lorenz96Bao2024EnSF │ EnFF-F2P   │                        5 │               26 │
│ Lorenz96Bao2024EnSF │ EnFF-OT    │                        5 │               20 │
│ Lorenz96Bao2024EnSF │ EnSF       │                        5 │               61 │
│ Na

### Top-k hyperparameters

In [18]:
group_by = """
    dataset_name,
    dataset_name_latex_short,
    model_name,
    sampling_time_step_count,
    hyperparameter1,
    hyperparameter1_name,
    hyperparameter1_name_latex,
    hyperparameter2,
    hyperparameter2_name,
    hyperparameter2_name_latex,
"""
logged_metrics_means = duckdb.sql(f"""
    select
        {group_by}
        mean(time_s) as time_s_mean,
        mean(rmse) as rmse,
        mean(crps) as crps,
    from logged_metrics
    group by
        {group_by}
""")
logged_metrics_means.show(max_width=125)

┌─────────────────────┬──────────────────────┬───┬──────────────────────┬──────────────────────┬────────────────────┐
│    dataset_name     │ dataset_name_latex…  │ … │     time_s_mean      │         rmse         │        crps        │
│       varchar       │       varchar        │   │        double        │        double        │       double       │
├─────────────────────┼──────────────────────┼───┼──────────────────────┼──────────────────────┼────────────────────┤
│ Lorenz96Bao2024EnSF │ Lorenz '96           │ … │   0.6574258482599993 │   1.1205621147155762 │ 1120.4936181640626 │
│ Lorenz96Bao2024EnSF │ Lorenz '96           │ … │   2.5859891654199973 │   0.8354133546352387 │  607.6280810546875 │
│ KuramotoSivashinsky │ KS                   │ … │  0.09884062616000051 │   1.8084643314456104 │  57.86870681160109 │
│ NavierStokesDim256  │ NS ($256 \times 25…  │ … │  0.46809859621999295 │   0.1295273420214653 │ 57.418344039916995 │
│ KuramotoSivashinsky │ KS                   │ … │  0.24

In [19]:
ranked_by_rmse = duckdb.sql(f"""
    select
        *,
        row_number() over (
            partition by dataset_name, model_name, sampling_time_step_count, k
            order by k
        ) as tie_breaker
    from (
        select
            dataset_name,
            dataset_name_latex_short,
            model_name,
            sampling_time_step_count,
            hyperparameter1,
            hyperparameter1_name,
            hyperparameter1_name_latex,
            hyperparameter2,
            hyperparameter2_name,
            hyperparameter2_name_latex,
            rmse,
            crps,
            time_s_mean,
            dense_rank() over (
                partition by dataset_name, model_name, sampling_time_step_count
                order by rmse
            ) as k,
        from logged_metrics_means
    )
    order by dataset_name, model_name, sampling_time_step_count, k, tie_breaker
""")
ranked_by_rmse.show(max_width=125)

┌─────────────────────┬──────────────────────┬────────────┬───┬──────────────────────┬───────┬─────────────┐
│    dataset_name     │ dataset_name_latex…  │ model_name │ … │     time_s_mean      │   k   │ tie_breaker │
│       varchar       │       varchar        │  varchar   │   │        double        │ int64 │    int64    │
├─────────────────────┼──────────────────────┼────────────┼───┼──────────────────────┼───────┼─────────────┤
│ KuramotoSivashinsky │ KS                   │ EnFF-F2P   │ … │ 0.008697827060000378 │     1 │           1 │
│ KuramotoSivashinsky │ KS                   │ EnFF-F2P   │ … │ 0.007716459259999766 │     2 │           1 │
│ KuramotoSivashinsky │ KS                   │ EnFF-F2P   │ … │ 0.007475149399999168 │     3 │           1 │
│ KuramotoSivashinsky │ KS                   │ EnFF-F2P   │ … │ 0.008469668259999707 │     4 │           1 │
│ KuramotoSivashinsky │ KS                   │ EnFF-F2P   │ … │ 0.008049477119999721 │     5 │           1 │
│ KuramotoSivashins

In [20]:
duckdb.sql(f"""
    copy (
        select
            dataset_name,
            model_name,
            sampling_time_step_count,
            hyperparameter1,
            hyperparameter1_name,
            hyperparameter2,
            hyperparameter2_name,
            rmse,
            crps,
            time_s_mean,
            k,
            tie_breaker,
        from ranked_by_rmse
    ) to '../sweeps/{topk_hyperparameter_filename}.csv'
""")

In [21]:
top1_hyperparameters = duckdb.sql("""
    select
        dataset_name,
        dataset_name_latex_short as System,
        model_name,
        sampling_time_step_count as '$T$',
        hyperparameter1,
        hyperparameter1_name_latex,
        hyperparameter2,
        hyperparameter2_name_latex,
    from ranked_by_rmse
    where true
    and k = 1 and tie_breaker = 1
    and (
      --true
        dataset_name != 'NavierStokesDim256' and sampling_time_step_count = 5
        or
        dataset_name = 'NavierStokesDim256' and sampling_time_step_count = 10
    )
    order by sampling_time_step_count, dataset_name, model_name desc
""").pl()
top1_hyperparameters

dataset_name,System,model_name,$T$,hyperparameter1,hyperparameter1_name_latex,hyperparameter2,hyperparameter2_name_latex
str,str,str,i64,f64,str,f64,str
"""KuramotoSivashinsky""","""KS""","""EnSF""",5,1.0,"""$\epsilon_{\alpha}$""",0.275,"""$\epsilon_{\beta}$"""
"""KuramotoSivashinsky""","""KS""","""EnFF-OT""",5,1e-05,"""$\sigma_{\min}$""",0.05,"""$\lambda$"""
"""KuramotoSivashinsky""","""KS""","""EnFF-F2P""",5,0.001,"""$\sigma_{\min}$""",0.005,"""$\lambda$"""
"""Lorenz96Bao2024EnSF""","""Lorenz '96""","""EnSF""",5,1.0,"""$\epsilon_{\alpha}$""",0.275,"""$\epsilon_{\beta}$"""
"""Lorenz96Bao2024EnSF""","""Lorenz '96""","""EnFF-OT""",5,0.1,"""$\sigma_{\min}$""",0.05,"""$\lambda$"""
"""Lorenz96Bao2024EnSF""","""Lorenz '96""","""EnFF-F2P""",5,0.0001,"""$\sigma_{\min}$""",0.05,"""$\lambda$"""
"""NavierStokesDim256""","""NS ($256 \times 256$)""","""EnSF""",10,1.0,"""$\epsilon_{\alpha}$""",0.275,"""$\epsilon_{\beta}$"""
"""NavierStokesDim256""","""NS ($256 \times 256$)""","""EnFF-OT""",10,0.01,"""$\sigma_{\min}$""",0.005,"""$\lambda$"""
"""NavierStokesDim256""","""NS ($256 \times 256$)""","""EnFF-F2P""",10,0.001,"""$\sigma_{\min}$""",0.001,"""$\lambda$"""


In [22]:
top1_hyperparameters_pivot = top1_hyperparameters.pivot(
    on=['model_name', 'hyperparameter1_name_latex', 'hyperparameter2_name_latex'],
    index=['System'],
    values=['hyperparameter1', 'hyperparameter2'],
    separator='|',
)
cols = ['System']
gt = GT(top1_hyperparameters_pivot)
gt = gt.tab_spanner(label=' ', columns=['System'])
for model_name, hyperparameter1_name_latex, hyperparameter2_name_latex in (
    top1_hyperparameters[['model_name', 'hyperparameter1_name_latex', 'hyperparameter2_name_latex']].unique().iter_rows()
):
    value_string = f'{{"{model_name}","{hyperparameter1_name_latex}","{hyperparameter2_name_latex}"}}'
    cols = {
        f'hyperparameter1|{value_string}': hyperparameter1_name_latex,
        f'hyperparameter2|{value_string}': hyperparameter2_name_latex
    }
    gt = gt.tab_spanner(label=model_name, columns=list(cols)).cols_label(**cols)
    if 'sigma' in hyperparameter1_name_latex:
        gt = gt.fmt_scientific(columns=list(cols)[0], decimals=0)
gt

Unnamed: 0_level_0,EnSF,EnSF,EnFF-OT,EnFF-OT,EnFF-F2P,EnFF-F2P
System,$\epsilon_{\alpha}$,$\epsilon_{\beta}$,$\sigma_{\min}$,$\lambda$,$\sigma_{\min}$,$\lambda$
KS,1.0,0.275,1 × 10−5,0.05,1 × 10−3,0.005
Lorenz '96,1.0,0.275,1 × 10−1,0.05,1 × 10−4,0.05
NS ($256 \times 256$),1.0,0.275,1 × 10−2,0.005,1 × 10−3,0.001


In [23]:
table_latex = (
    gt.as_latex()
    .replace(r'\{', r'{')
    .replace(r'\}', r'}')
    .replace(r'\_', r'_')
    .replace(r'\$', r'$')
    .replace(r'\\times', r'\times')
    .replace(r'\\epsilon', r'\varepsilon')
    .replace(r'\\alpha', r'\alpha')
    .replace(r'\\beta', r'\beta')
    .replace(r'\\sigma', r'\sigma')
    .replace(r'\\min', r'\min')
    .replace(r'\\lambda', r'\lambda')
    .replace('None', '--')
    .replace('Inflation', 'Infl.')
    .replace('Localization', 'Loc.')
)
for i in range(1, 4):
    table_latex = table_latex.replace(rf'1 $\times$ 10\textsuperscript{{-{i}}}', f'0.{"":{0}<{i-1}}1')
table_latex = table_latex.replace(r'1 $\times$ ', '')
print(table_latex)

\begin{table}[!t]


\fontsize{12.0pt}{14.4pt}\selectfont

\begin{tabular*}{\linewidth}{@{\extracolsep{\fill}}lrrrrrr}
\toprule
\multicolumn{1}{c}{ } & \multicolumn{2}{c}{EnSF} & \multicolumn{2}{c}{EnFF-OT} & \multicolumn{2}{c}{EnFF-F2P} \\ 
\cmidrule(lr){1-1} \cmidrule(lr){2-3} \cmidrule(lr){4-5} \cmidrule(lr){6-7}
System & $\varepsilon_{\alpha}$ & $\varepsilon_{\beta}$ & $\sigma_{\min}$ & $\lambda$ & $\sigma_{\min}$ & $\lambda$ \\ 
\midrule\addlinespace[2.5pt]
KS & 1.0 & 0.275 & 10\textsuperscript{-5} & 0.05 & 0.001 & 0.005 \\
Lorenz '96 & 1.0 & 0.275 & 0.1 & 0.05 & 10\textsuperscript{-4} & 0.05 \\
NS ($256 \times 256$) & 1.0 & 0.275 & 0.01 & 0.005 & 0.001 & 0.001 \\
\bottomrule
\end{tabular*}

\end{table}



### Heatmap

In [None]:
hyperparameter_grid = duckdb.sql("""
    select
        'EnSF' as model_name,
        epsilon_alpha as hyperparameter1,
        epsilon_beta as hyperparameter2,
    from sweep_ensf_epsilon_alpha
    cross join sweep_ensf_epsilon_beta
    union
    select
        model_name,
        sigma_min as hyperparameter1,
        lambda as hyperparameter2,
    from sweep_enff_sigma_min
    cross join sweep_enff_lambda
    cross join (values ('EnFF-OT'), ('EnFF-F2P')) as t(model_name)
""")

In [None]:
grid_search_info = duckdb.sql("""
    select distinct
        dataset_name,
        model_name,
        sampling_time_step_count,
        hyperparameter1_name_latex,
        hyperparameter2_name_latex,
    from rows
    order by dataset_name, model_name desc, sampling_time_step_count
""")
grid_search_info

In [None]:
def get_logged_metrics_pivot_table(
    dataset_name,
    model_name,
    sampling_time_step_count,
    hyperparameter1_name_latex,
    hyperparameter2_name_latex,
):
    logged_metrics_pivot = (
        duckdb.sql(f"""
            select
                hyperparameter_grid_filtered.hyperparameter1,
                hyperparameter_grid_filtered.hyperparameter2,
                rmse,
            from (
                select
                    hyperparameter1,
                    hyperparameter2,
                    rmse,
                from logged_metrics
                where true
                and model_name = {model_name!r}
                and sampling_time_step_count = {sampling_time_step_count}
            ) as logged_metrics_filtered
            right join (
                select
                    *
                from hyperparameter_grid
                where true
                and model_name = {model_name!r}
            ) as hyperparameter_grid_filtered on (
                logged_metrics_filtered.hyperparameter1 = hyperparameter_grid_filtered.hyperparameter1
                and
                logged_metrics_filtered.hyperparameter2 = hyperparameter_grid_filtered.hyperparameter2
            )
            order by hyperparameter_grid_filtered.hyperparameter1, hyperparameter_grid_filtered.hyperparameter2
        """).pl()
        .select('hyperparameter1', 'hyperparameter2', 'rmse')
        .pivot(on='hyperparameter1', index='hyperparameter2', aggregate_function='mean')
        .to_pandas()
        .set_index('hyperparameter2')
    )
    logged_metrics_pivot.index = logged_metrics_pivot.index.rename(hyperparameter2_name_latex)
    if model_name == 'EnSF':
        logged_metrics_pivot.columns = logged_metrics_pivot.columns.rename(hyperparameter1_name_latex)
        logged_metrics_pivot = logged_metrics_pivot.T
    else:
        logged_metrics_pivot.columns = logged_metrics_pivot.columns.map(lambda x: f'{float(x):.0e}').rename(hyperparameter1_name_latex)
    return logged_metrics_pivot

In [None]:
for (
    dataset_name,
    model_name,
    sampling_time_step_count,
    hyperparameter1_name_latex,
    hyperparameter2_name_latex,
) in grid_search_info.fetchall():
    logged_metrics_pivot = get_logged_metrics_pivot_table(
        dataset_name,
        model_name,
        sampling_time_step_count,
        hyperparameter1_name_latex,
        hyperparameter2_name_latex,
    )
    top_idx = list(zip(*np.unravel_index(logged_metrics_pivot.fillna(1e6).to_numpy().ravel().argsort()[:3], logged_metrics_pivot.shape)))
    
    if model_name == 'EnSF':
        wscale = .8 * font_scale#+ 1
        hscale = 1 * font_scale#+ 4
    else:
        wscale = .3 * font_scale
        hscale = 1.7 * font_scale
    fig, (ax, ax_cbar) = plt.subplots(2, 1, height_ratios=[20, 1], figsize=(wscale * logged_metrics_pivot.shape[0], hscale * logged_metrics_pivot.shape[1]))
    # rmse_max = 5
    (
        sns.heatmap(
            data=logged_metrics_pivot,
            linewidths=1,
            annot=True,
            fmt='.3f',
            cmap='tab20b',
            # cmap=cmap,
            vmin=.0,
            vmax=1.5,
            ax=ax,
            cbar_ax=ax_cbar,
            cbar_kws=dict(
                orientation='horizontal',
            ),
        )
    )
    top_idx_rev = [tuple(reversed(idx)) for idx in top_idx]
    for idx in top_idx_rev:
        ax.add_patch(matplotlib.patches.Rectangle(idx, 1, 1, fill=False, edgecolor='red', lw=20))
    # ax_cbar.remove()
    # ax.invert_xaxis()
    ax.invert_yaxis()
    
    file_dir = Path(dataset_name)
    file_dir.mkdir(exist_ok=True)
    fig.savefig(file_dir/f'Tune_{model_name}_T{sampling_time_step_count}.pdf', format='pdf', bbox_inches='tight', pad_inches=.06)
    break
logged_metrics_pivot