In [None]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math 
from tqdm import tqdm
from sklearn.metrics import roc_curve, roc_auc_score

from src.utils import print_yaml, rescale_logits, calculate_tau, plot_with_band, plot_bootstrap_band, pick_weighted_models, calculate_group_roc_pooled, plot_roc, plot_pareto, calculate_group_roc_pooled
from src.save_load import savePlot, loadFbdStudy, loadTargetSignals, loadShadowModelSignals, saveVisData, loadVisData
from LeakPro.leakpro.attacks.mia_attacks.lira import lira_vectorized
from LeakPro.leakpro.attacks.mia_attacks.rmia import rmia_vectorised, rmia_get_gtlprobs
from src.save_load import savePlot

In [None]:
# -------------------------- #
#   Load fbd study results   #
# -------------------------- #
study_name = "cifar10-resnet-fbd-00c3fdac98"
study_type = study_name.split("-")[2]
study_path = os.path.join("study", study_type+"_studies") 
metadata = []
gtl_probs = []
resc_logits = []
# If path is removed it will attempt to load the study from "study/..."
metadata, fbd_trial_results, gtl_probs, resc_logits, labels = loadFbdStudy(study_name, metadata=True, gtl=True, logits=True, path=study_path)
target_folder = metadata["study"]["target_folder"]

In [None]:
# -------------------------------- #
#   Load baseline target signals   #
# -------------------------------- #
target_folder = metadata["study"]["target_folder"]

baseline_logits, baseline_inmask, baseline_resc_logits, baseline_gtl_probs, baseline_metadata, baseline_metadata_pkl = loadTargetSignals(target_folder)

# Rescale baseline target logits
if baseline_resc_logits is None and baseline_logits is not None:
    baseline_resc_logits = rescale_logits(baseline_logits, labels)
print(f"Baseline resc_logits: {baseline_resc_logits[:10]}, shape: {baseline_resc_logits.shape}")

# Calculate the GTL Probabilities for the target logits
if baseline_gtl_probs is None and baseline_logits is not None:
    baseline_gtl_probs = rmia_get_gtlprobs(baseline_logits, labels)
print(f"Target gtl_probs: {baseline_gtl_probs[:10]}, shape: {baseline_gtl_probs.shape}")

baseline_accuracy = baseline_metadata_pkl.test_result.accuracy
print(f"Baseline accuracy: {baseline_accuracy}")

In [None]:
# ----------------------------- #
#   Load shadow model signals   #
# ----------------------------- #
shadow_logits, rescaled_shadow_logits, shadow_gtl_probs, shadow_inmask, sm_metadata, missing_indices = loadShadowModelSignals(target_folder)

# Rescale and calc gtl for shadow models
# rescaled_shadow_logits and shadow_gtl_probs will be false if they are not loaded
if rescaled_shadow_logits is False or shadow_gtl_probs is False:
    N, M, C = shadow_logits.shape
    shadow_gtl_probs_list = []
    rescaled_shadow_logits_list = []
    for m in range(M):
        model_logits = shadow_logits[:, m, :]  # shape (N, C)
        
        if not shadow_gtl_probs:
            shadow_gtl_probs_list.append(rmia_get_gtlprobs(model_logits, labels))
        if not rescaled_shadow_logits:
            rescaled_shadow_logits_list.append(rescale_logits(model_logits, labels))
            
        print(f"{len(shadow_gtl_probs_list)} shadow gtl probs calculated and {len(rescaled_shadow_logits_list)} rescaled logits calculated")
    
    if not shadow_gtl_probs:    
        shadow_gtl_probs = np.stack(shadow_gtl_probs_list, axis=1)  # shape = (N, M)
    if not rescaled_shadow_logits:
        rescaled_shadow_logits = np.stack(rescaled_shadow_logits_list, axis=1)
    
print(f"shadow gtl probs shape: {shadow_gtl_probs.shape}, shadow resc logits shape: {rescaled_shadow_logits.shape}")

In [None]:
# ----------------------------- #
#   Audit the baseline target   #
# ----------------------------- #
# Audit the baseline target with LiRA
lira_scores = lira_vectorized(baseline_resc_logits, rescaled_shadow_logits, shadow_inmask, "carlini", online=True)
print(f"lira_scores first 10: {lira_scores[:10]}")
# Audit the baseline target with RMIA
rmia_scores = rmia_vectorised(baseline_gtl_probs, shadow_gtl_probs, shadow_inmask, online=True, use_gpu_if_available=True)
print(f"rmia_scores first 10: {rmia_scores[:10]}")

# ------------------------------------- #
#   Audit or Load the weighted target   #
# ------------------------------------- #
data = loadVisData(study_name, study_path=study_path)  # If study_path is removed it will attempt to load the study from "study/..."

# Audit all weighted models with RMIA
if data.get('w_rmia_scores') is None:
    weighted_rmia_scores = []
    for gtl in tqdm(gtl_probs, desc=f"gtl_probs"):
        weighted_rmia_scores.append(rmia_vectorised(gtl, shadow_gtl_probs, shadow_inmask, online=True, use_gpu_if_available=True))
    saveVisData(np.stack(weighted_rmia_scores, axis=0), "w_rmia_scores", study_name, path=study_path)
else:
    weighted_rmia_scores = data.get('w_rmia_scores')
# Audit all weighted models with RMIA
if data.get('w_lira_scores') is None:
    weighted_lira_scores = []
    for resc_logits in tqdm(resc_logits, desc=f"resc_logits"):
        weighted_lira_scores.append(lira_vectorized(resc_logits, rescaled_shadow_logits, shadow_inmask, "carlini", online=True))
    saveVisData(np.stack(weighted_lira_scores, axis=0), "w_lira_scores", study_name, path=study_path)
else:
    weighted_lira_scores = data.get('w_lira_scores')
    
print(f"rmia_scores_weighted count: {len(weighted_rmia_scores)}")
print(f"lira_scores_weighted count: {len(weighted_lira_scores)}")

In [None]:
# ----------------- #
#   Calculate tau   #
# ----------------- #
# Baseline
fpr1 = 0.1
fpr2 = 0.01
fpr3 = 0.001
tau_baseline_rmia_1 = calculate_tau(rmia_scores, baseline_inmask, fpr1)
tau_baseline_rmia_2 = calculate_tau(rmia_scores, baseline_inmask, fpr2)
tau_baseline_rmia_3 = calculate_tau(rmia_scores, baseline_inmask, fpr3)

tau_baseline_lira_1 = calculate_tau(lira_scores, baseline_inmask, fpr1)
tau_baseline_lira_2 = calculate_tau(lira_scores, baseline_inmask, fpr2)
tau_baseline_lira_3 = calculate_tau(lira_scores, baseline_inmask, fpr3)

print(f"baseline tau rmia_1: {tau_baseline_rmia_1} at fpr: {fpr1}")
print(f"baseline tau lira_1: {tau_baseline_lira_1} at fpr: {fpr1}")

# Weighted 0.1 fpr
weigted_taus_rmia_1 = []
weigted_taus_lira_1 = []
weigted_taus_rmia_2 = []
weigted_taus_lira_2 = []
weigted_taus_rmia_3 = []
weigted_taus_lira_3 = []

# --- RMIA ---
rmia_scores_count = 1
for w_rmia_score in tqdm(weighted_rmia_scores, desc=f"w_rmia_scores"):
    weigted_taus_rmia_1.append(calculate_tau(w_rmia_score, baseline_inmask, fpr1))
    weigted_taus_rmia_2.append(calculate_tau(w_rmia_score, baseline_inmask, fpr2))
    weigted_taus_rmia_3.append(calculate_tau(w_rmia_score, baseline_inmask, fpr3))
print(f"n rmia taus: {weigted_taus_rmia_1[:5]}, {weigted_taus_rmia_2[:5]}, {weigted_taus_rmia_3[:5]}")

# --- LIRA ---
lira_scores_count = 1
for w_lira_score in tqdm(weighted_lira_scores, desc=f"w_lira_scores"):
    weigted_taus_lira_1.append(calculate_tau(w_lira_score, baseline_inmask, fpr1))
    weigted_taus_lira_2.append(calculate_tau(w_lira_score, baseline_inmask, fpr2))
    weigted_taus_lira_3.append(calculate_tau(w_lira_score, baseline_inmask, fpr3))
print(f"n lira taus: {weigted_taus_lira_1[:5]}, {weigted_taus_lira_2[:5]}, {weigted_taus_lira_3[:5]}")

# Study outputs
accuracies = [res.accuracy for res in fbd_trial_results]
noises = [res.noise for res in fbd_trial_results]
centralities = [res.centrality for res in fbd_trial_results]
temperatures = [res.temperature for res in fbd_trial_results]
tau_rmia = [res.tau for res in fbd_trial_results]   # tau in this context is log(tauc_fbd@0.1/tauc_ref@0.1)
print(f"study tau: {tau_rmia[:5]}")

In [None]:
# ----------------- #
#   Scatter plots   #
# ----------------- #

# Plot
fig, axes = plt.subplots(2, 2, figsize=(12, 5))
# --- Top-left: your study tau (FbD) --- #
ax = axes[0,0]
ax.scatter(tau_rmia, accuracies, label="FbD study", color="purple")
ax.set_xlabel("Study Objective: log(tauc_fbd@0.1 / tauc_ref@0.1)")
ax.set_ylabel("Accuracy")
ax.grid(True, alpha=0.3)
ax.legend(loc="lower right")

# --- Top-right: tau@0.1 for RMIA + LIRA --- #
ax = axes[0,1]
ax.scatter(weigted_taus_rmia_1, accuracies, label="RMIA on FbD", color="cornflowerblue", marker='o')
ax.scatter(weigted_taus_lira_1, accuracies, label="LIRA on FbD", color="orange", marker='x')
ax.scatter(tau_baseline_rmia_1, baseline_accuracy, label="RMIA on Baseline", color="red", marker='o')
ax.scatter(tau_baseline_lira_1, baseline_accuracy, label="LIRA on Baseline", color="green", marker='x')
ax.set_xlabel("τ@0.1FPR")
ax.set_ylabel("Accuracy")
ax.grid(True, alpha=0.3)
ax.legend(loc="lower right")

# --- Bottom-left: tau@0.01 --- #
ax = axes[1,0]
ax.scatter(weigted_taus_rmia_2, accuracies, label="RMIA on FbD", color="cornflowerblue", marker='o')
ax.scatter(weigted_taus_lira_2, accuracies, label="LIRA on FbD", color="orange", marker='x')
ax.scatter(tau_baseline_rmia_2, baseline_accuracy, label="RMIA on Baseline", color="red", marker='o')
ax.scatter(tau_baseline_lira_2, baseline_accuracy, label="LIRA on Baseline", color="green", marker='x')
ax.set_xlabel("τ@0.01FPR")
ax.set_ylabel("Accuracy")
ax.grid(True, alpha=0.3)
ax.legend(loc="lower right")

# --- Bottom-right: tau@0.001 --- #
ax = axes[1,1]
ax.scatter(weigted_taus_rmia_3, accuracies, label="RMIA on FbD", color="cornflowerblue", marker='o')
ax.scatter(weigted_taus_lira_3, accuracies, label="LIRA on FbD", color="orange", marker='x')
ax.scatter(tau_baseline_rmia_3, baseline_accuracy, label="RMIA on Baseline", color="red", marker='o')
ax.scatter(tau_baseline_lira_3, baseline_accuracy, label="LIRA on Baseline", color="green", marker='x')
ax.set_xlabel("τ@0.001FPR")
ax.set_ylabel("Accuracy")
ax.grid(True, alpha=0.3)
ax.legend(loc="lower right")

plt.tight_layout()
plt.show()

savePlot(fig, "Scatter_plots", study_name, study_path)

In [None]:
# -------------------- #
#   Create Dataframe   #
# -------------------- #
df = pd.DataFrame({
    "accuracy": accuracies,

    "study": np.arange(len(noises))+1,

    # RMIA / LiRA taus at different FPRs
    "tau_0.1_rmia": weigted_taus_rmia_1,
    "tau_0.1_lira": weigted_taus_lira_1,

    "tau_0.01_rmia": weigted_taus_rmia_2,
    "tau_0.01_lira": weigted_taus_lira_2,

    "tau_0.001_rmia": weigted_taus_rmia_3,
    "tau_0.001_lira": weigted_taus_lira_3,

    # Hyperparameters
    "noise": noises,
    "centrality": centralities,
    "temperature": temperatures,
})
df = df.reset_index(drop=True)
df["model_idx"] = df.index
df.head(10)

In [None]:
# --------------------------- #
#   Create FbD Param Groups   #
# --------------------------- #
fbd_grps = (
    df.groupby(["noise", "centrality", "temperature"])
    .agg(
        accuracy_mean=("accuracy", "mean"),
        accuracy_std=("accuracy", "std"),

        noise_mean=("noise", "mean"),
        noise_std=("noise", "std"),
        
        centrality_mean=("centrality", "mean"),
        centrality_std=("centrality", "std"),
        
        temperature_mean=("temperature", "mean"),
        temperature_std=("temperature", "std"),
        model_indices=("model_idx", list),
        n=("study", "size"),
    )
    .reset_index()
)
fbd_grps[["noise", "centrality", "temperature"]] = (fbd_grps[["noise", "centrality", "temperature"]].round(3))
fbd_grps = fbd_grps.set_index(["noise", "centrality", "temperature"])

print(fbd_grps.shape)
group_threshold = 7
fbd_grps = fbd_grps[fbd_grps["n"]>=group_threshold]
print(fbd_grps.shape)
fbd_grps

In [None]:
# ---------------------------------- #
#   Calculate FbD Param groups ROC   #
# ---------------------------------- #
# Calculate the baseline curves
baseline_fpr_curve_lira, baseline_tpr_curve_lira, _ = roc_curve(baseline_inmask, lira_scores)
baseline_fpr_curve_rmia, baseline_tpr_curve_rmia, _ = roc_curve(baseline_inmask, rmia_scores)
# ------------------------- Select groups ------------------------- #
groups = []
groups.append(fbd_grps.loc[(0.000, 0.40, 0.15)])
groups.append(fbd_grps.loc[(0.000, 0.40, 0.00)])
groups.append(fbd_grps.loc[(0.030, 0.40, 0.00)])
# ---------------------- Collect group scores ---------------------- #
grp_lira_scores_list = []
grp_rmia_scores_list = []
for grp in groups:
    grp_lira_scores_list.append([weighted_rmia_scores[i] for i in grp["model_indices"]])
    grp_rmia_scores_list.append([weighted_lira_scores[i] for i in grp["model_indices"]])
# --------------- Calculate accuracy band group ROCs --------------- #
# LiRA & RMIA scores
lira_fpr_curves, lira_tpr_curves = [], []
rmia_fpr_curves, rmia_tpr_curves = [], []
for l_scores, r_scores in zip(grp_lira_scores_list, grp_rmia_scores_list):
    l_fpr, l_tpr = calculate_group_roc_pooled(l_scores, baseline_inmask)
    lira_fpr_curves.append(l_fpr)
    lira_tpr_curves.append(l_tpr)

    r_fpr, r_tpr = calculate_group_roc_pooled(r_scores, baseline_inmask)
    rmia_fpr_curves.append(r_fpr)
    rmia_tpr_curves.append(r_tpr)

# --------------- Create Plots --------------- #
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

labels = []
for grp in groups:
    noise, centrality, temperature = grp.name
    labels.append(f"FbD: σ={noise:.3f}, c={centrality:.1f}, t={temperature:.2f}, ā={grp['accuracy_mean']:.2f}")
    
# --------------- LiRA Plot --------------- #
plot_roc(ax=axes[0], fpr_curves=lira_fpr_curves, tpr_curves=lira_tpr_curves, labels=labels, 
         b_fpr=baseline_fpr_curve_lira, b_tpr=baseline_tpr_curve_lira, b_label=f"Baseline: a={baseline_accuracy:.2f}", 
         title="LiRA ROC Curves")
# --------------- RMIA Plot --------------- #
plot_roc(ax=axes[1], fpr_curves=rmia_fpr_curves, tpr_curves=rmia_tpr_curves, labels=labels, 
         b_fpr=baseline_fpr_curve_lira, b_tpr=baseline_tpr_curve_lira, b_label=f"Baseline: a={baseline_accuracy:.2f}", 
         title="RMIA ROC Curves")
plt.show()

group_title = "roc_fbd_groups"
plot_path = os.path.join(os.path.join(study_path, study_name), "plot")
save_path = os.path.join(plot_path, group_title + ".tex")

fbd_grps_rounded = fbd_grps.round({
    "accuracy_mean": 3,
    "accuracy_std": 4,
    "noise_mean": 4,
    "noise_std": 4,
    "centrality_mean": 3,
    "centrality_std": 3,
    "temperature_mean": 3,
    "temperature_std": 3
})
# Convert to LaTeX table
latex_table = fbd_grps_rounded.to_latex(
    index=True,             # keeps the acc_band as the first column
    caption="FbD hyperparameter summary by accuracy band",
    label="tab:fbd_summary",
    float_format="%.3f",    # ensures consistent decimal places
    column_format="lccccccc" # left for acc_band, centered for other columns
)
# Save to file
with open(save_path, "w") as f:
    f.write(latex_table)

savePlot(fig, group_title, study_name, study_path)

In [None]:
# ------------------------------------ #
#   Create & Calculate Pareto Groups   #
# ------------------------------------ #
group_001 = (
    df.groupby(["noise", "centrality", "temperature"])
    .agg(
        accuracy_mean=("accuracy", "mean"),
        accuracy_std=("accuracy", "std"),

        rmia_mean=("tau_0.001_rmia", "mean"),
        rmia_std=("tau_0.001_rmia", "std"),
        rmia_median=("tau_0.001_rmia", "median"),

        lira_mean=("tau_0.001_lira", "mean"),
        lira_std=("tau_0.001_lira", "std"),
        lira_median=("tau_0.001_lira", "median"),
        
        n=("study", "size")
    )
    .reset_index()
)

group_01 = (
    df.groupby(["noise", "centrality", "temperature"])
    .agg(
        accuracy_mean=("accuracy", "mean"),
        accuracy_std=("accuracy", "std"),

        rmia_mean=("tau_0.01_rmia", "mean"),
        rmia_std=("tau_0.01_rmia", "std"),
        rmia_median=("tau_0.01_rmia", "median"),

        lira_mean=("tau_0.01_lira", "mean"),
        lira_std=("tau_0.01_lira", "std"),
        lira_median=("tau_0.01_lira", "median"),
        
        n=("study", "size")
    )
    .reset_index()
)

group_1 = (
    df.groupby(["noise", "centrality", "temperature"])
    .agg(
        accuracy_mean=("accuracy", "mean"),
        accuracy_std=("accuracy", "std"),

        rmia_mean=("tau_0.1_rmia", "mean"),
        rmia_std=("tau_0.1_rmia", "std"),
        rmia_median=("tau_0.1_rmia", "median"),

        lira_mean=("tau_0.1_lira", "mean"),
        lira_std=("tau_0.1_lira", "std"),
        lira_median=("tau_0.1_lira", "median"),
        
        n=("study", "size")
    )
    .reset_index()
)
grps = [group_1, group_01, group_001]
cut_group = []
for grp in grps:
    cut_group.append(grp[grp["n"]>=3])

paretos = []
for c_grp in cut_group:
    xy_groups = [
        c_grp[["lira_mean", "accuracy_mean"]].to_numpy(),
        c_grp[["rmia_mean", "accuracy_mean"]].to_numpy()
    ]
    tst_list = []
    for xy, attack in zip(xy_groups, ["lira", "rmia"]):
        n = xy.shape[0]
        is_dominated = np.zeros(n, dtype=bool)
    
        for i in range(n):
            # j dominates i if j has lower/equal x and higher/equal y,
            # and strictly better in at least one objective
            dominates_i = ((xy[:, 0] <= xy[i, 0]) & (xy[:, 1] >= xy[i, 1]) &
                        ((xy[:, 0] < xy[i, 0]) | (xy[:, 1] > xy[i, 1])))
            is_dominated[i] = np.any(dominates_i)
    
        frontier_mask = ~is_dominated
        if attack == "lira":
            tst = c_grp.loc[frontier_mask].sort_values("lira_mean")
        elif attack == "rmia":
            tst = c_grp.loc[frontier_mask].sort_values("rmia_mean")
            
        tst_list.append(tst)
    paretos.append(tst_list)

In [None]:
# --------------------------- #
#  Group Pareto Visualization #
# --------------------------- #
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
plot_pareto(ax=axes[0], cut_group=cut_group[0], paretos=paretos[0], grp_fpr=0.1)
plot_pareto(ax=axes[1], cut_group=cut_group[1], paretos=paretos[1], grp_fpr=0.01)
fig = plt.gcf()
fig.show()
savePlot(fig, f"pareto_frontiers", study_name, study_path)