In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
from pathlib import Path
import re
import seaborn as sns
import matplotlib.pyplot as plt

# Parse Data

In [None]:
# load runtime data for CNN
df_cnn = pd.read_csv("data/out/generalization/e7ivqipk_runtime.csv")
df_cnn.head()

In [None]:
# parse runtime data for LMCO data

# example: "data/sim_data/glmb/glmb_3.0k_10_600_0621_noclip_20markov/35/time_35.txt"
regex = re.compile(r"data/sim_data/(\w+)/[\w]+_(\d+).*/(\d+)/time_\d+\.txt")
data = []
for path in Path("data/sim_data").glob("**/time_*.txt"):
    match = regex.search(path.as_posix())
    if match is None:
        raise ValueError(f"Could not parse path: {path}")

    filter, scale, sim_idx = match.groups()
    data.append(
        {
            "filter": filter,
            "scale": int(scale),
            "sim_idx": int(sim_idx),
            "time": float(path.read_text().strip()) / 1000,
        }
    )
df_lmco = pd.DataFrame(data)
df_lmco.head()

In [None]:
# merge the two dfs
df_runtime = pd.concat(
    [
        df_cnn.assign(
            filter="cnn",
            time=df_cnn["forward"] + df_cnn["peaks"],
        )[["filter", "scale", "time"]],
        df_lmco[["filter", "scale", "time"]],
    ],
    ignore_index=True,
)

# Results

In [None]:
# summary stats
df_runtime["time_per_km2"] = df_runtime["time"] / df_runtime["scale"] ** 2
(
    df_runtime.groupby(["scale", "filter"])
    .agg({"time": ["mean", "std"], "time_per_km2": ["mean", "std"]})
    .round(2)
)

In [None]:
# get the ratio of glmb to cnn and lmb to cnn
cnn_means = df_runtime.query("filter == 'cnn'").groupby("scale")["time"].mean()
(
    df_runtime.query("scale <= 3")
    .groupby(["scale", "filter"])["time"]
    .mean()
    .groupby("scale")
    .transform(lambda x: x / cnn_means)
    .to_frame()
    .rename(columns={"time_per_km2": "Speedup with CNN"})
    .round(0)
    .transpose()
)

In [None]:
# summary plot

# Set log scale in y-axis

# Total Runtime
sns.catplot(data=df_runtime, x="scale", y="time", hue="filter")
plt.yscale("log")
plt.title("Total Runtime")

# Normalized Runtime
sns.catplot(data=df_runtime, x="scale", y="time_per_km2", hue="filter")
plt.yscale("log")
plt.title("Normalized Runtime")

plt.show()