In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import seaborn.objects as so
from pathlib import Path

from opf.plot.data import (
    load_run_metadata,
    test_or_load_runs,
    case_summary,
    model_summary,
    case_names,
    select_dual_models,
)
from opf.plot.plot import set_theme_paper, so_theme

figure_dir = Path("../figures/thesis")
figure_dir.mkdir(exist_ok=True, parents=True)

set_theme_paper()

In [None]:
run_dict = load_run_metadata()
df = test_or_load_runs(run_dict)
df = select_dual_models(df, ["Dual-S+", "Dual-P+", "Dual-H+"])

In [None]:
df_summary = case_summary(df)
display(
    df_summary.style.format("{:.2f}", na_rep="--").set_table_styles(
        [{"selector": "th", "props": [("text-align", "center")]}]
    )
)

In [None]:
df_model_summary = model_summary(df)
model_summary(df).style.format("{:.2f}", na_rep="--")

In [None]:
from opf.plot.data import constraint_breakdown
from opf.plot.plot import constraint_breakdown_latex


df_pivoted = constraint_breakdown(df)

constraint_breakdown_string = constraint_breakdown_latex(df_pivoted)
(figure_dir / "constraint_breakdown.tex").write_text(constraint_breakdown_string)

display(df_pivoted.style.format("{:.2f}", na_rep="--"))

In [None]:
for case_name in case_names:
    data = df.query("case_name == @case_name").melt(
        id_vars="model_name",
        value_vars=[
            "optimality_gap",
            "test_normal/inequality/error_max",
        ],
    )
    f = plt.figure(figsize=(7.0, 3))
    f.suptitle(case_name)
    p = (
        so.Plot(data, x="value", y="model_name")
        .add(so.Dot(pointsize=1), so.Jitter(y=0.5))
        .facet(col="variable")
        .theme(so_theme() | {"axes.grid.which": "both"})
        .share(x=False)
        .on(f)
        .plot()
    )
    f.axes[1].set_xscale("log")
    plt.close(f)
    display(p)

In [None]:
from opf.plot.plot import plot_tradeoff

p = plot_tradeoff(df, max=False)
display(p)

In [None]:
p = plot_tradeoff(df, max=True)
display(p)

In [None]:
import wandb
import pandas as pd

dual_kind = "pointwise"

api = wandb.Api()
runs = api.runs(
    "damowerko-academic/opf",
    filters={"tags": {"$in": [f"search-{dual_kind}"]}},
)

lr = {}
for run in runs:
    lr[run.id] = run.config[f"lr_dual_{dual_kind}"]
run_data = runs.histories(samples=1000, keys=["val/invariant"], format="pandas")

import seaborn.objects as so

data = (
    pd.merge(
        run_data,
        pd.DataFrame(lr.items(), columns=["run_id", f"lr_dual_{dual_kind}"]),
        on="run_id",
    )
    .groupby(f"lr_dual_{dual_kind}", as_index=False)
    .min()
    .sort_values(f"lr_dual_{dual_kind}")
)

so.Plot(data, x=f"lr_dual_{dual_kind}", y="val/invariant").add(so.Dot()).scale(x="log")