In [4]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image, draw_bounding_boxes
from torchvision.ops import box_convert, nms
from torch.nn.functional import interpolate
from PIL import Image
import numpy as np
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
from inria_dataloader import get_inria_dataloader
from tmm import load_detr_r50, NestedTensor

# -----------------------
# Config
# -----------------------
ROOT = "/opt/data/private/BlackBox"
SAVE_DIR = os.path.join(ROOT, "save", "attack")
os.makedirs(SAVE_DIR, exist_ok=True)

# Patch params
PATCH_SIDE = 300
MIN_PATCH_PX = 16
SCORE_THRESH = 0.5
IOU_NMS_THRESH = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model input dimensions
MODEL_INPUT_H, MODEL_INPUT_W = 640, 640

# Load trained patch
PATCH_PATH = os.path.join(ROOT, "save", "demo", "final_patch.pt")  # Change path as needed
patch = torch.load(PATCH_PATH)  # [3,H,W] or [1,3,H,W]
if patch.ndimension() == 3:
    patch = patch.unsqueeze(0).to(DEVICE)
else:
    patch = patch.to(DEVICE)

# -----------------------
# Helpers
# -----------------------
def detr_boxes_to_xyxy_pixel(pred_boxes):
    """
    pred_boxes: [Q,4] cx,cy,w,h (normalized 0..1 or absolute)
    returns [Q,4] xyxy in pixel coords (CPU tensor)
    """
    pb = pred_boxes.clone()
    if pb.max() <= 1.01:
        pb[:, 0] = pb[:, 0] * MODEL_INPUT_W
        pb[:, 1] = pb[:, 1] * MODEL_INPUT_H
        pb[:, 2] = pb[:, 2] * MODEL_INPUT_W
        pb[:, 3] = pb[:, 3] * MODEL_INPUT_H
    xyxy = box_convert(pb, in_fmt='cxcywh', out_fmt='xyxy')
    return xyxy.cpu()

def eot_transform_patch(patch_tensor: torch.Tensor):
    """
    patch_tensor: [1,3,H,W] on DEVICE
    apply small random scale / rotate / brightness / contrast
    returns transformed patch [1,3,h2,w2]
    """
    scale = float(np.random.uniform(0.9, 1.1))
    new_side = max(1, int(round(PATCH_SIDE * scale)))
    p = interpolate(patch_tensor, size=(new_side, new_side), mode='bilinear', align_corners=False)
    
    angle = float(np.random.uniform(-10, 10))
    p = TF.affine(p, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=InterpolationMode.BILINEAR, fill=0)
    
    b = float(np.random.uniform(0.9, 1.1))
    c = float(np.random.uniform(0.9, 1.1))
    p = torch.clamp((p * c) * b, 0.0, 1.0)
    
    return p

def paste_patch_via_mask(base_img: torch.Tensor, patch_tensor: torch.Tensor, center_xy: tuple):
    """
    base_img: [3,H,W] float on device
    patch_tensor: [1,3,ph,pw] or [3,ph,pw] on same device
    center_xy: (cx, cy) pixel coords (float)
    returns new image [3,H,W] (device) with patch pasted (non-inplace, gradient-preserving)
    """
    if patch_tensor.dim() == 4 and patch_tensor.shape[0] == 1:
        p = patch_tensor[0]
    elif patch_tensor.dim() == 3:
        p = patch_tensor
    else:
        raise ValueError("invalid patch shape")

    ph, pw = p.shape[1], p.shape[2]
    cx, cy = int(round(center_xy[0])), int(round(center_xy[1]))
    x0 = cx - pw // 2
    y0 = cy - ph // 2

    H, W = base_img.shape[1], base_img.shape[2]

    # compute crop ranges
    src_x0, src_y0 = 0, 0
    dst_x0, dst_y0 = x0, y0
    dst_x1, dst_y1 = x0 + pw, y0 + ph

    if dst_x0 < 0:
        src_x0 = -dst_x0; dst_x0 = 0
    if dst_y0 < 0:
        src_y0 = -dst_y0; dst_y0 = 0
    if dst_x1 > W:
        dst_x1 = W
    if dst_y1 > H:
        dst_y1 = H

    out_w = dst_x1 - dst_x0
    out_h = dst_y1 - dst_y0
    if out_w <= 0 or out_h <= 0:
        return base_img.clone()

    src_x1 = src_x0 + out_w
    src_y1 = src_y0 + out_h
    p_cropped = p[:, src_y0:src_y1, src_x0:src_x1]

    # create mask shaped [3,H,W], zeros then set box area to 1
    mask = torch.zeros_like(base_img)
    mask[:, dst_y0:dst_y1, dst_x0:dst_x1] = 1.0

    # build padded_patch with same H,W by padding p_cropped to correct location
    padded_patch = torch.zeros_like(base_img)
    padded_patch[:, dst_y0:dst_y1, dst_x0:dst_x1] = p_cropped

    # fusion (non-inplace)
    fused = base_img * (1.0 - mask) + padded_patch * mask
    return fused

# -----------------------
# Data and model init
# -----------------------
dataloader = get_inria_dataloader("/opt/data/private/BlackBox/data/INRIAPerson", split="Test", batch_size=1, num_workers=4, disable_random_aug=True)
print("Test dataset size:", len(dataloader.dataset))

# Load DETR-R50 model
model = load_detr_r50().to(DEVICE)
model.eval()
for p in model.parameters():
    p.requires_grad = False

# -----------------------
# Attack loop
# -----------------------
for imgs, targets in dataloader:
    imgs = imgs.to(DEVICE).clamp(0, 1)  # shape [B,3,H,W], here B==1
    B = imgs.shape[0]

    # 1) Get detections from clean image using DETR
    with torch.no_grad():
        try:
            det_out = model(imgs)
        except Exception:
            det_out = model(NestedTensor(imgs))

    logits = det_out['pred_logits'][0]  # [Q,C]
    boxes = det_out['pred_boxes'][0]    # [Q,4]
    probs = torch.softmax(logits, dim=-1)
    cls_scores = probs[..., 1]  # person class
    keep_idx = (cls_scores > SCORE_THRESH).nonzero(as_tuple=False).squeeze(1) if (cls_scores > SCORE_THRESH).any() else torch.tensor([], dtype=torch.long, device=cls_scores.device)

    # Convert kept boxes to pixel xyxy
    sel_xyxy = detr_boxes_to_xyxy_pixel(boxes[keep_idx].detach().cpu())  # CPU tensor in pixels

    patched = imgs.clone()

    for box in sel_xyxy:
        xmin, ymin, xmax, ymax = box.tolist()
        box_w = max(int(xmax - xmin), 1)
        box_h = max(int(ymax - ymin), 1)
        short = min(box_w, box_h)
        scale = float(np.clip(short / PATCH_SIDE, 0.5, 2.0))
        side = max(MIN_PATCH_PX, int(round(PATCH_SIDE * scale)))

        # Apply EoT transforms to patch (returns [1,3,side,side])
        patch_to_paste = eot_transform_patch(patch)
        patch_resized = interpolate(patch_to_paste, size=(side, side), mode='bilinear', align_corners=False)
        cx = (xmin + xmax) / 2.0
        cy = (ymin + ymax) / 2.0

        # Fusion (non-inplace, gradient-preserving)
        patched[0] = paste_patch_via_mask(patched[0], patch_resized, center_xy=(cx, cy))

    # Save the patched image
    save_image(patched[0].detach().cpu(), os.path.join(SAVE_DIR, "patched_image.png"))

    # Optionally, run the model on patched image and compute metrics (AP drop, etc.)
    # (not implemented here, but you can compare detection results on patched vs clean)


  patch = torch.load(PATCH_PATH)  # [3,H,W] or [1,3,H,W]
Using cache found in /root/.cache/torch/hub/facebookresearch_detr_main


Test dataset size: 288


# 进一步优化：


In [5]:
# attack.py
import os
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image, draw_bounding_boxes
from torchvision.ops import box_convert, nms
from torch.nn.functional import interpolate
import numpy as np
from inria_dataloader import get_inria_dataloader
from tmm import load_detr_r50, NestedTensor

# -----------------------
# Config
# -----------------------
ROOT = "/opt/data/private/BlackBox"
SAVE_DIR = os.path.join(ROOT, "save", "attack")
os.makedirs(SAVE_DIR, exist_ok=True)

# Patch params
PATCH_SIDE = 300          # 原始 global patch side（训练时基准）
PATCH_RATIO = 0.15        # <<-- 新增：patch 大小 = PATCH_RATIO * target_short_side
MIN_PATCH_PX = 16
SCORE_THRESH = 0.5
IOU_NMS_THRESH = 0.5

# Model / dataset
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_INPUT_H, MODEL_INPUT_W = 640, 640
TARGET_CLASS_IDX = 1   # person class index (与训练一致)
FALLBACK_TO_TOP = True
FALLBACK_SCORE_THRESH = 0.2

# Data loader
BATCH_SIZE = 1
NUM_WORKERS = 4

# Patch path (adjust if needed)
PATCH_PATH = os.path.join(ROOT, "save", "demo", "final_patch.pt")

# -----------------------
# Load patch (safe)
# -----------------------
if not os.path.exists(PATCH_PATH):
    raise FileNotFoundError(f"Patch file not found: {PATCH_PATH}")

patch = torch.load(PATCH_PATH, map_location='cpu')  # load to cpu first
if patch.ndimension() == 3:
    patch = patch.unsqueeze(0)  # [1,3,H,W]
patch = patch.float().to(DEVICE)  # move to device

# Ensure patch in 0..1
patch = patch.clamp(0.0, 1.0)

# -----------------------
# Helpers
# -----------------------
def detr_boxes_to_xyxy_pixel(pred_boxes):
    """
    pred_boxes: [Q,4] cx,cy,w,h (normalized 0..1 or absolute)
    returns [Q,4] xyxy in pixel coords (CPU tensor)
    """
    pb = pred_boxes.clone()
    # if normalized (max <= 1.01), scale to model input dims
    if pb.max() <= 1.01:
        pb[:, 0] = pb[:, 0] * MODEL_INPUT_W
        pb[:, 1] = pb[:, 1] * MODEL_INPUT_H
        pb[:, 2] = pb[:, 2] * MODEL_INPUT_W
        pb[:, 3] = pb[:, 3] * MODEL_INPUT_H
    xyxy = box_convert(pb, in_fmt='cxcywh', out_fmt='xyxy')
    return xyxy.cpu()

def paste_patch_via_mask(base_img: torch.Tensor, patch_tensor: torch.Tensor, center_xy: tuple):
    """
    Paste patch_tensor centered at center_xy into base_img.
    base_img: [3,H,W] (device)
    patch_tensor: [1,3,ph,pw] or [3,ph,pw]
    center_xy: (cx, cy) in pixel coords (float)
    returns new image [3,H,W] (device)
    """
    if patch_tensor.dim() == 4 and patch_tensor.shape[0] == 1:
        p = patch_tensor[0]
    elif patch_tensor.dim() == 3:
        p = patch_tensor
    else:
        raise ValueError("invalid patch shape")

    ph, pw = p.shape[1], p.shape[2]
    cx, cy = int(round(center_xy[0])), int(round(center_xy[1]))
    x0 = cx - pw // 2
    y0 = cy - ph // 2

    H, W = base_img.shape[1], base_img.shape[2]

    # compute crop ranges
    src_x0, src_y0 = 0, 0
    dst_x0, dst_y0 = x0, y0
    dst_x1, dst_y1 = x0 + pw, y0 + ph

    if dst_x0 < 0:
        src_x0 = -dst_x0; dst_x0 = 0
    if dst_y0 < 0:
        src_y0 = -dst_y0; dst_y0 = 0
    if dst_x1 > W:
        dst_x1 = W
    if dst_y1 > H:
        dst_y1 = H

    out_w = dst_x1 - dst_x0
    out_h = dst_y1 - dst_y0
    if out_w <= 0 or out_h <= 0:
        return base_img.clone()

    src_x1 = src_x0 + out_w
    src_y1 = src_y0 + out_h
    p_cropped = p[:, src_y0:src_y1, src_x0:src_x1]

    # create mask shaped [3,H,W], zeros then set box area to 1
    mask = torch.zeros_like(base_img)
    mask[:, dst_y0:dst_y1, dst_x0:dst_x1] = 1.0

    # build padded_patch with same H,W by padding p_cropped to correct location
    padded_patch = torch.zeros_like(base_img)
    padded_patch[:, dst_y0:dst_y1, dst_x0:dst_x1] = p_cropped

    # fusion (non-inplace)
    fused = base_img * (1.0 - mask) + padded_patch * mask
    return fused

def detect_on_image(model, img_tensor):
    """
    Run model on image tensor (shape [1,3,H,W]) and return dict with:
      - logits: [Q,C]
      - boxes: [Q,4] (cx,cy,w,h) normalized or absolute depending on model
      - probs: softmaxed probabilities [Q,C]
    """
    with torch.no_grad():
        try:
            out = model(img_tensor)
        except Exception:
            out = model(NestedTensor(img_tensor))
    logits = out['pred_logits'][0]  # [Q,C]
    boxes = out['pred_boxes'][0]    # [Q,4]
    probs = torch.softmax(logits, dim=-1)
    return {'logits': logits, 'boxes': boxes, 'probs': probs}

# -----------------------
# Data and model init
# -----------------------
dataloader = get_inria_dataloader(os.path.join(ROOT, "data", "INRIAPerson"),
                                  split="Test", batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, disable_random_aug=True)
print("Test dataset size:", len(dataloader.dataset))

# Load DETR-R50 model (white-box)
model = load_detr_r50().to(DEVICE)
model.eval()
for p in model.parameters():
    p.requires_grad = False

# -----------------------
# Attack evaluation loop
# -----------------------
total_images = 0
sum_clean_counts = 0
sum_patched_counts = 0
sum_clean_conf = 0.0
sum_patched_conf = 0.0

for idx, (imgs, targets) in enumerate(dataloader):
    total_images += 1
    imgs = imgs.to(DEVICE).clamp(0, 1)  # [B=1,3,H,W]
    img = imgs  # batch size 1 assumed
    # --- 1) detect on clean image
    clean_out = detect_on_image(model, img)
    logits = clean_out['logits']  # [Q,C]
    boxes = clean_out['boxes']    # [Q,4]
    probs = clean_out['probs']    # [Q,C]
    cls_scores = probs[..., TARGET_CLASS_IDX]  # [Q]

    # select indices above threshold or fallback to top
    if (cls_scores > SCORE_THRESH).any():
        keep_idx = (cls_scores > SCORE_THRESH).nonzero(as_tuple=False).squeeze(1)
    else:
        keep_idx = torch.tensor([], dtype=torch.long, device=cls_scores.device)
        if FALLBACK_TO_TOP:
            top_score, top_idx = torch.max(cls_scores, dim=0)
            if top_score.item() >= FALLBACK_SCORE_THRESH:
                keep_idx = top_idx.unsqueeze(0)

    if keep_idx.numel() == 0:
        sel_xyxy = torch.empty((0,4), dtype=torch.float32)
        sel_scores = torch.empty((0,), dtype=torch.float32)
    else:
        sel_boxes = boxes[keep_idx]
        sel_scores = cls_scores[keep_idx].detach().cpu()
        sel_xyxy = detr_boxes_to_xyxy_pixel(sel_boxes.detach().cpu())  # CPU coords

        # filter by minimum box side (same as train)
        widths = (sel_xyxy[:,2] - sel_xyxy[:,0])
        heights = (sel_xyxy[:,3] - sel_xyxy[:,1])
        large_mask = (widths >= MIN_PATCH_PX) & (heights >= MIN_PATCH_PX)
        if large_mask.sum() == 0:
            sel_xyxy = torch.empty((0,4), dtype=torch.float32)
            sel_scores = torch.empty((0,), dtype=torch.float32)
        else:
            sel_xyxy = sel_xyxy[large_mask]
            sel_scores = sel_scores[large_mask]

            # run NMS (torchvision expects xyxy in float tensors)
            try:
                keep_nms = nms(sel_xyxy, sel_scores, IOU_NMS_THRESH)
            except Exception:
                keep_nms = nms(sel_xyxy.cpu(), sel_scores.cpu(), IOU_NMS_THRESH)
            sel_xyxy = sel_xyxy[keep_nms]
            sel_scores = sel_scores[keep_nms]

    # record clean stats
    clean_count = sel_xyxy.shape[0]
    avg_clean_conf = float(sel_scores.mean().item()) if sel_scores.numel() > 0 else 0.0
    sum_clean_counts += clean_count
    sum_clean_conf += avg_clean_conf

    # --- 2) build patched image (no EoT, no rotation: fixed paste)
    patched = imgs.clone()
    # iterate selected boxes and paste patch centered and scaled
    for b_idx in range(sel_xyxy.shape[0]):
        box = sel_xyxy[b_idx]
        xmin, ymin, xmax, ymax = box.tolist()
        box_w = max(int(xmax - xmin), 1)
        box_h = max(int(ymax - ymin), 1)
        short = min(box_w, box_h)
        # NEW: use PATCH_RATIO relative to short side
        side = max(MIN_PATCH_PX, int(round(PATCH_RATIO * short)))
        # ensure side at least 1
        side = max(1, side)

        # use fixed patch (no EoT / no rotation)
        patch_to_paste = patch
        patch_resized = interpolate(patch_to_paste, size=(side, side), mode='bilinear', align_corners=False)

        cx = (xmin + xmax) / 2.0
        cy = (ymin + ymax) / 2.0

        # paste (non-inplace)
        patched[0] = paste_patch_via_mask(patched[0], patch_resized.to(DEVICE), center_xy=(cx, cy))

    # save patched image (unique filename)
    save_path = os.path.join(SAVE_DIR, f"patched_{idx:04d}.png")
    save_image(patched[0].detach().cpu(), save_path)

    # --- 3) detect on patched image and gather stats
    patched_out = detect_on_image(model, patched)
    p_logits = patched_out['logits']
    p_boxes = patched_out['boxes']
    p_probs = patched_out['probs']
    p_cls_scores = p_probs[..., TARGET_CLASS_IDX]

    if (p_cls_scores > SCORE_THRESH).any():
        p_keep_idx = (p_cls_scores > SCORE_THRESH).nonzero(as_tuple=False).squeeze(1)
    else:
        p_keep_idx = torch.tensor([], dtype=torch.long, device=p_cls_scores.device)
        if FALLBACK_TO_TOP:
            top_score, top_idx = torch.max(p_cls_scores, dim=0)
            if top_score.item() >= FALLBACK_SCORE_THRESH:
                p_keep_idx = top_idx.unsqueeze(0)

    if p_keep_idx.numel() == 0:
        p_sel_xyxy = torch.empty((0,4), dtype=torch.float32)
        p_sel_scores = torch.empty((0,), dtype=torch.float32)
    else:
        p_sel_boxes = p_boxes[p_keep_idx]
        p_sel_scores = p_cls_scores[p_keep_idx].detach().cpu()
        p_sel_xyxy = detr_boxes_to_xyxy_pixel(p_sel_boxes.detach().cpu())

        widths = (p_sel_xyxy[:,2] - p_sel_xyxy[:,0])
        heights = (p_sel_xyxy[:,3] - p_sel_xyxy[:,1])
        large_mask = (widths >= MIN_PATCH_PX) & (heights >= MIN_PATCH_PX)
        if large_mask.sum() == 0:
            p_sel_xyxy = torch.empty((0,4), dtype=torch.float32)
            p_sel_scores = torch.empty((0,), dtype=torch.float32)
        else:
            p_sel_xyxy = p_sel_xyxy[large_mask]
            p_sel_scores = p_sel_scores[large_mask]
            try:
                keep_nms2 = nms(p_sel_xyxy, p_sel_scores, IOU_NMS_THRESH)
            except Exception:
                keep_nms2 = nms(p_sel_xyxy.cpu(), p_sel_scores.cpu(), IOU_NMS_THRESH)
            p_sel_xyxy = p_sel_xyxy[keep_nms2]
            p_sel_scores = p_sel_scores[keep_nms2]

    patched_count = p_sel_xyxy.shape[0]
    avg_patched_conf = float(p_sel_scores.mean().item()) if p_sel_scores.numel() > 0 else 0.0
    sum_patched_counts += patched_count
    sum_patched_conf += avg_patched_conf

    # --- 4) print per-image summary
    print(f"[{idx+1}/{len(dataloader)}] saved {save_path} | clean_count={clean_count} avg_conf={avg_clean_conf:.3f} | patched_count={patched_count} avg_conf={avg_patched_conf:.3f}")

# end loop: summarize
avg_clean = sum_clean_counts / max(1, total_images)
avg_patched = sum_patched_counts / max(1, total_images)
drop_abs = avg_clean - avg_patched
drop_rel = (drop_abs / avg_clean * 100.0) if avg_clean > 0 else 0.0
avg_clean_conf = sum_clean_conf / max(1, total_images)
avg_patched_conf = sum_patched_conf / max(1, total_images)

print("\nAttack summary:")
print(f"  total_images = {total_images}")
print(f"  avg detections (clean)  = {avg_clean:.3f}")
print(f"  avg detections (patched)= {avg_patched:.3f}")
print(f"  absolute drop            = {drop_abs:.3f}")
print(f"  relative drop (%)        = {drop_rel:.2f}%")
print(f"  avg conf (clean)        = {avg_clean_conf:.3f}")
print(f"  avg conf (patched)      = {avg_patched_conf:.3f}")

# done


  patch = torch.load(PATCH_PATH, map_location='cpu')  # load to cpu first
Using cache found in /root/.cache/torch/hub/facebookresearch_detr_main


Test dataset size: 288
[1/288] saved /opt/data/private/BlackBox/save/attack/patched_0000.png | clean_count=2 avg_conf=0.994 | patched_count=2 avg_conf=0.992
[2/288] saved /opt/data/private/BlackBox/save/attack/patched_0001.png | clean_count=1 avg_conf=0.999 | patched_count=1 avg_conf=0.998
[3/288] saved /opt/data/private/BlackBox/save/attack/patched_0002.png | clean_count=10 avg_conf=0.878 | patched_count=8 avg_conf=0.847
[4/288] saved /opt/data/private/BlackBox/save/attack/patched_0003.png | clean_count=7 avg_conf=0.970 | patched_count=8 avg_conf=0.892
[5/288] saved /opt/data/private/BlackBox/save/attack/patched_0004.png | clean_count=1 avg_conf=0.978 | patched_count=1 avg_conf=0.971
[6/288] saved /opt/data/private/BlackBox/save/attack/patched_0005.png | clean_count=1 avg_conf=0.999 | patched_count=1 avg_conf=0.999
[7/288] saved /opt/data/private/BlackBox/save/attack/patched_0006.png | clean_count=1 avg_conf=0.999 | patched_count=1 avg_conf=0.999
[8/288] saved /opt/data/private/BlackB

# 添加返回recall功能：

In [None]:
# attack_with_recall.py
import os
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision.ops import box_convert, nms
from torch.nn.functional import interpolate
import numpy as np
from inria_dataloader import get_inria_dataloader
from tmm import load_detr_r50, NestedTensor

# -----------------------
# Config
# -----------------------
ROOT = "/opt/data/private/BlackBox"
SAVE_DIR = os.path.join(ROOT, "save", "attack")
os.makedirs(SAVE_DIR, exist_ok=True)

# Patch params
PATCH_SIDE = 300
PATCH_RATIO = 0.15
MIN_PATCH_PX = 16
SCORE_THRESH = 0.5
IOU_NMS_THRESH = 0.5
IOU_MATCH_THRESH = 0.5  # 用于计算 recall 的 IoU 阈值

# Model / dataset
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_INPUT_H, MODEL_INPUT_W = 640, 640
TARGET_CLASS_IDX = 1   # person class index (与训练一致)
FALLBACK_TO_TOP = True
FALLBACK_SCORE_THRESH = 0.2

# Data loader
BATCH_SIZE = 1
NUM_WORKERS = 4

# Patch path (adjust if needed)
PATCH_PATH = os.path.join(ROOT, "save", "demo", "final_patch.pt")

# -----------------------
# Load patch
# -----------------------
if not os.path.exists(PATCH_PATH):
    raise FileNotFoundError(f"Patch file not found: {PATCH_PATH}")

patch = torch.load(PATCH_PATH, map_location='cpu')
if patch.ndimension() == 3:
    patch = patch.unsqueeze(0)
patch = patch.float().to(DEVICE).clamp(0.0, 1.0)

# -----------------------
# Helpers
# -----------------------
def detr_boxes_to_xyxy_pixel(pred_boxes):
    pb = pred_boxes.clone()
    if pb.max() <= 1.01:
        pb[:, 0] = pb[:, 0] * MODEL_INPUT_W
        pb[:, 1] = pb[:, 1] * MODEL_INPUT_H
        pb[:, 2] = pb[:, 2] * MODEL_INPUT_W
        pb[:, 3] = pb[:, 3] * MODEL_INPUT_H
    xyxy = box_convert(pb, in_fmt='cxcywh', out_fmt='xyxy')
    return xyxy.cpu()

def paste_patch_via_mask(base_img: torch.Tensor, patch_tensor: torch.Tensor, center_xy: tuple):
    if patch_tensor.dim() == 4 and patch_tensor.shape[0] == 1:
        p = patch_tensor[0]
    elif patch_tensor.dim() == 3:
        p = patch_tensor
    else:
        raise ValueError("invalid patch shape")

    ph, pw = p.shape[1], p.shape[2]
    cx, cy = int(round(center_xy[0])), int(round(center_xy[1]))
    x0 = cx - pw // 2
    y0 = cy - ph // 2

    H, W = base_img.shape[1], base_img.shape[2]

    src_x0, src_y0 = 0, 0
    dst_x0, dst_y0 = x0, y0
    dst_x1, dst_y1 = x0 + pw, y0 + ph

    if dst_x0 < 0:
        src_x0 = -dst_x0; dst_x0 = 0
    if dst_y0 < 0:
        src_y0 = -dst_y0; dst_y0 = 0
    if dst_x1 > W:
        dst_x1 = W
    if dst_y1 > H:
        dst_y1 = H

    out_w = dst_x1 - dst_x0
    out_h = dst_y1 - dst_y0
    if out_w <= 0 or out_h <= 0:
        return base_img.clone()

    src_x1 = src_x0 + out_w
    src_y1 = src_y0 + out_h
    p_cropped = p[:, src_y0:src_y1, src_x0:src_x1]

    mask = torch.zeros_like(base_img)
    mask[:, dst_y0:dst_y1, dst_x0:dst_x1] = 1.0

    padded_patch = torch.zeros_like(base_img)
    padded_patch[:, dst_y0:dst_y1, dst_x0:dst_x1] = p_cropped

    fused = base_img * (1.0 - mask) + padded_patch * mask
    return fused

def detect_on_image(model, img_tensor):
    with torch.no_grad():
        try:
            out = model(img_tensor)
        except Exception:
            out = model(NestedTensor(img_tensor))
    logits = out['pred_logits'][0]
    boxes = out['pred_boxes'][0]
    probs = torch.softmax(logits, dim=-1)
    return {'logits': logits, 'boxes': boxes, 'probs': probs}

def iou_xyxy(box_a, box_b):
    """
    box_a: (4,) xyxy
    box_b: (N,4) xyxy
    return IoU scalar or array
    """
    ax1, ay1, ax2, ay2 = box_a
    bx1 = box_b[:,0]; by1 = box_b[:,1]; bx2 = box_b[:,2]; by2 = box_b[:,3]

    inter_x1 = np.maximum(ax1, bx1)
    inter_y1 = np.maximum(ay1, by1)
    inter_x2 = np.minimum(ax2, bx2)
    inter_y2 = np.minimum(ay2, by2)

    inter_w = np.maximum(0.0, inter_x2 - inter_x1)
    inter_h = np.maximum(0.0, inter_y2 - inter_y1)
    inter_area = inter_w * inter_h

    area_a = max(0.0, (ax2 - ax1)) * max(0.0, (ay2 - ay1))
    area_b = (bx2 - bx1) * (by2 - by1)
    union = area_a + area_b - inter_area + 1e-12
    return inter_area / union

def extract_gt_boxes_from_targets(targets):
    """
    Try multiple common formats to extract GT boxes.
    Return tensor (K,4) in xyxy pixel coordinates on CPU, or None if cannot extract.
    Supported possibilities:
      - targets is a list of dicts (len==batchsize), each dict with 'boxes' (Tensor Nx4) in xyxy pixel or normalized cxcywh or xyxy
      - targets is a dict with 'boxes'
      - targets is Tensor directly
    """
    # if dataloader returns list/tuple per batch:
    if targets is None:
        return None

    # If targets is a list with length B and each element a dict
    if isinstance(targets, (list, tuple)) and len(targets) > 0:
        t0 = targets[0]
        if isinstance(t0, dict):
            if 'boxes' in t0:
                boxes = t0['boxes']
                if isinstance(boxes, torch.Tensor):
                    boxes = boxes.detach().cpu()
                    # try to detect format: if values in 0..1 assume normalized cxcywh or xyxy?
                    if boxes.max() <= 1.01:
                        # Many datasets store xyxy normalized; attempt to detect by shape (cxcywh vs xyxy ambiguous)
                        # If values look like cxcywh (cx <=1 && w<=1), we convert from cxcywh->xyxy via MODEL dims
                        # Heuristic: if second coordinate max <=1 as well -> assume normalized cxcywh (common)
                        # We will assume boxes in (xmin,ymin,xmax,ymax) normalized if any width > 1 after converting from cxcywh
                        # To be conservative, we attempt both heuristics:
                        # If first two columns <=1 and third/fourth <=1 -> assume normalized cxcywh (as in detr) and convert
                        if boxes.shape[-1] == 4:
                            # Try to detect cxcywh by seeing if center coords near 0.5 and widths small
                            # We'll assume cxcywh from DETR-style: convert cxcywh->xyxy
                            pb = boxes.clone()
                            pb[:,0] = pb[:,0] * MODEL_INPUT_W
                            pb[:,1] = pb[:,1] * MODEL_INPUT_H
                            pb[:,2] = pb[:,2] * MODEL_INPUT_W
                            pb[:,3] = pb[:,3] * MODEL_INPUT_H
                            xyxy = box_convert(pb, in_fmt='cxcywh', out_fmt='xyxy')
                            return xyxy.cpu()
                    else:
                        # assume boxes already in pixel coords xyxy
                        if boxes.shape[-1] == 4:
                            # If format is cxcywh in pixels (rare), try converting — but we assume xyxy
                            return boxes.cpu()
        # fallback: if first element is a tensor of boxes
        if isinstance(t0, torch.Tensor):
            boxes = t0.detach().cpu()
            if boxes.shape[-1] == 4:
                return boxes

    # If targets is a dict (batch-level)
    if isinstance(targets, dict):
        if 'boxes' in targets:
            boxes = targets['boxes']
            if isinstance(boxes, torch.Tensor):
                boxes = boxes.detach().cpu()
                if boxes.max() <= 1.01:
                    pb = boxes.clone()
                    pb[:,0] = pb[:,0] * MODEL_INPUT_W
                    pb[:,1] = pb[:,1] * MODEL_INPUT_H
                    pb[:,2] = pb[:,2] * MODEL_INPUT_W
                    pb[:,3] = pb[:,3] * MODEL_INPUT_H
                    xyxy = box_convert(pb, in_fmt='cxcywh', out_fmt='xyxy')
                    return xyxy.cpu()
                else:
                    return boxes.cpu()

    # Unknown format
    return None

def compute_recall(gt_boxes_np, pred_boxes_np, iou_thresh=0.5):
    """
    gt_boxes_np: (G,4) numpy xyxy
    pred_boxes_np: (P,4) numpy xyxy
    returns (tp, G)
    greedy one-to-one matching by IoU >= threshold
    """
    if gt_boxes_np is None or gt_boxes_np.shape[0] == 0:
        return 0, 0
    if pred_boxes_np is None or pred_boxes_np.shape[0] == 0:
        return 0, gt_boxes_np.shape[0]

    G = gt_boxes_np.shape[0]
    P = pred_boxes_np.shape[0]
    matched_gt = np.zeros(G, dtype=bool)
    matched_pred = np.zeros(P, dtype=bool)
    tp = 0

    # compute IoU matrix
    for pi in range(P):
        pb = pred_boxes_np[pi]
        ious = iou_xyxy(pb, gt_boxes_np)  # returns array length G
        # find best GT not yet matched
        best_idx = np.argmax(ious)
        best_iou = float(ious[best_idx])
        if best_iou >= iou_thresh and (not matched_gt[best_idx]):
            tp += 1
            matched_gt[best_idx] = True
            matched_pred[pi] = True
    return int(tp), int(G)

# -----------------------
# Data and model init
# -----------------------
dataloader = get_inria_dataloader(os.path.join(ROOT, "data", "INRIAPerson"),
                                  split="Test", batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, disable_random_aug=True)
print("Test dataset size:", len(dataloader.dataset))

model = load_detr_r50().to(DEVICE)
model.eval()
for p in model.parameters():
    p.requires_grad = False

# -----------------------
# Attack evaluation loop with recall
# -----------------------
total_images = 0
sum_clean_counts = 0
sum_patched_counts = 0
sum_clean_conf = 0.0
sum_patched_conf = 0.0

total_gt_boxes = 0
total_tp_clean = 0
total_tp_patched = 0
skipped_recall_images = 0

for idx, (imgs, targets) in enumerate(dataloader):
    total_images += 1
    imgs = imgs.to(DEVICE).clamp(0, 1)
    # --- extract GT boxes (single-image batch assumed)
    gt_boxes = extract_gt_boxes_from_targets(targets)  # CPU tensor xyxy or None
    if gt_boxes is None:
        # cannot compute recall for this image
        gt_boxes_np = None
        skipped = True
        skipped_recall_images += 1
    else:
        gt_boxes_np = gt_boxes.numpy()
        skipped = False

    # --- detect clean
    clean_out = detect_on_image(model, imgs)
    logits = clean_out['logits']
    boxes = clean_out['boxes']
    probs = clean_out['probs']
    cls_scores = probs[..., TARGET_CLASS_IDX]

    if (cls_scores > SCORE_THRESH).any():
        keep_idx = (cls_scores > SCORE_THRESH).nonzero(as_tuple=False).squeeze(1)
    else:
        keep_idx = torch.tensor([], dtype=torch.long, device=cls_scores.device)
        if FALLBACK_TO_TOP:
            top_score, top_idx = torch.max(cls_scores, dim=0)
            if top_score.item() >= FALLBACK_SCORE_THRESH:
                keep_idx = top_idx.unsqueeze(0)

    if keep_idx.numel() == 0:
        sel_xyxy = torch.empty((0,4), dtype=torch.float32)
        sel_scores = torch.empty((0,), dtype=torch.float32)
    else:
        sel_boxes = boxes[keep_idx]
        sel_scores = cls_scores[keep_idx].detach().cpu()
        sel_xyxy = detr_boxes_to_xyxy_pixel(sel_boxes.detach().cpu())

        widths = (sel_xyxy[:,2] - sel_xyxy[:,0])
        heights = (sel_xyxy[:,3] - sel_xyxy[:,1])
        large_mask = (widths >= MIN_PATCH_PX) & (heights >= MIN_PATCH_PX)
        if large_mask.sum() == 0:
            sel_xyxy = torch.empty((0,4), dtype=torch.float32)
            sel_scores = torch.empty((0,), dtype=torch.float32)
        else:
            sel_xyxy = sel_xyxy[large_mask]
            sel_scores = sel_scores[large_mask]
            try:
                keep_nms = nms(sel_xyxy, sel_scores, IOU_NMS_THRESH)
            except Exception:
                keep_nms = nms(sel_xyxy.cpu(), sel_scores.cpu(), IOU_NMS_THRESH)
            sel_xyxy = sel_xyxy[keep_nms]
            sel_scores = sel_scores[keep_nms]

    clean_count = sel_xyxy.shape[0]
    avg_clean_conf = float(sel_scores.mean().item()) if sel_scores.numel() > 0 else 0.0
    sum_clean_counts += clean_count
    sum_clean_conf += avg_clean_conf

    # compute recall for clean (if gt available)
    if not skipped:
        tp_c, G = compute_recall(gt_boxes_np, sel_xyxy.numpy() if sel_xyxy.numel()>0 else np.empty((0,4)), IOU_MATCH_THRESH)
        total_gt_boxes += G
        total_tp_clean += tp_c
        recall_clean_img = (tp_c / G) if G>0 else 0.0
    else:
        recall_clean_img = None

    # --- build patched image (no EoT)
    patched = imgs.clone()
    for b_idx in range(sel_xyxy.shape[0]):
        box = sel_xyxy[b_idx]
        xmin, ymin, xmax, ymax = box.tolist()
        box_w = max(int(xmax - xmin), 1)
        box_h = max(int(ymax - ymin), 1)
        short = min(box_w, box_h)
        side = max(MIN_PATCH_PX, int(round(PATCH_RATIO * short)))
        side = max(1, side)
        patch_to_paste = patch
        patch_resized = interpolate(patch_to_paste, size=(side, side), mode='bilinear', align_corners=False)
        cx = (xmin + xmax) / 2.0
        cy = (ymin + ymax) / 2.0
        patched[0] = paste_patch_via_mask(patched[0], patch_resized.to(DEVICE), center_xy=(cx, cy))

    save_path = os.path.join(SAVE_DIR, f"patched_{idx:04d}.png")
    save_image(patched[0].detach().cpu(), save_path)

    # --- detect patched
    patched_out = detect_on_image(model, patched)
    p_logits = patched_out['logits']
    p_boxes = patched_out['boxes']
    p_probs = patched_out['probs']
    p_cls_scores = p_probs[..., TARGET_CLASS_IDX]

    if (p_cls_scores > SCORE_THRESH).any():
        p_keep_idx = (p_cls_scores > SCORE_THRESH).nonzero(as_tuple=False).squeeze(1)
    else:
        p_keep_idx = torch.tensor([], dtype=torch.long, device=p_cls_scores.device)
        if FALLBACK_TO_TOP:
            top_score, top_idx = torch.max(p_cls_scores, dim=0)
            if top_score.item() >= FALLBACK_SCORE_THRESH:
                p_keep_idx = top_idx.unsqueeze(0)

    if p_keep_idx.numel() == 0:
        p_sel_xyxy = torch.empty((0,4), dtype=torch.float32)
        p_sel_scores = torch.empty((0,), dtype=torch.float32)
    else:
        p_sel_boxes = p_boxes[p_keep_idx]
        p_sel_scores = p_cls_scores[p_keep_idx].detach().cpu()
        p_sel_xyxy = detr_boxes_to_xyxy_pixel(p_sel_boxes.detach().cpu())

        widths = (p_sel_xyxy[:,2] - p_sel_xyxy[:,0])
        heights = (p_sel_xyxy[:,3] - p_sel_xyxy[:,1])
        large_mask = (widths >= MIN_PATCH_PX) & (heights >= MIN_PATCH_PX)
        if large_mask.sum() == 0:
            p_sel_xyxy = torch.empty((0,4), dtype=torch.float32)
            p_sel_scores = torch.empty((0,), dtype=torch.float32)
        else:
            p_sel_xyxy = p_sel_xyxy[large_mask]
            p_sel_scores = p_sel_scores[large_mask]
            try:
                keep_nms2 = nms(p_sel_xyxy, p_sel_scores, IOU_NMS_THRESH)
            except Exception:
                keep_nms2 = nms(p_sel_xyxy.cpu(), p_sel_scores.cpu(), IOU_NMS_THRESH)
            p_sel_xyxy = p_sel_xyxy[keep_nms2]
            p_sel_scores = p_sel_scores[keep_nms2]

    patched_count = p_sel_xyxy.shape[0]
    avg_patched_conf = float(p_sel_scores.mean().item()) if p_sel_scores.numel() > 0 else 0.0
    sum_patched_counts += patched_count
    sum_patched_conf += avg_patched_conf

    # compute recall for patched (if gt available)
    if not skipped:
        tp_p, G_p = compute_recall(gt_boxes_np, p_sel_xyxy.numpy() if p_sel_xyxy.numel()>0 else np.empty((0,4)), IOU_MATCH_THRESH)
        # G_p should equal G (same GT count)
        total_tp_patched += tp_p
        # total_gt_boxes already incremented above for clean
        recall_patched_img = (tp_p / G) if G>0 else 0.0
    else:
        recall_patched_img = None

    # per-image print
    recall_str = f"{recall_clean_img:.3f}" if recall_clean_img is not None else "N/A"
    recall_p_str = f"{recall_patched_img:.3f}" if recall_patched_img is not None else "N/A"
    print(f"[{idx+1}/{len(dataloader)}] saved {save_path} | clean={clean_count} conf={avg_clean_conf:.3f} rec={recall_str} | patched={patched_count} conf={avg_patched_conf:.3f} rec={recall_p_str}")

# End loop: final summary
avg_clean = sum_clean_counts / max(1, total_images)
avg_patched = sum_patched_counts / max(1, total_images)
drop_abs = avg_clean - avg_patched
drop_rel = (drop_abs / avg_clean * 100.0) if avg_clean > 0 else 0.0
avg_clean_conf = sum_clean_conf / max(1, total_images)
avg_patched_conf = sum_patched_conf / max(1, total_images)

# recall summary
if total_gt_boxes > 0:
    overall_recall_clean = total_tp_clean / total_gt_boxes
    overall_recall_patched = total_tp_patched / total_gt_boxes
else:
    overall_recall_clean = None
    overall_recall_patched = None

print("\nAttack summary:")
print(f"  total_images = {total_images}")
print(f"  avg detections (clean)  = {avg_clean:.3f}")
print(f"  avg detections (patched)= {avg_patched:.3f}")
print(f"  absolute drop            = {drop_abs:.3f}")
print(f"  relative drop (%)        = {drop_rel:.2f}%")
print(f"  avg conf (clean)        = {avg_clean_conf:.3f}")
print(f"  avg conf (patched)      = {avg_patched_conf:.3f}")

if overall_recall_clean is not None:
    print(f"\n  recall (clean)   = {overall_recall_clean:.4f}  (TP={total_tp_clean}, GT={total_gt_boxes})")
    print(f"  recall (patched) = {overall_recall_patched:.4f}  (TP={total_tp_patched}, GT={total_gt_boxes})")
else:
    print("\n  recall could not be computed: GT boxes unavailable for the dataset items (extracted targets were None).")
    print(f"  skipped_recall_images = {skipped_recall_images}")
