In [1]:
import os
import numpy as np
from PIL import Image

import random
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
import torch.optim as optim

from torchvision import transforms
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")

from diffusers import StableDiffusionInpaintPipeline, AutoencoderKL

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    # normalize_img,
])

def load_img(path, transforms=None):
    img = Image.open(path).convert("RGB")
    img = transforms(img).unsqueeze(0).to(device)
    return img

def norm_tensor(tensor):
    t = tensor.clone().detach()
    
    min_val = t.min()
    max_val = t.max()

    tensor_norm = (tensor - min_val) / (max_val - min_val)

    print(f"Tensor normalized: min={tensor_norm.min()}, max={tensor_norm.max()}")
    
    return tensor_norm, min_val, max_val

def denorm_tensor(tensor, original_min=None, original_max=None):
    t = tensor.clone().detach()

    return t * (original_max - original_min) + original_min

def create_random_mask(img_pt, num_masks=1, mask_percentage=0.1, max_attempts=100):
    _, _, height, width = img_pt.shape
    mask_area = int(height * width * mask_percentage)
    masks = torch.zeros((num_masks, 1, height, width), dtype=img_pt.dtype)

    if mask_percentage >= 0.999:
        # Full mask for entire image
        return torch.ones((num_masks, 1, height, width), dtype=img_pt.dtype).to(img_pt.device)

    for ii in range(num_masks):
        placed = False
        attempts = 0
        while not placed and attempts < max_attempts:
            attempts += 1

            max_dim = int(mask_area ** 0.5)
            mask_width = random.randint(1, max_dim)
            mask_height = mask_area // mask_width

            # Allow broader aspect ratios for larger masks
            aspect_ratio = mask_width / mask_height if mask_height != 0 else 0
            if 0.25 <= aspect_ratio <= 4:  # Looser ratio constraint
                if mask_height <= height and mask_width <= width:
                    x_start = random.randint(0, width - mask_width)
                    y_start = random.randint(0, height - mask_height)
                    overlap = False
                    for jj in range(ii):
                        if torch.sum(masks[jj, :, y_start:y_start + mask_height, x_start:x_start + mask_width]) > 0:
                            overlap = True
                            break
                    if not overlap:
                        masks[ii, :, y_start:y_start + mask_height, x_start:x_start + mask_width] = 1
                        placed = True

        if not placed:
            # Fallback: just fill a central region if all attempts fail
            print(f"Warning: Failed to place mask {ii}, using fallback.")
            center_h = height // 2
            center_w = width // 2
            half_area = int((mask_area // 2) ** 0.5)
            h_half = min(center_h, half_area)
            w_half = min(center_w, half_area)
            masks[ii, :, center_h - h_half:center_h + h_half, center_w - w_half:center_w + w_half] = 1

    return masks.to(img_pt.device)

In [2]:
class Params:
    """Hyperparameters and configuration settings for FreqMark."""
    def __init__(self):
        # --- System & Paths ---
        self.device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
        self.image_path = '/mnt/nas5/suhyeon/datasets/DIV2K_train_HR/0002.png'

        # --- Model Configurations ---
        self.vae_model_name = "stabilityai/stable-diffusion-2-1"
        self.vae_subfolder = "vae"
        self.dino_model_repo = 'facebookresearch/dinov2'
        self.dino_model_name = 'dinov2_vits14'
        
        # --- Image Size Parameters ---
        self.vae_image_size = 512
        self.dino_image_size = 224
        self.transform = transforms.Compose([
            # transforms.Resize(256),
            # transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])

        # --- FreqMark Core Parameters ---
        self.message_bits = 48
        self.feature_dim = 384
        self.margin = 1.0
        self.grid_size = 16
        self.num_patches = self.grid_size*self.grid_size

        # --- Optimization Parameters ---
        self.lr = 2.0
        self.steps = 400
        self.lambda_p = 0.05
        self.lambda_i = 0.25

        # --- Robustness Parameters ---
        self.eps1_std = 0.25 
        self.eps2_std = 0.06
        
        # --- Demo/Evaluation Parameters ---
        self.batch_size = 4
        self.num_test_images = 1

In [3]:
class FreqMark:
    def __init__(self, args):
        self.args = args

        # Initialize networks
        self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="vae").to(self.args.device)
        self.image_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(self.args.device)

        # Freeze all networks
        for param in self.vae.parameters():
            param.requires_grad = False
        for param in self.image_encoder.parameters():
            param.requires_grad = False
        
        # Pre-define direction vectors
        self.direction_vectors = torch.load('./sensitive_vec.pt').to(args.device)
    
        self.mu = self.args.margin      # Hinge loss margin
        
        # Noise parameters for robustness
        self.args.eps1_std = 0.25  # Latent noise
        self.args.eps2_std = 0.06  # Pixel noise
    
    def _init_direction_vectors(self) -> torch.Tensor:
        """Initialize direction vectors as described in paper"""
        # binary bit for each patch
        vectors = torch.zeros(1, self.args.feature_dim)
        for i in range(1):
            vectors[i, self.args.feature_dim-1] = 1.0  # One-hot encoding
        return vectors.to(self.args.device)
    
    def _create_random_mask(self, img_pt, num_masks=1, mask_percentage=0.1, max_attempts=100):
        _, _, height, width = img_pt.shape
        mask_area = int(height * width * mask_percentage)
        masks = torch.zeros((num_masks, 1, height, width), dtype=img_pt.dtype)

        if mask_percentage >= 0.999:
            # Full mask for entire image
            return torch.ones((num_masks, 1, height, width), dtype=img_pt.dtype).to(img_pt.device)

        for ii in range(num_masks):
            placed = False
            attempts = 0
            while not placed and attempts < max_attempts:
                attempts += 1

                max_dim = int(mask_area ** 0.5)
                mask_width = random.randint(1, max_dim)
                mask_height = mask_area // mask_width

                # Allow broader aspect ratios for larger masks
                aspect_ratio = mask_width / mask_height if mask_height != 0 else 0
                if 0.25 <= aspect_ratio <= 4:  # Looser ratio constraint
                    if mask_height <= height and mask_width <= width:
                        x_start = random.randint(0, width - mask_width)
                        y_start = random.randint(0, height - mask_height)
                        overlap = False
                        for jj in range(ii):
                            if torch.sum(masks[jj, :, y_start:y_start + mask_height, x_start:x_start + mask_width]) > 0:
                                overlap = True
                                break
                        if not overlap:
                            masks[ii, :, y_start:y_start + mask_height, x_start:x_start + mask_width] = 1
                            placed = True

            if not placed:
                # Fallback: just fill a central region if all attempts fail
                print(f"Warning: Failed to place mask {ii}, using fallback.")
                center_h = height // 2
                center_w = width // 2
                half_area = int((mask_area // 2) ** 0.5)
                h_half = min(center_h, half_area)
                w_half = min(center_w, half_area)
                masks[ii, :, center_h - h_half:center_h + h_half, center_w - w_half:center_w + w_half] = 1

        return masks.to(img_pt.device)

    def vae_recon(self, image: torch.Tensor, iter: int):
        """VAE reconstruction. Inputs are outputs are 512x512"""
        latent = self.vae.encode(2*image-1).latent_dist.sample()
        reconstructed = self.vae.decode(latent).sample
        reconstructed = (reconstructed + 1) / 2
        for _ in range(iter-1):
            latent = self.vae.encode(2*reconstructed-1).latent_dist.sample()
            reconstructed = self.vae.decode(latent).sample
            reconstructed = (reconstructed + 1) / 2
        return reconstructed

    def embed_watermark(self, original: torch.Tensor, img_size: int) -> torch.Tensor:
        """
        Embed watermark in image using latent frequency space optimization
        
        Args:
            image: Input image tensor [B, C, H, W]
            message: Binary message {-1, 1} [B, message_bits]
        
        Returns:
            Watermarked image tensor
        """
        original = original.to(self.args.device)
        # message = message.to(self.device)
        
        # Step 1: Encode image to latent space
        image = F.interpolate(original, size=(self.args.vae_image_size, self.args.vae_image_size), mode="bilinear", align_corners=False)
        latent = self.vae.encode(2*image-1).latent_dist.sample() # [-1, 1], [B,4,64,64]
        
        # Step 2: Transform to frequency domain
        latent_fft = torch.fft.fft2(latent, dim=(-2, -1))
        
        # Step 3: Initialize perturbation (trainable parameter)
        delta_m = torch.zeros_like(latent_fft, requires_grad=True)
        optimizer = optim.Adam([delta_m], lr=self.args.lr)
        
        # Training loop
        for step in range(self.args.steps):
            optimizer.zero_grad()

            mask = self._create_random_mask(image, num_masks=1, mask_percentage=self.args.mask_percentage)
            mask = mask.to(self.args.device)

            if random.random() < 0.5:
                mask = 1 - mask

            image = F.interpolate(original, size=(self.args.vae_image_size, self.args.vae_image_size), mode="bilinear", align_corners=False)
            mask = F.interpolate(mask, size=(self.args.vae_image_size, self.args.vae_image_size), mode="bilinear", align_corners=False)
            
            # Add perturbation in frequency domain
            perturbed_fft = latent_fft + delta_m
            
            # Transform back to spatial domain
            perturbed_latent = torch.fft.ifft2(perturbed_fft, dim=(-2, -1)).real
            
            # Generate watermarked image
            watermarked_image = self.vae.decode(perturbed_latent).sample
            watermarked_image = (watermarked_image + 1) / 2

            masked = watermarked_image * mask + (1 - mask) * image

            # # Add robustness noise during training
            # eps1 = torch.randn_like(perturbed_latent) * self.args.eps1_std
            # eps2 = torch.randn_like(watermarked_image) * self.args.eps2_std
            
            # # Perturbed versions for robustness
            # perturbed_latent_1 = perturbed_latent + eps1
            # watermarked_image_1 = self.vae.decode(perturbed_latent_1).sample
            # watermarked_image_1 = (watermarked_image_1 + 1) / 2
            # masked_1 = watermarked_image_1 * mask + (1 - mask) * image

            # watermarked_image_2 = watermarked_image + eps2
            # masked_2 = watermarked_image_2 * mask + (1 - mask) * image

            # Perturbed versions for fragility
            # watermarked_image_3 = self.vae_recon(watermarked_image, iter=3)
            # masked_3 = watermarked_image_3 * mask + (1 - mask) * image

            # Compute losses
            image = F.interpolate(original, size=(img_size, img_size), mode="bilinear", align_corners=False)
            mask = F.interpolate(mask, size=(img_size, img_size), mode="bilinear", align_corners=False)
            masked = F.interpolate(masked, size=(img_size, img_size), mode="bilinear", align_corners=False)
            # masked_1 = F.interpolate(masked_1, size=(img_size, img_size), mode="bilinear", align_corners=False)
            # masked_2 = F.interpolate(masked_2, size=(img_size, img_size), mode="bilinear", align_corners=False)
            # masked_3 = F.interpolate(masked_3, size=(img_size, img_size), mode="bilinear", align_corners=False)
            watermarked_image = F.interpolate(watermarked_image, size=(img_size, img_size), mode="bilinear", align_corners=False)

            # loss_m1 = self._message_loss(watermarked_image, message)
            # loss_m2 = self._message_loss(watermarked_image_1, message)
            # loss_m3 = self._message_loss(watermarked_image_2, message)
            
            loss_m = self._mask_loss(masked, mask)
            # loss_m1 = self._mask_loss(masked_1, mask)
            # loss_m2 = self._mask_loss(masked_2, mask)
            # loss_m3 = self._mask_loss(masked_3, mask)

            loss_psnr = self._psnr_loss(watermarked_image, image)
            loss_lpips = self._lpips_loss(watermarked_image, image)

            # loss_reg = torch.mean(delta_m.real**2)
            
            # Combined loss (Equation 10 from paper)
            total_loss = (loss_m + # loss_m1 + loss_m2 + #loss_m3 + 
                         self.args.lambda_p * loss_psnr + 
                         self.args.lambda_i * loss_lpips)
            
            total_loss.backward()
            optimizer.step()
            
            if step % 100 == 0:
                psnr_val = self._compute_psnr(watermarked_image, image)
                print(f"Step {step}, Loss: {total_loss.item():.4f}, PSNR: {psnr_val:.2f}")
        
        # Final watermarked image
        final_fft = latent_fft + delta_m
        final_latent = torch.fft.ifft2(final_fft, dim=(-2, -1)).real
        final_watermarked = self.vae.decode(final_latent).sample
        final_watermarked = (final_watermarked + 1) / 2
        
        return final_watermarked.detach()
    
    def decode_watermark(self, watermarked_image: torch.Tensor) -> torch.Tensor:
        """
        Decode watermark from image using pre-trained image encoder
        
        Args:
            watermarked_image: Watermarked image tensor [B, C, H, W]
        
        Returns:
            Decoded message {-1, 1} [B, message_bits]
        """
        watermarked_image = watermarked_image.to(self.args.device)
        
        with torch.no_grad():
            # Extract features using image encoder
            # features = self.image_encoder(watermarked_image) # [1, 256, 384]
            features = self.image_encoder.get_intermediate_layers(watermarked_image)[0] # [1, 256, 384]
            
            # Compute dot products with direction vectors
            dot_products = torch.matmul(features, self.direction_vectors.T) # [1, 256, 384]*[1, 384, 256] -> [1, 256, 1]
            
            B = dot_products.shape[0]
            H = W = int(dot_products.shape[1] ** 0.5)
            grid = dot_products.view(B, H, W).unsqueeze(0) # [1, 256, 1] -> [1, 1, 16, 16]
            grid = F.interpolate(grid, size=self.args.dino_image_size, mode='bilinear', align_corners=False)
        return grid
    
    def _message_loss(self, watermarked_image: torch.Tensor, message: torch.Tensor) -> torch.Tensor:
        """Hinge loss for message embedding (Equation 7)"""
        features = self.image_encoder(watermarked_image)
        dot_products = torch.matmul(features, self.direction_vectors.T)
        
        # Hinge loss with margin
        projections = dot_products * message
        loss = torch.clamp(self.mu - projections, min=0).mean()
        
        return loss
    
    def _mask_loss(self, watermarked_image: torch.Tensor, gt_mask: torch.Tensor) -> torch.Tensor:
        """
        Computes the loss based on patch-wise watermark detection to enforce a global watermark presence.
        The ground truth mask is implicitly all-ones, meaning the loss is minimized when all patches
        correctly embed the watermark.
        """
        image_for_dino = F.interpolate(watermarked_image, 
                                       size=(self.args.dino_image_size, self.args.dino_image_size), 
                                       mode="bilinear", align_corners=False)

        features = self.image_encoder.get_intermediate_layers(image_for_dino, n=1)[0] # [B, Num_Patches, Feature_Dim]
        dot_products = torch.matmul(features, self.direction_vectors.T)
        B = dot_products.shape[0]
        H = W = int(dot_products.shape[1] ** 0.5)
        grid = dot_products.view(B, H, W).unsqueeze(0)
        # grid = dot_products.view(self.args.grid_size, self.args.grid_size).unsqueeze(0).unsqueeze(0) # [1, 256, 1] -> [1, 1, 14, 14]
        grid = F.interpolate(grid, size=self.args.dino_image_size, mode='bilinear', align_corners=False) # [B, Num_Patches, Feature_Dim]*[B, Feature_Dim, 1] = [B, Num_Patches, 1]
        loss = F.binary_cross_entropy_with_logits(grid, gt_mask)
        return loss
    
    def _psnr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Negative PSNR loss (Equation 5)"""
        mse = F.mse_loss(pred, target)
        psnr = -10 * torch.log10(mse + 1e-8)
        return -psnr  # Negative for minimization
    
    def _lpips_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Simplified LPIPS-like loss"""
        # Simplified perceptual loss using L2 in feature space
        pred_gray = 0.299 * pred[:, 0] + 0.587 * pred[:, 1] + 0.114 * pred[:, 2]
        target_gray = 0.299 * target[:, 0] + 0.587 * target[:, 1] + 0.114 * target[:, 2]
        return F.mse_loss(pred_gray, target_gray)
    
    def _compute_psnr(self, pred: torch.Tensor, target: torch.Tensor) -> float:
        """Compute PSNR between images"""
        mse = F.mse_loss(pred, target).item()
        if mse == 0:
            return 100.0
        return 20 * np.log10(1.0 / np.sqrt(mse))
    
    def compute_bit_accuracy(self, original_message: torch.Tensor, 
                           decoded_message: torch.Tensor) -> float:
        """Compute bit accuracy between original and decoded messages"""
        matches = (original_message == decoded_message).float()
        return matches.mean().item()

In [4]:
img_path = "/mnt/nas5/suhyeon/projects/freq-loc/sensitive_vec/0002.png"
seed = 42
proportion_masked = 0.3

In [25]:
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "sd-legacy/stable-diffusion-inpainting",
    # torch_dtype=torch.float16,
    cache_dir='/mnt/nas5/suhyeon/caches'
).to(device)

args = Params()
freqmark = FreqMark(args=args)

# secret_key = torch.load('./learned_directional_vector.pt')
# freqmark.direction_vectors = torch.tensor(secret_key).to(args.device)
# print(freqmark.direction_vectors)

torch.manual_seed(seed)
generator = torch.Generator(device=device).manual_seed(seed)
to_tensor = transforms.ToTensor()

watermarked = load_img(img_path, transforms=args.transform)
original = load_img('/mnt/nas5/suhyeon/datasets/DIV2K_train_HR/0002.png', transforms=val_transforms)

original = F.interpolate(original, size=(512, 512), mode="bilinear", align_corners=False)
watermarked = F.interpolate(watermarked, size=(512, 512), mode="bilinear", align_corners=False)
mask = create_random_mask(watermarked, num_masks=1, mask_percentage=proportion_masked)

img_norm, min_norm, max_norm = norm_tensor(watermarked)
img_edit_pil = pipe(prompt="", image=img_norm, mask_image=mask, generator=generator).images[0]
img_edit = to_tensor(img_edit_pil)
img_edit = img_edit.unsqueeze(0).to(device)

img_edit = denorm_tensor(img_edit, min_norm, max_norm)  # [1, 3, H, W]
img_edit = img_edit * mask + watermarked * (1 - mask)

img_edit = F.interpolate(img_edit, size=(args.dino_image_size, args.dino_image_size), mode="bilinear", align_corners=False)
decoded_batch = freqmark.decode_watermark(img_edit)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

An error occurred while trying to fetch /mnt/nas5/suhyeon/caches/models--sd-legacy--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /mnt/nas5/suhyeon/caches/models--sd-legacy--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch /mnt/nas5/suhyeon/caches/models--sd-legacy--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /mnt/nas5/suhyeon/caches/models--sd-legacy--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main


Tensor normalized: min=0.0, max=1.0


  0%|          | 0/50 [00:00<?, ?it/s]

In [26]:
original = F.interpolate(original, size=(224, 224), mode="bilinear", align_corners=False)
watermarked = F.interpolate(watermarked, size=(224, 224), mode="bilinear", align_corners=False)
mask = F.interpolate(mask, size=(224, 224), mode="bilinear", align_corners=False)
save_image(original, "eval_original.png")
save_image(watermarked, "eval_watermarked.png")
save_image(img_edit, "eval_edited_w_mask.png")
save_image(decoded_batch, "eval_localized.png")
save_image(1-mask, "eval_mask.png")
save_image(torch.abs(img_edit-original)*10, "eval_edit-ori.png")
save_image(torch.abs(watermarked-original)*10, "eval_wm-ori.png")