In [1]:
# This script is used to generate and visualize masks for a given image
# input: image path, class names, threshold, text_prompt, text_template, timesteps
# output: source img, loss, cross att, heatmap, mask

## Initialization

In [None]:
# set gpu device
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
# package and helper functions
import warnings
from typing import List, Optional

import torch
import torch.nn.functional as F
from diffusers import AutoPipelineForText2Image
import random
import numpy as np

from utils.attention_control import (
    AttnProcessor,
)
from utils.diffusion import Diffusion
from utils.img2text import Img2Text
from utils.parse_args import parse_args

T = torch.Tensor
TL = List[T]
TN = Optional[T]

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# load config and model
warnings.filterwarnings("ignore")
config = parse_args("segmentation", False, [])
img2text = Img2Text(config)
diffusion_dtype = torch.float16 if config.diffusion.dtype == "fp16" else torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(
    config.diffusion.variant,
    torch_dtype=diffusion_dtype,
    use_safetensors=True,
    cache_dir=config.model_dir,
    device_map=config.diffusion.device_map,
)
# register attention processor for attention hooks
pipe.unet.set_attn_processor(AttnProcessor())
diffusion = Diffusion(pipe)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]Taking `'Attention' object has no attribute 'key'` while using `accelerate.load_checkpoint_and_dispatch` to mean models/models--runwayml--stable-diffusion-v1-5/snapshots/f03de327dd89b501a01da37fc5240cf4fdba85a1/vae was saved with deprecated attention block weight names. We will load it with the deprecated attention block names and convert them on the fly to the new attention block format. Please re-save the model after this conversion, so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint, please also re-upload it or open a PR on the original repository.
Loading pipeline components...: 100%|██████████| 7/7 [00:04<00:00,  1.71it/s]


In [72]:
random.seed(4309)
np.random.seed(4309)
torch.manual_seed(4309)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

with torch.inference_mode():

    def generate_prompt(classes, weight):
        prompt = ""
        for i, c in enumerate(classes):
            prompt += f"({c}){weight[i]}"
            if i != len(classes) - 1:
                prompt += ", "
        return prompt

    classes = [
        "a bookshelf stacked with books",
        "an golden cat",
        "sofa",
        "a potted plant in the corner",
        "a artistic landscape photo hanging on the wall"
    ]
    class_weight = [1.0] * len(classes)
    class_emb = []
    for c in classes:
        class_emb.append(diffusion.encode_prompt(f"{c}"))

    # 1. prepare latent
    channel = diffusion.unet.config.in_channels
    height = diffusion.unet.config.sample_size
    width = diffusion.unet.config.sample_size
    latent = torch.randn(
        1,
        channel,
        height,
        width,
        device=diffusion.unet.device,
        dtype=diffusion.unet.dtype,
    )

    # 2. prepare text embedding
    negtive_prompt = ""
    neg_text_emb = diffusion.encode_prompt(negtive_prompt)

    # 3. prepare timesteps
    train_timesteps = diffusion.scheduler.config.num_train_timesteps
    step_ratio = train_timesteps // 500
    timesteps = list(range(train_timesteps - 1, 0, -step_ratio))

    # 4. reverse diffusion process
    for t in timesteps:
        # 4.1. generate unweight prompt
        prompt = generate_prompt(classes, [1] * len(classes))
        pos_text_emb = diffusion.encode_prompt(prompt)
        # 4.2. get unweight eps prediction
        eps_pred_cond = diffusion.get_eps_prediction([latent], [t], [pos_text_emb])
        # 4.3. get unweight elbo and compute alignment score
        elbo = []
        for c in classes:
            eps_pred_c = diffusion.get_eps_prediction(
                [latent], [t], [class_emb[classes.index(c)]]
            )
            elbo.append(F.mse_loss(eps_pred_cond, eps_pred_c, reduction="mean"))
        elbo = torch.stack(elbo)
        elbo = (elbo - elbo.min()) / (elbo.max() - elbo.min())
        class_weight = torch.round(torch.pow(1, elbo), decimals=2).tolist()
        # 4.4. generate weight prompt and get weight eps prediction
        prompt = generate_prompt(classes, class_weight)
        pos_text_emb = diffusion.encode_prompt(prompt)
        eps_pred_cond, eps_pred_uncond = diffusion.get_eps_prediction(
            [latent, latent], [t, t], [pos_text_emb, neg_text_emb]
        ).chunk(2)
        eps_pred = diffusion.classifier_free_guidance(eps_pred_uncond, eps_pred_cond)
        latent = diffusion.step(latent, t, max(0, t - step_ratio), eps_pred)

    # 5. decode latent to image
    img = diffusion.decode_latent(latent)[0]
    img.save("./tmp/example.png")