In [4]:
import argparse, os, sys, glob, cv2
from os import path
from base64 import b64encode
from PIL import Image, ImageDraw
from IPython.display import HTML

import numpy as np
import torch
from torch import autocast, float16
from torch.nn import functional as F

from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers import UNet2DConditionModel, PNDMScheduler, LMSDiscreteScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler

from transformers import CLIPTextModel, CLIPTokenizer
from tqdm.auto import tqdm

from base64 import b64encode
# from IPython.display import HTML

device = 'cuda'

cwd = path.join(os.getcwd())
modelpath = 'models/ldm/stable-diffusion-v1'
loadpath = path.normpath(path.join(cwd, '..', modelpath))

# print(loadpath)


  from .autonotebook import tqdm as notebook_tqdm


In [5]:

# 1. Load the autoencoder model which will be used to decode the latents into image space. 
vae = AutoencoderKL.from_pretrained('/home/finn/data/stable-diffusion-v1-4/vae')
vae.to(device)

# 2. Load the tokenizer and text encoder to tokenize and encode the text. 
tokenizer = CLIPTokenizer.from_pretrained('/home/finn/data/stable-diffusion-v1-4/tokenizer')
text_encoder = CLIPTextModel.from_pretrained('/home/finn/data/stable-diffusion-v1-4/text_encoder')
text_encoder.to(device)

# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained('/home/finn/data/stable-diffusion-v1-4/unet')
unet.to(device)


# 4. Create a scheduler for inference
scheduler = LMSDiscreteScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule='scaled_linear',
    num_train_timesteps=1000
)

In [6]:
def get_text_embeds(prompt):
  # Tokenize text and get embeddings
  text_input = tokenizer(
      prompt, 
      padding='max_length', 
      max_length=tokenizer.model_max_length,
      truncation=True, 
      return_tensors='pt'
    )

  with torch.no_grad():
    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

  # Do the same for unconditional embeddings
  uncond_input = tokenizer(
      [''] * len(prompt), padding='max_length',
      max_length=tokenizer.model_max_length, return_tensors='pt')
  with torch.no_grad():
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

  # Cat for final embeddings
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
  return text_embeddings


In [7]:
def produce_latents(text_embeddings, height=512, width=512,
                    num_inference_steps=50, guidance_scale=7.5, latents=None):
  if latents is None:
    latents = torch.randn((text_embeddings.shape[0] // 2, unet.in_channels, \
                           height // 8, width // 8))
  latents = latents.to(device)

  scheduler.set_timesteps(num_inference_steps)
  latents = latents * scheduler.sigmas[0]

  with autocast('cuda'):
    for i, t in tqdm(enumerate(scheduler.timesteps)):
      # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
      latent_model_input = torch.cat([latents] * 2)
      sigma = scheduler.sigmas[i]
      latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

      # predict the noise residual
      with torch.no_grad():
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']

      # perform guidance
      noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
      noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

      # compute the previous noisy sample x_t -> x_t-1
      latents = scheduler.step(noise_pred, i, latents)['prev_sample']
  
  return latents

In [8]:
def decode_img_latents(latents):
  latents = 1 / 0.18215 * latents

  with torch.no_grad():
    imgs = vae.decode(latents).sample

  imgs = (imgs + 0.5).clamp(0, 1)
  imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
  imgs = (imgs * 255).round().astype('uint8')
  pil_images = [Image.fromarray(image) for image in imgs]
  return pil_images


In [9]:

def prompt_to_image(prompts, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):
    if isinstance(prompts, str):
        prompts = [prompts]

    text_embeds = get_text_embeds(prompts)
    latents = produce_latents(
        text_embeds,
        height=height,
        width=width,
        latents=latents,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale
    )

    imgs = decode_img_latents(latents)
    return imgs
