In [1]:
import argparse
from tqdm.auto import tqdm
import torch
import matplotlib.pyplot as plt
from utils import *
import random
import numpy as np
import math
import os 
import scipy
import torch.nn as nn
from modified_stable_diffusion import ModifiedStableDiffusionPipeline
import PIL
from PIL import Image, ImageFilter,ImageEnhance
import commpy.utilities as util
import cv2
from bm3d import bm3d_rgb

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='diffusion watermark')
    parser.add_argument('--w_seed', default=0, type=int)
    # parser.add_argument('--dataset', default='Gustavosta/Stable-Diffusion-Prompts')
    parser.add_argument('--dataset', default='coco')
    # parser.add_argument('--dataset', default='stablediffusionDB')
    parser.add_argument('--model_path', default='../stable-diffusion-2-1-base')
    # parser.add_argument('--model_path', default='../stable-diffusion-v1-4')
    parser.add_argument('--image_length', default=512, type=int)
    parser.add_argument('--secret_length', default=48, type=int)
    parser.add_argument('--num_inference_steps', default=25, type=int)
    parser.add_argument('--guidancescale', default=5, type=float)
    parser.add_argument('--reverse_inference_steps', default=25, type=int)
    # parser.add_argument('--model', default='./encoder_decoder_pretrain/model48bit.pth', type=str)
    # parser.add_argument('--model', default='./model48bit_finetuned.pth', type=str)
    parser.add_argument('--model', default='./model48bit_finetuned_backup.pth', type=str)
    parser.add_argument('--birghtness', default=None, type=float,choices=[1,2,3,4,5])
    parser.add_argument('--noise', default=None, type=float,choices=[0.01,0.05])
    parser.add_argument('--contrast', default=None, type=float,choices=[1,2,3,4,5])
    parser.add_argument('--hue', default=None, type=float,choices=[0.25,2])
    parser.add_argument('--blur', default=None, type=int,choices=[1,3,5])
    parser.add_argument('--jpegcompression', default=None, type=int,choices=[40,50])
    parser.add_argument('--resize', default=None, type=float,choices=[0.4,0.8])
    parser.add_argument('--bm3d', default=None, type=float,choices=[10,20])
    args =parser.parse_known_args()[0]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_printoptions(sci_mode=False,profile='full')
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    maxlength=150
    
# dataset
dataset, prompt_key = get_dataset(args)
dataset=promptdataset(dataset,prompt_key)

#model
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.model_path, subfolder='scheduler')
pipe = ModifiedStableDiffusionPipeline.from_pretrained(
        args.model_path,
        scheduler=scheduler,
        torch_dtype=torch.float16,
        revision='fp16',
        )
pipe = pipe.to(device)

#diffusetrace
from encoder_decoder_pretrain.watermark_model import *
encoder=Watermark(secret_length=args.secret_length).to(device)
if args.model !=None:
    encoder.load_state_dict(torch.load(args.model))
encoder.eval()

secret=np.random.choice([0, 1], size=(args.secret_length))
Metric,Loss,Error_correct=[],[],[]
Secret=[]
Secret.append(secret)
for t in tqdm(range(4)):
    initial_latents=get_random_latents(pipe,args)
    init_latents=initial_latents.detach().clone()
    
    
    secret_tmp = torch.Tensor(secret).unsqueeze(-1).unsqueeze(-1).unsqueeze(0).to(device)
    secret_tmp = secret_tmp.expand(-1,-1,64,64)
    matrix1,mean,logvar=encoder(secret_tmp)
    mean=mean.reshape(-1,4,64,64)
    logvar=logvar.reshape(-1,4,64,64)
    eps = torch.randn_like(logvar)
    std = torch.exp(logvar / 2)
    matrix = eps * std + mean
    init_latents=matrix.half()

    prompt=dataset[random.randint(1, len(dataset))][0:maxlength]
    print(f"current prompt: {prompt}")
    img1= pipe(prompt=prompt,num_inference_steps=args.num_inference_steps,\
    latents=init_latents,guidance_scale=args.guidancescale).images[0]
    
    if args.birghtness != None:
                img1 = transforms.ColorJitter(brightness=args.birghtness)(img1)
    if args.noise != None:
                img1 = np.array(img1, dtype=np.uint8)
                g_noise = np.random.randn(*img1.shape).astype(np.uint8) * args.noise
                noisy_array = np.clip(img1.astype(np.float32) + g_noise, 0, 255).astype(np.uint8)
                img1 = Image.fromarray(noisy_array)
    if args.contrast != None:
                enhancer = ImageEnhance.Contrast(img1)
                factor = args.contrast
                img1= enhancer.enhance(factor)
    if args.hue != None:
                enhancer = ImageEnhance.Color(img1)
                factor = args.hue
                img1 = enhancer.enhance(factor)
    if args.jpegcompression != None:
                img1=compress_jpeg_to_pil(img1, args.jpegcompression)
    if args.blur != None:
                img1=Image.fromarray(cv2.GaussianBlur(np.array(img1),(args.blur,args.blur), 1))
    if args.resize != None:
                img1 = img1.resize((int(args.image_length*args.resize), int(args.image_length*args.resize)), PIL.Image.BICUBIC)
    if args.bm3d != None:
                rgb_array = np.array(img1)
                denoised_red = bm3d_rgb(rgb_array,sigma_psd=args.bm3d)
                img1 = Image.fromarray(denoised_red.astype(np.uint8))
    
    img1.save(f'./exp/img/{t}.png')

np.save('exp/secret.npy',secret)

    

    


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

Watermark(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(48, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(96, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (mean_linear): Linear(in_features=4096, out_features=16384, bias=True)
  (var_linear): Linear(in_feat

  0%|          | 0/4 [00:00<?, ?it/s]

current prompt: Elderly women debark a bus at a station. 
current prompt: Man and woman in the kitchen of a white house.
current prompt: A woman is sitting with a suitcase on some train tracks.
current prompt: Men looking at something near a boat and a truck with the hood raised.
