## Membership Inference Attacks with LiRA

### Import all necessary libraries

In [1]:
seed = 420

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['MPLCONFIGDIR'] = os.getcwd()+'/.cache/configs/'
os.environ['TORCH_HOME'] = os.getcwd()+'/.cache/torch'
os.environ['TRANSFORMERS_CACHE'] = os.getcwd()+'/.cache/huggingface/hub/'

import warnings
warnings.simplefilter(action='ignore', category=Warning)

import absl.logging
absl.logging.set_verbosity(absl.logging.ERROR)

import numpy as np
np.random.seed(seed)

import random
random.seed(seed)

import gc
import logging
import time
import copy
from typing import Optional, Tuple, List

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, default_collate, ConcatDataset, Subset
from torch.optim.lr_scheduler import SequentialLR, CosineAnnealingLR
from torchvision import datasets as dset
from torchvision.transforms import v2
from datasets import load_from_disk
from medmnist import DermaMNIST, BloodMNIST
from sklearn.metrics import roc_curve, auc
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import pandas as pd
from thop import profile
from itertools import product

torch.manual_seed(42)
torch.cuda.manual_seed(42)

In [2]:
device='cuda' if torch.cuda.is_available() else 'cpu'

### Choose a dataset and insert its OFA Network Config (available in the table below) and its number of classes

| Dataset     | MBNv3 network_config                                                                                                       |
|-------------|----------------------------------------------------------------------------------------------------------------------------|
| cifar10     | 10_4-7-6_4-7-6_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_3-7-4_3-7-4_3-7-4_0-0-0_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4 |
| cifar100    | 10_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_3-7-4_3-7-4_3-7-4_0-0-0_4-7-4_4-7-4_4-7-6_4-7-4 |
| pets        | 10_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4 |
| tiny        | 10_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-6_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-6 |
| cars        | 10_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-3_4-7-4_4-7-4_4-7-4_4-7-4_4-5-6_4-5-4 |
| food        | 10_4-7-4_4-7-4_4-7-6_4-7-4_4-7-4_4-7-4_4-7-6_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4 |
| dermamnist  | 10_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-6 |
| bloodmnist  | 10_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4 |
| stl         | 10_4-7-4_4-7-4_4-7-6_4-7-4_4-7-4_4-7-4_4-7-6_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4 |
| imagenette  | 10_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_3-7-4_3-7-4_3-7-4_0-0-0_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4 |
| caltech101  | 10_4-7-4_4-7-6_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4 |
| imagewoof   | 10_4-7-4_4-7-4_4-7-4_4-7-6_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4 |

In [None]:
# Define dataset name
dataset_name = "imagenette"

# Specify the OFA (Once-for-All) configuration string
ofa_config = "10_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_3-7-4_3-7-4_3-7-4_0-0-0_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4_4-7-4"

# Define the number of classes in the dataset
num_classes = 10

# Print current parameters for verification
print(f"Parameters defined:\n"
      f"  Dataset Name:\t\t{dataset_name}\n"
      f"  Number of Classes:\t{num_classes}")

### Set the paths to the dataset and the output directory

In [None]:
# Set root directory for model storage
main_path = "../storage/shadow_models/"

# Define output directory for shadow models
output_path = main_path + dataset_name

# Set path for storing models related to LIRA
models_path = output_path + "/lira/"

# Specify directory containing the dataset files
dataset_path = "../datasets/" + dataset_name

# Print all configured paths for verification
print(f"Configured paths:\n"
      f"  Main Path:    {main_path}\n"
      f"  Output Path:  {output_path}\n"
      f"  Models Path:  {models_path}\n"
      f"  Dataset Path: {dataset_path}")

### Set the number of shadow models and the challenge size

In [None]:
# Set the number of shadow models (recommended to use 256 based on literature)
num_shadows = 256

# Define the number of samples to use during the attack for each set (1000 is ok)
challenge_size = 1000

# Print current parameters for verification
print(f"Parameters defined:\n"
      f"  Number of Shadow Modelss:\t{num_shadows}\n"
      f"  Number of Challenge Samples:\t{challenge_size}")

### Class definition of some datasets

In [6]:
class OxfordPetsDataset(Dataset):
    def __init__(self, root='../datasets/pets', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = root+'/train85.pth' if split == 'train' else root+'/test15.pth'
        self.data = torch.load(self.data_dir)
        self.classes = sorted(set(label.item() for _, label in self.data)) # -> 0, 1, 2, ..., 36

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img, label = self.data[index]

        if self.transform:
            img = self.transform(img)

        return img, label


class TinyImageNetDataset(Dataset):
    def __init__(self, root='../datasets/tiny', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = root+'/train' if split == 'train' else root+'/valid'
        self.data = load_from_disk(self.data_dir)
        self.classes = self.data.features['label'].names
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        example = self.data[int(index)]
        img = example['image']
        label = example['label']

        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label

class Caltech101Dataset(Dataset):
    def __init__(self, root='../datasets/caltech101', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = root+'/train' if split == 'train' else root+'/test'
        self.data = load_from_disk(self.data_dir)
        self.classes = self.data.features['label'].names
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        example = self.data[int(index)]
        img = example['image']
        label = example['label']

        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label


class ImagewoofDataset(Dataset):
    def __init__(self, root='../datasets/imagewoof', split='train', transform=None):
        self.split = split
        self.transform = transform
        self.data_dir = root+'/train' if split == 'train' else root+'/validation'
        self.data = load_from_disk(self.data_dir)
        self.classes = self.data.features['label'].names
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        example = self.data[int(index)]
        img = example['image']
        label = example['label']

        if img.mode != 'RGB':
            img = img.convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label

### Definition of some functions

In [7]:
def get_transform(dataset_name, image_size, huge_augment, horizontal_flip, random_crop, random_erasing, train=True):

    if dataset_name == 'cifar10':
        mean = [0.4914, 0.4822, 0.4465]
        std = [0.2470, 0.2435, 0.2616]
    elif dataset_name == 'cifar100':
        mean = [0.5071, 0.4867, 0.4408]
        std = [0.2675, 0.2565, 0.2761]
    elif dataset_name == 'pets':
        mean = [0.4717, 0.4499, 0.3837]
        std = [0.2726, 0.2634, 0.2794]
    elif dataset_name == 'cars':
        mean = [0.4708, 0.4602, 0.4550]
        std = [0.2892, 0.2882, 0.2968]
    elif dataset_name == 'food':
        mean = [0.5450, 0.4435, 0.3436]
        std = [0.2695, 0.2719, 0.2766]
    elif dataset_name == 'tiny':
        mean = [0.4805, 0.4483, 0.3978]
        std = [0.2177, 0.2138, 0.2136]
    elif dataset_name == 'dermamnist': 
        mean = [0.7632, 0.5381, 0.5615]
        std = [0.0872, 0.1204, 0.1360]
    elif dataset_name == 'bloodmnist':
        mean = [0.7961, 0.6596, 0.6964]
        std = [0.2139, 0.2464, 0.0903]
    elif dataset_name == 'stl':
        mean = [0.4467, 0.4398, 0.4066]
        std = [0.2185, 0.2159, 0.2183]
    elif dataset_name == 'imagenette':
        mean = [0.4625, 0.4580, 0.4295]
        std = [0.2351, 0.2287, 0.2372]
    elif dataset_name == 'caltech101':
        mean = [0.5418, 0.5209, 0.4857]
        std = [0.2389, 0.2378, 0.2376]
    elif dataset_name == 'imagewoof': 
        mean = [0.4861, 0.4560, 0.3938]
        std = [0.2207, 0.2145, 0.2166]
    else:
        raise TypeError(f"Unknown dataset: {dataset_name}.")

    transformations = [v2.ToTensor(), v2.Resize(image_size, interpolation=Image.BICUBIC, antialias=True)] # for both train and test sets

    if train:
        if horizontal_flip:
            transformations.append(v2.RandomHorizontalFlip())

        if huge_augment == 'trivial_augment':
            transformations.append(v2.TrivialAugmentWide())
        elif huge_augment == 'auto_augment':
            transformations.append(v2.AutoAugment())
        elif huge_augment == 'rand_augment':
            transformations.append(v2.RandAugment())
        elif huge_augment == 'aug_mix':
            transformations.append(v2.AugMix())

        if random_crop:
            padding_fraction = 0.10  # 10% padding
            new_padding = int(image_size[0] * padding_fraction)
            transformations.append(v2.RandomCrop(image_size[0], padding=new_padding))

    transformations.append(v2.Normalize(mean=mean, std=std)) # also for the test set

    if train and random_erasing: # i put it after normalization for consistency with the code i used so far
        transformations.append(v2.RandomErasing(value='random'))

    transform = v2.Compose(transformations)

    return transform

def extract_minimal_ofa_network(num_classes, verbose=0):
    
    if verbose > 0:
        print('OFA model loading...')

    super_net = torch.hub.load('mit-han-lab/once-for-all', 'ofa_supernet_mbv3_w10', pretrained=True, verbose=False).eval()
    super_net.set_active_subnet(d=1, e=0, ks=3)
    model = super_net.get_active_subnet(preserve_weight=True)
    
    in_features = model.classifier.linear.in_features
    model.classifier = nn.Sequential(nn.Linear(in_features, num_classes))
    
    if verbose > 0:
        print('OFA model loaded!\n')
    
    return model

def extract_ofa_network(config, num_classes, verbose=0):
    
    if verbose > 0:
        print('OFA model loading...')
    
    split_pieces = config.split("_")
    first_el = split_pieces.pop(0)

    d = [int(piece.split("-")[0]) for i, piece in enumerate(split_pieces) if i % 4 == 0]
    k = [int(piece.split("-")[1]) if piece.split("-")[1] != '0' else 3 for piece in split_pieces]
    e = [int(piece.split("-")[2]) if piece.split("-")[2] != '0' else 3 for piece in split_pieces]

    super_net_name = 'ofa_supernet_mbv3_w10' if first_el == '10' else 'ofa_supernet_mbv3_w12'
    super_net = torch.hub.load('mit-han-lab/once-for-all', super_net_name, pretrained=True, verbose=False).eval()
    super_net.set_active_subnet(d=d, e=e, ks=k)
    model = super_net.get_active_subnet(preserve_weight=True)
    
    in_features = model.classifier.linear.in_features
    model.classifier = nn.Sequential(nn.Linear(in_features, num_classes))
    
    if verbose > 0:
        print('OFA model loaded!\n')
    
    return model

def learning_rate_scheduling(optimizer, scheduler_name=None, warmup=False, warmup_epochs=0, epochs=100):
    if scheduler_name is not None:
        scheduler_list = []

        if warmup:
            lr_lambda = lambda epoch: (epoch  / warmup_epochs) + 1e-5
            warmup_scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
            scheduler_list.append(warmup_scheduler)
            milestones = [warmup_epochs]
        else:
            milestones = []

        if scheduler_name == 'CosineAnnealingLR':
            scheduler_lr = CosineAnnealingLR(optimizer, T_max=epochs - warmup_epochs, eta_min=0, last_epoch=-1, verbose=False)
            scheduler_list.append(scheduler_lr)
            
        if scheduler_list:
            scheduler = SequentialLR(optimizer, schedulers=scheduler_list, milestones=milestones)
        else:
            scheduler = None
    else:
        scheduler = None

    return scheduler

def train_epoch(
    model, 
    train_dataloader,
    dataset_name, 
    criterion, 
    optimizer, 
    scaler, 
    mixed_precision, 
    scheduler,
    device,
    cutmix_or_mixup
):
    model.train()
    total_loss, correct_predictions, total_samples, nan_encountered = 0.0, 0, 0, False
    
    for inputs, labels in train_dataloader:
        inputs, labels = inputs.float().to(device), labels.to(device)
        optimizer.zero_grad()

        # Squeeze (if not already done in collate_fn) for dermamnist or bloodmnist: from shape [batch_size, 1] to shape [batch_size]
        if not cutmix_or_mixup and dataset_name in ['dermamnist', 'bloodmnist'] and labels.dim() == 2:
            labels = labels.squeeze(1)
        
        with torch.cuda.amp.autocast(enabled=mixed_precision):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            if torch.isnan(loss).any():
                print("==> Encountered NaN value in loss.")
                return 0, 0, True
            
            scaler.scale(loss).backward() if mixed_precision else loss.backward()
            scaler.step(optimizer) if mixed_precision else optimizer.step()
            scaler.update() if mixed_precision else None
        
        total_loss += loss.item()
        _, int_predictions = torch.max(outputs, 1)
        if cutmix_or_mixup:
            _, labels = torch.max(labels, 1)
        correct_predictions += (int_predictions == labels).sum().item()
        total_samples += labels.size(0)
    
    if scheduler:
        scheduler.step()
    
    average_loss = total_loss / len(train_dataloader)
    accuracy = correct_predictions / total_samples
    return average_loss, accuracy, nan_encountered

def test_epoch(
    model,
    test_dataloader,
    dataset_name,
    criterion,
    device
):
    model.eval()
    total_loss, correct_predictions, total_samples = 0.0, 0, 0
    
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.float().to(device), labels.to(device)

            # Squeeze only for dermamnist or bloodmnist: from shape [batch_size, 1] to shape [batch_size]
            if dataset_name in ['dermamnist', 'bloodmnist'] and labels.dim() == 2:
                labels = labels.squeeze(1)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, int_predictions = torch.max(outputs, 1)
            correct_predictions += (int_predictions == labels).sum().item()
            total_samples += labels.size(0)
    
    average_loss = total_loss / len(test_dataloader)
    accuracy = correct_predictions / total_samples
    return average_loss, accuracy


def get_stratified_splits(dataset, labels, train_pct=0.7, val_pct=0.15, test_pct=0.15, seed=42):
    np.random.seed(seed)
    assert abs(train_pct + val_pct + test_pct - 1.0) < 1e-5, "Percentages must sum to 1"

    classes = np.unique(labels)
    indices_per_split = {
        'train': [],
        'val': [],
        'test': []
    }

    for c in classes:
        class_indices = np.where(np.array(labels) == c)[0]
        n_samples = len(class_indices)

        n_train = int(n_samples * train_pct)
        n_val = int(n_samples * val_pct)

        np.random.shuffle(class_indices)

        indices_per_split['train'].extend(class_indices[:n_train])
        indices_per_split['val'].extend(class_indices[n_train:n_train+n_val])
        indices_per_split['test'].extend(class_indices[n_train+n_val:])

    train_set = Subset(dataset, indices_per_split['train']) if train_pct > 0 else None
    val_set = Subset(dataset, indices_per_split['val']) if val_pct > 0 else None 
    test_set = Subset(dataset, indices_per_split['test']) if test_pct > 0 else None

    return train_set, val_set, test_set

### Definition of the shadow models training function

In [8]:
def build_shadow_models(
    dataset_name,
    num_shadows=256,
    image_size=(224, 224),
    epochs=50,
    batch_size=96,
    learning_rate=0.001,
    weight_decay=0.00005,
    mixed_precision=True,
    device='cuda' if torch.cuda.is_available() else 'cpu',
    scheduler_name='CosineAnnealingLR',
    patience=10,
    horizontal_flip=True,
    huge_augment='aug_mix',
    random_crop=True,
    random_erasing=False,
    cutmix=False,
    mixup=True,
    label_smoothing=0.1,
    main_path=main_path,
    dataset_path=dataset_path
):

    # Write on the log file the training parameters
    print("Training Parameters:\n")
    print(f"seed: {seed}")
    for param_name, param_value in locals().items(): # locals() returns a dictionary with the current local variables
        print(f"{param_name}: {param_value}")
    print("---------------------------------------------------------\n")

    # Record the start time of the script
    start_time_script = time.time()
    
    # Here I am applying the same preprocessing to each set
    train_transformations = get_transform(dataset_name, image_size, huge_augment, horizontal_flip, random_crop, random_erasing, train=True)
    eval_transformations = get_transform(dataset_name, image_size, huge_augment=None, horizontal_flip=False, random_crop=False, random_erasing=False, train=False)

    if dataset_name == 'cifar10':
        data = dset.CIFAR10(root=dataset_path, train=False, transform=eval_transformations, download=True)
    elif dataset_name == 'cifar100':
        data = dset.CIFAR100(root=dataset_path, train=False, transform=eval_transformations, download=True)
    elif dataset_name == 'pets':
        data = OxfordPetsDataset(root=dataset_path, split='test', transform=eval_transformations)
    elif dataset_name == 'cars':
        data = dset.StanfordCars(root=dataset_path, split='test', transform=eval_transformations, download=False) 
    elif dataset_name == 'food':
        data = dset.Food101(root=dataset_path, split='test', transform=eval_transformations, download=True)
    elif dataset_name == 'tiny':
        data = TinyImageNetDataset(root=dataset_path, split='valid', transform=eval_transformations)
    elif dataset_name == 'dermamnist':
        data = DermaMNIST(root=dataset_path, split='test', size=224, as_rgb=True, transform=eval_transformations, download=True)
    elif dataset_name == 'bloodmnist':
        data = BloodMNIST(root=dataset_path, split='test', size=224, as_rgb=True, transform=eval_transformations, download=True)
    elif dataset_name == 'stl':
        data = dset.STL10(root=dataset_path, split='test', transform=eval_transformations, download=True)
    elif dataset_name == 'imagenette':
        data = dset.Imagenette(root=dataset_path, split='val', transform=eval_transformations, download=False) # for this dataset, download=True returns an error if the dataset is already downloaded
    elif dataset_name == 'caltech101':
        data = Caltech101Dataset(root=dataset_path, split='test', transform=eval_transformations)
    elif dataset_name == 'imagewoof':
        data = ImagewoofDataset(root=dataset_path, split='validation', transform=eval_transformations)
    else:
        raise TypeError(f"Unknown dataset: {dataset_name}.")

    
    # Get the number of classes of the dataset
    if dataset_name == 'dermamnist': # since the medmnist datasets don't have the "classes" attribute
        num_classes = 7
    elif dataset_name == 'bloodmnist':
        num_classes = 8
    else:
        num_classes = len(data.classes)
        
    n_splits = num_shadows // 2
    
    # First, check all existing models
    existing_pairs = []
    missing_pairs = []
    
    for i in range(n_splits):
        model_in_path = os.path.join(
            main_path,
            dataset_name,
            'lira',
            f'shadow_model_in_{i}'
        )
        model_out_path = os.path.join(
            main_path,
            dataset_name,
            'lira',
            f'shadow_model_out_{i}'
        )
        
        model_in_exists = os.path.exists(os.path.join(model_in_path, 'model.pt'))
        model_out_exists = os.path.exists(os.path.join(model_out_path, 'model.pt'))
        
        if model_in_exists and model_out_exists:
            existing_pairs.append(i)
        else:
            missing_pairs.append(i)
    
    if existing_pairs:
        print(f"Found {len(existing_pairs)} existing complete model pairs")
    if missing_pairs:
        print(f"Need to train {len(missing_pairs)} model pairs")
          

    # Process only missing pairs
    for i in tqdm(missing_pairs, desc="Training shadow models"):
        model_in_path = os.path.join(
            main_path,
            dataset_name,
            'lira',
            f'shadow_model_in_{i}'
        )
        model_out_path = os.path.join(
            main_path,
            dataset_name,
            'lira',
            f'shadow_model_out_{i}'
        )
        
    
        # Prepare datasets
        labels_in = [y for _, y in data]
        train_in, val_in, test_in = get_stratified_splits(data, labels_in, train_pct=0.5, val_pct=0.1, test_pct=0.4, seed=seed+i)

        train_in.dataset.transform = train_transformations
        val_in.dataset.transform = eval_transformations
        test_in.dataset.transform = eval_transformations

        labels_out = [y for _, y in train_in]
        train_out, _, _ = get_stratified_splits(train_in, labels_out, train_pct=0.5, val_pct=0.1, test_pct=0.4, seed=seed+i)
        train_out.dataset.transform = train_transformations
        
        print("Creating Train and Test Dataloaders...")
        if cutmix or mixup:
            if cutmix and mixup:
                advanced_transform = v2.RandomChoice([v2.CutMix(num_classes=num_classes), v2.MixUp(num_classes=num_classes)])
            elif cutmix:
                advanced_transform = v2.CutMix(num_classes=num_classes)
            else:
                advanced_transform = v2.MixUp(num_classes=num_classes)

            def collate_fn(batch):
                images, labels = default_collate(batch)
                if dataset_name in ['dermamnist', 'bloodmnist'] and labels.dim() == 2:
                    labels = labels.squeeze(1)
                return advanced_transform(images, labels)

            cutmix_or_mixup = True
            train_dataloader_in = DataLoader(train_in, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2, collate_fn=collate_fn)
            train_dataloader_out = DataLoader(train_out, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2, collate_fn=collate_fn)
        else:
            cutmix_or_mixup = False
            train_dataloader_in = DataLoader(train_in, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2)
            train_dataloader_out = DataLoader(train_out, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, prefetch_factor=2)

        val_dataloader_in = DataLoader(val_in, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, prefetch_factor=2)
        test_dataloader_in = DataLoader(test_in, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, prefetch_factor=2)
        print(f"Train and Test Dataloaders created!\n")
        
   
        # Train IN model if needed
        if not os.path.exists(os.path.join(model_in_path, 'model.pt')):
            os.makedirs(model_in_path, exist_ok=True)
            
            
            # Create the classifier model
            print("Creating the classifier...")
            model = extract_minimal_ofa_network(num_classes)
            model = model.to(device)
            print("Classifier created!\n")            
            
            # Create the scaler, the loss function, the optimizer and the learning rate scheduler
            scaler = torch.cuda.amp.GradScaler()
            criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
            optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
            scheduler = learning_rate_scheduling(optimizer, scheduler_name, epochs=epochs)

            # Initialize the training variables
            best_accuracy = 0.0
            best_model = None
            no_improvement_epochs = 0
            history = {
                'epoch': [],
                'train_loss': [],
                'train_accuracy': [],
                'val_loss': [],
                'val_accuracy': [],
                'learning_rate': [],
                'epoch_time': [],
            }

            
            # Start the training loop
            print("The training loop starts now!\n")
            for epoch in range(epochs):
                start_time = time.time()

                train_loss, train_accuracy, nan_encountered = train_epoch(
                    model=model,
                    train_dataloader=train_dataloader_in,
                    dataset_name=dataset_name,
                    criterion=criterion,
                    optimizer=optimizer,
                    scaler=scaler,
                    mixed_precision=mixed_precision,
                    scheduler=scheduler,
                    device=device,
                    cutmix_or_mixup=cutmix_or_mixup
                )

                if nan_encountered:
                    train_loss, train_accuracy, nan_encountered = train_epoch(
                        model=model,
                        train_dataloader=train_dataloader_in,
                        dataset_name=dataset_name,
                        criterion=criterion,
                        optimizer=optimizer,
                        scaler=scaler,
                        mixed_precision=False,
                        scheduler=scheduler,
                        device=device
                    )

                val_loss, val_accuracy = test_epoch(
                    model=model,
                    test_dataloader=val_dataloader_in,
                    dataset_name=dataset_name,
                    criterion=criterion,
                    device=device
                )

                scheduler_last_lr = scheduler.get_last_lr()[0] if scheduler is not None else learning_rate
                epoch_time = time.time() - start_time

                history['epoch'].append(epoch + 1)
                history['train_loss'].append(train_loss)
                history['train_accuracy'].append(train_accuracy)
                history['val_loss'].append(val_loss)
                history['val_accuracy'].append(val_accuracy)
                history['learning_rate'].append(scheduler_last_lr)
                history['epoch_time'].append(epoch_time)

                print(f"Epoch {epoch+1}/{epochs} - Time: {epoch_time:.2f}s - "
                  f"Train Loss: {train_loss:.4f} - Train Acc: {train_accuracy:.4f} - "
                  f"Val Loss: {val_loss:.4f} - Test Acc: {val_accuracy:.4f} - "
                  f"LR: {scheduler_last_lr:.6f}")

                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    best_model = copy.deepcopy(model)
                    no_improvement_epochs = 0
                else:
                    no_improvement_epochs += 1
                    if no_improvement_epochs >= patience:
                        print(f"Early stopping at epoch {epoch+1} due to no improvement in test accuracy for {patience} consecutive epochs.")
                        break

            
            torch.save(best_model, os.path.join(model_in_path, 'model.pt'))
            
            # Record the duration of the script
            end_time_script = time.time()
            elapsed_time_script = end_time_script - start_time_script
            hours, minutes, seconds = int(elapsed_time_script // 3600), int((elapsed_time_script % 3600) // 60), int(elapsed_time_script % 60)
            print(f"Execution completed in {hours}h {minutes}m {seconds}s.\n")
            
            torch.cuda.empty_cache()
            del model, best_model
            gc.collect()
            
         
        # Train OUT model if needed
        if not os.path.exists(os.path.join(model_out_path, 'model.pt')):
            os.makedirs(model_out_path, exist_ok=True) 
            
            
            # Create the classifier model
            print("Creating the classifier...")
            model = extract_minimal_ofa_network(num_classes)
            model = model.to(device)
            print("Classifier created!\n")            
            
            # Create the scaler, the loss function, the optimizer and the learning rate scheduler
            scaler = torch.cuda.amp.GradScaler()
            criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
            optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
            scheduler = learning_rate_scheduling(optimizer, scheduler_name, epochs=epochs)

            # Initialize the training variables
            best_accuracy = 0.0
            best_model = None
            no_improvement_epochs = 0
            history = {
                'epoch': [],
                'train_loss': [],
                'train_accuracy': [],
                'val_loss': [],
                'val_accuracy': [],
                'learning_rate': [],
                'epoch_time': [],
            }

            
            # Start the training loop
            print("The training loop starts now!\n")
            for epoch in range(epochs):
                start_time = time.time()

                train_loss, train_accuracy, nan_encountered = train_epoch(
                    model=model,
                    train_dataloader=train_dataloader_out,
                    dataset_name=dataset_name,
                    criterion=criterion,
                    optimizer=optimizer,
                    scaler=scaler,
                    mixed_precision=mixed_precision,
                    scheduler=scheduler,
                    device=device,
                    cutmix_or_mixup=cutmix_or_mixup
                )

                if nan_encountered:
                    train_loss, train_accuracy, nan_encountered = train_epoch(
                        model=model,
                        train_dataloader=train_dataloader_out,
                        dataset_name=dataset_name,
                        criterion=criterion,
                        optimizer=optimizer,
                        scaler=scaler,
                        mixed_precision=False,
                        scheduler=scheduler,
                        device=device
                    )

                val_loss, val_accuracy = test_epoch(
                    model=model,
                    test_dataloader=val_dataloader_in,
                    dataset_name=dataset_name,
                    criterion=criterion,
                    device=device
                )

                scheduler_last_lr = scheduler.get_last_lr()[0] if scheduler is not None else learning_rate
                epoch_time = time.time() - start_time

                history['epoch'].append(epoch + 1)
                history['train_loss'].append(train_loss)
                history['train_accuracy'].append(train_accuracy)
                history['val_loss'].append(val_loss)
                history['val_accuracy'].append(val_accuracy)
                history['learning_rate'].append(scheduler_last_lr)
                history['epoch_time'].append(epoch_time)

                print(f"Epoch {epoch+1}/{epochs} - Time: {epoch_time:.2f}s - "
                  f"Train Loss: {train_loss:.4f} - Train Acc: {train_accuracy:.4f} - "
                  f"Val Loss: {val_loss:.4f} - Val Acc: {val_accuracy:.4f} - "
                  f"LR: {scheduler_last_lr:.6f}")

                if val_accuracy > best_accuracy:
                    best_accuracy = val_accuracy
                    best_model = copy.deepcopy(model)
                    no_improvement_epochs = 0
                else:
                    no_improvement_epochs += 1
                    if no_improvement_epochs >= patience:
                        print(f"Early stopping at epoch {epoch+1} due to no improvement in test accuracy for {patience} consecutive epochs.")
                        break

            
            torch.save(best_model, os.path.join(model_out_path, 'model.pt'))
            
            # Record the duration of the script
            end_time_script = time.time()
            elapsed_time_script = end_time_script - start_time_script
            hours, minutes, seconds = int(elapsed_time_script // 3600), int((elapsed_time_script % 3600) // 60), int(elapsed_time_script % 60)
            print(f"Execution completed in {hours}h {minutes}m {seconds}s.\n")
            
            torch.cuda.empty_cache()
            del model, best_model
            gc.collect()
    
    print("Shadow model computation complete")

### Set the batch size for the training of the shadow models (which is done only once per dataset) and train!

Note: It checks automatically if some/all shadow models are already trained, so there is no need to comment the following line.

In [None]:
build_shadow_models(dataset_name=dataset_name, num_shadows=num_shadows, batch_size=1928)

### Definition of LiRA and related functions

In [10]:
def stable_logit(probs, y_true, eps=1e-10):
    """
    φstable = log(f(x)_y) - log(Σ_{y'≠y} f(x)_{y'})
    """
    if dataset_name in ['dermamnist', 'bloodmnist']:
        y_true = y_true.squeeze()
    y_true = y_true.astype(int)
    probs = probs.astype(np.float64)
    probs = np.clip(probs, eps, 1 - eps)
    log_p = np.log(probs)
    n_samples, n_classes = probs.shape
    log_f_y = log_p[np.arange(n_samples), y_true][:, np.newaxis]
    mask = np.ones_like(probs, dtype=bool)
    mask[np.arange(n_samples), y_true] = False
    probs_others = probs[mask].reshape(n_samples, n_classes - 1)
    sum_others = probs_others.sum(axis=1, keepdims=True)
    sum_others = np.maximum(sum_others, eps)
    log_sum_others = np.log(sum_others)
    return log_f_y - log_sum_others

def unstable_logit(probs, y_true, eps=1e-10):
    """
    φunstable = log(f(x)_y) - log(1 - f(x)_y)
    """
    if dataset_name in ['dermamnist', 'bloodmnist']:
        y_true = y_true.squeeze()
    y_true = y_true.astype(int)
    probs = np.clip(probs, eps, 1 - eps)
    probs = probs.astype(np.float64)
    f_y = probs[np.arange(len(probs)), y_true][:, np.newaxis]
    return np.log(f_y) - np.log(1 - f_y)

def logits_to_score(logits, y_true):
    """
    Compute score directly from pre-softmax logits:
    z(x)_y - max_{y'≠y} z(x)_{y'}
    """
    if dataset_name in ['dermamnist', 'bloodmnist']:
        y_true = y_true.squeeze()
    y_true = y_true.astype(int)
    n_samples, n_classes = logits.shape
    z_y = logits[np.arange(n_samples), y_true][:, np.newaxis]
    logits_excl_true = logits.copy()
    logits_excl_true[np.arange(n_samples), y_true] = -np.inf
    max_z_others = np.max(logits_excl_true, axis=1, keepdims=True)
    return z_y - max_z_others

def gaussian_likelihood_ratio(x, mu_in, sigma_in, mu_out, sigma_out):
    """
    Compute log likelihood ratio between two Gaussian distributions in a numerically stable way
    """
    eps = 1e-10
    sigma_in = np.maximum(sigma_in, eps)
    sigma_out = np.maximum(sigma_out, eps)
    log_lr = (
        -0.5 * np.log(2 * np.pi * sigma_in**2)
        - 0.5 * ((x - mu_in)**2) / sigma_in**2
        + 0.5 * np.log(2 * np.pi * sigma_out**2)
        + 0.5 * ((x - mu_out)**2) / sigma_out**2
    )
    return np.sum(log_lr, axis=1)

def one_sided_test(x, mu_out, sigma_out):
    """
    One-sided test for the offline variant (equation 4 from the paper)
    """
    eps = 1e-10
    sigma_out = np.maximum(sigma_out, eps)
    return -0.5 * np.sum(((x - mu_out)**2) / sigma_out**2, axis=1)

class LiRA:
    def __init__(
        self,
        n_shadows=256,
        attack_type='online',
        score_type='logit_stable',
        use_global_variance=True,
        debug=False
    ):
        """
        Initialize LiRA attack

        Args:
            n_shadows: number of shadow models (default 256)
            attack_type: 'online' or 'offline' (default 'online')
            score_type: 'logit_stable', 'logit_unstable' or 'logits' (default 'logit_stable')
            use_global_variance: if True use global variance for all examples (default True)
            debug: if True enable debug prints (default False)
        """
        self.n_shadows = n_shadows
        self.attack_type = attack_type
        self.score_type = score_type
        self.use_global_variance = use_global_variance
        self.debug = debug

    def compute_scores(self, preds, y_true):
        if self.score_type == 'logit_stable':
            return stable_logit(preds, y_true)
        elif self.score_type == 'logit_unstable':
            return unstable_logit(preds, y_true)
        elif self.score_type == 'logits':
            return logits_to_score(preds, y_true)
        else:
            raise ValueError(f"Score type {self.score_type} not supported")

    def fit_predict(self, shadow_preds_in, shadow_preds_out, target_preds, y_challenge):
        shadow_preds_in = shadow_preds_in.astype(np.float64)
        shadow_preds_out = shadow_preds_out.astype(np.float64)
        target_preds = target_preds.astype(np.float64)
        y_challenge = y_challenge.astype(int)
        
        # 1. Compute scores for shadow OUT
        scores_out = np.array([self.compute_scores(p, y_challenge) for p in shadow_preds_out])
        target_scores = self.compute_scores(target_preds, y_challenge)

        if self.debug:
            print("Scores OUT shape:", scores_out.shape)
            print("Scores OUT dtype:", scores_out.dtype)
            print("Scores OUT sample values:", scores_out[0:1, 0:5])
            print("Target scores shape:", target_scores.shape)
            print("Target scores dtype:", target_scores.dtype)
            print("Target scores sample values:", target_scores[0:5])

        # 3. Estimate Gaussian parameters for OUT
        if self.use_global_variance:
            mu_out = np.mean(scores_out, axis=(0,1))
            var_out = np.var(scores_out.astype(np.float64), axis=(0,1), ddof=1)
            var_out = np.maximum(np.nanmean(var_out), np.finfo(float).tiny)
        else:
            mu_out = np.mean(scores_out, axis=0)
            var_out = np.maximum(np.var(scores_out, axis=0), np.finfo(float).tiny)

        if self.debug:
            print("mu_out:", mu_out)
            print("var_out:", var_out)

        if self.attack_type == 'online':
            # Compute scores for shadow IN
            scores_in = np.array([self.compute_scores(p, y_challenge) for p in shadow_preds_in])
            
            if self.debug:
                print("Scores IN shape:", scores_in.shape)
                print("Scores IN dtype:", scores_in.dtype)
                print("Scores IN sample values:", scores_in[0:1, 0:5])

            # Estimate Gaussian parameters for IN
            if self.use_global_variance:
                mu_in = np.mean(scores_in, axis=(0,1))
                var_in = np.var(scores_in.astype(np.float64), axis=(0,1), ddof=1)
                var_in = np.maximum(np.nanmean(var_in), np.finfo(float).tiny)
            else:
                mu_in = np.mean(scores_in, axis=0)
                var_in = np.maximum(np.var(scores_in, axis=0), np.finfo(float).tiny)

            if self.debug:
                print("mu_in:", mu_in)
                print("var_in:", var_in)

            # Compute likelihood ratio
            lr = gaussian_likelihood_ratio(
                target_scores, 
                mu_in, np.sqrt(var_in),
                mu_out, np.sqrt(var_out)
            )
        else:
            # Offline attack
            lr = one_sided_test(
                target_scores,
                mu_out, np.sqrt(var_out)
            )

        if self.debug:
            print("Likelihood ratios (before clipping):", lr)

        # Handle extreme numerical values
        lr = np.clip(lr, -1e10, 1e10)
        mask = ~np.isnan(lr)
        lr = lr[mask]

        if self.debug:
            print("Likelihood ratios (after clipping):", lr)
            print("Number of valid likelihood ratios:", len(lr))

        stats = {
            'mu_out': mu_out,
            'var_out': var_out
        }

        if self.attack_type == 'online':
            stats.update({
                'mu_in': mu_in,
                'var_in': var_in
            })

        return lr, stats, mask

    def evaluate(self, likelihood_ratios, y_membership):
        """
        Compute evaluation metrics
        """
        fpr, tpr, _ = roc_curve(y_membership, likelihood_ratios)
        auc_score = auc(fpr, tpr)

        # TPR at specific FPR
        fpr_thresholds = [0.001, 0.01, 0.1]
        tpr_at_fpr = {}
        for fpr_threshold in fpr_thresholds:
            idx = np.searchsorted(fpr, fpr_threshold)
            if idx < len(tpr):
                tpr_at_fpr[fpr_threshold] = tpr[idx]
            else:
                tpr_at_fpr[fpr_threshold] = 0.0

        return {
            'fpr': fpr,
            'tpr': tpr,
            'auc': auc_score,
            'tpr_at_fpr': tpr_at_fpr
        }

### Challenge data preparation

In [11]:
def prepare_challenge_data(data_in, data_out, challenge_size=100, seed=42):
    """Challenge data con campionamento efficiente."""
    # Sample challenge indices
    idx_in = np.random.RandomState(seed).choice(len(data_in), challenge_size, replace=False)
    idx_out = np.random.RandomState(seed).choice(len(data_out), challenge_size, replace=False)

    # Extract only necessary data
    X_challenge = np.concatenate([[data_in[i][0] for i in idx_in], 
                               [data_out[i][0] for i in idx_out]])
    y_challenge = np.concatenate([[data_in[i][1] for i in idx_in], 
                               [data_out[i][1] for i in idx_out]])

    y_membership = np.zeros(2*challenge_size)
    y_membership[:challenge_size] = 1

    return X_challenge, y_challenge, y_membership

In [12]:
eval_transformations = get_transform(dataset_name, (224,224), huge_augment=None, horizontal_flip=False, random_crop=False, random_erasing=False, train=False)

if dataset_name == 'cifar10':
    data_train = dset.CIFAR10(root=dataset_path, train=True, transform=eval_transformations, download=True)
    data_test = dset.CIFAR10(root=dataset_path, train=False, transform=eval_transformations, download=True)
elif dataset_name == 'cifar100':
    data_train = dset.CIFAR100(root=dataset_path, train=True, transform=eval_transformations, download=True)
    data_test = dset.CIFAR100(root=dataset_path, train=False, transform=eval_transformations, download=True)
elif dataset_name == 'pets':
    data_train = OxfordPetsDataset(root=dataset_path, split='train', transform=eval_transformations)
    data_test = OxfordPetsDataset(root=dataset_path, split='test', transform=eval_transformations)
elif dataset_name == 'cars':
    data_train = dset.StanfordCars(root=dataset_path, split='train', transform=eval_transformations, download=False)
    data_test = dset.StanfordCars(root=dataset_path, split='test', transform=eval_transformations, download=False)
elif dataset_name == 'food':
    data_train = dset.Food101(root=dataset_path, split='train', transform=eval_transformations, download=True)
    data_test = dset.Food101(root=dataset_path, split='test', transform=eval_transformations, download=True)
elif dataset_name == 'tiny':
    data_train = TinyImageNetDataset(root=dataset_path, split='train', transform=eval_transformations)
    data_test = TinyImageNetDataset(root=dataset_path, split='valid', transform=eval_transformations)
elif dataset_name == 'dermamnist':
    data_train = DermaMNIST(root=dataset_path, split='train', size=224, as_rgb=True, transform=eval_transformations, download=True)
    data_test = DermaMNIST(root=dataset_path, split='test', size=224, as_rgb=True, transform=eval_transformations, download=True)
elif dataset_name == 'bloodmnist':
    data_train = BloodMNIST(root=dataset_path, split='train', size=224, as_rgb=True, transform=eval_transformations, download=True)
    data_test = BloodMNIST(root=dataset_path, split='test', size=224, as_rgb=True, transform=eval_transformations, download=True)
elif dataset_name == 'stl':
    data_train = dset.STL10(root=dataset_path, split='train', transform=eval_transformations, download=True)
    data_test = dset.STL10(root=dataset_path, split='test', transform=eval_transformations, download=True)
elif dataset_name == 'imagenette':
    data_train = dset.Imagenette(root=dataset_path, split='train', transform=eval_transformations, download=False)
    data_test = dset.Imagenette(root=dataset_path, split='val', transform=eval_transformations, download=False)
elif dataset_name == 'caltech101':
    data_train = Caltech101Dataset(root=dataset_path, split='train', transform=eval_transformations)
    data_test = Caltech101Dataset(root=dataset_path, split='test', transform=eval_transformations)
elif dataset_name == 'imagewoof':
    data_train = ImagewoofDataset(root=dataset_path, split='train', transform=eval_transformations)
    data_test = ImagewoofDataset(root=dataset_path, split='validation', transform=eval_transformations)
else:
    raise TypeError(f"Unknown dataset: {dataset_name}.")

In [13]:
X_challenge, y_challenge, y_membership = prepare_challenge_data(
    data_in=data_train,    
    data_out=data_test,    
    challenge_size=challenge_size,   
    seed=seed
)

### Set the path to the model to attack

In [None]:
# Define the path of the model to attack
model_to_attack = "../storage/trained_synthetic_classifiers/20250201_0051.pth"
                                
# Print current parameters for verification
print(f"Parameters defined:\n"
      f"  Model to Attack:\t{model_to_attack}\n")

### Load the model to attack and the challenge data

In [None]:
target_model = extract_ofa_network(ofa_config, num_classes)
target_model.load_state_dict(torch.load(model_to_attack))
target_model.to(device)
target_model.eval()

target_preds = []
with torch.no_grad():
    for images in DataLoader(X_challenge, batch_size=512):
        pred = target_model(images.to(device)).cpu() 
        target_preds.append(pred)
        
target_preds = torch.cat(target_preds, dim=0).numpy()
target_preds.shape

### Get the shadows' predictions

In [None]:
def get_shadow_predictions(models_path, dataset, N_SHADOW_MODELS, device='cuda'):
    shadow_preds = {'in': [], 'out': []}

    for i in tqdm(range(N_SHADOW_MODELS // 2)):
        for model_type in ['in', 'out']:
            model = torch.load(f'{models_path}/shadow_model_{model_type}_{i}/model.pt')
            model.eval()

            predictions = []
            with torch.no_grad():
                for images in DataLoader(dataset, batch_size=512):
                    pred = model(images.to(device)).cpu()
                    predictions.append(pred)

            shadow_preds[model_type].append(torch.cat(predictions).numpy())
            del model
            torch.cuda.empty_cache()

    return np.array(shadow_preds['in']), np.array(shadow_preds['out'])

shadow_preds_in, shadow_preds_out = get_shadow_predictions(models_path, X_challenge, num_shadows)

### Test all the combinations and plot the results!

In [None]:
mask_in = np.isnan(shadow_preds_in).any(axis=(1, 2))
mask_out = np.isnan(shadow_preds_out).any(axis=(1, 2))
mask_combined = mask_in | mask_out
mask_keep = ~mask_combined
filtered_shadow_preds_in = shadow_preds_in[mask_keep]
filtered_shadow_preds_out = shadow_preds_out[mask_keep]
filtered_shadow_preds_in.shape, filtered_shadow_preds_out.shape

In [None]:
def test_all_combinations(N_SHADOW_MODELS, filtered_shadow_preds_in, filtered_shadow_preds_out, 
                         target_preds, y_challenge, y_membership):
    # Define all possible combinations
    attack_types = ['offline', 'online']
    score_types = ['logit_stable', 'logit_unstable', 'logits']
    variance_options = [True, False]

    results = []
    
    # Test all combinations
    for attack_type, score_type, use_global_variance in product(attack_types, score_types, variance_options):
        # Initialize the attack with the current combination
        attack = LiRA(
            n_shadows=N_SHADOW_MODELS,
            attack_type=attack_type,
            score_type=score_type,
            use_global_variance=use_global_variance
        )

        # Execute the attack
        likelihood_ratios, stats, mask = attack.fit_predict(
            shadow_preds_in=filtered_shadow_preds_in,
            shadow_preds_out=filtered_shadow_preds_out,
            target_preds=target_preds,
            y_challenge=y_challenge
        )

        # Evaluate the attack
        metrics = attack.evaluate(likelihood_ratios, y_membership[mask])

        # Save the results
        results.append({
            'attack_type': attack_type,
            'score_type': score_type,
            'use_global_variance': use_global_variance,
            'auc': metrics['auc'],
            'metrics': metrics
        })

        print(f"Combinazione testata - Attack: {attack_type}, Score: {score_type}, "
              f"Global Variance: {use_global_variance}, AUC: {metrics['auc']:.4f}")
    
    # Find the best result
    best_result = max(results, key=lambda x: x['auc'])
    
    print("\nMigliore combinazione trovata:")
    print(f"Attack Type: {best_result['attack_type']}")
    print(f"Score Type: {best_result['score_type']}")
    print(f"Global Variance: {best_result['use_global_variance']}")
    print(f"AUC: {best_result['auc']:.4f}")
    
    # Create plots for the best result
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Standard plot
    ax1.plot(best_result['metrics']['fpr'], best_result['metrics']['tpr'], 
             color='tab:blue', lw=2)
    ax1.plot([0, 1], [0, 1], 'k--', lw=1)
    ax1.set_xlabel('False Positive Rate')
    ax1.set_ylabel('True Positive Rate')
    ax1.set_title(f'ROC Curve (AUC = {best_result["auc"]:.3f})')
    ax1.grid(True, alpha=0.3)
    
    # Log scale plot
    ax2.plot(best_result['metrics']['fpr'], best_result['metrics']['tpr'], 
             color='tab:blue', lw=2)
    ax2.plot([0, 1], [0, 1], 'k--', lw=1)
    ax2.set_xlabel('False Positive Rate')
    ax2.set_ylabel('True Positive Rate')
    ax2.set_title('ROC Curve (Log Scale)')
    ax2.set_xscale('log')
    ax2.set_yscale('log')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Create a DataFrame with all the results for easier analysis
    results_df = pd.DataFrame([
        {
            'Attack Type': r['attack_type'],
            'Score Type': r['score_type'],
            'Global Variance': r['use_global_variance'],
            'AUC': r['auc']
        }
        for r in results
    ])
    
    print("\nTutti i risultati ordinati per AUC:")
    print(results_df.sort_values('AUC', ascending=False))
    
    return best_result, results_df

In [None]:
best_result, results_df = test_all_combinations(
    num_shadows,
    filtered_shadow_preds_in,
    filtered_shadow_preds_out,
    target_preds,
    y_challenge,
    y_membership
)