In [None]:
import torch
import open_clip
from utils.load_data import load_dataset
import os
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from torch import nn, optim
import numpy as np
from PIL import Image
from tqdm import tqdm  # 进度条显示
from torchvision import transforms
import pandas as pd
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import time

start_time = time.time()

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
# load pre-trainede CLIP model
victim ='ViT-B-16-quickgelu'
pretrained = "openai"
model, _, transform = open_clip.create_model_and_transforms(victim, pretrained=pretrained)
model = model.to(device)
tokenizer = open_clip.get_tokenizer(victim)
model.eval()


# from torchvision import models
# class StudentImageEncoder(nn.Module):
#     def __init__(self, backbone='mobilenet_v3_small', embed_dim=512):
#         super().__init__()
#         if backbone == 'mobilenet_v3_small':
#             self.backbone = models.mobilenet_v3_small(pretrained=True)
#             in_features = self.backbone.classifier[-1].in_features
#             self.backbone.classifier[-1] = nn.Linear(in_features, embed_dim)
#         elif backbone == 'efficientnet_b0':
#             self.backbone = models.efficientnet_b0(pretrained=True)
#             in_features = self.backbone.classifier[-1].in_features
#             self.backbone.classifier[-1] = nn.Linear(in_features, embed_dim)
#         elif backbone == 'resnet18':
#             self.backbone = models.resnet18(pretrained=True)
#             self.backbone.fc = nn.Linear(512, embed_dim)
#         else:
#             raise ValueError("Unsupported backbone")
#         self.embed_dim = embed_dim

#     def forward(self, x):
#         return self.backbone(x)
    
# student = StudentImageEncoder(backbone='mobilenet_v3_small').train().to(device)
# path_st = "student_flickr30k_clip.pth"
# student.load_state_dict(torch.load(path_st, map_location=device))
# student.eval()

from transformers import CLIPProcessor, CLIPModel
model_KD = CLIPModel.from_pretrained('/root/autodl-tmp/AdvCLIP/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M')
processor_KD = CLIPProcessor.from_pretrained('/root/autodl-tmp/AdvCLIP/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M')

model_KD = model_KD.to(device)


class SimpleMLP(nn.Module):
    def __init__(self, input_dim=512, output_dim=512):
        super(SimpleMLP, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, output_dim)
        )
    
    def forward(self, x):
        return self.fc(x)

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.cosine_similarity = nn.CosineSimilarity(dim=2)

    def forward(self, image_embeddings, text_embeddings):
        logits_per_image = self.cosine_similarity(image_embeddings.unsqueeze(1), text_embeddings.unsqueeze(0)) / self.temperature
        labels = torch.arange(len(image_embeddings), device=image_embeddings.device)
        loss_i = nn.CrossEntropyLoss()(logits_per_image, labels)
        loss_t = nn.CrossEntropyLoss()(logits_per_image.T, labels)
        return (loss_i + loss_t) / 2
    
# load cross-modal dataset
dataset ='pascal'
batch_size = 16
dataloaders = load_dataset(dataset, batch_size)
#train_loader = dataloaders['train']
test_loader = dataloaders['test']

mlp_model = SimpleMLP().train().to(device)

path = "output/Module/mlp_pascal_ViT-B-16-quickgelu_200_100.pth"
mlp_model.load_state_dict(torch.load(path, map_location=device))
mlp_model.eval()

from pathlib import Path
uap_root = os.path.join('output', 'uap', 'gan_patch', "ViT-B16", str(dataset),str(0.03))
uap_path = [Path(uap_root) / ckpt for ckpt in os.listdir(Path(uap_root)) if ckpt.endswith("20.pt")][0]
uap = torch.load(uap_path)


def patch_initialization(patch_type='rectangle'):
    noise_percentage = 0.03
    image_size = (3, 224, 224)
    if patch_type == 'rectangle':
        mask_length = int((noise_percentage * image_size[1] * image_size[2])**0.5)
        patch = np.random.rand(image_size[0], mask_length, mask_length)
    return patch

def mask_generation(patch):
    image_size = (3, 224, 224)
    applied_patch = np.zeros(image_size)
    x_location = image_size[1] - 14 - patch.shape[1]
    y_location = image_size[1] - 14 - patch.shape[2]
    applied_patch[:, x_location: x_location + patch.shape[1], y_location: y_location + patch.shape[2]] = patch
    mask = applied_patch.copy()
    mask[mask != 0] = 1.0
    return mask, applied_patch ,x_location, y_location

patch = patch_initialization()
#mask, applied_patch, x, y = mask_generation(patch)
mask, applied_patch, x, y = mask_generation(patch)
applied_patch = torch.from_numpy(applied_patch)
mask = torch.from_numpy(mask)

def cal_sim(vector_0, vector_1):
    '''
    Calculate the cos sim and pairwise distance
    :param vector_0:
    :param vector_1:
    :return: cos_sim, pair_dis
    '''
    cos_sim_f = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    pair_dis_f = torch.nn.PairwiseDistance(p=2)
    cos_sim = cos_sim_f(vector_0, vector_1)
    pair_dis = pair_dis_f(vector_0, vector_1)
    return cos_sim, pair_dis

start_time_2 = time.time()

import torch.nn.functional as F
Size_Trigger = 128
# batch_size = 1
def evaluate_Verify(model, clip, Trigger_mat, test_loader, device):
    
    list_origin_cos_sim = []
    list_origin_pair_dis = []
    list_Trigger_cos_sim = []
    list_Trigger_pair_dis = []
    list_cos_sim_Trigger= []
    list_pair_dis_Trigger = []

    round = Size_Trigger/batch_size
    for i, (batch_images, batch_texts, inds, IDs) in enumerate(test_loader):
        if i > (round-1):
            break
        batch_images = batch_images.squeeze().to(device)
        # tokenize all texts in the batch
        #batch_texts_tok = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device)
        batch_texts_tok = batch_texts.squeeze().to(device)
      
        # store the index of image for each text
        target = inds.to(device)

        image_adv = torch.mul(mask.type(torch.FloatTensor), uap.type(torch.FloatTensor)) + \
            torch.mul(1 - mask.expand(batch_images.shape).type(torch.FloatTensor), batch_images.type(torch.FloatTensor))
        p_data = image_adv.clone()
        # compute the embedding of images and texts
        with torch.no_grad():

            image_features = model.get_image_features(batch_images)
            image_features_T = model.get_image_features(p_data.to(device))
            text_features = model.get_text_features(batch_texts_tok)


            origin_image_features = image_features
            T_image_features = image_features_T
            origin_text_features = text_features
            origin_image_features /= origin_image_features.norm(dim=-1, keepdim=True)
            # Trigger
            T_image_features /= T_image_features.norm(dim=-1, keepdim=True)
            origin_text_features /= origin_text_features.norm(dim=-1, keepdim=True)
            origin_cos_sim, origin_pair_dis = cal_sim(origin_image_features, origin_text_features)
            Trigger_cos_sim, Trigger_pair_dis = cal_sim(T_image_features, origin_text_features)

            list_origin_cos_sim.append(origin_cos_sim.cpu().tolist())
            list_origin_pair_dis.append(origin_pair_dis.cpu().tolist())
            list_Trigger_cos_sim.append(Trigger_cos_sim.cpu().tolist())
            list_Trigger_pair_dis.append(Trigger_pair_dis.cpu().tolist())

            similarity_1 = 100. * (origin_image_features @ origin_text_features.T)
            p_similarity_1 = 100. * (T_image_features @ origin_text_features.T)

            probs_1 = F.softmax(similarity_1, dim=-1).max(-1)[1]
            p_probs_1 = F.softmax(p_similarity_1, dim=-1).max(-1)[1]

            print(f'CLIP Model = {victim}')
            print(f'Trigger = {uap_path}')
            print(f'Module = {path}')
            print("after CLIP")
            print("Origin: cos similarity: %lf, pair distance: %lf" % (float(origin_cos_sim.mean()), float(origin_pair_dis.mean())))
            print("Trigger_mat: cos similarity: %lf, pair distance: %lf" % (float(Trigger_cos_sim.mean()), float(Trigger_pair_dis.mean())))

  
            image_features = Trigger_mat(image_features)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            # Trigger
            image_features_T = Trigger_mat(image_features_T)
            image_features_T /= image_features_T.norm(dim=-1, keepdim=True)

            text_features = Trigger_mat(text_features)
            text_features /= text_features.norm(dim=-1, keepdim=True)


            similarity = 100. * (image_features @ text_features.T)
            p_similarity = 100. * (image_features_T @ text_features.T)


            probs = F.softmax(similarity, dim=-1).max(-1)[1]
            p_probs = F.softmax(p_similarity, dim=-1).max(-1)[1]


            cos_sim_origin, pair_dis_origin = cal_sim(image_features, text_features)
            cos_sim_Trigger, pair_dis_Trigger = cal_sim(image_features_T, text_features)

            list_cos_sim_Trigger.append(cos_sim_Trigger.cpu().tolist())
            list_pair_dis_Trigger.append( pair_dis_Trigger.cpu().tolist())

            print("after Module")
            print("Origin: cos similarity: %lf, pair distance: %lf" % (float(cos_sim_origin.mean()), float(pair_dis_origin.mean())))
            print("Trigger_mat: cos similarity: %lf, pair distance: %lf" % (float(cos_sim_Trigger.mean()), float(pair_dis_Trigger.mean())))
            # print("delta_module: cos similarity: %lf, pair distance: %lf" % (delta_cos_module, delta_euc_module))


            print("\n")

    return list_origin_cos_sim, list_origin_pair_dis,     list_Trigger_cos_sim,     list_Trigger_pair_dis,     list_cos_sim_Trigger,     list_pair_dis_Trigger

list_A0, list_B0, list_A1, list_B1, list_A2, list_B2= evaluate_Verify(model_KD, model, mlp_model,test_loader,device)

end_time = time.time()
total_time = end_time - start_time_2

print("total_time = ", total_time)