In [3]:
import argparse
import wandb
import copy
from tqdm import tqdm
from statistics import mean, stdev
from sklearn import metrics

import torch

from inverse_stable_diffusion import InversableStableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler
import open_clip
from optim_utils import *
from io_utils import *


def save_intermediate_results(obj: torch.tensor, index_of_step: int, filenname: str, dirpath: str = "OUT/intermediate_results", suffix: str = '.png'):
    """
    Save intermediate results of Tree-Ring watermarking
    Though B should be 1 here, because somehow Tree-Ring throws errors for B > 1
    
    @param res: torch.tensor of shape (B, C, H, W)
    """
    os.makedirs(dirpath, exist_ok=True)
    obj = obj[0]  # (C, H, W)
    for i in range(obj.shape[0]):
        img = obj[i].detach().cpu().numpy()
        img = (img - img.min()) / (img.max() - img.min()) * 255  # Normalize to 0-255
        img = Image.fromarray(img.astype('uint8'), 'L')
        img.save(os.path.join(dirpath, filenname.format(index_of_step) + f'_{i}' + suffix))  # e.g. 1___init_latents_no_w_0.png

In [5]:
parser = argparse.ArgumentParser(description='diffusion watermark')
parser.add_argument('--run_name', default='test')
parser.add_argument('--dataset', default='Gustavosta/Stable-Diffusion-Prompts')
parser.add_argument('--start', default=0, type=int)
parser.add_argument('--end', default=10, type=int)
parser.add_argument('--image_length', default=512, type=int)
parser.add_argument('--model_id', default='stabilityai/stable-diffusion-2-1-base')
parser.add_argument('--with_tracking', action='store_true')
parser.add_argument('--num_images', default=1, type=int)
parser.add_argument('--guidance_scale', default=7.5, type=float)
parser.add_argument('--num_inference_steps', default=50, type=int)
parser.add_argument('--test_num_inference_steps', default=None, type=int)
parser.add_argument('--reference_model', default=None)
parser.add_argument('--reference_model_pretrain', default=None)
parser.add_argument('--max_num_log_image', default=100, type=int)
parser.add_argument('--gen_seed', default=0, type=int)

# watermark
parser.add_argument('--w_seed', default=999999, type=int)
parser.add_argument('--w_channel', default=0, type=int)
parser.add_argument('--w_pattern', default='rand')
parser.add_argument('--w_mask_shape', default='circle')
parser.add_argument('--w_radius', default=10, type=int)
parser.add_argument('--w_measurement', default='l1_complex')
parser.add_argument('--w_injection', default='complex')
parser.add_argument('--w_pattern_const', default=0, type=float)

# for image distortion
parser.add_argument('--r_degree', default=None, type=float)
parser.add_argument('--jpeg_ratio', default=None, type=int)
parser.add_argument('--crop_scale', default=None, type=float)
parser.add_argument('--crop_ratio', default=None, type=float)
parser.add_argument('--gaussian_blur_r', default=None, type=int)
parser.add_argument('--gaussian_std', default=None, type=float)
parser.add_argument('--brightness_factor', default=None, type=float)
parser.add_argument('--rand_aug', default=0, type=int)

args, unknown = parser.parse_known_args()

if args.test_num_inference_steps is None:
    args.test_num_inference_steps = args.num_inference_steps

In [None]:
table = None
if args.with_tracking:
    wandb.init(project='diffusion_watermark', name=args.run_name, tags=['tree_ring_watermark'])
    wandb.config.update(args)
    table = wandb.Table(columns=['gen_no_w', 'no_w_clip_score', 'gen_w', 'w_clip_score', 'prompt', 'no_w_metric', 'w_metric'])

# load diffusion model
device = 'cuda' if torch.cuda.is_available() else 'cpu'

scheduler = DPMSolverMultistepScheduler.from_pretrained(args.model_id, subfolder='scheduler')
pipe = InversableStableDiffusionPipeline.from_pretrained(
    args.model_id,
    scheduler=scheduler,
    #torch_dtype=torch.float16,
    torch_dtype=torch.float32,
    revision='fp16',
    )
pipe = pipe.to(device)

# reference model
if args.reference_model is not None:
    ref_model, _, ref_clip_preprocess = open_clip.create_model_and_transforms(args.reference_model, pretrained=args.reference_model_pretrain, device=device)
    ref_tokenizer = open_clip.get_tokenizer(args.reference_model)

# dataset
dataset, prompt_key = get_dataset(args)

tester_prompt = '' # assume at the detection time, the original prompt is unknown
text_embeddings = pipe.get_text_embedding(tester_prompt)

# ground-truth patch
gt_patch = get_watermarking_pattern(pipe, args, device)

results = []
clip_scores = []
clip_scores_w = []
no_w_metrics = []
w_metrics = []

for i in tqdm(range(args.start, args.end)):
    seed = i + args.gen_seed
    
    current_prompt = dataset[i][prompt_key]
    
    ### generation
    # generation without watermarking
    set_random_seed(seed)
    init_latents_no_w = pipe.get_random_latents()
    # and-mill -----------------------------------------------------------------------------------------------------------------
    save_intermediate_results(init_latents_no_w, index_of_step=1, dirpath='OUT', filenname='{}___init_latents_no_w')
    # and-mill -----------------------------------------------------------------------------------------------------------------
    outputs_no_w = pipe(
        current_prompt,
        num_images_per_prompt=args.num_images,
        guidance_scale=args.guidance_scale,
        num_inference_steps=args.num_inference_steps,
        height=args.image_length,
        width=args.image_length,
        latents=init_latents_no_w,
        )
    orig_image_no_w = outputs_no_w.images[0]
    # and-mill -----------------------------------------------------------------------------------------------------------------
    orig_image_no_w.save("OUT/2___orig_image_no_w.png")
    # and-mill -----------------------------------------------------------------------------------------------------------------
    
    # generation with watermarking
    if init_latents_no_w is None:
        set_random_seed(seed)
        init_latents_w = pipe.get_random_latents()
    else:
        init_latents_w = copy.deepcopy(init_latents_no_w)
    # and-mill -----------------------------------------------------------------------------------------------------------------
    save_intermediate_results(init_latents_w, index_of_step=3, dirpath='OUT', filenname='{}___init_latents_w_BEFORE_RING')
    # and-mill -----------------------------------------------------------------------------------------------------------------

    # get watermarking mask
    watermarking_mask = get_watermarking_mask(init_latents_w, args, device)

    # inject watermark
    init_latents_w, init_latens_w_fft = inject_watermark(init_latents_w, watermarking_mask, gt_patch, args,
                                                         # and-mill -------------------------------------------
                                                         return_fft=True
                                                         # and-mill -------------------------------------------
                                                         )
    # and-mill -----------------------------------------------------------------------------------------------------------------
    save_intermediate_results(init_latents_w, index_of_step=4, dirpath='OUT', filenname='{}___init_latents_w_AFTER_RING')
    #
    save_intermediate_results(init_latens_w_fft, index_of_step=4, dirpath='OUT', filenname='{}___init_latents_w_AFTER_RING_FFT')
    # and-mill -----------------------------------------------------------------------------------------------------------------

    outputs_w = pipe(
        current_prompt,
        num_images_per_prompt=args.num_images,
        guidance_scale=args.guidance_scale,
        num_inference_steps=args.num_inference_steps,
        height=args.image_length,
        width=args.image_length,
        latents=init_latents_w,
        )
    orig_image_w = outputs_w.images[0]
    # and-mill -----------------------------------------------------------------------------------------------------------------
    orig_image_w.save("OUT/5___orig_image_w.png")
    # and-mill -----------------------------------------------------------------------------------------------------------------

    ### test watermark
    # distortion
    orig_image_no_w_auged, orig_image_w_auged = image_distortion(orig_image_no_w, orig_image_w, seed, args)
    # and-mill -----------------------------------------------------------------------------------------------------------------
    orig_image_no_w_auged.save("OUT/6___orig_image_no_w_auged.png")
    orig_image_w_auged.save("OUT/7___orig_image_w_auged.png")
    # and-mill -----------------------------------------------------------------------------------------------------------------

    # reverse img without watermarking
    img_no_w = transform_img(orig_image_no_w_auged).unsqueeze(0).to(text_embeddings.dtype).to(device)
    # and-mill -----------------------------------------------------------------------------------------------------------------
    orig_image_no_w_auged.save("OUT/8___img_no_w.png")
    # and-mill -----------------------------------------------------------------------------------------------------------------
    image_latents_no_w = pipe.get_image_latents(img_no_w, sample=False)

    reversed_latents_no_w = pipe.forward_diffusion(
        latents=image_latents_no_w,
        text_embeddings=text_embeddings,
        guidance_scale=1,
        num_inference_steps=args.test_num_inference_steps,
    )

    # reverse img with watermarking
    img_w = transform_img(orig_image_w_auged).unsqueeze(0).to(text_embeddings.dtype).to(device)
    image_latents_w = pipe.get_image_latents(img_w, sample=False)
    # and-mill -----------------------------------------------------------------------------------------------------------------
    save_intermediate_results(image_latents_w, index_of_step=9, dirpath='OUT', filenname='{}___image_latents_w')
    # and-mill -----------------------------------------------------------------------------------------------------------------

    reversed_latents_w = pipe.forward_diffusion(
        latents=image_latents_w,
        text_embeddings=text_embeddings,
        guidance_scale=1,
        num_inference_steps=args.test_num_inference_steps,
    )
    # and-mill -----------------------------------------------------------------------------------------------------------------
    save_intermediate_results(reversed_latents_w, index_of_step=10, dirpath='OUT', filenname='{}___reversed_latents_w')
    # and-mill -----------------------------------------------------------------------------------------------------------------

    # eval
    no_w_metric, w_metric = eval_watermark(reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args)

    if args.reference_model is not None:
        sims = measure_similarity([orig_image_no_w, orig_image_w], current_prompt, ref_model, ref_clip_preprocess, ref_tokenizer, device)
        w_no_sim = sims[0].item()
        w_sim = sims[1].item()
    else:
        w_no_sim = 0
        w_sim = 0

    results.append({
        'no_w_metric': no_w_metric, 'w_metric': w_metric, 'w_no_sim': w_no_sim, 'w_sim': w_sim,
    })

    no_w_metrics.append(-no_w_metric)
    w_metrics.append(-w_metric)

    if args.with_tracking:
        if (args.reference_model is not None) and (i < args.max_num_log_image):
            # log images when we use reference_model
            table.add_data(wandb.Image(orig_image_no_w), w_no_sim, wandb.Image(orig_image_w), w_sim, current_prompt, no_w_metric, w_metric)
        else:
            table.add_data(None, w_no_sim, None, w_sim, current_prompt, no_w_metric, w_metric)

        clip_scores.append(w_no_sim)
        clip_scores_w.append(w_sim)

# roc
preds = no_w_metrics +  w_metrics
t_labels = [0] * len(no_w_metrics) + [1] * len(w_metrics)

fpr, tpr, thresholds = metrics.roc_curve(t_labels, preds, pos_label=1)
auc = metrics.auc(fpr, tpr)
acc = np.max(1 - (fpr + (1 - tpr))/2)
low = tpr[np.where(fpr<.01)[0][-1]]

if args.with_tracking:
    wandb.log({'Table': table})
    wandb.log({'clip_score_mean': mean(clip_scores), 'clip_score_std': stdev(clip_scores),
               'w_clip_score_mean': mean(clip_scores_w), 'w_clip_score_std': stdev(clip_scores_w),
               'auc': auc, 'acc':acc, 'TPR@1%FPR': low})

print(f'clip_score_mean: {mean(clip_scores)}')
print(f'w_clip_score_mean: {mean(clip_scores_w)}')
print(f'auc: {auc}, acc: {acc}, TPR@1%FPR: {low}')