<a href="https://colab.research.google.com/github/Shreya-Mendi/XAI/blob/Colab/Adversarial/patch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Adversarial patch notebook — overview

This notebook trains an adversarial patch that, when pasted onto images, causes a target model to predict the *pineapple* class (ImageNet class 953).

A lot of the adversarial tasks code was influenced by the course github repo provided in class: https://github.com/AIPI-590-XAI/Duke-AI-XAI/blob/main/adversarial-ai-example-notebooks/adversarial_attacks_patches.ipynb

High-level structure (map to code below):

1. Install dependencies (Colab-ready): package installs for torch, torchvision and imaging libraries.
2. Download dataset and pretrained resources: TinyImageNet archive and saved patches/checkpoints.
3. Load model & dataset: load a pretrained ResNet-34, set to eval mode, and prepare a DataLoader for TinyImageNet images.
4. Patch definition & helpers: create the strawberry mask, conversion helpers, and the differentiable placement function `place_patch_batch_tensor` which composites the patch into batches of images while allowing gradients to flow to the patch parameters.
5. Training loop: optimize a small patch (tensor parameter) with a margin loss that encourages the patched images to be classified as the target class.
6. Saving & evaluation: save the final patch image, evaluate targeted success rate on sample batches, and create upload/print-ready canvases.
7. Visualization: show top-5 predictions for patched examples to inspect behavior.

Notes and cautions:

- Kept everything in a single cell as the cells were getting interupted during run-time and ending the training process and killing the kernel on vs code. Only worked continuously for me when kept in a single cell
- The training loop can be compute-intensive; run on a GPU for reasonable runtimes.
- The patch training strength depends on number of epochs, learning rate, and patch size — tune with care.


In [None]:
# Strawberry -> Pineapple Adversarial Patch

# This creates a new patch trained on the same TinyImageNet subset the course uses and saves
# a strawberry-shaped patch that targets ImageNet class 953 (pineapple).


# 1) Install dependencies

# !pip install -q --upgrade pip
# !pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu118
# !pip install -q pillow numpy tqdm matplotlib requests scikit-image opencv-python-headless


# 2) Download TinyImageNet & pretrained resources (same as tutorial in the course github)

import os
import zipfile
from urllib.error import HTTPError
import urllib.request
from tqdm.notebook import tqdm

DATASET_PATH = "../data"
CHECKPOINT_PATH = "../saved_models/tutorial10"
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial10/"
os.makedirs(DATASET_PATH, exist_ok=True)
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

files_to_get = [(DATASET_PATH, "TinyImageNet.zip"), (CHECKPOINT_PATH, "patches.zip")]
for d, fname in files_to_get:
    fpath = os.path.join(d, fname)
    if not os.path.isfile(fpath):
        url = base_url + fname
        print("Downloading", url)
        try:
            urllib.request.urlretrieve(url, fpath)
        except HTTPError as e:
            print("Download failed:", e)
        if fname.endswith('.zip') and os.path.isfile(fpath):
            print("Unzipping", fpath)
            with zipfile.ZipFile(fpath, 'r') as z:
                z.extractall(d)

# 3) Load model, dataset, and helpers

import json
import torch
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from pathlib import Path

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Load ResNet34 pretrained
model = torchvision.models.resnet34(weights='IMAGENET1K_V1')
model = model.to(device)
model.eval()
for p in model.parameters(): p.requires_grad = False

# ImageNet normalization
NORM_MEAN = [0.485, 0.456, 0.406]
NORM_STD = [0.229, 0.224, 0.225]
plain_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=NORM_MEAN, std=NORM_STD)
])

# Load TinyImageNet dataset
imagenet_path = os.path.join(DATASET_PATH, 'TinyImageNet')
if not os.path.isdir(imagenet_path):
    raise FileNotFoundError(f"TinyImageNet not found at {imagenet_path}. Check downloads and rerun.")

dataset = ImageFolder(root=imagenet_path, transform=plain_transforms)
print('Dataset length:', len(dataset))

# DataLoader for training
BATCH_SIZE = 128
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, drop_last=True)

# Load label names if available in TinyImageNet folder
label_names_path = os.path.join(imagenet_path, 'label_list.json')
if os.path.isfile(label_names_path):
    with open(label_names_path, 'r') as f:
        label_names = json.load(f)
else:
    # fallback to torchvision labels
    labels_txt = 'imagenet_classes.txt'
    if not os.path.isfile(labels_txt):
        !wget -q -O imagenet_classes.txt https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
    with open('imagenet_classes.txt','r') as f:
        label_names = [l.strip() for l in f.readlines()]
print('Loaded', len(label_names), 'labels')


# 4) Training : strawberry-shaped adversarial patch targeting pineapple (953)

# Adjust PATCH_EPOCHS and BATCH_SIZE as needed if computationally heavy

import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image, ImageDraw, ImageFilter
from tqdm.notebook import tqdm

# USER PARAMS
TARGET_CLASS = 953   # pineapple
PATCH_SIZE = 96      # 96x96 strong printable patch
PATCH_EPOCHS = 8     # increase for stronger patch
LR = 0.03
CLIP_NORM = 0.1

# strawberry mask
def make_strawberry_mask(sz):
    w,h = sz,sz
    mask = Image.new('L', (w,h), 0)
    d = ImageDraw.Draw(mask)
    d.ellipse([w*0.12, h*0.28, w*0.88, h*0.92], fill=255)
    d.polygon([(w*0.5,h*0.05),(w*0.12,h*0.35),(w*0.88,h*0.35)], fill=255)
    return mask.filter(ImageFilter.GaussianBlur(radius=1))

mask_pil = make_strawberry_mask(PATCH_SIZE)
mask_alpha = transforms.ToTensor()(mask_pil).to(device)

TENSOR_MEANS = torch.tensor(NORM_MEAN, device=device)[:,None,None]
TENSOR_STD = torch.tensor(NORM_STD, device=device)[:,None,None]

# mapping functions
def patch_forward_norm(p):
    return (torch.tanh(p) + 1.0 - 2.0 * TENSOR_MEANS) / (2.0 * TENSOR_STD)

def patch_pixels(p):
    return (torch.tanh(p).detach().cpu().numpy().transpose(1,2,0) + 1.0) / 2.0


import torch.nn.functional as F
from PIL import Image
import numpy as np

# Differentiable place_patch_batch_tensor using affine transforms on GPU
import torch.nn.functional as F
import math

def place_patch_batch_tensor(imgs_norm, patch_param, mask_alpha, min_scale=0.6, max_scale=1.0, max_angle=12):
    """
    Differentiable placement of patch into a batch of normalized images.
    Uses F.interpolate and F.affine_grid + F.grid_sample so gradients flow to patch_param.
    imgs_norm: (N,3,H,W) normalized tensors on device
    patch_param: (3,P,P) parameter on device (requires_grad=True)
    mask_alpha: (1,P,P) mask tensor on device (values 0..1)
    """
    device = imgs_norm.device
    N, C, H, W = imgs_norm.shape
    out = imgs_norm.clone()
    P = patch_param.shape[1]

    # normalized patch in model input space (differentiable)
    patch_norm_full = patch_forward_norm(patch_param)  # (3,P,P) on device
    alpha_full = mask_alpha  # (1,P,P) on device

    for i in range(N):
        scale = float(np.random.uniform(min_scale, max_scale))
        angle = float(np.random.uniform(-max_angle, max_angle))
        newP = max(1, int(P * scale))

        # crop/resize the patch and alpha to newP via interpolate (diff)
        patch_resized = F.interpolate(patch_norm_full.unsqueeze(0), size=(newP, newP),
                                     mode='bilinear', align_corners=False).squeeze(0)  # (3,newP,newP)
        alpha_resized = F.interpolate(alpha_full.unsqueeze(0), size=(newP, newP),
                                      mode='bilinear', align_corners=False).squeeze(0)  # (1,newP,newP)

        # If angle != 0, rotate using affine grid (differentiable)
        if abs(angle) > 1e-3:
            # Build affine transform matrix for rotation about center.
            theta = math.radians(angle)
            cos_t = math.cos(theta)
            sin_t = math.sin(theta)
            # rotation matrix (2x3) for grid_sample expects mapping from output -> input
            # We want to rotate the patch by 'angle' degrees: use standard rotation matrix.
            M = torch.tensor([[cos_t, -sin_t, 0.0],
                              [sin_t,  cos_t, 0.0]], dtype=torch.float32, device=device)  # (2,3)

            # grid_sample uses normalized coords; we want to rotate patch about its center.
            # Create grid for the small patch size
            grid = F.affine_grid(M.unsqueeze(0), size=(1, C, newP, newP), align_corners=False)  # (1,newP,newP,2)
            # apply to patch and alpha
            patch_resized = F.grid_sample(patch_resized.unsqueeze(0), grid, mode='bilinear', padding_mode='zeros', align_corners=False).squeeze(0)
            alpha_resized = F.grid_sample(alpha_resized.unsqueeze(0), grid, mode='bilinear', padding_mode='zeros', align_corners=False).squeeze(0)

        # choose position ensuring full in-bounds placement
        max_x = max(0, W - newP)
        max_y = max(0, H - newP)
        x = np.random.randint(0, max_x+1) if max_x>0 else 0
        y = np.random.randint(0, max_y+1) if max_y>0 else 0
        x1, y1, x2, y2 = x, y, x + newP, y + newP

        # region in out
        out_region = out[i, :, y1:y2, x1:x2]
        # sanity check sizes
        if out_region.shape[1] != newP or out_region.shape[2] != newP:
            continue

        # composite: out = patch * alpha + out_region * (1-alpha)
        alpha_resized = alpha_resized.clamp(0.0, 1.0)  # (1,newP,newP)
        out[i, :, y1:y2, x1:x2] = patch_resized * alpha_resized + out_region * (1.0 - alpha_resized)

    return out

# optimizer & patch
patch_param = nn.Parameter(torch.zeros(3, PATCH_SIZE, PATCH_SIZE, device=device))
with torch.no_grad():
    patch_param.normal_(mean=0.0, std=0.25)
optimizer = torch.optim.Adam([patch_param], lr=LR)

# margin loss function
def margin_loss(logits, target_idx, margin=0.0):
    tlog = logits[:, target_idx]
    other = logits.clone()
    other[:, target_idx] = -1e9
    max_other, _ = other.max(dim=1)
    return torch.clamp(max_other - tlog + margin, min=0.0).mean()

# try torch.compile for speed
try:
    model = torch.compile(model)
    print('Compiled model')
except Exception:
    pass

# training loop
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))
print('Training patch...')
for epoch in range(PATCH_EPOCHS):
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{PATCH_EPOCHS}')
    for imgs, labels in pbar:
        imgs = imgs.to(device)
        patched = place_patch_batch_tensor(imgs.clone(), patch_param, mask_alpha, min_scale=0.7, max_scale=1.0, max_angle=12)
        target_labels = torch.full((patched.size(0),), TARGET_CLASS, dtype=torch.long, device=device)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
            logits = model(patched)
            loss = margin_loss(logits, TARGET_CLASS)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_([patch_param], CLIP_NORM)
        scaler.step(optimizer)
        scaler.update()
        with torch.no_grad():
            patch_param.clamp_(-3.0, 3.0)
        with torch.no_grad():
            preds = logits.argmax(dim=1)
            batch_rate = (preds == TARGET_CLASS).float().mean().item()
        pbar.set_postfix(loss=float(loss.item()), batch_rate=float(batch_rate))

# save patch image
final_pixels = patch_pixels(patch_param)
pil_patch = Image.fromarray((final_pixels*255).astype(np.uint8))
alpha = mask_pil.resize(pil_patch.size, Image.BILINEAR)
pil_patch.putalpha(alpha)
outname = f'strawberry_patch_pineapple_{PATCH_SIZE}px.png'
pil_patch.save(outname)
print('Saved patch to', outname)

# quick eval function

def eval_patch_on_loader(patch_param, loader, n_batches=50, trials_per_image=4):
    total=0; successes=0
    with torch.no_grad():
        for i,(imgs,labels) in enumerate(loader):
            imgs = imgs.to(device)
            for _ in range(trials_per_image):
                patched = place_patch_batch_tensor(imgs.clone(), patch_param, mask_alpha, min_scale=0.75, max_scale=1.0, max_angle=0)
                logits = model(patched)
                preds = logits.argmax(dim=1)
                mask_non_target = (labels.to(device) != TARGET_CLASS)
                successes += ((preds == TARGET_CLASS) & mask_non_target).sum().item()
                total += mask_non_target.sum().item()
            if i >= n_batches-1: break
    return successes / max(1, total)

print('Estimating success rate...')
est = eval_patch_on_loader(patch_param, train_loader, n_batches=50, trials_per_image=4)
print(f'Estimated targeted success: {est*100:.2f}%')


Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial10/TinyImageNet.zip
Unzipping ../data/TinyImageNet.zip
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial10/patches.zip
Unzipping ../saved_models/tutorial10/patches.zip
Device: cpu
Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 166MB/s]


Dataset length: 5000
Loaded 1000 labels
Compiled model
Training patch...


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=='cuda'))


Epoch 1/8:   0%|          | 0/39 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):
W1101 23:00:15.762000 393 torch/utils/cpp_extension.py:118] [0/0] No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
  with torch.cuda.amp.autocast(enabled=(device.type=='cuda')):


In [None]:

from PIL import Image
import os

# original saved patch
orig_patch_path = outname  # 'strawberry_patch_pineapple_96px.png' from your script
assert os.path.isfile(orig_patch_path), f"Patch not found: {orig_patch_path}"

# load original patch
orig = Image.open(orig_patch_path).convert("RGBA")

# create 224x224 canvas and paste original patch at native size (centered)
CANVAS_SIZE = 224
canvas_white = Image.new("RGBA", (CANVAS_SIZE, CANVAS_SIZE), (255,255,255,255))
x = (CANVAS_SIZE - orig.width) // 2
y = (CANVAS_SIZE - orig.height) // 2
canvas_white.paste(orig, (x,y), mask=orig)  # preserves the patch pixels, no scaling

# convert to RGB (drops alpha) and save clean PNG and high-quality JPEG variants
upload_png = f"patch_on_224_white.png"
upload_jpg_q95 = f"patch_on_224_white_q95.jpg"

canvas_white_rgb = canvas_white.convert("RGB")
canvas_white_rgb.save(upload_png, format="PNG")
canvas_white_rgb.save(upload_jpg_q95, format="JPEG", quality=95)

print("Saved upload-safe variants (no resizing of patch):")
print(" -", os.path.abspath(upload_png))
print(" -", os.path.abspath(upload_jpg_q95))

canvas_black = Image.new("RGBA", (CANVAS_SIZE, CANVAS_SIZE), (0,0,0,255))
canvas_black.paste(orig, (x,y), mask=orig)
canvas_black.convert("RGB").save("patch_on_224_black_q95.jpg", format="JPEG", quality=95)
print(" -", os.path.abspath("patch_on_224_black_q95.jpg"))


# For printing,
print_size_in = 3.0   # inches
dpi = 300
out_px = int(print_size_in * dpi)
# create big canvas and paste a resized copy of the patch so print is crisp; change/remove if undesired
big_canvas = Image.new("RGBA", (out_px, out_px), (255,255,255,255))
# scale factor to fill most of the big canvas but keep shape approx. If you want native-pixel-only, skip resize.
scale = out_px / CANVAS_SIZE  # scales 224->out_px proportionally
patched_resized = orig.resize((int(orig.width*scale), int(orig.height*scale)), Image.LANCZOS)
bx = (out_px - patched_resized.width)//2
by = (out_px - patched_resized.height)//2
big_canvas.paste(patched_resized, (bx,by), mask=patched_resized)
big_out = "patch_print_3in_300dpi.png"
big_canvas.convert("RGB").save(big_out, dpi=(dpi,dpi))
print("Saved print-ready PNG:", os.path.abspath(big_out))


Saved upload-safe variants (no resizing of patch):
 - /Users/shreyamendi/XAI/Adversarial/patch_on_224_white.png
 - /Users/shreyamendi/XAI/Adversarial/patch_on_224_white_q95.jpg
 - /Users/shreyamendi/XAI/Adversarial/patch_on_224_black_q95.jpg
Saved print-ready PNG: /Users/shreyamendi/XAI/Adversarial/patch_print_3in_300dpi.png


In [None]:
# show top-5 preds for some examples so we know what's happening
import torch, numpy as np
from torchvision import transforms

def show_top5_examples(patch_param, loader, n=5):
    model.eval()
    cnt = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs = imgs.to(device)[:8]
            patched = place_patch_batch_tensor(imgs.clone(), patch_param, mask_alpha, min_scale=0.8, max_scale=1.0, max_angle=5)
            logits = model(patched)
            probs = torch.softmax(logits, dim=1)
            top5 = probs.topk(5, dim=1)
            for i in range(patched.shape[0]):
                print("Example", cnt+i)
                print("Top-5 indices:", top5.indices[i].cpu().tolist())
                print("Top-5 names:", [label_names[idx] for idx in top5.indices[i].cpu().tolist()])
                print("Top-5 probs:", top5.values[i].cpu().numpy())
                if i >= n-1: break
            cnt += patched.shape[0]
            if cnt >= n: break

show_top5_examples(patch_param, train_loader, n=6)


Example 0
Top-5 indices: [953, 956, 363, 87, 954]
Top-5 names: ['pineapple', 'custard apple', 'armadillo', 'African grey', 'banana']
Top-5 probs: [9.9888402e-01 2.5211522e-04 2.0851233e-04 1.2242001e-04 9.8147786e-05]
Example 1
Top-5 indices: [953, 956, 131, 135, 954]
Top-5 names: ['pineapple', 'custard apple', 'little blue heron', 'limpkin', 'banana']
Top-5 probs: [9.9913824e-01 2.5549458e-04 1.8953078e-04 1.1735251e-04 8.2710525e-05]
Example 2
Top-5 indices: [953, 956, 411, 721, 529]
Top-5 names: ['pineapple', 'custard apple', 'apron', 'pillow', 'diaper']
Top-5 probs: [0.929177   0.02895793 0.01474138 0.00280569 0.00250522]
Example 3
Top-5 indices: [953, 954, 998, 987, 956]
Top-5 names: ['pineapple', 'banana', 'ear', 'corn', 'custard apple']
Top-5 probs: [9.9997890e-01 8.3316745e-06 3.1029329e-06 3.0556976e-06 2.6305511e-06]
Example 4
Top-5 indices: [953, 663, 698, 483, 956]
Top-5 names: ['pineapple', 'monastery', 'palace', 'castle', 'custard apple']
Top-5 probs: [0.993385   0.001639



**Acknowledgement / citation**

Piece-wise notebook code and the training/image conversion tasks in this file were generated with the assistance of ChatGPT at 6:30pm on Nov 1.
