In [1]:
%cd hidden

/raid/home/ashhar21137/watermarking3/stable_signature/hidden


In [3]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"]='4'

In [4]:
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# you should run this notebook in the root directory of the hidden project for the following imports to work
# %cd ..
from models import HiddenEncoder, HiddenDecoder, EncoderWithJND, EncoderDecoder
from attenuations import JND

In [9]:
import cv2 
from collections import defaultdict
import hashlib
from tqdm import tqdm 
from PIL import Image, ImageEnhance
from skimage.util import random_noise

In [10]:
def msg2str(msg):
    return "".join([('1' if el else '0') for el in msg])

def str2msg(str):
    return [True if el=='1' else False for el in str]

In [11]:
def brightness_attack(img_path, out_path, multi=False):
    brightness = 2
    if os.path.exists(out_path) and not multi:
        return
    
    img = Image.open(img_path)
    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(brightness)
    img.save(out_path)


def gaussian_noise_attack(img_path, out_path, multi=False):
        std = 0.1
        
        if os.path.exists(out_path) and not multi:
            return 
        
        image = cv2.imread(img_path)
        image = image / 255.0
        # Add Gaussian noise to the image
        noise_sigma = std  # Vary this to change the amount of noise
        noisy_image = random_noise(image, mode='gaussian', var=noise_sigma ** 2)
        # Clip the values to [0, 1] range after adding the noise
        noisy_image = np.clip(noisy_image, 0, 1)
        noisy_image = np.array(255 * noisy_image, dtype='uint8')
        cv2.imwrite(out_path, noisy_image)


def rotate_attack(img_path, out_path, multi=False):
    degree = 30
    expand = 0
    if os.path.exists(out_path) and not multi:
        return 
    
    img = Image.open(img_path)
    img = img.rotate(degree, expand=expand)
    img = img.resize((512,512))
    img.save(out_path)


def jpeg_attack(img_path, out_path, multi=False):
    quality = 50
    if os.path.exists(out_path) and not multi:
        return 
    
    img = Image.open(img_path)
    img.save(out_path, "JPEG", quality=quality)


In [12]:
class Params():
    def __init__(self, encoder_depth:int, encoder_channels:int, decoder_depth:int, decoder_channels:int, num_bits:int,
                attenuation:str, scale_channels:bool, scaling_i:float, scaling_w:float):
        # encoder and decoder parameters
        self.encoder_depth = encoder_depth
        self.encoder_channels = encoder_channels
        self.decoder_depth = decoder_depth
        self.decoder_channels = decoder_channels
        self.num_bits = num_bits
        # attenuation parameters
        self.attenuation = attenuation
        self.scale_channels = scale_channels
        self.scaling_i = scaling_i
        self.scaling_w = scaling_w

NORMALIZE_IMAGENET = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
UNNORMALIZE_IMAGENET = transforms.Normalize(mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], std=[1/0.229, 1/0.224, 1/0.225])
default_transform = transforms.Compose([transforms.ToTensor(), NORMALIZE_IMAGENET])

params = Params(
    encoder_depth=4, encoder_channels=64, decoder_depth=8, decoder_channels=64, num_bits=48,
    attenuation="jnd", scale_channels=False, scaling_i=1, scaling_w=1.5
)

decoder = HiddenDecoder(
    num_blocks=params.decoder_depth,
    num_bits=params.num_bits,
    channels=params.decoder_channels
)
encoder = HiddenEncoder(
    num_blocks=params.encoder_depth,
    num_bits=params.num_bits,
    channels=params.encoder_channels
)
attenuation = JND(preprocess=UNNORMALIZE_IMAGENET) if params.attenuation == "jnd" else None
encoder_with_jnd = EncoderWithJND(
    encoder, attenuation, params.scale_channels, params.scaling_i, params.scaling_w
)

In [13]:
ckpt_path = "ckpts/hidden_replicate.pth"

state_dict = torch.load(ckpt_path, map_location='cpu')['encoder_decoder']
encoder_decoder_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
encoder_state_dict = {k.replace('encoder.', ''): v for k, v in encoder_decoder_state_dict.items() if 'encoder' in k}
decoder_state_dict = {k.replace('decoder.', ''): v for k, v in encoder_decoder_state_dict.items() if 'decoder' in k}

encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

encoder_with_jnd = encoder_with_jnd.to(device).eval()
decoder = decoder.to(device).eval()

In [None]:
watermarked_images = "/raid/home/ashhar21137/watermarking2/original_images"
output_dir = "/raid/home/ashhar21137/watermarking3/stable_sig_watermarked"
results_file = "initial_stable_sig_results.json"
params = {'num_bits': 48}  # Example, adjust accordingly
