In [None]:
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt

from LeakPro.leakpro.attacks.mia_attacks.lira import lira_vectorized
from LeakPro.leakpro.attacks.mia_attacks.rmia import rmia_vectorised, rmia_get_gtlprobs

import plotly.io as pio
import os
import optuna
import numpy as np

from optuna.storages import JournalStorage
from optuna.storages.journal import JournalFileBackend

from src.save_load import loadTargetSignals, loadShadowModelSignals, loadFbdStudy, savePlot
from src.utils import calculate_roc, rescale_logits

In [None]:
# ------------------------- #
#    Load Target Signals    #
# ------------------------- #
target_folder = "wideresnet-d971f7dcdc"
logits, in_mask, resc_logits, gtl_probs, metadata, metadata_pkl = loadTargetSignals(target_folder)
print(f"resc_logits first 10: {resc_logits[:10]}, in_mask first 10: {in_mask[:10]}")
print(f"resc_logits shape: {resc_logits.shape}, in_mask shape: {in_mask.shape}")

In [None]:
# ------------------------------- #
#    Load Shadow Model Signals    #
# ------------------------------- #
what_to_load = {
    "logits": False,
    "resc_logits": True,
    "gtl_probs": True,
    "in_mask": True,
    "metadata_pkl": False
}
sm_logits, sm_resc_logits, sm_gtl_probs, sm_in_mask, sm_metadata_pkl, missing_indices = loadShadowModelSignals(target_folder, what_to_load)

In [None]:
# ------------------------ #
#    Calc and Plot ROC    #
# ----------------------- #
lira_scores = lira_vectorized(resc_logits, sm_resc_logits, sm_in_mask, "carlini", online=True)
rmia_scores = rmia_vectorised(gtl_probs, sm_gtl_probs, sm_in_mask, online=True, use_gpu_if_available=True)
lira_tpr, lira_fpr = calculate_roc(lira_scores, in_mask, clip=True)
rmia_tpr, rmia_fpr = calculate_roc(rmia_scores, in_mask, clip=True)
print(f"lira_scores first 10: {lira_scores[:10]}")
print(f"lira_scores shape: {lira_scores.shape}")
baseline_accuracy = metadata_pkl.test_result.accuracy

# ------- LiRA ROC ------- #
fig, ax = plt.subplots(figsize=(12, 5))
random_fpr = np.logspace(-5, 0, 500)   # from 1e-5 to 1 on log scale
random_tpr = random_fpr.copy()         # TPR = FPR

ax.plot(random_fpr, random_tpr, "--", color="red", alpha=0.7, label="Random Guessing")
ax.plot(lira_fpr, lira_tpr, color="cornflowerblue", label=f"LiRA")
ax.plot(rmia_fpr, rmia_tpr, color="orange", label=f"RMIA")

ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlim(1e-5, 1)
ax.set_ylim(1e-5, 1)
ax.set_xlabel("FPR")
ax.set_ylabel("TPR")
ax.grid(True, alpha=0.3)
ax.set_title("Baseline ROC")
ax.legend(loc="lower right")
savePlot(fig, "ROC_baseline", target_folder, "target")

In [None]:
# ----------------------------------#
#    View best values in journal    #
# ----------------------------------#

study_path = os.path.join("study/cifar10-baseline-2-c3b5ae7b140f363a", "journal.log")
storage = JournalStorage(JournalFileBackend(study_path))

# Show all summaries in this journal
study_summaries = optuna.get_all_study_summaries(storage=storage)
for summary in study_summaries:
    print(summary.study_name, summary.n_trials, summary.best_trial.value)

# Load the study
study_name = "cifar10-baseline-2"
study = optuna.load_study(storage=storage, study_name=study_name)
study.best_trial.value

print("Best trial:")
best_trial = study.best_trial
print(f"Value: {best_trial.value}")
print("Params:")
for key, value in best_trial.params.items():
    print(f"  {key}: {value}")