In [1]:
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from torch import autocast
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm.auto import tqdm

# Setup
#runwayml/stable-diffusion-v1-5
diffusion_model_id = 'CompVis/stable-diffusion-v1-4'
text_encoder_model_id = 'openai/clip-vit-large-patch14'
device = 'cuda'
seed_cpu = torch.Generator('cpu').manual_seed(1024)
seed_gpu = torch.Generator('cuda').manual_seed(1024)

# Hugging Face access token
token = ''
with open('hugging_face_token.txt', 'r') as secret:
    token = secret.readline().strip()

In [None]:
# Load model components

# Variational Autoencoder
vae = AutoencoderKL.from_pretrained(
    diffusion_model_id, subfolder='vae', torch_dtype=torch.float16,
    revision='fp16', use_auth_token=token)
vae.to(device)

# U-Net Model
u_net = UNet2DConditionModel.from_pretrained(
    diffusion_model_id, subfolder='unet', torch_dtype=torch.float16,
    revision='fp16', use_auth_token=token)
u_net.to(device)

# Text Encoder + Tokenizer
tokenizer = CLIPTokenizer.from_pretrained(text_encoder_model_id)
text_encoder = CLIPTextModel.from_pretrained(text_encoder_model_id)
text_encoder = text_encoder.to(device)

# Scheduler
scheduler = DDPMScheduler.from_config(
    diffusion_model_id, subfolder='scheduler', torch_dtype=torch.float16,
    revision='fp16', use_auth_token=token)

In [3]:
# Get prompt embeddings
def get_text_embeddings(prompt):
    # Tokenize text and get embeddings
    text_input = tokenizer(
        prompt, padding='max_length', max_length=tokenizer.model_max_length,
        truncation=True, return_tensors='pt')
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

    # Do the same for unconditional embeddings
    unconditional_input = tokenizer(
        [''] * len(prompt), padding='max_length', max_length=tokenizer.model_max_length,
        return_tensors='pt')
    unconditional_embeddings = text_encoder(unconditional_input.input_ids.to(device))[0]

    # Cat for final embeddings
    text_embeddings = torch.cat([unconditional_embeddings, text_embeddings])
    return text_embeddings

In [4]:
def produce_latents(text_embeddings, height=512, width=512,
    num_inference_steps=50, guidance_scale=7.5, latents=None):
    if latents is None:
        latents = torch.randn((text_embeddings.shape[0] // 2, u_net.in_channels, height // 8, width // 8), generator=seed_cpu)
        latents = latents.to(device)

    scheduler.set_timesteps(num_inference_steps)
    latents *= scheduler.init_noise_sigma

    for t in tqdm(scheduler.timesteps):
        # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = u_net(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        noise_pred_unconditional, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_unconditional + guidance_scale * (noise_pred_text - noise_pred_unconditional)

        # compute the previous noisy sample x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents).prev_sample
    return latents

In [5]:
def decode_img_latents(latents):
    latents = 1 / 0.18215 * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype('uint8')
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

In [6]:
prompt = 'astronaut in water'
text_embeddings = get_text_embeddings(prompt)
latents = produce_latents(text_embeddings)
# TODO: fix this
#imgs = decode_img_latents(latents)
#imgs[0]

  0%|          | 0/50 [00:00<?, ?it/s]

RuntimeError: Input type (float) and bias type (struct c10::Half) should be the same