In [1]:
# This script is used to generate image from text prompt using diffusion models
# input: random seed, timesteps, text prompt
# output: generated image

## Initialization

In [2]:
# set gpu device
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [None]:
# package and helper functions
import warnings
from typing import List, Optional

import torch
from diffusers import AutoPipelineForText2Image
import random
import numpy as np

from utils.diffusion import Diffusion
from utils.img2text import Img2Text
from utils.parse_args import parse_args

T = torch.Tensor
TL = List[T]
TN = Optional[T]

In [None]:
# load config and model
warnings.filterwarnings("ignore")
config = parse_args("segmentation", False, ["diffusion.variant=stabilityai/stable-diffusion-3.5-medium"])
img2text = Img2Text(config)
diffusion_dtype = torch.float16 if config.diffusion.dtype == "fp16" else torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(
    config.diffusion.variant,
    torch_dtype=diffusion_dtype,
    use_safetensors=True,
    cache_dir=config.model_dir,
    device_map=config.diffusion.device_map,
)
diffusion = Diffusion(pipe)

# Input

In [9]:
prompt = "a painting of a cat"
infer_timesteps = 50
random_seed = 4309

# Generation

In [10]:
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

with torch.inference_mode():

    # 1. prepare latent
    latent = diffusion.prepare_latent()

    # 2. prepare text embedding
    negtive_prompt = ""
    pos_text_emb = diffusion.encode_prompt(prompt)
    neg_text_emb = diffusion.encode_prompt(negtive_prompt)

    # 3. prepare timesteps
    train_timesteps = diffusion.scheduler.config.num_train_timesteps
    step_ratio = train_timesteps // infer_timesteps
    timesteps = list(range(train_timesteps - 1, 0, -step_ratio))

    # 4. reverse diffusion process
    for t in timesteps:
        model_pred_cond, model_pred_uncond = diffusion.get_model_prediction(
            [latent, latent], [t, t], [pos_text_emb, neg_text_emb]
        ).chunk(2)
        model_pred = diffusion.classifier_free_guidance(
            model_pred_uncond, model_pred_cond
        )
        latent = diffusion.step(latent, t, max(0, t - step_ratio), model_pred)

    # 5. decode latent to image
    img = diffusion.decode_latent(latent)[0]
    img.save("./tmp/example.png")