In [None]:
import subprocess
import os
from utils.misc import torch_to_pil
import gc
env = os.environ.copy()
env['MODEL_PATH'] = '../../dreambooth-outputs//wm_concept_implementation_person/checkpoint-49'
env['CLASS_DIR'] = 'data/class-person'
env['DREAMBOOTH_OUTPUT_DIR'] = '../../dreambooth-outputs/wm_concept_implementation_person/'
env['INSTANCE_DIR_Train'] = "CelebA-HQ/44/set_A"
env['INSTANCE_DIR_Adversarial'] = "CelebA-HQ/44/set_B"
env['WM_INSTANCE_DIR'] = "CelebA-HQ/44/set_A_W"
input_args = [
    '--pretrained_model_name_or_path', f"{env['MODEL_PATH']}",
    '--enable_xformers_memory_efficient_attention',
    '--instance_data_dir_for_train',f"{env['INSTANCE_DIR_Train']}",
    '--wm_instance_data_dir_for_adversarial',f"{env['WM_INSTANCE_DIR']}",
    '--instance_data_dir_for_adversarial',f"{env['INSTANCE_DIR_Adversarial']}",
    '--class_data_dir', f"{env['CLASS_DIR']}",
    '--output_dir', f"{env['DREAMBOOTH_OUTPUT_DIR']}",
    '--with_prior_preservation',
    '--prior_loss_weight', '1.0',
    '--instance_prompt', 'a photo of <concept1> person',
    '--class_prompt', 'a photo of person',
    '--inference_prompt', 'a photo of <concept1> person',
    '--resolution', '512',
    '--train_batch_size', '1',
    '--gradient_accumulation_steps', '1',
    '--learning_rate', '5e-6',
    '--lr_scheduler', 'constant',
    '--lr_warmup_steps', '0',
    '--num_class_images', '200',
    '--max_train_steps', '500',
    '--checkpointing_iterations', '50',
    '--center_crop',
    '--mixed_precision', 'bf16',
    '--prior_generation_precision', 'bf16',
    '--sample_batch_size', '8',
    '--gradient_checkpointing',
    '--use_8bit_adam'
    '--max_f_train_steps','4',
    '--max_adv_train_steps','1'
]

In [None]:
import argparse
import copy
import hashlib
import itertools
import logging
import os
from pathlib import Path

import datasets
import diffusers
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
from diffusers.utils.import_utils import is_xformers_available
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel

logger = get_logger(__name__)

image_transforms = transforms.Compose(
            [
                transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )
def infer(checkpoint_path, prompts=None, n_img=16, bs=8, n_steps=100, guidance_scale=7.5,sec_decoder=None,msg_val=None):
    import gc
    with torch.no_grad():
        pipe = StableDiffusionPipeline.from_pretrained(
            checkpoint_path, safety_checker=None,torch_dtype=torch.float16
        ).to("cuda")
        pipe.enable_xformers_memory_efficient_attention()
        torch.cuda.empty_cache()
        for prompt in prompts:
            print(prompt)
            norm_prompt = prompt.lower().replace(",", "").replace(" ", "_")
            out_path = f"{checkpoint_path}/dreambooth/{norm_prompt}"
            os.makedirs(out_path, exist_ok=True)
            for i in range(n_img // bs):
                images = pipe(
                    [prompt] * bs,
                    num_inference_steps=n_steps,
                    guidance_scale=guidance_scale,
                ).images
                for idx, image in enumerate(images):
                    image.save(f"{out_path}/{i}_{idx}.png")
                    
                    weight_dtype = next(sec_decoder.parameters()).dtype
                    device = next(sec_decoder.parameters()).device
                    image=image_transforms(image).to(dtype=weight_dtype,device=device)
                    
                    print(image.shape)
                    watermarked_image_pil = torch_to_pil(image)[0]
                    watermarked_image_pil.show()
                    decoded_msg = sec_decoder(image.unsqueeze(0))
                    decoded_msg = torch.argmax(decoded_msg, dim=2)
                    acc = 1 - torch.abs(decoded_msg - msg_val).sum().float() / (48 * 1)
                    print(f"acc {acc}")
                    
                    torch.cuda.empty_cache()
        del pipe  
        torch.cuda.empty_cache()  
        gc.collect()  

class DreamBoothDatasetFromTensor(Dataset):
    """Just like DreamBoothDataset, but take instance_images_tensor instead of path"""

    def __init__(
        self,
        instance_images_tensor,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        size=512,
        center_crop=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer

        self.instance_images_tensor = instance_images_tensor
        self.num_instance_images = len(self.instance_images_tensor)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = self.instance_images_tensor[index % self.num_instance_images]
        example["instance_images"] = instance_image
        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                self.class_prompt,
                truncation=True,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                return_tensors="pt",
            ).input_ids

        return example


class DreamBoothDatasetFromPriorTensor(Dataset):
    """Just like DreamBoothDataset, but take instance_images_tensor instead of path"""

    def __init__(
        self,
        instance_images_tensor,
        wm_data_tensor,
        instance_prompt,
        tokenizer,
        class_data_root=None,
        class_prompt=None,
        size=512,
        center_crop=False,
    ):
        self.size = size
        self.center_crop = center_crop
        self.tokenizer = tokenizer
        
        self.wm_instance_images_tensor = instance_images_tensor
        self.num_wm_instance_images = len(self.wm_instance_images_tensor)
        
        self.instance_images_tensor = instance_images_tensor
        self.num_instance_images = len(self.instance_images_tensor)
        self.instance_prompt = instance_prompt
        self._length = self.num_instance_images

        if class_data_root is not None:
            self.class_data_root = Path(class_data_root)
            self.class_data_root.mkdir(parents=True, exist_ok=True)
            self.class_images_path = list(self.class_data_root.iterdir())
            self.num_class_images = len(self.class_images_path)
            self._length = max(self.num_class_images, self.num_instance_images)
            self.class_prompt = class_prompt
        else:
            self.class_data_root = None

        self.image_transforms = transforms.Compose(
            [
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

    def __len__(self):
        return self._length

    def __getitem__(self, index):
        example = {}
        instance_image = self.instance_images_tensor[index % self.num_instance_images]
        wm_instance_image = self.wm_instance_images_tensor[index % self.num_wm_instance_images]
        example["wm_instance_images"] = instance_image
        example["instance_images"] = instance_image
        example["instance_prompt_ids"] = self.tokenizer(
            self.instance_prompt,
            truncation=True,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            return_tensors="pt",
        ).input_ids

        if self.class_data_root:
            class_image = Image.open(self.class_images_path[index % self.num_class_images])
            if not class_image.mode == "RGB":
                class_image = class_image.convert("RGB")
            example["class_images"] = self.image_transforms(class_image)
            example["class_prompt_ids"] = self.tokenizer(
                self.class_prompt,
                truncation=True,
                padding="max_length",
                max_length=self.tokenizer.model_max_length,
                return_tensors="pt",
            ).input_ids

        return example

def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
    text_encoder_config = PretrainedConfig.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="text_encoder",
        revision=revision,
    )
    model_class = text_encoder_config.architectures[0]

    if model_class == "CLIPTextModel":
        from transformers import CLIPTextModel

        return CLIPTextModel
    elif model_class == "RobertaSeriesModelWithTransformation":
        from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation

        return RobertaSeriesModelWithTransformation
    else:
        raise ValueError(f"{model_class} is not supported.")


def parse_args(input_args=None):
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--inference_prompts",
        type=str,
        default=None,
        help="The prompt used to generate images at inference.",
    )
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=True,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help=(
            "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
            " float32 precision."
        ),
    )
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--tokenizer_name",
        type=str,
        default=None,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--instance_data_dir_for_train",
        type=str,
        default=None,
        required=True,
        help="A folder containing the training data of instance images.",
    )
    parser.add_argument(
        "--instance_data_dir_for_adversarial",
        type=str,
        default=None,
        required=True,
        help="A folder containing the images to add adversarial noise",
    )
    parser.add_argument(
        "--wm_instance_data_dir_for_adversarial",
        type=str,
        default=None,
        required=True,
        help="A folder containing the images to add adversarial noise",
    )
    parser.add_argument(
        "--class_data_dir",
        type=str,
        default=None,
        required=False,
        help="A folder containing the training data of class images.",
    )
    parser.add_argument(
        "--instance_prompt",
        type=str,
        default=None,
        required=True,
        help="The prompt with identifier specifying the instance",
    )
    parser.add_argument(
        "--class_prompt",
        type=str,
        default=None,
        help="The prompt to specify images in the same class as provided instance images.",
    )
    parser.add_argument(
        "--with_prior_preservation",
        default=False,
        action="store_true",
        help="Flag to add prior preservation loss.",
    )
    parser.add_argument(
        "--prior_loss_weight",
        type=float,
        default=1.0,
        help="The weight of prior preservation loss.",
    )
    parser.add_argument(
        "--num_class_images",
        type=int,
        default=100,
        help=(
            "Minimal class images for prior preservation loss. If there are not enough images already present in"
            " class_data_dir, additional images will be sampled with class_prompt."
        ),
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="text-inversion-model",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--center_crop",
        default=False,
        action="store_true",
        help=(
            "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
            " cropped. The images will be resized to the resolution first before cropping."
        ),
    )
    parser.add_argument(
        "--train_text_encoder",
        action="store_true",
        help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=4,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--sample_batch_size",
        type=int,
        default=8,
        help="Batch size (per device) for sampling images.",
    )
    parser.add_argument(
        "--max_train_steps",
        type=int,
        default=20,
        help="Total number of training steps to perform.",
    )
    parser.add_argument(
        "--max_f_train_steps",
        type=int,
        default=10,
        help="Total number of sub-steps to train surogate model.",
    )
    parser.add_argument(
        "--max_adv_train_steps",
        type=int,
        default=10,
        help="Total number of sub-steps to train adversarial noise.",
    )
    parser.add_argument(
        "--checkpointing_iterations",
        type=int,
        default=5,
        help=("Save a checkpoint of the training state every X iterations."),
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=5e-6,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--use_8bit_adam",
        action="store_true",
        help="Whether or not to use 8-bit Adam from bitsandbytes.",
    )
    parser.add_argument(
        "--allow_tf32",
        action="store_true",
        help=(
            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="tensorboard",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default="fp16",
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--enable_xformers_memory_efficient_attention",
        action="store_true",
        help="Whether or not to use xformers.",
    )
    parser.add_argument(
        "--pgd_alpha",
        type=float,
        default=0.1 / 255,
        help="The step size for pgd.",
    )
    parser.add_argument(
        "--pgd_eps",
        type=float,
        default=0.0001,
        help="The noise budget for pgd.",
    )
    parser.add_argument(
        "--target_image_path",
        default=None,
        help="target image for attacking",
    )
    parser.add_argument(
        "--adam_beta1",
        type=float,
        default=0.9,
        help="The beta1 parameter for the Adam optimizer.",
    )
    parser.add_argument(
        "--adam_beta2",
        type=float,
        default=0.999,
        help="The beta2 parameter for the Adam optimizer.",
    )
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
    parser.add_argument(
        "--adam_epsilon",
        type=float,
        default=1e-08,
        help="Epsilon value for the Adam optimizer",
    )

    if input_args is not None:
        args = parser.parse_known_args(input_args)[0]
    else:
        args = parser.parse_known_args()[0]

    return args







In [None]:
class PromptDataset(Dataset):
    "A simple dataset to prepare the prompts to generate class images on multiple GPUs."

    def __init__(self, prompt, num_samples):
        self.prompt = prompt
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        example = {}
        example["prompt"] = self.prompt
        example["index"] = index
        return example


def load_data(data_dir, size=512, center_crop=True) -> torch.Tensor:
    image_transforms = transforms.Compose(
        [
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    # images = [image_transforms(Image.open(i).convert("RGB")) for i in list(Path(data_dir).iterdir())]
    images = [image_transforms(Image.open(i).convert("RGB")) for i in sorted(Path(data_dir).iterdir(), key=lambda x: int(x.stem))]
    images = torch.stack(images)
    return images


def train_one_epoch(
    args,
    model,
    tokenizer,
    noise_scheduler,
    vae,
    text_encoder,
    data_tensor: torch.Tensor,
    num_steps=20,
):
    # Load the tokenizer
    
    # unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
    unet=copy.deepcopy(model[0])
    unet.requires_grad_(True)
    # params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
    params_to_optimize = itertools.chain(unet.parameters())

    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
            optimizer_class = bnb.optim.AdamW8bit
        except ImportError:
            print("bitsandbytes not installed. Falling back to default Adam.")
            optimizer_class = torch.optim.AdamW
    else:
        optimizer_class = torch.optim.AdamW
    optimizer = optimizer_class(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    train_dataset = DreamBoothDatasetFromTensor(
        data_tensor,
        args.instance_prompt,
        tokenizer,
        args.class_data_dir,
        args.class_prompt,
        args.resolution,
        args.center_crop,
    )

    weight_dtype = torch.bfloat16
    device = torch.device("cuda")

    vae.to(device, dtype=weight_dtype)
    text_encoder.to(device, dtype=weight_dtype)
    unet.to(device, dtype=torch.float32)


    for step in (range(num_steps)):
        unet.train()
        
        from utils.misc import torch_to_pil
        
        step_data = train_dataset[step % len(train_dataset)]
        pixel_values = torch.stack([step_data["instance_images"], step_data["class_images"]]).to(
            device, dtype=weight_dtype
        )
        # print(pixel_values.shape)
        # torch_to_pil(pixel_values)[0].show()
        # print(step_data["instance_prompt_ids"], step_data["class_prompt_ids"])
        
        input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)

        latents = vae.encode(pixel_values).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(input_ids)[0]

        # Predict the noise residual
        model_pred = unet(noisy_latents.to(unet.dtype), timesteps, encoder_hidden_states.to(unet.dtype)).sample

        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        # with prior preservation loss
        if args.with_prior_preservation:
            model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
            target, target_prior = torch.chunk(target, 2, dim=0)

            # Compute instance loss
            instance_loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # Compute prior loss
            prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

            # Add the prior loss to the instance loss.
            loss = instance_loss + args.prior_loss_weight * prior_loss
            # loss = instance_loss

        else:
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        print(f'loss={loss}')
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params_to_optimize, 1.0, error_if_nonfinite=True)
        optimizer.step()
        optimizer.zero_grad()

    torch.cuda.empty_cache()
    logs = {"loss": loss.detach().item()}
    
    return [unet,logs]


def perturbation_modualtion(
    args,
    model,
    tokenizer,
    noise_scheduler,
    vae,
    text_encoder,
    data_tensor: torch.Tensor,
    original_images: torch.Tensor,
    num_steps: int,
    wm_tensor: torch.Tensor,
    sec_decoder,
    
):
    import numpy as np
    def decode_latents(latents):
        #latents = 1 / vae.config.scaling_factor * latents
        image = vae.decode(latents).sample
        return image
    """Return new perturbed data"""
    
    msg_val = torch.tensor(np.load('./secret_48.npy')).to('cuda')
    unet= model[0]
    weight_dtype = torch.bfloat16
    device = torch.device("cuda")

    unet.to(device, dtype=weight_dtype)
    perturbed_images = data_tensor.detach().clone()
    perturbed_images.to(torch.float32).requires_grad_(True)
    
    
    
    import torch.optim as optim
    optimizer = optim.Adam([perturbed_images], lr=1e-7)
    for step in range(num_steps):
        input_ids = tokenizer(
        args.instance_prompt,
        truncation=True,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        return_tensors="pt",
        ).input_ids.repeat(len(data_tensor), 1)
        latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()
        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(input_ids.to(device))[0]

        # Predict the noise residual
        model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        unet.zero_grad()
        loss0 = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        loss0.backward()

        alpha = args.pgd_alpha
        eps = args.pgd_eps
        adv_images = perturbed_images + alpha * perturbed_images.grad.sign()
        eta = torch.clamp(adv_images - wm_tensor, min=-eps, max=+eps)
        perturbed_images=(wm_tensor+eta).detach_()

        optimizer.zero_grad()
        perturbed_images.requires_grad_(True)
        # change to BCE loss
        image=perturbed_images.to(vae.dtype)
        decoded_msg = sec_decoder(image.to(next(sec_decoder.parameters()).dtype).cuda())
        labels = F.one_hot(msg_val.unsqueeze(0).expand(decoded_msg.shape[0],-1), num_classes=2).float()
        msgloss = F.binary_cross_entropy_with_logits(decoded_msg, labels.cuda())
        loss=msgloss
        loss.backward()
        optimizer.step()
        
        print(f"PGD loss - step {step}, loss: {loss.detach().item()},loss0: {loss0.detach().item()},msg: {msgloss.detach().item()}")
        
    perturbed_images.requires_grad_(False)
    torch.cuda.empty_cache()   
    return perturbed_images

# def perturbation_modualtion(
#     args,
#     model,
#     tokenizer,
#     noise_scheduler,
#     vae,
#     text_encoder,
#     data_tensor: torch.Tensor,
#     original_images: torch.Tensor,
#     target_tensor: torch.Tensor,
#     num_steps: int,
    
# ):
#     """Return new perturbed data"""

#     unet= model[0]
#     weight_dtype = torch.bfloat16
#     device = torch.device("cuda")

#     vae.to(device, dtype=weight_dtype)
#     text_encoder.to(device, dtype=weight_dtype)
#     unet.to(device, dtype=weight_dtype)

#     perturbed_images = data_tensor.detach().clone()
#     perturbed_images.to(torch.float32).requires_grad_(True)

#     input_ids = tokenizer(
#         args.instance_prompt,
#         truncation=True,
#         padding="max_length",
#         max_length=tokenizer.model_max_length,
#         return_tensors="pt",
#     ).input_ids.repeat(len(data_tensor), 1)

#     for step in (range(num_steps)):
#         perturbed_images.requires_grad = True
#         latents = vae.encode(perturbed_images.to(device, dtype=weight_dtype)).latent_dist.sample()
#         latents = latents * vae.config.scaling_factor

#         # Sample noise that we'll add to the latents
#         noise = torch.randn_like(latents)
#         bsz = latents.shape[0]
#         # Sample a random timestep for each image
#         timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
#         timesteps = timesteps.long()
#         # Add noise to the latents according to the noise magnitude at each timestep
#         # (this is the forward diffusion process)
#         noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

#         # Get the text embedding for conditioning
#         encoder_hidden_states = text_encoder(input_ids.to(device))[0]

#         # Predict the noise residual
#         model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

#         # Get the target for loss depending on the prediction type
#         if noise_scheduler.config.prediction_type == "epsilon":
#             target = noise
#         elif noise_scheduler.config.prediction_type == "v_prediction":
#             target = noise_scheduler.get_velocity(latents, noise, timesteps)
#         else:
#             raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

#         unet.zero_grad()
#         text_encoder.zero_grad()
#         loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

#         loss.backward()

#         alpha = args.pgd_alpha
#         eps = args.pgd_eps

#         adv_images = perturbed_images + alpha * perturbed_images.grad.sign()
#         eta = torch.clamp(adv_images - original_images, min=-eps, max=+eps)
#         perturbed_images = torch.clamp(original_images + eta, min=-1, max=+1).detach_()

#     return perturbed_images

from accelerate.utils import ProjectConfiguration, set_seed

args = parse_args(input_args)
logging_dir = Path(args.output_dir, args.logging_dir)

accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)

accelerator = Accelerator(
    mixed_precision=args.mixed_precision,
    log_with=args.report_to,
    project_config=accelerator_project_config,
)

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    datasets.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

if args.seed is not None:
    set_seed(args.seed)

# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
    class_images_dir = Path(args.class_data_dir)
    if not class_images_dir.exists():
        class_images_dir.mkdir(parents=True)
    cur_class_images = len(list(class_images_dir.iterdir()))

    if cur_class_images < args.num_class_images:
        torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
        if args.mixed_precision == "fp32":
            torch_dtype = torch.float32
        elif args.mixed_precision == "fp16":
            torch_dtype = torch.float16
        elif args.mixed_precision == "bf16":
            torch_dtype = torch.bfloat16
        pipeline = DiffusionPipeline.from_pretrained(
            args.pretrained_model_name_or_path,
            torch_dtype=torch_dtype,
            safety_checker=None,
            revision=args.revision,
        )
        pipeline.set_progress_bar_config(disable=True)

        num_new_images = args.num_class_images - cur_class_images
        logger.info(f"Number of class images to sample: {num_new_images}.")

        sample_dataset = PromptDataset(args.class_prompt, num_new_images)
        sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)

        sample_dataloader = accelerator.prepare(sample_dataloader)
        pipeline.to(accelerator.device)

        for example in tqdm(
            sample_dataloader,
            desc="Generating class images",
            disable=not accelerator.is_local_main_process,
        ):
            images = pipeline(example["prompt"]).images

            for i, image in enumerate(images):
                hash_image = hashlib.sha1(image.tobytes()).hexdigest()
                image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
                image.save(image_filename)

        del pipeline
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)

# Load scheduler and models
text_encoder = text_encoder_cls.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="text_encoder",
    revision=args.revision,
)
unet = UNet2DConditionModel.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

tokenizer = AutoTokenizer.from_pretrained(
    args.pretrained_model_name_or_path,
    subfolder="tokenizer",
    revision=args.revision,
    use_fast=False,
)

noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")

vae = AutoencoderKL.from_pretrained(
    args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
).cuda()

vae.requires_grad_(False)

if not args.train_text_encoder:
    text_encoder.requires_grad_(False)

if args.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True

clean_data = load_data(
    args.instance_data_dir_for_train,
    size=args.resolution,
    center_crop=args.center_crop,
)
perturbed_data = load_data(
    args.instance_data_dir_for_adversarial,
    size=args.resolution,
    center_crop=args.center_crop,
)
perturbed_wm_data = load_data(
    args.wm_instance_data_dir_for_adversarial,
    size=args.resolution,
    center_crop=args.center_crop,
)
original_data = perturbed_data.clone()
original_data.requires_grad_(False)

if args.enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        unet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError("xformers is not available. Make sure it is installed correctly")

if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        if args.train_text_encoder:
            text_encoder.gradient_checkpointing_enable()

target_latent_tensor = None
if args.target_image_path is not None:
    target_image_path = Path(args.target_image_path)
    assert target_image_path.is_file(), f"Target image path {target_image_path} does not exist"

    target_image = Image.open(target_image_path).convert("RGB").resize((args.resolution, args.resolution))
    target_image = np.array(target_image)[None].transpose(0, 3, 1, 2)

    target_image_tensor = torch.from_numpy(target_image).to("cuda", dtype=torch.float32) / 127.5 - 1.0
    target_latent_tensor = (
        vae.encode(target_image_tensor).latent_dist.sample().to(dtype=torch.bfloat16) * vae.config.scaling_factor
    )
    target_latent_tensor = target_latent_tensor.repeat(len(perturbed_data), 1, 1, 1).cuda()

In [None]:
weight_dtype = torch.float16
from utils.models import SecretEncoder,SecretDecoder
sec_decoder = SecretDecoder(output_size=48).to(accelerator.device, dtype=weight_dtype)
models = torch.load('./pretrained_latentwm.pth')
sec_decoder.load_state_dict(models['sec_decoder'])
sec_decoder.requires_grad_(False)
msg_val = torch.tensor(np.load('./secret_48.npy')).to('cuda')

In [None]:
def train_one_prior_epoch(
    args,
    model,
    tokenizer,
    noise_scheduler,
    vae,
    text_encoder,
    data_tensor: torch.Tensor,
    wm_data_tensor:torch.Tensor,
    num_steps=20,
):
     
    unet_ori=copy.deepcopy(model[0])
    # unet, text_encoder = copy.deepcopy(models[0]), copy.deepcopy(models[1])
    unet=copy.deepcopy(model[0])
    
    del model[0]
    torch.cuda.empty_cache()
    
    unet.requires_grad_(True)
    # params_to_optimize = itertools.chain(unet.parameters(), text_encoder.parameters())
    params_to_optimize = itertools.chain(unet.parameters())

    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
            optimizer_class = bnb.optim.AdamW8bit
        except ImportError:
            print("bitsandbytes not installed. Falling back to default Adam.")
            optimizer_class = torch.optim.AdamW
    else:
        optimizer_class = torch.optim.AdamW
    optimizer = optimizer_class(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    train_dataset = DreamBoothDatasetFromPriorTensor(
        data_tensor,
        wm_data_tensor,
        args.instance_prompt,
        tokenizer,
        args.class_data_dir,
        args.class_prompt,
        args.resolution,
        args.center_crop,
    )

    weight_dtype = torch.bfloat16
    device = torch.device("cuda")

    vae.to(device, dtype=weight_dtype)
    text_encoder.to(device, dtype=weight_dtype)
    unet.to(device, dtype=torch.float32)


    for step in (range(num_steps)):
        unet.train()
        
        from utils.misc import torch_to_pil
        
        step_data = train_dataset[step % len(train_dataset)]
        pixel_values = torch.stack([step_data["wm_instance_images"], step_data["class_images"]]).to(
            device, dtype=weight_dtype
        )
        ori_values=step_data["instance_images"].to(
            device, dtype=weight_dtype
        )
        # torch_to_pil(pixel_values)[0].show()
        # torch_to_pil(ori_values)[0].show()
        # print(pixel_values.shape)
        # torch_to_pil(pixel_values)[0].show()
        # print(step_data["instance_prompt_ids"], step_data["class_prompt_ids"])
        
        input_ids = torch.cat([step_data["instance_prompt_ids"], step_data["class_prompt_ids"]], dim=0).to(device)

        latents = vae.encode(pixel_values).latent_dist.sample()
        latents = latents * vae.config.scaling_factor

        latents_ori = vae.encode(ori_values.unsqueeze(0)).latent_dist.sample()
        latents_ori = latents_ori.repeat(2, 1, 1, 1) * vae.config.scaling_factor
        # Sample noise that we'll add to the latents
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]
        # Sample a random timestep for each image
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        noisy_latents_ori=noise_scheduler.add_noise(latents_ori, noise, timesteps)
        # Get the text embedding for conditioning
        encoder_hidden_states = text_encoder(input_ids)[0]

        # Predict the noise residual

        model_pred = unet(noisy_latents.to(unet.dtype), timesteps.to(unet.dtype), encoder_hidden_states.to(unet.dtype)).sample
        with torch.no_grad():
            model_pred_ori = unet_ori(noisy_latents_ori.to(unet_ori.dtype), timesteps, encoder_hidden_states.to(unet_ori.dtype)).sample
            
            
            
        # Get the target for loss depending on the prediction type
        if noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif noise_scheduler.config.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)
        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

        if args.with_prior_preservation:
            # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
            model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
            target, target_prior = torch.chunk(target, 2, dim=0)

            # Compute instance loss
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            # Compute prior loss
            prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")

        else:
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            
        wm_model_pred, wm_model_pred_prior=torch.chunk(model_pred_ori, 2, dim=0)
        
        loss_pred=F.mse_loss(model_pred.float(), target.float(), reduction="mean")
        loss_priorwm=F.mse_loss(model_pred.float(), wm_model_pred.float(), reduction="mean")
        loss=loss_priorwm+args.prior_loss_weight*prior_loss

        accelerator.backward(loss)
        if accelerator.sync_gradients:
            params_to_clip = (
                itertools.chain(unet.parameters(), text_encoder.parameters())
                if args.train_text_encoder
                else unet.parameters()
            )
            accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad() 

        torch.cuda.empty_cache()
    del unet_ori
    torch.cuda.empty_cache()
    logs = {"loss": loss.detach().item()}
        
    return [unet,logs]

In [None]:
f = [unet]
from tqdm.auto import tqdm
import gc
progress_bar = tqdm(
    range(0, args.max_train_steps),
    disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
for i in (range(args.max_train_steps)):
    
    # f_sur = train_one_epoch(
    #     args,
    #     f,
    #     tokenizer,
    #     noise_scheduler,
    #     vae,
    #     text_encoder,
    #     clean_data,
    #     args.max_f_train_steps,
    # )
    f_sur = train_one_epoch(
        args,
        f,
        tokenizer,
        noise_scheduler,
        vae,
        text_encoder,
        original_data,
        args.max_f_train_steps,
    )
    f_sur[0].requires_grad_(False)
    torch.cuda.empty_cache()
    gc.collect() 
    
    if i<20:
        f=f_sur
    else:
        perturbed_data = perturbation_modualtion(
            args,
            f_sur,
            tokenizer,
            noise_scheduler,
            vae,
            text_encoder,
            perturbed_data,
            original_data,
            args.max_adv_train_steps,
            wm_tensor=perturbed_wm_data,
            sec_decoder=sec_decoder
        )
        del f_sur[0]
        torch.cuda.empty_cache()
        gc.collect()
        
        f = train_one_prior_epoch(
            args,
            f,
            tokenizer,
            noise_scheduler,
            vae,
            text_encoder,
            perturbed_data,
            perturbed_wm_data,
            args.max_f_train_steps,
        )
        f[0].requires_grad_(False)
        torch.cuda.empty_cache()
        gc.collect()
    
    
    progress_bar.set_postfix(**f[1])
    progress_bar.update(1)
    
    if (i + 1) % args.checkpointing_iterations == 0:
        save_folder = f"{args.output_dir}/noise-ckpt/{i+1}"
        os.makedirs(save_folder, exist_ok=True)
        noised_imgs = perturbed_data.detach()
        img_names = [
            str(instance_path).split("/")[-1]
            for instance_path in list(Path(args.instance_data_dir_for_adversarial).iterdir())
        ]
        for img_pixel, img_name in zip(noised_imgs, img_names):
            save_path = os.path.join(save_folder, f"{i+1}_noise_{img_name}")
            Image.fromarray(
                (img_pixel * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0).cpu().numpy()
            ).save(save_path)
        print(f"Saved noise at step {i+1} to {save_folder}")
        
        

        if accelerator.is_main_process:
                save_path = os.path.join(args.output_dir, f"checkpoint-{i}")
                ckpt_pipeline = DiffusionPipeline.from_pretrained(
                    args.pretrained_model_name_or_path,
                    unet=accelerator.unwrap_model(f[0]),
                    revision=args.revision,
                )
                ckpt_pipeline.save_pretrained(save_path)
                del ckpt_pipeline
                
                torch.cuda.empty_cache()
                
                prompts = args.inference_prompts.split(";")
                infer(save_path, prompts, n_img=4, bs=4, n_steps=100,sec_decoder=sec_decoder,msg_val=msg_val)
                logger.info(f"Saved state to {save_path}")


