In [None]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import Pipeline
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import StandardScaler
from opacus import PrivacyEngine
import os
import shap
import torch.nn.functional as F
from collections import OrderedDict

output_dir = "test_cases/"
FILE_SUFFIX = "_DP_TC3_SGD"
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory '{output_dir}' is ready.")
print(f"Running Test Case 3 (Optimizer = SGD) with FILE_SUFFIX='{FILE_SUFFIX}'")

df = pd.read_csv("../datasets/diabetic_data.csv")
target_col = "readmitted"
X = df.drop(columns=["encounter_id", "patient_nbr", target_col])
y = df[target_col]
X.drop(columns=['diag_1', 'diag_2', 'diag_3', 'medical_specialty', 'citoglipton', 'glimepiride-pioglitazone'], inplace=True, errors='ignore')

categorical_cols = X.select_dtypes(include=["object"]).columns.tolist()
for col in categorical_cols: X[col] = X[col].astype(str)
X_encoded = X.copy()
encoders = {}
for col in categorical_cols:
    le = LabelEncoder()
    X_encoded[col] = le.fit_transform(X_encoded[col])
    encoders[col] = le
print("Original class distribution:")
print(y.value_counts())
target_size_after_undersampling = 27432
under_strategy = {'NO': target_size_after_undersampling}
over_strategy = {"<30": target_size_after_undersampling}
under = RandomUnderSampler(sampling_strategy=under_strategy, random_state=42)
over = SMOTE(sampling_strategy=over_strategy, random_state=42, k_neighbors=5)
pipeline = Pipeline([("under", under), ("over", over)])
X_resampled_num, y_resampled = pipeline.fit_resample(X_encoded, y)
X_resampled_decoded = X_resampled_num.copy()
for col, le in encoders.items():
    X_resampled_decoded[col] = le.inverse_transform(X_resampled_num[col].astype(int))
X_resampled_ohe = pd.get_dummies(X_resampled_decoded, drop_first=True)
target_mapping = {'<30': 0, '>30': 1, 'NO': 2}
class_names = ['<30', '>30', 'NO']
y_resampled_encoded = y_resampled.map(target_mapping)
feature_names = X_resampled_ohe.columns.tolist()

print("\nNew class distribution (encoded):")
print(y_resampled_encoded.value_counts())

X_train, X_test, y_train, y_test = train_test_split(
    X_resampled_ohe, y_resampled_encoded, test_size=0.2, random_state=42, stratify=y_resampled_encoded
)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train.values, dtype=torch.long)
X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test.values, dtype=torch.long)
BATCH_SIZE = 128
train_ds = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_ds = TensorDataset(X_test_tensor, y_test_tensor)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_dim = X_train_tensor.shape[1]
num_classes = len(y_resampled_encoded.unique())

class MulticlassNN(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(64, 32), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(32, num_classes)
        )
    def forward(self, x): return self.net(x)

print("\n--- Training DP-SGD Target Model (Opacus) ---")
target_delta = 1e-5
max_grad_norm = 1.0
NOISE_MULTIPLIER = 0.9
dp_model = MulticlassNN(input_dim, num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(dp_model.parameters(), lr=1e-3)

epochs = 50
privacy_engine = PrivacyEngine()
dp_model, optimizer, train_loader = privacy_engine.make_private(
    module=dp_model,
    optimizer=optimizer,
    data_loader=train_loader,
    noise_multiplier=NOISE_MULTIPLIER,
    max_grad_norm=max_grad_norm,
)

final_epsilon = 0.0
for epoch in range(1, epochs + 1):
    dp_model.train()
    total_loss = 0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        outputs = dp_model(xb)
        loss = criterion(outputs, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    try:
        epsilon = privacy_engine.get_epsilon(target_delta)
        final_epsilon = epsilon
    except Exception:
        epsilon = 99.0
        final_epsilon = 99.0
    
    if epoch % 10 == 0 or epoch == epochs:
        print(f"Epoch {epoch}/{epochs}, Loss: {total_loss/len(train_loader):.4f}, Epsilon: {epsilon:.2f} (Delta={target_delta:.0e})")

print("\n--- Evaluating DP-SGD Target Model Utility ---")
eval_model = dp_model._module
eval_model.eval()
all_preds, all_true = [], []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        outputs = eval_model(xb)
        _, predicted = torch.max(outputs.data, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_true.extend(yb.cpu().numpy())

accuracy = accuracy_score(all_true, all_preds)
report = classification_report(all_true, all_preds, target_names=class_names)
cm = confusion_matrix(all_true, all_preds)

print("\n" + "---" * 15)
print("### Model Evaluation Metrics (DP-SGD Teacher) ###")
print(f"\n**Accuracy:** {accuracy:.4f}")
print("\n**Confusion Matrix:**")
print(cm)
print("\n**Classification Report (Precision, Recall, F1-Score):**")
print(report)
print("\n**Privacy Metrics:**")
print(f"  - Noise Multiplier: {NOISE_MULTIPLIER}")
print(f"  - Final Epsilon:    {final_epsilon:.4f}")
print(f"  - Delta:            {target_delta}")
print("---" * 15 + "\n")

DP_MODEL_PATH = os.path.join(output_dir, f"dp_target_model{FILE_SUFFIX}.pth")
torch.save(eval_model.state_dict(), DP_MODEL_PATH)
print(f"\nTrained DP model saved to {DP_MODEL_PATH}")

print("\nStep 5: Generating SHAP explanations for DP-SGD model.")

dp_model_unwrapped = dp_model._module
clean_model = MulticlassNN(input_dim, num_classes).to(device)
clean_model.load_state_dict(dp_model_unwrapped.state_dict())
clean_model.eval()

background_size = min(100, len(X_train_tensor))
background_data = X_train_tensor[np.random.choice(len(X_train_tensor), background_size, replace=False)].to(device)
X_explain = X_test_tensor[:20].to(device)

explainer = shap.DeepExplainer(clean_model, background_data)
shap_values_list = explainer.shap_values(X_explain)

mean_abs_shap = np.mean(np.abs(np.array(shap_values_list)), axis=(0, 2))
feature_importance = dict(zip(feature_names, mean_abs_shap))

print("\n### SHAP Feature Importance (DP-SGD Teacher) ###")
print("\nAverage absolute SHAP values (feature importance for DP-SGD Model):")
sorted_importance = sorted(feature_importance.items(), key=lambda item: item[1], reverse=True)
for feature, value in sorted_importance[:10]:
    print(f"{feature}: {value:.4f}")

print("\nSaving SHAP and attack results.")

shap_df = pd.DataFrame({
    'feature': feature_names,
    'mean_abs_shap': mean_abs_shap
}).sort_values('mean_abs_shap', ascending=False)

shap_csv_path = os.path.join(output_dir, f'shap_feature_importance_dpsgd{FILE_SUFFIX}.csv')
shap_df.to_csv(shap_csv_path, index=False)
print(f"Saved SHAP feature importance to '{shap_csv_path}'.")

print("\n--- Starting MIA Attack on DP-SGD Model ---")

DP_MODEL_PATH = os.path.join(output_dir, f"dp_target_model{FILE_SUFFIX}.pth")
try:
    dp_target_model = MulticlassNN(input_dim, num_classes).to(device)
    dp_target_model.load_state_dict(torch.load(DP_MODEL_PATH, weights_only=True)) 
    dp_target_model.eval()
    print(f"\nSuccessfully loaded model for MIA from {DP_MODEL_PATH}")
except FileNotFoundError:
    print(f"ERROR: Model file not found at {DP_MODEL_PATH}. Run the DP-SGD training script first.")
except Exception as e:
    print(f"Error loading model: {e}")

class AttackNN_AllLogits(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 32), nn.ReLU(),
            nn.Linear(32, 16), nn.ReLU(),
            nn.Linear(16, 1),
        )
    def forward(self, x): return self.net(x)

def create_attack_dataset_all_logits(model, train_tensor, test_tensor):
    model.eval()
    with torch.no_grad():
        train_outputs = model(train_tensor.to(device)).cpu()
        test_outputs = model(test_tensor.to(device)).cpu()
    attack_X = torch.cat([train_outputs, test_outputs], dim=0)
    train_labels = torch.ones(len(train_outputs))
    test_labels = torch.zeros(len(test_outputs))
    attack_y = torch.cat([train_labels, test_labels], dim=0)
    return attack_X, attack_y

def run_mia_trial(attack_X, attack_y, random_seed):
    attack_X_train, attack_X_test, attack_y_train, attack_y_test = train_test_split(
        attack_X, attack_y, test_size=0.3, random_state=random_seed, stratify=attack_y
    )
    attack_train_ds = TensorDataset(attack_X_train, attack_y_train)
    attack_train_loader = DataLoader(attack_train_ds, batch_size=64, shuffle=True)
    n0 = np.sum(attack_y_train.numpy() == 0)
    n1 = np.sum(attack_y_train.numpy() == 1)
    pos_weight_val = n0 / n1
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight_val, dtype=torch.float32).to(device))
    attack_model = AttackNN_AllLogits(attack_X.shape[1]).to(device)
    optimizer = optim.Adam(attack_model.parameters(), lr=1e-3)
    for epoch in range(50):
        attack_model.train()
        for xb, yb in attack_train_loader:
            xb, yb = xb.to(device), yb.to(device).unsqueeze(1)
            optimizer.zero_grad()
            loss = criterion(attack_model(xb), yb)
            loss.backward()
            optimizer.step()
    attack_model.eval()
    all_preds, all_true = [], []
    with torch.no_grad():
        preds_logits = attack_model(attack_X_test.to(device)).squeeze()
        predicted_classes = (preds_logits > 0.0).float()
        all_preds.extend(predicted_classes.cpu().numpy())
        all_true.extend(attack_y_test.cpu().numpy())
    report_dict = classification_report(all_true, all_preds, output_dict=True, zero_division=0)
    TPR = report_dict['1.0']['recall']
    FPR = 1 - report_dict['0.0']['recall']
    advantage = TPR - FPR
    return advantage

print("Step 1: Creating attack dataset for the DP-SGD Model.")
attack_X_dp, attack_y_dp = create_attack_dataset_all_logits(dp_target_model, X_train_tensor, X_test_tensor)
num_trials = 10
all_advantages = []
print(f"\nStep 2: Running {num_trials} MIA trials on the DP-SGD Model...")
for i in range(num_trials):
    seed = 42 + i
    advantage = run_mia_trial(attack_X_dp, attack_y_dp, seed)
    print(f"  Trial {i+1}/{num_trials} (Seed: {seed}) -> MIA Advantage: {advantage:.4f}")
    all_advantages.append(advantage)
mean_advantage = np.mean(all_advantages)
std_advantage = np.std(all_advantages)
print("\n" + "---" * 10)
print("Final Robust MIA Results for DP-SGD Model")
print(f"  Mean MIA Advantage: {mean_advantage:.4f}")
print(f"  Std Dev of MIA Advantage: {std_advantage:.4f}")
print("---" * 10 + "\n")

mia_save_path = os.path.join(output_dir, f"dpsgd_mia_advantage_robust{FILE_SUFFIX}.npy")
np.save(mia_save_path, np.array([mean_advantage, std_advantage]))
print(f"Successfully executed MIA. Mean and Std Dev saved to '{mia_save_path}'.")

print("\n--- Starting Knowledge Distillation Target Generation ---")
class TeacherNN(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, 64)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.layer2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)
        self.layer3 = nn.Linear(32, num_classes)
        
    def forward(self, x):
        features = self.relu1(self.layer1(x))
        x = self.dropout1(features)
        x = self.relu2(self.layer2(x))
        x = self.dropout2(x)
        logits = self.layer3(x)
        return logits, features

DP_MODEL_PATH = os.path.join(output_dir, f"dp_target_model{FILE_SUFFIX}.pth")
print(f"Step 2: Loading trained Teacher Model from {DP_MODEL_PATH}")
teacher_model = TeacherNN(input_dim, num_classes).to(device)
state_dict = torch.load(DP_MODEL_PATH, map_location=device, weights_only=True)
new_state_dict = OrderedDict()
key_map = {
    'net.0.weight': 'layer1.weight', 'net.0.bias': 'layer1.bias',
    'net.3.weight': 'layer2.weight', 'net.3.bias': 'layer2.bias',
    'net.6.weight': 'layer3.weight', 'net.6.bias': 'layer3.bias'
}
for old_key, new_key in key_map.items():
    if old_key in state_dict:
        new_state_dict[new_key] = state_dict[old_key]
teacher_model.load_state_dict(new_state_dict)
teacher_model.eval()

print("Step 3: Generating privacy-preserving Logits and Feature Targets.")
with torch.no_grad():
    Y_soft_logits, Y_teacher_features = teacher_model(X_train_tensor.to(device))
    Y_soft_logits = Y_soft_logits.cpu().numpy()
    Y_teacher_features = Y_teacher_features.cpu().numpy()

kd_data_path = os.path.join(output_dir, f'kd_mc_targets_with_features{FILE_SUFFIX}.npz')
np.savez(kd_data_path,
         X_train=X_train_scaled,
         Y_train=y_train.values,
         Y_soft_logits=Y_soft_logits,
         Y_teacher_features=Y_teacher_features,
         X_test=X_test_scaled,
         Y_test=y_test.values,
         input_dim=np.array([input_dim]),
         num_classes=np.array([num_classes])
)
print("\nTargets and features generated successfully.")
print(f"  - Saved to: {kd_data_path}")
print(f"  - Y_soft_logits shape: {Y_soft_logits.shape}")
print(f"  - Y_teacher_features shape: {Y_teacher_features.shape}")

print("\n--- Starting KD Student Model Training ---")

ALPHA = 0.7
TEMPERATURE = 2.0
BETA_FEATURE_LOSS = 0.5
LEARNING_RATE = 1e-3
EPOCHS = 50

print("Step 1: Loading data, logits, and feature targets.")
try:

    kd_data_path = os.path.join(output_dir, f'kd_mc_targets_with_features{FILE_SUFFIX}.npz')
    kd_data = np.load(kd_data_path)
    X_train_kd = kd_data['X_train']
    Y_train_kd = kd_data['Y_train']
    Y_soft_logits_kd = kd_data['Y_soft_logits']
    Y_teacher_features_kd = kd_data['Y_teacher_features']
    X_test_kd = kd_data['X_test']
    Y_test_kd = kd_data['Y_test']
    INPUT_DIM_KD = kd_data['input_dim'][0]
    NUM_CLASSES_KD = kd_data['num_classes'][0]
    print(f"Successfully loaded data from {kd_data_path}")
except FileNotFoundError:
    print(f"ERROR: {kd_data_path} not found. Please run the target generation script first.")

class StudentNN(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, 64)
        self.relu1 = nn.ReLU()
        self.layer2 = nn.Linear(64, num_classes)
        
    def forward(self, x):
        features = self.relu1(self.layer1(x))
        logits = self.layer2(features)
        return logits, features

student_model = StudentNN(INPUT_DIM_KD, NUM_CLASSES_KD).to(device)
train_ds_kd = TensorDataset(
    torch.tensor(X_train_kd, dtype=torch.float32),
    torch.tensor(Y_train_kd, dtype=torch.long),
    torch.tensor(Y_soft_logits_kd, dtype=torch.float32),
    torch.tensor(Y_teacher_features_kd, dtype=torch.float32)
)
train_loader_kd = DataLoader(train_ds_kd, batch_size=BATCH_SIZE, shuffle=True)
test_ds_kd = TensorDataset(torch.tensor(X_test_kd, dtype=torch.float32), torch.tensor(Y_test_kd, dtype=torch.long))
test_loader_kd = DataLoader(test_ds_kd, batch_size=BATCH_SIZE)

def combined_loss_fn(student_logits, student_features, y_true, teacher_logits, teacher_features):
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(student_logits / TEMPERATURE, dim=1),
        F.softmax(teacher_logits / TEMPERATURE, dim=1)
    ) * (TEMPERATURE ** 2)
    hard_loss = F.cross_entropy(student_logits, y_true)
    distillation_loss = ALPHA * soft_loss + (1.0 - ALPHA) * hard_loss
    feature_loss = nn.MSELoss()(student_features, teacher_features)
    return distillation_loss + BETA_FEATURE_LOSS * feature_loss
    
print(f"Step 2: Training new Student Model with Feature Matching...")
optimizer = optim.Adam(student_model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    student_model.train()
    total_loss = 0.0
    for xb, y_true_b, teacher_logits_b, teacher_features_b in train_loader_kd:
        xb, y_true_b, teacher_logits_b, teacher_features_b = (
            xb.to(device), y_true_b.to(device),
            teacher_logits_b.to(device), teacher_features_b.to(device)
        )
        optimizer.zero_grad()
        student_logits, student_features = student_model(xb)
        loss = combined_loss_fn(student_logits, student_features, y_true_b, teacher_logits_b, teacher_features_b)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
    student_model.eval()
    all_preds_kd, all_true_kd = [], []
    with torch.no_grad():
        for xb_test, Y_test_batch in test_loader_kd:
            outputs, _ = student_model(xb_test.to(device))
            _, predicted = torch.max(outputs.data, 1)
            all_preds_kd.extend(predicted.cpu().numpy())
            all_true_kd.extend(Y_test_batch.numpy())
    test_accuracy_kd = accuracy_score(all_true_kd, all_preds_kd)
    
    if (epoch + 1) % 10 == 0:
        print(f"  Epoch [{epoch+1}/{EPOCHS}] - Train Loss: {total_loss/len(X_train_kd):.4f} - Test Accuracy: {test_accuracy_kd:.4f}")

STUDENT_MODEL_PATH = os.path.join(output_dir, f"kd_mc_student_model{FILE_SUFFIX}.pt")
torch.save(student_model.state_dict(), STUDENT_MODEL_PATH)
print("\nNew KD Student Model (with feature matching) trained and saved.")
print(f"  - Saved to: {STUDENT_MODEL_PATH}")
print(f"Final Test Accuracy: {test_accuracy_kd:.4f}.")

print("\n--- Starting KD Student Model Evaluation ---")

print("Step 1: Loading data and trained Student Model.")

try:

    kd_data_path = os.path.join(output_dir, f'kd_mc_targets_with_features{FILE_SUFFIX}.npz')
    kd_data = np.load(kd_data_path)
    X_train_scaled_eval = kd_data['X_train']
    X_test_scaled_eval = kd_data['X_test']
    Y_test_eval = kd_data['Y_test']
    INPUT_DIM_EVAL = kd_data['input_dim'][0]
    NUM_CLASSES_EVAL = kd_data['num_classes'][0]
    
    X_train_tensor_eval = torch.tensor(X_train_scaled_eval, dtype=torch.float32)
    X_test_tensor_eval = torch.tensor(X_test_scaled_eval, dtype=torch.float32)
    Y_test_tensor_eval = torch.tensor(Y_test_eval, dtype=torch.long)

    teacher_shap_csv_path = os.path.join(output_dir, f'shap_feature_importance_dpsgd{FILE_SUFFIX}.csv')
    feature_names_df = pd.read_csv(teacher_shap_csv_path)
    feature_names_eval = feature_names_df['feature'].tolist()
    print("Successfully loaded all data and SHAP features.")

except FileNotFoundError as e:
    print(f"ERROR: A required file is missing. Please run previous scripts.\n{e}")
    raise SystemExit("Stopping script due to missing file.")
except Exception as e:
    print(f"An unexpected error occurred: {e}")
    raise SystemExit("Stopping script due to unexpected error.")

class StudentNN_Eval(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.layer1 = nn.Linear(input_dim, 64)
        self.relu1 = nn.ReLU()
        self.layer2 = nn.Linear(64, num_classes)
        
    def forward(self, x):
        features = self.relu1(self.layer1(x))
        logits = self.layer2(features)
        return logits

STUDENT_MODEL_PATH = os.path.join(output_dir, f"kd_mc_student_model{FILE_SUFFIX}.pt")
student_model_eval = StudentNN_Eval(INPUT_DIM_EVAL, NUM_CLASSES_EVAL).to(device)
student_model_eval.load_state_dict(torch.load(STUDENT_MODEL_PATH, map_location=device, weights_only=True))
student_model_eval.eval()
print(f"Successfully loaded student model from {STUDENT_MODEL_PATH}")

print("\nStep 2: Checking Utility Advantage (Test Accuracy).")
test_ds_eval = TensorDataset(X_test_tensor_eval, Y_test_tensor_eval)
test_loader_eval = DataLoader(test_ds_eval, batch_size=128)
all_preds_eval, all_true_eval = [], []
with torch.no_grad():
    for xb_test, Y_test_batch in test_loader_eval:
        outputs = student_model_eval(xb_test.to(device))
        _, predicted = torch.max(outputs.data, 1)
        all_preds_eval.extend(predicted.cpu().numpy())
        all_true_eval.extend(Y_test_batch.numpy())

accuracy_eval = accuracy_score(all_true_eval, all_preds_eval)
report_eval = classification_report(all_true_eval, all_preds_eval, target_names=class_names)
cm_eval = confusion_matrix(all_true_eval, all_preds_eval)

print("\n" + "---" * 15)
print("### Model Evaluation Metrics (KD Student) ###")
print(f"\n**Accuracy:** {accuracy_eval:.4f}")
print("\n**Confusion Matrix:**")
print(cm_eval)
print("\n**Classification Report (Precision, Recall, F1-Score):**")
print(report_eval)
print("\n**Privacy Metrics:**")
print(f"  - Inherited from Teacher (Epsilon: {final_epsilon:.4f}, Delta: {target_delta})")
print("---" * 15 + "\n")

print("\nStep 3: Checking Privacy Advantage (Membership Inference Attack).")
attack_X_student, attack_y_student = create_attack_dataset_all_logits(student_model_eval, X_train_tensor_eval, X_test_tensor_eval)
num_trials = 10
all_advantages = []
print(f"  Running {num_trials} MIA trials on the KD Student Model...")
for i in range(num_trials):
    seed = 42 + i
    advantage = run_mia_trial(attack_X_student, attack_y_student, seed)
    print(f"    Trial {i+1}/{num_trials} (Seed: {seed}) -> MIA Advantage: {advantage:.4f}")
    all_advantages.append(advantage)
mean_advantage = np.mean(all_advantages)
std_advantage = np.std(all_advantages)
print("\n  Final Robust MIA Results for KD Student Model")
print(f"    Mean MIA Advantage: {mean_advantage:.4f}")
print(f"    Std Dev of MIA Advantage: {std_advantage:.4f}")

print("\nStep 4: Checking Explainability Advantage (SHAP Stability).")
background_size = min(200, len(X_train_tensor_eval))
background_data = X_train_tensor_eval[np.random.choice(len(X_train_tensor_eval), background_size, replace=False)].to(device)
X_explain = X_test_tensor_eval[:100].to(device)

explainer_student = shap.DeepExplainer(student_model_eval, background_data)
shap_values_list_student = explainer_student.shap_values(X_explain)

mean_abs_shap_student = np.mean(np.abs(np.array(shap_values_list_student)), axis=(0, 2))

student_shap_df = pd.DataFrame({
    'feature': feature_names_eval,
    'mean_abs_shap': mean_abs_shap_student
})

student_shap_csv_path = os.path.join(output_dir, f'kd_mc_student_shap_importance{FILE_SUFFIX}.csv')
student_shap_df.to_csv(student_shap_csv_path, index=False)
print(f"\nSaved Student SHAP feature importance to '{student_shap_csv_path}'.")

teacher_shap_df = pd.read_csv(teacher_shap_csv_path)

print("\n### SHAP Feature Importance (KD Student) ###")
print("\nAverage absolute SHAP values (feature importance for KD Student Model):")
student_importance = dict(zip(feature_names_eval, mean_abs_shap_student))
sorted_student_importance = sorted(student_importance.items(), key=lambda item: item[1], reverse=True)
for feature, value in sorted_student_importance[:10]:
    print(f"{feature}: {value:.4f}")

print(f"\n--- Comparison: Top features from DP-SGD Teacher ---")
teacher_top = teacher_shap_df.sort_values('mean_abs_shap', ascending=False)
print(teacher_top[['feature', 'mean_abs_shap']].head(5).to_string(index=False))

print("\nAll DP-SGD and KD scripts executed.")