# Composite-Adv Demonstration
This notebook provides a step-by-step demonstration showing how to launch our composite adversarial attack (CAA). We use the CIFAR-10 dataset for demonstration, while other datasets could be executed similarly.

![CAA Flow](figures/CAA_Flow.png)

## I. Install `composite-adv` Package

In [None]:
import warnings
warnings.filterwarnings('ignore')
!pip install -q git+https://github.com/IBM/composite-adv.git
!pip install -q --upgrade --no-cache-dir gdown # Download pre-trained models from google drive.

## II. Dataset

In [None]:
from composite_adv.utilities import make_dataloader
data_loader = make_dataloader(dataset_path="./data/", dataset_name="cifar10", batch_size=256)

## III. Select Models

In [None]:
# Download Pre-trained model
# download_gdrive('Google File ID', 'Saving Path')
from composite_adv.utilities import download_gdrive
download_gdrive('1109eOxG5sSIxCwe_BUKRViMrSF20Ac4c', 'cifar10-resnet_50-gat_fs.pt')

In [None]:
# load model
from composite_adv.utilities import make_model
model = make_model(arch="resnet50", # GAT support two architectures: ['resnet50','wideresnet']
                   dataset_name="cifar10", # GAT support three datasets: ['cifar10','svhn','imagenet']
                   checkpoint_path="cifar10-resnet_50-gat_fs.pt")

# Load Madry's model (https://github.com/MadryLab/robustness)
# from composite_adv.utilities import make_madry_model
# model = make_madry_model(arch="resnet50",
#                          dataset_name="cifar10",
#                          checkpoint_path="")


# Load TRADES model (https://github.com/yaodongyu/TRADES)
# from composite_adv.utilities import make_trades_model
# model = make_trades_model(arch="wideresnet",
#                           dataset_name="cifar10",
#                           checkpoint_path="")


# Send to GPU
import torch
if not torch.cuda.is_available():
    print('using CPU, this will be slow')
else:
    model.cuda()

## IV. Evaluate Clean Accuracy

In [None]:
from composite_adv.utilities import robustness_evaluate
from composite_adv.attacks import NoAttack

attack = NoAttack()
robustness_evaluate(model, attack, data_loader)

## V. Evaluate Robust Accuracy

**CAA Configuration**
1. Attacks Pool Selection. For simpilicity, we use the following abbreviations to specify each attack types.
   `0`: Hue, `1`: Saturation, `2`: Rotation, `3`: Brightness, `4`: Contrast, `5`: $\ell_\infty$

2. Attack Ordering Specify. We provide three ordering options ['fixed','random','scheduled']

**Setup**
```python
# Specify Attack

from composite_adv.attacks import CompositeAttack
# Three Attacks (Hue->Saturation->Rotation; Fixed Order)
attack = CompositeAttack(model, dataset="cifar10", enabled_attack=(0,1,2), order_schedule="fixed")
# Semantic Attacks; Random Order
attack = CompositeAttack(model, dataset="cifar10", enabled_attack=(0,1,2,3,4), order_schedule="random")
# Full Attacks; Scheduled Order
attack = CompositeAttack(model, dataset="cifar10", enabled_attack=(0,1,2,3,4,5), order_schedule="scheduled") 

# Model Evaluation
from composite_adv.utilities import robustness_evaluate
robustness_evaluate(model, attack, data_loader)
```

In [None]:
from composite_adv.attacks import CompositeAttack
# Full Attacks; Scheduled Order
attack = CompositeAttack(model, dataset="cifar10", enabled_attack=(0,1,2,3,4,5), order_schedule="scheduled")

from composite_adv.utilities import robustness_evaluate
robust_accuracy, attack_success_rate = robustness_evaluate(model, attack, data_loader)
print("Robust Accuracy:", robust_accuracy)
print("Attack Success Rate:", attack_success_rate)

## VI. Visualize CAA examples

In [None]:
from composite_adv.attacks import CompositeAttack
import torchvision
attack = CompositeAttack(model, enabled_attack=(0,1,2,3,4,5), order_schedule="scheduled")

def imgshow(img):
    import matplotlib.pyplot as plt
    import numpy as np
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [None]:
inputs, labels = next(iter(data_loader))
ori_images, ori_labels = inputs[:5].cuda(), labels[:5].cuda()
adv_images = attack(ori_images, ori_labels)
ori_grid = torchvision.utils.make_grid(ori_images.cpu(), nrow=5, padding=1)
adv_grid = torchvision.utils.make_grid(adv_images.cpu(), nrow=5, padding=1)

In [None]:
imgshow(ori_grid)

In [None]:
imgshow(adv_grid)