In [1]:
import os
import torch
from PIL import Image
import torch.nn.functional as F

from watermark_anything.data.metrics import msg_predict_inference
from notebooks.inference_utils import (
    load_model_from_checkpoint, 
    default_transform, 
    msg2str,
    multiwm_dbscan
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
exp_dir = "checkpoints"
json_path = os.path.join(exp_dir, "params.json")
ckpt_path = os.path.join(exp_dir, 'wam_mit.pth')
wam = load_model_from_checkpoint(json_path, ckpt_path).to(device).eval()

# DBSCAN parameters
epsilon = 1
min_samples = 700

# Configuration
output_dir = "outputs_H"
multi_watermark_dir = os.path.join(output_dir, "multi_watermarked")

# List all multi-watermarked images
wm_files = [f for f in os.listdir(multi_watermark_dir) if f.endswith('_multi_wm.png')]

print(f"Found {len(wm_files)} multi-watermarked images in {multi_watermark_dir}")
print(f"DBSCAN parameters: epsilon={epsilon}, min_samples={min_samples}")
print("-" * 80)

making attention of type 'vanilla' with 64 in_channels
Working with z of shape (1, 68, 32, 32) = 69632 dimensions.
making attention of type 'vanilla' with 64 in_channels
Model loaded successfully from checkpoints/wam_mit.pth
{'embedder_config': 'configs/embedder.yaml', 'augmentation_config': 'configs/all_augs_multi_wm.yaml', 'extractor_config': 'configs/extractor.yaml', 'attenuation_config': 'configs/attenuation.yaml', 'embedder_model': 'vae_small', 'extractor_model': 'sam_base', 'nbits': 32, 'img_size': 256, 'img_size_extractor': 256, 'attenuation': 'jnd_1_3_blue', 'scaling_w': 2.0, 'scaling_w_schedule': None, 'scaling_i': 1.0, 'roll_probability': 0.2, 'multiple_w': 1.0, 'nb_wm_eval': 5, 'optimizer': 'AdamW,lr=1e-4', 'optimizer_d': None, 'scheduler': 'CosineLRScheduler,lr_min=1e-6,t_initial=100,warmup_lr_init=1e-6,warmup_t=5', 'epochs': 200, 'batch_size': 8, 'batch_size_eval': 16, 'temperature': 1.0, 'workers': 8, 'to_freeze_embedder': None, 'lambda_w': 1.0, 'lambda_w2': 6.0, 'lambda_

In [2]:
for wm_file in wm_files:
    wm_path = os.path.join(multi_watermark_dir, wm_file)
    
    # Load the watermarked image
    img = Image.open(wm_path).convert("RGB")
    img_tensor = default_transform(img).unsqueeze(0).to(device)
    
    # Detect watermarks
    preds = wam.detect(img_tensor)["preds"]
    mask_preds = F.sigmoid(preds[:, 0, :, :])
    bit_preds = preds[:, 1:, :, :]
    
    # Use DBSCAN to find multiple watermarks
    centroids, positions = multiwm_dbscan(bit_preds, mask_preds, epsilon=epsilon, min_samples=min_samples)
    
    print(f"{wm_file}: Found {len(centroids)} messages")
    
    if centroids:
        centroids_pt = torch.stack(list(centroids.values()))
        for i, centroid in enumerate(centroids_pt):
            print(f"  Message {i+1}: {msg2str(centroid)}")
    else:
        print("  No messages detected")
    
    print("-" * 40)



0_1bb84be82a6f5ac75d1a836e5cd13859_original.png_multi_wm.png: Found 2 messages
  Message 1: 11101000001111101101010110000000
  Message 2: 01101111010111010101001011111111
----------------------------------------
0_2858a371363f604153e86fb1c074bf_original.png_multi_wm.png: Found 2 messages
  Message 1: 11101000001111101101010110000000
  Message 2: 01101111010111010101001011111111
----------------------------------------
0_29fe5241becd92fa7665256089c0de_original.png_multi_wm.png: Found 2 messages
  Message 1: 11101000001111101101010110000000
  Message 2: 01101111010111010101001011111111
----------------------------------------
0_355e2452d8c7a92d33aad487227f626_original.png_multi_wm.png: Found 2 messages
  Message 1: 11101000001111101101010110000000
  Message 2: 01101111010111010101001011111111
----------------------------------------
0_71c43518ed80a997f3ba6fdab62a752_original.png_multi_wm.png: Found 2 messages
  Message 1: 11101000001111101101010110000000
  Message 2: 01101111010111010101