In [1]:
import torch
import torch.nn.functional as F
from diffusers import StableDiffusionPipeline
from PIL import Image
from typing import Optional, List, Union, Dict, Any
import numpy as np
import os
import io
import lpips
from torchvision import transforms
from collections import defaultdict
import argparse
import sys
from skimage.metrics import peak_signal_noise_ratio, structural_similarity

class WatermarkInjectionPipeline(StableDiffusionPipeline):
    def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, 
                 safety_checker=None, feature_extractor=None, image_encoder=None, 
                 requires_safety_checker=False, wm_encoder=None, wm_decoder=None):
        super().__init__(vae, text_encoder, tokenizer, unet, scheduler, 
                         safety_checker, feature_extractor, image_encoder, requires_safety_checker)
        self.wm_encoder = wm_encoder
        self.wm_decoder = wm_decoder
        self._inference_mode = True
    
    def set_inference_mode(self, mode=True):
        self._inference_mode = mode
        if self.unet: 
            self.unet.eval() if mode else self.unet.train()
    
    def _get_beta_t(self, t, scheduler):
        if hasattr(scheduler, 'alphas_cumprod'):
            alphas = scheduler.alphas_cumprod
        elif hasattr(scheduler, 'alphas_cumprod_gpu'):
             alphas = scheduler.alphas_cumprod_gpu
        else:
            return torch.tensor(1.0, device=self.device) 

        if torch.is_tensor(t): 
            t_idx = t.cpu().item()
            alpha = alphas[t_idx]
        else: 
            alpha = alphas[t]
            
        alpha = alpha.detach().clone() if torch.is_tensor(alpha) else torch.tensor(alpha, device=self.device)
        return torch.sqrt(alpha) / torch.sqrt(1 - alpha)
    
    @torch.no_grad()
    def __call__(
        self, 
        prompt: Union[str, List[str]], 
        latents: Optional[torch.FloatTensor] = None, 
        wm_injection_start_step: int = 20, 
        wm_injection_end_step: int = 45, 
        wm_weight: float = 1.0, 
        secret_input: Optional[torch.Tensor] = None, 
        height: Optional[int] = None, 
        width: Optional[int] = None, 
        num_inference_steps: int = 50, 
        guidance_scale: float = 7.5, 
        enable_watermark: bool = True, 
        **kwargs
    ):
        self.set_inference_mode(True)
        
        height = height or self.unet.config.sample_size * self.vae_scale_factor
        width = width or self.unet.config.sample_size * self.vae_scale_factor
        
        if isinstance(prompt, str):
            batch_size = 1
        else:
            batch_size = len(prompt)
            
        device = self._execution_device
        
        prompt_embeds, neg_embeds = self.encode_prompt(
            prompt, device, 1, guidance_scale > 1.0, negative_prompt=None
        )
        text_embeddings = torch.cat([neg_embeds, prompt_embeds]) if guidance_scale > 1.0 else prompt_embeds
        
        self.scheduler.set_timesteps(num_inference_steps, device=device)
        
        if latents is None:
            latents = self.prepare_latents(
                batch_size, self.unet.config.in_channels, height, width, 
                text_embeddings.dtype, device, kwargs.get('generator'), None
            )
        else:
            latents = latents.to(device=device, dtype=text_embeddings.dtype)

        wm_residual = None
        if enable_watermark and self.wm_encoder is not None and secret_input is not None:
            secret_input = secret_input.to(device=device, dtype=text_embeddings.dtype)
            wm_residual = self.wm_encoder(secret_input)
            
            if wm_residual.shape[0] != latents.shape[0]:
                wm_residual = wm_residual.repeat(latents.shape[0], 1, 1, 1)

        extra_kwargs = self.prepare_extra_step_kwargs(kwargs.get('generator'), 0.0)
        num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order
        
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(self.scheduler.timesteps):
                latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
                
                noise_pred = self.unet(
                    latent_model_input, t, encoder_hidden_states=text_embeddings
                ).sample
                
                if guidance_scale > 1.0:
                    uncond, text = noise_pred.chunk(2)
                    noise_pred = uncond + guidance_scale * (text - uncond)
                
                if wm_residual is not None and wm_injection_start_step <= i <= wm_injection_end_step:
                    beta_t = self._get_beta_t(t, self.scheduler)
                    noise_pred = noise_pred - beta_t * wm_weight * wm_residual
                
                latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
                
                if i == len(self.scheduler.timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
            
        if not kwargs.get("output_type") == "latent":
            image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
            current_bs = image.shape[0]
            do_denormalize = [True] * current_bs
            image = self.image_processor.postprocess(
                image, output_type="pil", do_denormalize=do_denormalize 
            )
        else:
            image = latents
            
        return {"images": image, "latents": latents}

class DistortionUnit:
    def __init__(self, device='cuda'):
        self.device = device
        self.to_tensor = transforms.ToTensor()
        self.to_pil = transforms.ToPILImage()

    def apply_distortion(self, img, method, **kwargs):
        if not isinstance(img, Image.Image):
             img = self.to_pil(img.cpu()) if isinstance(img, torch.Tensor) else img
        
        if method == 'clean': 
            return img
            
        if method == 'gaussian_noise':
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img_np = np.array(img, dtype=np.uint8)
            
            std = kwargs.get('std', 0.1)
            
            g_noise = np.random.randn(*img_np.shape).astype(np.float32) * (std * 255)
            
            noisy_array = np.clip(img_np.astype(np.float32) + g_noise, 0, 255).astype(np.uint8)
            
            return Image.fromarray(noisy_array)

        img_tensor = self.to_tensor(img).to(self.device)
        
        if method == 'gaussian_blur':
            k = kwargs.get('kernel_size', (5, 5))
            s = kwargs.get('sigma', (1.0, 1.0))
            distorted = transforms.GaussianBlur(kernel_size=k, sigma=s)(img_tensor)
            
        elif method == 'jpeg_compression':
            pil_temp = self.to_pil(img_tensor.cpu())
            buffer = io.BytesIO()
            quality = kwargs.get('quality', 50)
            pil_temp.save(buffer, format="JPEG", quality=quality)
            buffer.seek(0)
            return Image.open(buffer).convert('RGB')
        
        else:
            distorted = img_tensor

        distorted = torch.clamp(distorted, 0.0, 1.0)
        return self.to_pil(distorted.cpu())

class ImageQualityEvaluator:
    def __init__(self, device='cuda'):
        self.device = device
        try: 
            self.lpips_model = lpips.LPIPS(net='alex').eval().to(device)
        except: 
            self.lpips_model = None
            
    def _to_numpy(self, img):
        if isinstance(img, Image.Image):
            img = np.array(img)
        elif isinstance(img, torch.Tensor):
            img = img.detach().cpu().numpy()
            if img.ndim == 3 and img.shape[0] in [1, 3]: 
                img = np.transpose(img, (1, 2, 0)) * 255.0
            elif img.ndim == 4: 
                img = np.transpose(img[0], (1, 2, 0)) * 255.0
        if not isinstance(img, np.ndarray): img = np.array(img)
        return img.astype(np.float32)

    def _to_tensor(self, img_np):
        tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(self.device)
        return tensor / 127.5 - 1.0

    def evaluate(self, orig, wm):
        img_orig_np = self._to_numpy(orig)
        img_wm_np = self._to_numpy(wm)
        
        psnr = peak_signal_noise_ratio(img_orig_np, img_wm_np, data_range=255)
        
        try:
            ssim = structural_similarity(img_orig_np, img_wm_np, data_range=255, channel_axis=2)
        except TypeError:
            ssim = structural_similarity(img_orig_np, img_wm_np, data_range=255, multichannel=True)
            
        lpips_val = 0.0
        if self.lpips_model is not None:
            t_orig = self._to_tensor(img_orig_np)
            t_wm = self._to_tensor(img_wm_np)
            with torch.no_grad():
                lpips_val = self.lpips_model(t_orig, t_wm).item()

        return {'psnr': psnr, 'ssim': ssim, 'lpips': lpips_val}

def calculate_accuracy(predicted, target):
    return ((predicted > 0.5).float() == (target > 0.5).float()).float().mean().item()

def parse_args():
    parser = argparse.ArgumentParser(description="ALIEN Watermark Injection Test")
    
    parser.add_argument("--sd_model_path", type=str, default="../stable-diffusion-v1-5", help="Path to Stable Diffusion model")
    parser.add_argument("--wm_model_path", type=str, default="./ALIEN_Models", help="Path to Watermark Encoder/Decoder")
    
    parser.add_argument("--output_dir", type=str, default="./output_alien_test", help="Directory to save results")
    parser.add_argument("--prompt", type=str, default=None, help="Prompt for generation (overrides default list if provided)")
    parser.add_argument("--seed", type=int, default=1111, help="Random seed base")
    parser.add_argument("--secret_len", type=int, default=48, help="Length of watermark secret bits")
    
    parser.add_argument("--wm_weight", type=float, default=1, help="Strength of watermark injection")
    parser.add_argument("--start_step", type=int, default=1, help="Injection start step")
    parser.add_argument("--end_step", type=int, default=50, help="Injection end step")
    
    parser.add_argument("--noise_std", type=float, default=0.1, help="Standard deviation for Gaussian Noise attack")
    parser.add_argument("--jpeg_quality", type=int, default=50, help="Quality for JPEG Compression attack")
    parser.add_argument("--blur_sigma", type=float, default=1.0, help="Sigma for Gaussian Blur attack")
    
    return parser.parse_known_args()[0]

def main():
    args = parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"üöÄ Starting ALIEN Test on {device} (Strict Numpy Noise Mode)")
    print(f"üìÇ Output Dir: {args.output_dir}")

    os.makedirs(args.output_dir, exist_ok=True)

    if args.prompt:
        prompts = [args.prompt]
    else:
        prompts = [
            "A cat, soft golden lighting, cinematic bokeh, highly detailed fur, 8k, realistic, studio lighting."
        ]

    try:
        pipeline = StableDiffusionPipeline.from_pretrained(args.sd_model_path, safety_checker=None).to(device)
        
        try:
            from model import LatentMarkEncoder, LatentMarkDecoder
        except ImportError:
            print("‚ö†Ô∏è 'model' module not found, assuming classes are defined in this script or environment.")
            pass

        wm_encoder = LatentMarkEncoder(secret_size=args.secret_len, latent_channels=4).to(device)
        wm_decoder = LatentMarkDecoder(latent_channels=4, secret_size=args.secret_len).to(device)
        
        enc_path = os.path.join(args.wm_model_path, "encoder.pth")
        dec_path = os.path.join(args.wm_model_path, "decoder.pth")

        if os.path.exists(enc_path):
            wm_encoder.load_state_dict(torch.load(enc_path, map_location='cpu'))
            wm_decoder.load_state_dict(torch.load(dec_path, map_location='cpu'))
            print(f"‚úÖ Loaded Pretrained Watermark Models from {args.wm_model_path}")
        else:
            print(f"‚ö†Ô∏è Warning: Pretrained models not found at {args.wm_model_path}. Using random weights.")
            
        wm_encoder.eval()
        wm_decoder.eval()
        
        alien_pipe = WatermarkInjectionPipeline(
            vae=pipeline.vae, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer, 
            unet=pipeline.unet, scheduler=pipeline.scheduler, wm_encoder=wm_encoder, wm_decoder=wm_decoder
        ).to(device)
        
        evaluator = ImageQualityEvaluator(device)
        distorter = DistortionUnit(device)
        
    except Exception as e:
        print(f"‚ùå Initialization Error: {e}")
        import traceback
        traceback.print_exc()
        return

    stats = {'psnr': [], 'ssim': [], 'lpips': [], 'acc': defaultdict(list)}
    
    attacks = [
        ('clean', 'clean', {}),
        ('noise', 'gaussian_noise', {'std': args.noise_std}), 
        ('blur', 'gaussian_blur', {'kernel_size': (5, 5), 'sigma': (args.blur_sigma, args.blur_sigma)}),
        ('jpeg', 'jpeg_compression', {'quality': args.jpeg_quality})
    ]

    secret_input = torch.randint(0, 2, (1, args.secret_len), dtype=torch.float32, device=device)
    
    print(f"\nüì¢ Processing {len(prompts)} images...")
    
    for idx, prompt in enumerate(prompts):
        img_id = idx + 1111
        save_dir = f"{args.output_dir}/img_{img_id:02d}"
        os.makedirs(save_dir, exist_ok=True)
        
        seed = args.seed + idx
        generator = torch.Generator(device).manual_seed(seed)
        
        print(f"[{img_id}/{len(prompts)}] Generating: {prompt[:50]}...")

        orig_res = alien_pipe(prompt, enable_watermark=False, num_inference_steps=50, generator=generator)
        orig_img = orig_res["images"][0]
        orig_img.save(f"{save_dir}/original.jpg")
        
        generator.manual_seed(seed) 
        alien_res = alien_pipe(
            prompt, enable_watermark=True, secret_input=secret_input,
            wm_injection_start_step=args.start_step, 
            wm_injection_end_step=args.end_step, 
            wm_weight=args.wm_weight,
            num_inference_steps=50, generator=generator
        )
        alien_img = alien_res["images"][0]
        alien_latents = alien_res["latents"]
        alien_img.save(f"{save_dir}/alien_wm.jpg")
        
        metrics = evaluator.evaluate(orig_img, alien_img)
        for k in ['psnr', 'ssim', 'lpips']:
            stats[k].append(metrics[k])
        
        for name, method, params in attacks:
            if method == 'clean':
                z = alien_latents
                d_img = alien_img 
            else:
                d_img = distorter.apply_distortion(alien_img, method, **params)
                d_img.save(f"{save_dir}/attack_{name}.jpg")
                
                with torch.no_grad():
                    img_t = transforms.ToTensor()(d_img).unsqueeze(0).to(device)
                    img_norm = img_t * 2.0 - 1.0
                    z = pipeline.vae.encode(img_norm).latent_dist.sample() * pipeline.vae.config.scaling_factor
            
            decoded_secret = wm_decoder(z)
            acc = calculate_accuracy(decoded_secret, secret_input)
            stats['acc'][name].append(acc)

    print("\n" + "="*50)
    print(f"üìä ALIEN WATERMARK REPORT (N={len(prompts)})")
    print("="*50)
    print(f"Visual Quality:")
    print(f"  PSNR  : {np.mean(stats['psnr']):.2f}")
    print(f"  SSIM  : {np.mean(stats['ssim']):.4f}")
    print(f"  LPIPS : {np.mean(stats['lpips']):.4f}")
    
    print("\nRobustness (Bit Accuracy):")
    print("-" * 50)
    print(f"{'Attack':<15} | {'Accuracy':<10}")
    print("-" * 50)
    for name, _, _ in attacks:
        avg_acc = np.mean(stats['acc'][name])
        print(f"{name.upper():<15} | {avg_acc:.4f}")
    print("="*50)
    print(f"‚úÖ Results saved to: {args.output_dir}")

if __name__ == '__main__':
    main()

üöÄ Starting ALIEN Test on cuda (Strict Numpy Noise Mode)
üìÇ Output Dir: ./output_alien_test


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


‚úÖ Loaded Pretrained Watermark Models from ./ALIEN_Models
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]




Loading model from: /root/miniconda3/envs/wm_bench/lib/python3.10/site-packages/lpips/weights/v0.1/alex.pth

üì¢ Processing 1 images...
[1111/1] Generating: A cat, soft golden lighting, cinematic bokeh, high...


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]


üìä ALIEN WATERMARK REPORT (N=1)
Visual Quality:
  PSNR  : 27.96
  SSIM  : 0.8562
  LPIPS : 0.1192

Robustness (Bit Accuracy):
--------------------------------------------------
Attack          | Accuracy  
--------------------------------------------------
CLEAN           | 1.0000
NOISE           | 1.0000
BLUR            | 1.0000
JPEG            | 1.0000
‚úÖ Results saved to: ./output_alien_test
