In [1]:
import os
import torch
from PIL import Image
from torchvision.transforms.functional import to_tensor

from watermark_anything.data.metrics import msg_predict_inference
from notebooks.inference_utils import (
    load_model_from_checkpoint, 
    default_transform, 
    msg2str
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the model
exp_dir = "checkpoints"
json_path = os.path.join(exp_dir, "params.json")
ckpt_path = os.path.join(exp_dir, 'wam_mit.pth')
wam = load_model_from_checkpoint(json_path, ckpt_path).to(device).eval()

# Configuration
output_dir = "outputs_H"
watermark_dir = os.path.join(output_dir, "watermarked")

# List all watermarked images
wm_files = [f for f in os.listdir(watermark_dir) if f.endswith('_wm.png')]

print(f"Found {len(wm_files)} watermarked images in {watermark_dir}")
print("-" * 80)

making attention of type 'vanilla' with 64 in_channels
Working with z of shape (1, 68, 32, 32) = 69632 dimensions.
making attention of type 'vanilla' with 64 in_channels
Model loaded successfully from checkpoints/wam_mit.pth
{'embedder_config': 'configs/embedder.yaml', 'augmentation_config': 'configs/all_augs_multi_wm.yaml', 'extractor_config': 'configs/extractor.yaml', 'attenuation_config': 'configs/attenuation.yaml', 'embedder_model': 'vae_small', 'extractor_model': 'sam_base', 'nbits': 32, 'img_size': 256, 'img_size_extractor': 256, 'attenuation': 'jnd_1_3_blue', 'scaling_w': 2.0, 'scaling_w_schedule': None, 'scaling_i': 1.0, 'roll_probability': 0.2, 'multiple_w': 1.0, 'nb_wm_eval': 5, 'optimizer': 'AdamW,lr=1e-4', 'optimizer_d': None, 'scheduler': 'CosineLRScheduler,lr_min=1e-6,t_initial=100,warmup_lr_init=1e-6,warmup_t=5', 'epochs': 200, 'batch_size': 8, 'batch_size_eval': 16, 'temperature': 1.0, 'workers': 8, 'to_freeze_embedder': None, 'lambda_w': 1.0, 'lambda_w2': 6.0, 'lambda_

In [2]:
for wm_file in wm_files:
    wm_path = os.path.join(watermark_dir, wm_file)
    
    # Load the watermarked image
    img = Image.open(wm_path).convert("RGB")
    img_tensor = default_transform(img).unsqueeze(0).to(device)
    
    # Detect watermark
    preds = wam.detect(img_tensor)["preds"]
    mask_preds = torch.sigmoid(preds[:, 0, :, :])
    bit_preds = preds[:, 1:, :, :]
    
    # Predict message
    pred_message = msg_predict_inference(bit_preds, mask_preds).cpu().float()
    confidence = torch.max(mask_preds).item()
    
    # Print result
    print(f"{wm_file}: {msg2str(pred_message[0])} (confidence: {confidence:.4f})")



0_08cc23c13b79af4d3852e78a8af8ced_original.png_wm.png: 01000100010000101110101111111100 (confidence: 1.0000)
0_126b1334283b521949e0684339c5389b_original.png_wm.png: 01000100010000101110101111111100 (confidence: 1.0000)
0_1416aabf6d59fd4e348473adc87838f_original.png_wm.png: 01000100010000101110101111111100 (confidence: 1.0000)
0_14cbceb82b5212b6d2b15b1c4387f2_original.png_wm.png: 01000100010000101110101111111100 (confidence: 1.0000)
0_16d7b535248b20815a2510276a592777_original.png_wm.png: 01000100010000101110101111111100 (confidence: 1.0000)
0_1746d6bd31e130db53a929b27975d44e_original.png_wm.png: 01000100010000101110101111111100 (confidence: 1.0000)
0_18805193994c3bc8d1adbdb318f2272_original.png_wm.png: 01000100010000101110101111111100 (confidence: 1.0000)
0_1ab8ea5ecaf859f061e099597d72b5ee_original.png_wm.png: 01000100010000101110101111111100 (confidence: 1.0000)
0_1bb84be82a6f5ac75d1a836e5cd13859_original.png_wm.png: 01000100010000101110101111111100 (confidence: 1.0000)
0_1bd4f858f492e