In [None]:
!git clone https://github.com/Jessieweneedtocook/SPV.git

In [None]:
%cd SPV

In [None]:
!pip install .

In [None]:
!pip install datasets diffusers transformers accelerate

In [None]:
import torch.nn.functional as F
import torch.nn as nn
import torch
import timm
class StabilityPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model("convnext_tiny", pretrained=True, features_only=True)
        self.decoder = nn.Sequential(
            nn.Conv2d(768, 256, 3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)[-1]              # [B, 96, 8, 8]
        x = self.decoder(x)                  # [B, 1, 16, 16]
        x = F.interpolate(x, size=(256, 256), mode='bilinear', align_corners=False)
        return x

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = StabilityPredictor().to(device)
model.load_state_dict(torch.load("/content/drive/MyDrive/stability_predictor.pth", map_location=device))
model.eval()

In [None]:
from vine.src.vine_turbo import VINE_Turbo, VAE_encode, VAE_decode, initialize_unet_no_lora, initialize_vae_no_lora
from vine.src.stega_encoder_decoder import ConditionAdaptor, CustomConvNeXt
from vine.src.model import make_1step_sched
from accelerate.utils import set_seed
import torch, os, gc
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# for p in model.parameters():
#     p.requires_grad = False


watermarker = VINE_Turbo(
    ckpt_path=None,
    device='cuda',
    stability_predictor=model,
    tensor_six=False
)
watermarker.load_state_dict(torch.load("/content/drive/MyDrive/wm_finetuned2_epochRB2.pth"))
watermarker.to(device)
watermarker.train()
for p in watermarker.parameters():
    p.requires_grad = True

watermarker.stability_predictor.eval()
for p in watermarker.stability_predictor.parameters():
    p.requires_grad = False


decoder = CustomConvNeXt(secret_size=100).to(device)
decoder.convnext.classifier = nn.Sequential(
    nn.Flatten(1),
    nn.Linear(1024, 100, bias=True)
)
decoder.load_state_dict(torch.load("/content/drive/MyDrive/decoder_finetuned2_epochRB2.pth"))
decoder.train()
for p in decoder.parameters():
    p.requires_grad = True


params = list(watermarker.parameters()) + list(decoder.parameters())

print("\n✅ Stage 1: ConditionAdaptor and Decoder will be trained. VAE, UNet, StabilityPredictor are frozen.")


In [None]:
from vine.src.vine_turbo import VINE_Turbo, VAE_encode, VAE_decode, initialize_unet_no_lora, initialize_vae_no_lora
from vine.src.stega_encoder_decoder import ConditionAdaptor, CustomConvNeXt
from vine.src.model import make_1step_sched
from accelerate.utils import set_seed
import torch, os, gc
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# for p in model.parameters():
#     p.requires_grad = False


# watermarker = VINE_Turbo.from_pretrained('Shilin-LU/VINE-B-Enc',
#     stability_predictor=model,  # your pretrained stability model
#     tensor_six=True  # unless you're using PretrainedConditionAdaptor
# )
# watermarker.load_state_dict(torch.load("/content/drive/MyDrive/wm_finetuned_epochF.pth"))

watermarker = VINE_Turbo(
    ckpt_path=None,
    device='cuda',
    stability_predictor=model,
    tensor_six=False
)
watermarker.to(device)
watermarker.eval()
for p in watermarker.parameters():
    p.requires_grad = False

watermarker.sec_encoder = ConditionAdaptor().to(device)
for p in watermarker.sec_encoder.parameters():
    p.requires_grad = True

watermarker.vae_a2b.encoder.train()
for p in watermarker.vae_a2b.encoder.parameters():
    p.requires_grad = True


decoder = CustomConvNeXt(secret_size=100).to(device)
decoder.convnext.classifier = nn.Sequential(
    nn.Flatten(1),
    nn.Linear(1024, 100, bias=True)
)
decoder.train()
for p in decoder.parameters():
    p.requires_grad = True


params = list(decoder.parameters()) + list(watermarker.sec_encoder.parameters()) + list(watermarker.vae_a2b.encoder.parameters())
print("\n✅ Stage 1: ConditionAdaptor and Decoder will be trained. VAE, UNet, StabilityPredictor are frozen.")

In [None]:
!pip install -q timm opencv-python matplotlib Pillow
!pip install datasets

In [None]:
import os
import requests
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import timm

In [None]:
def save_img(img, index, convert_mask=False):
    resize_size = (512, 512)
    img = img.convert("RGB").resize(resize_size, Image.BILINEAR)
    img_path = f"data/img_{index:05d}.jpg"

    img.save(img_path)

In [None]:
from itertools import islice
from datasets import load_dataset
import os
from PIL import Image
from tqdm import tqdm

dataset_stream = load_dataset("paint-by-inpaint/PIPE", split="train", streaming=True)
dataset = list(islice(dataset_stream, 5000))

In [None]:
os.makedirs("/content/data", exist_ok=True)
for i, entry in enumerate(tqdm(dataset)):
    save_img(entry["target_img"], i)

In [None]:
del dataset

In [None]:
pipe_dataset_stream = load_dataset("paint-by-inpaint/PIPE", split="train", streaming=True)
pipe_dataset = list(islice(pipe_dataset_stream, 5000, 10000))

In [None]:
counter = len(os.listdir("/content/data"))
for i, entry in enumerate(tqdm(pipe_dataset)):
    save_img(entry["target_img"], counter + i)

In [None]:
del pipe_dataset

In [None]:
from datasets import load_dataset
import os
import pandas as pd
from PIL import Image
from tqdm import tqdm
from itertools import islice

dataset_stream = load_dataset("timbrooks/instructpix2pix-clip-filtered", split="train", streaming=True)

os.makedirs("/content/data/images/class0", exist_ok=True)

csv_path = "/content/data/edit_prompts.csv"
records = []

for i, item in enumerate(tqdm(islice(dataset_stream, 10000))):
    image = item["original_image"]
    prompt = item["edit_prompt"]
    filename = f"img_{i:05d}.jpg"
    image.save(f"/content/data/images/class0/{filename}")
    records.append({"filename": filename, "edit_prompt": prompt})

df = pd.DataFrame(records)
df.to_csv(csv_path, index=False)

del dataset_stream

In [None]:
!mkdir -p /content/data/images/class0
!find /content/data -maxdepth 1 -iname "*.jpg" -exec mv {} /content/data/images/class0/ \;

In [None]:
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert("RGB")),
    transforms.ToTensor(),
])

epochs = 3
batch_size = 8

train_dataset = datasets.ImageFolder(root='/content/data/images', transform=transform)
loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)

In [None]:
secret_dim = 100
def generate_random_watermark(batch_size):
    return torch.rand(batch_size, secret_dim).to(device) * 0.9 + 0.05

mse_loss = nn.MSELoss()
optimizer = torch.optim.AdamW(params, lr=1e-5)

In [None]:

distortion_strength_paras = {
    'brightness':   (0.9, 2.5),
    'contrast':     (0.9, 2.5),
    'saturation':   (0.9, 2.5),
    'blurring':     (0.0, 15.0),
    'motion_blur':  (1, 10),
    'zoom_blur':    (1.0, 4.0),
    'pixelation':   (1, 20),
    'noise':        (0.0, 2)
}
# #2
# distortion_strength_paras = {
#     'brightness':   (0.3, 1.3),
#     'contrast':     (0.3, 1.3),
#     'saturation':   (0.3, 1.3),
#     'blurring':     (0.0, 5.0),
#     'motion_blur':  (1, 6),
#     'zoom_blur':    (1.0, 2.0),
#     'pixelation':   (1, 8),
#     'noise':        (0.0, 0.1),
#     'compression':  (100, 30)
# }
# #3
# distortion_strength_paras = {
#     'brightness':   (0.3, 1.3),
#     'contrast':     (0.3, 1.3),
#     'saturation':   (0.3, 1.3),
#     'blurring':     (0.0, 5.0),
#     'motion_blur':  (1, 6),
#     'zoom_blur':    (1.0, 2.0),
#     'pixelation':   (1, 8),
#     'noise':        (0.0, 0.1),
#     'compression':  (100, 30)
# }

In [None]:
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF

def apply_single_distortion(image, distortion_type, strength=None):
    assert distortion_type in distortion_strength_paras.keys(), f"Unsupported distortion: {distortion_type}"

    if distortion_type == "brightness":
        image = image * strength
        image = torch.clamp(image, 0.0, 1.0)

    elif distortion_type == "contrast":
        mean = image.mean(dim=(1, 2), keepdim=True)
        image = (image - mean) * strength + mean
        image = torch.clamp(image, 0.0, 1.0)

    elif distortion_type == "saturation":
        gray = image.mean(dim=0, keepdim=True)
        image = (image - gray) * strength + gray
        image = torch.clamp(image, 0.0, 1.0)

    elif distortion_type == "blurring":
        if strength > 0:
            kernel_size = int(strength * 2) | 1
            image = TF.gaussian_blur(image, kernel_size=kernel_size, sigma=strength)

    elif distortion_type == "motion_blur":
        if strength > 0:
            kernel_size = int(strength) | 1
            image = TF.gaussian_blur(image, kernel_size=kernel_size, sigma=strength/2)

    elif distortion_type == "zoom_blur":
        zoom_factor = strength
        C, H, W = image.shape
        zoomed = F.interpolate(image.unsqueeze(0), scale_factor=zoom_factor, mode='bilinear', align_corners=False)[0]
        zoomed = TF.center_crop(zoomed, (H, W))
        image = 0.5 * image + 0.5 * zoomed

    elif distortion_type == "pixelation":
        factor = max(1, int(strength))
        C, H, W = image.shape
        small = F.interpolate(image.unsqueeze(0), size=(H//factor, W//factor), mode='nearest')
        image = F.interpolate(small, size=(H, W), mode='nearest')[0]

    elif distortion_type == "noise":
        noise = torch.randn_like(image) * strength
        image = image + noise
        image = torch.clamp(image, 0.0, 1.0)

    else:
        raise ValueError(f"Unsupported distortion: {distortion_type}")

    return image


In [None]:
import random
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torch

def random_distort_batch(batch):
    B, C, H, W = batch.shape
    out = []

    # fig, axs = plt.subplots(1, B, figsize=(4*B, 4))

    for i in range(B):
        img = batch[i]  # Tensor already on GPU, [0, 1]

        # Pick one random distortion
        op = random.choice(list(distortion_strength_paras.keys()))
        min_val, max_val = distortion_strength_paras[op]
        strength = random.uniform(min_val, max_val)

        # Apply distortion
        img = apply_single_distortion(img, op, strength)

        img = torch.clamp(img, 0.0, 1.0)

        # Visualize
        # img_pil = TF.to_pil_image(img.cpu())

        # ax = axs[i] if B > 1 else axs
        # ax.imshow(img_pil)
        # ax.axis('off')
        # ax.set_title(f"{op}\n{strength:.2f}")

        out.append(img)

    # plt.tight_layout()
    # plt.show()

    return torch.stack(out)


In [None]:
import torch
from PIL import Image, ImageFilter, ImageEnhance
import torchvision.transforms.functional as TF
import io
import numpy as np
from vine.w_bench_utils.distortion.utils import to_tensor, to_pil

def apply_single_distortion(image, distortion_type, strength=None):
    assert distortion_type in distortion_strength_paras.keys(), f"Unsupported distortion: {distortion_type}"

    if distortion_type == "brightness":
        enhancer = ImageEnhance.Brightness(image)
        image = enhancer.enhance(strength)

    elif distortion_type == "contrast":
        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(strength)

    elif distortion_type == "saturation":
        enhancer = ImageEnhance.Color(image)
        image = enhancer.enhance(strength)

    elif distortion_type == "blurring":
        image = image.filter(ImageFilter.GaussianBlur(radius=strength))

    elif distortion_type == "motion_blur":
        k = int(strength)
        k = k + 1 if k % 2 == 0 else k
        image = image.filter(ImageFilter.BoxBlur(k // 2))

    elif distortion_type == "zoom_blur":
        zoom = strength
        w, h = image.size
        zoomed = image.resize((int(w * zoom), int(h * zoom)), Image.BILINEAR)
        left = (zoomed.width - w) // 2
        top = (zoomed.height - h) // 2
        zoomed = zoomed.crop((left, top, left + w, top + h))
        image = Image.blend(image, zoomed, alpha=0.5)

    elif distortion_type == "pixelation":
        factor = max(1, int(strength))
        small = image.resize((image.width // factor, image.height // factor), Image.NEAREST)
        image = small.resize(image.size, Image.NEAREST)

    elif distortion_type == "noise":
        tensor = to_tensor([image], norm_type=None)
        noise = torch.randn(tensor.size()) * strength
        image = to_pil((tensor + noise).clamp(0, 1), norm_type=None)[0]

    elif distortion_type == "compression":
        buffer = io.BytesIO()
        image.save(buffer, format="JPEG", quality=int(strength))
        image = Image.open(buffer)

    else:
        raise ValueError(f"Unsupported distortion: {distortion_type}")

    return image.convert("RGB")

In [None]:
import random

def random_distort_batch(batch):
    B, C, H, W = batch.shape
    out = []

    # fig, axs = plt.subplots(1, B, figsize=(4*B, 4))

    for i in range(B):
        img = TF.to_pil_image(batch[i].cpu())

        op = random.choice(list(distortion_strength_paras.keys()))
        min_val, max_val = distortion_strength_paras[op]
        strength = random.uniform(min_val, max_val)


        img = apply_single_distortion(img, op, strength)

        # ax = axs[i] if B > 1 else axs
        # ax.imshow(img)
        # ax.axis('off')
        # ax.set_title(f"{op}\n{strength:.2f}")


        out.append(TF.to_tensor(img))

    # plt.tight_layout()
    # plt.show()

    return torch.stack(out)

In [None]:
import pandas as pd
import torch
import random
from PIL import Image
from diffusers import StableDiffusionInstructPix2PixPipeline
import torchvision.transforms.functional as TF


prompt_df = pd.read_csv("/content/data/edit_prompts.csv")

pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    "timbrooks/instruct-pix2pix", torch_dtype=torch.float16
).to("cuda")
pipe.safety_checker = None

to_pil = lambda x: TF.to_pil_image(x.cpu())
to_tensor = lambda x: TF.to_tensor(x).to("cuda")

def maybe_edit_batch(batch, batch_idx):
    edited_batch = []
    batch_size = batch.size(0)

    for i, img_tensor in enumerate(batch):
        img_idx = batch_idx * batch_size + i

        if img_idx >= len(prompt_df):
            edited_batch.append(img_tensor)
            continue

        prompt = prompt_df.iloc[img_idx]["edit_prompt"]

        if random.random() < 0.5:
            pil_img = to_pil(img_tensor)
            guidance = random.uniform(7.0, 9.0)

            try:
                edited = pipe(
                    prompt=prompt,
                    image=pil_img,
                    num_inference_steps=20,
                    image_guidance_scale=1.5,
                    guidance_scale=guidance
                ).images[0]
                edited_batch.append(to_tensor(edited))
            except Exception as e:
                print(f"❌ Edit failed on idx {img_idx}: {e}")
                edited_batch.append(img_tensor)
        else:
            edited_batch.append(img_tensor)

    return torch.stack(edited_batch)


In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
!pip install lpips

In [None]:
import lpips
lpips_loss_fn = lpips.LPIPS(net='alex').to(device)
lpips_loss_fn.eval()
lpips_loss_fn.requires_grad_(False)

In [None]:
def bit_accuracy(pred_logits, target_bits):
    preds = (torch.sigmoid(pred_logits) > 0.5).float()
    correct_bits = (preds == target_bits).float().sum()
    total_bits = torch.numel(preds)
    return (correct_bits / total_bits).item()

In [None]:
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoTokenizer, CLIPTextModel
import os
import gc
from tqdm import tqdm
import torchvision.transforms.functional as TF
import random

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

gc.collect()
torch.cuda.empty_cache()

scaler = GradScaler()

# for model in [decoder, vae, unet, condition_adaptor, model]:
#     model.to(device)
for models in [decoder, watermarker]:
    models.to(device)


for epoch in range(epochs):
    total_loss = total_secret_loss = total_recon_loss = total_embed_mask_loss = total_lpips_loss = total_bit_acc = 0
    # if epoch == 4:
    #     for p in decoder.parameters():
    #         p.requires_grad = True
    #     params = list(watermarker.parameters()) + list(decoder.parameters())
    #     optimizer = torch.optim.AdamW(params, lr=1e-5)
    batch_ind = 0
    for imgs, _ in tqdm(loader, desc=f"Epoch {epoch+1}"):
        imgs = imgs.to(device)
        B = imgs.size(0)

        watermark = torch.randint(0, 2, (B, 100), device=device).float()

        resized_imgs = F.interpolate(imgs, size=(256, 256), mode='bicubic', align_corners=False)

        with torch.no_grad():
            stability_mask = model(resized_imgs)  # shape: [B, 1, H, W]

        resized_img = resized_imgs * 2.0 - 1.0

        input_image = imgs

        input_image = 2.0 * input_image - 1.0


        mask = torch.nan_to_num(stability_mask.clone().float(), nan=0.0, posinf=1.0, neginf=0.0)
        B, _, H, W = mask.shape
        flat_mask = mask.view(B, -1)
        k = (H * W) // 2
        topk = torch.topk(flat_mask, k=k, largest=True, dim=1).indices
        binary_mask = torch.zeros_like(flat_mask)
        binary_mask.scatter_(1, topk, 1.0)
        mask = binary_mask.view(B, 1, H, W)

        optimizer.zero_grad()

        with autocast():

            x_out_decoded = watermarker(resized_img, secret=watermark)

            x_out_decoded = torch.nan_to_num(x_out_decoded, nan=0.0, posinf=1.0, neginf=0.0).clamp(-1.0, 1.0)

            residual_256 = x_out_decoded - resized_img

            residual_512 = F.interpolate(residual_256, size=(512, 512), mode='bicubic', align_corners=False)

            encoded_image = residual_512 + input_image

            encoded_image = encoded_image * 0.5 + 0.5
            encoded_image = torch.clamp(encoded_image, min=0.0, max=1.0)

            # x_out_decoded = torch.nan_to_num(x_out_decoded, nan=0.0, posinf=1.0, neginf=0.0).clamp(-1.0, 1.0)
            # residual = x_out_decoded - imgs  # imgs is already [-1, 1]
            # residual_upsampled = F.interpolate(residual, size=(512, 512), mode="bilinear", align_corners=False)
            # imgs_upsampled = F.interpolate(imgs, size=(512, 512), mode="bilinear", align_corners=False)
            # reconstructed_img = imgs_upsampled + residual_upsampled
            # reconstructed_img = reconstructed_img.clamp(-1.0, 1.0)

            # Now match decoder preprocessing exactly
            # x_out_resized = F.interpolate(reconstructed_img, size=(256, 256), mode="bicubic", align_corners=False)
            # x_out_resized = (x_out_resized + 1.0) / 2.0
            if random.random() < 0.9:
                encoded_image_distort = random_distort_batch(encoded_image).to(device)
            else:
                encoded_image_distort = maybe_edit_batch(encoded_image, batch_ind).to(device)

            # x_out_resized = maybe_edit_batch(x_out_resized, batch_ind).to(device)


            encoded_image_256 = F.interpolate(encoded_image_distort, size=(256, 256), mode='bicubic', align_corners=False)

            pred_watermark = decoder(encoded_image_256)

            bit_acc = bit_accuracy(pred_watermark.detach(), watermark)

            x_sec = watermarker.sec_encoder(watermark, resized_img, stability_mask)

            delta = torch.abs(x_sec - resized_img)
            embed_mask_loss = torch.mean(delta * mask)

            x_out_lpips = (encoded_image + 1.0) / 2.0
            imgs_lpips = (imgs + 1.0) / 2.0
            lpips_loss = lpips_loss_fn(x_out_lpips, imgs_lpips).mean()

            recon_loss = F.mse_loss(encoded_image, imgs)
            secret_loss = F.binary_cross_entropy_with_logits(pred_watermark, watermark)

            loss = secret_loss * 2 + embed_mask_loss * 1.5 + recon_loss * 2  + lpips_loss * 1.5
            # + 2.0 * embed_mask_loss  ← optional if you want to penalize editing stable regions

            total_bit_acc += bit_acc

        if torch.isnan(loss) or torch.isinf(loss):
            print("NaN detected, skipping batch.")
            continue

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        for model_to_clip in [decoder, watermarker]:
            torch.nn.utils.clip_grad_norm_(model_to_clip.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        total_secret_loss += secret_loss.item()
        total_recon_loss += recon_loss.item()
        total_embed_mask_loss += embed_mask_loss.item()
        total_lpips_loss += lpips_loss.item()

        batch_ind += 1

    avg_loss = total_loss / len(loader)
    avg_secret = total_secret_loss / len(loader)
    avg_recon = total_recon_loss / len(loader)
    avg_lpips = total_lpips_loss / len(loader)
    avg_embed = total_embed_mask_loss / len(loader)
    avg_bit_acc = total_bit_acc / len(loader)

    print(f"Epoch [{epoch+1}/{epochs}] | Loss: {avg_loss:.4f} | Secret: {avg_secret:.4f} | EmbedMask: {avg_embed:.4f} | Recon: {avg_recon:.4f} | LPIPS: {avg_lpips:.4f} | BitAcc: {avg_bit_acc:.4f}")

    #torch.save(watermarker.vae_a2b.encoder.state_dict(), f"/content/drive/MyDrive/vae_enc_finetuned_epoch{epoch+1}.pth")
    #torch.save(watermarker.sec_encoder.state_dict(), f"/content/drive/MyDrive/condition_adaptor_finetuned_epoch{epoch+1}.pth")

torch.save(decoder.state_dict(), f"/content/drive/MyDrive/decoder_finetuned2_epoch{epoch+1}.pth")
torch.save(watermarker.state_dict(), f"/content/drive/MyDrive/wm_finetuned2_epoch{epoch+1}.pth")

In [None]:
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms


image_path = "/content/data/images/class0"
image_files = sorted([f for f in os.listdir(image_path) if f.lower().endswith((".jpg", ".png"))])
img = image_files[567]
img = Image.open(os.path.join(image_path, img)).convert("RGB")

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])
img_tensor = transform(img).unsqueeze(0).to(device)  # shape: [1, 3, 256, 256]


watermark = generate_random_watermark(1)  # shape: [1, 100]


with torch.no_grad():
    stability_mask = model(img_tensor)  # shape: [1, 1, 256, 256]


flat_mask = stability_mask.view(1, -1)
k = (256 * 256) // 2
_, topk_indices = torch.topk(flat_mask, k=k, largest=True, dim=1)
binary_mask = torch.zeros_like(flat_mask)
binary_mask.scatter_(1, topk_indices, 1.0)
mask = binary_mask.view(1, 1, 256, 256)  # 1 = stable

with torch.no_grad():
    x_sec = watermarker.sec_encoder(watermark, img_tensor, stability_mask)


def visualize_watermark_embedding(img, x_sec, mask):
    img = img.squeeze().detach().cpu()
    x_sec = x_sec.squeeze().detach().cpu()
    delta = torch.abs(img - x_sec).mean(dim=0)
    stability_mask = 1 - mask
    mask = stability_mask.squeeze().detach().cpu()

    plt.figure(figsize=(15, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(img.permute(1, 2, 0).clamp(0, 1))
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(delta, cmap='gnuplot')
    plt.title("Embedding Heatmap (Δ from x_sec)")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(mask, cmap="gray")
    plt.title("Stability Mask (1 = Stable)")
    plt.axis("off")

    plt.tight_layout()
    plt.show()


visualize_watermark_embedding(img_tensor, x_sec, mask)


In [None]:
with torch.no_grad():
    x_out = watermarker(img_tensor, secret=watermark)
delta = torch.abs(img_tensor - x_out).mean(dim=1, keepdim=True)  # [1, 1, H, W]
def visualize_embedding_map(original, watermarked, mask):
    original = original.squeeze().detach().cpu()
    watermarked = watermarked.squeeze().detach().cpu()
    delta = torch.abs(original - watermarked).mean(dim=0)  # [H, W]
    mask = mask.squeeze().detach().cpu()

    plt.figure(figsize=(15, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(original.permute(1, 2, 0).clamp(0, 1))
    plt.title("Original Image")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(delta, cmap="hot")
    plt.title("Watermark Δ Heatmap")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(mask, cmap="gray")
    plt.title("Stability Mask (1 = Stable)")
    plt.axis("off")

    plt.tight_layout()
    plt.show()
visualize_embedding_map(img_tensor, x_out, mask)