In [None]:
from ultralytics import YOLOWorld
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch

import warnings

warnings.filterwarnings(action="ignore")
warnings.simplefilter(action="ignore")

In [None]:
yolo = YOLOWorld(model='../models/yolov8s-world.pt')
yolo = yolo.to(torch.device(device="cuda" if torch.cuda.is_available() else "cpu"))
key_layer_idx = {
    # module 2-9 same as yolov8 
    "backbone_c2f1": 2,
    "backbone_c2f2": 4,
    "backbone_c2f3": 6,
    "backbone_c2f4": 8, 
    "backbone_sppf": 9,
    # module changed
    "neck_c2f1": 15,
    "neck_c2f2": 19,
    "neck_c2f3": 22,
    "detect_head": 23
}
layers = {layer: yolo.model.model[idx] for layer, idx in key_layer_idx.items()}
detect_head = layers['detect_head']

提取输出

In [None]:
import os
from PIL import Image
from DyFilterAttack.analyzer.utils import SaveFeatures
save_feats = SaveFeatures()
save_feats.register_hooks(module=detect_head, parent_path='detect_head', verbose=True)
img_path = '../testset/bus.jpg'
results = yolo.predict(img_path)
detect_head_raw_feats = save_feats.get_features()

nl = detect_head.nl
nc, reg_max =  detect_head.nc, detect_head.reg_max
no = nc + 4 * reg_max
assert no == detect_head.no

In [None]:
# plot det result[B=0]
result = results[0]
for det in result.boxes:
    xmin, ymin, xmax, ymax = det.xyxy[0]
    conf = det.conf  # Confidence
    cls = det.cls  # Class ID
    class_name = result.names[cls[0].item()]
    print(f"bbox: {xmin}, {ymin}, {xmax}, {ymax}, conf: {conf}, class: {class_name}")

image = Image.fromarray(result.plot()[:, :, ::-1])
image.show()
image.save('./result/bus_result.jpg')

In [None]:
# process1 
# text -> (B, nc, embed_dim)
# image -> (B, embed_dim, H, W)
# cv4 contrast(iamge, text) -> (B, nc, H, W)
# cv2(image) -> (B, reg_max * 4, H, W)
# cat_result -> (B, nc + reg_max * 4, H ,W) -> (B, no, H ,W)
# x[i] -> cat_result[i] (i = 1, 2, nl)

cv2_raw_feats = [detect_head_raw_feats[f'detect_head.cv2.{i}'] for i in range(nl)]
cv4_raw_feats = [detect_head_raw_feats[f'detect_head.cv4.{i}'] for i in range(nl)]

print(f'cv2_raw_feats {0}: {cv2_raw_feats[0].size()}')
print(f'cv2_raw_feats {1}: {cv2_raw_feats[1].size()}')
print(f'cv2_raw_feats {2}: {cv2_raw_feats[2].size()}')
print(f'cv4_raw_feats {0}: {cv4_raw_feats[0].size()}')
print(f'cv4_raw_feats {1}: {cv4_raw_feats[1].size()}')
print(f'cv4_raw_feats {2}: {cv4_raw_feats[2].size()}')

# process2 (_inference)
# flat(x[i]) -> (B, no, H * W)
# cat(x) -> (B, C, H0 * W0 + H1 * W1 + H2 * W2)
# split(x) -> bbox(B, 4 * reg_max, H0 * W0 + H1 * W1 + H2 * W2), cls(logit)(B, nc, H0 * W0 + H1 * W1 + H2 * W2)
# docode(bbox) -> dbox
# logit(cls) -> sigmoid(cls)
# y -> cat(dbox, cls)

cat_raw_feats = [torch.cat((cv2_raw_feats[i], cv4_raw_feats[i]), 1) for i in range(nl)]
flatten_raw_feats = torch.cat([cat_raw_feat.view(cat_raw_feats[0].shape[0], no, -1) for cat_raw_feat in cat_raw_feats], 2)
raw_box = flatten_raw_feats[:, : reg_max * 4]
raw_cls = flatten_raw_feats[:, reg_max * 4 :]

dfl_feats = detect_head_raw_feats['detect_head.dfl']
dbox = detect_head.decode_bboxes(dfl_feats, detect_head.anchors.unsqueeze(0)) * detect_head.strides
# ! Attention: we need cls(logit) as y_det
logit_cls = raw_cls
sigmoid_cls = logit_cls.sigmoid()

print(f'cat_raw_feat: {cat_raw_feats[0].size()}')           # (B, C, H * W)
print(f'flatten_raw_feats: {flatten_raw_feats.size()}')     # (B, C, H0 * W0 + H1 * W1 + H2 * W2)
print(f'raw_box: {raw_box.size()}')                         # (B, 4 * reg_max, H0 * W0 + H1 * W1 + H2 * W2)
print(f'raw_cls: {raw_cls.size()}')                         # (B, nc, H0 * W0 + H1 * W1 + H2 * W2)
print(f'dbox: {dbox.size()}')                               # (B, 4,  H0 * W0 + H1 * W1 + H2 * W2)
print(f'logit_cls: {logit_cls.size()}')                     # (B, nc, H0 * W0 + H1 * W1 + H2 * W2)
print(f'sigmoid_cls: {sigmoid_cls.size()}')                 # (B, nc, H0 * W0 + H1 * W1 + H2 * W2)

In [None]:
# process3 (construct y_det_orig and y_de_target)
# obtain the specific cls indices selected by non_max_suppression
from ultralytics.utils import ops
import numpy as np
predictor = yolo.predictor
preds = torch.cat([dbox, sigmoid_cls], 1)  # (B, 4+nc, N)

detections, keep_idxs = ops.non_max_suppression(
    preds,
    predictor.args.conf,
    predictor.args.iou,
    predictor.args.classes,
    predictor.args.agnostic_nms,
    predictor.args.max_det,
    nc=0 if predictor.args.task == "detect" else len(predictor.model.names),
    end2end=getattr(predictor.model, "end2end", False),
    rotated=predictor.args.task == "obb",
    return_idxs=True,
)

num_nms_output = [idx.numel() for idx in keep_idxs]
max_out = max(num_nms_output)

y_det = raw_cls.new_zeros(raw_cls.shape[0], raw_cls.shape[1], max_out)
for b, idx in enumerate(keep_idxs):  
    if idx.numel() > 0:
        y_det[b, :, :idx.numel()] = flatten_raw_feats[:raw_cls.shape[0], raw_box.shape[1]:, idx]
        _det = raw_cls[:, :, idx]
        assert np.all((_det==y_det).cpu().numpy())
        

first_max_cls_idx = torch.argmax(y_det, dim=1)  # (B, max_out)
y_det_orig = y_det[torch.arange(y_det.shape[0]), first_max_cls_idx, torch.arange(y_det.shape[2])] # (B, max_out)

_, topk_indices = torch.topk(y_det, 2, dim=1)
second_max_cls_idx = topk_indices[:, 1]  # (B, max_out)
y_det_target = y_det[torch.arange(y_det.shape[0]), second_max_cls_idx, torch.arange(y_det.shape[2])] # (B, max_out)

print("shapes:")
print(f'y_det_orig       {y_det_orig.size()}')
print(f'y_det_target     {y_det_target.size()}')
print("indexs:")
print(f'y_det_orig       {first_max_cls_idx}')
print(f'y_det_target     {second_max_cls_idx}')
print("logits:")
print(f'y_det_orig       {y_det_orig}')
print(f'y_det_target     {y_det_target}')
print("sigmoid:")
print(f'y_det_orig       {y_det_orig.sigmoid()}')
print(f'y_det_target     {y_det_target.sigmoid()}')