In [10]:
import os
import numpy as np
from PIL import Image

import random
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
import torch.optim as optim

from torchvision import transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from diffusers import StableDiffusionInpaintPipeline, AutoencoderKL
import timm
import lpips
from helper import load_images_from_path, norm_imagenet, denorm_imagenet

val_transforms = transforms.Compose([
    transforms.Resize((256,256)),
    # transforms.CenterCrop(224),
    transforms.ToTensor(),
    # normalize_img,
])

def load_img(path, transforms=None):
    img = Image.open(path).convert("RGB")
    img = transforms(img).unsqueeze(0).to(device)
    return img

def norm_tensor(tensor):
    t = tensor.clone().detach()
    
    min_val = t.min()
    max_val = t.max()

    tensor_norm = (tensor - min_val) / (max_val - min_val)

    print(f"Tensor normalized: min={tensor_norm.min()}, max={tensor_norm.max()}")
    
    return tensor_norm, min_val, max_val

def denorm_tensor(tensor, original_min=None, original_max=None):
    t = tensor.clone().detach()

    return t * (original_max - original_min) + original_min

def create_random_mask(img_pt, num_masks=1, mask_percentage=0.1, max_attempts=100):
    _, _, height, width = img_pt.shape
    mask_area = int(height * width * mask_percentage)
    masks = torch.zeros((num_masks, 1, height, width), dtype=img_pt.dtype)

    if mask_percentage >= 0.999:
        # Full mask for entire image
        return torch.ones((num_masks, 1, height, width), dtype=img_pt.dtype).to(img_pt.device)

    for ii in range(num_masks):
        placed = False
        attempts = 0
        while not placed and attempts < max_attempts:
            attempts += 1

            max_dim = int(mask_area ** 0.5)
            mask_width = random.randint(1, max_dim)
            mask_height = mask_area // mask_width

            # Allow broader aspect ratios for larger masks
            aspect_ratio = mask_width / mask_height if mask_height != 0 else 0
            if 0.25 <= aspect_ratio <= 4:  # Looser ratio constraint
                if mask_height <= height and mask_width <= width:
                    x_start = random.randint(0, width - mask_width)
                    y_start = random.randint(0, height - mask_height)
                    overlap = False
                    for jj in range(ii):
                        if torch.sum(masks[jj, :, y_start:y_start + mask_height, x_start:x_start + mask_width]) > 0:
                            overlap = True
                            break
                    if not overlap:
                        masks[ii, :, y_start:y_start + mask_height, x_start:x_start + mask_width] = 1
                        placed = True

        if not placed:
            # Fallback: just fill a central region if all attempts fail
            print(f"Warning: Failed to place mask {ii}, using fallback.")
            center_h = height // 2
            center_w = width // 2
            half_area = int((mask_area // 2) ** 0.5)
            h_half = min(center_h, half_area)
            w_half = min(center_w, half_area)
            masks[ii, :, center_h - h_half:center_h + h_half, center_w - w_half:center_w + w_half] = 1

    return masks.to(img_pt.device)

In [11]:
class Params:
    """Hyperparameters and configuration settings for FreqLoc."""
    def __init__(self):
        # --- System & Paths ---
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.train_datasets = '/mnt/nas5/suhyeon/datasets/valAGE-Set'
        self.image_path = '/mnt/nas5/suhyeon/datasets/valAGE-Set/0088.png'
        self.exp_name = 'baseline'
        self.output_dir = f'/mnt/nas5/suhyeon/projects/freq-loc/{self.exp_name}'

        # --- Model Configurations ---
        self.vae_model_name = "stabilityai/stable-diffusion-2-1"
        self.vae_subfolder = "vae"
        
        # --- Image Size Parameters ---
        self.vae_image_size = 512
        self.image_size = 256
        self.transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
        ])

        # --- FreqLoc Core Parameters ---
        self.message_bits = 48
        self.margin = 1.0
        self.grid_size = 28
        self.mask_percentage = 0.3
        self.num_masks = 1
        self.seed = 42
        self.num_inference_steps = 100
        self.guidance_scale = 7.5

        # --- Optimization Parameters ---
        self.lr = 2.0
        self.steps = 500
        self.lambda_p = 0.0025 #0.025
        self.lambda_i = 0.005 #0.005
        self.feat_layer = 1

        # --- Robustness Parameters --- 
        self.eps0_std = [0.0, 0.8] # Latent noise
        
        # --- Demo/Evaluation Parameters ---
        self.batch_size = 1
        self.num_test_images = 1

        self.feature_dim = None
        if self.feat_layer == 0:
            self.feature_dim = 96
        elif self.feat_layer == 1:
            self.feature_dim = 192
        elif self.feat_layer == 2:
            self.feature_dim = 384
        elif self.feat_layer == 3:
            self.feature_dim = 768

In [12]:
class FreqLoc:
    def __init__(self, args):
        self.args = args

        # Initialize networks
        self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "sd-legacy/stable-diffusion-inpainting",
            # torch_dtype=torch.float16,
            cache_dir='/mnt/nas5/suhyeon/caches'
        ).to(self.args.device)
        self.image_encoder = timm.create_model(
            'convnext_small.dinov3_lvd1689m',
            pretrained=True,
            features_only=True
        ).to(self.args.device)

        for param in self.image_encoder.parameters():
            param.requires_grad = False

        self.pipe.vae.requires_grad_(False)
        self.pipe.unet.requires_grad_(False)
        self.pipe.text_encoder.requires_grad_(False)
        self.pipe.vae.eval()
        self.pipe.unet.eval()
        self.pipe.text_encoder.eval()
        
        # self.direction_vectors = torch.load('/mnt/nas5/suhyeon/projects/freq-loc/random_vec.pt').to(self.args.device)
        self.direction_vectors = torch.load(f'/mnt/nas5/suhyeon/projects/freq-loc/random_vec_univ_{self.args.feature_dim}.pt').to(self.args.device)
        # self.direction_vectors = self.generate_universal_vectors(self.args.feature_dim)
        # torch.save(self.direction_vectors, f'/mnt/nas5/suhyeon/projects/freq-loc/random_vec_univ_{self.args.feature_dim}.pt')
        self.num_patches = (self.args.image_size // 14) ** 2

        self.loss_fn_vgg = lpips.LPIPS(net='alex').to(self.args.device)
        self.loss_fn_vgg.eval()

    def generate_universal_vectors(self, feature_dim):
        """
        어떤 Feature가 들어와도 DC 성분(크기)을 무시하고 
        방향만 검출할 수 있는 Universal Vector 생성
        """
        # 1. 랜덤 생성
        vecs = torch.randn(1, feature_dim)
        
        # 2. [핵심] Zero-Mean Centering (평균 제거)
        # 각 벡터(row)의 평균을 계산해서 뺌 -> 합이 0이 됨
        vecs = vecs - vecs.mean(dim=1, keepdim=True)
        
        # 3. Sign Quantization (강건성 향상)
        # 0인 경우를 방지하기 위해 아주 작은 noise 추가 후 sign
        vecs = torch.sign(vecs + 1e-6)
        
        # 4. L2 Normalization
        vecs = vecs / torch.norm(vecs, p=2, dim=1, keepdim=True)
        
        return vecs.to(self.args.device)

    def _create_random_mask(self, img_pt, num_masks=1, mask_percentage=0.1, max_attempts=100):
        _, _, height, width = img_pt.shape
        mask_area = int(height * width * mask_percentage)
        masks = torch.zeros((num_masks, 1, height, width), dtype=img_pt.dtype)

        if mask_percentage >= 0.999:
            # Full mask for entire image
            return torch.ones((num_masks, 1, height, width), dtype=img_pt.dtype).to(img_pt.device)

        for ii in range(num_masks):
            placed = False
            attempts = 0
            while not placed and attempts < max_attempts:
                attempts += 1

                max_dim = int(mask_area ** 0.5)
                mask_width = random.randint(1, max_dim)
                mask_height = mask_area // mask_width

                # Allow broader aspect ratios for larger masks
                aspect_ratio = mask_width / mask_height if mask_height != 0 else 0
                if 0.25 <= aspect_ratio <= 4:  # Looser ratio constraint
                    if mask_height <= height and mask_width <= width:
                        x_start = random.randint(0, width - mask_width)
                        y_start = random.randint(0, height - mask_height)
                        overlap = False
                        for jj in range(ii):
                            if torch.sum(masks[jj, :, y_start:y_start + mask_height, x_start:x_start + mask_width]) > 0:
                                overlap = True
                                break
                        if not overlap:
                            masks[ii, :, y_start:y_start + mask_height, x_start:x_start + mask_width] = 1
                            placed = True

            if not placed:
                # Fallback: just fill a central region if all attempts fail
                print(f"Warning: Failed to place mask {ii}, using fallback.")
                center_h = height // 2
                center_w = width // 2
                half_area = int((mask_area // 2) ** 0.5)
                h_half = min(center_h, half_area)
                w_half = min(center_w, half_area)
                masks[ii, :, center_h - h_half:center_h + h_half, center_w - w_half:center_w + w_half] = 1

        return masks.to(img_pt.device)

    def embed_watermark(self, original: torch.Tensor, img_size: int) -> torch.Tensor:
        """
        Embed watermark in image using latent frequency space optimization
        
        Args:
            image: Input image tensor [B, C, H, W]
            message: Binary message {-1, 1} [B, message_bits]
        
        Returns:
            Watermarked image tensor
        """

        original = original.to(self.args.device)
        # message = message.to(self.device)
        
        # Step 1: Encode image to latent space
        image = F.interpolate(original, size=(self.args.vae_image_size, self.args.vae_image_size), mode="bilinear", align_corners=False)
        latent = self.pipe.vae.encode(2*image-1).latent_dist.sample() # [-1, 1], [B,4,64,64]
        
        # Step 2: Transform to frequency domain
        # latent_fft = torch.fft.fft2(latent, dim=(-2, -1))
        
        # Step 3: Initialize perturbation (trainable parameter)
        # delta_m = torch.zeros_like(latent_fft, requires_grad=True)
        delta_m = torch.zeros_like(latent, requires_grad=True)
        optimizer = optim.Adam([delta_m], lr=self.args.lr)

        # input = F.interpolate(original, size=(self.args.vae_image_size, self.args.vae_image_size), mode="bilinear", align_corners=False)
        # adaptive_weight = self._get_feature_weight(input, min_weight=0.3)

        # Training loop
        for step in range(self.args.steps):
        # for step in tqdm(range(self.args.steps), desc="Embedding Watermark"):
            optimizer.zero_grad()

            mask = self._create_random_mask(image, num_masks=1, mask_percentage=self.args.mask_percentage)
            mask = mask.to(self.args.device)
            target_mask = mask * 2 - 1 # Convert to {-1, 1}

            if random.random() < 0.5:
                mask = 1 - mask

            image = F.interpolate(original, size=(self.args.vae_image_size, self.args.vae_image_size), mode="bilinear", align_corners=False)
            mask = F.interpolate(mask, size=(self.args.vae_image_size, self.args.vae_image_size), mode="bilinear", align_corners=False)
            
            # perturbed_fft = latent_fft + delta_m
            # perturbed_latent = torch.fft.ifft2(perturbed_fft, dim=(-2, -1)).real
            perturbed_latent = latent + delta_m

            watermarked_image = self.pipe.vae.decode(perturbed_latent).sample
            watermarked_image = (watermarked_image + 1) / 2
            
            masked = watermarked_image * mask + (1 - mask) * image

            # uniform noise
            latent_mask = F.interpolate(mask, size=(64, 64), mode="bilinear", align_corners=False)
            
            std_val_0 = random.uniform(self.args.eps0_std[0], self.args.eps0_std[1])
            eps0 = torch.randn_like(perturbed_latent) * std_val_0

            perturbed_latent_1 = (perturbed_latent + eps0)*latent_mask + perturbed_latent*(1-latent_mask)

            watermarked_image_1 = self.pipe.vae.decode(perturbed_latent_1).sample
            masked_1 = (watermarked_image_1 + 1) / 2
            masked_1 = masked_1 * mask + (1 - mask) * image

            # Compute losses
            image = F.interpolate(original, size=(img_size, img_size), mode="bilinear", align_corners=False)
            mask = F.interpolate(mask, size=(img_size, img_size), mode="bilinear", align_corners=False)
            target_mask = F.interpolate(target_mask, size=(img_size, img_size), mode="bilinear", align_corners=False)
            masked = F.interpolate(masked, size=(img_size, img_size), mode="bilinear", align_corners=False)
            masked_1 = F.interpolate(masked_1, size=(img_size, img_size), mode="bilinear", align_corners=False)

            watermarked_image = F.interpolate(watermarked_image, size=(img_size, img_size), mode="bilinear", align_corners=False)
            watermarked_image_1 = F.interpolate(watermarked_image_1, size=(img_size, img_size), mode="bilinear", align_corners=False)
            
            watermarked_image = norm_imagenet(watermarked_image)
            masked = norm_imagenet(masked)
            masked_1 = norm_imagenet(masked_1)

            epsilon = 1e-6

            features = self.image_encoder(watermarked_image)[self.args.feat_layer]
            B, C, H, W = features.shape
            features = features.permute(0, 2, 3, 1).view(B, H * W, C)
            features_norm = features / (torch.norm(features, p=2, dim=-1, keepdim=True) + epsilon)
            cosine_similarity = torch.matmul(features_norm, self.direction_vectors.T)

            features = self.image_encoder(watermarked_image_1)[self.args.feat_layer]
            features = features.permute(0, 2, 3, 1).view(B, H * W, C)
            features_norm = features / (torch.norm(features, p=2, dim=-1, keepdim=True) + epsilon)
            cosine_similarity_1 = torch.matmul(features_norm, self.direction_vectors.T)

            B = cosine_similarity.shape[0]
            H = W = int(cosine_similarity.shape[1] ** 0.5)

            target_cosine = 0.15 # 1.0 - 1.5
            loss_m = torch.mean(F.relu(target_cosine - cosine_similarity))
            loss_m1 = torch.mean(F.relu(target_cosine - cosine_similarity_1))

            loss_f = self._dice_loss(cosine_similarity, mask)
            loss_f1 = self._dice_loss(cosine_similarity_1, mask)

            watermarked_image = denorm_imagenet(watermarked_image)
            masked = denorm_imagenet(masked)
            masked_1 = denorm_imagenet(masked_1)

            loss_psnr = self._psnr_loss(watermarked_image, image)
            loss_lpips = self._lpips_loss(watermarked_image, image)

            clean_weight = 1.0
            noisy_weight = 1.0
            
            total_loss = clean_weight * (loss_m) + \
                         noisy_weight * (loss_m1) + \
                         0.2 * loss_f + 0.2 * loss_f1 + \
                         self.args.lambda_p * loss_psnr + \
                         self.args.lambda_i * loss_lpips
            
            total_loss.backward()
            optimizer.step()

            if step == 0 or (step+1) % 100 == 0:
                psnr_val = self._compute_psnr(watermarked_image.detach(), image.detach())
                print(f"Step {step+1}, Loss: {total_loss.item():.4f}, PSNR: {psnr_val:.2f}")
                print(f"Mask Loss: {loss_m.item():.4f}") #, DICE Loss: {loss_d.item():.4f}")
                print(f"Mask1 Loss: {(loss_m1).item():.4f}") #, DICE1 Loss: {loss_d1.item():.4f}")
                print(f"Focal Loss: {loss_f.item():.4f}, Focal1 Loss: {loss_f1.item():.4f}")
                print(f"PSNR Loss: {loss_psnr.item():.4f}, LPIPS Loss: {loss_lpips.item():.4f}")

        # Final watermarked image
        # wm_fft = latent_fft + delta_m
        # wm_latent = torch.fft.ifft2(wm_fft, dim=(-2, -1)).real
        wm_latent = latent + delta_m
        rec_wm = self.pipe.vae.decode(wm_latent).sample
        rec_wm = (rec_wm + 1) / 2

        rec_clean = self.pipe.vae.decode(latent).sample
        rec_clean = (rec_clean + 1) / 2

        pixel_delta = rec_wm - rec_clean

        final_images = torch.clamp(rec_clean + 1.0 * pixel_delta, 0, 1)
        
        return final_images.detach(), pixel_delta.detach()
        
    def decode_watermark(self, watermarked_image: torch.Tensor) -> torch.Tensor:
        watermarked_image = watermarked_image.to(self.args.device)
        
        with torch.no_grad():
            watermarked_image = norm_imagenet(watermarked_image) 
            features = self.image_encoder(watermarked_image)[self.args.feat_layer]
            B, C, H, W = features.shape
            features = features.permute(0, 2, 3, 1).view(B, H * W, C)
            # dot_products = torch.matmul(features, self.direction_vectors.T) # [1, 256, 384]*[1, 384, 256] -> [1, 256, 1]

            epsilon = 1e-6
            features_norm = features / (torch.norm(features, p=2, dim=-1, keepdim=True) + epsilon)
            direction_norm = self.direction_vectors / (torch.norm(self.direction_vectors, p=2, dim=-1, keepdim=True) + epsilon)
            dot_products = torch.matmul(features_norm, direction_norm.T)

            B = dot_products.shape[0]
            H = W = int(dot_products.shape[1] ** 0.5)
            grid = dot_products.view(B, H, W).unsqueeze(0) # [1, 256, 1] -> [1, 1, 16, 16]
            grid = F.interpolate(grid, size=self.args.image_size, mode='bilinear', align_corners=False)

            # threshold = 0.1
            # binary_prediction = (grid >= threshold).float()
            temperature = 5.0  # (5.0 ~ 10.0 사이의 값으로 실험 필요)
            scaled_grid = grid * temperature
            confidence_map = torch.sigmoid(scaled_grid)
            binary_prediction = (confidence_map >= 0.5).float()

        return binary_prediction
    
    def _psnr_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """Negative PSNR loss (Equation 5)"""
        mse = F.mse_loss(pred, target)
        psnr = -10 * torch.log10(mse + 1e-8)
        return -psnr  # Negative for minimization
    
    def _lpips_loss(self, pred, target):
        pred_norm = pred * 2 - 1 # [-1, 1]
        target_norm = target * 2 - 1 # [-1, 1]
        return self.loss_fn_vgg(pred_norm, target_norm).mean()
    
    def _compute_psnr(self, pred: torch.Tensor, target: torch.Tensor) -> float:
        """Compute PSNR between images"""
        mse = F.mse_loss(pred, target).item()
        if mse == 0:
            return 100.0
        return 20 * np.log10(1.0 / np.sqrt(mse))
    
    def _dice_loss(self, cos_sim, gt_mask, smooth=1e-5):
        temperature = 5
        B = cos_sim.shape[0]
        H = W = int(cos_sim.shape[1] ** 0.5)
        grid = cos_sim.view(B, H, W).unsqueeze(0)
        grid = F.interpolate(grid, size=self.args.image_size, mode='bilinear', align_corners=False) # [B, Num_Patches, Feature_Dim]*[B, Feature_Dim, 1] = [B, Num_Patches, 1]
        pred = torch.sigmoid(grid * temperature) # Logits to probabilities
       
        # Flatten label and prediction tensors
        pred = pred.view(-1)
        target = gt_mask.view(-1)
        
        intersection = (pred * target).sum()
        dice_coeff = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
        
        return 1 - dice_coeff


In [13]:
def compute_psnr(a, b):
    mse = F.mse_loss(a, b).item()
    if mse == 0:
        return 100.0
    return 20 * torch.log10(1.0 / torch.sqrt(torch.tensor(mse)))

def calculate_iou(pred_mask, gt_mask):
    # Ensure masks are binary
    # pred_mask_bin = (pred_mask < 0).float()
    pred_mask_bin = torch.sigmoid(pred_mask)
    pred_mask_bin = (pred_mask_bin > 0.65).float() # Thresholding at 0.65
    gt_mask_bin = (gt_mask > 0).float() # Ground truth might not be 0/1

    save_image(pred_mask, "pred.png")
    save_image(pred_mask_bin, "pred_bin.png")
    save_image(gt_mask_bin, "gt.png")
    save_image(pred_mask_bin * gt_mask_bin, "intersection.png")
    save_image(pred_mask_bin + gt_mask_bin, "union.png")

    # Intersection and Union
    intersection = (pred_mask_bin * gt_mask_bin).sum()
    union = (pred_mask_bin + gt_mask_bin).sum() - intersection

    iou = intersection / (union + 1e-6) # Add epsilon to avoid division by zero
    return iou.item()

In [14]:
# img_path = "/mnt/nas5/suhyeon/projects/freq-loc/secret_code/0002.png"
# img_path = "/mnt/nas5/suhyeon/projects/freq-loc/secret_code/analysis_dist_wm_step400.png"
img_path = "/mnt/nas5/suhyeon/projects/freq-loc/baseline/20251120-152728/watermarked/0088.png"
seed = 45
proportion_masked = 0.3
trials = 5

In [15]:
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "sd-legacy/stable-diffusion-inpainting",
    # torch_dtype=torch.float16,
    cache_dir='/mnt/nas5/suhyeon/caches'
).to(device)

args = Params()
freqmark = FreqLoc(args=args)

# secret_key = torch.load('./learned_directional_vector.pt')
# freqmark.direction_vectors = torch.tensor(secret_key).to(args.device)
# print(freqmark.direction_vectors)

torch.manual_seed(seed)
generator = torch.Generator(device=device).manual_seed(seed)
to_tensor = transforms.ToTensor()

watermarked = load_img(img_path, transforms=args.transform)
original = load_img('/mnt/nas5/suhyeon/datasets/valAGE-Set/0088.png', transforms=val_transforms)

original = F.interpolate(original, size=(512, 512), mode="bilinear", align_corners=False)
watermarked = F.interpolate(watermarked, size=(512, 512), mode="bilinear", align_corners=False)

psnrs = []
ious = []
logits = []

for _ in range(trials):
    mask = create_random_mask(watermarked, num_masks=1, mask_percentage=proportion_masked)

    img_norm, min_norm, max_norm = norm_tensor(watermarked)
    img_edit_pil = pipe(prompt="", image=img_norm, mask_image=mask, generator=generator).images[0]
    img_edit = to_tensor(img_edit_pil)
    img_edit = img_edit.unsqueeze(0).to(device)

    img_edit = denorm_tensor(img_edit, min_norm, max_norm)  # [1, 3, H, W]
    # img_edit = img_edit * mask + watermarked * (1-mask)

    img_edit = F.interpolate(img_edit, size=(args.image_size, args.image_size), mode="bilinear", align_corners=False)
    decoded_batch = freqmark.decode_watermark(img_edit)

    save_image(img_edit, "edited.png")
    original = F.interpolate(original, size=(args.image_size, args.image_size), mode="bilinear", align_corners=False)
    watermarked_224 = F.interpolate(watermarked, size=(args.image_size, args.image_size), mode="bilinear", align_corners=False)
    mask_224 = F.interpolate(mask, size=(args.image_size, args.image_size), mode="bilinear", align_corners=False)
    psnrs.append(compute_psnr(watermarked_224, original))
    ious.append(calculate_iou(decoded_batch, 1-mask_224))
    logits.append(decoded_batch)
 
    print(f"PSNR: {psnrs[-1]:.2f}, IoU: {ious[-1]:.4f}")

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]An error occurred while trying to fetch /mnt/nas5/suhyeon/caches/models--sd-legacy--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae: Error no file named diffusion_pytorch_model.safetensors found in directory /mnt/nas5/suhyeon/caches/models--sd-legacy--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/vae.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...:  29%|██▊       | 2/7 [00:00<00:00,  8.42it/s]An error occurred while trying to fetch /mnt/nas5/suhyeon/caches/models--sd-legacy--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet: Error no file named diffusion_pytorch_model.safetensors found in directory /mnt/nas5/suhyeon/caches/models--sd-legacy--stable-diffusion-inpainting/snapshots/8a4288a76071f7280aedbdb3253bdb9e9d5d84bb/unet.
Defaulting to unsafe 

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: /opt/conda/envs/stableguard/lib/python3.12/site-packages/lpips/weights/v0.1/alex.pth
Tensor normalized: min=0.0, max=1.0


100%|██████████| 50/50 [00:08<00:00,  6.01it/s]


PSNR: 30.63, IoU: 0.8810
Tensor normalized: min=0.0, max=1.0


100%|██████████| 50/50 [00:08<00:00,  6.00it/s]


PSNR: 30.63, IoU: 0.8518
Tensor normalized: min=0.0, max=1.0


100%|██████████| 50/50 [00:08<00:00,  5.99it/s]


PSNR: 30.63, IoU: 0.8593
Tensor normalized: min=0.0, max=1.0


100%|██████████| 50/50 [00:08<00:00,  5.98it/s]


PSNR: 30.63, IoU: 0.8281
Tensor normalized: min=0.0, max=1.0


100%|██████████| 50/50 [00:08<00:00,  5.97it/s]


PSNR: 30.63, IoU: 0.8928


In [16]:
# args = Params()
# freqmark = FreqMark(args=args)

# # secret_key = torch.load('./learned_directional_vector.pt')
# # freqmark.direction_vectors = torch.tensor(secret_key).to(args.device)
# # print(freqmark.direction_vectors)

# torch.manual_seed(seed)
# generator = torch.Generator(device=device).manual_seed(seed)
# to_tensor = transforms.ToTensor()

# watermarked = load_img(img_path, transforms=args.transform)
# original = load_img('/mnt/nas5/suhyeon/datasets/DIV2K_train_HR/0002.png', transforms=val_transforms)

# original = F.interpolate(original, size=(512, 512), mode="bilinear", align_corners=False)
# watermarked = F.interpolate(watermarked, size=(512, 512), mode="bilinear", align_corners=False)

# psnrs = []
# ious = []
# logits = []

# for _ in range(trials):
#     mask = create_random_mask(watermarked, num_masks=1, mask_percentage=proportion_masked)

#     delta = torch.load('delta_m.pt').to(device)

#     original_norm = torch.linalg.norm(delta)

#     # 2. '구조 없는' 워터마크 생성
#     delta_m_random = torch.randn_like(delta)
#     random_norm = torch.linalg.norm(delta_m_random)
#     delta_m_random = delta_m_random * (original_norm / random_norm) # 세기를 동일하게 맞춤

#     # 3. '세기만 약한' 워터마크 생성
#     delta_m_scaled = delta * 0.5
    
#     results = {}

#     latent = freqmark.vae.encode(2*original-1).latent_dist.sample()
#     latent_fft = torch.fft.fft2(latent, dim=(-2, -1))

#     for name, delta in [("Optimized", delta), 
#                         ("Random", delta_m_random), 
#                         ("Scaled", delta_m_scaled)]:
        
#         final_fft = latent_fft + delta
#         final_latent = torch.fft.ifft2(final_fft, dim=(-2, -1)).real
#         watermarked_image = (freqmark.vae.decode(final_latent).sample + 1) / 2
        
#         watermarked_image = F.interpolate(watermarked_image, size=(args.dino_image_size, args.dino_image_size), mode="bilinear", align_corners=False)
#         logits = freqmark.decode_watermark(watermarked_image) 
        
#         # 워터마크가 있어야 할 영역(gt_mask=1)의 평균 logit 점수 계산
#         # avg_logit = (logits * gt_mask_resized).sum() / gt_mask_resized.sum()
#         # results[name] = avg_logit.item()
#         save_image(logits, f"logits_{name}.png")

In [17]:
print(f"## Average on {trials} trials ##")
print(f"PSNR (imperceptibility): {np.mean(psnrs):.2f} dB")
print(f"IoU (localization accuracy): {np.mean(ious):.4f}")

## Average on 5 trials ##
PSNR (imperceptibility): 30.63 dB
IoU (localization accuracy): 0.8626


In [18]:
original = F.interpolate(original, size=(224, 224), mode="bilinear", align_corners=False)
watermarked = F.interpolate(watermarked, size=(224, 224), mode="bilinear", align_corners=False)
mask = F.interpolate(mask, size=(224, 224), mode="bilinear", align_corners=False)
save_image(original, "eval_original.png")
save_image(watermarked, "eval_watermarked.png")
save_image(img_edit, "eval_edited_w_mask.png")
save_image(decoded_batch, "eval_localized.png")
save_image(1-mask, "eval_mask.png")
save_image(torch.abs(img_edit-original)*10, "eval_edit-ori.png")
save_image(torch.abs(watermarked-original)*10, "eval_wm-ori.png")

RuntimeError: The size of tensor a (256) must match the size of tensor b (224) at non-singleton dimension 3

In [None]:
# total_logits = torch.cat(logits, dim=0)
# # torch.save(total_logits, "logits_w_l1_loss.pt")
# total_logits = total_logits.cpu().numpy().flatten()
# print(f"Mean: {total_logits.mean():.2f}, Std: {total_logits.std():.2f}, Min: {total_logits.min():.2f}, Max: {total_logits.max():.2f}")

In [None]:
sig = torch.sigmoid(torch.cat(logits, dim=0)).cpu().numpy().flatten()

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 6))
plt.hist(sig, bins=50, alpha=0.7)#, label='A: w/ L1 loss')
plt.title('Logit Distribution Comparison')
plt.xlabel('Logit Value')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True)
plt.savefig('logits_comparison.png')
# print("\nSaved logit distribution histogram to 'logit_histogram.png'")

In [None]:
# # logits_a = torch.load("logits_wo_loss.pt").cpu().numpy().flatten()
# logits_a = torch.load("logits_wo_loss.pt").cpu().numpy().flatten()
# logits_b = torch.load("logits_w_l1_loss.pt").cpu().numpy().flatten()
# logits_c = total_logits
# print(f"[A: w/o Add. Loss]Mean: {logits_a.mean():.2f}, Std: {logits_a.std():.2f}, Min: {logits_a.min():.2f}, Max: {logits_a.max():.2f}")
# print(f"[B: w/ L1 Loss] Mean: {logits_b.mean():.2f}, Std: {logits_b.std():.2f}, Min: {logits_b.min():.2f}, Max: {logits_b.max():.2f}")
# print(f"[B: w/ L1 Loss (Dual)] Mean: {logits_c.mean():.2f}, Std: {logits_c.std():.2f}, Min: {logits_c.min():.2f}, Max: {logits_c.max():.2f}")

# import matplotlib.pyplot as plt
# plt.figure(figsize=(10, 6))
# plt.hist(logits_a, bins=50, alpha=0.4, label='A: w/o L')
# plt.hist(logits_b, bins=50, alpha=0.4, label='B: w/ L')
# plt.hist(logits_c, bins=50, alpha=0.4, label='B: w/ L (Dual)')
# plt.title('Logit Distribution Comparison')
# plt.xlabel('Logit Value')
# plt.ylabel('Frequency')
# plt.legend()
# plt.grid(True)
# plt.savefig('logits_comparison.png')


In [None]:
# sig_a = torch.sigmoid(torch.load("logits_wo_loss.pt")).cpu().numpy().flatten()
# sig_b = torch.sigmoid(torch.load("logits_w_l1_loss.pt")).cpu().numpy().flatten()
# sig_c = torch.sigmoid(torch.load("logits_w_l1_loss_dual.pt")).cpu().numpy().flatten()
# print(f"[A: w/o Add. Loss]Mean: {sig_a.mean():.2f}, Std: {sig_a.std():.2f}, Min: {sig_a.min():.2f}, Max: {sig_a.max():.2f}")
# print(f"[B: w/ L1 Loss] Mean: {sig_b.mean():.2f}, Std: {sig_b.std():.2f}, Min: {sig_b.min():.2f}, Max: {sig_b.max():.2f}")
# print(f"[C: w/ L1 Loss (Dual)] Mean: {sig_c.mean():.2f}, Std: {sig_c.std():.2f}, Min: {sig_c.min():.2f}, Max: {sig_c.max():.2f}")

<!--  -->

In [None]:
# import matplotlib.pyplot as plt
# plt.figure(figsize=(10, 6))
# plt.hist(sig_a, bins=50, alpha=0.7, label='A: w/o Add. Loss')
# plt.hist(sig_b, bins=50, alpha=0.7, label='B: w/ L1 Loss')
# plt.hist(sig_c, bins=50, alpha=0.7, label='B: w/ L1 Loss (Dual)')
# plt.title('Logit Distribution Comparison')
# plt.xlabel('Logit Value')
# plt.ylabel('Frequency')
# plt.legend()
# plt.grid(True)
# plt.savefig('logits_comparison.png')
# # print("\nSaved logit distribution histogram to 'logit_histogram.png'")

In [None]:
# 