In [None]:
import pickle

        # "resultsFedAVG_CIFAR_FULL.pkl"
        # "resultsFedAVG_Fashion_FULL.pkl"
        # "resultsFedAVG_MNIST_FULL.pkl"
PickleFile = "resultsFedAVG_CIFAR_FULL.pkl"
# 1. Load all entries
results = []
with open(PickleFile, 'rb') as f:
    while True:
        try:
            results.append(pickle.load(f))
        except EOFError:
            break

# 2. Remove entries with attack=1 and privacy=1
results = [res for res in results if not (res.get('attack') == 1 and res.get('privacy') == 1)]

# 3. Print a summary line for each remaining entry
for i, res in enumerate(results):
    attack  = res.get('attack')
    privacy = res.get('privacy')
    # depending on privacy, the hyper-param key is either 'noise_STD' or 'p'
    param_name = 'p' if privacy == 2 else 'noise_STD'
    param_val  = res.get(param_name, None)
    print(f"Entry {i:2d} → attack={attack}, privacy={privacy}, {param_name}={param_val}")


In [None]:
import pickle, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict

# --- edit if your file lives elsewhere ---
PICKLE_PATH = Path(PickleFile)

def load_all_frames(pkl_path: Path):
    """Read *every* pickle object stored in the file into a list."""
    frames = []
    with pkl_path.open("rb") as fh:
        while True:
            try:
                frames.append(pickle.load(fh))
            except EOFError:
                break
    return frames

frames = load_all_frames(PICKLE_PATH)
print(f"Loaded {len(frames)} result frames ✅")
privacies = sorted({f["privacy"] for f in frames})
attacks   = sorted({f["attack"]  for f in frames})

print("Available privacies:", privacies)
print("Available attacks:  ", attacks)


In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from ipywidgets import interact, Dropdown, IntSlider, Button, HBox, Output
from collections import defaultdict
from IPython.display import display
import pandas as pd
import ipywidgets as widgets

import matplotlib.pyplot as plt

# Font settings
title_fontsize = 20
label_fontsize = 17
tick_fontsize  = 15
legend_fontsize = 15
legend_title_fontsize = 18

def load_frames(pickle_path):
    frames = []
    with open(pickle_path, "rb") as fh:
        while True:
            try:
                obj = pickle.load(fh)
                frames.append(obj)
            except (EOFError, pickle.UnpicklingError):
                break
    return frames

def process_fedavg_frames(frames):
    processed = []
    for f in frames:
        param_key = 'noise_STD' if 'noise_STD' in f else 'p'
        processed.append({
            'privacy': f['privacy'],
            'attack': f['attack'],
            'param_val': f[param_key],
            'FAR': f.get('FAR', []),
            "Error":f.get("ConvergenceError", []),
            'MDR': f.get('MDR', []),
            'train_loss': f.get('train_loss', []),
            'train_accuracy': f.get('train_accuracy', []),
            'test_loss': f.get('test_loss', []),
            'test_accuracy': f.get('test_accuracy', []),
        })
    return processed

def group_by_param(frames, key='param_val'):
    groups = defaultdict(list)
    for f in frames:
        groups[f[key]].append(f)
    return groups

def get_eval_rounds(total_rounds, num_evals):
    rounds = []
    count = 0
    while len(rounds) < num_evals and count < total_rounds:
        if count < 50:
            interval = 5
        elif count < 200:
            interval = 10
        else:
            interval = 20
        if len(rounds) == 0 or count - rounds[-1] >= interval:
            rounds.append(count)
        count += 1
    return rounds[:num_evals]

def format_sigma_label(val):
    if val == 0:
        return "$0$"
    logval = np.log10(abs(val))
    if np.isclose(logval, int(round(logval)), atol=1e-6):
        sign = "-" if val < 0 else ""
        return f"${sign}10^{{{2*int(round(logval))}}}$"
    return f"${val:.2g}$"

def format_p_label(val):
    if val == 0:
        return "$0$"
    elif val == 2**61-1:
        return r"$2^{61}-1$"
    else:
        return f"${val}$"

def compute_markevery(idx, series_len, nmarkers=8):
    nmarkers = min(nmarkers, series_len)
    if series_len <= nmarkers:
        shift = idx*2 % series_len if series_len > 1 else 0
        markevery = [(i + shift) % series_len for i in range(series_len)]
    else:
        phase = int(round(idx * series_len / (2 * nmarkers)))
        mark_locs = (np.linspace(0, series_len-1, nmarkers, dtype=int) + phase) % series_len
        markevery = mark_locs.tolist()
    return markevery

def grid_plot_all_metrics(frames, attack, rnd, fig_dpi=150):

    metrics_rows = [
        ('train_loss', 'Round', 'Train Loss'),
        ('train_accuracy', 'Round', 'Train Accuracy'),
        ('test_loss', 'Round', 'Test Loss'),
        ('test_accuracy', 'Round', 'Test Accuracy'),
    ]
    privacies = sorted({fb['privacy'] for fb in frames})
    privacy_titles = {
        1: "DP-Privacy",
        2: "SMPC-Privacy",
        3: "SP-Privacy"
    }
    n_cols = len(privacies)
    n_rows = len(metrics_rows)
    markers = ['o', 's', 'd', '^', 'v', '<', '>', 'p', '*', 'h', '+', 'x']

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(5.5*n_cols, 3.3*n_rows), dpi=fig_dpi, constrained_layout=True)
    if n_rows == 1 or n_cols == 1:
        axs = np.atleast_2d(axs)
    column_legends = [[] for _ in range(n_cols)]
    column_labels = [[] for _ in range(n_cols)]

    for col, privacy in enumerate(privacies):
        subset = [f for f in frames if f['privacy']==privacy and f['attack']==attack]
        if not subset:
            for row in range(n_rows):
                axs[row, col].set_visible(False)
            continue
        groups = group_by_param(subset, 'param_val')
        pretty_param = r'$p$' if privacy == 2 else r'$\sigma^2$'

        for row, (metric, xlab, ylab) in enumerate(metrics_rows):
            ax = axs[row, col]
            for idx, (param_val, flist) in enumerate(sorted(groups.items())):
                # --- SKIP p=0 for SMPC privacy ---
                if privacy == 2 and param_val == 0:
                    continue
                fdict = flist[0]
                marker = markers[idx % len(markers)]
                series = np.array(fdict.get(metric, []), dtype=float)
                if series.size == 0:
                    continue
                # x-axis logic
                if metric in ['train_loss', 'train_accuracy']:
                    x = np.arange(len(series))
                else:  # test_loss, test_accuracy
                    # Infer the likely number of total rounds by checking training metrics or FAR
                    max_train_len = max(
                        len(fdict.get('train_loss', [])), 
                        len(fdict.get('train_accuracy', [])), 
                        len(fdict.get('FAR', []))
                    )
                    total_rounds = max_train_len
                    x = get_eval_rounds(total_rounds, len(series))
                markevery = compute_markevery(idx, len(series), nmarkers=8)
                if privacy == 2:
                    p_label = format_p_label(param_val)
                    label = f"{pretty_param}={p_label}"
                else:
                    sigma_label = format_sigma_label(param_val)
                    label = f"{pretty_param}={sigma_label}"
                handle = ax.plot(x, series, label=label, marker=marker, markersize=7, markevery=markevery, linewidth=2.2)[0]
                # Only collect once per privacy
                if row == 0:
                    column_legends[col].append(handle)
                    column_labels[col].append(label)
            ax.set_xlabel(xlab, fontsize=label_fontsize)
            ax.set_ylabel(ylab, fontsize=label_fontsize)
            ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
            ax.set_title(f"{ylab} - {privacy_titles.get(privacy, privacy)}", fontsize=title_fontsize)
            ax.tick_params(axis='both', labelsize=tick_fontsize)
            ax.grid(True)
    # Legends on top as before
    for col, privacy in enumerate(privacies):
        handles = column_legends[col]
        labels = column_labels[col]
        if handles:
            left = (col + 0.05) / n_cols
            width = 0.9 / n_cols
            top = 1.05
            leg_ax = fig.add_axes([left, top, width, 0.08])
            leg_ax.axis('off')
            leg_ax.legend(
                handles, labels, loc='upper center', fontsize=legend_fontsize, ncol=max(1, len(labels)//2),
                title=privacy_titles.get(privacy, f"Privacy={privacy}"), title_fontsize=legend_title_fontsize, frameon=False,
                bbox_to_anchor=(0.5, 1), borderaxespad=0
            )
    return fig, axs

def save_grid_as_pdf(frames, attack, rnd, filename=None):
    if filename is None:
        filename = f"fedavg_grid_attack-{attack}_round-{rnd}.pdf"
    plt.ioff()
    fig, axs = grid_plot_all_metrics(frames, attack, rnd, fig_dpi=300)
    fig.savefig(filename, bbox_inches='tight', dpi=300)
    plt.close(fig)
    plt.ion()
    print(f"Saved figure as {filename}")

def grid_plot_all_metrics_and_show(frames, attack, rnd, fig_dpi=150):
    fig, axs = grid_plot_all_metrics(frames, attack, rnd, fig_dpi=fig_dpi)
    plt.show()

def extract_final_metrics_table(frames, attack, round_idx):
    metrics = ['train_loss', 'train_accuracy', 'test_loss', 'test_accuracy']
    privacy_titles = {
        1: "DP-Privacy",
        2: "SMPC-Privacy",
        3: "SP-Privacy"
    }
    privacies = sorted({fb['privacy'] for fb in frames})
    rows = []
    for privacy in privacies:
        subset = [f for f in frames if f['privacy'] == privacy and f['attack'] == attack]
        if not subset:
            continue
        groups = group_by_param(subset, 'param_val')
        for param_val, flist in sorted(groups.items()):
            fdict = flist[0]
            row = {
                "Privacy": privacy_titles.get(privacy, str(privacy)),
                "Param": format_p_label(param_val) if privacy == 2 else format_sigma_label(param_val)
            }
            for metric in metrics:
                vals = np.array(fdict.get(metric, []), dtype=float)
                if vals.size == 0:
                    row[metric] = None
                else:
                    row[metric] = vals[-1]
            rows.append(row)
    df = pd.DataFrame(rows, columns=["Privacy", "Param"] + metrics)
    return df

def fedavg_grid_plotter_widget_with_save(pickle_path):
    tmp_frames = load_frames(pickle_path)
    frames = process_fedavg_frames(tmp_frames)
    if not frames:
        print("No valid frames found in the file.")
        return
    attacks   = sorted({fb['attack'] for fb in frames})
    n_rounds  = max(len(f['FAR']) if f['FAR'] else 1 for f in frames)

    attack_dropdown = Dropdown(options=attacks, description='Attack')
    round_slider = IntSlider(min=0, max=n_rounds-1, value=0, description='Round')
    save_button = Button(description="Save as PDF")
    output = Output()

    def plot_and_show(attack, rnd):
        with output:
            output.clear_output(wait=True)
            grid_plot_all_metrics_and_show(frames, attack, rnd)

    def save_button_clicked(b):
        attack = attack_dropdown.value
        rnd = round_slider.value
        save_grid_as_pdf(frames, attack, rnd)

    save_button.on_click(save_button_clicked)

    ui = HBox([attack_dropdown, round_slider, save_button])
    interact_output = widgets.interactive_output(plot_and_show, {'attack': attack_dropdown, 'rnd': round_slider})
    display(ui, interact_output, output)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


# Predefined small primes for phasing
_marker_primes = [2, 3, 5, 7, 11, 13, 17]

def compute_sparse_markevery(series_len, privacy_idx, param_idx, nmarkers=8):
    """
    Returns a list of indices for markers to be plotted sparsely,
    using prime numbers for spacing and an offset for non-overlapping markers.
    """
    step = max(1, series_len // nmarkers)

    # Choose a prime number to determine spacing
    prime = _marker_primes[(privacy_idx + param_idx) % len(_marker_primes)]

    # Offset start so each privacy+param combo shifts markers
    offset = (privacy_idx * 3 + param_idx * 5) % series_len  # tweak 3 and 5 as needed

    # Compute positions with fractional step for better spread
    positions = [(offset + i * prime) % series_len for i in range(nmarkers)]
    positions = sorted(set(positions))  # remove duplicates if any
    return positions



def plot_series_with_sparse_markers(ax, series, label, color, linestyle, marker, privacy_idx, param_idx):
    x = np.arange(len(series))
    markevery = compute_sparse_markevery(len(series), privacy_idx, param_idx, nmarkers=16)
    
    ax.plot(x, series, label=label, color=color, linestyle=linestyle, linewidth=2)
    ax.scatter(x[markevery], series[markevery], color=color, marker=marker, s=40)
    
    return Line2D([0], [0], color=color, linestyle=linestyle, marker=marker, linewidth=2, markersize=6, label=label)




# --- Helper: average FAR/MDR over 300 inner steps ---
def get_series(metric, fdict):
    arr = np.array(fdict.get(metric, []), dtype=float)
    if metric in ["FAR", "MDR"]:
        if arr.ndim == 2 and arr.shape[1] == 300:
            return arr.mean(axis=1)  # average across 300 inner steps
    return arr

# --- Formatting helpers (reuse your own) ---
def sci_label(val):
    """Return value in 10^x format for matplotlib labels."""
    if val == 0:
        return "$0$"
    if val == 2**61 - 1:
        return r"$2^{61}-1$"
    exponent = int(np.floor(np.log10(abs(val))))
    return f"$10^{{{exponent}}}$"

def label_for_privacy_param(privacy, param_val):
    
    if privacy == 2:  # SMPC
        return f"SMPC, p={sci_label(param_val)}"
    elif privacy == 1:  # DP
        return f"DP, σ={sci_label(param_val)}"
    elif privacy == 3:  # SP
        return f"SP, σ={sci_label(param_val)}"
    else:
        return f"Privacy={privacy}, param={sci_label(param_val)}"

# --- Colors and styles ---
_privacy_titles = {1: "DP-Privacy", 2: "SMPC-Privacy", 3: "SP-Privacy"}
_linestyles = ['-', '--', ':', '-.']
_markers = ['o', 's', 'd', '^', 'v', '<', '>', 'p', '*', 'h', '+', 'x']

# --- Colors and styles ---
_privacy_titles = {1: "DP-Privacy", 2: "SMPC-Privacy", 3: "SP-Privacy"}

# Fixed color per privacy
_privacy_colors = {1: 'blue', 2: 'red', 3: 'green'}

# Linestyles for multiple params within each privacy
_linestyles_dp_sp = ['-', '--', ':', '-.', (0, (3,1))]  # 5 styles for DP and SP
_linestyles_smpc = ['-', '--']  # 2 styles for SMPC

# Markers for multiple params
_markers_dp_sp = ['o', 's', 'd', '^', 'v']  # 5 markers
_markers_smpc = ['o', 's']  # 2 markers

def _style_for_param(idx, privacy):
    color = _privacy_colors.get(privacy, 'black')
    if privacy in [1, 3]:
        linestyle = _linestyles_dp_sp[idx % len(_linestyles_dp_sp)]
        marker = _markers_dp_sp[idx % len(_markers_dp_sp)]
    elif privacy == 2:
        linestyle = _linestyles_smpc[idx % len(_linestyles_smpc)]
        marker = _markers_smpc[idx % len(_markers_smpc)]
    else:
        linestyle = '-'
        marker = 'o'
    return {"color": color, "linestyle": linestyle, "marker": marker}

# Now, in all your plotting functions (per-privacy, DP+SP combined, all-together)
# replace the calls to plot_series_with_sparse_markers like this:
# handle = plot_series_with_sparse_markers(ax, series, label, **style, idx=idx)
# and append handle to a handles list for proper legend display



# =====================================================
# 1. Per-privacy plots
# =====================================================
def plot_farmdr_per_privacy(frames, attack, show=True, fig_dpi=150):
    figs = {}
    for privacy in sorted({int(f["privacy"]) for f in frames if int(f["attack"]) == attack}):
        sub = [f for f in frames if int(f["privacy"]) == privacy and int(f["attack"]) == attack]
        if not sub:
            continue

        # Group by parameter
        from collections import defaultdict
        groups = defaultdict(list)
        for f in sub:
            if privacy == 2 and f["param_val"] == 0:
                continue  # skip SMPC p=0
            groups[f["param_val"]].append(f)

        fig, (ax_far, ax_mdr) = plt.subplots(2, 1, figsize=(12, 4), dpi=fig_dpi, constrained_layout=True)

        handles = []
        for idx, (pv, flist) in enumerate(sorted(groups.items())):
            fdict = flist[0]
            style = _style_for_param(idx,privacy)
            series_far = get_series("FAR", fdict)
            series_mdr = get_series("MDR", fdict)
            if series_far.size == 0 or series_mdr.size == 0:
                continue
            label = label_for_privacy_param(privacy, pv)

            # Plot lines + sparse markers
            plot_series_with_sparse_markers(ax_far, series_far, label, **style, privacy_idx=privacy, param_idx=idx)
            plot_series_with_sparse_markers(ax_mdr, series_mdr, label, **style, privacy_idx=privacy, param_idx=idx)

            # Legend handle
            handle = Line2D([0], [0], color=style["color"], linestyle=style["linestyle"],
                            marker=style["marker"], linewidth=2, markersize=6, label=label)
            handles.append(handle)

        for ax, metric in [(ax_far, "FAR"), (ax_mdr, "MDR")]:
            ax.set_title(f"{metric} — {_privacy_titles.get(privacy)}", fontsize=title_fontsize)
            ax.set_xlabel("Round", fontsize=label_fontsize)
            ax.set_ylabel(metric, fontsize=label_fontsize)
            ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
            ax.grid(True)
            ax.legend(handles=handles, fontsize=legend_fontsize, frameon=False)

        if show:
            plt.show()
        figs[_privacy_titles.get(privacy, f"privacy-{privacy}")] = fig
    return figs

# =====================================================
# 2. DP+SP combined, SMPC separate
# =====================================================
def plot_farmdr_dp_sp_combined(frames, attack, show=True, fig_dpi=150):
    figs = {}

    # --- DP + SP together ---
    subset = [f for f in frames if int(f["attack"]) == attack and int(f["privacy"]) in [1, 3]]
    if subset:
        fig, (ax_far, ax_mdr) = plt.subplots(2, 1, figsize=(12, 4), dpi=fig_dpi, constrained_layout=True)
        plotted_idx = 0
        handles = []
        for privacy in [1, 3]:
            sub = [f for f in subset if int(f["privacy"]) == privacy]
            groups = group_by_param(sub, "param_val")
            for idx, (pv, flist) in enumerate(sorted(groups.items())):
                fdict = flist[0]
                style = _style_for_param(plotted_idx,privacy)
                series_far = get_series("FAR", fdict)
                series_mdr = get_series("MDR", fdict)
                if series_far.size == 0 or series_mdr.size == 0:
                    continue
                label = label_for_privacy_param(privacy, pv)

                plot_series_with_sparse_markers(ax_far, series_far, label, **style, privacy_idx=privacy, param_idx=idx)
                plot_series_with_sparse_markers(ax_mdr, series_mdr, label, **style, privacy_idx=privacy, param_idx=idx)

                handle = Line2D([0], [0], color=style["color"], linestyle=style["linestyle"],
                                marker=style["marker"], linewidth=2, markersize=6, label=label)
                handles.append(handle)
                plotted_idx += 1

        for ax, metric in [(ax_far, "FAR"), (ax_mdr, "MDR")]:
            ax.set_title(f"{metric} — DP+SP combined", fontsize=title_fontsize)
            ax.set_xlabel("Round", fontsize=label_fontsize); ax.set_ylabel(metric, fontsize=label_fontsize); ax.grid(True)
            ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
            ax.legend(handles=handles, fontsize=legend_fontsize, frameon=False)

        if show: plt.show()
        figs["DP_SP_combined"] = fig

    # --- SMPC separate ---
    subset = [f for f in frames if int(f["attack"]) == attack and int(f["privacy"]) == 2 and f["param_val"] != 0]
    if subset:
        fig, (ax_far, ax_mdr) = plt.subplots(2, 1, figsize=(12, 4), dpi=fig_dpi, constrained_layout=True)
        groups = group_by_param(subset, "param_val")
        handles = []
        for idx, (pv, flist) in enumerate(sorted(groups.items())):
            fdict = flist[0]
            style = _style_for_param(idx,2)
            series_far = get_series("FAR", fdict)
            series_mdr = get_series("MDR", fdict)
            if series_far.size == 0 or series_mdr.size == 0:
                continue
            label = label_for_privacy_param(2, pv)

            plot_series_with_sparse_markers(ax_far, series_far, label, **style, privacy_idx=2, param_idx=idx)
            plot_series_with_sparse_markers(ax_mdr, series_mdr, label, **style, privacy_idx=2, param_idx=idx)

            handle = Line2D([0], [0], color=style["color"], linestyle=style["linestyle"],
                            marker=style["marker"], linewidth=2, markersize=6, label=label)
            handles.append(handle)

        for ax, metric in [(ax_far, "FAR"), (ax_mdr, "MDR")]:
            ax.set_title(f"{metric} — SMPC only")
            ax.set_xlabel("Round", fontsize=label_fontsize); ax.set_ylabel(metric, fontsize=label_fontsize); ax.grid(True)
            ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)
            ax.legend(handles=handles, fontsize=legend_fontsize, frameon=False)

        if show: plt.show()
        figs["SMPC_only"] = fig

    return figs

# =====================================================
# 3. All-in-one plot
# =====================================================
def plot_farmdr_all_together(frames, attack, show=True, fig_dpi=150):
    subset = [f for f in frames if int(f["attack"]) == attack]
    if not subset:
        return {}

    fig, (ax_far, ax_mdr) = plt.subplots(1, 2, figsize=(14, 5), dpi=fig_dpi, constrained_layout=True)

    plotted_idx = 0
    handles = []
    for privacy in sorted({int(f["privacy"]) for f in subset}):
        # collect only valid frames
        valid_frames = [f for f in subset if int(f["privacy"]) == privacy]

        # special rule: skip SMPC param_val == 0
        if privacy == 2:
            valid_frames = [f for f in valid_frames if f["param_val"] != 0]

        # now group by param
        groups = group_by_param(valid_frames, "param_val")

        for idx, (pv, flist) in enumerate(sorted(groups.items())):
            fdict = flist[0]
            style = _style_for_param(plotted_idx, privacy)
            series_far = get_series("FAR", fdict)
            series_mdr = get_series("MDR", fdict)
            if series_far.size == 0 or series_mdr.size == 0:
                continue
            label = label_for_privacy_param(privacy, pv)

            plot_series_with_sparse_markers(ax_far, series_far, label, **style,
                                            privacy_idx=privacy, param_idx=idx)
            plot_series_with_sparse_markers(ax_mdr, series_mdr, label, **style,
                                            privacy_idx=privacy, param_idx=idx)

            handle = Line2D([0], [0], color=style["color"], linestyle=style["linestyle"],
                            marker=style["marker"], linewidth=2, markersize=6, label=label)
            handles.append(handle)
            plotted_idx += 1

    # Format axes
    for ax, metric in [(ax_far, "FAR"), (ax_mdr, "MDR")]:
        ax.set_title(f"{metric} — all privacies", fontsize=title_fontsize)
        ax.set_xlabel("Round", fontsize=label_fontsize)
        ax.set_ylabel(metric, fontsize=label_fontsize)
        ax.grid(True)
        ax.tick_params(axis='both', which='major', labelsize=tick_fontsize)

    # --- Figure-level legend between the plots ---
    fig.legend(handles, [h.get_label() for h in handles],
               loc='upper center', bbox_to_anchor=(0.5, 1.23),
               ncol=5, fontsize=legend_fontsize, frameon=False)

    if show:
        plt.show()
    return {"all_together": fig}


# =====================================================
# Wrapper: show all variants
# =====================================================
def plot_all_far_mdr_views(frames, attack, show=True):
    figs = {}
    # figs.update(plot_farmdr_per_privacy(frames, attack, show=show))
    # figs.update(plot_farmdr_dp_sp_combined(frames, attack, show=show))
    figs.update(plot_farmdr_all_together(frames, attack, show=show))
    return figs

# =====================================================
# Save helper
# =====================================================
def save_far_mdr_views(figs, prefix="results"):
    for name, fig in figs.items():
        fig.savefig(f"{prefix}_{name}.pdf", bbox_inches="tight", dpi=300)
        print(f"Saved {prefix}_{name}.pdf")


In [None]:
# Helper: save figs as PDF
def save_far_mdr_views_pdf(figs, prefix="results"):
    for name, fig in figs.items():
        filename = f"{prefix}_{name}.pdf"
        fig.savefig(filename, bbox_inches="tight", dpi=300)
        print(f"Saved {filename}")

# List of your pickle files
pickle_files = [
    "resultsFedAVG_CIFAR_FULL.pkl",
    "resultsFedAVG_Fashion_FULL.pkl",
    "resultsFedAVG_MNIST_FULL.pkl"
]

for pkl_file in pickle_files:
    print(f"Processing {pkl_file}...")
    frames_raw = load_frames(pkl_file)
    frames = process_fedavg_frames(frames_raw)
    
    # Get all attacks in this file
    attacks = sorted({f["attack"] for f in frames})
    
    for attack in attacks:
        print(f"  Generating plots for attack {attack}...")
        figs = plot_all_far_mdr_views(frames, attack=attack, show=False)  # don't display interactively
        
        # Build unique prefix for saving
        prefix = f"{pkl_file.replace('.pkl','')}_attack{attack}"
        save_far_mdr_views_pdf(figs, prefix=prefix)
        
print("All plots generated and saved as PDF ✅")


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from collections import defaultdict

# --- Colors and styles ---
_privacy_titles = {1: "DP-Privacy", 2: "SMPC-Privacy", 3: "SP-Privacy"}
_privacy_colors = {1: 'blue', 2: 'red', 3: 'green'}
_linestyles_dp_sp = ['-', '--', ':', '-.', (0, (3,1))]
_linestyles_smpc = ['-', '--']
_markers_dp_sp = ['o', 's', 'd', '^', 'v']
_markers_smpc = ['o', 's']

def _style_for_param(idx, privacy):
    color = _privacy_colors.get(privacy, 'black')
    if privacy in [1, 3]:
        linestyle = _linestyles_dp_sp[idx % len(_linestyles_dp_sp)]
        marker = _markers_dp_sp[idx % len(_markers_dp_sp)]
    elif privacy == 2:
        linestyle = _linestyles_smpc[idx % len(_linestyles_smpc)]
        marker = _markers_smpc[idx % len(_markers_smpc)]
    else:
        linestyle = '-'
        marker = 'o'
    return {"color": color, "linestyle": linestyle, "marker": marker}

def compute_sparse_markevery(series_len, privacy_idx, param_idx, nmarkers=16):
    primes = [2, 3, 5, 7, 11, 13, 17]
    prime = primes[(privacy_idx + param_idx) % len(primes)]
    offset = (privacy_idx * 3 + param_idx * 5) % series_len
    positions = [(offset + i * prime) % series_len for i in range(nmarkers)]
    return sorted(set(positions))

def get_pdmm_error_series(fdict):
    arr = np.array(fdict.get("Error", []), dtype=float)
    if arr.ndim == 2:
        return arr.mean(axis=0)  # average over 300 inner steps
    return arr

def label_for_privacy_param(privacy, param_val):
    if privacy == 2:  # SMPC
        return f"SMPC, p={param_val}"
    elif privacy == 1:  # DP
        return f"DP, σ={param_val}"
    elif privacy == 3:  # SP
        return f"SP, σ={param_val}"
    else:
        return f"Privacy={privacy}, param={param_val}"

def plot_pdmm_error_per_privacy(frames, attack, fig_dpi=150):
    figs = {}
    privacies = sorted({int(f["privacy"]) for f in frames if int(f["attack"]) == attack})
    for privacy in privacies:
        subset = [f for f in frames if int(f["privacy"]) == privacy and int(f["attack"]) == attack]
        if not subset:
            continue

        groups = defaultdict(list)
        for f in subset:
            if privacy == 2 and f["param_val"] == 0:
                continue
            groups[f["param_val"]].append(f)

        fig, ax = plt.subplots(figsize=(12, 4), dpi=fig_dpi, constrained_layout=True)
        handles = []

        for idx, (pv, flist) in enumerate(sorted(groups.items())):
            fdict = flist[0]
            series = get_pdmm_error_series(fdict)
            if series.size == 0:
                continue
            style = _style_for_param(idx, privacy)
            label = label_for_privacy_param(privacy, pv)
            x = np.arange(len(series))
            markevery = compute_sparse_markevery(len(series), privacy, idx)
            ax.plot(x, series, label=label, color=style["color"], linestyle=style["linestyle"], linewidth=2)
            ax.scatter(x[markevery], series[markevery], color=style["color"], marker=style["marker"], s=40)
            handle = Line2D([0], [0], color=style["color"], linestyle=style["linestyle"],
                            marker=style["marker"], linewidth=2, markersize=6, label=label)
            handles.append(handle)

        ax.set_title(f"PDMM Error — {_privacy_titles.get(privacy)}", fontsize=20)
        ax.set_xlabel("Round", fontsize=17)
        ax.set_ylabel("Error", fontsize=17)
        ax.tick_params(axis='both', which='major', labelsize=15)
        ax.grid(True)
        ax.legend(handles=handles, fontsize=15, frameon=False)

        plt.show()
        figs[_privacy_titles.get(privacy, f"privacy-{privacy}")] = fig

    return figs
