# Steal Now and Attack Later

This notebook provides a demonstration showing how to use ART to launch the SNAL attack [1].

The core concept of this attack is to first collect objects from any model and then in a second step append valid patches to the target image and weaken the impact of unimportant pixels.


[1] Steal Now and Attack Later: Evaluating Robustness of Object Detection against Black-box Adversarial Attacks (https://arxiv.org/abs/2404.15881)

In [None]:
import logging
from typing import Any
import sys

import numpy as np
import torch
import torchvision

logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))

In [None]:
#%% Download the target image from MS COCO dataset
from io import BytesIO

import requests
from PIL import Image
def load_image(img_path, img_size=(640, 640), unsqueeze=True):

    img = Image.open(img_path).convert("RGB")
    img = img.resize((img_size[1], img_size[0]))

    # convert to pytorch tensor
    img = np.array(img).transpose((2, 0, 1))
    img = np.ascontiguousarray(img)
    img_tensor = torch.from_numpy(img).to(torch.float32) / 255.0

    # adjust dims accordingly
    if unsqueeze:
        img_tensor = torch.unsqueeze(img_tensor, 0)

    return img_tensor

def save_img(img_tensor, file_name):

    if len(img_tensor.shape) == 4:
        raise NotImplementedError("img_tensor.shape should be 3")

    img = torchvision.transforms.ToPILImage()(img_tensor)
    img.save(file_name)

TARGET = 'https://farm2.staticflickr.com/1065/705706084_39a7f28fc9_z.jpg' # val2017/000000552842.jpg
response = requests.get(TARGET)
img = np.asarray(Image.open(BytesIO(response.content)).resize((640, 640)))
img = np.array(img).transpose((2, 0, 1))
img = np.ascontiguousarray(img)
img_pt = torch.from_numpy(img).to(torch.float32) / 255.0
save_img(img_pt, 'target.png')
x_pt = load_image('target.png')

In [None]:
#%% Download YOLO model
# If ultralytics is not found, please run the command: `pip install ultralytics`
from ultralytics import YOLO

class TYOLOv8():
    def __init__(self, 
                 model_name: str,
                 output_folder: str,
                 img_dim: int) -> None:

        self.bcount = 0
        self.img_dim = img_dim

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.amp = False
        self.output_folder = output_folder

        self.model_name = model_name
        self.set_model(model_name)
        self.create_folder()

    def set_model(self, model_name) -> None:
        self.model = YOLO(model_name)

    def print(self):
        logger.info(f"model name: {self.model_name}")
        logger.info(f"total query: {self.bcount}")

    def count_reset(self) -> None:
        self.bcount = 0

    def count(self, x : torch.tensor) -> None:
        self.bcount = self.bcount + x.shape[0]

    def create_folder(self) -> None:
        dummy_img = torch.zeros([1, 3, self.img_dim, self.img_dim], device=self.device)
        self.model(dummy_img, save=True)
        self.output_folder = self.model.predictor.save_dir

    def transform(self, x: torch.Tensor) -> torch.Tensor:
        return x

    def inference(self, x : torch.tensor) -> list:
        pred = self.forwad(x)
        out = []
        for obj in pred:
            out.append(obj.boxes.xyxy)

        return out

    def forwad(self, x: torch.Tensor) -> Any:
        self.count(x)
        with torch.no_grad(), torch.cuda.amp.autocast(self.amp):
            output = self.model(self.transform(x))
        return output

    def eval_img(self, f_name):
        self.bcount = self.bcount + 1
        results = self.model(f_name, save=True)
        num_boxes = results[0].boxes.xyxy.shape[0]
        logger.info(f'*** total boxes in the eval_img: {num_boxes} ***')
        return num_boxes

model = TYOLOv8('yolov8m', './', 640)
model.eval_img('target.png')

In [None]:
#%% Collect patches from a set of images

import glob

from torchvision import transforms
from torchvision.datasets.vision import VisionDataset

from art.attacks.evasion import collect_patches_from_images
class CustomDatasetFolder(VisionDataset):
    def __init__(self, root, transform=None):
        super(CustomDatasetFolder, self).__init__(root)
        self.transform = transform
        samples = glob.glob(f"{root}/*.jpg")

        self.samples = samples

    def __getitem__(self, index):
        sample = self._loader(self.samples[index])
        if self.transform is not None:
            sample = self.transform(sample)
        return sample
    
    def __len__(self):
        return len(self.samples)

    def _loader(self, path):
        return Image.open(path).convert("RGB")

ROOT_MSCOCO = '/dataset/val2017/'
img_dataset = CustomDatasetFolder(
            ROOT_MSCOCO,
            transforms.Compose([
            transforms.RandomResizedCrop((640,640)),
            transforms.AutoAugment(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]))
img_loader = torch.utils.data.DataLoader(img_dataset, batch_size=1, shuffle=True)

candidates_list = []
TILE_SIZE = 64
MAX_IMGS = 1000
img_count = 0
for x in iter(img_loader):
    img_count = img_count + 1
    if img_count == MAX_IMGS:
        break

    candidates, _ = collect_patches_from_images(model, x)
    print(f'Number of objects are detected: {len(candidates[0])}')
    candidates_list = candidates_list + candidates[0]

print(len(candidates_list))

In [None]:
from art.attacks.evasion import SNAL
attack = SNAL(model,
              eps = 16.0 /255.0,
              max_iter = 400,
              num_grid = 10)
attack.set_candidates(candidates_list)
x_adv = attack.generate(img_pt[None, :].numpy())
adv_np = np.transpose(x_adv[0, :] * 255.0, (1, 2, 0)).astype(np.uint8)
Image.fromarray(adv_np).save(f'{model.output_folder}/output.png')
model.eval_img(f'{model.output_folder}/output.png')
model.print()