In [4]:
import argparse, os, sys, glob, cv2
from os import path
from base64 import b64encode
from PIL import Image, ImageDraw
from IPython.display import HTML

import numpy as np
import torch
from torch import autocast, float16
from torch.nn import functional as F

from diffusers import StableDiffusionPipeline, AutoencoderKL
from diffusers import UNet2DConditionModel, PNDMScheduler, LMSDiscreteScheduler
from diffusers.schedulers.scheduling_ddim import DDIMScheduler

from transformers import CLIPTextModel, CLIPTokenizer
from tqdm.auto import tqdm

from base64 import b64encode
# from IPython.display import HTML

device = 'cuda'

cwd = path.join(os.getcwd())
modelpath = 'models/ldm/stable-diffusion-v1'
loadpath = path.normpath(path.join(cwd, '..', modelpath))

# print(loadpath)


  from .autonotebook import tqdm as notebook_tqdm


In [5]:

# 1. Load the autoencoder model which will be used to decode the latents into image space. 
vae = AutoencoderKL.from_pretrained('/home/finn/data/stable-diffusion-v1-4/vae')
vae.to(device)

# 2. Load the tokenizer and text encoder to tokenize and encode the text. 
tokenizer = CLIPTokenizer.from_pretrained('/home/finn/data/stable-diffusion-v1-4/tokenizer')
text_encoder = CLIPTextModel.from_pretrained('/home/finn/data/stable-diffusion-v1-4/text_encoder')
text_encoder.to(device)

# 3. The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained('/home/finn/data/stable-diffusion-v1-4/unet')
unet.to(device)


# 4. Create a scheduler for inference
scheduler = LMSDiscreteScheduler(
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule='scaled_linear',
    num_train_timesteps=1000
)

In [6]:
def get_text_embeds(prompt):
  # Tokenize text and get embeddings
  text_input = tokenizer(
      prompt, 
      padding='max_length', 
      max_length=tokenizer.model_max_length,
      truncation=True, 
      return_tensors='pt'
    )

  with torch.no_grad():
    text_embeddings = text_encoder(text_input.input_ids.to(device))[0]

  # Do the same for unconditional embeddings
  uncond_input = tokenizer(
      [''] * len(prompt), padding='max_length',
      max_length=tokenizer.model_max_length, return_tensors='pt')
  with torch.no_grad():
    uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

  # Cat for final embeddings
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
  return text_embeddings


In [8]:
def decode_img_latents(latents):
  latents = 1 / 0.18215 * latents

  with torch.no_grad():
    imgs = vae.decode(latents).sample

  imgs = (imgs + 0.5).clamp(0, 1)
  imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
  imgs = (imgs * 255).round().astype('uint8')
  pil_images = [Image.fromarray(image) for image in imgs]
  return pil_images


# Making Videos

In [10]:
def produce_latents(
    text_embeddings,
    height=512, 
    width=512,
    num_inference_steps=50, 
    guidance_scale=7.5, 
    latents=None,
    return_all_latents=False
):

  if latents is None:
    latents = torch.randn(
        (text_embeddings.shape[0] // 2, 
        unet.in_channels, 
        height // 8, 
        width // 8)
    )
  latents = latents.to(device)

  scheduler.set_timesteps(num_inference_steps)
  latents = latents * scheduler.sigmas[0]

  latents_history = [latents]
  with autocast('cuda'):
    for i, t in tqdm(enumerate(scheduler.timesteps)):
      # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
      latent_model_input = torch.cat([latents] * 2)
      sigma = scheduler.sigmas[i]
      latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)

      # predict the noise residual
      with torch.no_grad():
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']

      # perform guidance
      noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
      noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

      # compute the previous noisy sample x_t -> x_t-1
      latents = scheduler.step(noise_pred, i, latents)['prev_sample']
      latents_history.append(latents)
  
  if not return_all_latents:
    return latents

  all_latents = torch.concat(latents_history, dim=0)
  return all_latents

In [11]:
def prompt_to_images(prompts, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None, return_all_latents=False, batch_size=4):
    if isinstance(prompts, str):
        prompts = [prompts]

    text_embeds = get_text_embeds(prompts)

    latents = produce_latents(
        text_embeds,
        height=height,
        width=width,
        latents=latents,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        return_all_latents=return_all_latents
    )

    
    all_imgs = []
    for i in tqdm(range(0, len(latents), batch_size)):
        imgs = decode_img_latents(latents[i:i+batch_size])
        all_imgs.extend(imgs)

    return all_imgs

In [22]:
prompt = 'flowers on a pond, Dinosaur tyrannosaurus rex holding a fishing rod, purple sky, classic romantic renaissance art style'
video_frames = prompt_to_images(prompt, num_inference_steps=50, return_all_latents=True)

50it [00:04, 10.50it/s]
100%|██████████| 13/13 [00:07<00:00,  1.77it/s]


In [16]:
def images_to_video(imgs, video_name='video.mp4', fps=15):
    video_dims = (imgs[0].width, imgs[0].height)
    fourcc = cv2.VideoWriter_fourcc(*'DIVX')
    video = cv2.VideoWriter(video_name, fourcc, fps, video_dims)

    for img in imgs:
        tmp_img=img.copy() 
        video.write(cv2.cvtColor(np.array(tmp_img), cv2.COLOR_RGB2BGR))

    for i in range(15):
        tmp_img=imgs[-1]
        video.write(cv2.cvtColor(np.array(tmp_img), cv2.COLOR_RGB2BGR))
    video.release()

In [17]:
def display_video(file_path, width=512):
    compressed_vid_path = 'comp_' + file_path
    if os.path.exists(compressed_vid_path):
        os.remove(compressed_vid_path)
    os.system(f'ffmpeg -i {file_path} -vcodec libx264 {compressed_vid_path}')

    mp4 = open(compressed_vid_path, 'rb').read()
    data_url = 'data:simul2/mp4;base64,' +b64encode(mp4).decode()
    return HTML("""
        <video width={} controls>
            <source src="{}" type="video/mp4" />
        </video>
    """.format(width, data_url))

In [23]:
video_name = prompt.replace(" ", "_") + ".mp4"
images_to_video(video_frames, video_name)
display_video(video_name)

OpenCV: FFMPEG: tag 0x58564944/'DIVX' is not supported with codec id 12 and format 'mp4 / MP4 (MPEG-4 Part 14)'
OpenCV: FFMPEG: fallback to use tag 0x7634706d/'mp4v'
ffmpeg version 4.2.2 Copyright (c) 2000-2019 the FFmpeg developers
  built with gcc 7.3.0 (crosstool-NG 1.23.0.449-a04d0)
  configuration: --prefix=/home/finn/miniconda3/envs/ldm --cc=/tmp/build/80754af9/ffmpeg_1587154242452/_build_env/bin/x86_64-conda_cos6-linux-gnu-cc --disable-doc --enable-avresample --enable-gmp --enable-hardcoded-tables --enable-libfreetype --enable-libvpx --enable-pthreads --enable-libopus --enable-postproc --enable-pic --enable-pthreads --enable-shared --enable-static --enable-version3 --enable-zlib --enable-libmp3lame --disable-nonfree --enable-gpl --enable-gnutls --disable-openssl --enable-libopenh264 --enable-libx264
  libavutil      56. 31.100 / 56. 31.100
  libavcodec     58. 54.100 / 58. 54.100
  libavformat    58. 29.100 / 58. 29.100
  libavdevice    58.  8.100 / 58.  8.100
  libavfilter     