In [1]:
# This script is used to generate image refined by ELBO
# input: random seed, timesteps, ELBO strength, text prompt
# output: generated image

## Initialization

In [None]:
# set gpu device
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
# package and helper functions
import warnings
from typing import List, Optional

import torch
import torch.nn.functional as F
from diffusers import AutoPipelineForText2Image, StableDiffusion3Pipeline
import random
import numpy as np

from utils.attention_control import AttnProcessor, JointAttnProcessor
from utils.diffusion import Diffusion
from utils.img2text import Img2Text
from utils.parse_args import parse_args

T = torch.Tensor
TL = List[T]
TN = Optional[T]

In [None]:
# load config and model
warnings.filterwarnings("ignore")
config = parse_args("segmentation", False, ["diffusion.variant=stable-diffusion-v1-5/stable-diffusion-v1-5"])
img2text = Img2Text(config)
diffusion_dtype = torch.float16 if config.diffusion.dtype == "fp16" else torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(
    config.diffusion.variant,
    torch_dtype=diffusion_dtype,
    use_safetensors=True,
    cache_dir=config.model_dir,
    device_map=config.diffusion.device_map,
)
# register attention processor for attention hooks
if isinstance(pipe, StableDiffusion3Pipeline):
    pipe.transformer.set_attn_processor(JointAttnProcessor())
else:
    pipe.unet.set_attn_processor(AttnProcessor())
diffusion = Diffusion(pipe)

# Input

In [None]:
# full prompt = concatenated classes
classes = [
    "a bookshelf stacked with books",
    "an golden cat",
    "sofa",
    "a potted plant in the corner",
    "a artistic landscape photo hanging on the wall",
]
infer_timesteps = 500
random_seed = 4309
elbo_strength = 1.3

# Generation

In [72]:
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

with torch.inference_mode():

    def generate_prompt(classes, weight):
        prompt = ""
        for i, c in enumerate(classes):
            prompt += f"({c}){weight[i]}"
            if i != len(classes) - 1:
                prompt += ", "
        return prompt

    class_weight = [1.0] * len(classes)
    class_emb = []
    for c in classes:
        class_emb.append(diffusion.encode_prompt(f"{c}"))

    # 1. prepare latent
    latent = diffusion.prepare_latent()

    # 2. prepare text embedding
    negtive_prompt = ""
    neg_text_emb = diffusion.encode_prompt(negtive_prompt)

    # 3. prepare timesteps
    train_timesteps = diffusion.scheduler.config.num_train_timesteps
    step_ratio = train_timesteps // infer_timesteps
    timesteps = list(range(train_timesteps - 1, 0, -step_ratio))

    # 4. reverse diffusion process
    for t in timesteps:
        # 4.1. generate unweight prompt
        prompt = generate_prompt(classes, [1] * len(classes))
        pos_text_emb = diffusion.encode_prompt(prompt)
        # 4.2. get unweight model prediction for whole sentence
        model_pred_all = diffusion.get_model_prediction([latent], [t], [pos_text_emb])
        # 4.3. get unweight elbo and compute alignment score
        elbo = []
        for c in classes:
            model_pred_c = diffusion.get_model_prediction(
                [latent], [t], [class_emb[classes.index(c)]]
            )
            elbo.append(F.mse_loss(model_pred_all, model_pred_c, reduction="mean"))
        elbo = torch.stack(elbo)
        elbo = (elbo - elbo.min()) / (elbo.max() - elbo.min())
        class_weight = torch.round(torch.pow(elbo_strength, elbo), decimals=2).tolist()
        # 4.4. generate weight prompt and get weight model prediction
        prompt = generate_prompt(classes, class_weight)
        pos_text_emb = diffusion.encode_prompt(prompt)
        model_pred_cond, model_pred_uncond = diffusion.get_model_prediction(
            [latent, latent], [t, t], [pos_text_emb, neg_text_emb]
        ).chunk(2)
        model_pred = diffusion.classifier_free_guidance(
            model_pred_uncond, model_pred_cond
        )
        latent = diffusion.step(latent, t, max(0, t - step_ratio), model_pred)

    # 5. decode latent to image
    img = diffusion.decode_latent(latent)[0]
    img.save("./tmp/example.png")