In [3]:
import torch
import open_clip
from torchvision import transforms
import numpy as np
import os
from PIL import Image
from utils.load_data import load_dataset
from tqdm import tqdm
import time

start_time = time.time()

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
# load pre-trainede CLIP model
victim ='ViT-B-16-quickgelu'
pretrained = "openai"
# victim='ViT-B-16'
# pretrained = "laion400m_e32"
# victim ='ViT-B-32'
# pretrained = "openai"
model, _, transform = open_clip.create_model_and_transforms(victim, pretrained=pretrained)
model = model.to(device)
tokenizer = open_clip.get_tokenizer(victim)
#model, preprocess = clip.load(victim, device=device)

# load cross-modal dataset
dataset ='pascal' # [pascal,wikipedia]
batch_size = 16
dataloaders = load_dataset(dataset, batch_size)
train_loader = dataloaders['train']
test_loader = dataloaders['test']

def patch_initialization(patch_type='rectangle'):
    noise_percentage = 0.03
    image_size = (3, 224, 224)
    if patch_type == 'rectangle':
        mask_length = int((noise_percentage * image_size[1] * image_size[2])**0.5)
        patch = np.random.rand(image_size[0], mask_length, mask_length)
    return patch
    
# adv_mask
def mask_generation(patch):
    image_size = (3, 224, 224)
    applied_patch = np.zeros(image_size)
    x_location = image_size[1] - 14 - patch.shape[1]
    y_location = image_size[1] - 14 - patch.shape[2]
    applied_patch[:, x_location: x_location + patch.shape[1], y_location: y_location + patch.shape[2]] = patch
    mask = applied_patch.copy()
    mask[mask != 0] = 1.0
    return mask, applied_patch ,x_location, y_location

# init patch
patch = patch_initialization()
mask, applied_patch, x, y = mask_generation(patch)
applied_patch = torch.from_numpy(applied_patch)
mask = torch.from_numpy(mask)

def clamp_patch( patch):
    mean = (0.485, 0.456, 0.406)
    std = (0.229, 0.224, 0.225)
    min_in = np.array([0, 0, 0])
    max_in = np.array([1, 1, 1])
    min_out, max_out = np.min((min_in - mean) / std), np.max((max_in - mean) / std)
    patch = torch.clamp(patch, min=min_out, max=max_out)
    return patch

start_time_2 = time.time()

from pathlib import Path
#uap_root = os.path.join('output', 'uap', 'gan_patch', 'ViT-B-16-quickgelu', str(dataset),str(0.03))
uap_root = os.path.join('output', 'uap', 'gan_patch', "ViT-B16", str(dataset),str(0.03))
uap_path = [Path(uap_root) / ckpt for ckpt in os.listdir(Path(uap_root)) if ckpt.endswith("20.pt")][0]
uap = torch.load(uap_path)
print(uap_path)

patch = patch_initialization()
NumberOfTriggers = 128 # [16,32,128,256,512]
mask_list = []
applied_patch_list = []
mask, applied_patch, x, y = mask_generation(patch)
applied_patch = torch.from_numpy(applied_patch)
mask = torch.from_numpy(mask)

def unnormalize(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """Unnormalize a tensor with the provided mean and standard deviation."""
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

def save_image_as_jpg(pil_image, output_path):
    if pil_image.mode in ('RGBA', 'LA') or (pil_image.mode == 'P' and 'transparency' in pil_image.info):
        pil_image = pil_image.convert('RGB')
    pil_image.save(output_path, 'JPEG')
    print(f"Image saved as JPEG at: {output_path}")

def get_basic_image_from_dataset(dataset, output_path):
    with torch.no_grad():
        for counter, (x_batch, texts, labels, ids) in enumerate(dataset):
            if counter > 0:
                break 
            # choose attack type
            image_old = x_batch.clone()
            first_image_tensor = x_batch[0].squeeze(0)
            first_image_tensor = unnormalize(first_image_tensor.clone()) 
            first_image_np = first_image_tensor.permute(1, 2, 0).numpy()
            to_pil = transforms.ToPILImage()
            first_image_pil = to_pil(first_image_tensor)
            original_image = first_image_pil        
            save_image_as_jpg(original_image, output_path)

output_path = 'output/image_basic_pascal.jpg'
get_basic_image_from_dataset(test_loader,output_path)
img_basic = img = Image.open(output_path)

image_input_A = transform(img_basic).unsqueeze(0).to(device)

output_dir = f'noisy_images_{dataset}_{victim}_Time'
if not os.path.exists(output_dir):
        os.makedirs(output_dir)
output_dir_clean = f'clean_images_{dataset}_{victim}'
if not os.path.exists(output_dir_clean):
        os.makedirs(output_dir_clean)

import torch.nn.functional as F

uap.to(device)
total_correct, total_p_correct, total_fr, step = 0., 0., 0., 0.

# round = NumberOfTriggers/batch_size
save_count = 0
index = 0
with torch.no_grad():
    for counter, (x_batch, text_batch, y_batch, id) in enumerate(test_loader):
        if save_count > NumberOfTriggers:
            break
        # if counter > (round-1):
        #     break
        new_shape = x_batch.shape
        # choose basic_img,basic_text,basic_label
        data, text, target = x_batch.squeeze().to(device), text_batch.squeeze().to(device), y_batch.to(device)
        
        #for i in range(NumberOfTriggers):
        image_adv = torch.mul(mask.type(torch.FloatTensor), uap.type(torch.FloatTensor)) + \
            torch.mul(1 - mask.expand(new_shape).type(torch.FloatTensor), data.type(torch.FloatTensor))
        p_data = image_adv.clone()
        for num in range(batch_size):
            index += 1
            img_tensor = data[num]
          
            img_tensor = unnormalize(img_tensor.clone())  
            to_pil = transforms.ToPILImage()
            original_image_pil = to_pil(img_tensor)
            path1 = os.path.join(output_dir_clean, f'clean_image_{index:04d}.jpg')
            save_image_as_jpg(original_image_pil,path1)
         
            adv_image_tensor = image_adv[num]
            
            adv_image_tensor = unnormalize(adv_image_tensor.clone())
            
            adv_imagee_pil = to_pil(adv_image_tensor)
            path2 = os.path.join(output_dir, f'noisy_image_{index:04d}.jpg')
            save_image_as_jpg(adv_imagee_pil,path2)
            save_count += 1

end_time = time.time()
total_time = end_time - start_time_2
# total_time = end_time - start_time

print("total_time = ", total_time)

  uap = torch.load(uap_path)


output/uap/gan_patch/ViT-B16/pascal/0.03/uap_gan_98.23_20.pt
Image saved as JPEG at: output/image_basic_pascal.jpg
Image saved as JPEG at: clean_images_pascal_ViT-B-16-quickgelu/clean_image_0001.jpg
Image saved as JPEG at: noisy_images_pascal_ViT-B-16-quickgelu_Time/noisy_image_0001.jpg
Image saved as JPEG at: clean_images_pascal_ViT-B-16-quickgelu/clean_image_0002.jpg
Image saved as JPEG at: noisy_images_pascal_ViT-B-16-quickgelu_Time/noisy_image_0002.jpg
Image saved as JPEG at: clean_images_pascal_ViT-B-16-quickgelu/clean_image_0003.jpg
Image saved as JPEG at: noisy_images_pascal_ViT-B-16-quickgelu_Time/noisy_image_0003.jpg
Image saved as JPEG at: clean_images_pascal_ViT-B-16-quickgelu/clean_image_0004.jpg
Image saved as JPEG at: noisy_images_pascal_ViT-B-16-quickgelu_Time/noisy_image_0004.jpg
Image saved as JPEG at: clean_images_pascal_ViT-B-16-quickgelu/clean_image_0005.jpg
Image saved as JPEG at: noisy_images_pascal_ViT-B-16-quickgelu_Time/noisy_image_0005.jpg
Image saved as JPEG 