<a href="https://colab.research.google.com/github/MMathisLab/AROS/blob/main/Notebooks/AROS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Adversarially Robust Out-of-Distribution Detection Using Lyapunov-Stabilized Embeddings

This notebook is designed to replicate and analyze the results presented in Table 1 of the AROS paper, focusing on out-of-distribution detection performance under both attack scenarios and clean evaluation. The dataset configurations involve using CIFAR-10 and CIFAR-100 as in-distribution and out-of-distribution datasets. The notebook is structured to load a pre-trained model as the encoder, followed by generating fake OOD embeddings through sampling. The model is then trained using the designed loss function and evaluated across various OOD detection benchmarks to assess its performance under different conditions.



#Import packages

In [1]:
!pip install git+https://github.com/RobustBench/robustbench.git
!pip install aros-node==0.0.1rc1

Collecting git+https://github.com/RobustBench/robustbench.git
  Cloning https://github.com/RobustBench/robustbench.git to /tmp/pip-req-build-i22rthe1
  Running command git clone --filter=blob:none --quiet https://github.com/RobustBench/robustbench.git /tmp/pip-req-build-i22rthe1
  Resolved https://github.com/RobustBench/robustbench.git to commit 776bc95bb4167827fb102a32ac5aea62e46cfaab
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting autoattack@ git+https://github.com/fra31/auto-attack.git@a39220048b3c9f2cca9a4d3a54604793c68eca7e#egg=autoattack (from robustbench==1.1)
  Using cached autoattack-0.1-py3-none-any.whl
Collecting autoattack@ git+https://github.com/fra31/auto-attack.git@a39220048b3c9f2cca9a4d3a54604793c68eca7e#egg=autoattack (from robustbench->aros-node==0.0.1rc1)
  Using cached autoattack-0.1-py3-none-any.whl


In [2]:
import aros_node
import argparse
import torch
import torch.nn as nn
from tqdm.notebook import tqdm
import numpy as np

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


#Set hyperparameters & dataloader

In [3]:
parser = argparse.ArgumentParser(description="Hyperparameters for the script")

# Define the hyperparameters controlled via CLI 'Ding2020MMA'


parser.add_argument('--fast', type=bool, default=True, help='Toggle between fast and full fake data generation modes')
parser.add_argument('--epoch1', type=int, default=2, help='Number of epochs for stage 1')
parser.add_argument('--epoch2', type=int, default=1, help='Number of epochs for stage 2')
parser.add_argument('--epoch3', type=int, default=2, help='Number of epochs for stage 3')
parser.add_argument('--in_dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100'], help='The in-distribution dataset to be used')
parser.add_argument('--threat_model', type=str, default='Linf', help='Adversarial threat model for robust training')
parser.add_argument('--noise_std', type=float, default=1, help='Standard deviation of noise for generating noisy fake embeddings')
parser.add_argument('--attack_eps', type=float, default=8/255, help='Perturbation bound (epsilon) for PGD attack')
parser.add_argument('--attack_steps', type=int, default=10, help='Number of steps for the PGD attack')
parser.add_argument('--attack_alpha', type=float, default=2.5 * (8/255) / 10, help='Step size (alpha) for each PGD attack iteration')

args = parser.parse_args('')

# Set the default model name based on the selected dataset
if args.in_dataset == 'cifar10':
    default_model_name = 'Rebuffi2021Fixing_70_16_cutmix_extra'
elif args.in_dataset == 'cifar100':
    default_model_name = 'Wang2023Better_WRN-70-16'

parser.add_argument('--model_name', type=str, default=default_model_name, choices=['Rebuffi2021Fixing_70_16_cutmix_extra', 'Wang2023Better_WRN-70-16'], help='The pre-trained model to be used for feature extraction')

# Re-parse arguments to include model_name selection based on the dataset
args = parser.parse_args('')
num_classes = 10 if args.in_dataset == 'cifar10' else 100

trainloader, testloader,test_set, ID_OOD_loader = aros_node.get_loaders(in_dataset=args.in_dataset)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


#Fake embedding generation

In [4]:
robust_backbone = aros_node.load_model(model_name=args.model_name, dataset=args.in_dataset, threat_model=args.threat_model).to(device)

last_layer_name, last_layer = list(robust_backbone.named_children())[-1]
setattr(robust_backbone, last_layer_name, nn.Identity())
fake_loader=None

num_fake_samples = len(trainloader.dataset) // num_classes

embeddings, labels = [], []

with torch.no_grad():
    for imgs, lbls in trainloader:
        imgs = imgs.to(device, non_blocking=True)
        embed = robust_backbone(imgs).cpu()  # move to CPU only once per batch
        embeddings.append(embed)
        labels.append(lbls)
embeddings = torch.cat(embeddings).numpy()
labels = torch.cat(labels).numpy()


print("embedding computed...")

if args.fast==False:
  gmm_dict = {}
  for cls in np.unique(labels):
      cls_embed = embeddings[labels == cls]
      gmm = GaussianMixture(n_components=1, covariance_type='full').fit(cls_embed)
      gmm_dict[cls] = gmm

  print("fake crafing...")

  fake_data = []

  for cls, gmm in gmm_dict.items():
      samples, likelihoods = [], []
      while len(samples) < num_samples_needed:
          s = gmm.sample(100)[0]
          likelihood = gmm.score_samples(s)
          samples.append(s[likelihood < np.quantile(likelihood, 0.001)])
          likelihoods.append(likelihood[likelihood < np.quantile(likelihood, 0.001)])
          if sum(len(smp) for smp in samples) >= num_samples_needed:
              break
      samples = np.vstack(samples)[:num_samples_needed]
      fake_data.append(samples)

  fake_data = np.vstack(fake_data)
  fake_data = torch.tensor(fake_data).float()
  fake_data = F.normalize(fake_data, p=2, dim=1)

  fake_labels = torch.full((fake_data.shape[0],), 10)
  fake_loader = aros_node.DataLoader(aros_node.TensorDataset(fake_data, fake_labels), batch_size=128, shuffle=True)

if args.fast==True:

    noise_std = 0.1  # standard deviation of noise
    noisy_embeddings = torch.tensor(embeddings) + noise_std * torch.randn_like(torch.tensor(embeddings))

    # Normalize Noisy Embeddings
    noisy_embeddings = aros_node.F.normalize(noisy_embeddings, p=2, dim=1)[:len(trainloader.dataset)//num_classes]

    # Convert to DataLoader if needed
    fake_labels = torch.full((noisy_embeddings.shape[0],), num_classes)[:len(trainloader.dataset)//num_classes]
    fake_loader = aros_node.DataLoader(aros_node.TensorDataset(noisy_embeddings, fake_labels), batch_size=128, shuffle=True)



  checkpoint = torch.load(model_path, map_location=torch.device('cpu'))


embedding computed...


#Train and eval

In [None]:
final_model = aros_node.stability_loss_function_(trainloader, testloader, robust_backbone, num_classes, fake_loader, last_layer, args)

test_attack = aros_node.evaluate.PGD_AUC(final_model, eps=args.attack_eps, steps=args.attack_steps, alpha=args.attack_alpha, num_classes=num_classes)
aros_node.evaluate.get_clean_AUC(final_model, ID_OOD_loader , device, num_classes)

adv_auc = aros_node.evaluate.get_auc_adversarial(model=final_model,  test_loader=ID_OOD_loader, test_attack=test_attack, device=device, num_classes=num_classes)
print(f"Adv AUC: {adv_auc}")


Sequential(
  (0): DMWideResNet(
    (init_conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (layer): Sequential(
      (0): _BlockGroup(
        (block): Sequential(
          (0): _Block(
            (batchnorm_0): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu_0): Swish()
            (conv_0): Conv2d(16, 256, kernel_size=(3, 3), stride=(1, 1), bias=False)
            (batchnorm_1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu_1): Swish()
            (conv_1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (shortcut): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
          (1): _Block(
            (batchnorm_0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu_0): Swish()
            (conv_0): Conv2d(256, 256, kernel_size

Training ODE block with loss function:   0%|          | 0/500 [00:00<?, ?it/s]

Loss 3.293180465698242
Epoch 1, Learning Rate: 0.005

Epoch: 0

Epoch: 1


  0%|          | 0/2500 [00:00<?, ?batch/s]

FPR95: 0.4375
AUROC is: 0.865082755
AUPR: 0.8353065616255184


  0%|          | 0/2500 [00:00<?, ?batch/s]

#Extra Experiments.

In [None]:
import os
import zipfile

!pip install wget
import wget

from pathlib import Path
import torchvision
from torchvision import transforms
import tarfile

image_size=32
load_out_names=[ "places365","LSUN", "iSUN" ]



if "places365" in load_out_names:
    # Define the directory path and create it if it does not exist
    base_dir = "./datasets/data"
    os.makedirs(base_dir, exist_ok=True)

    # Download and save categories_places365.txt
    dest = os.path.join(base_dir, "categories_places365.txt")
    if not Path(dest).is_file():
        wget.download("https://dl.dropboxusercontent.com/s/enr71zpolzi1xzm/categories_places365.txt", out=dest)

    # Download and save places365_val.txt
    dest = os.path.join(base_dir, "places365_val.txt")
    if not Path(dest).is_file():
        wget.download("https://dl.dropboxusercontent.com/s/gaf1ygpdnkhzyjo/places365_val.txt", out=dest)

    # Download and save val_256.tar
    dest = os.path.join(base_dir, "val_256.tar")
    if not Path(dest).is_file():
        wget.download("https://dl.dropboxusercontent.com/s/3pwqsyv33f6if3z/val_256.tar", out=dest)

    # Extract val_256.tar if val_256 directory does not exist
    dest_final = os.path.join(base_dir, "val_256")
    if not Path(dest_final).is_dir():
        with tarfile.open(dest) as tar:
            tar.extractall(path=base_dir)

    # Load the Places365 dataset
    places365 = torchvision.datasets.Places365(
        root=base_dir,
        split='val',
        small=True,
        download=False,
        transform=transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor()
        ])
    )






if "LSUN" in load_out_names:
        # Define the base directory and ensure it exists
        base_dir = "./datasets/data"
        os.makedirs(base_dir, exist_ok=True)

        # Define the destination path for LSUN dataset
        dest = os.path.join(base_dir, "LSUN_resize.tar.gz")
        if not Path(dest).is_file():
            wget.download("https://bit.ly/3wA55Wb", out=dest)
            with tarfile.open(dest) as tar:
                tar.extractall(path=os.path.join(base_dir, "LSUN_resize"))

        # Define transformation based on image size
        transform = transforms.ToTensor() if image_size == 32 else transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(image_size)
        ])

        # Load the LSUN dataset
        LSUN = torchvision.datasets.ImageFolder(root=os.path.join(base_dir, "LSUN_resize"), transform=transform)


if "iSUN" in load_out_names:
        # Define the destination path for iSUN dataset
        dest = os.path.join(base_dir, "iSUN.tar.gz")
        if not Path(dest).is_file():
            wget.download("https://bit.ly/3yRMTJe", out=dest)
            with tarfile.open(dest) as tar:
                tar.extractall(path=os.path.join(base_dir, "iSUN"))

        # Define transformation based on image size
        transform = transforms.ToTensor() if image_size == 32 else transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(image_size)
        ])

        # Load the iSUN dataset
        iSUN = torchvision.datasets.ImageFolder(root=os.path.join(base_dir, "iSUN"), transform=transform)




class LabelChangedDataset(aros_node.Dataset):
    def __init__(self, original_dataset, new_label):
        self.original_dataset = original_dataset
        self.new_label = new_label

    def __len__(self):
        return len(self.original_dataset)

    def __getitem__(self, idx):
        image, _ = self.original_dataset[idx]
        return image, self.new_label



# Download and load the SVHN test set
svhn = torchvision.datasets.SVHN(root='./datasets/data', split='test', download=True, transform=transform)




iSUN = LabelChangedDataset(iSUN, num_classes)

LSUN = LabelChangedDataset(LSUN, num_classes)

places365 = LabelChangedDataset(places365, num_classes)


svhn = LabelChangedDataset(svhn, num_classes)


In [None]:
test_dataset_isun = aros_node.ConcatDataset([test_set, iSUN])

testloader_isun = aros_node.DataLoader(test_dataset_isun, shuffle=False, batch_size=64)

aros_node.get_clean_AUC(final_model, testloader_isun , device, num_classes)


In [None]:
test_dataset_LSUN = aros_node.ConcatDataset([test_set, LSUN])

testloader_LSUN = aros_node.DataLoader(test_dataset_LSUN, shuffle=False, batch_size=64)

aros_node.get_clean_AUC(final_model, testloader_LSUN, device, num_classes)

In [None]:
test_dataset_places365 = aros_node.ConcatDataset([test_set, places365])

testloader_places365 = aros_node.DataLoader(test_dataset_places365, shuffle=False, batch_size=64)

aros_node.get_clean_AUC(final_model, testloader_places365 , device, num_classes)


In [None]:
test_dataset_svhn = aros_node.ConcatDataset([test_set, svhn])

testloader_svhn = aros_node.DataLoader(test_dataset_svhn, shuffle=False, batch_size=64)

aros_node.get_clean_AUC(final_model, testloader_svhn , device, num_classes)
