<a href="https://colab.research.google.com/github/Cyber-Handle-Enterprise/waveboxapp/blob/master/DrEyeBender's_Stable_Diffusion_notebook_Public_copy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DrEyeBender's Stable Diffusion notebook

Demo videos:

Musicians and Game of Thrones characters
https://www.youtube.com/watch?v=4MQVlVZO9JU&t=2s

Dogs turning into cats tunring into birds tunring back into dogs
https://www.youtube.com/watch?v=vl1vij9tpyI&t=1s

Fork it: https://github.com/isaac-bender/stable_diffusion_interp/blob/main/DrEyeBender's_Stable_Diffusion_notebook_Public_copy.ipynb

This is based on https://github.com/CompVis/stable-diffusion/blob/main/scripts/txt2img.py

Stuff I added:
1.   Prompt interpolation
2.   Multiple prompts per batch
3.   Better output naming
4.   Automatic grid layout
5.   Prompt prefixes
6.   Random start code options (shared or different per sample in batch)
7.   Fixed or new start code per batch

TODOs:
1.   ~Installation~
2.   Assemble a video out of the rendered frames
3.   Init images
4.   Probably some debugging :)
5.   Bring back prompt file support (temporarily broken)
6.   Add more samplers
7.   Prompt combiner
8.   Menus for enumerated choices (rather than string entry)
9.   Cleanup of some copypasta
10.  Reorganize the settings so they're grouped more logically

# Use examples
These are just a few suggestions that explain how you can use the options. Of course, these aren't the only valid configurations. As you become familiar with the options, I'm sure you'll think of some more creative ways to use them.

Example use case: Prompt variation test
---
You want to compare the effetcs of variations on a prompt in a controlled way
*   use_prefixes = True
*   prompt_mode = 'per_batch'
*   new_start_code_per_sample = False
*   seed = [whatever number you want]

Put your prompts in the prompt list, and put your prefixes in the prefix list.

Each sample gets the same prompt, but with a different prefix.

You could also not use prefixes, and set prompt mode to 'per_sample'. That will put a different propmpt in each cell.

They will all be seeded the exact same way, so you can compare the effects of the prefixes or different prompts in isolation.

Example use case: Just render a bunch of prompts
---
*   use_prefixes = False
*   prompt_mode = 'per_sample'
*   new_start_code_per_sample = True
*   new_start_code_per_batch = True
*   seed = int(time.time())

Put your prompts in the prompt list.

Each sample in the batch will have a different propmpt and seed, and each batch will get a new set of seeds. This way, you get maximum variety.

Example use case: Prompt interpolation
---
Now it's starting to get interesting :)
*   num_interpolation_steps = 10 #higher numbers will give smoother animations
*   use_prefixes = [up to you]
*   prompt_mode = 'per_sample'
*   new_start_code_per_sample = [up to you]
*   num_interpolation_steps = 10 #higher numbers will give smoother animations

Put your prompts in the prompt list.

Each sample in the batch will have a different propmpt.

**If new_start_code_per_batch is True, the start codes will also be interpolated from batch to batch.**

The learned conditioning for each prompt is interpolated between itself and the learned conditioning of the prompt in the same position in the next batch.

num_interpolation_steps controls how smooth the animation is. 10 is just an example starting point. If you have all night, set it to 1000, I'm sure that would look neat!

At the end of the prompt list, it wraps around to interpolate the last batch into the first.

It may look like it doesn't render the very last frame, but that frame would be identical to the first frame anyway.

---
p.s., Share your prompts! Concealing your prompts is a coward's move.

## Clone repo and install packages

**'do_install' defaults False in order to help prevent redundant / unintentional insatlls**

This is from https://colab.research.google.com/github/cpacker/stable-diffusion/blob/interactive-notebook/scripts/stable_diffusion_interactive_colab.ipynb

In [None]:
  # Check what GPU we're using
  !nvidia-smi

In [None]:
#defaults False in order to help prevent redundant / unintentional insatlls
do_install = False #@param {type:"boolean"}
get_weights_from_gdrive = True #@param {type:"boolean"}
if do_install:
  #@title Run once { display-mode: "form" }

  # clone repo
  !git clone https://github.com/CompVis/stable-diffusion.git

  # base colab installs cause issues in lightning.seed_everything
  !pip uninstall -y torchtext

  # Copy-pasta of https://github.com/cpacker/stable-diffusion/blob/main/environment.yaml
  # But try skipping the torch and cudatoolkit installs
  #!pip install numpy==1.19.2  # omit, causes this issue: https://stackoverflow.com/questions/66060487/valueerror-numpy-ndarray-size-changed-may-indicate-binary-incompatibility-exp
  !pip install albumentations==0.4.3
  !pip install opencv-python==4.1.2.30
  !pip install pudb==2019.2
  !pip install imageio==2.9.0
  !pip install imageio-ffmpeg==0.4.2
  !pip install pytorch-lightning==1.4.2
  !pip install omegaconf==2.1.1
  !pip install test-tube>=0.7.5
  !pip install streamlit>=0.73.1
  !pip install einops==0.3.0
  !pip install torch-fidelity==0.3.0
  !pip install transformers==4.19.2
  !pip install torchmetrics==0.6.0
  !pip install kornia==0.6
  !pip install git+https://github.com/crowsonkb/k-diffusion
  !pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
  !pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip
  !pip install -e .

  # Colab broke widget support on 8/19/2022, here's the temp fix:
  # https://github.com/googlecolab/colabtools/issues/3020
  !pip install "ipywidgets>=7,<8"  
%cd /content/stable-diffusion


## Linking your copy of the Stable Diffusion weights
*This notebook assumes you already have access to the Stable Diffusion weights.*

### Upload the weights to Google Drive, then mount into Colab

Upload your copy of the weights (e.g., `sd-v1-4.ckpt`) to a folder on your Drive called "stable-diffusion-checkpoints" (or change the following code to match the path where the weights are stored on your account). If you didn't put your weights in a Drive folder called `stable-diffusion-checkpoints`, update `/content/drive/stable-diffusion-checkpoints/sd-v1-4.ckpt` accordingly.

You can also mount the Drive folder using the Colab file browser.

In [None]:
model_path_from_gdrive = None
# This will open a pop-up window asking you to link your Google Drive account to this notebook for access
if get_weights_from_gdrive:
  from google.colab import drive
  drive.mount('/content/drive/')
  model_path_from_gdrive = '/content/drive/MyDrive/stable-diffusion-checkpoints/sd-v1-4.ckpt'  

# Init

In [None]:
from IPython import display
import argparse, os, sys, glob
import torch
torch.backends.cudnn.benchmark = False
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm.auto import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import contextmanager, nullcontext

try:
  # if you have a local installation, you can cd to it here
  # either set the env var, or just set the path here directly
  sd_path = os.getenv("LOCAL_STABLE_DIFFUSION_PATH")
  os.chdir(sd_path)
  sys.path.append(os.getcwd())
except Exception as e:
  pass

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

# Utility functions

In [None]:
from PIL import Image, ImageFont, ImageDraw
import math
import numpy as np
import time
import io
import k_diffusion as K
import torch.nn as nn

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    return model

def clean_filename(str):
  #print(str)
  try:
    while str[0] == ' ':
      str = str[1:]
    str = ''.join([ c if c.isalnum() else ' ' for c in str ])
    str = str.replace('  ', ' ')
    str = str.replace(' ', '-')
  except:
    pass
  return str

def make_start_code(batch_size, shared_code, C, H, W, f, device):
  if shared_code:
    #the same random start code is shared by each sample
    start_code_max = torch.randn([C, H // f, W // f], device=device)
    start_code_max = start_code_max.repeat(batch_size, *[1 for n in range(len(start_code_max.shape))])
  else:
    #start code is random for the whole tensor, so each sample is different
    start_code_max = torch.randn([batch_size, C, H // f, W // f], device=device)    

  return start_code_max

try:
  from matplotlib import font_manager
  system_fonts = font_manager.findSystemFonts(fontpaths=None, fontext='ttf')
  #print('\n'.join(system_fonts))
  for font_name in system_fonts:
    if 'segoeui.ttf' in font_name.lower() or 'liberationsans-regular.ttf' in font_name.lower():
      font = ImageFont.truetype(font_name, 14)
      break
except:
  font = ImageFont.load_default()

def draw_text(text, x, y, drawing_context, font):
	drawing_context.text((x+2, y+2), text, font = font, fill=(0, 0, 0))
	drawing_context.text((x, y), text, font = font, fill=(255, 255, 255))

def get_grid_dims(num_elements):
  best_diff = num_elements+1
  for x in range(1, num_elements+1):
    y = num_elements // x
    if x * y != num_elements:
      continue
    if y > x:
      continue
    diff = x - y
    if diff < best_diff:
      best_diff = diff
      result = (x, y)
  return result

def get_grid_dims_squareish(num_elements):
  y = int(math.sqrt(num_elements))
  x = y
  while (x * y) < num_elements:
    x += 1
  return x, y

def make_image_grid(images, conversion_func, blanks_ok=True):
  individual_images = [conversion_func(image[None, :]) for image in images]
  _, _, image_h, image_w = images.shape
  return make_image_grid_from_pil(individual_images, image_h, image_w, captions, blanks_ok)

def make_image_grid_from_pil(individual_images, image_h, image_w, captions=None, blanks_ok=True):

  image_count = len(individual_images)
  assert(image_w == image_h) #todo

  if blanks_ok:
    num_cols, num_rows = get_grid_dims_squareish(image_count)
  else:
    num_cols, num_rows = get_grid_dims(image_count)

  w = image_w * num_cols
  h = image_h * num_rows
  combo = Image.new('RGB', (w, h))
  if captions != None:
    drawing_context = ImageDraw.Draw(combo)

  for sample_num in range(image_count):
    col = sample_num % num_cols
    row = sample_num // num_cols
    x = col * image_w
    y = row * image_h
    image = individual_images[sample_num]
    combo.paste(image, (x, y))
    if captions != None:
      draw_text(captions[sample_num], x+1, y, drawing_context, font)

  return combo

def load_list_file(list_file):
  with io.open(list_file,'rt', encoding='utf-8') as f:

    lines = [line.rstrip() for line in f]

    #contents = f.read()
    ##print(f'load_list_file {contents}')
    ##lines = contents.decode("ISO-8859-1").rstrip("\n")
    #lines = contents.rstrip("\n")
    #lines = lines.split("\n")
  #print(len(lines))
  return lines
  
display_counter = 0
def display_image(image):
  global display_counter
  print(f'display_counter {display_counter}')
  if display_counter > 50:
    display_counter = 0
    display.clear_output()
    print(f'display.clear_output()')
  display.display(image)
  display_counter += 1  

def load_captions_from_file_list(file_list, match_suffix=None, match_substr=None, sample_rate=1.0, num_to_take=None, randomize_seed=False):

  if randomize_seed:
    original_state = np.random.get_state()
    np.random.seed(int(time.time()))

  if match_suffix != None or match_substr != None:
    t = []
    for fn in tqdm(file_list):
      if (match_suffix == None or fn.endswith(match_suffix)) and (match_substr == None or match_substr in fn):
        t.append(fn)
    file_list = t

  if (sample_rate > 0 and sample_rate < 1) or num_to_take != None:
    random_idx = (list(range(0, len(file_list))))
    np.random.shuffle(random_idx)
    if num_to_take != None:
      #random_idx = sorted(random_idx[:int(num_to_take)])
      random_idx = (random_idx[:int(num_to_take)])
    else:
      #random_idx = sorted(random_idx[:int(len(random_idx) * sample_rate)])
      random_idx = (random_idx[:int(len(random_idx) * sample_rate)])
    file_list = [file_list[i] for i in random_idx]

  result = []
  for f in tqdm(file_list):
    try:
      caption = load_list_file(f)[0]
      caption = caption.replace('\r', '')
      caption = caption.replace('\n', '')
      result.append(caption)
    except Exception as e:
      print(f'load_captions_from_path: Error loading {f}\n{e}')

  if randomize_seed:
    np.random.set_state(original_state)

  return result

def shuffle_list(the_list):
  random_idx = (list(range(0, len(the_list))))
  np.random.shuffle(random_idx)
  return [the_list[i] for i in random_idx]

import datetime
ring_size = 10
timing_ring = ring_size*[0]
ring_idx = 0

def print_timing_stats(num_done, num_total, start_time, cur_time, n=None, update_freq = None):
  try:
    if num_done <= 0:
        return
    if update_freq == None:
        update_freq = num_total // 10000
    if int(num_done) % int(update_freq) != 0:
        return

    global timing_ring
    global ring_idx
    timing_ring[ring_idx%ring_size] = cur_time
    ring_idx += 1
    num = min(ring_size, ring_idx)
    dt = (max(timing_ring[:num]) - min(timing_ring[:num])) / (num * update_freq)

    #dt = cur_time - start_time
    rate = dt# / num_done
    eta = rate * (num_total - num_done)
    formatted_eta = str(datetime.timedelta(seconds = eta))
    if n != None:
      n = f'n: {n}'
    else:
      n = ''
    print(f'{num_done}/{num_total} ({(100.0*num_done/num_total):.2f}%) ETA: {formatted_eta} it/sec: {(1/rate):.2f} {n}')
  except:
    pass
  
def process_prompt(use_subjects, s, use_prompts, p, use_hard_prompt, hp):
  result = ''
  need_comma = False
  if use_subjects:
    result += s
    need_comma = True
  if use_hard_prompt:
    if need_comma:
      result += ', '
    result += hp
    need_comma = True
  if use_prompts:
    if need_comma:
      result += ', '
    result += p
  return result

class PromptFilename:
  def __init__(self):
    self.prompt_list = []
    self.hash_list = []
    pass

  def add_prompt(self, prompt):
    self.prompt_list.append(clean_filename(prompt))
    self.hash_list.append(hash(prompt))

  def get_string(self, max_len):
    result = ''
    for hash in self.hash_list:
      result = f'{result}-ph`{hash:X}`'
    #print(f'len(result) {len(result)}')
    remaining_chars = max_len - len(result)
    #print(f'remaining_chars {remaining_chars}')
    chars_per_prompt = max(remaining_chars // len(self.prompt_list), 0)
    #print(f'chars_per_prompt {chars_per_prompt}')
    for prompt in self.prompt_list:
      result = f'{result}+{prompt[:chars_per_prompt]}'
    return result

class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
        return uncond + (cond - uncond) * cond_scale


class KDiffusionSampler:
    def __init__(self, m, sampler):
        self.model = m
        self.model_wrap = K.external.CompVisDenoiser(m)
        self.schedule = sampler

    def sample(self, S, conditioning, batch_size, shape, verbose, unconditional_guidance_scale, unconditional_conditioning, eta, x_T):
        sigmas = self.model_wrap.get_sigmas(S)
        x = x_T * sigmas[0]
        model_wrap_cfg = CFGDenoiser(self.model_wrap)

        samples_ddim = K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}, disable=False)

        return samples_ddim, None

#https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
    low_norm = low/torch.norm(low, dim=1, keepdim=True)
    high_norm = high/torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm*high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
    return res        

# Technical Settings
See Python comments for variable descriptions. Many are unchanged from the original script.
This section is for the boring settings you're not likely to change much

In [None]:
#quality level for saved JPGs
jpg_quality = 99 #@param {type:"integer"}

#do not save a grid, only individual samples. Helpful when evaluating lots of samples"
skip_grid = False #@param {type:"boolean"}
#do not save individual samples. For speed measurements."
skip_save = False #@param {type:"boolean"}
#ddim eta (eta=0.0 corresponds to deterministic sampling"
ddim_eta = 0.0 #@param {type:"number"}
#sample this many times"
n_iter = 1 #@param {type:"integer"}
#image height, in pixel space"
H = 512 #@param {type:"integer"}
#image width, in pixel space"
W = 512 #@param {type:"integer"}
#latent channels"
C = 4 #@param {type:"integer"}
#downsampling factor"
f = 8 #@param {type:"integer"}
#evaluate at this precision"
precision = "autocast" #@param {type:"string"}
#      choices=["full", "autocast"],

#for grid layout, should blank cells be allowed?
#True tries to make the grid as square as possible
#False only uses grid dimensions that exactly multiply to the number of samples
grid_mode_blanks_ok = False #@param {type:"boolean"}

precision_scope = autocast if precision=="autocast" else nullcontext


# Settings that require model reload

In [None]:
force_model_reload = False #@param {type:"boolean"}

if not 'model_loaded' in locals() or force_model_reload:
  #path to config which constructs model"
  model_config_path = "configs/stable-diffusion/v1-inference.yaml" #@param {type:"string"}
  #path to checkpoint of model"
  if model_path_from_gdrive != None:
    ckpt = model_path_from_gdrive
  else:
    ckpt = r'models/ldm/stable-diffusion-v1/sd-v1-4.ckpt' #@param {type:"raw"}

  config = OmegaConf.load(f"{model_config_path}")
  model = load_model_from_config(config, f"{ckpt}")

  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
  model = model.to(device)

  model_loaded = True

# Prompts and prompt settings

In [None]:
#the prompts to render
import random

#if you have a file full of prompts, you can load it here
#there's a simple prompt generator at the end of the notebook
use_files = False #@param {type: "boolean"}
repeats_per_prompt = 1 #@param {type: "integer"}

use_subjects = False #@param {type: "boolean"}
randomize_subjects = False #@param {type: "boolean"}
subject_file_path = r'subjects.txt' #@param {type: "raw"}
use_prompts = False #@param {type: "boolean"}
randomize_prompts = False #@param {type: "boolean"}
prompt_file_path = r'prompts.txt' #@param {type: "raw"}

use_hard_prompt = False #@param {type: "boolean"}
hard_prompt = "something about a lighthouse playing a guitar, I think"  #@param {type: "string"}

prompot_combo_mode = "one_each" #@param ["all", "one_each"]

subjects = ['']
prompts = ['']

if use_files:
  if use_subjects:
    subjects = load_list_file(subject_file_path)
  if use_prompts:
    prompts = load_list_file(prompt_file_path)

  if randomize_subjects:
    data = shuffle_list(subjects)

  if randomize_prompts:
    data = shuffle_list(data)

  data = []
  if prompot_combo_mode == "all":
    for s in subjects:
      for p in prompts:
        p = process_prompt(use_subjects, s, use_prompts, p, use_hard_prompt, hard_prompt)
        for r in range(repeats_per_prompt):
          data.append(p)
  if prompot_combo_mode == "one_each":
    num_items = max(len(subjects), len(prompts))
    for i in range(num_items):
      s = subjects[i%len(subjects)]
      p = prompts[i%len(prompts)]
      p = process_prompt(use_subjects, s, use_prompts, p, use_hard_prompt, hard_prompt)
      for r in range(repeats_per_prompt):      
        data.append(p)
  
  print('\n'.join(data[:100]))
  print(len(data))

else:
  data = [
    "painting of a virus monster playing guitar",
    "A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation.",
    "In the foreground, happy chocolate beagles are playing in a field of wildflowers on a clear day, in the distance you can see mountains and a forest.",
    "A sentient android made of wood, by Ivan shishkin and beeple, trending on artstation",
    "The cybernetic humanoid stood in the haunted forest like an ancient statue, pen and colorful ink gouache painting by Hiroshi Yoshida, Dan Mumford, Artstation, Behance, atmospheric environment art",
    "My mother told me a fairytale, about a whimsical fairy land found in an ancient children's storybook, comic illustration by Larry Elmore, Artstation, zbrushcentral",
    #
    "So glad I got up for this -15F sunrise. There's only a handful of days in the year where the morning light lines up perfectly with the hole in Hollow Rock. Grand Portage, MN. ",
    "For a 15-minute period during my flight back to Canada yesterday, there were no clouds blocking the view over Greenland's glaciers and icebergs ",
    "Woke up at 5AM to catch the tulips with morning mist, the Netherlands (OC)",
    "Eclipse Phases over Brasstown Bald, Georgia ",
    "You might have seen it before, but here's that one place in Indonesia with a volcano behind waterfalls. ",
    "One of my scariest moments as a photographer- what you dont see here is the 100m drop in front of me and the gale force wind from behind. Two minutes of light and then it was dark again. Faroe islands ",
    #  
    "This Neo-Tokyo is now a wasteland after the cybernetic robots attacked, pen and colorful ink gouache painting by Hiroshi Yoshida, Dan Mumford, Artstation, Behance, atmospheric environment art",
    "Faceless Phantoms surrounded her asking where she was heading, she's lost walking down a desolate road, pen and colorful ink gouache painting by Hiroshi Yoshida, Naoko Takeuchi, Artstation, Behance, magical fantasy art",
    "A vast magic shop filled with potion bottles and oddities, 3D illustration by Greg Rutkowski and Gediminas Pranckevicius, Artstation, zbrushcentral, 3D shading, magical realism",
    "Bacon double cheeseburger and fries, tarot card by Alphonse Mucha",
    "A tropical sea port at night, mountains and a small city in the background artstation, cgsociety, matte painting, in watercolor",
    "Landscape oil on canvas by Bob Ross",
  ]

#if True, each sample in the batch gets prefixed with the corresponding entry in this list
use_slot_modifiers = False #@param {type:"boolean"}

#various slot modifier ideas
#These examples have 6 entries because that's the batch size I'm using. If you use bigger batches and want to use this feature, you'll need longer lists.
slot_modifiers = ['red ', 'orange ', 'yellow ', 'green ', 'blue ', 'purple ']
#slot_modifiers = ['majestic', 'made of legos', 'fiery', 'firendly', 'stuffed animal', 'made of cake']
#slot_modifiers = ['Superman in ', 'Batman in ', 'Wonder Woman in ', 'The Flash in ', 'Aquaman in ', 'Harley Quinn in ']
#slot_modifiers = ['Captain America in ', 'Iron Man in ', 'Black Widow in ', 'Spider-Man in ', 'The Hulk in ', 'Thor in ']
#slot_modifiers = ['art nouveau ', 'comic book ', 'stained glass ', 'a painting by Mary Jane Ansell of ', 'claymation ', 'a marble sculpture of ']
#slot_modifiers = ['award winning', 'top rated', "world's best", 'incredibly detailed', 'hyperrealistic', 'artstation']


# This is where the magic happens

Frequently-used settings go here

In [None]:
#dirs to write results to
outdir = r"gallery" #@param {type:"raw"}
ckpt_name = os.path.split(ckpt)[1]
ckpt_name = os.path.splitext(ckpt_name)[0]
outdir = os.path.join(outdir, ckpt_name)
sample_path = os.path.join(outdir, "samples") #@param {type:"raw"}
os.makedirs(sample_path, exist_ok=True)
grid_path = os.path.join(outdir, "grids") #@param {type:"raw"}
#debug images show the prompts and blend values
grid_d_path = os.path.join(outdir, "grids_debug") #@param {type:"raw"}
os.makedirs(grid_path, exist_ok=True)
os.makedirs(grid_d_path, exist_ok=True)
#how many samples to produce for each given prompt. A.k.a. batch size"
#A 3090 can handle a maximum of 6 at 512x512
batch_size = 3 #@param {type:"integer"}

#the seed (for reproducible sampling)"
seed = int(time.time()) #@param {type:"raw"}
#useful if you're tyring to evaluate a large number of seeds
increment_seed_per_iter = False #@param {type:"boolean"}

# 1 step means there's no interpolation
num_interpolation_steps = 4 #@param {type:"integer"}
#if True, new start codes are gerated per batch. If false, start codes are only generated once at the beginning of the run.
new_start_code_per_batch = True #@param {type:"boolean"}
#if True, each sample in the batch gets a different start code.  If False, each sample in the batch gets the same start code.
new_start_code_per_sample = True #@param {type:"boolean"}

#batch_shift: each prompt blends with the one in the same slot of the next batch
#right_shift: each prompt blends with the one to its right (or it wraps around to the first one, in the case of the last sample in the batch)
prompt_blend_mode = 'batch_shift' #@param ['batch_shift', 'right_shift']

# The path taken through conditioning space when interpolating. Can be linear or spherical. Personally I think linear looks better, probably because its velocity is lower.
prompt_blend_type = 'linear' #@param ['linear', 'spherical']

#if 'per_sample', each sample in the branch uses a different prompt. If 'per_batch', each sample in the batch uses the same prompt.
prompt_mode = 'per_sample' #@param ['per_batch', 'per_sample']

#number of ddim sampling steps"
ddim_steps_min = 20 #@param {type:"integer"}
ddim_steps_max = 20 #@param {type:"integer"}

#unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))"
scale = 7.5 #@param {type:"string"}

# you can enable more than one sampler
sampler_names = []
use_PLMS = False #@param {type:'boolean'}
if use_PLMS: sampler_names.append('PLMS')
use_DDIM = False #@param {type:'boolean'}
if use_DDIM: sampler_names.append('DDIM')
use_k_dpm_2_a = False #@param {type:'boolean'}
if use_k_dpm_2_a: sampler_names.append('k_dpm_2_a')
use_k_dpm_2 = False #@param {type:'boolean'}
if use_k_dpm_2: sampler_names.append('k_dpm_2')
use_k_euler_a = False #@param {type:'boolean'}
if use_k_euler_a: sampler_names.append('k_euler_a')
use_k_euler = False #@param {type:'boolean'}
if use_k_euler: sampler_names.append('k_euler')
use_k_heun = False #@param {type:'boolean'}
if use_k_heun: sampler_names.append('k_heun')
use_k_lms = True #@param {type:'boolean'}
if use_k_lms: sampler_names.append('k_lms')

for sampler_name in sampler_names:
  if sampler_name == 'PLMS':
      sampler = PLMSSampler(model)
  elif sampler_name == 'DDIM':
      sampler = DDIMSampler(model)
  elif sampler_name == 'k_dpm_2_a':
      sampler = KDiffusionSampler(model,'dpm_2_ancestral')
  elif sampler_name == 'k_dpm_2':
      sampler = KDiffusionSampler(model,'dpm_2')
  elif sampler_name == 'k_euler_a':
      sampler = KDiffusionSampler(model,'euler_ancestral')
  elif sampler_name == 'k_euler':
      sampler = KDiffusionSampler(model,'euler')
  elif sampler_name == 'k_heun':
      sampler = KDiffusionSampler(model,'heun')
  elif sampler_name == 'k_lms':
      sampler = KDiffusionSampler(model,'lms')
  else:
      raise Exception("Unknown sampler: " + sampler_name)    

  start_codes = []
  seed_everything(seed)
  start_codes.append(make_start_code(batch_size, not new_start_code_per_sample, C, H, W, f, device))
  start_codes.append(make_start_code(batch_size, not new_start_code_per_sample, C, H, W, f, device))
  frame = 0

  original_start_code = start_codes[0]
  #print(f'original_start_code {original_start_code}')

  with torch.no_grad():
      with precision_scope("cuda"):
          with model.ema_scope():
              tic = time.time()
              all_samples = list()
              for n in trange(n_iter, desc="Sampling"):
                  if n > 0:
                    if increment_seed_per_iter:
                      #seed = time.time()
                      seed += 1
                      print(f'new seed: {seed}')
                      seed_everything(seed)
                      start_codes[0] = make_start_code(batch_size, not new_start_code_per_sample, C, H, W, f, device)
                      start_codes[1] = make_start_code(batch_size, not new_start_code_per_sample, C, H, W, f, device)

                  if prompt_mode == 'per_sample':
                    num_batches = len(data) // batch_size
                  elif prompt_mode == 'per_batch':
                    num_batches = len(data)
                  num_batches *= len(sampler_names)
                  interp_batch = 0
                  for batch_num in tqdm(range(num_batches), desc="data"):
                    if new_start_code_per_batch:
                      start_codes[0] = start_codes[1]
                      print(f'batch_num {batch_num}')
                      if batch_num == num_batches - 1:
                        print('looping')
                        start_codes[1] = original_start_code
                        #print(f'original_start_code {original_start_code}')
                      else:
                        start_codes[1] = make_start_code(batch_size, not new_start_code_per_sample, C, H, W, f, device)

                    #todo: add support for more parameter sweeps
                    #for scale in np.arange(4, 9.01, 1.0).tolist():
                    #for scale in np.arange(0, 100.01, 1.0).tolist():
                    #for batch_size in range(1, 1000):
                    #for ddim_steps in range(1, 1001):
                    for ddim_steps in range(ddim_steps_min, ddim_steps_max+1):
                      if 1000 % ddim_steps != 0:
                        continue
                      num_interp_batches = num_interpolation_steps * num_batches
                      start_time = time.time()
                      for prompt_interp in tqdm(np.arange(0, 1.0, (1.0/float(num_interpolation_steps))).tolist()):
                        sampling_start_time = time.time()                    
                        print_timing_stats(interp_batch, num_interp_batches, start_time, time.time(), n=None, update_freq = 1)
                        interp_batch += 1
                        prompt_interp_inv = 1.0 - prompt_interp

                        if new_start_code_per_batch:
                          #start_code = start_codes[0][:batch_size, :] * prompt_interp_inv + start_codes[1][:batch_size, :] * prompt_interp
                          #lerp doesn't work well in this space, use slerp instead
                          start_code = slerp(prompt_interp, start_codes[0][:batch_size, :], start_codes[1][:batch_size, :])
                        else:
                          start_code = start_codes[0][:batch_size, :]


                        uc = None
                        if scale != 1.0:
                            uc = model.get_learned_conditioning(batch_size * [""])
                        #if isinstance(prompt, tuple):
                        #    prompts = list(prompt)
                        #else:

                        if prompt_mode == 'per_sample':
                          prompts = data[batch_num*batch_size : (batch_num+1)*batch_size]
                        elif prompt_mode == 'per_batch':
                          prompts = [data[batch_num%len(data)]] * batch_size
                        
                        #print(f'\n'.join(prompts))
                        if use_slot_modifiers:
                          for b in range(batch_size):
                            prompts[b] = slot_modifiers[b] + ', ' +prompts[b]
                        c = model.get_learned_conditioning(prompts)

                        if prompt_blend_mode != None:
                          if prompt_blend_mode =='right_shift':
                            blended_c = torch.zeros_like(c)
                            next_prompts = []
                            for prompt_num in range(batch_size):
                              if prompt_blend_type == 'linear':
                                blended_c[prompt_num] = c[prompt_num] * prompt_interp_inv + c[(prompt_num+1) % batch_size] * prompt_interp
                              elif prompt_blend_type == 'spherical':
                                blended_c[prompt_num] = slerp(prompt_interp, c[prompt_num], c[(prompt_num+1) % batch_size])
                              next_prompts.append(prompts[(prompt_num+1) % batch_size])
                            c = blended_c
                          elif prompt_blend_mode =='batch_shift':
                            
                            if prompt_mode == 'per_sample':
                              next_begin = (batch_num+1)*batch_size
                              next_end   = (batch_num+2)*batch_size
                              next_prompts = []
                              for i in range(next_begin, next_end):
                                i = i % len(data)
                                next_prompts.append(data[i])
                            elif prompt_mode == 'per_batch':
                              next_prompts = [data[(batch_num+1)%len(data)]] * batch_size

                            if use_slot_modifiers:
                              for b in range(batch_size):
                                next_prompts[b] = slot_modifiers[b] + ', ' +next_prompts[b]
                            next_c = model.get_learned_conditioning(next_prompts)
                            

                            if prompt_blend_type == 'linear':
                              c = c * prompt_interp_inv + next_c * prompt_interp
                            elif prompt_blend_type == 'spherical':
                              c = slerp(prompt_interp, c, next_c)


                        shape = [C, H // f, W // f]
                        try:
                          samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                            conditioning=c,
                                                            batch_size=batch_size,
                                                            shape=shape,
                                                            verbose=False,
                                                            unconditional_guidance_scale=scale,
                                                            unconditional_conditioning=uc,
                                                            eta=ddim_eta,
                                                            x_T=start_code)
                        except Exception as e:
                          print(f'Exception in sampler.sample\n{e}')
                          continue


                        x_samples_ddim = model.decode_first_stage(samples_ddim)
                        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                        all_images = []

                        if not skip_save:
                            for x_sample in x_samples_ddim:
                                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                                all_images.append(Image.fromarray(x_sample.astype(np.uint8)))

                            
                            grid_name = PromptFilename()
                            for prompt_string in prompts:
                              grid_name.add_prompt(prompt_string)
                              #grid_name.append(clean_filename(prompt_string)[:20])
                            cleaned_name = grid_name.get_string(max_len=150)
                            for suffix in range(0, 100000): #todo, cleanup
                              #gs means guidance scale
                              output_filepath = os.path.join(grid_path, f"{frame:06d}{cleaned_name}-gs_{scale}-{sampler_name}{ddim_steps}-pi{prompt_interp:.04f}-s{seed}-{suffix:04d}.jpg")
                              grid_prompt_filepath = os.path.join(grid_d_path, f'{cleaned_name}.txt')
                              if not os.path.isfile(output_filepath):
                                break

                            if prompt_blend_mode != None and prompt_interp_inv != 1.0:
                              d_captions = ['\n'.join([f'{prompt_interp_inv:.02f} {prompts[i]}', f'{prompt_interp:.02f} {next_prompts[i]}']) for i in range(batch_size)]
                            else:
                              d_captions = prompts

                            if not skip_grid and batch_size > 1:
                              image_grid = make_image_grid_from_pil(all_images, H, W, blanks_ok=True)
                              image_grid.save(output_filepath, quality=jpg_quality, optimize=True)
                              image_grid = make_image_grid_from_pil(all_images, H, W, captions=d_captions, blanks_ok=True)
                              display_image(image_grid)
                              grid_filename = output_filepath.replace(grid_path, grid_d_path)
                              image_grid.save(grid_filename, quality=jpg_quality, optimize=True)
                              with io.open(grid_prompt_filepath, 'wt', encoding='utf-8') as caption_file:
                                caption_file.write('\n\n'.join(d_captions))
                              
                              #print(f'saved {output_filepath}')
                            save_captions = []
                            for image_num, image in enumerate(all_images):
                              image_name = PromptFilename()
                              image_name.add_prompt(prompts[image_num])
                              #cleaned_name = clean_filename(prompts[image_num])
                              #cleaned_name = f'{cleaned_name[:100]}-ph{hash(prompts[image_num])}'
                              save_caption = [prompts[image_num]]
                              if prompt_blend_mode != None and prompt_interp_inv != 1.0:
                                #cleaned_name += '+' + clean_filename(next_prompts[image_num])
                                image_name.add_prompt(next_prompts[image_num])
                                save_caption.append(next_prompts[image_num])
                              cleaned_name = image_name.get_string(max_len=150)
                              save_captions.append(save_caption)
                              for suffix in range(0, 100000): #todo, cleanup
                                image_path = os.path.join(sample_path, f"{image_num:02d}-{frame:06d}{cleaned_name}-gs_{scale}-{sampler_name}{ddim_steps}-pi{prompt_interp:.04f}-s{seed}-{suffix:04d}.jpg")
                                image_prompt_path = os.path.join(sample_path, f"{cleaned_name}.txt")
                                if not os.path.isfile(image_path):
                                  break
                              if skip_grid:
                                display_image(image)
                              image.save(image_path, quality=jpg_quality, optimize=True)
                              with io.open(image_prompt_path, 'wt', encoding='utf-8') as caption_file:
                                caption_file.write(d_captions[image_num])
                              if batch_size == 1:
                                display_image(image)
                              #print(f'saved {image_path}')
                              
                          
                            sampling_dtime =  time.time() - sampling_start_time
                            print(f'Generated {batch_size} samples in {sampling_dtime}, {batch_size / sampling_dtime} samples per second')
                            try:
                              with io.open('gen_time_log.csv', 'at') as gen_time_log:
                                log_line = f'{batch_size},{sampling_dtime},{batch_size / sampling_dtime},{scale},{sampler_name},{ddim_steps},{seed}\n'
                                gen_time_log.write(log_line)
                            except Exception as e:
                              print(f'Exception writing gen_time_log.csv:\n{e}') 
                        frame += 1

              toc = time.time()

print(f"Your samples are ready and waiting for you here: \n{outdir} \n"
      f" \nEnjoy.")

# Video (todo)

# Prompt combiner
Not necessary for the above, just a little tool to help make combinations of prompts and subjects.

In [None]:
subjects = [
"Beagle Puppy",
"Maine Coon Cat",
"Colorful macaw",
]

generated_prompts = []
for name in subjects:
  scenes = [
  f"An incredibly beautiful richly colored portrait illustration of the Tarot {name} in the style of stained glass by Alphonse Mucha as featured on Artstation",
  f"beautiful {name} portrait photo photograph face, artstation",
  f"Conceptual Portrait of {name}, 35mm Portrait, vivid gouache and oil matte character portrait by Josephine Wall and Kelly McKernan, inspired by Disney, Artstation, CGsociety, 3d shading, Character concept art, #oc, character development",
  f"extremely detailed depiction of {name} in ominous timeless space casting a spell that emits a colorful {name} by artstation warcraft",
  f"Portrait of Cybernetic Cyberpunk {name}, vivid character portrait by Kelly McKernan and Skeeva, Artstation, CGsociety, 3d shading, Character concept art, #oc, character development",
  f"{name} stands in front of three large mirrors isolated staring at the reflection, the reflection in the mirrors are of a {name}, each mirror depicting a different {name}, vivid gouache and oil 3D illustration by Greg Rutkowski and Mark Ryden, Artstation, zbrushcentral, cel shading, magic realism.",
  ]  
  for scene in scenes:
    print(scene)
    generated_prompts.append(scene)

  save_generated_prompt_file = False #@param {type: "boolean"}
  if save_generated_prompt_file:
    with open('prompts.txt', 'wt') as prompt_out_file:
      prompt_out_file.write('\n'.join(generated_prompts))

    if save_generated_prompt_file:
      data = generated_prompts
