In [None]:
import math
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, InterpolationMode
import cv2
import numpy as np
import torch
import clip
from torchvision import transforms
import csv
from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
import random
from PIL import Image
import os
import matplotlib.pyplot as plt
import open_clip


def load_csv_data(csv_path):
    data = []
    with open(csv_path, mode='r') as csv_file:
        csv_reader = csv.DictReader(csv_file)
        for row in csv_reader:
            row['entities'] = [e.strip() for e in row['entities'].split(',')]
            data.append(row)
    return data  

device = "cuda:3" 
clipmodel, preprocess = clip.load("ViT-L/14@336px", device=device)

# Load data

mean = OPENAI_DATASET_MEAN
std = OPENAI_DATASET_STD
processor_before = transforms.Compose([transforms.Resize(size=(336, 336), interpolation=transforms.InterpolationMode.BICUBIC),transforms.ToTensor(),])
processor_after = transforms.Compose([transforms.Normalize(mean, std),])


clip_inres = clipmodel.visual.input_resolution
clip_ksize = clipmodel.visual.conv1.kernel_size
def attention_layer(q, k, v, num_heads=1):
    "Compute 'Scaled Dot Product Attention'"
    tgt_len, bsz, embed_dim = q.shape
    head_dim = embed_dim // num_heads
    scaling = float(head_dim) ** -0.5
    q = q * scaling
    
    q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
    attn_output_weights = torch.bmm(q, k.transpose(1, 2))
    attn_output_weights = F.softmax(attn_output_weights, dim=-1)
    attn_output_heads = torch.bmm(attn_output_weights, v)
    assert list(attn_output_heads.size()) == [bsz * num_heads, tgt_len, head_dim]
    attn_output = attn_output_heads.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
    attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, -1)
    attn_output_weights = attn_output_weights.sum(dim=1) / num_heads
    return attn_output, attn_output_weights
    
def clip_encode_dense(x,n):
    vision_width = clipmodel.visual.transformer.width
    vision_heads = vision_width // 64
    # print("[vision_width and vision_heads]:", vision_width, vision_heads) #[vision_width and vision_heads]: 1024 16
    
    # modified from CLIP
    x = x.half()
    x = clipmodel.visual.conv1(x) 
    feah, feaw = x.shape[-2:]

    x = x.reshape(x.shape[0], x.shape[1], -1) 
    x = x.permute(0, 2, 1) 
    class_embedding = clipmodel.visual.class_embedding.to(x.dtype)
    x = torch.cat([class_embedding + torch.zeros(x.shape[0], 1, x.shape[-1]).to(x), x], dim=1) # 加入类别嵌入

    pos_embedding = clipmodel.visual.positional_embedding.to(x.dtype)
    tok_pos, img_pos = pos_embedding[:1, :], pos_embedding[1:, :]
    pos_h = clip_inres // clip_ksize[0]
    pos_w = clip_inres // clip_ksize[1]
    assert img_pos.size(0) == (pos_h * pos_w), f"the size of pos_embedding ({img_pos.size(0)}) does not match resolution shape pos_h ({pos_h}) * pos_w ({pos_w})"
    img_pos = img_pos.reshape(1, pos_h, pos_w, img_pos.shape[1]).permute(0, 3, 1, 2)

    img_pos = torch.nn.functional.interpolate(img_pos, size=(feah, feaw), mode='bicubic', align_corners=False)
    img_pos = img_pos.reshape(1, img_pos.shape[1], -1).permute(0, 2, 1)
    pos_embedding = torch.cat((tok_pos[None, ...], img_pos), dim=1)
    x = x + pos_embedding

    x = clipmodel.visual.ln_pre(x)
    
    x = x.permute(1, 0, 2)  

    x = torch.nn.Sequential(*clipmodel.visual.transformer.resblocks[:-n])(x)

    attns = []
    atten_outs = []
    vs = []
    qs = []
    ks = []
    x_in_list = []
    x_out_list = []

    for TR in clipmodel.visual.transformer.resblocks[-n:]:
        x_in = x
        x = TR.ln_1(x_in)
        x_in_list.append(x)
        linear = torch._C._nn.linear    
        q, k, v = linear(x, TR.attn.in_proj_weight, TR.attn.in_proj_bias).chunk(3, dim=-1)
        attn_output, attn = attention_layer(q, k, v, vision_heads)  
        attns.append(attn)
        atten_outs.append(attn_output)
        vs.append(v)
        qs.append(q)
        ks.append(k)
        
        x_after_attn = linear(attn_output, TR.attn.out_proj.weight, TR.attn.out_proj.bias)     
        x = x_after_attn + x_in
        x = x + TR.mlp(TR.ln_2(x))
        ww = x.permute(1, 0, 2)
        ww = clipmodel.visual.ln_post(ww)
        ww = ww @ clipmodel.visual.proj
        x_out_list.append(ww)

    x = x.permute(1, 0, 2)  # LND -> NLD
    x = clipmodel.visual.ln_post(x)
    x = x @ clipmodel.visual.proj
    return x, x_in_list, vs, qs, ks, attns, atten_outs, (feah, feaw)

def save_image(image, output_path):
    # Convert to PIL image and save
    img_np = image.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0)
    img_np = np.clip(img_np * 255, 0, 255).astype(np.uint8)
    img_pil = Image.fromarray(img_np)
    img_pil.save(output_path)

def plot_adv_ori(input_img, output_adv):
    original_img = input_img.cpu() 
    adversarial_img = output_adv.cpu()  

    original_img = original_img.squeeze(0)
    adversarial_img = adversarial_img.squeeze(0)

    original_img = original_img.permute(1, 2, 0)
    adversarial_img = adversarial_img.permute(1, 2, 0)

    original_img = original_img.numpy()
    adversarial_img = adversarial_img.numpy()

    difference = np.abs(original_img - adversarial_img) * 3

    difference_normalized = (difference - difference.min()) / (difference.max() - difference.min())

    plt.figure(figsize=(15, 5))


    plt.subplot(1, 3, 1)
    plt.imshow(original_img)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(adversarial_img)
    plt.title('Adversarial Image')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(difference_normalized, cmap='hot')  
    plt.title('Difference (Original vs Adversarial)')
    plt.axis('off')

    plt.tight_layout()
    plt.show()


def get_text_feat(text, model, device):

    ## ori_text_feat
    tokenizer = open_clip.get_tokenizer(model_name='ViT-L-14-336')
    prompts = tokenizer(text).to(device)
    text_embedding = model.encode_text(prompts)
    text_embedding = F.normalize(text_embedding, dim=-1)
    text_feat = text_embedding.unsqueeze(0)# [1, num_prompt, 768]

    return text_feat

def get_mask_and_feat(model, text_feat, ori_vs, ori_ks, if_plot = False):

    # get ori_feat_list
    ori_feat_list = [] 
    for j in range(0, 24):
        ww = ori_vs[j]
        linear = torch._C._nn.linear
        TR = model.visual.transformer.resblocks[j]
        ww = linear(ww, TR.attn.out_proj.weight, TR.attn.out_proj.bias)
        ww = ww.permute(1, 0, 2)  
        ww = model.visual.ln_post(ww)
        ww = ww @ model.visual.proj
        ori_feat = F.normalize(ww, dim=-1) # [1, num_patch, 768]
        ori_feat_list.append(ori_feat)
    
    # get mask
    xx = ori_vs[-1]
    linear = torch._C._nn.linear
    TR = model.visual.transformer.resblocks[-1]
    for i in range(20):
        xx , _ =  attention_layer(ori_ks[8], ori_ks[8], xx , 16)
    xx = linear(xx, TR.attn.out_proj.weight, TR.attn.out_proj.bias)
    xx = xx.permute(1, 0, 2)  
    xx = model.visual.ln_post(xx)
    xx = xx @ model.visual.proj
    ori_feat = F.normalize(xx, dim=-1)

    img_txt_matching = ori_feat[:, 1:, :] @ text_feat.transpose(-1, -2)
    mask_list = []
    xmask_list = []
    for i in range(len(text)):
        if len(text)==1:
            x = img_txt_matching.squeeze()
        else:
            x = img_txt_matching.squeeze()[:,i]
        threshold = (min(x) + max(x))/2 
        
        # gen mask
        xmask = x.reshape(24, 24)
        mask = (xmask > threshold).float()
        
        xmask_resized = F.interpolate(
            xmask.unsqueeze(0).unsqueeze(0),  
            size=(336, 336),
            mode='bilinear',
            align_corners= False
        ).squeeze() 
        xmask_list.append(xmask_resized)

        mask_resized = F.interpolate(
            mask.unsqueeze(0).unsqueeze(0),  
            size=(336, 336),
            mode="nearest"  ,
            align_corners= None
        ).squeeze() 
        mask_list.append(mask_resized)
        
        # plot mask
        if if_plot:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
            im1 = ax1.imshow(x.reshape(24, 24).detach().cpu().numpy())
            ax1.set_title('Original Image (24x24)')
            ax1.axis('off')
            fig.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)
            im2 = ax2.imshow(mask_resized.cpu().numpy(), cmap='gray', vmin=0, vmax=1)
            ax2.set_title('Resized Mask (336x336)')
            ax2.axis('off')
            fig.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)
            plt.tight_layout()
            plt.show()

    return ori_feat_list, mask_list, xmask_list
        
def random_crop_within_bounds(X, mask, xmask, min_scale=0.75, max_scale=1.0):

    img_height, img_width = X.size(1), X.size(2)

    if xmask.numel() > 0:
        k = int(0.15 * xmask.numel())
        if k == 0:
            k = 1
        top_values = torch.topk(xmask.flatten(), k).values
        threshold = top_values[-1]
        important_mask = xmask >= threshold
    else:
        important_mask = torch.zeros_like(xmask, dtype=torch.bool)
    

    rows = torch.any(important_mask, dim=1)
    cols = torch.any(important_mask, dim=0)
    
    if torch.any(rows) and torch.any(cols):

        min_row, max_row = torch.where(rows)[0][[0, -1]]
        min_col, max_col = torch.where(cols)[0][[0, -1]]
        
    else:

        rows = torch.any(mask, dim=1)
        cols = torch.any(mask, dim=0)
        min_row, max_row = torch.where(rows)[0][[0, -1]]
        min_col, max_col = torch.where(cols)[0][[0, -1]]

    min_height = max_row - min_row + 1
    min_width = max_col - min_col + 1

    crop_height = int(img_height * random.uniform(min_scale, max_scale))
    crop_width = int(img_width * random.uniform(min_scale, max_scale))

    crop_height = min(crop_height, img_height)
    crop_width = min(crop_width, img_width)

    crop_height = max(crop_height, min_height)
    crop_width = max(crop_width, min_width)

    max_top = min(min_row, img_height - crop_height)
    max_left = min(min_col, img_width - crop_width)

    max_top = max(0, max_top)
    max_left = max(0, max_left)
    

    top_range_min = max(0, min_row - (crop_height - min_height))
    top_range_max = max_top
    left_range_min = max(0, min_col - (crop_width - min_width))
    left_range_max = max_left
    

    if top_range_min == top_range_max and left_range_min == left_range_max:
        crop_top = top_range_min
        crop_left = left_range_min
    else:
        crop_top = random.randint(top_range_min, top_range_max)
        crop_left = random.randint(left_range_min, left_range_max)

    cropped_X = X[:,
                 crop_top:crop_top + crop_height,
                 crop_left:crop_left + crop_width]
    
    cropped_mask = mask[crop_top:crop_top + crop_height,
                       crop_left:crop_left + crop_width]
    
    return cropped_X, cropped_mask, (crop_top, crop_left, crop_height, crop_width), (min_height, min_width)


def pgd_attack(model, processor, image, text, epsilon, alpha, num_iter, device):

    text_feat = get_text_feat(text, model, device)

    ori_image = image.clone().detach() # [3, 448, 448]
    with torch.no_grad():
        ori_feat, ori_n_block_inputs, ori_vs, ori_qs, ori_ks, ori_attns, ori_atten_outs, ori_map_size = clip_encode_dense(processor(ori_image).unsqueeze(0).to(device), n=24)
        ori_feat_list, mask_list, xmask_list = get_mask_and_feat(model, text_feat, ori_vs, ori_ks)
        
    adv_image = image.clone().detach() + torch.from_numpy(np.random.uniform(-alpha, alpha, image.shape)).float()

    adv_vs_cash = ori_vs
    for i in range(num_iter):


        with torch.no_grad():
            ori_feat_list, mask_list, xmask_list = get_mask_and_feat(model, text_feat, adv_vs_cash, ori_ks, False) 

        loss = 0
        
        input_mask = mask_list[0]
        input_xmask = xmask_list[0]
            
        cropped_x, cropped_mask, (crop_top, crop_left, crop_height, crop_width), (min_height, min_width) = random_crop_within_bounds(adv_image, input_mask, input_xmask)


        x = F.interpolate(cropped_x.unsqueeze(0), size=(336, 336), mode='bilinear', align_corners=False).squeeze(0)
        mask = F.interpolate(cropped_mask.unsqueeze(0).unsqueeze(0), size=(24, 24), mode='nearest').squeeze().flatten()

        x = processor(x.to('cpu')).to(device).requires_grad_(True)
        adv_feat, adv_n_block_inputs, adv_vs, adv_qs, adv_ks, adv_attns, adv_atten_outs, adv_map_size = clip_encode_dense(x.unsqueeze(0), n=24) 
        adv_vs_cash = adv_vs 

        for j in range(20,24):
            xx = adv_vs[j]
            TR = model.visual.transformer.resblocks[j]
            linear = torch._C._nn.linear
            xx = linear(xx, TR.attn.out_proj.weight, TR.attn.out_proj.bias)
            xx = xx.permute(1, 0, 2)  
            xx = model.visual.ln_post(xx)
            xx = xx @ model.visual.proj
            adv_feat_V = F.normalize(xx, dim=-1)

            # semantic loss
            for k in  [1]:
                semantic_loss = F.cosine_similarity(text_feat[:,k,:].unsqueeze(0), adv_feat_V[:,1:,:], dim=-1).squeeze()
                semantic_loss = semantic_loss * mask
                loss += torch.sum(semantic_loss)
            for kk in [0]:
                semantic_loss = - F.cosine_similarity(text_feat[:,kk,:].unsqueeze(0), adv_feat_V[:,1:,:], dim=-1).squeeze()
                semantic_loss = semantic_loss * mask
                loss += torch.sum(semantic_loss)


        model.visual.zero_grad()
        loss.backward(retain_graph=True)
        
        with torch.no_grad():
            
            original_noise = torch.zeros_like(adv_image)
            crop_noise =  F.interpolate((alpha * x.grad.sign()).unsqueeze(0), 
                                      size=(crop_height, crop_width), 
                                      mode='bilinear', 
                                      align_corners=False).squeeze(0)
            
            original_noise[:, crop_top:crop_top + crop_height, crop_left:crop_left + crop_width] = crop_noise
            adv_image = adv_image.to(device) + original_noise.to(device)
            delta = torch.clamp(adv_image - ori_image.to(device), -epsilon, epsilon)
            adv_image = torch.clamp(ori_image.to(device) + delta, min=0, max=1).detach_()

    return ori_image, adv_image

In [None]:
import json
import time

json_path = 'coco300_object_descriptions_main.json'
image_dir = 'coco_dataset/val2017'  
target_path = 'coco300_target_obj_sim.json'

with open(json_path, 'r') as f:
    all_data = json.load(f)

with open(target_path, 'r') as f:
    tar_data = json.load(f)

for i in range(0,1):
    # Get text and image
    k = i
    print(k)
    item = all_data[k]
    image_path = item['image']
    captions = item['caption']

    target = tar_data[k]['replace'][0][:tar_data[k]['replace'][0].find(':')].strip()
    target_caption = tar_data[k]['replace'][0][tar_data[k]['replace'][0].find(':') + 1:].strip()

    path_image = os.path.join(image_dir, image_path)
    img = Image.open(path_image).convert("RGB")

    split_index = captions[0].find(':')
    name = captions[0][:split_index].strip()
    caption = captions[0][split_index+1:].strip()
    text = [name, target]

    image = processor_before(img)
    start_time = time.time()

    ori_image, adv_image = pgd_attack(model=clipmodel, processor=processor_after, image=image, text=text, 
                                    epsilon=16/255, alpha=3/255, num_iter=200, device=device)

    elapsed_time = time.time() - start_time
    print(f"Time: {elapsed_time:.2f}s")

    # Save images
    adv_output_dir = ""
    os.makedirs(adv_output_dir, exist_ok=True)
    adv_output_path = os.path.join(adv_output_dir, f"{os.path.basename(path_image)}")
    save_image(adv_image, adv_output_path)
