# Installation and Imports

In [None]:
!pip install diffusers[torch] transformers accelerate image-reward clip

In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from PIL import Image
import json
from io import BytesIO
import os
from functools import partialmethod
from PIL import Image

from diffusers import DiffusionPipeline, EulerDiscreteScheduler
from transformers import CLIPProcessor, CLIPModel
import ImageReward as RM

# Constant Definitions

In [70]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float16

TURBO_URL = "stabilityai/sd-turbo"
# TURBO_URL = "stabilityai/sdxl-turbo"

SD_URL = "stabilityai/stable-diffusion-2-1-base"

CLIP_URL = "openai/clip-vit-large-patch14"
REWARD_URL = "ImageReward-v1.0"

LATENT_SHAPE = lambda bs: (bs, 4, 64, 64)

BASE_PATH = "./drive/MyDrive/RewardSearch/"

# Model Initialiation

In [None]:
def get_turbo_pipeline():
  scheduler = EulerDiscreteScheduler.from_pretrained(TURBO_URL, subfolder="scheduler")
  pipeline = DiffusionPipeline.from_pretrained(
      TURBO_URL, torch_dtype=torch.float16, variant="fp16",
      scheduler=scheduler
  ).to("cpu")

  _ = torch.compile(pipeline, mode="reduce-overhead", fullgraph=True)

  return pipeline


def get_sd_pipeline():
  scheduler = EulerDiscreteScheduler.from_pretrained(SD_URL, subfolder="scheduler")
  pipeline = DiffusionPipeline.from_pretrained(
      SD_URL, torch_dtype=torch.float16, variant="fp16",
      scheduler=scheduler
  ).to("cpu")

  _ = torch.compile(pipeline, mode="reduce-overhead", fullgraph=True)

  return pipeline


def get_clip():
  clip_processor = CLIPProcessor.from_pretrained(CLIP_URL)
  clip_model = CLIPModel.from_pretrained(
      CLIP_URL, torch_dtype=DTYPE, variant="fp16",
      from_tf=True
  ).to("cpu")

  _ = torch.compile(clip_model, mode="reduce-overhead", fullgraph=True)

  return clip_processor, clip_model


def get_reward_model():
  model = RM.load(REWARD_URL).to("cpu")

  _ = torch.compile(model, mode="reduce-overhead", fullgraph=True)

  return model


turbo_pipeline = get_turbo_pipeline()
sd_pipeline = get_sd_pipeline()
clip_processor, clip_model = get_clip()
reward_model = get_reward_model()

# Utilities

In [72]:
class DotDict(dict):
  """ A dictionary that allows item = d.key access for brevity. """

  def __init__(self, *args, **kwargs):
    super().__init__(*args)

    for k, v in kwargs.items():
      self[k] = v


  def copy(self):
    return DotDict(self)


  def from_dict(self, d, recursive=False):
    for k, v in d.items():

      if recursive and isinstance(v, dict):
        self[k] = DotDict().from_dict(v, recursive=True)

      elif recursive and (isinstance(v, list) or (isinstance(v, np.ndarray) and isinstance(v[0], dict))):
        self[k] = []
        for it in v:
          if isinstance(it, dict):
            self[k].append(DotDict().from_dict(it, recursive=True))
          else:
            self[k].append(it)

      else:
        self[k] = v

    return self

  def to_dict(self, recursive=False):
    d = dict()
    for k, v in self.items():

      if recursive and isinstance(v, dict):
        d[k] = DotDict().to_dict(v, recursive=True)

      elif recursive and (isinstance(v, list) or (isinstance(v, np.ndarray) and isinstance(v[0], dict))):
        d[k] = []
        for it in v:
          if isinstance(it, dict):
            d[k].append(DotDict().to_dict(it, recursive=True))
          else:
            d[k].append(it)

      else:
        d[k] = v

    return d


  def __getattr__(self, k):
    try:
      return super().__getattr__(k)
    except AttributeError:
      return self[k]


  def __setattr__(self, k, v):
    try:
      super().__getattr__(k)
      super().__setattr__(k, v)
    except AttributeError:
      self[k] = v

In [73]:
class Cudize:
  """ To save GPU RAM, temporarily move models to gpu,
  and enable torch.inference_mode().
  """

  def __init__(self, *args):
    self.models = [m for m in args]
    self.inference_manager = None

  def __enter__(self):
    for m in self.models:
      m.to(DEVICE)
    self.inference_manager = torch.inference_mode(True)
    self.inference_manager.__enter__()
    return self

  def __exit__(self, *args):
    for m in self.models:
      m.to("cpu")
    self.inference_manager.__exit__(*args)


In [74]:
class Silence:
  """ Silence tqdm progress bars, including inside of packages. """

  def __init__(self):
    pass

  def __enter__(self):
    self.tmp__init__ = tqdm.__init__
    tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
    return self

  def __exit__(self, exc_type, exc_val, exc_tb):
    tqdm.__init__ = self.tmp__init__
    if exc_type is not None:
      raise exc_type(exc_val, exc_tb)
    return self

In [75]:
def slerp(val, low, high):
    """ Batched spherical interpolation between high and low.
    val in [0, 1], 0=low, 1=high.
    """

    assert low.shape == high.shape
    og_shape = low.shape

    low = low.reshape(low.shape[0], -1)
    high = high.reshape(high.shape[0], -1)

    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high

    return res.reshape(og_shape)


def slap(theta, x, y):
  """ Multidimensional slerp to rotate a vector 360 degrees on a 2d plane.
  theta in radians
  """
  theta = theta % (2*np.pi)

  if theta < np.pi/2:
    val = theta / (np.pi/2)
    return slerp(val, x, y)

  elif theta < np.pi:
    val = (theta - np.pi/2) / (np.pi/2)
    return slerp(val, y, -x)

  elif theta < 3*np.pi/2:
    val = (theta - np.pi) / (np.pi/2)
    return slerp(val, -x, -y)

  val = (theta - 3*np.pi/2) / (np.pi/2)
  return slerp(val, -y, x)


def slink(thetas, basis):
  """ Generalization of slap to any number of dimensions.
  Can represent a vector in a k dimensional basis using k-1 angles.
  thetas in radians
  """
  assert len(thetas) == len(basis)-1

  out = basis[0][None]
  for i in range(len(thetas)):
    out = slap(thetas[i], out, basis[i+1][None])

  return out

# Scoring

In [76]:
class Scorer:
  """ Wrapper for various scoring functions """

  def __init__(self, reward_model, clip_processor, clip_model):
    self.reward_model = reward_model
    self.clip_processor = clip_processor
    self.clip_model = clip_model

    self.buffer = BytesIO()

    self.clip_text_cache = None
    self.clip_prompt_cache = None


  def reward_score(self, prompt, images):
    # ImageReward
    out = self.reward_model.inference_rank(prompt, images)[1]
    if isinstance(out, list):
      return out
    return [out]


  def clip_score(self, prompt, images):
    # CLIP Score

    if self.clip_text_cache is None or prompt != self.clip_prompt_cache:

      inputs = self.clip_processor(text=[prompt], return_tensors="pt").to(self.clip_model.device)
      text_emb = clip_model.get_text_features(**inputs)

      text_emb /= torch.norm(text_emb, dim=-1, keepdim=True)

      self.clip_text_cache = text_emb
      self.clip_prompt_cache = prompt

    else:

      text_emb = self.clip_text_cache

    inputs = self.clip_processor(images=images, return_tensors="pt").to(self.clip_model.device)
    img_emb = clip_model.get_image_features(**inputs)

    img_emb /= torch.norm(img_emb, dim=-1, keepdim=True)

    return torch.sum(text_emb * img_emb, dim=-1).detach().cpu().numpy().tolist()


  def jpeg_score(self, prompt, images):
    # JPEG compression ratio

    out = []
    for im in images:

      buffer = BytesIO()
      im.save(buffer, "JPEG")
      b = len(buffer.getvalue())

      out.append(im.width * im.height * 3 / b)

    return out


scorer = Scorer(
  reward_model,
  clip_processor,
  clip_model,
)

# Model Inference

In [77]:
def get_prompt_embeds(prompt, pipeline, batch_size=1, use_guidance=False):
  # Embed prompts once for speed
  return pipeline.encode_prompt(
      prompt=prompt,
      device=pipeline.unet.device,
      num_images_per_prompt=batch_size,
      do_classifier_free_guidance=use_guidance,
  )


def generate_images_turbo(prompt, pipeline, num_inference_steps, latents=None, prompt_embeds=None, batch_size=1, output_type="pil"):
  # Wrapper for generation with SD-Turbo

  with Silence():
    if prompt_embeds is None:
      return pipeline(
          prompt=prompt,
          num_inference_steps=num_inference_steps,
          num_images_per_prompt=batch_size,
          latents=latents,
          output_type=output_type,
          guidance_scale=0.0,
      ).images

    if len(prompt_embeds) == 2:
      return pipeline(
          prompt=prompt,
          num_inference_steps=num_inference_steps,
          num_images_per_prompt=batch_size,
          latents=latents,
          output_type=output_type,
          guidance_scale=0.0,
          prompt_embeds=prompt_embeds[0],
          negative_prompt_embeds=prompt_embeds[1],
      ).images

    return pipeline(
          prompt=prompt,
          num_inference_steps=num_inference_steps,
          num_images_per_prompt=batch_size,
          latents=latents,
          output_type=output_type,
          guidance_scale=0.0,
          prompt_embeds=prompt_embeds[0],
          negative_prompt_embeds=prompt_embeds[1],
          pooled_prompt_embeds=prompt_embeds[2],
          negative_pooled_prompt_embeds=prompt_embeds[3],
      ).images

# Search Algorithms

In [78]:
def gradient_ascent(
    prompt,
    pipeline,
    score_fn,
    num_steps,
    num_inference_steps,
    step_size
):
  """ Gradient ascent/descent with a fixed m.
  Becomes Random Sampling with step_size=1.

  Args:
    prompt: image prompt
    pipeline: diffusers pipeline
    score_fn: (prompt, images) -> list(scores)
    num_steps: number of steps to search
    num_inference_steps: steps per inference rollout
    step_size: interpolation value m

  Returns
    scores: array of scores per step
    images: image for each step
    acc_steps: steps where the new state was accepted
    rej_steps: steps where the new state was rejected
  """

  prompt_embeds = get_prompt_embeds(prompt, pipeline, batch_size=1)

  curr_latents = torch.randn(LATENT_SHAPE(1), device=DEVICE, dtype=DTYPE)
  curr_img = generate_images_turbo(None, pipeline, num_inference_steps, latents=curr_latents, prompt_embeds=prompt_embeds)
  curr_score = score_fn(prompt, curr_img)[0]

  scores = [curr_score]
  images = [np.array(curr_img[0])]
  acc_steps = [0]
  rej_steps = [0]

  for step in (pbar:=tqdm(range(1, 1+num_steps))):

    basis = torch.randn(LATENT_SHAPE(1), device=DEVICE, dtype=DTYPE)
    new_latents = slerp(step_size, curr_latents, basis)

    new_img = generate_images_turbo(None, pipeline, num_inference_steps, latents=new_latents, prompt_embeds=prompt_embeds)
    new_score = score_fn(prompt, new_img)[0]

    if new_score >= curr_score:
      curr_latents = new_latents
      curr_img = new_img
      curr_score = new_score

      acc_steps.append(step)
    else:
      rej_steps.append(step)

    scores.append(new_score)
    images.append(np.array(new_img[0]))

    pbar.set_postfix(curr_score=f"{curr_score:.3f}", new_score=f"{new_score:.3f}")

  return DotDict(
    scores=np.array(scores),
    images=np.stack(images),
    acc_steps=np.array(acc_steps),
    rej_steps=np.array(rej_steps)
  )

In [79]:
def simulated_annealing(
    prompt,
    pipeline,
    score_fn,
    num_steps,
    num_inference_steps,
    mutation_start,
    mutation_decay,
    temp_start,
    temp_decay,
):
  """ Simulated Annealing algorithm.
  Becomes Stochastic Hill Climbing with temp_start -> 0.

  Args:
    prompt: image prompt
    pipeline: diffusers pipeline
    score_fn: (prompt, images) -> list(scores)
    num_steps: number of steps to search
    num_inference_steps: steps per inference rollout
    mutation_start: initial m
    mutation_decay: m decay rate
    temp_start: initial temperature
    temp_decay: temperature decay rate.

  Returns
    scores: array of scores per step
    images: image for each step
    acc_steps: steps where the new state was accepted
    rej_steps: steps where the new state was rejected
  """

  prompt_embeds = get_prompt_embeds(prompt, pipeline, batch_size=1)

  curr_latents = torch.randn(LATENT_SHAPE(1), device=DEVICE, dtype=DTYPE)
  curr_img = generate_images_turbo(None, pipeline, num_inference_steps, latents=curr_latents, prompt_embeds=prompt_embeds)
  curr_score = score_fn(prompt, curr_img)[0]

  mut = mutation_start
  temp = temp_start

  scores = [curr_score]
  images = [np.array(curr_img[0])]
  acc_steps = [0]
  rej_steps = [0]

  for step in (pbar:=tqdm(range(1, num_steps+1))):

    basis = torch.randn(LATENT_SHAPE(1), device=DEVICE, dtype=DTYPE)
    new_latents = slerp(mut, curr_latents, basis)

    new_img = generate_images_turbo(None, pipeline, num_inference_steps, latents=new_latents, prompt_embeds=prompt_embeds)
    new_score = score_fn(prompt, new_img)[0]

    if (
        new_score >= curr_score or
        np.random.rand() < np.exp((new_score - curr_score)/temp)
    ):
      curr_latents = new_latents
      curr_score = new_score

      acc_steps.append(step)
    else:
      rej_steps.append(step)

    scores.append(new_score)
    images.append(np.array(new_img[0]))

    mut *= mutation_decay
    temp *= temp_decay

    pbar.set_postfix(curr_score=f"{curr_score:.3f}", new_score=f"{new_score:.3f}")

  return DotDict(
    scores=np.array(scores),
    images=np.stack(images),
    acc_steps=np.array(acc_steps),
    rej_steps=np.array(rej_steps)
  )

# Configs

In [80]:
BASE_CONFIG = DotDict(
    num_steps=50,
    num_inference_steps=2
)

TEST_CONFIG = DotDict()


""" One-Shot """
TEST_CONFIG.ONE_SHOT = DotDict(
    algorithm=gradient_ascent
)

TEST_CONFIG.ONE_SHOT.REWARD = DotDict(
    score_fn=scorer.reward_score,
    num_steps=0,
    num_inference_steps=BASE_CONFIG.num_inference_steps,
    step_size=1.0
)

TEST_CONFIG.ONE_SHOT.CLIP = TEST_CONFIG.ONE_SHOT.REWARD.copy()
TEST_CONFIG.ONE_SHOT.CLIP.score_fn = scorer.clip_score

TEST_CONFIG.ONE_SHOT.JPEG = TEST_CONFIG.ONE_SHOT.REWARD.copy()
TEST_CONFIG.ONE_SHOT.JPEG.score_fn = scorer.jpeg_score


""" Random """
TEST_CONFIG.RANDOM = DotDict(
    algorithm=gradient_ascent
)

TEST_CONFIG.RANDOM.REWARD = DotDict(
    score_fn=scorer.reward_score,
    num_steps=BASE_CONFIG.num_steps,
    num_inference_steps=BASE_CONFIG.num_inference_steps,
    step_size=1.0
)

TEST_CONFIG.RANDOM.CLIP = TEST_CONFIG.RANDOM.REWARD.copy()
TEST_CONFIG.RANDOM.CLIP.score_fn = scorer.clip_score

TEST_CONFIG.RANDOM.JPEG = TEST_CONFIG.RANDOM.REWARD.copy()
TEST_CONFIG.RANDOM.JPEG.score_fn = scorer.jpeg_score


""" Gradient Descent """
TEST_CONFIG.GRADIENT_DESCENT = DotDict(
    algorithm=gradient_ascent
)

TEST_CONFIG.GRADIENT_DESCENT.REWARD = DotDict(
    score_fn=scorer.reward_score,
    num_steps=BASE_CONFIG.num_steps,
    num_inference_steps=BASE_CONFIG.num_inference_steps,
    step_size=0.01
)

TEST_CONFIG.GRADIENT_DESCENT.CLIP = TEST_CONFIG.GRADIENT_DESCENT.REWARD.copy()
TEST_CONFIG.GRADIENT_DESCENT.CLIP.score_fn = scorer.clip_score

TEST_CONFIG.GRADIENT_DESCENT.JPEG = TEST_CONFIG.GRADIENT_DESCENT.REWARD.copy()
TEST_CONFIG.GRADIENT_DESCENT.JPEG.score_fn = scorer.jpeg_score


""" Simulated Annealing """
TEST_CONFIG.SIMULATED_ANNEALING = DotDict(
    algorithm=simulated_annealing
)

TEST_CONFIG.SIMULATED_ANNEALING.REWARD = DotDict(
    score_fn=scorer.reward_score,
    num_steps=BASE_CONFIG.num_steps,
    num_inference_steps=BASE_CONFIG.num_inference_steps,
    mutation_start=1.0,
    mutation_decay=0.92,
    temp_start=0.25,
    temp_decay=0.92
)

TEST_CONFIG.SIMULATED_ANNEALING.CLIP = TEST_CONFIG.SIMULATED_ANNEALING.REWARD.copy()
TEST_CONFIG.SIMULATED_ANNEALING.CLIP.score_fn = scorer.clip_score
TEST_CONFIG.SIMULATED_ANNEALING.CLIP.temp_start=0.025

TEST_CONFIG.SIMULATED_ANNEALING.JPEG = TEST_CONFIG.SIMULATED_ANNEALING.REWARD.copy()
TEST_CONFIG.SIMULATED_ANNEALING.JPEG.score_fn = scorer.jpeg_score
TEST_CONFIG.SIMULATED_ANNEALING.JPEG.temp_start=1.0


""" Greedy Annealing """
TEST_CONFIG.GREEDY = DotDict(
    algorithm=simulated_annealing
)

TEST_CONFIG.GREEDY.REWARD = DotDict(
    score_fn=scorer.reward_score,
    num_steps=BASE_CONFIG.num_steps,
    num_inference_steps=BASE_CONFIG.num_inference_steps,
    mutation_start=1.0,
    mutation_decay=0.92,
    temp_start=0.0001,
    temp_decay=0.92
)

TEST_CONFIG.GREEDY.CLIP = TEST_CONFIG.GREEDY.REWARD.copy()
TEST_CONFIG.GREEDY.CLIP.score_fn = scorer.clip_score

TEST_CONFIG.GREEDY.JPEG = TEST_CONFIG.GREEDY.REWARD.copy()
TEST_CONFIG.GREEDY.JPEG.score_fn = scorer.jpeg_score


""" Extra Info """
NAME_TO_SCORE_MODEL = {
  "REWARD": reward_model,
  "CLIP": clip_model,
  "JPEG": nn.Linear(1, 1),
}

# Testing

In [None]:
with open(os.path.join(BASE_PATH, "benchmark-prompts.json"), "r") as f:
  PROMPTS = json.load(f)

RESULT_PATH = os.path.join(BASE_PATH, "results")
os.makedirs(RESULT_PATH, exist_ok=True)


""" Generate predictions for every benchmark prompt"""
for alg_name, alg_config in TEST_CONFIG.items():
  if alg_name in ["ONE_SHOT", "RANDOM", "GRADIENT_DESCENT", "SIMULATED_ANNEALING"]:
    continue

  for r_name, r_config in alg_config.items():
    if r_name in ["algorithm", "REWARD"]:
      continue

    save_path = os.path.join(RESULT_PATH, f"{alg_name}-{r_name}.npy")
    results = []
    final_scores = []

    with Cudize(turbo_pipeline, NAME_TO_SCORE_MODEL[r_name]):
      for prompt_dict in (pbar:=tqdm(PROMPTS, desc=f"{alg_name}-{r_name}")):
        id = prompt_dict["id"]
        prompt = prompt_dict["prompt"]

        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        np.random.seed(0)

        with Silence():
          out = alg_config.algorithm(
            prompt,
            turbo_pipeline,
            **r_config
          )
        out.pop("images")

        results.append({
            "id": id,
            "prompt": prompt,
            "result": out.to_dict(recursive=True)
        })
        final_scores.append(np.max(out.scores))

        pbar.set_postfix(
            median=np.median(final_scores),
            mean=np.mean(final_scores),
            std=np.std(final_scores)
        )

      np.save(save_path, results)


In [None]:
def load_database():
  """ Load data generated by above """

  database = {}

  for f in os.listdir(RESULT_PATH):
    filename = f[:f.find(".")]
    if filename == "":
      continue

    alg, reward = tuple(filename.split("-"))

    if reward not in database:
      database[reward] = {}
    database[reward][alg] = np.load(os.path.join(RESULT_PATH, f), allow_pickle=True)

  return DotDict().from_dict(database, recursive=True)

database = load_database()

In [None]:
COLORS = {
    "RANDOM": "k",
    "SIMULATED_ANNEALING": "b",
    "GREEDY": "r",
    "ONE_SHOT": "y"
}

AXIS_NAMES = {
    "CLIP": "CLIP Score",
    "JPEG": "JPEG Compression Ratio",
    "REWARD": "ImageReward Score",
}

ALG_NAMES = {
    "RANDOM": "Random",
    "SIMULATED_ANNEALING": "SA",
    "GREEDY": "SHC",
    "ONE_SHOT": "y"
}

TITLE_NAMES = {
    "CLIP": "CLIP",
    "JPEG": "JPEG Compression Ratio",
    "REWARD": "ImageReward",
}

def visualize_reward_curves(database, reward, ax):
  """ Visualize the score curves for each algorithm,
  also print the mean performance.
  """

  print(f"\n{reward}")
  for alg, alg_data in database[reward].items():
    if alg in ["GRADIENT_DESCENT"]:
      continue

    scores = np.stack(r.result.scores for r in alg_data)
    infer_scores = np.stack([np.max(scores[:, :i+1], axis=-1) for i in range(scores.shape[1])], axis=-1)

    print(f"\t{alg}: {np.mean(infer_scores[:,-1]):.4f}")

    ax.plot(
        range(len(scores[0])),
        np.mean(infer_scores, axis=0),
        color=COLORS[alg],
        label=ALG_NAMES[alg]
    )
    # plt.plot(
    #     range(len(scores[0])),
    #     np.median(infer_scores, axis=0),
    #     ":",
    #     color=COLORS[alg],
    #     label=alg
    # )
    # plt.fill_between(
    #     range(len(scores[0])),
    #     np.mean(infer_scores, axis=0) - np.std(infer_scores, axis=0)/10,
    #     np.mean(infer_scores, axis=0) + np.std(infer_scores, axis=0)/10,
    #     alpha=0.1,
    #     color=COLORS[alg],
    # )

  print("")

  ax.legend()
  ax.set_xlabel("Search Step")
  ax.set_ylabel(AXIS_NAMES[reward])
  ax.set_title(f"{TITLE_NAMES[reward]} Score")


fig, ax = plt.subplots(2, 2, figsize=(8, 8))

visualize_reward_curves(database, "REWARD", ax[0,0])
visualize_reward_curves(database, "CLIP", ax[0, 1])
visualize_reward_curves(database, "JPEG", ax[1, 0])
ax[1,1].axis("off")

plt.tight_layout()
plt.show()


# Visualization

In [None]:
def get_grid(prompt, pipeline, grid_size, batch_size):
  """
  Generate a grid of 2d interpolated images.
  """

  assert GRID_SIZE % BATCH_SIZE == 0

  # torch.manual_seed(0)
  # torch.cuda.manual_seed(0)
  # np.random.seed(0)

  x = torch.randn(LATENT_SHAPE(2), device=DEVICE, dtype=DTYPE)
  y = torch.randn(LATENT_SHAPE(2), device=DEVICE, dtype=DTYPE)

  grid = np.zeros((grid_size, grid_size, 512, 512, 3))

  with Cudize(turbo_pipeline):

    prompt_embeds = get_prompt_embeds(prompt, pipeline, batch_size=batch_size)

    for col in tqdm(range(grid_size)):

      assert grid_size % batch_size == 0
      batched_r = []

      for batch in range(grid_size // batch_size):
        theta = col / grid_size

        l1 = slerp(theta, x[:1], x[1:])
        l2 = slerp(theta, y[:1], y[1:])

        latents = torch.cat(
            [slerp(lamb, l1, l2) for lamb in np.linspace(0, 1, grid_size)[batch*batch_size:(1+batch)*batch_size]],
            dim=0
        )

        with Silence():
          images = generate_images_turbo(
              None,
              pipeline,
              2,
              prompt_embeds=prompt_embeds,
              latents=latents,
              batch_size=batch_size,
              output_type="np"
          )

        for row in range(BATCH_SIZE):
          grid[row+batch_size*batch, col] = images[row]

  return grid


PROMPT = "a concept art of a vehicle, cyberpunk"

GRID_SIZE = 33
BATCH_SIZE = 1

grid = get_grid(PROMPT, turbo_pipeline, GRID_SIZE, BATCH_SIZE)[:-3, :-3]


In [None]:
def get_score_grid(prompt, grid, score_fn, score_model):
  """
  Apply a score function to a grid of images.
  """

  s_grid = np.zeros(grid.shape[:2])

  with Cudize(score_model):
    for i in tqdm(range(grid.shape[0])):
      for j in range(grid.shape[1]):
        img = Image.fromarray((grid[i, j] * 255).astype(np.uint8))
        score = score_fn(prompt, [img])[0]
        s_grid[i, j] = score

  return s_grid

plt.imshow(grid[0,0])
plt.show()

reward_grid = get_score_grid(PROMPT, grid, scorer.reward_score, reward_model)
plt.matshow(reward_grid)
plt.show()

clip_grid = get_score_grid(PROMPT, grid, scorer.clip_score, clip_model)
plt.matshow(clip_grid)
plt.show()

jpeg_grid = get_score_grid(PROMPT, grid, scorer.jpeg_score, reward_model)
plt.matshow(jpeg_grid)
plt.show()

In [None]:
""" Visualize select images from a grid. """

fig, ax = plt.subplots(4, 4, figsize=(24, 24))

g_plot = grid[::8, ::8]
for i in range(4):
  for j in range(4):
    ax[i, j].imshow(g_plot[i, j])
    ax[i, j].axis("off")

plt.title("2 Dimensional Image Interpolation")
plt.tight_layout()
plt.show()


In [None]:
""" Show score grids as a single image. """

fig, ax = plt.subplots(2, 2, figsize=(6, 6))

ax[0,0].matshow(reward_grid)
plt.colorbar(ax[0,0].matshow(reward_grid), ax=ax[0,0])
ax[0,0].set_title("ImageReward Score")

ax[1,0].matshow(jpeg_grid)
plt.colorbar(ax[1,0].matshow(jpeg_grid), ax=ax[1,0])
ax[1,0].set_title("CLIP Score")

ax[0,1].matshow(clip_grid)
plt.colorbar(ax[0,1].matshow(clip_grid), ax=ax[0,1])
ax[0,1].set_title("JPEG Compression Score")

ax[1,1].axis("off")

plt.tight_layout()
plt.show()


In [None]:
""" Get a before/after search is applied to the prompt.
Also visualize the accepted/rejected graph.
"""

PROMPT = "delicious plate of food"
SUFFIX = ""

SCORE_FUNCTION = scorer.reward_score

NUM_STEPS = 50
NUM_INFERENCE_STEPS = 2

STEP_SIZE = 1.0

MUTATION_START = 0.01
MUTATION_DECAY = 1.0

START_TEMP = 0.000001
TEMP_DECAY = 0.92

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

with Cudize(turbo_pipeline, reward_model):
  out = simulated_annealing(
    f"{PROMPT}, {SUFFIX}",
    turbo_pipeline,
    SCORE_FUNCTION,
    NUM_STEPS,
    NUM_INFERENCE_STEPS,
    MUTATION_START,
    MUTATION_DECAY,
    START_TEMP,
    TEMP_DECAY,
    # STEP_SIZE
  )

plt.scatter(out.rej_steps, out.scores[out.rej_steps], color="r", marker="x", label="rejected")
plt.scatter(out.acc_steps, out.scores[out.acc_steps], color="b", label="accepted")
plt.plot(out.acc_steps, out.scores[out.acc_steps], color="b")
plt.legend()
plt.title(f"Random Sampling Rollout")
plt.xlabel("Step")
plt.ylabel("ImageReward Score")
plt.show()

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

ax[0].imshow(out.images[0])
ax[0].axis("off")
ax[0].set_title(f"Initial: Score={out.scores[0]:.4f}")

ax[1].imshow(out.images[np.argmax(out.scores)])
ax[1].axis("off")
ax[1].set_title(f"Final: Score={out.scores.max():.4f}")

plt.suptitle("SHC with Small Fixed m")
plt.show()


# Examples

In [None]:
""" Generate large visualizations
of images from different scores/algorithms/prompts/models
"""

# from benchmark
PROMPTS = [
    "a concept art of a vehicle, cyberpunk",
    "a beautiful portrait of a beautiful woman in the jungle surrounded by pink flowers, shamanism, matte painting, fantasy art",
    "a cute cat",
     "close up photo of anthropomorphic fox animal dressed in white shirt and khaki cargo pants, fox animal, glasses",
    "Portrait of an old sea captain, male, detailed face, fantasy, highly detailed, cinematic, art painting by greg rutkowski",
    "classic model of atoms, made out of glass marbles and chrome steel rods, studio",
    "cartoon character in a style of adventure time cartoon",
    "an alien planet viewed from space, extremely, beautiful, dynamic, creative, cinematic"
]

SHOW_NAMES = [
    "One-Shot",
    "Random Sampling",
    "Simulated Annealing",
    "Stochastic Hill Climbing"
]

R_NAMES = {
    "REWARD": "ImageReward",
    "CLIP": "CLIP",
    "JPEG": "JPEG Compression Ratio",
}

for r_name in TEST_CONFIG.ONE_SHOT.keys():
  if r_name == "algorithm":
    continue

  fig, ax = plt.subplots(len(PROMPTS), 6, figsize=(4*6, 0.35+4*len(PROMPTS)))

  for i, prompt in enumerate(tqdm(PROMPTS)):

    text = ax[i, 0].text(0.5, 0.5, prompt, horizontalalignment='center',
     verticalalignment='center', size='x-large', wrap=True, transform=ax[i, 0].transAxes)
    text._get_wrap_line_width = lambda : 250.
    ax[i, 0].axis("off")

    to_run = list(TEST_CONFIG.items())
    for k in range(len(to_run)):
      if to_run[k][0] == "GRADIENT_DESCENT":
        to_run.pop(k)
        break

    for j, tup in enumerate(tqdm(to_run, leave=False)):
      alg_name, alg_config = tup

      with Cudize(turbo_pipeline, NAME_TO_SCORE_MODEL[r_name]):
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        np.random.seed(0)
        out = alg_config.algorithm(
          prompt,
          turbo_pipeline,
          **alg_config[r_name]
        )

      im = out.images[out.scores.argmax()]

      if i == 0:
        ax[i,j+1].set_title(f"{SHOW_NAMES[j]}\n\nScore={out.scores.max():.4f}")
      else:
        ax[i,j+1].set_title(f"Score={out.scores.max():.4f}")

      ax[i,j+1].imshow(im)
      ax[i,j+1].axis("off")

    with Cudize(sd_pipeline, NAME_TO_SCORE_MODEL[r_name]):
      with Silence():
        torch.manual_seed(0)
        torch.cuda.manual_seed(0)
        np.random.seed(0)
        im = sd_pipeline(
          prompt=prompt
        ).images
        score = TEST_CONFIG.ONE_SHOT[r_name].score_fn(prompt, im)[0]

    if i == 0:
      ax[i,-1].set_title(f"Stable Diffusion 2.1\n\nScore={score:.4f}")
    else:
      ax[i,-1].set_title(f"Score={score:.4f}")

    ax[i,-1].imshow(im[0])
    ax[i,-1].axis("off")

  plt.suptitle(f"{R_NAMES[r_name]} Optimized Examples", fontsize="xx-large")
  # plt.tight_layout()
  plt.show()
  plt.clf()


# Junk

In [None]:

PROMPT = "delicious plate of food"

NUM_EPOCHS = 10
STEPS_PER_EPOCH = 5
BATCH_SIZE = 4

MUTATION_START = 1 # 1/2
MUTATION_DECAY = 1 # 0.95

START_TEMP = 1/2
TEMP_DECAY = 0.93

# fig, ax = plt.subplots(1, NUM_EPOCHS, figsize=(3*NUM_EPOCHS, 4))
scores = []
sampled_scores = []

with torch.inference_mode(True):

  latents = torch.randn(LATENT_SHAPE(1), device=DEVICE, dtype=DTYPE)
  image = turbo_pipeline(
    prompt=PROMPT,
    num_inference_steps=4,
    guidance_scale=0.0,
    num_images_per_prompt=1,
    latents=latents
  ).images[0]
  score = reward_model.inference_rank(PROMPT, [image])[1]

  t = 0
  # ax[0].imshow(image)
  # ax[0].axis("off")
  # ax[0].set_title(f"Step {t}: r={score:.2f}")
  # scores.append(score)

  temp = START_TEMP
  mut = MUTATION_START

  for epoch in (pbar:=tqdm(range(1, NUM_EPOCHS))):
    for step in range(STEPS_PER_EPOCH):

      basis = torch.randn(LATENT_SHAPE(BATCH_SIZE), device=DEVICE, dtype=DTYPE)

      test_latents = slerp(mut, torch.cat([latents]*BATCH_SIZE,dim=0), basis)
      test_images = pipeline(
        prompt=PROMPT,
        num_inference_steps=1,
        guidance_scale=0.0,
        num_images_per_prompt=BATCH_SIZE,
        latents=test_latents
      ).images
      test_scores = reward_model.inference_rank(PROMPT, test_images)[1]

      best = np.argmax(test_scores)
      if test_scores[best] > score:
        latents = test_latents[best][None]
        image = test_images[best]
        score = test_scores[best]

      else:
        logits = torch.tensor([score]+test_scores) / temp
        keeper = torch.distributions.Categorical(logits=logits).sample()

        if keeper > 0:
          latents = test_latents[keeper-1][None]
          image = test_images[keeper-1]
          score = test_scores[keeper-1]

      pbar.set_postfix(score=f"{score:.2f}")
      scores.append(score)
      sampled_scores.append(test_scores)

      temp *= TEMP_DECAY
      mut *= MUTATION_DECAY
      t += 1

    # ax[epoch].imshow(image)
    # ax[epoch].axis("off")
    # ax[epoch].set_title(f"Step {t}: r={score:.2f}")

plt.imshow(image)
plt.axis("off")
plt.title("Stable DIffusion Turbo")
plt.show()

# plt.suptitle(f'Image Progress For prompt "{PROMPT}"')
# plt.tight_layout()
# plt.show()
# plt.clf()

# sampled_scores = np.array(sampled_scores)
# plt.scatter(
#     1+np.arange(len(sampled_scores)).repeat(sampled_scores.shape[1]),
#     sampled_scores.reshape(-1)
# )
# plt.xlabel("Sampling Step")
# plt.ylabel("Rewards")
# plt.title(f'Rewards For prompt "{PROMPT}"')
# plt.plot(scores)
# plt.show()


In [None]:

NUM_EPOCHS = 5
STEPS_PER_EPOCH = 5
BATCH_SIZE = 4

MUTATION_START = 1 # 1/2
MUTATION_DECAY = 1 # 0.95**2

START_TEMP = 1/2
TEMP_DECAY = 0.93**2

with open("benchmark-prompts.json", "r") as f:
  prompts = json.load(f)

rewards = []

with torch.inference_mode():

  for itm in (pbar:=tqdm(prompts)):
    prompt = itm["prompt"]

    latents = torch.randn(LATENT_SHAPE(1), device=DEVICE, dtype=DTYPE)
    image = pipeline(
      prompt=prompt,
      num_inference_steps=4,
      guidance_scale=0.0,
      num_images_per_prompt=1,
      latents=latents
    ).images[0]
    score = reward_model.inference_rank(prompt, [image])[1]

    best_score = score

    # temp = START_TEMP
    # mut = MUTATION_START

    # for epoch in tqdm(range(1, NUM_EPOCHS)):
    #   for step in range(STEPS_PER_EPOCH):

    #     basis = torch.randn(LATENT_SHAPE(BATCH_SIZE), device=DEVICE, dtype=DTYPE)

    #     test_latents = slerp(mut, torch.cat([latents]*BATCH_SIZE,dim=0), basis)
    #     test_images = pipeline(
    #       prompt=prompt,
    #       num_inference_steps=1,
    #       guidance_scale=0.0,
    #       num_images_per_prompt=BATCH_SIZE,
    #       latents=test_latents
    #     ).images
    #     test_scores = reward_model.inference_rank(prompt, test_images)[1]

    #     best = np.argmax(test_scores)
    #     if test_scores[best] > score:
    #       latents = test_latents[best][None]
    #       image = test_images[best]
    #       score = test_scores[best]

    #     else:
    #       logits = torch.tensor([score]+test_scores) / temp
    #       keeper = torch.distributions.Categorical(logits=logits).sample()

    #       if keeper > 0:
    #         latents = test_latents[keeper-1][None]
    #         image = test_images[keeper-1]
    #         score = test_scores[keeper-1]

    #     if score > best_score:
    #       best_score = score

    #     temp *= TEMP_DECAY
    #     mut *= MUTATION_DECAY

    rewards.append(best_score)
    pbar.set_postfix(mean=f"{np.mean(rewards):.3f}", std=f"{np.std(rewards):.3f}")

    np.save("rewards.npy", np.array(rewards))


In [None]:
print(np.mean(rewards), np.std(rewards))