In [None]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [None]:
import argparse
import copy
from tqdm import tqdm
from statistics import mean, stdev
from sklearn import metrics

import torch

from tree_ring.inverse_stable_diffusion import InversableStableDiffusionPipeline
from diffusers import DPMSolverMultistepScheduler
import open_clip
from tree_ring.optim_utils import *
from tree_ring.io_utils import *
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
parser = argparse.ArgumentParser(description='diffusion watermark')
# parser.add_argument('--run_name', default='test')
parser.add_argument('--run_name', default='no_attack')
parser.add_argument('--dataset', default='./Stable-Diffusion-Prompts')
parser.add_argument('--start', default=0, type=int)
# parser.add_argument('--end', default=10, type=int)
parser.add_argument('--end', default=1000, type=int)
parser.add_argument('--image_length', default=512, type=int)
parser.add_argument('--model_id', default='stabilityai/stable-diffusion-2-1-base')
parser.add_argument('--with_tracking', action='store_true')
parser.add_argument('--num_images', default=1, type=int)
parser.add_argument('--guidance_scale', default=7.5, type=float)
parser.add_argument('--num_inference_steps', default=50, type=int)
parser.add_argument('--test_num_inference_steps', default=None, type=int)
parser.add_argument('--reference_model', default=None)
parser.add_argument('--reference_model_pretrain', default=None)
parser.add_argument('--max_num_log_image', default=100, type=int)
parser.add_argument('--gen_seed', default=0, type=int)

# watermark
parser.add_argument('--w_seed', default=999999, type=int)
# parser.add_argument('--w_channel', default=0, type=int)
parser.add_argument('--w_channel', default=3, type=int)
# parser.add_argument('--w_pattern', default='rand')
parser.add_argument('--w_pattern', default='ring')
parser.add_argument('--w_mask_shape', default='circle')
parser.add_argument('--w_radius', default=10, type=int)
parser.add_argument('--w_measurement', default='l1_complex')
parser.add_argument('--w_injection', default='complex')
parser.add_argument('--w_pattern_const', default=0, type=float)

# for image distortion
parser.add_argument('--r_degree', default=None, type=float)
parser.add_argument('--jpeg_ratio', default=None, type=int)
parser.add_argument('--crop_scale', default=None, type=float)
parser.add_argument('--crop_ratio', default=None, type=float)
parser.add_argument('--gaussian_blur_r', default=None, type=int)
parser.add_argument('--gaussian_std', default=None, type=float)
parser.add_argument('--brightness_factor', default=None, type=float)
parser.add_argument('--orig_img_no_w_path',default='../datas/videos/text2video-zero/16frames_uniform/' ,type=str)
parser.add_argument('--orig_img_w_path',default='../datas/videos/text2video-zero/16frames_uniform_w/', type=str)

args = parser.parse_args([])

if args.test_num_inference_steps is None:
    args.test_num_inference_steps = args.num_inference_steps

In [None]:
scheduler = DPMSolverMultistepScheduler.from_pretrained(args.model_id, subfolder='scheduler')
pipe = InversableStableDiffusionPipeline.from_pretrained(
    args.model_id,
    scheduler=scheduler,
    torch_dtype=torch.float16,
    revision='fp16',
    )
pipe = pipe.to(device)

In [None]:
tester_prompt = '' # assume at the detection time, the original prompt is unknown
text_embeddings = pipe.get_text_embedding(tester_prompt)

# ground-truth patch
gt_patch = torch.load('./tree_ring/key.pt').to(dtype=torch.complex32)
init_latents_w = pipe.get_random_latents()
watermarking_mask = get_watermarking_mask(init_latents_w, args, device)

In [None]:
folders_no_w = os.listdir(args.orig_img_no_w_path)
folders_w = os.listdir(args.orig_img_w_path)
for video in folders_no_w:
    frames = os.listdir(os.path.join(args.orig_img_no_w_path,video))
    for frame in frames:
        frame_path_no_w = os.path.join(args.orig_img_no_w_path, video, frame)
        frame_path_w = os.path.join(args.orig_img_w_path, video, frame)
        
        img_no_w = Image.open(frame_path_no_w)
        img_w = Image.open(frame_path_w)
        
        img_no_w = transform_img(img_no_w).unsqueeze(0).to(text_embeddings.dtype).to(device)
        img_w = transform_img(img_w).unsqueeze(0).to(text_embeddings.dtype).to(device)
        
        image_latents_no_w = pipe.get_image_latents(img_no_w, sample=False)
        image_latents_w = pipe.get_image_latents(img_w, sample=False)
        
        reversed_latents_no_w = pipe.forward_diffusion(
        latents=image_latents_no_w,
        text_embeddings=text_embeddings,
        guidance_scale=1,
        num_inference_steps=args.test_num_inference_steps,
        )
        reversed_latents_w = pipe.forward_diffusion(
        latents=image_latents_w,
        text_embeddings=text_embeddings,
        guidance_scale=1,
        num_inference_steps=args.test_num_inference_steps,
        )
        
        no_w_metric, w_metric = eval_watermark(reversed_latents_no_w, reversed_latents_w, watermarking_mask, gt_patch, args)