# Draw

In [None]:
# version to plot figure 4 (truncated)

import os, pickle, glob
import numpy as np
import matplotlib.pyplot as plt

# ────────────────────────────────────────────────────────────────────────────
# 0) X-axis mode: "epoch" | "bp_total" | "exbp_total" | "time" | "gpu_seconds"
# ────────────────────────────────────────────────────────────────────────────
X_MODE = "time"
#X_MODE = "bp_total"
#X_MODE = "exbp_total"
#X_MODE = "epoch"

#X_MODE = "gpu_seconds"

# ────────────────────────────────────────────────────────────────────────────
# 1) CURVE SPECS – add pattern-only placeholders for new methods
# ────────────────────────────────────────────────────────────────────────────

folder = "_release_data/"
#folder = "saved_data/"


CURVES = [
    # 1) EF21-SGDM (74.66%)
    {"pattern": folder + "resnet18_cifar10_EF21_SGDM_topk-0,1_lr-0,1_eta-0,1_p-None_q-None.pickle",
     "label":"EF21-SGDM","color":"orange","marker":"v","linestyle":"-"},
    # 2) ECONTROL (76.38%)
    {"pattern": folder + "resnet18_cifar10_ECONTROL_topk-0,1_lr-1,0_eta-0,1_p-None_q-None.pickle",
     "label":"ECONTROL","color":"yellowgreen","marker":"p","linestyle":"-"},
    # 3) EF21 (77.50%)
    {"pattern": folder + "resnet18_cifar10_EF21_topk-0,1_lr-1,0_eta-None_p-None_q-None.pickle",
     "label":"EF21","color":"blue","marker":"o","linestyle":"-"},
    # 4) EF21-MVR_2b (79.20%)
    {"pattern": folder + "resnet18_cifar10_EF21_MVR_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,67.pickle",
     "label":r"$\|\text{EF21-MVR}\|$","color":"#8c564b","marker":"<","linestyle":"-"},
    # 5) EF21-RHM (81.48%)
    {"pattern": folder + "resnet18_cifar10_EF21_RHM_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,67.pickle",
     "label":r"$\|\text{EF21-RHM}\|$","color":"#e377c2","marker":">","linestyle":"-"},
    # 6) ||EF21-SGDM|| (NORM) (82.66%)
    {"pattern": folder + "resnet18_cifar10_EF21_SGDM_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,5.pickle",
     "label":r"$\|\text{EF21-SGDM}\|$","color":"darkgreen","marker":"s","linestyle":"-"},
    # 7) ||EF21-IGT|| (83.28%)
    {"pattern": folder + "resnet18_cifar10_EF21_IGT_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,57.pickle",
     "label":r"$\|\text{EF21-IGT}\|$","color":"#9467bd","marker":"^","linestyle":"-"},
    # 8) ||EF21-HM|| (SOM_XT2) (84.32%)
    {"pattern": folder + "resnet18_cifar10_EF21_HM_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,67.pickle",
     "label":r"$\|\text{EF21-HM}\|$","color":"red","marker":"*","linestyle":"-"},
]




# ────────────────────────────────────────────────────────────────────────────
# 2) GLOBAL STYLE
# ────────────────────────────────────────────────────────────────────────────
ARG = dict(label_fontsize=20, marker_size=15)
plt.rcParams.update({
    "lines.linewidth": 2,
    "xtick.labelsize": ARG["label_fontsize"],
    "ytick.labelsize": ARG["label_fontsize"],
    "legend.fontsize": ARG["label_fontsize"],
    "axes.titlesize":  ARG["label_fontsize"],
    "axes.labelsize":  ARG["label_fontsize"],
    "figure.figsize": [10, 8],
    "pdf.fonttype": 42, "ps.fonttype": 42,
    "text.usetex": False, "font.family": "serif",
})

# ────────────────────────────────────────────────────────────────────────────
# 3) HELPERS
# ────────────────────────────────────────────────────────────────────────────
AUTO_COLORS  = plt.rcParams["axes.prop_cycle"].by_key().get("color", [])
AUTO_MARKERS = ["o","s","^","v","D","*","P","X","<",">","p","h","H","8"]

def load_many(pattern):
    runs = []
    for fname in glob.glob(pattern):
        with open(fname, "rb") as f:
            runs.append(pickle.load(f))
    return runs

def mean_curve(runs, key):
    if not runs: return None
    arrs = [r.get(key) for r in runs if key in r]
    if not arrs: return None
    m = min(len(a) for a in arrs)
    return np.mean([np.asarray(a[:m]) for a in arrs], axis=0)

def auto_label_from_pattern(pat):
    base = os.path.basename(pat).replace(".pickle","")
    parts = base.split("_")
    cand = [t for t in parts if t.startswith("EF21")]
    return cand[0] if cand else base

def ensure_style(spec, idx):
    if "label" not in spec: spec["label"] = auto_label_from_pattern(spec["pattern"])
    if "color" not in spec: spec["color"] = AUTO_COLORS[idx % len(AUTO_COLORS)]
    if "marker" not in spec: spec["marker"] = AUTO_MARKERS[idx % len(AUTO_MARKERS)]
    if "linestyle" not in spec: spec["linestyle"] = "-"

# turn per-epoch series into cumulative; pad 0 for epoch 0 snapshot
def _cum_gpu_seconds(arr):
    arr = np.asarray(arr, dtype=float)
    if arr.size == 0: return arr
    c = np.cumsum(arr)
    return np.concatenate([[0.0], c])

def pick_x_series(runs, x_mode):
    """
    Returns (x_mean, xlabel). For gpu_seconds we cumulate and pad 0.
    For others we use pre-cumulative keys if present.
    """
    key_map = {
        "epoch": None,
        "gpu_seconds": "gpu_seconds",          # will cumulate
        "bp_total": "cum_bp_eq_total",
        "exbp_total": "cum_ex_bp_total",
        "time": "wall_seconds",
    }
    labels = {
        "epoch": "Epoch",
        "gpu_seconds": "Cumulative GPU seconds (per-epoch CUDA timing)",
        "bp_total": "Cumulative backprop equivalents",
        "exbp_total": "Cumulative examples × backprops",
        "time": "Cumulative wall-clock seconds",
    }

    k = key_map.get(x_mode)
    if k is None:
        return None, labels["epoch"]

    xs = []
    for r in runs:
        if k not in r: 
            continue
        a = r[k]
        if x_mode == "gpu_seconds":
            a = _cum_gpu_seconds(a)
        else:
            a = np.asarray(a, dtype=float)
        xs.append(a)
    if not xs:
        return None, labels["epoch"]

    m = min(len(a) for a in xs)
    x = np.mean([a[:m] for a in xs], axis=0)
    return x, labels[x_mode]

def markers_on(n_pts, curve_idx, n_curves):
    freq = max(1, n_pts // 10)
    shift = (freq // max(1, n_curves)) * curve_idx
    return np.arange(shift, n_pts, freq)

# ────────────────────────────────────────────────────────────────────────────
# 4) PLOT  — truncate x-axis to the fastest run when using time-based axes
# ────────────────────────────────────────────────────────────────────────────
metrics = {
    "train_loss": ("Train loss",  "Train loss"),
    "test_loss":  ("Test loss",   "Test loss"),
    "test_acc":   ("Test accuracy", "Accuracy"),
}

TIME_BASED = {"time", "gpu_seconds"}  # we’ll truncate to the fastest run for these

for mkey, (title, ylabel) in metrics.items():
    fig, ax = plt.subplots()

    # First pass: load all curves and prepare (x, y) without plotting
    prepared = []  # list of dicts: {"spec":..., "x":..., "y":...}
    for idx, spec in enumerate(CURVES):
        ensure_style(spec, idx)
        runs = load_many(spec["pattern"])
        if not runs:
            print(f"⚠  no runs match '{spec['pattern']}' – skipping")
            continue

        y = mean_curve(runs, mkey)
        if y is None:
            print(f"⚠  runs for '{spec['pattern']}' lack '{mkey}' – skipping")
            continue

        # choose x-axis (and keep lengths aligned)
        x, xlabel = pick_x_series(runs, X_MODE)
        if x is None:
            x = np.arange(0, len(y))  # epoch 0..N-1 to match the epoch=0 snapshot
            xlabel = "Epoch"

        # align lengths now; we'll do a second trim for the common cutoff below
        min_len = min(len(x), len(y))
        x = x[:min_len]
        y = y[:min_len]

        prepared.append({"spec": spec, "x": x, "y": y, "xlabel": xlabel})

    if not prepared:
        plt.close(fig)
        continue

    # If using a time-based x-axis, truncate all curves to the fastest (smallest) final x
    if X_MODE in TIME_BASED:
        common_xmax = min(arr["x"][-1] for arr in prepared)
    else:
        common_xmax = None

    # Second pass: plot (with optional truncation)
    for idx, item in enumerate(prepared):
        spec, x, y = item["spec"], item["x"], item["y"]
        xlabel = item["xlabel"]

        if common_xmax is not None:
            mask = x <= common_xmax + 1e-12  # include point exactly at cutoff
            x = x[mask]
            y = y[mask]

        ax.plot(
            x, y,
            label=spec["label"],
            color=spec["color"],
            linestyle=spec["linestyle"],
            marker=spec["marker"],
            markevery=markers_on(len(y), idx, len(prepared)),
            markersize=ARG["marker_size"],
            markerfacecolor=spec["color"],
            markeredgecolor="black",
        )

    if mkey in {"train_loss", "test_loss"}:
        ax.set_yscale("log")

    ax.set_title(title)
    ax.set_xlabel(prepared[0]["xlabel"])  # same for all
    ax.set_ylabel(ylabel)
    ax.grid(True, which="both", ls="--", alpha=0.4)
    ax.legend()

    # lock the x-axis to the common cutoff for time-based modes
    if common_xmax is not None:
        ax.set_xlim(0, common_xmax)

    fig.tight_layout()
    pdf_name = f"{mkey}_{X_MODE}.pdf"
    fig.savefig(pdf_name, bbox_inches="tight")
    print(f"✓  saved {pdf_name}")
    plt.show()



In [None]:
# version to plot figures 1,2,3


import os, pickle, glob
import numpy as np
import matplotlib.pyplot as plt

# ────────────────────────────────────────────────────────────────────────────
# 0) X-axis mode: "epoch" | "bp_total" | "exbp_total" | "time" | "gpu_seconds"
# ────────────────────────────────────────────────────────────────────────────
X_MODE = "time"
#X_MODE = "bp_total"
#X_MODE = "exbp_total"
#X_MODE = "epoch"

#X_MODE = "gpu_seconds"

# ────────────────────────────────────────────────────────────────────────────
# 1) CURVE SPECS – add pattern-only placeholders for new methods
# ────────────────────────────────────────────────────────────────────────────

folder = "_release_data/"
#folder = "saved_data/"


CURVES = [
    # 1) EF21-SGDM (74.66%)
    {"pattern": folder + "resnet18_cifar10_EF21_SGDM_topk-0,1_lr-0,1_eta-0,1_p-None_q-None.pickle",
     "label":"EF21-SGDM","color":"orange","marker":"v","linestyle":"-"},
    # 2) ECONTROL (76.38%)
    {"pattern": folder + "resnet18_cifar10_ECONTROL_topk-0,1_lr-1,0_eta-0,1_p-None_q-None.pickle",
     "label":"ECONTROL","color":"yellowgreen","marker":"p","linestyle":"-"},
    # 3) EF21 (77.50%)
    {"pattern": folder + "resnet18_cifar10_EF21_topk-0,1_lr-1,0_eta-None_p-None_q-None.pickle",
     "label":"EF21","color":"blue","marker":"o","linestyle":"-"},
    # 4) EF21-MVR_2b (79.20%)
    {"pattern": folder + "resnet18_cifar10_EF21_MVR_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,67.pickle",
     "label":r"$\|\text{EF21-MVR}\|$","color":"#8c564b","marker":"<","linestyle":"-"},
    # 5) EF21-RHM (81.48%)
    {"pattern": folder + "resnet18_cifar10_EF21_RHM_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,67.pickle",
     "label":r"$\|\text{EF21-RHM}\|$","color":"#e377c2","marker":">","linestyle":"-"},
    # 6) ||EF21-SGDM|| (NORM) (82.66%)
    {"pattern": folder + "resnet18_cifar10_EF21_SGDM_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,5.pickle",
     "label":r"$\|\text{EF21-SGDM}\|$","color":"darkgreen","marker":"s","linestyle":"-"},
    # 7) ||EF21-IGT|| (83.28%)
    {"pattern": folder + "resnet18_cifar10_EF21_IGT_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,57.pickle",
     "label":r"$\|\text{EF21-IGT}\|$","color":"#9467bd","marker":"^","linestyle":"-"},
    # 8) ||EF21-HM|| (SOM_XT2) (84.32%)
    {"pattern": folder + "resnet18_cifar10_EF21_HM_NORM_topk-0,1_lr-0,1_eta-1,0_p-None_q-0,67.pickle",
     "label":r"$\|\text{EF21-HM}\|$","color":"red","marker":"*","linestyle":"-"},
]




# ────────────────────────────────────────────────────────────────────────────
# 2) GLOBAL STYLE
# ────────────────────────────────────────────────────────────────────────────
ARG = dict(label_fontsize=20, marker_size=15)
plt.rcParams.update({
    "lines.linewidth": 2,
    "xtick.labelsize": ARG["label_fontsize"],
    "ytick.labelsize": ARG["label_fontsize"],
    "legend.fontsize": ARG["label_fontsize"],
    "axes.titlesize":  ARG["label_fontsize"],
    "axes.labelsize":  ARG["label_fontsize"],
    "figure.figsize": [10, 8],
    "pdf.fonttype": 42, "ps.fonttype": 42,
    "text.usetex": False, "font.family": "serif",
})

# ────────────────────────────────────────────────────────────────────────────
# 3) HELPERS
# ────────────────────────────────────────────────────────────────────────────
AUTO_COLORS  = plt.rcParams["axes.prop_cycle"].by_key().get("color", [])
AUTO_MARKERS = ["o","s","^","v","D","*","P","X","<",">","p","h","H","8"]

def load_many(pattern):
    runs = []
    for fname in glob.glob(pattern):
        with open(fname, "rb") as f:
            runs.append(pickle.load(f))
    return runs

def mean_curve(runs, key):
    if not runs: return None
    arrs = [r.get(key) for r in runs if key in r]
    if not arrs: return None
    m = min(len(a) for a in arrs)
    return np.mean([np.asarray(a[:m]) for a in arrs], axis=0)

def auto_label_from_pattern(pat):
    base = os.path.basename(pat).replace(".pickle","")
    parts = base.split("_")
    cand = [t for t in parts if t.startswith("EF21")]
    return cand[0] if cand else base

def ensure_style(spec, idx):
    if "label" not in spec: spec["label"] = auto_label_from_pattern(spec["pattern"])
    if "color" not in spec: spec["color"] = AUTO_COLORS[idx % len(AUTO_COLORS)]
    if "marker" not in spec: spec["marker"] = AUTO_MARKERS[idx % len(AUTO_MARKERS)]
    if "linestyle" not in spec: spec["linestyle"] = "-"

# turn per-epoch series into cumulative; pad 0 for epoch 0 snapshot
def _cum_gpu_seconds(arr):
    arr = np.asarray(arr, dtype=float)
    if arr.size == 0: return arr
    c = np.cumsum(arr)
    return np.concatenate([[0.0], c])

def pick_x_series(runs, x_mode):
    """
    Returns (x_mean, xlabel). For gpu_seconds we cumulate and pad 0.
    For others we use pre-cumulative keys if present.
    """
    key_map = {
        "epoch": None,
        "gpu_seconds": "gpu_seconds",          # will cumulate
        "bp_total": "cum_bp_eq_total",
        "exbp_total": "cum_ex_bp_total",
        "time": "wall_seconds",
    }
    labels = {
        "epoch": "Epoch",
        "gpu_seconds": "Cumulative GPU seconds (per-epoch CUDA timing)",
        "bp_total": "Cumulative backprop equivalents",
        "exbp_total": "Cumulative examples × backprops",
        "time": "Cumulative wall-clock seconds",
    }

    k = key_map.get(x_mode)
    if k is None:
        return None, labels["epoch"]

    xs = []
    for r in runs:
        if k not in r: 
            continue
        a = r[k]
        if x_mode == "gpu_seconds":
            a = _cum_gpu_seconds(a)
        else:
            a = np.asarray(a, dtype=float)
        xs.append(a)
    if not xs:
        return None, labels["epoch"]

    m = min(len(a) for a in xs)
    x = np.mean([a[:m] for a in xs], axis=0)
    return x, labels[x_mode]

def markers_on(n_pts, curve_idx, n_curves):
    freq = max(1, n_pts // 10)
    shift = (freq // max(1, n_curves)) * curve_idx
    return np.arange(shift, n_pts, freq)

# ────────────────────────────────────────────────────────────────────────────
# 4) PLOT
# ────────────────────────────────────────────────────────────────────────────
metrics = {
    "train_loss": ("Train loss",  "Train loss"),
    "test_loss":  ("Test loss",   "Test loss"),
    "test_acc":   ("Test accuracy", "Accuracy"),
}

for mkey, (title, ylabel) in metrics.items():
    fig, ax = plt.subplots()

    for idx, spec in enumerate(CURVES):
        ensure_style(spec, idx)
        runs = load_many(spec["pattern"])
        if not runs:
            print(f"⚠  no runs match '{spec['pattern']}' – skipping")
            continue

        y = mean_curve(runs, mkey)
        if y is None:
            print(f"⚠  runs for '{spec['pattern']}' lack '{mkey}' – skipping")
            continue

        # choose x-axis (and keep lengths aligned)
        x, xlabel = pick_x_series(runs, X_MODE)
        if x is None:
            x = np.arange(0, len(y))  # epoch 0..N-1 to match your epoch=0 snapshot
            xlabel = "Epoch"
        min_len = min(len(x), len(y))
        x = x[:min_len]
        y = y[:min_len]

        ax.plot(
            x, y,
            label=spec["label"],
            color=spec["color"],
            linestyle=spec["linestyle"],
            marker=spec["marker"],
            markevery=markers_on(len(y), idx, len(CURVES)),
            markersize=ARG["marker_size"],
            markerfacecolor=spec["color"],
            markeredgecolor="black",
        )

    if mkey in {"train_loss", "test_loss"}:
        ax.set_yscale("log")

    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.grid(True, which="both", ls="--", alpha=0.4)
    ax.legend()
    fig.tight_layout()

    pdf_name = f"{mkey}_{X_MODE}.pdf"
    fig.savefig(pdf_name, bbox_inches="tight")
    print(f"✓  saved {pdf_name}")
    plt.show()
