In [None]:
import argparse
import os

import torch
from diffusers import StableDiffusionPipeline
from torchvision import transforms
from PIL import Image

def load_image(p):
    return Image.open(p).convert('RGB').resize((512,512))

parser = argparse.ArgumentParser(description="Inference")
parser.add_argument(
    "--model_path",
    type=str,
    default='./dreambooth-outputs/concept-n161/checkpoint-1000/',
    help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
    "--img_dir",
    type=str,
    default="./wm_images/images",
    help="The output directory where predictions are saved",
)
parser.add_argument(
    "--output_dir",
    type=str,
    default="./wm_images/wm_images",
    help="The output directory where predictions are saved",
)

args =parser.parse_known_args()[0]

image_files = [f for f in os.listdir(args.img_dir) if f.endswith('.png')]
images = []
for image_file in image_files:
    image_path = os.path.join(args.img_dir, image_file)
    image = load_image(image_path)
    images.append(image)

import torch
import sys
import torch.nn.functional as F
sys.path.append("../")
from utils.models import SecretEncoder,SecretDecoder
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    UNet2DConditionModel,
)
from PIL import Image
import random
import numpy as np
def torch_to_pil(images):
    images = images.detach().cpu().float()
    if images.ndim == 3:
        images = images[None, ...]
    images = images.permute(0, 2, 3, 1)
    images = (images + 1) * 0.5
    images = (images * 255).clamp(0, 255).detach().cpu().numpy().astype(np.uint8)
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]

    return pil_images

def decode_latents(latents):
    # latents = 1 / vae.config.scaling_factor * latents
    image = vae.decode(latents).sample
    return image

def generate_random_mask(batch_size, height=512, width=512):
    mask = torch.zeros(batch_size, 3, height, width)
    total_area = height * width
    for b in range(batch_size):
        area_ratio = random.uniform(0.2, 0.2)
        target_area = int(total_area * area_ratio)
        found = False
        while not found:
            mask_height = random.randint(1, height)
            mask_width = target_area // mask_height
            if mask_width <= width and mask_height <= height:
                found = True
        start_x = random.randint(0, width - mask_width)
        start_y = random.randint(0, height - mask_height)
        mask[b, :, start_y:start_y + mask_height, start_x:start_x + mask_width] = 1
    return mask

sec_encoder = SecretEncoder(48).cuda()
sec_decoder = SecretDecoder(output_size=48).cuda()
models = torch.load('./checkpoints_48bit/checkpoints/state_dict_46.pth')
sec_encoder.load_state_dict(models['sec_encoder'])
sec_decoder.load_state_dict(models['sec_decoder'])
sec_encoder.eval()
sec_decoder.eval()
from torchvision import transforms
train_transforms = transforms.Compose(
    [
        transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(512),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

with torch.no_grad():
    vae = AutoencoderKL.from_pretrained(
    '../stable-diffusion-2-1-base/', subfolder="vae").cuda()
    
    image = Image.open('../1.png')
    image=train_transforms(image).cuda()
    
    mask_tensor = generate_random_mask(1).cuda()
    masked_image = image * mask_tensor
    
    
    y_indices, x_indices = torch.where(mask_tensor[0][0] > 0) 
    if y_indices.numel() == 0 or x_indices.numel() == 0:
        raise ValueError("Mask must have at least one non-zero element.")
    y_min, y_max = y_indices.min(), y_indices.max() 
    x_min, x_max = x_indices.min(), x_indices.max()  
    cropped_image = image[:, y_min:y_max + 1, x_min:x_max + 1]
    target_size = (512, 512) 
    scaled_image = F.interpolate(cropped_image.unsqueeze(0), size=target_size, mode='bilinear', align_corners=False)
    scaled_image = scaled_image

    # masked_image=scaled_image
    print(scaled_image.shape)
    
    watermarked_image_pil = torch_to_pil(masked_image)[0]
    watermarked_image_pil.show()
    watermarked_image_pil = torch_to_pil(scaled_image)[0]
    watermarked_image_pil.show()
    
    latents = vae.encode(masked_image).latent_dist.sample().detach()
    
    
    
    msg_val = torch.randint(0, 2, (1, 48)).cuda()
    latents=vae.encode(scaled_image).latent_dist.sample().detach()
    watermarked_latent, _ = sec_encoder(latents, msg_val.float())
    watermarked = decode_latents(watermarked_latent)
    
    watermarked_image_pil = torch_to_pil(watermarked)[0]
    watermarked_image_pil.show()
    
    decoded_msg = sec_decoder(watermarked* mask_tensor)
    decoded_msg = torch.argmax(decoded_msg, dim=2)
    acc = 1 - torch.abs(decoded_msg - msg_val).sum().float() / (48 * 1)
    print(f"acc {acc}")
    
    
    latents = vae.encode(masked_image).latent_dist.sample().detach()
    msg_val = torch.randint(0, 2, (1, 48)).cuda()
    watermarked_latent, _ = sec_encoder(latents, msg_val.float())
    watermarked = decode_latents(watermarked_latent)
    watermarked_image_pil = torch_to_pil(watermarked)[0]
    watermarked_image_pil.show()
    decoded_msg = sec_decoder(watermarked*mask_tensor)
    decoded_msg = torch.argmax(decoded_msg, dim=2)
    acc = 1 - torch.abs(decoded_msg - msg_val).sum().float() / (48 * 1)
    print(f"acc {acc}")
    exit()