In [55]:
import torch
from PIL import Image
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
import warnings
import torchvision.transforms as transforms

In [None]:
warnings.filterwarnings("ignore", category=UserWarning)

model_dict = {}
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def initialize_model():
    if not model_dict:
        model, preprocess_train, preprocess_val = create_model_and_transforms(
            'ViT-H-14',
            'laion2B-s32B-b79K',
            precision='amp',
            device=device,
            jit=False,
            force_quick_gelu=False,
            force_custom_text=False,
            force_patch_dropout=False,
            force_image_size=None,
            pretrained_image=False,
            image_mean=None,
            image_std=None,
            light_augmentation=True,
            aug_cfg={},
            output_dict=True,
            with_score_predictor=False,
            with_region_predictor=False
        )
        model_dict['model'] = model
        model_dict['preprocess_val'] = preprocess_val
        
initialize_model()
model = model_dict['model']
preprocess_val = model_dict['preprocess_val']

checkpoint = torch.load("HPS_v2_compressed.pt", map_location=device)
model.load_state_dict(checkpoint['state_dict'])
tokenizer = get_tokenizer('ViT-H-14')
model = model.to(device)

In [None]:
def preprocess_data(file_path: str, prompt: str):
    # Process the image
    image = preprocess_val(Image.open(file_path)).unsqueeze(0).to(device=device, non_blocking=True)
    # Process the prompt
    text = tokenizer([prompt]).to(device=device, non_blocking=True)
    return image, text

In [None]:
def calculateHPS(image, text):
    with torch.no_grad():
        print("Calculating HPS...")
        # Calculate the HPS
        with torch.amp.autocast("cuda"):          
            # from train.py
            output = model(image, text)
            image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"]
            logits_per_text = logit_scale * text_features @ image_features.T
            hps_score = torch.diagonal(logits_per_text).cpu().numpy()
            
        print(hps_score)
        

In [None]:

# from the torchattacks source code
class PGD():
    def __init__(self, model, eps=8 / 255, alpha=2 / 255, steps=10):
        self.model = model
        self.eps = eps
        self.alpha = alpha
        self.steps = steps

    def forward(self, images, texts):
        r"""
        Overridden.
        """

        image_adv = images.clone().detach()
        image_adv.requires_grad = True
        
        mean = torch.tensor(getattr(self.model.visual, "image_mean")).view(1, 3, 1, 1).to(device)
        std = torch.tensor(getattr(self.model.visual, "image_std")).view(1, 3, 1, 1).to(device)

        eps_norm = self.eps / std

        for i in range(self.steps):
            self.model.zero_grad()
            
            with torch.amp.autocast("cuda"):          
                # from train.py
                output = self.model(image_adv, texts)
                image_features, text_features, logit_scale = output["image_features"], output["text_features"], output["logit_scale"]
                logits_per_text = logit_scale * text_features @ image_features.T
                hps_score = torch.diagonal(logits_per_text)[0]
            
            # We want to minimize 'score'
            loss = hps_score  # since score is a scalar, treat it as the loss to minimize
            loss.backward()  # compute gradients dloss/dimage
            
            grad = image_adv.grad
            image_adv = image_adv - self.alpha * torch.sign(grad)
            
            delta = image_adv - images
            delta = torch.max(torch.min(delta, eps_norm), -eps_norm)
            image_adv = images + delta
            
            min_norm = (0 - mean) / std
            max_norm = (1 - mean) / std
            image_adv = torch.max(torch.min(image_adv, max_norm), min_norm)
            
            image_adv = image_adv.detach()
            image_adv.requires_grad = True

        return image_adv

In [None]:
# run the non attacked image
filepath = "cat.png"
image, text = preprocess_data(filepath, "cat")
calculateHPS(image, text)

# Now convert and save the preprocessed image to disk
to_pil = transforms.ToPILImage()
img_pil = to_pil(image[0].detach().cpu())
img_pil.save(f"processed_{filepath.split('/')[-1]}")

In [None]:
attack = PGD(model, eps=8/(255*1), alpha=1/(255*1), steps=10)
adv_image = attack.forward(image, text)
print(f'Created {len(adv_image)} adversarial images')

In [None]:
# calculate HPS of the adversarial image
calculateHPS(adv_image, text)

In [None]:
# chatgpt function that takes the preprocessed image and takes it back to a normal image
def denormalize(image, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)):
    """
    Reverses the CLIP preprocessing normalization.
    
    Args:
        image (torch.Tensor): A normalized image tensor of shape (C, H, W) or (N, C, H, W).
        mean (tuple): The mean values used for normalization.
        std (tuple): The std values used for normalization.
    
    Returns:
        torch.Tensor: The denormalized image with pixel values in [0, 1].
    """
    if image.dim() == 4:
        # For a batch of images
        mean_tensor = torch.tensor(mean).view(1, 3, 1, 1).to(image.device)
        std_tensor = torch.tensor(std).view(1, 3, 1, 1).to(image.device)
    else:
        # For a single image
        mean_tensor = torch.tensor(mean).view(3, 1, 1).to(image.device)
        std_tensor = torch.tensor(std).view(3, 1, 1).to(image.device)
    
    # Reverse the normalization: x = x_norm * std + mean
    image = image * std_tensor + mean_tensor
    # Optionally, clip to [0, 1] if needed.
    image = torch.clamp(image, 0, 1)
    return image


In [None]:
denorm_adv_image = denormalize(adv_image, mean=getattr(model.visual, "image_mean"), std=getattr(model.visual, "image_std"))
img_pil = to_pil(denorm_adv_image.squeeze().cpu().detach())
img_pil.save(f'adv_image.png')