# Adversarial Patch Attack & Defence (SAC-style)

Complete pipeline notebook.

## 1. Imports & Globals

In [1]:

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms


## 2. Load Dataset

In [2]:

DATA_ROOT = "./data_gtsrb"
IMG_SIZE = 224
BATCH_SIZE = 8
NUM_WORKERS = 2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

imagenet_mean = (0.485, 0.456, 0.406)
imagenet_std  = (0.229, 0.224, 0.225)

tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])

train_set = datasets.GTSRB(DATA_ROOT, split="train", download=True, transform=tfms)
test_set  = datasets.GTSRB(DATA_ROOT, split="test",  download=True, transform=tfms)


## 3. Patch Trainer

In [3]:

class UniversalTargetedPatchTrainer:
    def __init__(self, model, target_class, patch_size, lr, init_patch=None):
        self.model = model
        self.target_class = target_class
        self.patch_size = patch_size
        self.device = DEVICE

        if init_patch is None:
            self.patch = torch.rand(3, patch_size, patch_size, device=self.device, requires_grad=True)
        else:
            self.patch = init_patch.to(self.device).clone().requires_grad_(True)

        self.optimizer = optim.Adam([self.patch], lr=lr)
        self.criterion = nn.CrossEntropyLoss()

    def apply_patch(self, images, random_pos=True):
        B, C, H, W = images.shape
        patched = images.clone()
        for i in range(B):
            x = random.randint(0, W-self.patch_size)
            y = random.randint(0, H-self.patch_size)
            patched[i,:,y:y+self.patch_size,x:x+self.patch_size] = self.patch
        return patched.clamp(-2.5,2.5)

    def train_on_batch(self, images):
        images = images.to(self.device)
        patched = self.apply_patch(images)
        outputs = self.model(patched)
        targets = torch.full((images.size(0),), self.target_class, device=self.device, dtype=torch.long)
        loss = self.criterion(outputs, targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.patch.data.clamp_(-2.5,2.5)
        return loss.item()


## 4. Patch Segmentation Dataset

In [4]:

class PatchSegmentationDataset(Dataset):
    def __init__(self, base_dataset, patch, patch_size):
        self.base_dataset = base_dataset
        self.patch = patch
        self.patch_size = patch_size

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        img,_ = self.base_dataset[idx]
        img = img.clone()
        _,H,W = img.shape
        mask = torch.zeros((H,W))
        x = random.randint(0,W-self.patch_size)
        y = random.randint(0,H-self.patch_size)
        img[:,y:y+self.patch_size,x:x+self.patch_size] = self.patch
        mask[y:y+self.patch_size,x:x+self.patch_size] = 1
        return img, mask


## 5. Visualization Helpers

In [6]:

def show_patch(patch):
    p = patch.detach().cpu().permute(1,2,0)
    p = (p*torch.tensor(imagenet_std)+torch.tensor(imagenet_mean)).clamp(0,1)
    plt.imshow(p)
    plt.axis('off')
    plt.show()
