In [20]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from collections import defaultdict
from IPython.display import display

# ---------- Helper Functions ----------

def load_frames(pickle_path):
    frames = []
    with open(pickle_path, "rb") as fh:
        while True:
            try:
                obj = pickle.load(fh)
                frames.append(obj)
            except (EOFError, pickle.UnpicklingError):
                break
    return frames

def process_fedavg_frames(frames):
    processed = []
    for f in frames:
        param_key = 'noise_STD' if 'noise_STD' in f else ('p' if 'p' in f else None)
        processed.append({
            'privacy': f['privacy'],
            'attack': f['attack'],
            'param_val': f[param_key] if param_key else 0,
            'FAR': f.get('FAR', []),
            'MDR': f.get('MDR', []),
            'Error': f.get('Error', []),
            'train_loss': f.get('train_loss', []),
            'train_accuracy': f.get('train_accuracy', []),
            'test_loss': f.get('test_loss', []),
            'test_accuracy': f.get('test_accuracy', []),
        })
    return processed

def group_by_param(frames, key='param_val'):
    groups = defaultdict(list)
    for f in frames:
        groups[f[key]].append(f)
    return groups

def get_eval_rounds(total_rounds, num_evals):
    rounds = []
    count = 0
    while len(rounds) < num_evals and count < total_rounds:
        if count < 50:
            interval = 5
        elif count < 200:
            interval = 10
        else:
            interval = 20
        if len(rounds) == 0 or count - rounds[-1] >= interval:
            rounds.append(count)
        count += 1
    return rounds[:num_evals]

def format_sigma_label(val):
    if val == 0:
        return "$0$"
    logval = np.log10(abs(val)) if val != 0 else 0
    if np.isclose(logval, int(round(logval)), atol=1e-6):
        sign = "-" if val < 0 else ""
        return f"${sign}10^{{{int(round(logval))}}}$"
    return f"${val:.2g}$"

def format_p_label(val):
    if val == 0:
        return "$0$"
    elif val == 2**61-1:
        return r"$2^{61}-1$"
    else:
        return f"${val}$"

def compute_markevery(idx, series_len, nmarkers=8):
    nmarkers = min(nmarkers, series_len)
    if series_len <= nmarkers:
        shift = idx*2 % series_len if series_len > 1 else 0
        markevery = [(i + shift) % series_len for i in range(series_len)]
    else:
        phase = int(round(idx * series_len / (2 * nmarkers)))
        mark_locs = (np.linspace(0, series_len-1, nmarkers, dtype=int) + phase) % series_len
        markevery = mark_locs.tolist()
    return markevery

# ---------- Dashboard Grid Plotting with Save Button & Legends ----------

def plot_dashboard_grid(frames, attack, rnd, save_path=None):
    privacies = sorted({f['privacy'] for f in frames})
    privacy_titles = {1: "DP-Privacy", 2: "SMPC-Privacy", 3: "SP-Privacy"}
    pretty_param = {1: r'$\sigma^2$', 2: r'$p$', 3: r'$\sigma^2$'}
    all_metrics = [
        ('MDR',         'MDR',         'PDMM iteration'),
        ('FAR',         'FAR',         'PDMM iteration'),
        ('Error',       'Consensus Error', 'PDMM iteration'),
        ('train_loss',  'Train Loss',      'Round'),
        ('train_accuracy', 'Train Accuracy', 'Round'),
        ('test_loss',   'Test Loss',       'Round'),
        ('test_accuracy', 'Test Accuracy', 'Round')
    ]
    nrows = len(all_metrics)
    ncols = len(privacies)
    fig, axs = plt.subplots(nrows, ncols, figsize=(6*ncols, 3*nrows), squeeze=False)
    markers = ['o', 's', 'd', '^', 'v', '<', '>', 'p', '*', 'h', '+', 'x']
    column_legends = [[] for _ in range(ncols)]
    column_labels = [[] for _ in range(ncols)]

    for col, privacy in enumerate(privacies):
        subset = [f for f in frames if f['privacy'] == privacy and f['attack'] == attack]
        if not subset:
            for row in range(nrows):
                axs[row, col].set_visible(False)
            continue
        groups = group_by_param(subset, 'param_val')
        for row, (metric, ylabel, xlabel) in enumerate(all_metrics):
            ax = axs[row, col]
            plotted = False
            for idx, (param_val, flist) in enumerate(sorted(groups.items())):
                # Remove p=0 for SMPC-Privacy
                if privacy == 2 and param_val == 0:
                    continue
                fdict = flist[0]
                series = fdict.get(metric, [])
                if not isinstance(series, (list, np.ndarray)) or len(series) == 0:
                    continue
                # Per-iteration/per-round logic
                if metric in ['FAR', 'MDR', 'Error']:
                    if isinstance(series[0], (list, np.ndarray)):
                        if rnd < 0 or rnd >= len(series):
                            continue
                        arr = np.array(series[rnd], dtype=float)
                    else:
                        arr = np.array(series, dtype=float)
                    x = np.arange(len(arr))
                    if metric in ('FAR', 'MDR'):
                        arr = np.clip(arr, 0.0, 1.0)
                    if metric == 'Error':
                        ax.set_yscale('log')
                elif metric in ['train_loss', 'train_accuracy']:
                    arr = np.array(series, dtype=float)
                    x = np.arange(len(arr))
                else:  # test_loss, test_accuracy: adaptive x-axis
                    arr = np.array(series, dtype=float)
                    if fdict.get('FAR', []):
                        total_rounds = len(fdict['FAR'])
                    else:
                        total_rounds = max(get_eval_rounds(1000, len(arr)))+1
                    eval_x = get_eval_rounds(total_rounds, len(arr))
                    if len(eval_x) > len(arr):
                        eval_x = eval_x[:len(arr)]
                    elif len(eval_x) < len(arr):
                        eval_x = list(eval_x) + [eval_x[-1]] * (len(arr) - len(eval_x))
                    x = eval_x

                marker = markers[idx % len(markers)]
                markevery = compute_markevery(idx, len(arr), nmarkers=8)
                if privacy == 2:
                    p_label = format_p_label(param_val)
                    label = f"{pretty_param[privacy]}={p_label}"
                else:
                    sigma_label = format_sigma_label(param_val)
                    label = f"{pretty_param[privacy]}={sigma_label}"
                handle = ax.plot(x, arr, label=label, marker=marker, markersize=5, markevery=markevery, linewidth=2)[0]
                if row == 0:
                    column_legends[col].append(handle)
                    column_labels[col].append(label)
                plotted = True

            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)
            ax.set_title(f"{ylabel} ({privacy_titles.get(privacy, privacy)})", fontsize=13)
            ax.grid(True)
            if metric in ('FAR', 'MDR'):
                ax.set_ylim(-0.05, 1.05)
    # Legends on top
    for col, privacy in enumerate(privacies):
        handles = column_legends[col]
        labels = column_labels[col]
        if handles:
            left = (col + 0.05) / ncols
            width = 0.9 / ncols
            top = 1.01 + 0.11 * (nrows/8)  # slight extra space for more rows
            leg_ax = fig.add_axes([left, top, width, 0.08])
            leg_ax.axis('off')
            leg_ax.legend(
                handles, labels, loc='upper center', fontsize=13, ncol=max(1, len(labels)//2),
                title=privacy_titles.get(privacy, f"Privacy={privacy}"), title_fontsize=14, frameon=False,
                bbox_to_anchor=(0.5, 0.5), borderaxespad=0
            )
    fig.suptitle(f"Federated Metrics Dashboard | Attack={attack} | Round={rnd}", fontsize=20)
    fig.tight_layout(rect=[0, 0, 1, 0.96])
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)
        print(f"Saved dashboard as {save_path}")
        plt.close(fig)
    else:
        plt.show()

# ---------- Widget with Save Button ----------

def fedavg_dashboard_widget(pickle_path):
    tmp_frames = load_frames(pickle_path)
    frames = process_fedavg_frames(tmp_frames)
    if not frames:
        print("No valid frames found in the file.")
        return
    attacks = sorted({fb['attack'] for fb in frames})
    if frames[0]['FAR']:
        rounds = len(frames[0]['FAR']) if isinstance(frames[0]['FAR'][0], (list, np.ndarray)) else 1
    else:
        rounds = 1
    round_slider = widgets.IntSlider(min=0, max=rounds-1, description='Round')
    attack_dropdown = widgets.Dropdown(options=attacks, description='Attack')
    save_button = widgets.Button(description="Save Dashboard as PDF", button_style="success")
    output = widgets.Output()

    def plot_and_display(attack, rnd):
        output.clear_output(wait=True)
        with output:
            plot_dashboard_grid(frames, attack, rnd)

    def save_dashboard_clicked(b):
        attack = attack_dropdown.value
        rnd = round_slider.value
        filename = f"dashboard_attack{attack}_round{rnd}.pdf"
        plot_dashboard_grid(frames, attack, rnd, save_path=filename)

    attack_dropdown.observe(lambda change: plot_and_display(attack_dropdown.value, round_slider.value), names="value")
    round_slider.observe(lambda change: plot_and_display(attack_dropdown.value, round_slider.value), names="value")
    save_button.on_click(save_dashboard_clicked)

    ui = widgets.HBox([attack_dropdown, round_slider, save_button])
    display(ui, output)
    plot_and_display(attack_dropdown.value, round_slider.value)

# --- Usage in notebook cell ---
# fedavg_dashboard_widget("resultsFedAVG.pkl")




In [21]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from ipywidgets import interact, Dropdown, IntSlider, Button, HBox, Output
from collections import defaultdict
from IPython.display import display
import pandas as pd
import ipywidgets as widgets

def load_frames(pickle_path):
    frames = []
    with open(pickle_path, "rb") as fh:
        while True:
            try:
                obj = pickle.load(fh)
                frames.append(obj)
            except (EOFError, pickle.UnpicklingError):
                break
    return frames

def process_fedavg_frames(frames):
    processed = []
    for f in frames:
        param_key = 'noise_STD' if 'noise_STD' in f else 'p'
        processed.append({
            'privacy': f['privacy'],
            'attack': f['attack'],
            'param_val': f[param_key],
            'FAR': f.get('FAR', []),
            'MDR': f.get('MDR', []),
            'train_loss': f.get('train_loss', []),
            'train_accuracy': f.get('train_accuracy', []),
            'test_loss': f.get('test_loss', []),
            'test_accuracy': f.get('test_accuracy', []),
        })
    return processed

def group_by_param(frames, key='param_val'):
    groups = defaultdict(list)
    for f in frames:
        groups[f[key]].append(f)
    return groups

def get_eval_rounds(total_rounds, num_evals):
    rounds = []
    count = 0
    while len(rounds) < num_evals and count < total_rounds:
        if count < 50:
            interval = 5
        elif count < 200:
            interval = 10
        else:
            interval = 20
        if len(rounds) == 0 or count - rounds[-1] >= interval:
            rounds.append(count)
        count += 1
    return rounds[:num_evals]

def format_sigma_label(val):
    if val == 0:
        return "$0$"
    logval = np.log10(abs(val))
    if np.isclose(logval, int(round(logval)), atol=1e-6):
        sign = "-" if val < 0 else ""
        return f"${sign}10^{{{2*int(round(logval))}}}$"
    return f"${val:.2g}$"

def format_p_label(val):
    if val == 0:
        return "$0$"
    elif val == 2**61-1:
        return r"$2^{61}-1$"
    else:
        return f"${val}$"

def compute_markevery(idx, series_len, nmarkers=8):
    nmarkers = min(nmarkers, series_len)
    if series_len <= nmarkers:
        shift = idx*2 % series_len if series_len > 1 else 0
        markevery = [(i + shift) % series_len for i in range(series_len)]
    else:
        phase = int(round(idx * series_len / (2 * nmarkers)))
        mark_locs = (np.linspace(0, series_len-1, nmarkers, dtype=int) + phase) % series_len
        markevery = mark_locs.tolist()
    return markevery

def grid_plot_all_metrics(frames, attack, rnd, fig_dpi=150):
    # Font settings
    title_fontsize = 20
    label_fontsize = 17
    tick_fontsize  = 15
    legend_fontsize = 18
    legend_title_fontsize = 18

    metrics_rows = [
        ('train_loss', 'Round', 'Train Loss'),
        ('train_accuracy', 'Round', 'Train Accuracy'),
        ('test_loss', 'Round', 'Test Loss'),
        ('test_accuracy', 'Round', 'Test Accuracy'),
    ]
    privacies = sorted({fb['privacy'] for fb in frames})
    privacy_titles = {
        1: "DP-Privacy",
        2: "SMPC-Privacy",
        3: "SP-Privacy"
    }
    n_cols = len(privacies)
    n_rows = len(metrics_rows)
    markers = ['o', 's', 'd', '^', 'v', '<', '>', 'p', '*', 'h', '+', 'x']

    fig, axs = plt.subplots(n_rows, n_cols, figsize=(5.5*n_cols, 3.3*n_rows), dpi=fig_dpi, constrained_layout=True)
    if n_rows == 1 or n_cols == 1:
        axs = np.atleast_2d(axs)
    column_legends = [[] for _ in range(n_cols)]
    column_labels = [[] for _ in range(n_cols)]

    for col, privacy in enumerate(privacies):
        subset = [f for f in frames if f['privacy']==privacy and f['attack']==attack]
        if not subset:
            for row in range(n_rows):
                axs[row, col].set_visible(False)
            continue
        groups = group_by_param(subset, 'param_val')
        pretty_param = r'$p$' if privacy == 2 else r'$\sigma^2$'

        for row, (metric, xlab, ylab) in enumerate(metrics_rows):
            ax = axs[row, col]
            for idx, (param_val, flist) in enumerate(sorted(groups.items())):
                # --- SKIP p=0 for SMPC privacy ---
                if privacy == 2 and param_val == 0:
                    continue
                fdict = flist[0]
                marker = markers[idx % len(markers)]
                series = np.array(fdict.get(metric, []), dtype=float)
                if series.size == 0:
                    continue
                # x-axis logic
                if metric in ['train_loss', 'train_accuracy']:
                    x = np.arange(len(series))
                else:  # test_loss, test_accuracy
                    # Infer the likely number of total rounds by checking training metrics or FAR
                    max_train_len = max(
                        len(fdict.get('train_loss', [])), 
                        len(fdict.get('train_accuracy', [])), 
                        len(fdict.get('FAR', []))
                    )
                    total_rounds = max_train_len
                    x = get_eval_rounds(total_rounds, len(series))
                markevery = compute_markevery(idx, len(series), nmarkers=8)
                if privacy == 2:
                    p_label = format_p_label(param_val)
                    label = f"{pretty_param}={p_label}"
                else:
                    sigma_label = format_sigma_label(param_val)
                    label = f"{pretty_param}={sigma_label}"
                handle = ax.plot(x, series, label=label, marker=marker, markersize=7, markevery=markevery, linewidth=2.2)[0]
                # Only collect once per privacy
                if row == 0:
                    column_legends[col].append(handle)
                    column_labels[col].append(label)
            ax.set_xlabel(xlab, fontsize=label_fontsize)
            ax.set_ylabel(ylab, fontsize=label_fontsize)
            ax.set_title(f"{ylab} - {privacy_titles.get(privacy, privacy)}", fontsize=title_fontsize)
            ax.tick_params(axis='both', labelsize=tick_fontsize)
            ax.grid(True)
    # Legends on top as before
    for col, privacy in enumerate(privacies):
        handles = column_legends[col]
        labels = column_labels[col]
        if handles:
            left = (col + 0.05) / n_cols
            width = 0.9 / n_cols
            top = 1.05
            leg_ax = fig.add_axes([left, top, width, 0.08])
            leg_ax.axis('off')
            leg_ax.legend(
                handles, labels, loc='upper center', fontsize=legend_fontsize, ncol=max(1, len(labels)//2),
                title=privacy_titles.get(privacy, f"Privacy={privacy}"), title_fontsize=legend_title_fontsize, frameon=False,
                bbox_to_anchor=(0.5, 1), borderaxespad=0
            )
    return fig, axs

def save_grid_as_pdf(frames, attack, rnd, filename=None):
    if filename is None:
        filename = f"fedavg_grid_attack-{attack}_round-{rnd}.pdf"
    plt.ioff()
    fig, axs = grid_plot_all_metrics(frames, attack, rnd, fig_dpi=300)
    fig.savefig(filename, bbox_inches='tight', dpi=300)
    plt.close(fig)
    plt.ion()
    print(f"Saved figure as {filename}")

def grid_plot_all_metrics_and_show(frames, attack, rnd, fig_dpi=150):
    fig, axs = grid_plot_all_metrics(frames, attack, rnd, fig_dpi=fig_dpi)
    plt.show()

def extract_final_metrics_table(frames, attack, round_idx):
    metrics = ['train_loss', 'train_accuracy', 'test_loss', 'test_accuracy']
    privacy_titles = {
        1: "DP-Privacy",
        2: "SMPC-Privacy",
        3: "SP-Privacy"
    }
    privacies = sorted({fb['privacy'] for fb in frames})
    rows = []
    for privacy in privacies:
        subset = [f for f in frames if f['privacy'] == privacy and f['attack'] == attack]
        if not subset:
            continue
        groups = group_by_param(subset, 'param_val')
        for param_val, flist in sorted(groups.items()):
            fdict = flist[0]
            row = {
                "Privacy": privacy_titles.get(privacy, str(privacy)),
                "Param": format_p_label(param_val) if privacy == 2 else format_sigma_label(param_val)
            }
            for metric in metrics:
                vals = np.array(fdict.get(metric, []), dtype=float)
                if vals.size == 0:
                    row[metric] = None
                else:
                    row[metric] = vals[-1]
            rows.append(row)
    df = pd.DataFrame(rows, columns=["Privacy", "Param"] + metrics)
    return df

def fedavg_grid_plotter_widget_with_save(pickle_path):
    tmp_frames = load_frames(pickle_path)
    frames = process_fedavg_frames(tmp_frames)
    if not frames:
        print("No valid frames found in the file.")
        return
    attacks   = sorted({fb['attack'] for fb in frames})
    n_rounds  = max(len(f['FAR']) if f['FAR'] else 1 for f in frames)

    attack_dropdown = Dropdown(options=attacks, description='Attack')
    round_slider = IntSlider(min=0, max=n_rounds-1, value=0, description='Round')
    save_button = Button(description="Save as PDF")
    output = Output()

    def plot_and_show(attack, rnd):
        with output:
            output.clear_output(wait=True)
            grid_plot_all_metrics_and_show(frames, attack, rnd)

    def save_button_clicked(b):
        attack = attack_dropdown.value
        rnd = round_slider.value
        save_grid_as_pdf(frames, attack, rnd)

    save_button.on_click(save_button_clicked)

    ui = HBox([attack_dropdown, round_slider, save_button])
    interact_output = widgets.interactive_output(plot_and_show, {'attack': attack_dropdown, 'rnd': round_slider})
    display(ui, interact_output, output)

# Example usage in a notebook cell:
fedavg_grid_plotter_widget_with_save("resultsFedAVG.pkl")


HBox(children=(Dropdown(description='Attack', options=(0, 1, 3, 4, 5, 6), value=0), IntSlider(value=0, descrip…

Output()

Output()