In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import cv2 as cv
import numpy as np
!pip install tensorflow_datasets
!pip install keras_cv
!pip install sklearn
import keras_cv

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
model = keras_cv.models.StableDiffusion(jit_compile=False)
encoding_1 = model.encode_text("Photo of an astronaut riding a horse")
encoding_2 = model.encode_text("Photo of a lion surfing on a wave, octane render, realistic, cinematic, focused, sharp, epic, beautiful, vibrant")

seed = 12345
noise = tf.random.normal((512 // 8, 512 // 8, 4), seed=seed)
noise2 = tf.random.normal((512 // 8, 512 // 8, 4), seed=seed)
quality = 10

def recursive_distance_walk(model, encoding_1, encoding_2, max_frame_distance=0.1, p1=0, p2=1, img1=None, img2=None):
    global seed, noise, quality    

    if img1 is None:
        image_1 = model.generate_image(
            encoding_1,
            batch_size=1,
            diffusion_noise=noise,
            num_steps=quality
        )
        cv.imwrite(f"image_at{p1}.png", cv.cvtColor(image_1[0], cv.COLOR_RGB2BGR))
    else:
        image_1 = img1

    if img2 is None:
        image_2 = model.generate_image(
            encoding_2,
            batch_size=1,
            diffusion_noise=noise,
            num_steps=quality
        )
        cv.imwrite(f"image_at{p2}.png", cv.cvtColor(image_2[0], cv.COLOR_RGB2BGR))
    else:
        image_2 = img2

    #distance = np.linalg.norm((image_1 - image_2)/255.0)

    plt.imshow(image_1[0])
    plt.show()
    plt.imshow(image_2[0])
    plt.show()

    similarity = 0
    comats = []
    for i in range(image_1[0].shape[2]):
        diff =  (image_1[0][:, :, i])/255.0 - (image_2[0][:, :, i])/255.0
        
        sobel_x1 = cv.Sobel((image_1[0][:, :, i])/255.0, -1, 1, 0, 3)
        sobel_y1 = cv.Sobel((image_1[0][:, :, i])/255.0, -1, 0, 1, 3)

        edges_1 = (sobel_x1**2 + sobel_y1**2)

        sobel_x2 = cv.Sobel((image_2[0][:, :, i])/255.0, -1, 1, 0, 3)
        sobel_y2 = cv.Sobel((image_2[0][:, :, i])/255.0, -1, 0, 1, 3)
        edges_2 = sobel_x2**2 + sobel_y2**2
        
        diff_2 = (edges_1 - edges_2)**2
        similarity += np.sum(diff_2)
        comats.append(diff_2)
    print(similarity)
    plt.imshow(np.stack(comats, axis=-1)*100)
    plt.show()

    p3 = 0.5*(p1 + p2)

    if similarity < max_frame_distance:
        return [image_1, image_2]
    else:
        middle_encoding = (encoding_1 + encoding_2) / 2.0
        return recursive_distance_walk(model, encoding_1, middle_encoding, max_frame_distance, img1=image_1, p1=p1, p2=p3) + recursive_distance_walk(model, middle_encoding, encoding_2, max_frame_distance, img2=image_2, p1=p3, p2=p2)

By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE


In [None]:
class PointRepresentation(object):
  def __init__(self, image, embedding, linear_position, noise):
    self.image = image
    self.embedding = embedding
    self.linear_position = linear_position
    self.noise = noise

  def get_image_distance(self, other):
    distance = 0
    for i in range(self.image.shape[2]):
        sobel_x1 = cv.Sobel((self.image[:, :, i])/255.0, -1, 1, 0, 3)
        sobel_y1 = cv.Sobel((self.image[:, :, i])/255.0, -1, 0, 1, 3)

        edges_1 = (sobel_x1**2 + sobel_y1**2)

        sobel_x2 = cv.Sobel((other.image[:, :, i])/255.0, -1, 1, 0, 3)
        sobel_y2 = cv.Sobel((other.image[:, :, i])/255.0, -1, 0, 1, 3)
        edges_2 = sobel_x2**2 + sobel_y2**2
        
        diff = (self.image[:, :, i])/255.0 - (other.image[:, :, i])/255.0
        diff = diff**2

        diff_2 = (edges_1 - edges_2)**2
        distance += np.sum(diff_2) + 0.01*np.sum(diff)
    return distance

  def get_mix(self, other, factor):
    new_embedding = other.embedding*factor + (1-factor)*self.embedding
    new_linear_position = other.linear_position * factor + (1-factor) * self.linear_position
    new_noise = other.noise * factor + (1-factor) * self.noise
    new_point = PointRepresentation(None, new_embedding, new_linear_position, new_noise)
    return new_point


class NbestInterpolationsFinder(object):
  def __init__(self, model, n, noise, noise2, embedding_start, embedding_end, quality=10):
    self.model = model
    self.n = n
    self.noise = noise
    self.noise2 = noise2
    self.embedding_start = embedding_start
    self.embedding_end = embedding_end
    self.quality = quality

  def calculate_images(self):
    representations = [PointRepresentation(None, self.embedding_start, 0, self.noise), PointRepresentation(None, self.embedding_end, 1, self.noise2)]

    for representation in representations:
      image = self.model.generate_image(
          representation.embedding,
          batch_size=1,
          diffusion_noise=representation.noise,
          num_steps=self.quality
      )
      representation.image = image[0] 

    while(len(representations) < self.n):
      list_b = representations[1:]
      list_a = representations[:-1]

      biggest_distance = 0
      worst_area = representations
      for a, b in zip(list_a, list_b):
        dist = a.get_image_distance(b)
        if dist > biggest_distance:
          biggest_distance = dist
          worst_area = [a, b]
      print("Worst distance was: " )
      print(biggest_distance)
      self.sample_at(worst_area, representations)
    return representations

  def sample_at(self, area, representations):
    point_a, point_b = area[0], area[1]
    middle_point = point_a.get_mix(point_b, 0.5)

    image = self.model.generate_image(
            middle_point.embedding,
            batch_size=1,
            diffusion_noise=middle_point.noise,
            num_steps=self.quality
        )
    
    middle_point.image = image[0]
    insertion_index = representations.index(point_a)+1
    representations.insert(insertion_index, middle_point)


class TextToVideo(object):
  def __init__(self, model, texts, noise_function, steps_per_image=10, steps_per_prompt=300):
    self.model = model
    self.texts = texts
    self.noise_function = noise_function
    self.steps_per_image = steps_per_image
    self.steps_per_prompt = steps_per_prompt

  def generate_video(self):
    video_images = []
    encodings = [self.model.encode_text(text) for text in self.texts]
    alist = encodings[:-1]
    blist = encodings[1:]
    for encoding_a, encoding_b in zip(alist, blist):
      best_interpolation_finder = NbestInterpolationsFinder(self.model, self.steps_per_prompt, self.noise_function(0), self.noise_function(0), encoding_a, encoding_b, quality=self.steps_per_image)
      part_images = best_interpolation_finder.calculate_images()
      video_images += part_images
    return video_images

In [None]:
step_index = 0
texts = [
  "",
  ""
]
noise = tf.random.normal((512 // 8, 512 // 8, 4), seed=seed)
def noise_func(x):
  global noise
  return noise

ttv = TextToVideo(model, texts, noise_func, steps_per_prompt=100)
video = ttv.generate_video()
#nbip = NbestInterpolationsFinder(model, 60*5, noise, noise2, encoding_1, encoding_1, quality=50)
for img in video:
  plt.imshow(img)
  plt.show()
  cv.imwrite(f"walk_step_{step_index}.png", cv.cvtColor(img, cv.COLOR_RGB2BGR))
  step_index += 1

Worst distance was: 
305947.87916224997
Worst distance was: 
266484.6724185097
Worst distance was: 
266192.77345827717
Worst distance was: 
264702.0424257637
Worst distance was: 
209206.34592630283
Worst distance was: 
176321.80992403967
Worst distance was: 
119121.67625105714
Worst distance was: 
99967.75759962527
Worst distance was: 
96511.13941672556
Worst distance was: 
94061.99184653204

In [None]:
#step_index = 0
#for img in recursive_distance_walk(model, encoding_1, encoding_2, max_frame_distance=10000):
#  plt.imshow(img[0])
#  plt.show()
#  cv.imwrite(f"walk_step_{step_index}.png", cv.cvtColor(img[0], cv.COLOR_RGB2BGR))
#  step_index += 1

In [None]:
import os
from PIL import Image


def load_images():
    images = []
    names = os.listdir("./")
    valid_names = [name for name in names if name.startswith("walk_step_")]
    names = valid_names
    for name in sorted(names, key=lambda x: int(x.split("_")[2].split(".")[0])):
        if name.startswith("walk_step_"):
            images.append(cv.imread(name))
    return images

def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):
    if rubber_band:
        images += images[2:-1][::-1]
    images[0].save(
        filename,
        save_all=True,
        append_images=images[1:],
        duration=1000 // frames_per_second,
        loop=0,
    )


export_as_gif(
    "Dragon-astronaut.gif",
    [Image.fromarray(cv.cvtColor(img, cv.COLOR_BGR2RGB)) for img in load_images()],
    frames_per_second=30,
    rubber_band=True,
)