In [None]:
import argparse
from sympy import false
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, PNDMScheduler
from diffusers.utils import randn_tensor
from tqdm.auto import tqdm
from PIL import Image

In [None]:
# def parse_args():
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         "-m",
#         "--model_id",
#         type=str,
#         default="runwayml/stable-diffusion-v1-5",
#         help="Path to pretrained model or model identifier from huggingface.co/models.",
#     )
#     parser.add_argument(
#         "-p",
#         "--prompt",
#         type=str,
#         default="a photograph of an astronaut riding a horse",
#         help="Text used to generate images.",
#     )
#     parser.add_argument(
#         "-n",
#         "--images_num",
#         type=int,
#         default=1,
#         help="How much images to generate.",
#     )
#     parser.add_argument(
#         "-s",
#         "--steps",
#         type=int,
#         default=50,
#         help="The number of denoising steps.",
#     )
#     parser.add_argument(
#         "-wi",
#         "--width",
#         type=int,
#         default=512,
#         help="The width in pixels of the generated image.",
#     )
#     parser.add_argument(
#         "-he",
#         "--height",
#         type=int,
#         default=512,
#         help="The height in pixels of the generated image.",
#     )
#     parser.add_argument(
#         "-g",
#         "--guidance",
#         type=float,
#         default=7.5,
#         help="Higher guidance scale encourages to generate images that are closely linked to the text.",
#     )
#     parser.add_argument(
#         "-sd",
#         "--seed",
#         type=int,
#         default=42,
#         help="Seed for random process.",
#     )
#     parser.add_argument(
#         "-ci",
#         "--cuda_id",
#         type=int,
#         default=0,
#         help="cuda_id.",
#     )
#     parser.add_argument(
#         "-o",
#         "--path",
#         type=str,
#         default="outputs",
#         help="output path",
#     )
#     parser.add_argument(
#         "-pr",
#         "--pre_defined_pipeline",
#         type=bool,
#         default=False,
#         help="Use pre-defined_pipeleine for custom_pipeline.",
#     )
#     args = parser.parse_args()
#     return args

In [None]:
def image_grid(imgs, rows, cols):
    if not len(imgs) == rows * cols:
        raise ValueError("The specified number of rows and columns are not correct.")

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

In [None]:
# 1. Configure stable diffusion paramters
# args = parse_args()
# args.model

from regex import B


model_id = "runwayml/stable-diffusion-v1-5"
a_height = 512
a_width = 512
prompt = ["a photograph of an astronaut riding a horse"]
num_images_per_prompt = 1
num_inference_steps = 50
guidance_scale = 7.5
generator = torch.manual_seed(930319)
torch_device = torch.device("cuda", 7)
batch_size = len(prompt)
with_predefined_pipeline = True
output_path = "outputs"

In [None]:
if with_predefined_pipeline == True:
    
    # 2. Construct pre-defined diffusion pipeline 
    pipeline = StableDiffusionPipeline.from_pretrained(model_id)
    
    # 3. Load pre-defined diffusion pipeline to GPU
    pipeline.to(torch_device)

    # 2. Construct custom diffusion pipeline
    # 2-1. Load the tokenizer and text encoder to tokenize and encode the text.
    # text_encoder: Other diffusion models may use other encoders such as BERT(Default : CLIP)
    # tokenizer: It must match the one used by the text_encoder model(Default : CLIPtokenizer)
    text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
    tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")

    # 2-2. The UNet model for generating the latents.
    # unet: Model used to generate the latent representation of the input(Default : UNET))
    unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
    
    # 2-3. The scheduler for denoising latent vector. (Deault : PNDM)
    # scheduler: Scheduling algorithm used to progressively add noise to the image during training
    scheduler = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler")
    
    # 2-4. Load the autoencoder model which will be used to decode the latents into image space. 
    # vae: Autoencoder module used to decode latent representations into real images.
    vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")

    # 3. Load custom diffusion pipeline to GPU
    text_encoder.to(torch_device)
    unet.to(torch_device)
    vae.to(torch_device)

elif with_predefined_pipeline == False:
    vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
    height = a_height or unet.config.sample_size * vae_scale_factor
    width = a_width or unet.config.sample_size * vae_scale_factor

In [None]:
# With Predefined Pipeline
if with_predefined_pipeline == True:
    with torch.no_grad():
        # 4~8. Execute Inference(With Predefined Pipeline)
        height = a_height
        width = a_width
        image = pipeline(
            prompt=prompt,
            height=height,
            width=width,
            num_images_per_prompt=num_images_per_prompt,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=generator).images
        
        grid = image_grid(image, rows=batch_size, cols=num_images_per_prompt)
        grid.save(output_path+f"/predefined_result_step_{num_inference_steps}.png")

In [None]:
if with_predefined_pipeline == False:
    with torch.no_grad():
        
        # 4. Tokenize the text and generate the embeddings from the prompt
        # 4-1. generate conditional embeddings from text prompt
        # 4-1-1. Generate token from prompt
        text_input = tokenizer(prompt, padding="max_length",
            max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
        text_input_ids = text_input.input_ids

        # 4-1-2. Generate embeddings from token
        text_embeddings = text_encoder(text_input_ids.to(torch_device))[0]
        # duplicate text embeddings for each generation per prompt, using mps friendly method
        text_embeddings = text_embeddings.to(dtype=text_encoder.dtype, device=torch_device)
        bs_embed, seq_len, _ = text_embeddings.shape
        text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
        text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)

        # 4-2. get unconditional embeddings for classifier free guidance
        # 4-2-1. Generate token from empty prompt
        uncond_tokens = [""]*batch_size
        uncond_input = tokenizer(uncond_tokens, padding="max_length",
            max_length=text_embeddings.shape[1], truncation=True, return_tensors="pt")
        uncond_input_ids = uncond_input.input_ids

        # 4-2-2. Generate unconditional embeddings from token
        uncond_embeddings = text_encoder(uncond_input_ids.to(torch_device))[0]
        # duplicate unconditional embeddings for each generation per prompt
        seq_len = uncond_embeddings.shape[1]
        uncond_embeddings = uncond_embeddings.to(dtype=text_encoder.dtype, device=torch_device)
        uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
        uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)

        # 4-3. concatenate the unconditional and text embeddings
        # into a single batch to avoid doing two forward passes
        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

In [None]:
if with_predefined_pipeline == False:
    with torch.no_grad():
        
        # 5. Initialize denoising network
        # 5-1. Generate random noise
        shape = (batch_size*num_images_per_prompt, unet.config.in_channels,
                height//vae_scale_factor, width//vae_scale_factor)
        latents = randn_tensor(shape, generator=generator, device=torch_device,
                dtype=text_embeddings.dtype)
        # 5-2. scale the initial noise by the standard deviation required by the scheduler
        latents = latents * scheduler.init_noise_sigma
        # 5-3. initialize the scheduler with our chosen num_inference_steps
        scheduler.set_timesteps(num_inference_steps-1)

In [None]:
if with_predefined_pipeline == False:
    with torch.no_grad():
        
        # 6. denoising network loop(for num_inference_steps)
        for t in tqdm(scheduler.timesteps):
            # 6-1. expand the latents if we are doing classifier-free guidance
            # to avoid doing two forward passes.            
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

            # 6-2. UNET : predict the noise residual
            noise_pred = unet(latent_model_input, timestep=t,
                encoder_hidden_states=text_embeddings, return_dict=False)[0]
            
            # 6-3.reflect guidance scale on predicted noise to perform classifier-free guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # 6-4. subtract sampe x_(t) with predicted noise to generate sameple x_(t-1)
            latents = scheduler.step(noise_pred, t, latents).prev_sample


In [None]:
if with_predefined_pipeline == False:
    # 7. Decode the image 
    # 7-1. scale the denoised latent by scaling factor required by the VAE
    latents = 1 / vae.config.scaling_factor * latents
    # 7-2. decode the image latents with vae
    image = vae.decode(latents, return_dict=False)[0]

In [None]:
if with_predefined_pipeline == False:
    with torch.no_grad():
        
        # 8. Post-process the image    
        # convert the image to PIL and save it
        image = (image / 2 + 0.5).clamp(0, 1)
        # we always cast to float32 as this does not cause significant overhead
        # and it is compatible with bfloat16
        image = image.cpu().permute(0, 2, 3, 1).float().numpy()
        images = (image * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        
        grid = image_grid(pil_images, rows=1, cols=num_images_per_prompt)
        grid.save(output_path+f"/custom_result_step_{num_inference_steps}.png")