# Testing CIN Wrapper

In [None]:

from cin_model import CIN_MODEL
import torch
import numpy as np  

import warnings
warnings.filterwarnings("ignore")

cin = CIN_MODEL()
limit=100
msgs= [torch.randint(0,2,(1,30)) for _i in range(limit)]
img_list='dataset/temp/0'
imgs_model,org_imgs_model = cin._encode_(img_list,msgs,limit=limit)
decoded_msgs = cin._decode_(imgs_model)

bitacc = [np.mean((msgs[i].numpy() == decoded_msgs[i]).astype(np.float32)) for i in range(limit)]
np.mean(bitacc)


1.0

Confirming Published Results on Robustness and Visual Quality

In [None]:

from utilsNew import tensor2PIL
encode_image_pil = tensor2PIL(imgs_model)
orig_image_pil = tensor2PIL(org_imgs_model)


from img_attacks import *
from torchvision import transforms

attacks =  { 
            "Identity":[lambda x,y:x, {"y":None}],
            "JPEG-50":[jpeg_comp, {"quality": 50}],
            "Gauss.Blur-2.0":[gaussian_blur_pil, {"sigma": 2.0}],
           "Dropout-30%":[apply_dropout_pil, {"p": 0.3}],
            "Cropout-30%":[cropout_fn, {"crop_ratio": 0.3}],
            "Crop-3.5%":[blacken_edges, {"ratio": 0.035}],
}

for attack in attacks.keys():
    attacked_imgs = []
    bit_acc_attacked =[]
    for img in encode_image_pil:
        attacked_img = attacks[attack][0](img, **attacks[attack][1])
        attacked_img_ = transforms.ToTensor()(attacked_img).unsqueeze(0)  # Convert to tensor and add batch dimension
        attacked_imgs.append(attacked_img_)
    
    decoded_message_attacked = cin._decode_(attacked_imgs)
    bit_acc_attacked = [np.mean(decoded_message_attacked[i] == msgs[i].detach().cpu().numpy()) for i in range(len(decoded_message_attacked))]
    print(f"Attack: {attack}, Bit accuracy after attack: {np.mean(bit_acc_attacked):.5f} ± {np.std(bit_acc_attacked):.5f}")

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

psnr_values = [np.mean([psnr(np.array(orig_image_pil[i]), np.array(encode_image_pil[i]),data_range=255) for i in range(len(orig_image_pil))])]
ssim_values = [np.mean([ssim(np.array(orig_image_pil[i]), np.array(encode_image_pil[i]), channel_axis=-1, data_range=255) for i in range(len(orig_image_pil))])]
print(f"PSNR: {psnr_values[0]:.2f}, SSIM: {ssim_values[0]:.4f}")

Attack: Identity, Bit accuracy after attack: 1.00000 ± 0.00000
Attack: JPEG-50, Bit accuracy after attack: 0.94667 ± 0.05812
Attack: Gauss.Blur-2.0, Bit accuracy after attack: 0.96400 ± 0.03580
Attack: Dropout-30%, Bit accuracy after attack: 0.98167 ± 0.04907
Attack: Cropout-30%, Bit accuracy after attack: 1.00000 ± 0.00000
Attack: Crop-3.5%, Bit accuracy after attack: 1.00000 ± 0.00000
PSNR: 35.09, SSIM: 0.9276


Comment: Approximately mathces the publsihed results

# Testing WmForger Wrapper

In [None]:

from videoseal.wmforger.optimize_image import get_artifact_discriminator
from torchvision import transforms
import torch
import torch.nn.functional as F
import torch.nn as nn
class WMFORGER_MODEL():
    def __init__(self, device="cuda", ckt_path="videoseal/wmforger/convnext_pref_model.pth"):
        self.device = device
        self.ckt_path = ckt_path
        self.model =  get_artifact_discriminator(ckt_path, device=self.device)
        self.model.eval()

    @staticmethod
    def optimize(img, model, device="cuda:0", num_steps=100, lr=0.05):
        """
        Optimize the image to remove the watermark by doing gradient descent.
        The loss is minus the preference model output,
        i.e. we maximize the preference model output.
        """
        transform_image = transforms.Compose([
            transforms.Resize((768, 768)),
        ])
        img = transform_image(img).to(device)
        param = torch.nn.Parameter(torch.zeros_like(img)).to(device)

        optim = torch.optim.SGD([param], lr=lr)
        for _ in range(num_steps):
            optim.zero_grad()
            loss = -model((img + param).clip(0, 1)).mean()
            loss.backward()
            optim.step()

        return (img + param).clip(0, 1).detach().cpu()
    @staticmethod
    def clean_watermark(img_wm, stolen_watermark):
        kernel_x = torch.tensor(
                [[-1., 0., 1.],
                [-2., 0., 2.],
                [-1., 0., 1.]]
            ).unsqueeze(0).unsqueeze(0)
        kernel_y = torch.tensor(
            [[1., 2., 1.],
            [0., 0., 0.],
            [-1., -2., -1.]]
        ).unsqueeze(0).unsqueeze(0)
        kernel_to_grayscale = torch.tensor(
            [0.299, 0.587, 0.114]
        ).unsqueeze(1).unsqueeze(1).unsqueeze(0)

        conv_rgb = nn.Conv2d(3, 1, kernel_size=1, padding=0, bias=False)
        conv_x = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False, padding_mode='reflect')
        conv_y = nn.Conv2d(1, 1, kernel_size=3, padding=1, bias=False, padding_mode='reflect')
        conv_rgb.weight = nn.Parameter(kernel_to_grayscale, requires_grad=False)
        conv_x.weight = nn.Parameter(kernel_x, requires_grad=False)
        conv_y.weight = nn.Parameter(kernel_y, requires_grad=False)

        x = conv_rgb(img_wm)
        grad_x = conv_x(x)
        grad_y = conv_y(x)
        edge_map = grad_x.mul_(grad_x).add_(grad_y.mul_(grad_y)).clip(0,1)

        # Reduce gradients in stolen watermark.
        stolen_watermark = stolen_watermark * (1 - edge_map.sqrt())
        return stolen_watermark
    def remove_watermark(self,imgs,num_steps=100, lr=0.05, clean_wm=False):
        '''
        Remove watermark from a batch of images.
            imgs: list of torch tensors with shape (1,3,H,W) in range [0,1]
            num_steps: number of optimization steps
            lr: learning rate for optimization
        Returns:
            cleaned_imgs: list of torch tensors with shape (1,3,H,W) in range [0,1]
        '''
        cleaned_imgs=[]
        for img_wm in imgs:
            img_wm = img_wm.to("cpu")
            img_cleaned = self.optimize(img_wm, self.model, device=self.device, num_steps=num_steps, lr=lr)
            stolen_watermark = img_wm - F.interpolate(img_cleaned, size=img_wm.shape[-2:], mode="bilinear", align_corners=True, antialias=False)
            
            #Optionally clean the stolen watermark 
            if clean_wm:
                stolen_watermark = self.clean_watermark(img_wm,stolen_watermark)
            alpha = 2.0 # strength factor for the attack
            img_removed = img_wm - alpha * stolen_watermark
            
            cleaned_imgs.append(torch.clip(img_removed,0,1))
        return cleaned_imgs

Without Cleaning the Watermark

In [None]:
#As Reported in the Paper

wmforger = WMFORGER_MODEL() 
cleaned_imgs_cin= wmforger.remove_watermark(imgs_model,num_steps=500, lr=0.05,clean_wm=False)
decoded_msgs_cleaned = cin._decode_(cleaned_imgs_cin)
bitacc = [np.mean((msgs[i].numpy() == decoded_msgs_cleaned[i]).astype(np.float32)) for i in range(len(imgs_model))]
np.mean(bitacc)


0.0

With Cleaninig the Watermark

In [None]:
cleaned_imgs_cin= wmforger.remove_watermark(imgs_model,num_steps=500, lr=0.05, clean_wm=True)
decoded_msgs_cleaned = cin._decode_(cleaned_imgs_cin)
bitacc = [np.mean((msgs[i].numpy() == decoded_msgs_cleaned[i]).astype(np.float32)) for i in range(len(imgs_model))]
np.mean(bitacc)


0.10333333