In [None]:
# Column of 3 big plots: Qubits | Qutrits | Hybrid

import os, glob
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.ticker import (
    LogFormatterMathtext, LogLocator, FixedLocator
)

# -------------------- Paths --------------------
CSV_QS = "qubits_qutrits_full_basis_runs.csv"
CSV_HT = "hybrid_full_basis_runs.csv"

OUT_PNG = "Partition_hist.png"
OUT_PDF = "Partition_hist.pdf"

# -------------------- Global styling  --------------------
FONT = 36
INFO_FONT_QS  = max(FONT - 6, 18)
INFO_FONT_HET = max(FONT - 6, 18)

plt.rcParams.update({
    "pdf.fonttype": 42, "ps.fonttype": 42,
    "figure.dpi": 150, "savefig.dpi": 600,
    "font.size":   FONT,
    "axes.labelsize": FONT + 2,
    "axes.titlesize": FONT + 4,  
    "xtick.labelsize": FONT - 2,
    "ytick.labelsize": FONT - 2,
    "legend.fontsize": FONT - 6,
    "axes.linewidth": 2.8,
    "xtick.major.width": 2.4, "ytick.major.width": 2.4,
    "xtick.minor.width": 1.6, "ytick.minor.width": 1.6,
    "xtick.major.size": 9,    "ytick.major.size": 9,
    "xtick.minor.size": 5,    "ytick.minor.size": 5,
    "axes.grid": False,
})

# -------------------- Load & sanitize --------------------
df_qs = pd.read_csv(CSV_QS)
has_row_type = "row_type" in df_qs.columns

for col in ["d","N","total_ops","colors","times","colors_std","times_std","seed","n_seeds"]:
    if col in df_qs.columns:
        df_qs[col] = pd.to_numeric(df_qs[col], errors="coerce")

df_qs["system"] = df_qs["system"].astype(str)
df_qs["method"] = df_qs["method"].astype(str)
if "d" in df_qs.columns: df_qs["d"] = df_qs["d"].astype("Int64")
if "N" in df_qs.columns: df_qs["N"] = df_qs["N"].astype("Int64")

if has_row_type:
    df_qs_sum = df_qs[df_qs["row_type"]=="summary"].copy()
    df_qs_sum["colors_std"] = df_qs_sum.get("colors_std", 0.0).fillna(0.0)
    df_qs_sum["times_std"]  = df_qs_sum.get("times_std", 0.0).fillna(0.0)
    for col in ["d","N","total_ops","n_seeds"]:
        if col in df_qs_sum.columns:
            try: df_qs_sum[col] = df_qs_sum[col].astype(int)
            except Exception: pass
else:
    df_qs_sum = (df_qs.groupby(["system","d","N","method"], as_index=False)
                      .agg(colors=("colors","mean"),
                           colors_std=("colors","std"),
                           times=("times","mean"),
                           times_std=("times","std"),
                           total_ops=("total_ops","max")))

df_qs_sum = df_qs_sum.dropna(subset=["system","d","N","method","colors"])
df_qs_sum["d"] = df_qs_sum["d"].astype(int)
df_qs_sum["N"] = df_qs_sum["N"].astype(int)
if "total_ops" in df_qs_sum.columns and df_qs_sum["total_ops"].notna().any():
    df_qs_sum["total_ops"] = df_qs_sum["total_ops"].astype(int)

# Hybrid (RAW)
df_ht = pd.read_csv(CSV_HT)
if "row_type" in df_ht.columns and (df_ht["row_type"] == "raw").any():
    df_ht = df_ht[df_ht["row_type"] == "raw"].copy()

for c in ["n_qubits","n_qutrits","M_used","total_ops","seed","colors","times","times_std"]:
    if c in df_ht.columns:
        df_ht[c] = pd.to_numeric(df_ht[c], errors="coerce")

df_ht = df_ht.dropna(subset=["n_qubits","n_qutrits","method","colors"])
if not df_ht.empty:
    df_ht["n_qubits"]  = df_ht["n_qubits"].astype(int)
    df_ht["n_qutrits"] = df_ht["n_qutrits"].astype(int)
    if "M_used" in df_ht: df_ht["M_used"] = df_ht["M_used"].astype(int, errors="ignore")
    if "total_ops" in df_ht: df_ht["total_ops"] = df_ht["total_ops"].astype(int, errors="ignore")
    df_ht["method"] = df_ht["method"].astype(str)

# -------------------- Methods & colors --------------------
methods_qs = sorted(df_qs_sum["method"].unique().tolist())
methods_ht = sorted(df_ht["method"].unique().tolist()) if not df_ht.empty else []
methods_all = sorted(set(methods_qs) | set(methods_ht))

palette = [
    "#1f77b4", "#d62728", "#2ca02c", "#9467bd", "#ff7f0e",
    "#8c564b", "#17becf", "#e377c2", "#7f7f7f", "#bcbd22",
]
style = {m: dict(color=palette[i % len(palette)]) for i, m in enumerate(sorted(methods_all))}

def ilp_last(method_list):
    return [m for m in method_list if m.lower()!="ilp"] + [m for m in method_list if m.lower()=="ilp"]

# -------------------- Info helpers --------------------
def fmt_int(x):
    try: return f"{int(x):,}"
    except: return str(x)

def total_ops_for(d_val, Ns, sub_df=None):
    if sub_df is not None and "total_ops" in sub_df.columns and sub_df["total_ops"].notna().any():
        m = sub_df.groupby("N", as_index=False)["total_ops"].max()
        m = dict(zip(m["N"].astype(int), m["total_ops"].astype(int)))
        return [int(m.get(n, (d_val*d_val)**n - 1)) for n in Ns]
    return [(d_val*d_val)**n - 1 for n in Ns]

def make_info_text_qubits(Ns, totals):
    entries = [(int(n), f"N={int(n)}→{fmt_int(t)}") for n, t in zip(Ns, totals)]
    line1_items = [s for n, s in entries if n < 4]
    line2_items = [s for n, s in entries if n >= 4]
    line1 = ", ".join(line1_items)
    line2 = ", ".join(line2_items)
    txt = "Total ops\n"
    if line2:
        txt += (line1 + ("," if line1 else "") + "\n" + line2)
    else:
        txt += line1
    return txt

def make_info_text_generic(Ns, totals):
    items = [f"N={int(n)}→{fmt_int(t)}" for n, t in zip(Ns, totals)]
    if len(items) > 6:
        half = (len(items) + 1) // 2
        line1 = ", ".join(items[:half])
        line2 = ", ".join(items[half:])
        return "Total ops\n" + line1 + ("," if line2 else "") + "\n" + line2
    return "Total ops\n" + ", ".join(items)

def make_info_text_hetero_two_lines(pairs, totals):
    items = [f"{nq}q ⊗ {nt}t→{(fmt_int(tot) if tot is not None else 'n/a')}"
             for (nq, nt), tot in zip(pairs, totals)]
    if len(items) <= 1:
        return "Total ops\n" + (items[0] if items else "")
    half = (len(items) + 1) // 2
    line1 = ", ".join(items[:half])
    line2 = ", ".join(items[half:])
    return "Total ops\n" + line1 + "," + "\n" + line2

# -------------------- Runtime axis (fixed log) --------------------
RUNTIME_YMIN = 1e-4
RUNTIME_YMAX = 1e2
RUNTIME_TICKS = [1e-4, 1e-2, 1e0, 1e2]
RUNTIME_FMT = LogFormatterMathtext(base=10)

# Label offsets
LABEL_OFFSETS = {
    "left":  -0.15,
    "right":  1.16,
}
TICK_PAD = 14
XLABEL_PAD = 18

# ===== Uniform thin log grid =====
GRID_LW   = 0.45
GRID_ALPH = 0.75
GRID_COL  = "0.65"
def draw_uniform_log_grid_left(ax):
    ax.set_axisbelow(False)  # grid above bars
    ax.grid(False)
    if not isinstance(ax.yaxis.get_major_locator(), LogLocator):
        ax.yaxis.set_major_locator(LogLocator(base=10.0))
    ax.yaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(2, 10)*0.1))
    ax.grid(True, which="major", axis="y", linestyle=":", linewidth=GRID_LW, alpha=GRID_ALPH, color=GRID_COL, zorder=10)
    ax.grid(True, which="minor", axis="y", linestyle=":", linewidth=GRID_LW, alpha=GRID_ALPH, color=GRID_COL, zorder=10)

def configure_right_log_axis(ax2, show_label=True):
    ax2.set_yscale("log", base=10)
    ax2.set_ylim(RUNTIME_YMIN, RUNTIME_YMAX)
    ax2.yaxis.set_major_locator(FixedLocator(RUNTIME_TICKS))
    ax2.yaxis.set_major_formatter(RUNTIME_FMT)
    ax2.yaxis.set_minor_locator(LogLocator(base=10.0, subs=np.arange(2, 10)*0.1))
    ax2.tick_params(axis='y', which='both', right=True, labelright=True, pad=TICK_PAD, labelsize=FONT-2)
    for sp in ax2.spines.values(): sp.set_linewidth(2.6)
    ax2.grid(False)
    if show_label:
        ax2.set_ylabel("Runtime (s)")
        ax2.yaxis.set_label_coords(LABEL_OFFSETS["right"], 0.5)
        ax2.yaxis.label.set_fontsize(FONT + 4)
    else:
        ax2.set_ylabel(None)

def _safe_pos(vals):
    v = np.array(vals, dtype=float)
    v = np.where(np.isfinite(v) & (v > 0), v, np.nan)
    return v

# ===== Spacing & widths =====
GROUP_SPACING = 2.8
WIDTH_PART_MAX = 0.22
GAP_PART_MIN   = 0.16
GAP_PART_REL   = 0.64
TIME_W_CAP     = 1.08

def _calc_offsets_and_widths(M):
    width_part = min(WIDTH_PART_MAX, 0.85 / max(1, M))
    gap_part   = max(GAP_PART_MIN, GAP_PART_REL * width_part)
    delta      = width_part + gap_part
    eps = 0.06
    width_time = min(delta * (1.0 - eps), TIME_W_CAP)
    offsets = (np.arange(M) - (M-1)/2.0) * delta
    return width_part, gap_part, offsets, width_time, delta

# -------------------- Panels --------------------
def _style_axes_common(ax, add_left_label: bool, left_label_text: str):
    ax.set_yscale("log")
    for sp in ax.spines.values():
        sp.set_linewidth(2.8)
    ax.tick_params(axis='x', which='both', pad=TICK_PAD)
    ax.tick_params(axis='y', which='both', pad=TICK_PAD)
    if add_left_label:
        ax.set_ylabel(left_label_text)
        ax.yaxis.set_label_coords(LABEL_OFFSETS["left"], 0.5)
        ax.yaxis.label.set_fontsize(FONT + 4)
    else:
        ax.set_ylabel(None)

def plot_qs_panel(ax, sub_sum, title, x_label, add_y_label=False, draw_mub_segments=False, show_time_ylabel=True):
    if sub_sum.empty:
        ax.set_title(title, pad=16, fontweight="bold", fontsize=FONT+2)
        ax.text(0.5,0.5,"No data", ha="center", va="center"); ax.axis("off"); return

    methods = ilp_last(sorted(sub_sum["method"].unique().tolist()))
    Ns = sorted(sub_sum["N"].unique().astype(int).tolist())
    x = np.arange(len(Ns), dtype=float) * GROUP_SPACING

    M = len(methods)
    width_part, gap_part, offsets, width_time, delta = _calc_offsets_and_widths(M)

    # Aggregate stats
    per_m_colors, per_m_times = {}, {}
    for m in methods:
        g = (sub_sum[sub_sum["method"]==m]
             .groupby("N", as_index=False)
             .agg(mu_c=("colors","mean"),
                  sd_c=("colors_std","mean"),
                  mu_t=("times","mean"),
                  sd_t=("times_std","mean"))
             .sort_values("N"))
        g["N"] = g["N"].astype(int)
        gN = set(g["N"].tolist())

        mu_c = np.array([float(g.loc[g["N"]==n, "mu_c"].iloc[0]) if n in gN else np.nan for n in Ns], dtype=float)
        sd_c = np.array([float(g.loc[g["N"]==n, "sd_c"].iloc[0]) if n in gN and pd.notna(g.loc[g["N"]==n, "sd_c"].iloc[0]) else 0.0 for n in Ns], dtype=float)

        mu_t = np.array([float(g.loc[g["N"]==n, "mu_t"].iloc[0]) if n in gN and pd.notna(g.loc[g["N"]==n, "mu_t"].iloc[0]) else np.nan for n in Ns], dtype=float)
        sd_t = np.array([float(g.loc[g["N"]==n, "sd_t"].iloc[0]) if n in gN and pd.notna(g.loc[g["N"]==n, "sd_t"].iloc[0]) else 0.0 for n in Ns], dtype=float)

        per_m_colors[m] = (mu_c, sd_c)
        per_m_times[m]  = (_safe_pos(mu_t), _safe_pos(sd_t))

    # Right axis (fixed; label ON)
    ax2 = ax.twinx()
    for i, m in enumerate(methods):
        mu_t, sd_t = per_m_times[m]
        col = style.get(m, {"color":"C0"})["color"]
        pos = x + offsets[i]
        ax2.bar(pos, mu_t, width=width_time, color=col, edgecolor="none", alpha=0.28, zorder=2)
        ax2.errorbar(pos, mu_t, yerr=sd_t, fmt="none", ecolor=col, elinewidth=1.6, capsize=5.0,
                     capthick=1.4, alpha=0.28, zorder=3)
    configure_right_log_axis(ax2, show_label=True)

    # Partition bars on top
    ymax = 0.0
    for i, m in enumerate(methods):
        mu, sd = per_m_colors[m]
        col = style.get(m, {"color":"C0"})["color"]
        pos = x + offsets[i]
        ax.bar(pos, mu, width=width_part, color=col, edgecolor="none", linewidth=0.0, alpha=0.98, zorder=4, label=m)
        ax.errorbar(pos, mu, yerr=sd, fmt="none", ecolor="black", elinewidth=2.4, capsize=7.0, capthick=2.0, zorder=5)
        val_max = np.nanmax(mu + np.nan_to_num(sd))
        if np.isfinite(val_max): ymax = max(ymax, float(val_max))

    # X-limits include full time bar width
    min_off_time = float(np.min(offsets) - width_time*0.56)
    max_off_time = float(np.max(offsets) + width_time*0.56)
    ax.set_xlim(x[0] + min_off_time - 0.30, x[-1] + max_off_time + 0.30)

    # Optional MUB segments (qubits only)
    if draw_mub_segments:
        group_lefts  = x + (np.min(offsets) - width_part*0.55)
        group_rights = x + (np.max(offsets) + width_part*0.55)
        for n, xl, xr in zip(Ns, group_lefts, group_rights):
            mub = (2**int(n)) + 1
            ax.hlines(y=mub, xmin=xl, xmax=xr, linestyles=":", linewidth=3.0, colors="black", zorder=6)

    # Axes cosmetics
    ax.set_title(title, pad=16, fontweight="bold", fontsize=FONT)  # smaller heading
    ax.set_xticks(x, [str(n) for n in Ns])
    ax.set_xlabel(x_label, labelpad=XLABEL_PAD)
    _style_axes_common(ax, add_y_label, "No. of Operator partitions")

    # LEFT grid (uniform thin)
    draw_uniform_log_grid_left(ax)

    # Info box
    d_val = int(sub_sum["d"].iloc[0])
    totals = total_ops_for(d_val, Ns, sub_sum if "total_ops" in sub_sum.columns else None)
    ax.text(0.02, 0.98, (make_info_text_qubits(Ns, totals) if d_val == 2 else make_info_text_generic(Ns, totals)),
            transform=ax.transAxes, ha="left", va="top",
            fontsize=INFO_FONT_QS,
            bbox=dict(boxstyle="round,pad=0.42", facecolor=(1,1,1,0.0), edgecolor="none"),
            zorder=7)

def plot_hetero_panel(ax, df_ht_raw, title, add_y_label=False):
    if df_ht_raw.empty:
        ax.set_title(title, pad=16, fontweight="bold", fontsize=FONT+2)
        ax.text(0.5,0.5,"No heterogeneous data", ha="center", va="center"); ax.axis("off"); return

    methods = ilp_last(sorted(df_ht_raw["method"].unique().tolist()))
    cases = sorted(df_ht_raw[["n_qubits","n_qutrits"]].drop_duplicates().itertuples(index=False, name=None))
    x = np.arange(len(cases), dtype=float) * GROUP_SPACING

    M = len(methods)
    width_part, gap_part, offsets, width_time, delta = _calc_offsets_and_widths(M)

    # Stats
    stats_c = {(nq,nt): {} for (nq,nt) in cases}
    stats_t = {(nq,nt): {} for (nq,nt) in cases}
    for (nq,nt) in cases:
        sub = df_ht_raw[(df_ht_raw["n_qubits"]==nq) & (df_ht_raw["n_qutrits"]==nt)]
        for m in methods:
            s_m = sub.loc[sub["method"]==m]
            vals_c = s_m["colors"].to_numpy(float)
            vals_c = vals_c[np.isfinite(vals_c) & (vals_c>0)]
            mu_c = float(np.mean(vals_c)) if vals_c.size else np.nan
            sd_c = float(np.std(vals_c))  if vals_c.size else 0.0
            stats_c[(nq,nt)][m] = (mu_c, sd_c)

            if "times" in s_m.columns:
                vals_t = s_m["times"].to_numpy(float)
                vals_t = vals_t[np.isfinite(vals_t) & (vals_t>0)]
                mu_t = float(np.mean(vals_t)) if vals_t.size else np.nan
                sd_t = float(np.std(vals_t))  if vals_t.size else 0.0
            else:
                mu_t, sd_t = np.nan, 0.0
            stats_t[(nq,nt)][m] = (_safe_pos(mu_t), _safe_pos(sd_t))

    # Right axis (fixed)
    ax2 = ax.twinx()
    for i, m in enumerate(methods):
        col = style.get(m, {"color":"C0"})["color"]
        mu_t = np.array([stats_t[c][m][0] for c in cases], dtype=float)
        sd_t = np.array([stats_t[c][m][1] for c in cases], dtype=float)
        pos = x + offsets[i]
        ax2.bar(pos, mu_t, width=width_time, color=col, edgecolor="none", alpha=0.28, zorder=2)
        ax2.errorbar(pos, mu_t, yerr=sd_t, fmt="none", ecolor=col, elinewidth=1.6, capsize=5.0,
                     capthick=1.4, alpha=0.28, zorder=3)
    configure_right_log_axis(ax2, show_label=True)

    # Partition bars
    ymax = 0.0
    for i, m in enumerate(methods):
        col = style.get(m, {"color":"C0"})["color"]
        mu = np.array([stats_c[c][m][0] for c in cases], dtype=float)
        sd = np.array([stats_c[c][m][1] for c in cases], dtype=float)
        pos = x + offsets[i]
        ax.bar(pos, mu, width=width_part, color=col, edgecolor="none", linewidth=0.0, alpha=0.98, zorder=4, label=m)
        ax.errorbar(pos, mu, yerr=sd, fmt="none", ecolor="black", elinewidth=2.4, capsize=7.0, capthick=2.0, zorder=5)
        val_max = np.nanmax(mu + np.nan_to_num(sd))
        if np.isfinite(val_max): ymax = max(ymax, float(val_max))

    min_off_time = float(np.min(offsets) - width_time*0.56)
    max_off_time = float(np.max(offsets) + width_time*0.56)
    ax.set_xlim(x[0] + min_off_time - 0.30, x[-1] + max_off_time + 0.30)

    ax.set_title(title, pad=16, fontweight="bold", fontsize=FONT)  # smaller heading
    ax.set_yscale("log")
    ax.set_ylim(max(1e-6, 0.7), ymax * 1.85 if ymax > 0 else 2.0)
    ax.set_xticks(x, [f"{nq}q ⊗ {nt}t" for (nq, nt) in cases])
    ax.set_xlabel("Qubit ⊗ Qutrit", labelpad=XLABEL_PAD)
    for sp in ax.spines.values(): sp.set_linewidth(2.8)
    ax.tick_params(axis='x', which='both', pad=TICK_PAD)
    ax.tick_params(axis='y', which='both', pad=TICK_PAD)
    if add_y_label:
        ax.set_ylabel("No. of Operator partitions")
        ax.yaxis.set_label_coords(LABEL_OFFSETS["left"], 0.5)
        ax.yaxis.label.set_fontsize(FONT + 4)
    else:
        ax.set_ylabel(None)

    # LEFT grid (uniform thin)
    draw_uniform_log_grid_left(ax)

    # Info box
    pairs = [(nq, nt) for (nq, nt) in cases]
    totals = []
    for (nq, nt) in pairs:
        sc = df_ht_raw[(df_ht_raw["n_qubits"]==nq) & (df_ht_raw["n_qutrits"]==nt)]
        if "total_ops" in sc.columns and not sc["total_ops"].isna().all():
            try: totals.append(int(sc["total_ops"].max()))
            except Exception: totals.append(None)
        else:
            totals.append(None)
    ax.text(0.02, 0.98, make_info_text_hetero_two_lines(pairs, totals),
            transform=ax.transAxes, ha="left", va="top",
            fontsize=INFO_FONT_HET,
            bbox=dict(boxstyle="round,pad=0.40", facecolor=(1,1,1,0.0), edgecolor="none"),
            zorder=7)

# -------------------- Build figure (column) --------------------
fig = plt.figure(figsize=(15.5, 34.0))
gs = fig.add_gridspec(nrows=3, ncols=1, height_ratios=[1.0, 1.0, 1.0], hspace=0.38)

ax1 = fig.add_subplot(gs[0, 0])  # (a) Qubits
ax2 = fig.add_subplot(gs[1, 0])  # (b) Qutrits
ax3 = fig.add_subplot(gs[2, 0])  # (c) Hybrid

# Subsets
sub_qb_sum = df_qs_sum[(df_qs_sum["system"]=="qubit")  & (df_qs_sum["d"]==2)].copy()
sub_qt_sum = df_qs_sum[(df_qs_sum["system"]=="qutrit") & (df_qs_sum["d"]==3)].copy()

# Panels
plot_qs_panel(ax1, sub_qb_sum, title="(a) Qubit system",
              x_label="No. of Qubits (N)", add_y_label=True, draw_mub_segments=True,  show_time_ylabel=True)
plot_qs_panel(ax2, sub_qt_sum, title="(b) Qutrit system",
              x_label="No. of Qutrits (N)", add_y_label=True, draw_mub_segments=False, show_time_ylabel=True)
plot_hetero_panel(ax3, df_ht, title="(c) Hybrid system",
                  add_y_label=True)

# Global legend at bottom
methods_in_fig = sorted(set(sub_qb_sum["method"].unique().tolist() if not sub_qb_sum.empty else [])
                        | set(sub_qt_sum["method"].unique().tolist() if not sub_qt_sum.empty else [])
                        | set(df_ht["method"].unique().tolist() if not df_ht.empty else []))
methods_in_fig = [m for m in methods_in_fig if m.lower()!="ilp"] + [m for m in methods_in_fig if m.lower()=="ilp"]

# --- Only change: rename "Spectral" method(s) as "SC" in legend label ---
legend_patches = [
    Patch(
        facecolor=style.get(m, {"color": "C0"})["color"],
        edgecolor="none",
        label=("SC" if m.lower().startswith("spectral") else m)
    )
    for m in methods_in_fig
]

leg = fig.legend(handles=legend_patches, loc="lower center", bbox_to_anchor=(0.55, -0.02),
                 ncol=min(len(methods_in_fig), 6), frameon=True, fontsize=FONT-6,
                 columnspacing=1.6, handlelength=2.2, borderpad=0.6)
leg.get_frame().set_linewidth(1.0)

plt.subplots_adjust(left=0.19, right=0.9, top=0.985, bottom=0.07)

plt.savefig(OUT_PNG, dpi=600, bbox_inches="tight")
plt.savefig(OUT_PDF, dpi=600, bbox_inches="tight")
plt.close(fig)

print("Saved:", OUT_PNG, OUT_PDF)


In [None]:

import os, glob
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

# -------------------- Uniform typography --------------------
plt.rcParams.update({
    "font.size": 18.0,
    "axes.titlesize": 18.0,
    "axes.labelsize": 20.0,
    "xtick.labelsize": 18.0,
    "ytick.labelsize": 18.0,
    "legend.fontsize": 18.0,
    "figure.titlesize": 18.0,
    "axes.titlepad": 6.0,
    "axes.labelpad": 9.0,
    "xtick.major.pad": 4.0,
    "ytick.major.pad": 4.0,
})

# ---------- Line/marker style ----------
LINE_W      = 1.4
LINE_W_TIME = 1.6
LINE_W_REF  = 1.2
MSIZE       = 9.5
CAPSIZE     = 3.2

# -------------------- Load CSV --------------------
CSV = "gnn_multi_models_runs.csv"

df = pd.read_csv(CSV)

# Normalize numerics/strings
num_cols = [
    "train_seconds","target_N","target_M_used","total_ops","colors","time",
    "colors_gnn","colors_ilp","time_gnn","time_ilp",
    "ratio_gnn_over_ilp","ratio_mean","ratio_std","ratio_stderr"
]
for c in num_cols:
    if c in df.columns:
        df[c] = pd.to_numeric(df[c], errors="coerce")
for c in ["row_type","scenario","trained_on","model_path"]:
    if c in df.columns:
        df[c] = df[c].astype(str)

if "model_path" in df.columns:
    df["model_id"] = df["model_path"].apply(lambda s: os.path.basename(s) if isinstance(s, str) else s)
else:
    df["model_id"] = ""

# -------------------- Helpers --------------------
def _box_axes(ax, lw=1.2):
    for sp in ax.spines.values():
        sp.set_linewidth(lw)
        sp.set_color("black")

def _scen_label(s: str) -> str:
    toks = [t for t in str(s).split("_") if t.isdigit()]
    if not toks:
        return s
    nums = ",".join(str(int(t)) for t in toks)
    return f"Train on qubit {nums}" if len(toks) == 1 else f"Train on qubits {nums}"

def set_sparse_xticks(ax, values, max_ticks=8, integer=True):
    vals = np.array(sorted(set(int(v) if integer else float(v) for v in values if np.isfinite(v))))
    if vals.size == 0:
        return
    if vals.size <= max_ticks:
        ax.set_xticks(vals)
    else:
        step = int(np.ceil(vals.size / max_ticks))
        ax.set_xticks(vals[::step])

def safe_set_ylim_log(ax, data, pad_top_factor=1.8, pad_bottom_factor=0.88, floor=1e-12):
    d = np.asarray([v for v in data if np.isfinite(v) and v > 0.0])
    if d.size == 0:
        ax.set_yscale("log")
        ax.set_ylim(floor, 10.0)
        return
    lo = max(np.min(d) * pad_bottom_factor, floor)
    hi = np.max(d) * pad_top_factor
    if hi <= lo:
        hi = lo * 10.0
    ax.set_yscale("log")
    ax.set_ylim(lo, hi)

MODEL_MARKERS = ["o","^","s","D","P","X","v","<",">"]

# -------------------- Data splits --------------------
train_df     = df[df["row_type"]=="train"].copy()
targets_df   = df[(df["row_type"]=="infer") & (df["scenario"]=="targets")].copy()
subset_df    = df[df["row_type"]=="subset_ratio"].copy()
subset_stats = df[df["row_type"]=="subset_ratio_stats"].copy()

# -------------------- 3×1 column --------------------
fig, (ax1, ax2, axC) = plt.subplots(3, 1, figsize=(8, 16.2))
fig.subplots_adjust(left=0.09, right=0.96, bottom=0.06, top=0.94, hspace=0.44)

# ==================== (a) Training bars ====================
if train_df.empty:
    ax1.text(0.5, 0.5, "No training data", ha="center", va="center"); ax1.axis("off")
else:
    train_df = train_df.sort_values(["scenario","model_id"]).reset_index(drop=True)
    labels = [f"Model {i+1}" for i in range(len(train_df))]
    times  = train_df["train_seconds"].to_numpy(float)
    scen   = train_df["scenario"].tolist()

    color_cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', [])
    uniq_scen = list(dict.fromkeys(scen))
    scen_color = {s: color_cycle[i % len(color_cycle)] for i, s in enumerate(uniq_scen)}

    for i, (t, s) in enumerate(zip(times, scen)):
        ax1.bar(i, max(t, 1e-12), color=scen_color[s], edgecolor="black", linewidth=0.8, zorder=2)

    # X ticks
    if len(labels) > 10:
        step = int(np.ceil(len(labels)/10))
        idxs = list(range(0, len(labels), step))
        ax1.set_xticks(idxs)
        ax1.set_xticklabels([labels[i] for i in idxs], rotation=0, ha="center")
    else:
        ax1.set_xticks(range(len(labels)))
        ax1.set_xticklabels(labels, ha="center", rotation=0)

    ax1.set_ylabel("Training time (s)")
    ax1.grid(True, axis="y", which="both", linestyle=":", linewidth=0.9, zorder=0)
    safe_set_ylim_log(ax1, times, pad_top_factor=2.6, pad_bottom_factor=0.9, floor=1e-12)
    ax1.margins(x=0.005, y=0.01)

    patches = [Patch(facecolor=scen_color[s], edgecolor="black", label=_scen_label(s)) for s in uniq_scen]
    ax1.legend(handles=patches, loc="upper left",
               frameon=False, ncol=1,
               borderaxespad=0.25, borderpad=0.25,
               handlelength=1.0, handletextpad=0.4, labelspacing=0.25)
    _box_axes(ax1, lw=1.2)

# ==================== (b) Targets: Partitions (left, log) & Time (right) vs N ====================
if targets_df.empty:
    ax2.text(0.5, 0.5, "No targets data", ha="center", va="center"); ax2.axis("off")
else:
    ax2R = ax2.twinx()

    uniq_models = list(dict.fromkeys(targets_df["model_id"].tolist()))
    model_name_map = {m: f"Model {i+1}" for i, m in enumerate(uniq_models)}

    handles_models = []
    xtick_vals = set()
    left_vals_all = []
    right_vals_all = []

    for mi, m in enumerate(uniq_models):
        sub = targets_df[targets_df["model_id"]==m].copy()
        gg = (sub.groupby("target_N", as_index=False)
                    .agg(colors_mean=("colors","mean"),
                         colors_std=("colors","std"),
                         time_mean=("time","mean"),
                         time_std=("time","std"))
                    .sort_values("target_N"))
        Ns = gg["target_N"].dropna().astype(int).to_numpy()
        xtick_vals.update(Ns.tolist())
        mk = MODEL_MARKERS[mi % len(MODEL_MARKERS)]

        left_vals_all.extend(gg["colors_mean"].to_numpy(float))
        right_vals_all.extend(gg["time_mean"].to_numpy(float))

        ax2.errorbar(Ns, gg["colors_mean"], yerr=gg["colors_std"].fillna(0),
                     fmt=f"-{mk}", color="tab:blue", linewidth=LINE_W, capsize=CAPSIZE,
                     markersize=MSIZE)
        ax2R.errorbar(Ns, gg["time_mean"], yerr=gg["time_std"].fillna(0),
                      fmt=f"--{mk}", color="tab:red", linewidth=LINE_W_TIME, capsize=CAPSIZE,
                      markersize=MSIZE)

        handles_models.append(Line2D([0],[0], color="black", marker=mk, linestyle="None",
                                     label=model_name_map[m], markersize=MSIZE-1))

    # MUB: 2^N + 1
    all_Ns = sorted(list(xtick_vals))
    if all_Ns:
        mub = (2**np.array(all_Ns) + 1)
        ax2.plot(all_Ns, mub, linestyle=":", linewidth=LINE_W_REF, color="black", label=r"Optimal: $2^N + 1$")
        left_vals_all.extend(mub.tolist())

    safe_set_ylim_log(ax2, left_vals_all, pad_top_factor=1.7, pad_bottom_factor=0.9)
    if len(right_vals_all) > 0 and np.isfinite(right_vals_all).any():
        hi_r = np.nanmax(right_vals_all)
        ax2R.set_ylim(0, hi_r*1.08 if np.isfinite(hi_r) and hi_r>0 else 1.0)

    ax2.set_xlabel("No. of Qubits (N)")
    ax2.set_ylabel("No. of Partitions")
    ax2R.set_ylabel("Inference time (s)")
    set_sparse_xticks(ax2, all_Ns, max_ticks=8)
    ax2.grid(True, which="both", linestyle=":", linewidth=0.9, zorder=0)
    ax2.margins(x=0.005, y=0.01)
    _box_axes(ax2, lw=1.2); _box_axes(ax2R, lw=1.2)

    # Metric legend (TOP-LEFT, lines ONLY — no markers)
    metric_handles = [
        Line2D([0],[0], color="tab:blue", linestyle="-",  label="Partitions (left)", linewidth=LINE_W),
        Line2D([0],[0], color="tab:red",  linestyle="--", label="Time (right)",     linewidth=LINE_W_TIME),
        Line2D([0],[0], color="black",     linestyle=":",  label=r"Optimal: $2^N + 1$",  linewidth=LINE_W_REF),
    ]
    leg_metric = ax2.legend(handles=metric_handles,
                            loc="upper left", frameon=False, ncol=1,
                            borderaxespad=0.25, borderpad=0.25,
                            handlelength=2.2, handletextpad=0.6, labelspacing=0.25)
    ax2.add_artist(leg_metric)

    # Models legend
    ax2.legend(handles=handles_models,
               loc="lower right", frameon=False, ncol=1,
               borderaxespad=0.25, borderpad=0.25,
               handlelength=1.0, handletextpad=0.4, labelspacing=0.25)

# ==================== (c) Ratio vs M (LEFT) + MOVED GNN time vs M (RIGHT) ====================
if subset_df.empty and subset_stats.empty:
    axC.text(0.5, 0.5, "No subset/ratio data", ha="center", va="center"); axC.axis("off")
else:
    # Ratio data (left axis)
    if not subset_stats.empty and {"target_M_used","ratio_mean","ratio_stderr"}.issubset(subset_stats.columns):
        rss = subset_stats.sort_values("target_M_used")
        Ms_ratio   = rss["target_M_used"].to_numpy(float)
        ratio_mean = rss["ratio_mean"].to_numpy(float)
        ratio_err  = rss["ratio_stderr"].fillna(0).to_numpy(float)
    else:
        comp = (subset_df.groupby("target_M_used")["ratio_gnn_over_ilp"]
                        .agg(['mean','std','count']).reset_index())
        comp["stderr"] = comp["std"] / np.sqrt(comp["count"].clip(lower=1))
        comp = comp.sort_values("target_M_used")
        Ms_ratio   = comp["target_M_used"].to_numpy(float)
        ratio_mean = comp["mean"].to_numpy(float)
        ratio_err  = comp["stderr"].fillna(0).to_numpy(float)

    l_ratio = axC.errorbar(Ms_ratio, ratio_mean, yerr=ratio_err, fmt="-o",
                           linewidth=LINE_W, capsize=CAPSIZE, color="tab:green",
                           markersize=MSIZE, label="GNN / ILP (partitions)")
    axC.axhline(1.0, linestyle="--", linewidth=LINE_W_REF, color="0.4")

    
    axCR = axC.twinx()
    if not subset_df.empty and {"target_M_used","time_gnn"}.issubset(subset_df.columns):
        agg = (subset_df.groupby("target_M_used", as_index=False)
                        .agg(time_gnn_mean=("time_gnn","mean"),
                             time_gnn_std =("time_gnn","std"))
                        .sort_values("target_M_used"))
        Ms_time = agg["target_M_used"].to_numpy(float)
        t_mean  = agg["time_gnn_mean"].to_numpy(float)
        t_std   = agg["time_gnn_std"].fillna(0).to_numpy(float)
        l_time = axCR.errorbar(Ms_time, t_mean, yerr=t_std,
                               fmt="--s", color="tab:red",
                               linewidth=LINE_W_TIME, capsize=CAPSIZE,
                               markersize=MSIZE, label="GNN time (right)")
        # Right axis limits (linear)
        if np.isfinite(t_mean).any():
            hi = np.nanmax(t_mean)
            axCR.set_ylim(0, hi*1.08 if hi>0 else 1.0)

    axC.set_xlabel("No. of Operators (M)")
    axC.set_ylabel("GNN / ILP (partitions)")
    axCR.set_ylabel("Inference time (s)")
    set_sparse_xticks(axC, Ms_ratio, max_ticks=8, integer=False)
    axC.grid(True, linestyle=":", linewidth=0.9, alpha=0.9)
    axC.margins(x=0.005, y=0.01)
    _box_axes(axC, lw=1.2); _box_axes(axCR, lw=1.2)

    # Legend for (c): combine ratio (left) and time (right)
    handles_c, labels_c = [], []
    handles_c.append(Line2D([0],[0], color="tab:green", linestyle="-", marker="o",
                            linewidth=LINE_W, markersize=MSIZE, label="GNN / ILP (partitions)"))
    handles_c.append(Line2D([0],[0], color="tab:red", linestyle="--", marker="s",
                            linewidth=LINE_W_TIME, markersize=MSIZE, label="GNN time (right)"))
    axC.legend(handles=handles_c, loc="upper left", frameon=False,
               borderaxespad=0.25, borderpad=0.25,
               handlelength=1.6, handletextpad=0.4, labelspacing=0.25)

# -------------------- Titles OUTSIDE each axes, centered --------------------
def add_outside_title(ax, text, y_pad=0.012, fontsize=19):
    bb = ax.get_position()
    cx = (bb.x0 + bb.x1) / 2.0
    y  = bb.y1 + y_pad
    fig.text(cx, y, text, ha='center', va='bottom', fontsize=fontsize, fontweight='bold')

add_outside_title(ax1, "(a) Training time per model")
add_outside_title(ax2, "(b) Inference: partitions & time vs N")
add_outside_title(axC, "(c) Partition ratio vs M (left) & GNN time (right)")

# -------------------- Save combined 3×1 figure --------------------
OUT_PNG = "GNN_3x1_column_v1.png"
OUT_PDF = "GNN_3x1_column_v1.pdf"
fig.savefig(OUT_PNG, dpi=600, bbox_inches="tight")
fig.savefig(OUT_PDF, bbox_inches="tight")
plt.close(fig)

print("Saved:", OUT_PNG, OUT_PDF)


In [None]:
# GRID 3×3 — Fidelity vs #settings (LINES + SHADED BANDS)


import os, re, glob
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, FormatStrFormatter

# ---------- figure styling ----------
plt.rcParams.update({
    "pdf.fonttype": 42, "ps.fonttype": 42,
    "font.size": 17,          
    "axes.titlesize": 20,     
    "axes.labelsize": 22,     
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 20,    
    "axes.linewidth": 1.6,
    "xtick.major.width": 1.4,
    "ytick.major.width": 1.4,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "xtick.minor.visible": True,
    "ytick.minor.visible": True,
})

# -------------------- Locate CSVs --------------------
candidates = [
    "QST_scheduling.csv",
    "standard_QST.csv",
    *sorted(glob.glob("*GRID2x3*FIXEDMARKERS*DSATUR*SC*.csv")),
    *sorted(glob.glob("*standard*QST*.csv")),
    *sorted(glob.glob("*Standard*QST*.csv")),
]
csvs = [p for p in candidates if os.path.exists(p)]
if not csvs:
    raise FileNotFoundError("No CSVs found. Place your main CSV and StandardQST CSV next to this script or in /mnt/data/")

dfs = []
for p in csvs:
    try:
        dfp = pd.read_csv(p)
        if {"method","system","K","fidelity_mean"}.issubset(dfp.columns):
            dfp["__source__"] = os.path.basename(p)
            dfs.append(dfp)
    except Exception as e:
        print(f"[WARN] Could not read {p}: {e}")
if not dfs:
    raise RuntimeError("No readable CSV with required columns: method, system, K, fidelity_mean.")

df = pd.concat(dfs, ignore_index=True, sort=False)

# -------------------- Normalize --------------------
def norm_text(x): return str(x).strip()

def norm_system(s):
    s = norm_text(s).lower().replace(" ", "")
    if s in ("q","qubit","qubits"): return "qubit"
    if s in ("t","qutrit","qutrits"): return "qutrit"
    if "hybrid" in s: return "hybrid"
    return norm_text(s)

for c in ["method","system","hybrid_label"]:
    if c in df.columns: df[c] = df[c].map(norm_text)

for c in ["N","K","fidelity_mean","fidelity_std","setting_size","n_qubits","n_qutrits"]:
    if c in df.columns: df[c] = pd.to_numeric(df[c], errors="coerce")

if "system" in df.columns:
    df["system"] = df["system"].map(norm_system)

def normalize_hybrid_label(row):
    lab = str(row.get("hybrid_label","")).strip()
    nq  = row.get("n_qubits", np.nan)
    nt  = row.get("n_qutrits", np.nan)
    s = lab.lower().replace(" ", "").replace("_", "")
    s = s.replace("tensor","⊗").replace("*","⊗").replace("x","⊗")
    m = re.search(r"(\d+)\s*q.*?⊗.*?(\d+)\s*t", s)
    if m: return f"{int(m.group(1))}q ⊗ {int(m.group(2))}t"
    m2 = re.search(r"(\d+)\s*q.*?(\d+)\s*t", s)
    if m2: return f"{int(m2.group(1))}q ⊗ {int(m2.group(2))}t"
    if pd.notna(nq) and pd.notna(nt): return f"{int(nq)}q ⊗ {int(nt)}t"
    return lab

if "system" in df.columns:
    mask_h = df["system"]=="hybrid"
    if mask_h.any():
        df.loc[mask_h, "hybrid_label"] = df[mask_h].apply(normalize_hybrid_label, axis=1)

required = {"method","system","K","fidelity_mean"}
df = df.dropna(subset=list(required))

# -------------------- Method order (canonical + extras) --------------------
preferred = ["DSATUR", "SC", "RLF", "GNN", "ILP"]
present   = [m for m in preferred if m in set(df["method"].dropna().unique())]
extras    = [m for m in df["method"].dropna().unique() if m not in present]  # includes StandardQST, SoTA, etc.
methods_order = present + list(dict.fromkeys(extras))

# Baseline/standard aliases 
baseline_aliases = ["SOTA","SoTA","Standard","StdQST","StandardQST","Baseline","Ref","Reference"]
def is_baseline_method(name: str) -> bool:
    low = name.lower()
    return any(b.lower() in low for b in baseline_aliases)

# Consistent colors across panels
base_colors = plt.rcParams['axes.prop_cycle'].by_key().get('color', [f"C{i}" for i in range(12)])
method_color = {m: base_colors[i % len(base_colors)] for i, m in enumerate(methods_order)}

# ----- Legend label overrides -----
LEGEND_LABEL_OVERRIDE = {
    "StandardQST": "StandardQST (Qubits only)"
}

# Force only the LINE color for certain methods
LINE_COLOR_OVERRIDE = {
    "StandardQST": "k",   
}

# -------------------- Tick helper: evenly spaced integer ticks --------------------
def set_even_integer_ticks(ax, xvals, nticks=5):
    """
    Choose a constant integer step and place ticks on that grid so they are
    visually evenly spaced.
    """
    xvals = np.asarray(xvals, dtype=int)
    if xvals.size == 0:
        return
    xmin, xmax = int(np.min(xvals)), int(np.max(xvals))
    if xmin == xmax:
        ax.set_xlim(xmin, xmax)
        ax.set_xticks([xmin])
        ax.xaxis.set_major_formatter(FormatStrFormatter('%d'))
        return

    span = xmax - xmin
    nice_steps = np.array([1, 2, 3, 5, 10, 20, 25, 50, 100])
    rough = max(1, int(round(span / max(1, nticks - 1))))
    step = int(nice_steps[np.argmin(np.abs(nice_steps - rough))])

    # lock ticks to the integer grid defined by "step"
    t0 = int(np.ceil(xmin / step) * step)
    t1 = int(np.floor(xmax / step) * step)
    if t1 < t0:  # edge case when span < step
        t0 = xmin
        t1 = xmax
        step = max(1, span)  # 1 tick interval across full span

    ticks = list(range(t0, t1 + 1, step))

    ax.set_xlim(xmin, xmax)
    ax.set_xticks(ticks)
    ax.xaxis.set_major_formatter(FormatStrFormatter('%d'))

# -------------------- Build series (DEDUP by K; aggregate; sanitize) --------------------
def build_series(df_panel):
    """
    Build clean series: unique, increasing K per method with aggregated stats.
    - fidelity_mean: average across duplicates
    - fidelity_std: RMS across duplicates
    - setting_size: first (or most common) per K
    """
    series, allK = [], []
    for m in methods_order:
        sub = df_panel[df_panel["method"] == m].copy()
        if sub.empty:
            continue

        sub = sub.dropna(subset=["K", "fidelity_mean"])

        # Group only by K (single panel => N/hybrid_label are constant)
        agg = sub.groupby(["K"], as_index=False).agg(
            fidelity_mean=("fidelity_mean", "mean"),
            fidelity_std=("fidelity_std", lambda s: float(np.sqrt(np.nanmean(np.square(s)))))
        )

        # Attach setting_size (choose first non-null per K)
        if "setting_size" in sub.columns:
            ss = sub.groupby("K")["setting_size"].agg(lambda s: s.dropna().iloc[0] if len(s.dropna()) else np.nan)
            agg = agg.merge(ss.reset_index(), on="K", how="left")

        agg = agg.sort_values("K")
        K  = agg["K"].to_numpy(int)
        mu = agg["fidelity_mean"].to_numpy(float)
        sd = agg["fidelity_std"].to_numpy(float)
        sd = np.nan_to_num(sd, nan=0.0)

        sz = None
        if "setting_size" in agg.columns:
            sz = agg["setting_size"].to_numpy()
            if np.all(np.isnan(sz)):
                sz = None
            else:
                try:
                    sz = sz.astype(int)
                except:
                    pass

        series.append((m, K, mu, sd, sz))
        allK.append(K)

    Ks = np.unique(np.concatenate(allK)) if allK else np.array([], dtype=int)
    return series, Ks

# ==================== INSET: larger, clearer + exact main-x alignment ====================
INSET_BOX_DEFAULT = (0.54, 0.06, 0.44, 0.50)   
INSET_BOX_QUTRIT  = (0.54, 0.06, 0.44, 0.50)

INSET_LINEWIDTH = 1.8
INSET_TICKSIZE  = 13
INSET_LABELSIZE = 14
INSET_SPINE_W   = 1.3

def add_inset(ax_main, inset_triplets, inset_qutrit=False, label_y=False, inset_box=None):
    """
    inset_triplets: list of (line_handle, K_array, sizes_array)
    The inset:
      - plots sizes vs the SAME K values as the main panel
      - copies x-lims and x-ticks from the main axis
      - hides x tick labels, but keeps the x-grid to match the main panel
    """
    if inset_box is None:
        inset_box = INSET_BOX_QUTRIT if inset_qutrit else INSET_BOX_DEFAULT

    ax_ins = ax_main.inset_axes(inset_box, facecolor='white', zorder=5)

    # Plot series and track y-extent
    max_y = 1
    for ln, Kvals, sizes in inset_triplets:
        if sizes is None or len(sizes) == 0:
            continue
        # If sizes length doesn't match K, fall back to index to avoid crash
        if Kvals is None or len(Kvals) != len(sizes):
            x = np.arange(1, len(sizes) + 1, dtype=int)
        else:
            x = np.asarray(Kvals, dtype=float)

        ax_ins.plot(x, sizes, linewidth=INSET_LINEWIDTH, color=ln.get_color(), alpha=1.0)
        if len(sizes):
            max_y = max(max_y, int(np.max(sizes)))

    # Copy x-range and ticks from main axis for perfect alignment
    xmin, xmax = ax_main.get_xlim()
    ax_ins.set_xlim(xmin, xmax)

    main_ticks = ax_main.get_xticks()
    ax_ins.set_xticks(main_ticks)

    # Grid: same x major grid style as main
    ax_ins.grid(True, axis='x', which='major', alpha=0.30, linestyle="--", linewidth=1.0)

    # Hide x numbers in inset; keep ticks
    ax_ins.set_xticklabels([])
    ax_ins.tick_params(axis='x', which='both', labelbottom=False)

    # Y ticks: coarse and readable
    ax_ins.set_ylim(0, max_y * 1.08 if max_y > 0 else 1)
    ax_ins.yaxis.set_major_locator(MaxNLocator(integer=True, nbins=4))
    ax_ins.tick_params(axis='y', which='major', labelsize=INSET_TICKSIZE, length=3.4, pad=8)

    if label_y:
        ax_ins.set_ylabel("No. of ops", fontsize=INSET_LABELSIZE, labelpad=8)

    
    for sp in ax_ins.spines.values():
        sp.set_linewidth(INSET_SPINE_W)
        sp.set_alpha(0.95)

# -------------------- Plot each panel (LINEAR y; lines + fixed bands) --------------------
def plot_panel(ax, df_panel, title, panel_key, add_ylabel=False, add_xlabel=False,
               inset_qutrit=False, inset_label_y=False):
    series, Ks = build_series(df_panel)
    if not series:
        ax.set_title(title, fontsize=20, fontweight="bold", pad=10)  
        ax.set_xlim(0,1); ax.set_ylim(0.8,1.02)
        ax.set_xticks([]); ax.set_yticks([])
        for sp in ax.spines.values(): sp.set_linewidth(1.4)
        return []

    line_handles = []
    for (m, K, mu, sd, sz) in series:
        band_col = method_color.get(m, None)                      
        line_col = LINE_COLOR_OVERRIDE.get(m, band_col)           

        # Sanitize shaded band bounds
        lo = mu - sd
        hi = mu + sd
        lo, hi = np.minimum(lo, hi), np.maximum(lo, hi)
        lo = np.clip(lo, 0.0, 1.05)
        hi = np.clip(hi, 0.0, 1.05)

        # Band behind line
        ax.fill_between(K, lo, hi, color=band_col, alpha=0.18, linewidth=0, zorder=1)

        # Line on top (override color if specified)
        (ln,) = ax.plot(K, mu, linewidth=1.9, color=line_col, label=m, zorder=2)
        line_handles.append((ln, K, sz))  # keep K and sizes for the inset

    # LINEAR y-scale and grid
    ax.set_ylim(0.80, 1.02)
    if Ks.size: set_even_integer_ticks(ax, Ks, nticks=5)
    ax.grid(True, alpha=0.30, linestyle="--", linewidth=1.0)

    if add_ylabel: ax.set_ylabel("Fidelity", labelpad=14)
    if add_xlabel: ax.set_xlabel("No. of measurement settings", labelpad=12)
    ax.set_title(title, fontsize=20, fontweight="bold", pad=10) 

    for sp in ax.spines.values(): sp.set_linewidth(1.4)

    # Inset of "# of ops"
    inset_triplets = []
    for (ln, K, sz) in line_handles:
        m = ln.get_label() or "unknown"
        if is_baseline_method(m):
            continue
        inset_triplets.append((ln, K, sz))
    add_inset(ax, inset_triplets, inset_qutrit=inset_qutrit, label_y=inset_label_y)

    ax.tick_params(axis='x', labelsize=18, pad=7)
    ax.tick_params(axis='y', labelsize=18, pad=7)
    for lbl in ax.get_xticklabels(): lbl.set_ha('center')

    return line_handles

# -------------------- Prepare subsets & 3×3 layout --------------------
qubit_df   = df[df["system"]=="qubit"].copy()
qutrit_df  = df[df["system"]=="qutrit"].copy()
hybrid_df  = df[df["system"]=="hybrid"].copy()

def sel_qubits(n): return qubit_df[qubit_df["N"]==n]
def sel_qutrits(n): return qutrit_df[qutrit_df["N"]==n]
def sel_hybrid(lbl): return hybrid_df[hybrid_df["hybrid_label"]==lbl]

layout = [
    ("qubit", 2), ("qubit", 3), ("qubit", 4),                 # Row 1
    ("qubit", 5), ("qutrit", 2), ("qutrit", 3),               # Row 2
    ("hybrid", "1q ⊗ 1t"), ("hybrid", "1q ⊗ 2t"), ("hybrid", "2q ⊗ 1t"),  # Row 3
]
letters = list("abcdefghi")

panels = []
for i, (kind, key) in enumerate(layout):
    if kind=="qubit":
        sub = sel_qubits(key); base_title = f"{key} Qubits"
    elif kind=="qutrit":
        sub = sel_qutrits(key); base_title = f"{key} Qutrits"
    else:
        sub = sel_hybrid(key); base_title = key
    title = f"({letters[i]}) {base_title}"
    panels.append((kind, key, sub, title))

# -------------------- Figure --------------------
fig_w, fig_h = 20.0, 14.5  
fig, axes = plt.subplots(3, 3, figsize=(fig_w, fig_h), squeeze=False)

global_methods_present = set()
for idx, ax in enumerate(axes.flat):
    kind, key, sub, title = panels[idx]
    r, c = divmod(idx, 3)
    show_xlabel   = (r == 2)
    show_ylabel   = (c == 0)
    inset_qutrit  = (r == 1 and c in (1,2))   # qutrit panels
    inset_label_y = (c == 0)                  # all first-column insets

    if sub.empty:
        ax.set_title(title, fontsize=20, fontweight="bold", pad=10) 
        ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=16)
        ax.set_xticks([]); ax.set_yticks([])
        for sp in ax.spines.values(): sp.set_linewidth(1.4)
    else:
        plot_panel(ax, sub, title, (kind, key),
                   add_ylabel=show_ylabel, add_xlabel=show_xlabel,
                   inset_qutrit=inset_qutrit, inset_label_y=inset_label_y)
        for m in sub["method"].dropna().unique():
            global_methods_present.add(m)

# -------------------- Global legend (lines only) --------------------
legend_labels = [m for m in methods_order if m in global_methods_present]
legend_handles = []
legend_texts = []  # apply label overrides here
for m in legend_labels:
    leg_col = LINE_COLOR_OVERRIDE.get(m, method_color.get(m))
    h, = plt.plot([], [], label=m, color=leg_col, linewidth=2.0)  # line only, no marker
    legend_handles.append(h)
    legend_texts.append(LEGEND_LABEL_OVERRIDE.get(m, m))

ncol = min(6, max(1, len(legend_labels)))
fig.legend(
    handles=legend_handles, labels=legend_texts,
    loc="lower center", bbox_to_anchor=(0.5, -0.01),
    ncol=ncol, frameon=True, fancybox=True, framealpha=0.98,
    borderpad=0.45, labelspacing=0.50, handlelength=3.2, columnspacing=1.2,
    prop={"size": 20}
)

# -------------------- Layout & Save --------------------
plt.tight_layout()
plt.subplots_adjust(left=0.07, right=0.995, top=0.97, bottom=0.13, wspace=0.28, hspace=0.36)

OUT_PDF = "GRID_3x3_Fidelity_AllMethods_EXACT_WithStandardQST_v93_LINEAR_lines_MATCHED_INSETS.pdf"
OUT_PNG = "GRID_3x3_Fidelity_AllMethods_EXACT_WithStandardQST_v93_LINEAR_lines_MATCHED_INSETS.png"
plt.savefig(OUT_PDF, bbox_inches="tight")
plt.savefig(OUT_PNG, dpi=600, bbox_inches="tight")
print("Saved:", OUT_PDF, OUT_PNG)


In [None]:
# Dual-axis plot: compact layout, vertical legend (inside bottom-center),
# colored y-axes/spines/grids, and a black annotation on the first blue datapoint.
# - Uses ONLY the provided CSV files.
# - Shows the figure inline and also saves PDF/PNG.

import os
import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# ---------- Paths (edit if needed) ----------
CSV_2Q = "tradeoff_N2.csv"
CSV_3Q = "tradeoff_3q_depth_vs_K.csv"

# ---------- Strict file checks ----------
if not os.path.exists(CSV_2Q):
    raise FileNotFoundError(f"Missing CSV: {CSV_2Q}")
if not os.path.exists(CSV_3Q):
    raise FileNotFoundError(f"Missing CSV: {CSV_3Q}")

# ---------- Aesthetics ----------
mpl.rcParams.update({
    "pdf.fonttype": 42, "ps.fonttype": 42,
    "font.size": 10, "axes.titlesize": 12, "axes.labelsize": 11,
    "xtick.labelsize": 10, "ytick.labelsize": 10, "legend.fontsize": 9,
    "axes.linewidth": 1.0, "xtick.major.width": 0.9, "ytick.major.width": 0.9,
    "xtick.minor.width": 0.7, "ytick.minor.width": 0.7,
    "xtick.direction": "in", "ytick.direction": "in",
    "xtick.minor.visible": True, "ytick.minor.visible": True,
})

# ---------- Column helpers ----------
def find_col(df, names):
    lower = {c.lower(): c for c in df.columns}
    for n in names:
        if n.lower() in lower:
            return lower[n.lower()]
    for n in names:
        for c in df.columns:
            if n.lower() in c.lower():
                return c
    raise KeyError(f"Could not find any of {names} in columns: {list(df.columns)}")

def clean(df, xcol, ycol):
    out = df[[xcol, ycol]].copy()
    out[xcol] = pd.to_numeric(out[xcol], errors="coerce")
    out[ycol] = pd.to_numeric(out[ycol], errors="coerce")
    return out.dropna().sort_values(xcol)

# ---------- Load & prepare ----------
df2 = pd.read_csv(CSV_2Q)
df3 = pd.read_csv(CSV_3Q)

x2 = find_col(df2, ["B_used", "K_settings", "K", "num_circuits", "partitions"])
y2 = find_col(df2, ["max_CNOTs_per_qubit", "max_per_qubit_2Q_depth", "twoq_depth", "depth"])
x3 = find_col(df3, ["K_settings", "B_used", "K", "num_circuits", "partitions"])
y3 = find_col(df3, ["max_per_qubit_2Q_depth", "max_CNOTs_per_qubit", "twoq_depth", "depth"])

a2 = clean(df2, x2, y2).groupby(x2, as_index=False)[y2].min().sort_values(x2)
a3 = clean(df3, x3, y3).groupby(x3, as_index=False)[y3].min().sort_values(x3)

# ---------- Colors ----------
left_color, right_color = "tab:blue", "tab:red"

# ---------- Figure ----------
fig, ax_left = plt.subplots(figsize=(5.2, 3.6))
fig.subplots_adjust(left=0.11, right=0.89, bottom=0.20, top=0.90)

# Left Y (2-qubit)
line2, = ax_left.plot(a2[x2], a2[y2], ls="solid", marker="o", lw=1.0, ms=3.4,
                      color=left_color, label="Max depth of circuit (2 qubits)")
ax_left.set_ylabel("Circuit Depth (No of 2-Qubit gates)", color=left_color, labelpad=8)
ax_left.tick_params(axis="y", colors=left_color, pad=3)
ax_left.spines["left"].set_color(left_color)

# X axis
ax_left.set_xlabel("No of Quantum Circuits (Unitaries)", labelpad=8)
ax_left.set_title("QST Measurement Cost Trade-off", pad=8, fontsize=10)

# Left Y limits
y2_min, y2_max = a2[y2].min(), a2[y2].max()
y2_pad = max(0.2, (y2_max - y2_min) * 0.12) if y2_max > y2_min else 0.5
ax_left.set_ylim(y2_min - y2_pad, y2_max + y2_pad)
ax_left.yaxis.get_major_locator().set_params(integer=True)

# Right Y (3-qubit)
ax_right = ax_left.twinx()
line3, = ax_right.plot(a3[x3], a3[y3], ls="solid", marker="o", lw=1.0, ms=3.4,
                       color=right_color, label="Max depth of circuit (3 qubits)")
ax_right.set_ylabel("Circuit Depth (No of 2-Qubit gates)", color=right_color, labelpad=8)
ax_right.tick_params(axis="y", colors=right_color, pad=3)
ax_right.spines["right"].set_color(right_color)

# Right Y limits
y3_min, y3_max = a3[y3].min(), a3[y3].max()
y3_pad = max(0.4, (y3_max - y3_min) * 0.12) if y3_max > y3_min else 0.5
ax_right.set_ylim(y3_min - y3_pad, y3_max + y3_pad)
ax_right.yaxis.get_major_locator().set_params(integer=True)

# X limits
x_min = min(a2[x2].min(), a3[x3].min())
x_max = max(a2[x2].max(), a3[x3].max())
x_pad = max(1, int(round((x_max - x_min) * 0.05))) if (x_max - x_min) > 0 else 1
ax_left.set_xlim(x_min - x_pad, x_max + x_pad)

# Colored horizontal grids
ax_left.grid(False); ax_right.grid(False)
ax_left.yaxis.grid(True, which="major", color=left_color,  alpha=0.10, ls="--")
ax_right.yaxis.grid(True, which="major", color=right_color, alpha=0.10, ls="--")

# Spine widths
for sp in ax_left.spines.values(): sp.set_linewidth(1.0)
for sp in ax_right.spines.values(): sp.set_linewidth(1.0)

# Annotation (first blue datapoint)
first_x, first_y = a2[x2].iloc[0], a2[y2].iloc[0]
ax_left.annotate(
    "Data Point of the Circuit",
    xy=(first_x, first_y), xytext=(10, -20), textcoords="offset points",
    color="black", fontsize=9, ha="left", va="bottom",
    arrowprops=dict(arrowstyle="->", color="black", lw=0.9, shrinkA=0, shrinkB=0),
    bbox=dict(boxstyle="round,pad=0.30", fc="white", ec="black", alpha=0.80)
)

# Legend (stacked, inside bottom-center)
lines = [line2, line3]; labels = [l.get_label() for l in lines]
leg = ax_left.legend(lines, labels, frameon=True, fancybox=True, framealpha=0.92,
                     ncol=1, loc="lower center", bbox_to_anchor=(0.5, 0.06),
                     borderaxespad=0.4, handlelength=1.6, handletextpad=0.5, labelspacing=0.35)
leg.get_frame().set_edgecolor("0.7"); leg.get_frame().set_linewidth(0.8); leg.get_frame().set_facecolor("white")
for text, col in zip(leg.get_texts(), [left_color, right_color]): text.set_color(col)

# Save + show inline
out_pdf = "QST_tradeoff.pdf"
out_png = "QST_tradeoff.png"
fig.savefig(out_pdf, bbox_inches="tight")
fig.savefig(out_png, dpi=600, bbox_inches="tight")
plt.show()
print(f"Saved: {out_pdf}, {out_png}")
