In [1]:
import argparse
from tqdm.auto import tqdm
import torch
import matplotlib.pyplot as plt
from utils import *
import random
import numpy as np
import math
import os 
import scipy
import torch.nn as nn
from modified_stable_diffusion import ModifiedStableDiffusionPipeline
import PIL
from PIL import Image, ImageFilter,ImageEnhance
import commpy.utilities as util
import cv2
from bm3d import bm3d_rgb

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='diffusion watermark')
    parser.add_argument('--w_seed', default=0, type=int)
    # parser.add_argument('--dataset', default='Gustavosta/Stable-Diffusion-Prompts')
    parser.add_argument('--dataset', default='coco')
    # parser.add_argument('--dataset', default='stablediffusionDB')
    parser.add_argument('--model_path', default='../stable-diffusion-2-1-base')
    # parser.add_argument('--model_path', default='../stable-diffusion-v1-4')
    parser.add_argument('--image_length', default=512, type=int)
    parser.add_argument('--secret_length', default=48, type=int)
    parser.add_argument('--num_inference_steps', default=25, type=int)
    parser.add_argument('--guidancescale', default=5, type=float)
    parser.add_argument('--reverse_inference_steps', default=25, type=int)
    # parser.add_argument('--model', default='./encoder_decoder_pretrain/model48bit.pth', type=str)
    # parser.add_argument('--model', default='./model48bit_finetuned.pth', type=str)
    parser.add_argument('--model', default='./model48bit_finetuned_backup.pth', type=str)
    parser.add_argument('--birghtness', default=None, type=float,choices=[1,2,3,4,5])
    parser.add_argument('--noise', default=None, type=float,choices=[0.01,0.05])
    parser.add_argument('--contrast', default=None, type=float,choices=[1,2,3,4,5])
    parser.add_argument('--hue', default=None, type=float,choices=[0.25,2])
    parser.add_argument('--blur', default=None, type=int,choices=[1,3,5])
    parser.add_argument('--jpegcompression', default=None, type=int,choices=[40,50])
    parser.add_argument('--resize', default=None, type=float,choices=[0.4,0.8])
    parser.add_argument('--bm3d', default=30, type=float,choices=[10,20])
    args =parser.parse_known_args()[0]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    torch.set_printoptions(sci_mode=False,profile='full')
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    maxlength=150
    
# dataset
dataset, prompt_key = get_dataset(args)
dataset=promptdataset(dataset,prompt_key)

#model
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.model_path, subfolder='scheduler')
pipe = ModifiedStableDiffusionPipeline.from_pretrained(
        args.model_path,
        scheduler=scheduler,
        torch_dtype=torch.float16,
        revision='fp16',
        )
pipe = pipe.to(device)

#diffusetrace
from encoder_decoder_pretrain.watermark_model import *
encoder=Watermark(secret_length=args.secret_length).to(device)
if args.model !=None:
    encoder.load_state_dict(torch.load(args.model))
encoder.eval()

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

Watermark(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(48, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(96, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (mean_linear): Linear(in_features=4096, out_features=16384, bias=True)
  (var_linear): Linear(in_feat

Watermark(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (1): Sequential(
      (0): Conv2d(48, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (2): Sequential(
      (0): Conv2d(96, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
    )
  )
  (mean_linear): Linear(in_features=4096, out_features=16384, bias=True)
  (var_linear): Linear(in_feat

In [2]:
watermarked_dataset_dir='./exp/img'
watermarked_img_dir = os.path.join(watermarked_dataset_dir, '') # load watermarked images: around 1.0 bit-accuracy
num_images = len(os.listdir(watermarked_img_dir))-1
# watermarked_images = torch.zeros((num_images, 3, args.image_length, args.image_length))
watermarked_images=[]
for i in range(num_images):
            img_path = os.path.join(watermarked_img_dir, '{}.png'.format(i))
            im = Image.open(img_path)
            # im = torch.from_numpy(np.array(im))
            # im= im.permute(2, 0, 1)
            # im = (im/255)*2-1
            # watermarked_images[i] = im
            watermarked_images.append(im)

Secret = np.load('exp/secret.npy')
print(len(Secret))

10


In [3]:
for t in tqdm(range(num_images)):
    reverse_latents=reverse(watermarked_images[t],pipe,args).float()
    reverse_latents = reverse_latents.view(1, -1)
    x = encoder.decoder_projection(reverse_latents)
    x = torch.reshape(x, (-1, *encoder.decoder_input_chw))
    average_tensor1 = torch.from_numpy(Secret[t]).to(device)
    average_tensor2 = torch.round(torch.mean(encoder.decoder(x), dim=(-2, -1)))
    average_tensor3 = torch.mean(encoder.decoder(x), dim=(-2, -1))
    biterror=torch.sum(abs(average_tensor1-average_tensor2))
    mse_loss = torch.nn.MSELoss()
    loss = mse_loss(average_tensor3, average_tensor1)
    # print(int(biterror.detach().cpu().numpy()))
    print(int(biterror.detach().cpu().numpy()))


  0%|          | 0/10 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


1
1
1
3
1
4
0
2
0
1
