In [None]:
import json
import os
import numpy as np
import matplotlib.pyplot as plt

from src.utils import print_yaml, bootstrap_sampling, interpolate_unique, rescale_logits, calculate_tau
from src.save_load import loadAudit, savePlot, loadFbdStudy, loadTargetSignals, loadShadowModelSignals
from LeakPro.leakpro.attacks.mia_attacks.lira import lira_vectorized
from LeakPro.leakpro.attacks.mia_attacks.rmia import rmia_vectorised, rmia_get_gtlprobs

In [None]:
# -------------------------- #
#   Load fbd study results   #
# -------------------------- #
study_name = "cifar10-resnet-fbd-815409b641"
study_path = os.path.join("study", study_name)
metadata = []
gtl_probs = []
resc_logits = []

metadata, fbd_trial_results, gtl_probs, resc_logits, labels = loadFbdStudy(study_name, metadata=True, gtl=True, logits=True)
target_folder = metadata["study"]["target_folder"]

In [None]:
# -------------------------------- #
#   Load baseline target signals   #
# -------------------------------- #
target_folder = metadata["study"]["target_folder"]
baseline_target_logits, baseline_target_inmask, baseline_target_metadata = loadTargetSignals(target_folder)

# Rescale baseline target logits
baseline_resc_logits = rescale_logits(baseline_target_logits, labels)
print(f"Target resc_logits: {baseline_resc_logits[:10]}, shape: {baseline_resc_logits.shape}")

# Calculate the GTL Probabilities for the target logits
target_gtl_probs = rmia_get_gtlprobs(baseline_target_logits, labels)
print(f"Target gtl_probs: {target_gtl_probs[:10]}, shape: {target_gtl_probs.shape}")

In [None]:
# ----------------------------- #
#   Load shadow model signals   #
# ----------------------------- #
shadow_logits, shadow_inmask = loadShadowModelSignals(target_folder)
# Rescale and calc gtl for shadow models

# Calculate the GTL Probabilities for shadow model logits
N, M, C = shadow_logits.shape
shadow_gtl_probs_list = []

for m in range(M):
    model_logits = shadow_logits[:, m, :]  # shape (N, C)
    probs = rmia_get_gtlprobs(model_logits, labels)
    shadow_gtl_probs_list.append(probs)
    print(f"{len(shadow_gtl_probs_list)} shadow gtl probs calculated")

shadow_gtl_probs = np.stack(shadow_gtl_probs_list, axis=1)  # shape = (N, M)

In [None]:
# ----------------------------- #
#   Audit the baseline target   #
# ----------------------------- #
# Audit the baseline target with LiRA
#lira_scores = lira_vectorized()
# Audit the baseline target with RMIA
rmia_scores = rmia_vectorised(target_gtl_probs, shadow_gtl_probs, shadow_inmask, online=True, use_gpu_if_available=True)
print(f"rmia_scores first 10: {rmia_scores[:10]}")

# ----------------------------- #
#   Audit the weighted target   #
# ----------------------------- #
# Audit all weighted models with RMIA
weighted_rmia_scores = []
for gtl in gtl_probs:
    weighted_rmia_scores.append(rmia_vectorised(gtl, shadow_gtl_probs, shadow_inmask, online=True, use_gpu_if_available=True))
    
print(f"rmia_scores_weighted count: {len(weighted_rmia_scores)}")

In [None]:
# ----------------- #
#   Calculate tau   #
# ----------------- #
# Baseline
fpr = 0.1
tau_baseline = calculate_tau(rmia_scores, baseline_target_inmask, fpr)
print(f"baseline tau: {tau_baseline} at fpr: {fpr}")

# Weighted
weigted_taus = []
for w_rmia_score in weighted_rmia_scores:
    tau_weighted = calculate_tau(w_rmia_score, baseline_target_inmask, fpr)
    weigted_taus.append(tau_weighted)
    print(f"weighted tau: {tau_weighted} at fpr: {fpr}")

In [None]:
# --------------------------------- #
#   Visualize Model Study Results   #
# --------------------------------- #
accuracy = [res.accuracy for res in fbd_trial_results]
tau_rmia = [res.tau for res in fbd_trial_results]   # tau in this context is log(tauc_fbd@0.1/tauc_ref@0.1)

# Plot
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# --- Left plot: your research tau --- #
ax = axes[0]
ax.scatter(tau_rmia, accuracy)
ax.set_xlabel("tau@0.1 (log(tauc_fbd@0.1 / tauc_ref@0.1))")
ax.set_ylabel("Accuracy")
ax.set_title("FbD study")
ax.grid(True, alpha=0.3)

# --- Right plot: true tau = log(tpr@fpr / fpr) --- #
ax2 = axes[1]
ax2.scatter(weigted_taus, accuracy, marker='o')
ax2.set_xlabel("τ@0.1")
ax2.set_ylabel("tau@0.1 (log(TPR@0.1 / 0.1))")
ax2.set_title("True τ at FPR=0.1")
ax2.grid(True, alpha=0.3)
plt.setp(ax2.get_xticklabels(), rotation=45, ha="right")

plt.tight_layout()
plt.show()

In [None]:
# -------------------------- #
#   Visualize LiRA results   #
# -------------------------- #