# CIFAR examples

This code runs CIFAR10 and CIFAR100. To switch between these two datasets, update the `dataset` field in `train_config.yaml` and the `data_path` field in `audit.yaml` accordingly.


In [1]:
import os
import sys
import yaml

project_root = os.path.abspath(os.path.join(os.getcwd(), "../../.."))
sys.path.append(project_root)


In [2]:

from examples.mia.cifar.utils.cifar_data_preparation import get_cifar_dataloader
from examples.mia.cifar.utils.cifar_model_preparation import ResNet18, create_trained_model_and_metadata


# Load the config.yaml file
with open('train_config.yaml', 'r') as file:
    train_config = yaml.safe_load(file)

# Generate the dataset and dataloaders
path = os.path.join(os.getcwd(), train_config["data"]["data_dir"])

In [3]:
train_loader, test_loader = get_cifar_dataloader(path, train_config)

In [None]:
# Train the model
if not os.path.exists("target"):
    os.makedirs("target")
if train_config["data"]["dataset"] == "cifar10":
    num_classes = 10
elif train_config["data"]["dataset"] == "cifar100":
    num_classes = 100
else:
    raise ValueError("Invalid dataset name")

model = ResNet18(num_classes = num_classes)
train_acc, train_loss, test_acc, test_loss = create_trained_model_and_metadata(model, 
                                                                               train_loader, 
                                                                               test_loader, 
                                                                               train_config)

In [None]:
import matplotlib.pyplot as plt

# Plot training and test accuracy
plt.figure(figsize=(5, 4))

plt.subplot(1, 2, 1)
plt.plot(train_acc, label='Train Accuracy')
plt.plot(test_acc, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()

# Plot training and test loss
plt.subplot(1, 2, 2)
plt.plot(train_loss, label='Train Loss')
plt.plot(test_loss, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()

plt.tight_layout()
plt.show()

In [2]:
from cifar_handler import CifarInputHandler

from leakpro import LeakPro

# Read the config file
config_path = "audit.yaml"

# Prepare leakpro object
leakpro = LeakPro(CifarInputHandler, config_path)

# Run the audit 
mia_results_optuna = leakpro.run_audit(return_results=True, use_optuna=True)

2025-02-24 22:02:59,800 INFO     Target model blueprint created from ResNet18 in ./utils/cifar_model_preparation.py.
2025-02-24 22:02:59,802 INFO     Loaded target model metadata from ./target/model_metadata.pkl
2025-02-24 22:03:00,038 INFO     Loaded target model from ./target
2025-02-24 22:03:02,156 INFO     Loaded population dataset from ./data/cifar10.pkl
2025-02-24 22:03:02,157 INFO     Loaded population dataset from ./data/cifar10.pkl
2025-02-24 22:03:02,158 INFO     Image extension initialized.
  from .autonotebook import tqdm as notebook_tqdm
2025-02-24 22:03:02,690 INFO     MIA attack factory loaded.
2025-02-24 22:03:02,691 INFO     Creating shadow model handler singleton
2025-02-24 22:03:02,693 INFO     Creating distillation model handler singleton
2025-02-24 22:03:02,694 INFO     Configuring the RMIA attack
2025-02-24 22:03:02,695 INFO     Added attack: rmia
2025-02-24 22:03:02,696 INFO     Preparing attack: rmia
2025-02-24 22:03:02,696 INFO     Preparing shadow models for R

## Generate report

In [3]:
# Import and initialize ReportHandler
from leakpro.reporting.report_handler import ReportHandler

# report_handler = ReportHandler()
report_handler = ReportHandler(report_dir="./leakpro_output/results")

# Save MIA resuls using report handler
for res in mia_results_optuna:
    report_handler.save_results(attack_name=res.attack_name, result_data=res, config=res.configs)

# # Create the report by compiling the latex text
report_handler.create_report()

2025-02-24 22:03:36,894 INFO     Initializing report handler...
2025-02-24 22:03:36,895 INFO     report_dir set to: ./leakpro_output/results
2025-02-24 22:03:36,895 INFO     Saving results for rmia
2025-02-24 22:03:47,058 INFO     No results of type GIAResults found.
2025-02-24 22:03:47,060 INFO     No results of type SinglingOutResults found.
2025-02-24 22:03:47,060 INFO     No results of type InferenceResults found.
2025-02-24 22:03:47,060 INFO     No results of type LinkabilityResults found.
2025-02-24 22:03:52,199 INFO     PDF compiled


<Figure size 640x480 with 0 Axes>