<a href="https://colab.research.google.com/github/AdamPeetz/stable_diffusion/blob/main/latent_space_walks_V5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Latent Space Interpolations with Stable Diffusion

The input manifold of the diffusion model is a hypersphere of noise. Words are encoded to locations in the hypersphere which can be translated into images by the model. The intermediate points between two encoded prompts can be calculated and fed into the model to create a smooth transition between two images. This process is called latent space interpolation.

## Install Keras CV Library

The Keras_CV library contains the stable diffusion model and must be installed seperate from the rest of the Keras package.

In [None]:
!pip install tensorflow keras_cv --upgrade --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/721.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m721.6/721.6 kB[0m [31m46.7 MB/s[0m eta [36m0:00:00[0m
[?25h

## Load Required Libraries

Several support libraries are used in this noteboook.

Keras_CV and Keras provide API support for working with nueral networks. <br>
Tensorflow and Numpy are used for tensor manipulation.<br>
Math provides support for equations used in the code.<br>
Matplotlib is used to generate images.<br>
TQDM is used to record the progress of looping operations.<br>
OS and Shutil are used to navigate directory structures.<br>
Google.Colab Drive allows the notebook to communicate with a gmail drive.<br>


In [None]:
import keras_cv
from tensorflow import keras
from matplotlib import pyplot
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import math
from tqdm import tqdm
import os, shutil
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


## Clear the Keras Backend

Clear the Keras backend between succesive model runs to reset as variables stored in the backend may contain information from previous runs.

In [None]:
tf.keras.backend.clear_session()

## Load the Stable Diffusion Model

A stable diffusion model object is defined here for use in multiple areas of the notebook.

The dimensions of the image the model produces are selected at this point. The image dimensions must match the noise variable dimensions in the text encoding prompt cell. The image dimensions are required to be a factor of an acceptable input tensor for the model.

The image dimensions have a significant impact on the image generated by the model and the resources required to create that image. Image resolution can be increased after the model has run using an ESRGAN and a CPU.

In [None]:
# Enable mixed precision
# (only do this if you have a recent NVIDIA GPU)
keras.mixed_precision.set_global_policy("mixed_float16")

# define image dimensions
height = 768
width = 512

# Instantiate the Stable Diffusion model
model = keras_cv.models.StableDiffusion(img_height=height, img_width=width, jit_compile=True)

By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE


## Helper Functions

### save_images

Saves images created by the model to a single directory location.

### glob_save_images

Saves images created by the model to a series of directory locations. Google Colab and Drive are known to expereince errors when directories contain too many files. Globbed directory structures help colab and drive overcome these challenges. The logic structure of the code currently supports runs of up to 1100 total images.

### slerp

The latent space of the model is a hypersphere. Linear interpolation between two points in the sphere does not create a smooth transition between two images. A spherical linear interpolation, or slerp, generates an arc between two points in space. It results in smoother transitions between images when used for latent space walks.

In [None]:
def save_images(image_library, image_counter):
  for idx in tqdm(range(len(image_library))):
   plt.figure(figsize=(10,10))
   plt.imshow(image_library[idx][0])
   plt.axis('off')
   filename = '/content/gdrive/My Drive/planegan/sample_output_2/generated_plot_%05d.png' % (image_count+1)
   plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
   plt.close()
   image_count += 1

In [None]:
def glob_save_images(base_directory, image_library, image_counter):
  file_count = len(image_library)
  globbed_dir_num = int((file_count/100)+1)
  glob_file = 0
  for flc in range(0, globbed_dir_num):
      os.makedirs(base_directory + "%05d" % (flc+1))
  for idx in tqdm(range(len(image_library))):
    if idx  >= 0 and idx <= 100:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05s.png" % ((glob_file+1), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 101 and idx <= 200:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+2), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 201 and idx <= 300:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+3),( image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 301 and idx <= 400:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+4), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 401 and idx <= 500:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+5), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 501 and idx <= 600:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+6), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 601 and idx <= 700:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+7), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 701 and idx <= 800:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+8), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 801 and idx <= 900:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+9), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 901 and idx <= 1000:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+10), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1
    if idx >= 1001 and idx <= 1100:
         plt.figure(figsize=(10,10))
         plt.imshow(image_library[idx][0])
         plt.axis('off')
         filename = base_directory + "%05d/generated_plot_%05d.png" % ((glob_file+11), (image_counter+1))
         plt.savefig(filename, transparent=True,bbox_inches="tight",pad_inches=0.0 )
         plt.close()
         image_counter += 1

In [None]:
# function for spherical interpolation pathway
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):

    #convert V0 and V1 to numpy arrays
    v0 = v0.numpy()
    v1 = v1.numpy()

    #calculate dot product of vectors
    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
    #if transformation would not be meaningful, do linear interpolation
    if np.abs(dot) > DOT_THRESHOLD:
        v2 = (1 - t) * v0 + t * v1
    # otherwise, spheical interpolate
    else:
        theta_0 = np.arccos(dot)
        sin_theta_0 = np.sin(theta_0)
        theta_t = theta_0 * t
        sin_theta_t = np.sin(theta_t)
        s0 = np.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1

        #convert vectors back to tensor object
        v2 = tf.convert_to_tensor(v2, float)

    #return new vector
    return v2

## Encode Text Prompts to the Models Latent Space

Generates encodings for text prompts and defines other variables used in the model.

A seed can be imposed. This allows results to be duplicated between runs.

Positive prompts are defined as strings and squeezed into encodings. Model results are improved when prompts have a level of complexity.

A negative prompt can be used to further steer the model. A negative prompt guides the model away from certain types of images.

The number of interpolation steps correlates to the number of intermediate images produced by the model.

The batch size is the number of images produced at each step. This should be set to 1 as the model will produce duplicate images if the batch size is larger than 1.



In [None]:
# set noise seed for repeatability
seed = 4121990

# calculate noise input dimensions based on image dimensions
noise = tf.random.normal((768 // 8, 512// 8, 4), seed=seed)

# define positive string prompts
prompt_1 = "galaxies in space from a distance, fractal, psytrance, high quality, masterpeice, high definition, highly detailed, elegant, sharp focus, digital painting, scifi, fantasy, center frame"
prompt_2 = " devouring black holes in space, fractal, psytrance, high quality, masterpeice, high definition, highly detailed, elegant, sharp focus,concept art, digital painting, scifi, fantasy, center frame"
prompt_3 = "battle between heaven and hell over a planet in space, explosions, fractal, psytrance, high quality, masterpeice, high definition, highly detailed, sharp focus, digital painting, scifi, fantasy"

# define negative string prompts
negative = "extra limbs, bad art, watermark, face, dull, pencils, error, malformed, low detail, jpeg artifacts, cropped, plain background, ugly, low-res, poorly drawn face, out of frame, poorly drawn hands, blurry, bad art, extra hands, bad anatomy, amputee, missing limbs, amputated"
# set number of interpolation steps between points in space
interpolation_steps = 256
# define model batch size (1 for 1 image at each step, if seed is fixed, will produce duplicates of the same image)
batch_size = 1
# create number of batches
batches = (interpolation_steps) // batch_size

# encode prompt to latent space
encoding_1 = tf.squeeze(model.encode_text(prompt_1))
encoding_2 = tf.squeeze(model.encode_text(prompt_2))
encoding_3 = tf.squeeze(model.encode_text(prompt_3))
                  )

## Interpolate Points Between Prompt N and N and Generate Images

Generate arcs between two encodings and create images with locations on each arc.

A tensor list variable is defined to hold locations along an arc. Those locations are calculated with slep and appended to the tensor list.

An images variable is defined and the model is run to generate an image tensor for each point in the tensor list. Those image tensors are appended to the images variable.

This cluster of code is repeated for each text encoding the model interpolates between. In this example, 3 encodings forming a circle.

In [None]:
# employ slerp
# divide points between 0 and 1 in linerspace based on the number of steps you want
tensor_list = []
for _, t in enumerate(np.linspace(0, 1, interpolation_steps)):
    new_latent = slerp(float(t), encoding_1, encoding_2)
    tensor_list.append(new_latent)

images = []
print(range(interpolation_steps))
for batch in range(interpolation_steps):
    print(batch)
    images.append(
      model.generate_image(
            tensor_list[batch],
            negative_prompt=negative,
            batch_size=batch_size,
            diffusion_noise=noise,
                           )
                  )

In [None]:
# employ slerp
# divide points between 0 and 1 in linerspace based on the number of steps you want
tensor_list = []
for _, t in enumerate(np.linspace(0, 1, interpolation_steps)):
    new_latent = slerp(float(t), encoding_2, encoding_3)
    tensor_list.append(new_latent)

print(range(interpolation_steps))
for batch in range(interpolation_steps):
    print(batch)
    images.append(
      model.generate_image(
            tensor_list[batch],
            negative_prompt=negative,
            batch_size=batch_size,
            diffusion_noise=noise,
                           )
                  )

In [None]:
# employ slerp
# divide points between 0 and 1 in linerspace based on the number of steps you want
tensor_list = []
for _, t in enumerate(np.linspace(0, 1, interpolation_steps)):
    new_latent = slerp(float(t), encoding_3, encoding_1)
    tensor_list.append(new_latent)

print(range(interpolation_steps))
for batch in range(interpolation_steps):
    print(batch)
    images.append(
      model.generate_image(
            tensor_list[batch],
            negative_prompt=negative,
            batch_size=batch_size,
            diffusion_noise=noise,
                           )
                  )

## Save Generated Images to a Directory

The image tensors stored in the images variable are translated to .pngs by matplot lib and exported to a directory

In [None]:
counter = 0
base_directory = '/content/gdrive/My Drive/planegan/sample_output_4/'
glob_save_images(base_directory, images, counter)

# Circular Noise Walks

In [None]:
# set noise seed for repeatability
tf.random.set_seed(12345)
set_seed = 500

# calculate noise input dimensions based on image dimensions
noise = tf.random.normal((768 // 8, 512// 8, 4), seed=set_seed)

circular_prompt = "devouring black holes in space, fractal, psytrance, high quality, masterpeice, high definition, highly detailed, elegant, sharp focus,concept art, digital painting, scifi, fantasy, center frame"
negative = "extra limbs, bad art, watermark, face, dull, pencils, error, malformed, low detail, jpeg artifacts, cropped, plain background, ugly, low-res, poorly drawn face, out of frame, poorly drawn hands, blurry, bad art, extra hands, bad anatomy, amputee, missing limbs, amputated"
encoding = tf.squeeze(model.encode_text(circular_prompt))
walk_steps = 1024
batch_size = 1
batches = walk_steps // batch_size

walk_noise_x = tf.random.normal(noise.shape, dtype=tf.float64, seed=set_seed)
walk_noise_y = tf.random.normal(noise.shape, dtype=tf.float64, seed=set_seed)

# noise linespace halved for test
walk_scale_x = tf.cos(tf.linspace(0, 1, walk_steps) * math.pi)
walk_scale_y = tf.sin(tf.linspace(0, 1, walk_steps) * math.pi)
noise_x = tf.tensordot(walk_scale_x, walk_noise_x, axes=0)
noise_y = tf.tensordot(walk_scale_y, walk_noise_y, axes=0)
noise = tf.add(noise_x, noise_y)
batched_noise = tf.split(noise, batches)


Downloading data from https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true
Downloading data from https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5


In [None]:
images = []
for batch in range(batches):
    print(batch)
    images.append(
        model.generate_image(
            encoding,
            batch_size=batch_size,
            negative_prompt=negative,
            diffusion_noise=batched_noise[batch],
        )
    )

In [None]:
counter = 0
base_directory = '/content/gdrive/My Drive/planegan/sample_output_3/'
glob_save_images(base_directory, images, counter)

## Unassign Cloud GPU

For google colab, use this section of code to disconnect the GPU after the run has completed.

In [None]:
from google.colab import runtime
runtime.unassign()