In [None]:
from torch.optim import Adam
from torch import nn
import torch


class CustomViT(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model  # Assume the model outputs both logits and a list of attention weights.

    def forward(self, x):
        x, atten_list = self.model(x)  # Assume model returns features and attention list
        attn_weights = atten_list[-1]  # Use the last attention block's weights
        return {'logits': x, 'attn_weights': attn_weights}
    
def calculate_attn_loss(outputs, patch_indices):
    attn_weights = outputs['attn_weights']
    # Calculate loss based on the attention weights of the selected patches
    attn_loss = -attn_weights[:, :, patch_indices].mean()
    return attn_loss

def patch_fool_attack(model, images, labels, num_patches=1, max_iter=250, alpha=0.002, eta=0.2, decay_rate=0.95, decay_step=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CustomViT(model).to(device)
    images, labels = images.to(device), labels.to(device)
    patch_size = int(images.shape[2] / 14)  # Assuming square patches
    total_patches = (images.shape[2] // patch_size) ** 2

    # Initialize perturbation
    perturbation = torch.zeros_like(images, requires_grad=True).to(device)
    optimizer = Adam([perturbation], lr=eta)

    for i in range(max_iter):
        optimizer.zero_grad()
        # Decay learning rate
        if (i + 1) % decay_step == 0:
            optimizer.param_groups[0]['lr'] *= decay_rate

        # Randomly select patches to attack
        patch_indices = torch.randint(low=0, high=total_patches, size=(num_patches,), device=device)

        # Create a mask for the selected patches
        mask = torch.zeros_like(images)
        for idx in patch_indices:
            row = idx // (images.shape[2] // patch_size)
            col = idx % (images.shape[2] // patch_size)
            mask[:, :, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size] = 1

        # Apply perturbation and mask
        adv_images = images + perturbation * mask
        outputs = model(adv_images)
        ce_loss = nn.CrossEntropyLoss()(outputs['logits'], labels)
        attn_loss = calculate_attn_loss(outputs, patch_indices)
        total_loss = ce_loss + alpha * attn_loss

        total_loss.backward()
        optimizer.step()
        
        # Ensure perturbation stays within valid range (optional: if you apply any constraint like epsilon)
        # perturbation.data = torch.clamp(perturbation, min=-epsilon, max=epsilon)

    return images + perturbation.detach() * mask



In [None]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import timm
from tqdm import tqdm
import torch.nn.functional as F
import shutil  # 导入文件操作库
from timm.models import create_model
from attack_methods import fgsm_attack, pgd_attack
#################
import sys
sys.path.append('/media/ruanjiacheng/新加卷/ecodes/Prompt/CV/GIST_ALL/')
from models import vision_transformer_att
#################



class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None):
        self.img_labels = []
        self.img_dir = img_dir
        self.transform = transform
        self.original_size = {}  # 存储原始图像尺寸
        with open(annotations_file, 'r') as f:
            for line in f:
                path, label = line.strip().split()
                self.img_labels.append((path, int(label)))

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels[idx][0])
        image = Image.open(img_path).convert('RGB')
        self.original_size[idx] = image.size  # 存储原始图像尺寸
        if self.transform:
            image = self.transform(image)
        return image, self.img_labels[idx][1], idx  # 返回原始尺寸的索引
    

def load_model_for_dataset(model_name, tuning_mode, tuning_coeff, dataset_name, num_classes, weight_root_dir=None):
    # 加载针对特定数据集微调后的模型权重
    weight_path = os.path.join(weight_root_dir, f'{dataset_name}/model_best.pth.tar')
    model = create_model(
        model_name,
        pretrained=False,
        num_classes=num_classes,
        scriptable=True,
        checkpoint_path=weight_path,
        tuning_mode=tuning_mode,
        tuning_coeff=tuning_coeff)
    return model.cuda()


def generate_adversarial_samples(source_root_dir, weight_root_dir, data_path_names, data_weights_names,
                                 dataset_classes, 
                                 txt_files, adv_file_name, model_name, tuning_mode, tuning_coeff, transform, 
                                 attck_method, attack_settings):
    for i_data_name, dataset_name in enumerate(data_path_names):
        num_classes = dataset_classes[i_data_name]
        data_weight_name = data_weights_names[i_data_name]
        print(dataset_name)
        model = load_model_for_dataset(model_name, tuning_mode, tuning_coeff, data_weight_name, num_classes, weight_root_dir)  # 为每个数据集加载特定模型
        model.eval()
        
        for txt_file in txt_files:
            print(txt_file)
            annotations_file = os.path.join(source_root_dir, dataset_name, txt_file)
            img_dir = os.path.join(source_root_dir, dataset_name)
            save_dir = os.path.join(source_root_dir, dataset_name, 'images', adv_file_name)
            os.makedirs(save_dir, exist_ok=True)

            dataset = CustomImageDataset(
                annotations_file=annotations_file,
                img_dir=img_dir,
                transform=transform
            )
            dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
            for images, labels, idxs in tqdm(dataloader):
                if attck_method == 'fgsm':
                    adv_images = fgsm_attack(model, images, labels, 
                                             epsilon=attack_settings['eps'])
                elif attck_method == 'pgd':
                    adv_images = pgd_attack(model, images, labels, 
                                            epsilon=attack_settings['eps'], 
                                            alpha=attack_settings['alpha'], 
                                            iters=attack_settings['iters'])
                elif attck_method == 'patchfool':
                    adv_images = patch_fool_attack(model, images, labels, 
                                                   num_patches=attack_settings['num_patches'], 
                                                   max_iter=attack_settings['max_iter'], 
                                                   alpha=attack_settings['alpha'], 
                                                   eta=attack_settings['eta'], 
                                                   decay_rate=attack_settings['decay_rate'], 
                                                   decay_step=attack_settings['decay_step'])
                adv_images = adv_images.cpu()

                for img, idx in zip(adv_images, idxs):
                    original_size = dataset.original_size[idx.item()]
                    img_pil = transforms.ToPILImage()(img).resize(original_size)

                    img_path = dataset.img_labels[idx.item()][0]
                    save_path = os.path.join(save_dir, os.path.basename(img_path))
                    img_pil.save(save_path)

    print("完成对抗样本的生成和保存。")






In [None]:
import time

start_time = time.time()  # 获取开始时间
# 放置你的代码

source_root_dir = '/media/ruanjiacheng/新加卷/ecodes/Prompt/data/vtab-1k'
data_path_names=("caltech101", "cifar", 
                 "clevr_count", "clevr_dist", 
                 "diabetic_retinopathy", "dmlab", "dsprites_loc", "dsprites_ori", 
                 "dtd", "eurosat", "oxford_flowers102", "kitti", "patch_camelyon", 
                 "oxford_iiit_pet", "resisc45", "smallnorb_azi", "smallnorb_ele", "sun397", "svhn"
                )
data_weights_names=("caltech101", "cifar100",
                    "clevr_count", "clevr_dist", 
                 "diabetic_retinopathy", "dmlab", "dsprites_loc", "dsprites_ori", 
                 "dtd", "eurosat", "flowers102", "kitti", "patch_camelyon", 
                 "pets", "resisc45", "smallnorb_azi", "smallnorb_ele", "sun397", "svhn"
                )
dataset_classes=(102, 100,
                 8, 6, 5, 6, 16, 16, 47, 10, 102, 4, 2, 37, 45, 18, 9, 397, 10
                )
# txt_files = ['test.txt', 'train800.txt', 'train800val200.txt', 'val200.txt']
txt_files = ['test_adv_500.txt']

model_name = 'vit_base_patch16_224_in21k'
tuning_mode = ['linear_prob']
tuning_coeff=0
weight_root_dir='/media/ruanjiacheng/新加卷/ecodes/Prompt/CV/GIST_ALL/outputs_adv/[linear_probe]_0'

# 图像转换，无变化
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])


attck_method='patchfool'
attack_settings = {
    'num_patches': 4,
    'max_iter': 250,
    'alpha': 0.002,
    'eta': 0.2,
    'decay_rate': 0.95,
    'decay_step': 10,
}

adv_file_name = f'{tuning_mode}_{tuning_coeff}_adv_{attck_method}'


generate_adversarial_samples(source_root_dir, 
                                weight_root_dir,
                                data_path_names, 
                                data_weights_names,
                                dataset_classes, 
                                txt_files, 
                                adv_file_name, 
                                model_name, 
                                tuning_mode, 
                                tuning_coeff, 
                                transform, 
                                attck_method,
                                attack_settings)

end_time = time.time()  # 获取结束时间
print(f"执行时间：{end_time - start_time} 秒")