# **Evolutionary Prompt-Mining based on Deforum Stable Diffusion v0.2**
[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from [HuggingFace](https://huggingface.co/CompVis/stable-diffusion).

The aesthetics model that is an integral part of this method was made by [Katherine Crowson](https://twitter.com/RiversHaveWings) and can be found on her [Github account](https://github.com/crowsonkb/simulacra-aesthetic-models). 

Notebook by [Magnus Petersen](https://twitter.com/Omorfiamorphism), the baseline of the notebook, setup, description, and image generation, is based on the
[deforum](https://discord.gg/upmXXsrwZc) notebook.

# Setup

In [None]:
#@markdown **NVIDIA GPU**
import subprocess
sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(sub_p_res)

In [None]:
#@markdown **Model and Output Paths**
# ask for the link
print("Local Path Variables:\n")

models_path = "/content/models" #@param {type:"string"}
output_path = "/content/output" #@param {type:"string"}

#@markdown **Google Drive Path Variables (Optional)**
mount_google_drive = True #@param {type:"boolean"}
force_remount = False

if mount_google_drive:
    from google.colab import drive # type: ignore
    try:
        drive_path = "/content/drive"
        drive.mount(drive_path,force_remount=force_remount)
        models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"}
        output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"}
        models_path = models_path_gdrive
        output_path = output_path_gdrive
    except:
        print("...error mounting drive or with drive path variables")
        print("...reverting to default path variables")

import os
os.makedirs(models_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)

print(f"models_path: {models_path}")
print(f"output_path: {output_path}")

In [None]:
#@markdown **Setup Environment**

setup_environment = True #@param {type:"boolean"}
print_subprocess = False #@param {type:"boolean"}

if setup_environment:
    import subprocess, time
    print("Setting up environment...")
    start_time = time.time()
    all_process = [
        ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],
        ['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'],
        ['git', 'clone', 'https://github.com/deforum/stable-diffusion'],
        ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],
        ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],
        ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq'],
        ['git', 'clone', 'https://github.com/shariqfarooq123/AdaBins.git'],
        ['git', 'clone', 'https://github.com/isl-org/MiDaS.git'],
        ['git', 'clone', 'https://github.com/MSFTserver/pytorch3d-lite.git'],
    ]
    for process in all_process:
        running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')
        if print_subprocess:
            print(running)
    
    print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
    with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:
        f.write('')

    end_time = time.time()
    print(f"Environment set up in {end_time-start_time:.0f} seconds")

In [None]:
#@markdown **Python Definitions**
import json
from IPython import display

import gc, math, os, pathlib, subprocess, sys, time
import cv2
import numpy as np
import pandas as pd
import random
import requests
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from skimage.exposure import match_histograms
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from types import SimpleNamespace
from torch import autocast

sys.path.extend([
    'src/taming-transformers',
    'src/clip',
    'stable-diffusion/',
    'k-diffusion',
    'pytorch3d-lite',
    'AdaBins',
    'MiDaS',
])

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

! git clone https://github.com/MagnusPetersen/EvoGen-Prompt-Evolution.git

from helpers import save_samples, sampler_fn
from k_diffusion.external import CompVisDenoiser
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
        return uncond + (cond - uncond) * cond_scale

def add_noise(sample: torch.Tensor, noise_amt: float):
    return sample + torch.randn(sample.shape, device=sample.device) * noise_amt

def get_output_folder(output_path, batch_folder):
    out_path = os.path.join(output_path,time.strftime('%Y-%m'))
    if batch_folder != "":
        out_path = os.path.join(out_path, batch_folder)
    os.makedirs(out_path, exist_ok=True)
    return out_path

def load_img(path, shape):
    if path.startswith('http://') or path.startswith('https://'):
        image = Image.open(requests.get(path, stream=True).raw).convert('RGB')
    else:
        image = Image.open(path).convert('RGB')

    image = image.resize(shape, resample=Image.LANCZOS)
    image = np.array(image).astype(np.float16) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

def maintain_colors(prev_img, color_match_sample, mode):
    if mode == 'Match Frame 0 RGB':
        return match_histograms(prev_img, color_match_sample, multichannel=True)
    elif mode == 'Match Frame 0 HSV':
        prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)
        color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)
        matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)
        return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)
    else: # Match Frame 0 LAB
        prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)
        color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)
        matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)
        return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)

def make_callback(sampler, dynamic_threshold=None, static_threshold=None):  
    # Creates the callback function to be passed into the samplers
    # The callback function is applied to the image after each step
    def dynamic_thresholding_(img, threshold):
        # Dynamic thresholding from Imagen paper (May 2022)
        s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))
        s = np.max(np.append(s,1.0))
        torch.clamp_(img, -1*s, s)
        torch.FloatTensor.div_(img, s)

    # Callback for samplers in the k-diffusion repo, called thus:
    #   callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
    def k_callback(args_dict):
        if static_threshold is not None:
            torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold)
        if dynamic_threshold is not None:
            dynamic_thresholding_(args_dict['x'], dynamic_threshold)

    # Function that is called on the image (img) and step (i) at each step
    def img_callback(img, i):
        # Thresholding functions
        if dynamic_threshold is not None:
            dynamic_thresholding_(img, dynamic_threshold)
        if static_threshold is not None:
            torch.clamp_(img, -1*static_threshold, static_threshold)

    if sampler in ["plms","ddim"]: 
        # Callback function formated for compvis latent diffusion samplers
        callback = img_callback
    else: 
        # Default callback function uses k-diffusion sampler variables
        callback = k_callback

    return callback

def generate(args, prompt_batch, return_latent=False, return_sample=False, return_c=False):
    seed_everything(args.seed)
    os.makedirs(args.outdir, exist_ok=True)

    if args.sampler == 'plms':
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)

    model_wrap = CompVisDenoiser(model)       
    batch_size = args.n_samples
    data = [prompt_batch]

    init_latent = None
    if args.init_latent is not None:
        init_latent = args.init_latent
    elif args.init_sample is not None:
        init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample))
    elif args.init_image != None and args.init_image != '':
        init_image = load_img(args.init_image, shape=(args.W, args.H)).to(device)
        init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
        init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space        

    sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, verbose=False)

    t_enc = int((1.0-args.strength) * args.steps)

    start_code = None
    if args.fixed_code and init_latent == None:
        start_code = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)

    callback = make_callback(sampler=args.sampler,
                            dynamic_threshold=args.dynamic_threshold, 
                            static_threshold=args.static_threshold)

    results = []
    precision_scope = autocast if args.precision == "autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                for prompts in data:
                    uc = None
                    if args.scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = model.get_learned_conditioning(prompts)

                    if args.init_c != None:
                        c = args.init_c

                    if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]:
                        samples = sampler_fn(
                            c=c, 
                            uc=uc, 
                            args=args, 
                            model_wrap=model_wrap, 
                            init_latent=init_latent, 
                            t_enc=t_enc, 
                            device=device, 
                            cb=callback)
                    else:

                        if init_latent != None:
                            z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                            samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.scale,
                                                    unconditional_conditioning=uc,)
                        else:
                            if args.sampler == 'plms' or args.sampler == 'ddim':
                                shape = [args.C, args.H // args.f, args.W // args.f]
                                samples, _ = sampler.sample(S=args.steps,
                                                                conditioning=c,
                                                                batch_size=args.n_samples,
                                                                shape=shape,
                                                                verbose=False,
                                                                unconditional_guidance_scale=args.scale,
                                                                unconditional_conditioning=uc,
                                                                eta=args.ddim_eta,
                                                                x_T=start_code,
                                                                img_callback=callback)

                    if return_latent:
                        results.append(samples.clone())

                    x_samples = model.decode_first_stage(samples)
                    if return_sample:
                        results.append(x_samples.clone())

                    x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                    if return_c:
                        results.append(c.clone())

                    for x_sample in x_samples:
                    #    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                    #    image = Image.fromarray(x_sample.astype(np.uint8))
                        results.append(x_sample)
    return results

def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:
    sample = ((sample.astype(float) / 255.0) * 2) - 1
    sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)
    sample = torch.from_numpy(sample)
    return sample

def sample_to_cv2(sample: torch.Tensor) -> np.ndarray:
    sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32)
    sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)
    sample_int8 = (sample_f32 * 255).astype(np.uint8)
    return sample_int8

In [None]:
#@markdown **Select and Load Model**
from IPython.display import clear_output, display

model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"]
model_checkpoint =  "sd-v1-4.ckpt" #@param ["custom","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt"]
custom_config_path = "" #@param {type:"string"}
custom_checkpoint_path = "" #@param {type:"string"}

check_sha256 = True #@param {type:"boolean"}

load_on_run_all = True #@param {type: 'boolean'}
half_precision = True # needs to be fixed

model_map = {
    "sd-v1-4-full-ema.ckpt": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'},
    "sd-v1-4.ckpt": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},
    "sd-v1-3-full-ema.ckpt": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'},
    "sd-v1-3.ckpt": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'},
    "sd-v1-2-full-ema.ckpt": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'},
    "sd-v1-2.ckpt": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'},
    "sd-v1-1-full-ema.ckpt": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'},
    "sd-v1-1.ckpt": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}
}

# config path
ckpt_config_path = custom_config_path if model_config == "custom" else os.path.join(models_path, model_config)
if os.path.exists(ckpt_config_path):
    print(f"{ckpt_config_path} exists")
else:
    ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
print(f"Using config: {ckpt_config_path}")

# checkpoint path or download
ckpt_path = custom_checkpoint_path if model_checkpoint == "custom" else os.path.join(models_path, model_checkpoint)
ckpt_valid = True
if os.path.exists(ckpt_path):
    print(f"{ckpt_path} exists")
else:
    print(f"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}")
    ckpt_valid = False

if check_sha256 and model_checkpoint != "custom" and ckpt_valid:
    import hashlib
    print("\n...checking sha256")
    with open(ckpt_path, "rb") as f:
        bytes = f.read() 
        hash = hashlib.sha256(bytes).hexdigest()
        del bytes
    if model_map[model_checkpoint]["sha256"] == hash:
        print("hash is correct\n")
    else:
        print("hash in not correct\n")
        ckpt_valid = False

if ckpt_valid:
    print(f"Using ckpt: {ckpt_path}")

def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):
    map_location = "cuda" #@param ["cpu", "cuda"]
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location=map_location)
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    if half_precision:
        model = model.half().to(device)
    else:
        model = model.to(device)
    model.eval()
    return model

if load_on_run_all and ckpt_valid:
    local_config = OmegaConf.load(f"{ckpt_config_path}")
    model = load_model_from_config(local_config, f"{ckpt_path}",half_precision=half_precision)
    model = model.to(device)

artists = pd.read_csv('/content/EvoGen-Prompt-Evolution/Wordlists/artists.csv')
genres = pd.read_csv('/content/EvoGen-Prompt-Evolution/Wordlists/genres.csv')
words = pd.read_csv('/content/EvoGen-Prompt-Evolution/Wordlists/wordlist.csv')
words_aes = pd.read_csv('/content/EvoGen-Prompt-Evolution/Wordlists/wordsprompt.csv')
engrams_aes = pd.read_csv('/content/EvoGen-Prompt-Evolution/Wordlists/engramprompt.csv')

clear_output()

In [None]:
#@title Aesthetics Helpers

from torchvision.transforms import functional as TF
import torch.nn.functional as F

!git clone https://github.com/openai/CLIP
!git clone https://github.com/crowsonkb/simulacra-aesthetic-models
!pip install -e ./CLIP
import sys
sys.path.append('./CLIP')

import clip
from torchvision import transforms
import matplotlib.pyplot as plt 

class AestheticMeanPredictionLinearModel(nn.Module):
    def __init__(self, feats_in):
        super().__init__()
        self.linear = nn.Linear(feats_in, 1)

    def forward(self, input):
        x = F.normalize(input, dim=-1) * input.shape[-1] ** 0.5
        return self.linear(x)

clip_model_name = 'ViT-B/16'
clip_model = clip.load(clip_model_name, jit=False, device=device)[0]
clip_model.eval().requires_grad_(False)

normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                 std=[0.26862954, 0.26130258, 0.27577711])

# 512 is embed dimension for ViT-B/16 CLIP
aes_model = AestheticMeanPredictionLinearModel(512)
aes_model.load_state_dict(
    torch.load("/content/simulacra-aesthetic-models/models/sac_public_2022_06_29_vit_b_16_linear.pth")
)

aes_model = aes_model.to(device)
clear_output()

In [None]:
#@title Evolution Helpers

class PromptGenerator:
    def __init__(self, population_count, prompt_length_max, prompt_length_min,
                 artist_prop, genre_prop, custom_prop, delete_prop, add_prop, mutate_prop,
                 shuffle_prop, cross_prop, k):
        self.artists = artists
        self.genres = genres
        if use_aes_words:
          self.words = words_aes
        if use_aes_engrams:
          self.words = engrams_aes
        if use_aes_words and use_aes_engrams:
          self.words = words_aes.append(engrams_aes)
        if use_aes_words == False and use_aes_engrams == False:
          self.words = words

        self.custom = custom
        self.population_count = population_count
        self.prompt_length_max = prompt_length_max 
        self.prompt_length_min = prompt_length_min

        self.artist_prop = artist_prop 
        self.genre_prop = genre_prop 
        self.custom_prop = custom_prop

        self.word_prop = 1 - self.artist_prop - self.genre_prop
        self.delete_prop = delete_prop 
        self.add_prop = add_prop 
        
        self.mutate_prop = mutate_prop 
        self.shuffle_prop = shuffle_prop 
        self.cross_prop = cross_prop 
        self.k = k 

        self.fittness_history = []

    def initialize_prompt_population(self):
        #initialize the prompt population by randomly selecting words from artists, genres, and words dictionaries
        prompt_population = []
        for i in range(self.population_count):
            prompt = []
            for j in range(np.random.randint(self.prompt_length_min, self.prompt_length_max)):
                #pic based on artist_prop, genre_prop, and word_prop probabilities which dataframe to select from
                rand_num = np.random.random()
                if rand_num < self.artist_prop:
                    prompt.append(self.artists.sample(1).artist.values[0])
                elif rand_num < self.artist_prop + self.genre_prop:
                    prompt.append(self.genres.sample(1).genre.values[0])
                elif rand_num < self.artist_prop + self.genre_prop + self.custom_prop:
                    prompt.append(self.custom.sample(1).custom.values[0])
                else:
                    prompt.append(self.words.sample(1).word.values[0])
            prompt_population.append(prompt)
        self.prompt_population = prompt_population

    def selection(self, scores):
        selection_ix = np.random.randint(self.population_count)
        for ix in np.random.randint(0, self.population_count, self.k-1):
            if scores[ix] > scores[selection_ix]:
                selection_ix = ix
        return self.prompt_population[selection_ix]

    def cross_over(self, prompt_1, prompt_2):
        c1, c2 = prompt_1, prompt_2
        rand_num = np.random.random()
        if rand_num < self.cross_prop:
            if len(prompt_2) ==0:
              prompt_index = 0
            else:
              prompt_index = np.random.randint(0, min(len(prompt_1), len(prompt_2)))
            c1 = prompt_1[:prompt_index] + prompt_2[prompt_index:]
            c2 = prompt_2[:prompt_index] + prompt_1[prompt_index:]
        return [c1, c2]

    def mutate_prompts(self, prompt):
        if (len(prompt) == 0):
          prompt_index = 0
          rand_num = np.random.random()
          if rand_num < self.artist_prop:
              prompt.insert(prompt_index, self.artists.sample(1).artist.values[0])
          elif rand_num < self.artist_prop + self.genre_prop:
              prompt.insert(prompt_index, self.genres.sample(1).genre.values[0])
          elif rand_num < self.artist_prop + self.genre_prop + self.custom_prop:
              prompt.insert(prompt_index, self.custom.sample(1).custom.values[0])
          else:
              prompt.insert(prompt_index, self.words.sample(1).word.values[0])

        for i in range(len(prompt)):
            rand_num = np.random.random()
            if rand_num < self.mutate_prop:
                rand_num = np.random.random()
                if rand_num < self.artist_prop:
                    prompt[i] = self.artists.sample(1).artist.values[0]
                elif rand_num < self.artist_prop + self.genre_prop:
                    prompt[i] = self.genres.sample(1).genre.values[0]
                elif rand_num < self.artist_prop + self.genre_prop + self.custom_prop:
                    prompt[i] = self.custom.sample(1).custom.values[0]
                else:
                    prompt[i] = self.words.sample(1).word.values[0]


        delete_count = np.random.binomial(len(prompt), self.delete_prop)
        if len(prompt) - delete_count < 2:
            delete_count = len(prompt) - 2

        prompt = np.delete(prompt, np.random.randint(len(prompt), size=delete_count)).tolist()
        
        for i in range(len(prompt)):
            rand_num = np.random.random()
            if rand_num < self.add_prop:
                rand_num = np.random.random()
                if rand_num < self.artist_prop:
                    prompt.insert(i, self.artists.sample(1).artist.values[0])
                elif rand_num < self.artist_prop + self.genre_prop:
                    prompt.insert(i, self.genres.sample(1).genre.values[0])
                elif rand_num < self.artist_prop + self.genre_prop + self.custom_prop:
                    prompt.insert(i, self.custom.sample(1).custom.values[0])
                else:
                    prompt.insert(i, self.words.sample(1).word.values[0])
            
        rand_num = np.random.random()
        if rand_num < self.shuffle_prop:
            prompt = np.random.permutation(prompt).tolist()
            
        return prompt

    def create_next_generation(self, scores):
        selected = [self.selection(scores) for _ in range(self.population_count)]
        children = []
        for i in range(0, self.population_count, 2):
            prompt_1, prompt_2 = selected[i], selected[i+1]
            for c in self.cross_over(prompt_1, prompt_2):
                c = self.mutate_prompts(c)
                children.append(c)

        filtered_children = []
        for elem in children:
            if elem not in filtered_children:
                filtered_children.append(elem)

        children = filtered_children

        missing_prompts = self.population_count - len(children)
        print("The following number of duplicate prompts had to be replaced with random ones:"+str(missing_prompts))
        for i in range(missing_prompts):
            prompt = []
            for j in range(np.random.randint(self.prompt_length_min, self.prompt_length_max)):
                rand_num = np.random.random()
                if rand_num < self.artist_prop:
                    prompt.append(self.artists.sample(1).artist.values[0])
                elif rand_num < self.artist_prop + self.genre_prop:
                    prompt.append(self.genres.sample(1).genre.values[0])
                else:
                    prompt.append(self.words.sample(1).word.values[0])
            children.append(prompt)

        self.prompt_population = children

    def population_as_string(self):
        return [' '.join(prompt) for prompt in self.prompt_population]

# Settings

In [None]:
custom_words = ["😀", "😃", "😄", "😁", "😆", "😅", "😂", "🤣", "🥲", "☺️", "😊", "😇", "🙂", "🙃", "😉", "😌", "😍", "🥰", "😘", "😗"]
custom = pd.DataFrame(custom_words, columns=["custom"])

In [None]:
#@markdown **Evolutionary Algorithm Settings**

#@markdown General population settings, such as how many generations the algorithm runs for, how many prompts there are in each generation, and the word length range of the prompts.
generations = 25 #@param
n_samples = 12 #@param
population_count = 100 #@param
population_count = int(n_samples*(population_count//n_samples + 1))
prompt_length_max = 15 #@param
prompt_length_min = 3 #@param
#@markdown Probability to sample from one of the word lists when adding or mutating a word. The difference between the sum of the three custom lists and 1 is the probability to sample from the English dictionary word list.
artist_prop = 0.04 #@param
genre_prop = 0.08 #@param
custom_prop = 0.0 #@param
#@markdown Decide which list to use if the genre, custom and artists list are not selected from sampling. Use either a list from high scoring prompts, 2/3-grams of those prompts of both. If none are selected use a complete english dictionary.
use_aes_words = True #@param {type:"boolean"}
use_aes_engrams = True #@param {type:"boolean"}
#@markdown Generation evolution settings including the probability to delete, add and swap out each word for a new one from the dictionary.
delete_prop = 0.1 #@param
add_prop = 0.1 #@param
mutate_prop = 0.2 #@param
shuffle_prop = 0.1 #@param
#@markdown Generation evolution settings for the new generation parent selection and breeding. The cross-over probability is the probability of the parents swapping prompt parts. K denotes the rounds in the tournament selection process. A higher K value means fewer parents generate the next generation, this means a higher score increase but less diversity in the prompts.
cross_prop = 0.8 #@param
k = 4 #@param
#@markdown Cutoff score to save the image and prompt
cutoff = 5.5 #@param

prompt_generator = PromptGenerator(population_count, prompt_length_max, prompt_length_min,
                                  artist_prop, genre_prop, custom_prop, delete_prop, add_prop, mutate_prop,
                                  shuffle_prop, cross_prop, k)
prompt_generator.initialize_prompt_population()

mean_score = []
best_score = []
mean_prompt_length = []

In [None]:
def DeforumArgs(n_samples):
    #@markdown **Image Generation Settings**

    #@markdown Image generation settings have an impact on the speed and behavior of the evolutionary algorithm, the euler_a sampler in conjunction a with low step size and resolution is advisable for quick prompt evolution. Good prompts can then later be used to generate higher-quality images. The parameter n_samples determines how many images are generated per prompt, the scores are then averaged. A higher n_samples slows down the generation but stabilizes the optimization.
    
    #@markdown **Save & Display Settings**
    batch_name = "StableFun" #@param {type:"string"}
    outdir = get_output_folder(output_path, batch_name)
    save_samples = False #@param {type:"boolean"}
    display_samples = False #@param {type:"boolean"}

    #@markdown **Image Settings**
    n_samples = n_samples
    W = 400 #@param
    H = 400 #@param
    W, H = map(lambda x: x - x % 64, (W, H))  # resize to integer multiple of 64

    #@markdown **Init Settings**
    use_init = False #@param {type:"boolean"}
    strength = 0.5 #@param {type:"number"}
    init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"}

    #@markdown **Sampling Settings**
    seed = -1 #@param
    sampler = 'euler_ancestral' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"]
    steps = 10 #@param
    scale = 7 #@param
    ddim_eta = 0.0 #@param
    dynamic_threshold = None
    static_threshold = None   

    #@markdown **Batch Settings**
    seed_behavior = "random" #@param ["iter","fixed","random"]

    precision = 'autocast' 
    fixed_code = True
    C = 4
    f = 8

    prompt = ""
    timestring = ""
    init_latent = None
    init_sample = None
    init_c = None

    return locals()

args = SimpleNamespace(**DeforumArgs(n_samples))
args.timestring = time.strftime('%Y%m%d%H%M%S')
args.strength = max(0.0, min(1.0, args.strength))

if args.seed == -1:
    args.seed = random.randint(0, 2**32)
if not args.use_init:
    args.init_image = None
    args.strength = 0
if args.sampler == 'plms' and (args.use_init != 'None'):
    print(f"Init images aren't supported with PLMS yet, switching to KLMS")
    args.sampler = 'klms'
if args.sampler != 'ddim':
    args.ddim_eta = 0

def next_seed(args):
    if args.seed_behavior == 'iter':
        args.seed += 1
    elif args.seed_behavior == 'fixed':
        pass # always keep seed the same
    else:
        args.seed = random.randint(0, 2**32)
    return args.seed

def render_image_batch(args, prompts):    
    # create output folder for the batch
    index = 0
    
    # function for init image batching
    init_array = []
    if args.use_init:
        if args.init_image == "":
            raise FileNotFoundError("No path was given for init_image")
        if args.init_image.startswith('http://') or args.init_image.startswith('https://'):
            init_array.append(args.init_image)
        elif not os.path.isfile(args.init_image):
            if args.init_image[-1] != "/": # avoids path error by adding / to end if not there
                args.init_image += "/" 
            for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array
                if image.split(".")[-1] in ("png", "jpg", "jpeg"):
                    init_array.append(args.init_image + image)
        else:
            init_array.append(args.init_image)
    else:
        init_array = [""]

    all_images = []
        
    for image in init_array: # iterates the init images
        args.init_image = image
        results = generate(args, prompts)
        for image in results:
            all_images.append(image)
            index += 1
        args.seed = next_seed(args)
        
    return all_images

# Run

In [None]:
#@title Evolution Loop

def plot_fittness_history():
  # plot the mean score over time, the best score time, a score histogram, and the mean prompt length over time in a 2 by 2 subplot grid
  plt.figure(figsize=(16,10))
  plt.rcParams.update({'font.size': 12})
  plt.subplot(2,2,1)
  plt.plot(mean_score)
  plt.title("Mean Score")
  plt.xlabel("Generation")
  plt.ylabel("Score")
  plt.subplot(2,2,2)
  plt.plot(best_score)
  plt.title("Best Score")
  plt.xlabel("Generation")
  plt.ylabel("Score")
  plt.subplot(2,2,3)
  plt.hist(scores, bins=20)
  plt.title("Score Histogram")
  plt.xlabel("Score")
  plt.ylabel("Frequency")
  plt.subplot(2,2,4)
  plt.plot(mean_prompt_length)
  plt.title("Mean Prompt Length")
  plt.xlabel("Generation")
  plt.ylabel("Prompt Length")
  plt.tight_layout()
  plt.show()

def plot_top_9():
  top_9idx = torch.flip(np.argsort(scores)[-9:], (0,)).tolist()
  print(*[prompts[i] for i in top_9idx], sep = "\n")
  top_9 = image_population[top_9idx]
  top_9 = torch.cat([top_9[i:i+3] for i in range(0, 9, 3)], dim=3)
  top_9 = torch.cat([top_9[i:i+1] for i in range(0, 3, 1)], dim=2)
  best_img = transforms.ToPILImage()(top_9[0])
  display(best_img)
  best_img.save(os.path.join(gen_path, "best_9.png"))

def fittness_function(images):
  clip_preped_images = torch.zeros(size = (args.n_samples, 3, 224, 224))
  for i in range(len(images)):
    img = TF.resize(images[i], 224, transforms.InterpolationMode.LANCZOS)
    img = TF.center_crop(img, (224,224))
    img = TF.to_tensor(img).to(device)
    img = normalize(img)
    clip_preped_images[i] = img

  clip_image_embed = F.normalize(
      clip_model.encode_image(clip_preped_images.to(device)).float(),
      dim=-1)
  scores = aes_model(clip_image_embed).mean(axis = 1)
  return scores

def images_to_pil(images):
  r_images = []
  for x_sample in images:
    x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
    image = Image.fromarray(x_sample.astype(np.uint8))
    r_images.append(image)
  return r_images

with torch.no_grad():
  for i in range(generations):
    gen_path = get_output_folder(output_path, args.batch_name)+'/gen_'+str(i)
    os.makedirs(gen_path, exist_ok=True)
    os.makedirs(gen_path+"/best", exist_ok=True)
  
    prompts = prompt_generator.population_as_string()
    image_population = torch.zeros(size = (prompt_generator.population_count, 3, args.H, args.W))
    scores = torch.zeros(prompt_generator.population_count)

    for j in range(0, prompt_generator.population_count, args.n_samples):
      gc.collect()
      torch.cuda.empty_cache()

      images = render_image_batch(args, prompts[j:(j+args.n_samples)])
      images_pil = images_to_pil(images)
      scores[j:(j+args.n_samples)] = fittness_function(images_pil)

      for k in range(len(images)):
        image_population[j+k] = images[k]

        if args.display_samples:
          print(prompts[j+k])
          display(images_pil[k])
      
        if args.save_samples:
            filename = prompts[j+k]+".png"
            images_pil[k].save(os.path.join(gen_path, filename))

        if scores[j+k] >= cutoff:
          filename_length = min(150, len(prompts[j+k]))
          images_pil[k].save(gen_path+'/best/'+prompts[j+k][:filename_length]+'.png')
          with open(gen_path+'/best/'+prompts[j+k][:filename_length]+'.txt', 'w') as f:
            f.write(prompts[j+k])
            f.write('\n')
            f.write(str(args.seed))

    clear_output(wait=True)
    mean_score.append(scores.mean().item())
    best_score.append(max(scores).item())
    mean_prompt_length.append(np.mean([len(prompt) for prompt in prompt_generator.prompt_population]))

    plot_fittness_history()
    plot_top_9()

    prompt_generator.create_next_generation(scores)

# Rerun Prompts in higher Quality

In [None]:
prompts = ["epiphenomenalism Lovis Corinth carders"]

In [None]:
def DeforumArgs():
    #@markdown **Save & Display Settings**
    batch_name = "StableFun" #@param {type:"string"}
    outdir = get_output_folder(output_path, batch_name)
    save_settings = True #@param {type:"boolean"}
    save_samples = False #@param {type:"boolean"}
    display_samples = True #@param {type:"boolean"}

    #@markdown **Image Settings**
    n_samples = 1 #@param
    W = 512 #@param
    H = 512 #@param
    W, H = map(lambda x: x - x % 64, (W, H))  # resize to integer multiple of 64

    #@markdown **Init Settings**
    use_init = False #@param {type:"boolean"}
    strength = 0.5 #@param {type:"number"}
    init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"}

    #@markdown **Sampling Settings**
    seed = -1 #@param
    sampler = 'klms' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"]
    steps = 75 #@param
    scale = 7 #@param
    ddim_eta = 0.0 #@param
    dynamic_threshold = None
    static_threshold = None   

    #@markdown **Batch Settings**
    n_batch = 1 #@param
    seed_behavior = "iter" #@param ["iter","fixed","random"]

    #@markdown **Grid Settings**
    make_grid = False #@param {type:"boolean"}
    grid_rows = 2 #@param 

    precision = 'autocast' 
    fixed_code = True
    C = 4
    f = 8

    prompt = ""
    timestring = ""
    init_latent = None
    init_sample = None
    init_c = None

    return locals()


args_high_res = SimpleNamespace(**DeforumArgs())
args_high_res.timestring = time.strftime('%Y%m%d%H%M%S')
args_high_res.strength = max(0.0, min(1.0, args_high_res.strength))


if args_high_res.seed == -1:
    args_high_res.seed = random.randint(0, 2**32)
if not args_high_res.use_init:
    args_high_res.init_image = None
    args_high_res.strength = 0
if args_high_res.sampler == 'plms' and (args_high_res.use_init != 'None'):
    print(f"Init images aren't supported with PLMS yet, switching to KLMS")
    args_high_res.sampler = 'klms'
if args_high_res.sampler != 'ddim':
    args_high_res.ddim_eta = 0

def next_seed(args_high_res):
    if args_high_res.seed_behavior == 'iter':
        args_high_res.seed += 1
    elif args_high_res.seed_behavior == 'fixed':
        pass # always keep seed the same
    else:
        args_high_res.seed = random.randint(0, 2**32)
    return args_high_res.seed

def generate_non_batch(args, return_latent=False, return_sample=False, return_c=False):
    seed_everything(args.seed)
    os.makedirs(args.outdir, exist_ok=True)

    if args.sampler == 'plms':
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)

    model_wrap = CompVisDenoiser(model)       
    batch_size = args.n_samples
    prompt = args.prompt
    assert prompt is not None
    data = [batch_size * [prompt]]

    init_latent = None
    if args.init_latent is not None:
        init_latent = args.init_latent
    elif args.init_sample is not None:
        init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample))
    elif args.init_image != None and args.init_image != '':
        init_image = load_img(args.init_image, shape=(args.W, args.H)).to(device)
        init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
        init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space        

    sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, verbose=False)

    t_enc = int((1.0-args.strength) * args.steps)

    start_code = None
    if args.fixed_code and init_latent == None:
        start_code = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)

    callback = make_callback(sampler=args.sampler,
                            dynamic_threshold=args.dynamic_threshold, 
                            static_threshold=args.static_threshold)

    results = []
    precision_scope = autocast if args.precision == "autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                for prompts in data:
                    uc = None
                    if args.scale != 1.0:
                        uc = model.get_learned_conditioning(batch_size * [""])
                    if isinstance(prompts, tuple):
                        prompts = list(prompts)
                    c = model.get_learned_conditioning(prompts)

                    if args.init_c != None:
                        c = args.init_c

                    if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]:
                        samples = sampler_fn(
                            c=c, 
                            uc=uc, 
                            args=args, 
                            model_wrap=model_wrap, 
                            init_latent=init_latent, 
                            t_enc=t_enc, 
                            device=device, 
                            cb=callback)
                    else:

                        if init_latent != None:
                            z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                            samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.scale,
                                                    unconditional_conditioning=uc,)
                        else:
                            if args.sampler == 'plms' or args.sampler == 'ddim':
                                shape = [args.C, args.H // args.f, args.W // args.f]
                                samples, _ = sampler.sample(S=args.steps,
                                                                conditioning=c,
                                                                batch_size=args.n_samples,
                                                                shape=shape,
                                                                verbose=False,
                                                                unconditional_guidance_scale=args.scale,
                                                                unconditional_conditioning=uc,
                                                                eta=args.ddim_eta,
                                                                x_T=start_code,
                                                                img_callback=callback)

                    if return_latent:
                        results.append(samples.clone())

                    x_samples = model.decode_first_stage(samples)
                    if return_sample:
                        results.append(x_samples.clone())

                    x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                    if return_c:
                        results.append(c.clone())

                    for x_sample in x_samples:
                        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                        image = Image.fromarray(x_sample.astype(np.uint8))
                        results.append(image)
    return results

def render_high_quality_image_batch(args_high_res):
    args_high_res.prompts = prompts
    
    # create output folder for the batch
    os.makedirs(args_high_res.outdir, exist_ok=True)
    if args_high_res.save_settings or args_high_res.save_samples:
        print(f"Saving to {os.path.join(args_high_res.outdir, args_high_res.timestring)}_*")

    # save settings for the batch
    if args_high_res.save_settings:
        filename = os.path.join(args_high_res.outdir, f"{args_high_res.timestring}_settings.txt")
        with open(filename, "w+", encoding="utf-8") as f:
            json.dump(dict(args_high_res.__dict__), f, ensure_ascii=False, indent=4)

    index = 0
    
    # function for init image batching
    init_array = []
    if args_high_res.use_init:
        if args_high_res.init_image == "":
            raise FileNotFoundError("No path was given for init_image")
        if args_high_res.init_image.startswith('http://') or args_high_res.init_image.startswith('https://'):
            init_array.append(args_high_res.init_image)
        elif not os.path.isfile(args_high_res.init_image):
            if args_high_res.init_image[-1] != "/": # avoids path error by adding / to end if not there
                args_high_res.init_image += "/" 
            for image in sorted(os.listdir(args_high_res.init_image)): # iterates dir and appends images to init_array
                if image.split(".")[-1] in ("png", "jpg", "jpeg"):
                    init_array.append(args_high_res.init_image + image)
        else:
            init_array.append(args_high_res.init_image)
    else:
        init_array = [""]

    # when doing large batches don't flood browser with images
    clear_between_batches = args_high_res.n_batch >= 32

    for iprompt, prompt in enumerate(prompts):  
        args_high_res.prompt = prompt

        all_images = []

        for batch_index in range(args_high_res.n_batch):
            if clear_between_batches: 
                display.clear_output(wait=True)            
            print(f"Batch {batch_index+1} of {args_high_res.n_batch}")
            
            for image in init_array: # iterates the init images
                args_high_res.init_image = image
                results = generate_non_batch(args_high_res)
                for image in results:
                    if args_high_res.make_grid:
                        all_images.append(T.functional.pil_to_tensor(image))
                    if args_high_res.save_samples:
                        filename = f"{args_high_res.timestring}_{index:05}_{args_high_res.seed}.png"
                        image.save(os.path.join(args_high_res.outdir, filename))
                    if args_high_res.display_samples:
                        display(image)
                    index += 1
                args_high_res.seed = next_seed(args_high_res)

        #print(len(all_images))
        if args_high_res.make_grid:
            grid = make_grid(all_images, nrow=int(len(all_images)/args_high_res.grid_rows))
            grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()
            filename = f"{args_high_res.timestring}_{iprompt:05d}_grid_{args_high_res.seed}.png"
            grid_image = Image.fromarray(grid.astype(np.uint8))
            grid_image.save(os.path.join(args_high_res.outdir, filename))
            clear_output(wait=True)            
            display(grid_image)

render_high_quality_image_batch(args_high_res)    