# **Mafia Diffusion v0.1**
[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from [HuggingFace](https://huggingface.co/CompVis/stable-diffusion).  

---

Based on [deforum](https://discord.gg/upmXXsrwZc) notebook

Tweaked by Prof. R.J#1965

---

**2022-09-02 update**  
chloebubble#9999 [(Twitter)](https://twitter.com/0xCrung)
  
**changelog:**  
  
~ wrote/rewrote functions for file saving; prompts are now saved in filenames and optionally to a text file; things go neatly in their own directories now 
~ various bugs and annoyances fixed  
~ TODO: push updated qol functions; add a check to prevent repeated prompts being written to file


In [None]:
#@markdown **NVIDIA GPU**
import subprocess
sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')
print(sub_p_res)

# Setup and Config

In [None]:
#@markdown **Setup Environment**

setup_environment = True #@param {type:"boolean"}
print_subprocess = True #@param {type:"boolean"}

if setup_environment:
    import subprocess
    print("...setting up environment")
    all_process = [['pip', 'install', 'torch==1.11.0+cu113', 'torchvision==0.12.0+cu113', 'torchaudio==0.11.0', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],
                   ['pip', 'install', 'omegaconf==2.1.1', 'einops==0.3.0', 'pytorch-lightning==1.4.2', 'torchmetrics==0.6.0', 'torchtext==0.2.3', 'transformers==4.19.2', 'kornia==0.6'],
                   ['git', 'clone', 'https://github.com/ProfRJ/stable-diffusion'],
                   ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],
                   ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],
                   ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'resize-right', 'torchdiffeq', 'python-slugify'],
                 ]
    for process in all_process:
        running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')
        if print_subprocess:
            print(running)
    
    print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8'))
    with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:
        f.write('')

In [None]:
#@markdown **Python Definitions**
import json
from IPython import display
from google.colab import output

import argparse, glob, os, pathlib, subprocess, sys, time
import cv2
import numpy as np
import pandas as pd
import random
import requests
import shutil
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.transforms.functional as TF
import math
from contextlib import contextmanager, nullcontext
from einops import rearrange, repeat
from itertools import islice
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from skimage.exposure import match_histograms
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from types import SimpleNamespace
from torch import autocast
from os.path import exists
from slugify import slugify, smart_truncate
from slugify.__main__ import slugify_params, parse_args

sys.path.append('./src/taming-transformers')
sys.path.append('./src/clip')
sys.path.append('./stable-diffusion/')
sys.path.append('./k-diffusion')

from helpers import save_samples, sampler_fn
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

from k_diffusion import sampling
from k_diffusion.external import CompVisDenoiser

from qol_functions import split_weighted_subprompts, setres, get_output_folder, load_img, next_seed, split_batches_from_samples, get_gpu_information

class CFGDenoiser(nn.Module):
    def __init__(self, model, args):
        super().__init__()
        self.inner_model = model
        self.latent_truncation = args.latent_truncation
        self.threshold = args.threshold

    def forward(self, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)

        res = uncond + (cond - uncond) * cond_scale
        if self.latent_truncation: 
            maxval = 0.0 + torch.max(res).cpu().numpy()
            minval = 0.0 + torch.min(res).cpu().numpy()
            if maxval < self.threshold and minval > -self.threshold:
                return res
            if maxval > self.threshold:
                maxval = min(max(1, 0.707*maxval), self.threshold)
            if minval < -self.threshold:
                minval = max(min(-1, 0.707*minval), -self.threshold)
            return torch.clamp(res, min=minval, max=maxval)
        else:
            return res

def make_callback(sampler, dynamic_threshold=None, static_threshold=None):  
    # Creates the callback function to be passed into the samplers
    # The callback function is applied to the image after each step
    def dynamic_thresholding_(img, threshold):
        # Dynamic thresholding from Imagen paper (May 2022)
        s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))
        s = np.max(np.append(s,1.0))
        torch.clamp_(img, -1*s, s)
        torch.FloatTensor.div_(img, s)

    # Callback for samplers in the k-diffusion repo, called thus:
    #   callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
    def k_callback(args_dict):
        if static_threshold is not None:
            torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold)
        if dynamic_threshold is not None:
            dynamic_thresholding_(args_dict['x'], dynamic_threshold)

    # Function that is called on the image (img) and step (i) at each step
    def img_callback(img, i):
        # Thresholding functions
        if dynamic_threshold is not None:
            dynamic_thresholding_(img, dynamic_threshold)
        if static_threshold is not None:
            torch.clamp_(img, -1*static_threshold, static_threshold)

    if sampler in ["plms","ddim"]: 
        # Callback function formated for compvis latent diffusion samplers
        callback = img_callback
    else: 
        # Default callback function uses k-diffusion sampler variables
        callback = k_callback

    return callback

def generate(args, return_latent=False, return_sample=False, return_c=False):
    seed_everything(args.seed)
    os.makedirs(args.outdir, exist_ok=True)
    if args.sampler == 'plms':
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)
    
    results = []
    for batch_sequence in range(args.batch_sequences):
        batch_size = args.batch_size_schedule[batch_sequence]
        args.n_samples = batch_size
        init_latent = None
        if args.init_latent is not None:
            init_latent = args.init_latent
        elif args.init_sample is not None:
            init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample))
        elif args.init_image != None and args.init_image != '':
            init_image = load_img(args.init_image, shape=(args.W, args.H)).to(device)
            init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
            init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image))  # move to latent space        

        sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, verbose=False)

        t_enc = int((1.0-args.strength) * args.steps)

        start_code = None
        if args.fixed_code and init_latent == None:
            start_code = torch.randn([batch_size, args.C, args.H // args.f, args.W // args.f], device=device)

        callback = make_callback(sampler=args.sampler,
                                dynamic_threshold=args.dynamic_threshold, 
                                static_threshold=args.static_threshold)

        precision_scope = autocast if args.precision == "autocast" else nullcontext
        with torch.no_grad():
            with precision_scope("cuda"):
                with model.ema_scope():
                    for n in range(1):
                        model_wrap = CompVisDenoiser(model, args)       
                        prompt = args.prompt
                        data = [batch_size * [prompt]]
                        assert prompt is not None
                        for prompts in data:
                            uc = None
                            if args.scale != 1.0:
                                uc = model.get_learned_conditioning(batch_size * [""])
                            if isinstance(prompts, tuple):
                                prompts = list(prompts)
                            c = model.get_learned_conditioning(prompts)

                            if args.init_c != None:
                                c = args.init_c

                            if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]:
                                shape = [args.C, args.H // args.f, args.W // args.f]
                                sigmas = model_wrap.get_sigmas(args.steps)
                                if args.use_init:
                                    sigmas = sigmas[len(sigmas)-t_enc-1:]
                                    x = init_latent + torch.randn([batch_size, *shape], device=device) * sigmas[0]
                                else:
                                    x = torch.randn([args.n_samples, *shape], device=device) * sigmas[0]
                                model_wrap_cfg = CFGDenoiser(model_wrap, args)
                                extra_args = {'cond': c, 'uncond': uc, 'cond_scale': args.scale}
                                if args.sampler=="klms":
                                    samples = sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)
                                elif args.sampler=="dpm2":
                                    samples = sampling.sample_dpm_2(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)
                                elif args.sampler=="dpm2_ancestral":
                                    samples = sampling.sample_dpm_2_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)
                                elif args.sampler=="heun":
                                    samples = sampling.sample_heun(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)
                                elif args.sampler=="euler":
                                    samples = sampling.sample_euler(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)
                                elif args.sampler=="euler_ancestral":
                                    samples = sampling.sample_euler_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=False, callback=callback)
                            else:

                                if init_latent != None:
                                    z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
                                    samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.scale,
                                                            unconditional_conditioning=uc,)
                                else:
                                    if args.sampler == 'plms' or args.sampler == 'ddim':
                                        shape = [args.C, args.H // args.f, args.W // args.f]
                                        samples, _ = sampler.sample(S=args.steps,
                                                                        conditioning=c,
                                                                        batch_size=args.n_samples,
                                                                        shape=shape,
                                                                        verbose=False,
                                                                        unconditional_guidance_scale=args.scale,
                                                                        unconditional_conditioning=uc,
                                                                        eta=args.ddim_eta,
                                                                        x_T=start_code,
                                                                        img_callback=callback)

                            if return_latent:
                                results.append(samples.clone())

                            x_samples = model.decode_first_stage(samples)
                            if return_sample:
                                results.append(x_samples.clone())

                            x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)

                            if return_c:
                                results.append(c.clone())

                            for x_sample in x_samples:
                                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                                image = Image.fromarray(x_sample.astype(np.uint8))
                                results.append(image)
    return results

def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:
    sample = ((sample.astype(float) / 255.0) * 2) - 1
    sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)
    sample = torch.from_numpy(sample)
    return sample

def sample_to_cv2(sample: torch.Tensor) -> np.ndarray:
    sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32)
    sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)
    sample_int8 = (sample_f32 * 255).astype(np.uint8)
    return sample_int8

In [None]:
#@markdown **Model and Output Paths**
# ask for the link
print("Local Path Variables:\n")

models_path = "/content/models" #@param {type:"string"}
output_path = "/content/output" #@param {type:"string"}

#@markdown **Google Drive Path Variables (Optional)**
mount_google_drive = True #@param {type:"boolean"}
force_remount = False

if mount_google_drive:
    from google.colab import drive # type: ignore
    try:
        drive_path = "/content/drive"
        drive.mount(drive_path,force_remount=force_remount)
        models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"}
        output_path_gdrive = "/content/drive/MyDrive/DiffusionOutput/StableDiffusion" #@param {type:"string"}
        models_path = models_path_gdrive
        output_path = output_path_gdrive
    except:
        print("...error mounting drive or with drive path variables")
        print("...reverting to default path variables")

import os
os.makedirs(models_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)

print(f"models_path: {models_path}")
print(f"output_path: {output_path}")

In [None]:
#@markdown **Select and Load Model**

model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"]
model_checkpoint =  "sd-v1-4.ckpt" #@param ["custom","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt"]
custom_config_path = "" #@param {type:"string"}
custom_checkpoint_path = "" #@param {type:"string"}

check_sha256 = True #@param {type:"boolean"}

load_on_run_all = True #@param {type: 'boolean'}
half_precision = True # needs to be fixed

model_map = {
    "sd-v1-4-full-ema.ckpt": {'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a'},
    "sd-v1-4.ckpt": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},
    "sd-v1-3-full-ema.ckpt": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'},
    "sd-v1-3.ckpt": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'},
    "sd-v1-2-full-ema.ckpt": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'},
    "sd-v1-2.ckpt": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'},
    "sd-v1-1-full-ema.ckpt": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'},
    "sd-v1-1.ckpt": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}
}

def wget(url, outputdir):
    res = subprocess.run(['wget', url, '-v', '--show-progress', '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    print(res)


# config path
ckpt_config_path = custom_config_path if model_config == "custom" else os.path.join(models_path, model_config)
if os.path.exists(ckpt_config_path):
    print(f"{ckpt_config_path} exists")
else:
    ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml"
print(f"Using config: {ckpt_config_path}")

# checkpoint path or download
ckpt_path = custom_checkpoint_path if model_checkpoint == "custom" else os.path.join(models_path, model_checkpoint)


if os.path.exists(ckpt_path):
    print(f"{ckpt_path} exists")
    ckpt_valid = True
else:
    try: 
        download_link = f'http://batbot.ai/models/stable-diffusion/{model_checkpoint}'
        print(f"!wget -O {models_path}/{model_checkpoint} {download_link}")
        wget(download_link, models_path)
        ckpt_valid = True
    except:
        os.remove(os.path.join(models_path, model_checkpoint))
        print(f"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}")
        ckpt_valid = False

if check_sha256 and model_checkpoint != "custom" and ckpt_valid:
    import hashlib
    print("\n...checking sha256")
    with open(ckpt_path, "rb") as f:
        bytes = f.read() 
        hash = hashlib.sha256(bytes).hexdigest()
        del bytes
    if model_map[model_checkpoint]["sha256"] == hash:
        print("hash is correct\n")
    else:
        print("hash in not correct\n")
        ckpt_valid = False

if ckpt_valid:
    print(f"Using ckpt: {ckpt_path}")

def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):
    map_location = "cuda" #@param ["cpu", "cuda"]
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location=map_location)
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    if half_precision:
        model = model.half().to(device)
    else:
        model = model.to(device)
    model.eval()
    return model

if load_on_run_all and ckpt_valid:
    local_config = OmegaConf.load(f"{ckpt_config_path}")
    model = load_model_from_config(local_config, f"{ckpt_path}",half_precision=half_precision)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    model = model.to(device)

### Temporary place for qol module script

In [None]:
#@title
from torchvision.datasets.utils import download_url
from ldm.util import instantiate_from_config
import torch
import os
# todo ?
from google.colab import files
from IPython.display import Image as ipyimg
import ipywidgets as widgets
from PIL import Image
import numpy as np
from numpy import asarray
from einops import rearrange, repeat
import torch, torchvision
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
import time
from omegaconf import OmegaConf
import requests
import PIL
from PIL import Image
from torchvision.transforms import functional as TF
import math
import random
import subprocess

def add_noise(sample: torch.Tensor, noise_amt: float):
    return sample + torch.randn(sample.shape, device=sample.device) * noise_amt

def split_weighted_subprompts(text):
    """
    grabs all text up to the first occurrence of ':' 
    uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
    if ':' has no value defined, defaults to 1.0
    repeats until no text remaining
    """
    remaining = len(text)
    prompts = []
    weights = []
    while remaining > 0:
        if ":" in text:
            idx = text.index(":") # first occurrence from start
            # grab up to index as sub-prompt
            prompt = text[:idx]
            remaining -= idx
            # remove from main text
            text = text[idx+1:]
            # find value for weight 
            if " " in text:
                idx = text.index(" ") # first occurence
            else: # no space, read to end
                idx = len(text)
            if idx != 0:
                try:
                    weight = float(text[:idx])
                except: # couldn't treat as float
                    print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
                    weight = 1.0
            else: # no value found
                weight = 1.0
            # remove from main text
            remaining -= idx
            text = text[idx+1:]
            # append the sub-prompt and its weight
            prompts.append(prompt)
            weights.append(weight)
        else: # no : found
            if len(text) > 0: # there is still text though
                # take remainder as weight 1
                prompts.append(text)
                weights.append(1.0)
            remaining = 0
    return prompts, weights

def setres(image_shape, W, H):
    image_shape, _, _ = image_shape.partition(' |')
    return {
        "Custom": (W, H),
        "Square": (512, 512),
        "Large Square": (768, 768),
        "Landscape": (704, 512),
        "Large Landscape": (767, 640),
        "Portrait": (512, 704),
        "Large Portrait": (640, 768)  
    }.get(image_shape)

def get_output_folder(output_path,batch_folder=None):
    yearMonth = time.strftime('%Y-%m')
    out_path = os.path.join(output_path,yearMonth)
    if batch_folder != "":
        out_path = os.path.join(out_path,batch_folder)
        # we will also make sure the path suffix is a slash if linux and a backslash if windows
        #if out_path[-1] != os.path.sep:
        #    out_path += os.path.sep
    os.makedirs(out_path, exist_ok=True)
    return out_path

def get_prompts_folder(output_path,batch_folder=None,save_prompts_file=False):
    yearMonth = time.strftime('%Y-%m/')
    out_path = os.path.join(output_path,yearMonth)
    if save_prompts_file:
      if batch_folder != "":
        prompts_folder = os.path.join(out_path,batch_folder,'prompts')
      else:
        prompts_folder = os.path.join(out_path,'prompts')
      if out_path[-1] != os.path.sep:
        out_path += os.path.sep
    os.makedirs(prompts_folder, exist_ok=True)
    return prompts_folder


def load_img(path, shape):
    
    if path.startswith('http://') or path.startswith('https://'):
        image = Image.open(requests.get(path, stream=True).raw).convert('RGB')
    else:
        image = Image.open(path).convert('RGB')

    fac = max(shape[0] / image.size[0], shape[1] / image.size[1])
    image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
    image = np.array(image).astype(np.float16) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    image = TF.center_crop(image, shape[::-1])
    return 2.*image - 1.
    

def make_grid(images):
    mode = images[0].mode
    size = images[0].size

    n = len(images)
    x = math.ceil(n**0.5)
    y = math.ceil(n / x)

    output = Image.new(mode, (size[0] * x, size[1] * y))
    for i, image in enumerate(images):
        cur_x, cur_y = i % x, i // x
        output.paste(image, (size[0] * cur_x, size[1] * cur_y))
    return output


def get_gpu_information(image_size):
    memory = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    memory = memory.split(', ')[1].strip(' MiB')
    path = f'{os.getcwd()}/stable-diffusion/helpers/gpu-info'
    f = open(f"{path}/{memory}.txt","r")
    lines = f.readlines()
    max_samples = 0
    for line in lines:
        line = line.split(' | ')
        max_res = int(line[1])
        if max_res >= image_size:
            max_samples = int(line[0])
            continue
        else:

            break
    if max_samples == 0:
        raise error_message("Specified resolution is too large to fit on vram.")
    return max_samples


def split_batches_from_samples(n_samples, image_size):
    remaining_samples = n_samples
    batch_size_schedule = []
    max_samples = get_gpu_information(image_size)
    batch_sequences = int(math.ceil(n_samples/max_samples))
    for batch_sequence in range(batch_sequences):
        while remaining_samples > 0:
            if remaining_samples/max_samples <= 1:
                batch_size_schedule.append(remaining_samples)
                remaining_samples -= remaining_samples
            else:
                batch_size_schedule.append(max_samples)
                remaining_samples -= max_samples
    return (batch_sequences, batch_size_schedule)


def next_seed(args):
    if args.seed_behavior == 'iter':
        args.seed += 1
    elif args.seed_behavior == 'fixed':
        pass # always keep seed the same
    else:
        args.seed = random.randint(0, 2**32)
    return args.seed

class error_message(Exception):
       pass

# Run

In [None]:
full_prompts = [
          # "cyberpunk horror render of a masterfully sculpted, beautifully decaying porcelain cyborg princess, 4k cgsociety portrait render by WLOP"
          #'steampunk transcendence by Greg Rutkowski and Alex Grey, cyberpunk transcendence by Ross Tran and Wojciech Siudmak, conceptartworld 4k, outsider art inspired high detail path traced architectural render, colossal aztec-cyberpunk hybrid deities',
          #'a magical portal to the vaporwave dimension, enchanted nostalgiacore 90s artwork by macintosh plus, greg rutkowski, steven belledin, 4k desktop wallpaper',
          'cyberpunk horror render of a decaying mecha-ghoul by WLOP, cgsociety 4k portrait render, highly detailed filmic render',
]

In [None]:
def DeforumArgs():
    #@markdown **Save & Display Settings**
    batch_name = "mechaghouls" #@param {type:"string"}
    outdir = get_output_folder(output_path, batch_name)
    save_settings = True #@param {type:"boolean"}
    save_samples = True #@param {type:"boolean"}
    display_samples = True #@param {type:"boolean"}
    save_prompts = True #@param {type:"boolean"}
    prompts_folder = get_prompts_folder(output_path, batch_name, save_prompts)
    #@markdown **Image Settings**
    n_samples = 16 #@param {type:"integer"}
    image_shape = "Square | 512x512" #@param ["Custom", "Square | 512x512", "LargeSquare | 768x768", "Landscape | 704x512", "Large Landscape | 768x640", "Portrait | 512x704", "Large Portrait | 640x768"]
    W = 512 #@param {type:"slider", min:256, max:1536, step:64}
    H = 512 #@param {type:"slider", min:256, max:1536, step:64}
    (W, H) = setres(image_shape, W, H)

    #@markdown **Init Settings**
    use_init = False #@param {type:"boolean"}
    strength = 0.25 #@param {type:"number"}
    init_image = "" #@param {type:"string"}

    #@markdown **Sampling Settings**
    seed = -1 #@param
    sampler = 'euler_ancestral' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"]
    steps = 25 #@param
    scale = 7 #@param
    ddim_eta = 0.0 #@param
    latent_truncation = True
    threshold = 2.5

    dynamic_threshold = None
    static_threshold = None   

    #@markdown **Batch Settings**
    n_batch = 1 #@param
    seed_behavior = "random" #@param ["iter","fixed","random"]

    #@markdown **Grid Settings**
    make_grid = True #@param {type:"boolean"}

    (batch_sequences, batch_size_schedule) = split_batches_from_samples(n_samples, W*H)
    
    precision = 'autocast' 
    fixed_code = True
    C = 4
    f = 8

    prompt = ""
    timestring = ""
    init_latent = None
    init_sample = None
    init_c = None

    return locals()


args = SimpleNamespace(**DeforumArgs())
args.timestring = time.strftime('%Y-%m-%d-%H%M')
args.strength = max(0.0, min(1.0, args.strength))


if args.seed == -1:
    args.seed = random.randint(0, 2**32)
if not args.use_init:
    args.init_image = None
    args.strength = 0
if args.sampler == 'plms' and args.use_init:
    print(f"Init images aren't supported with PLMS yet, switching to KLMS")
    args.sampler = 'klms'
if args.sampler != 'ddim':
    args.ddim_eta = 0

# save full unsanitised prompts in a text file
if args.save_prompts:
  os.makedirs(args.prompts_folder, exist_ok=True)
  os.path.join(args.prompts_folder)
  prompts_list = '\n'.join(map(str,full_prompts))
  prompts_path = f'{args.prompts_folder}/prompts.txt'
  if not os.path.exists(prompts_path):
    open(prompts_path, 'w').close()
  with open(prompts_path, 'a+') as file:
    file.write(prompts_list + "\n")
  

def render_image_batch(args):
    # create output folder for the batch
    os.makedirs(args.outdir, exist_ok=True)
    if args.save_settings or args.save_samples:
        print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*")            

    # save settings for the batch
    if args.save_settings:
      outdir_settings = f'{args.outdir}/settings'
      os.makedirs(outdir_settings, exist_ok=True)
      settings_dict = dict(args.__dict__)
      settings_filename = os.path.join(outdir_settings, f"{args.timestring}_settings.txt")
    
    index = 0
    
    # function for init image batching
    init_array = []
    if args.use_init:
        if args.init_image == "":
            raise FileNotFoundError("No path was given for init_image")
        if args.init_image.startswith('http://') or args.init_image.startswith('https://'):
            init_array.append(args.init_image)
        elif not os.path.isfile(args.init_image):
            if args.init_image[-1] != "/": # avoids path error by adding / to end if not there
                args.init_image += "/" 
            for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array
                if image.split(".")[-1] in ("png", "jpg", "jpeg"):
                    init_array.append(args.init_image + image)
        else:
            init_array.append(args.init_image)
    else:
        init_array = [""]

    # when doing large batches don't flood browser with images
    clear_between_batches = args.n_batch >= 32 or args.n_samples*args.n_batch >= 32 
    
    print('n_samples will run over', args.batch_sequences, 'batches with the following schedule', args.batch_size_schedule)
    for iprompt, prompt in enumerate(full_prompts):  
        
        all_images = []

        for batch_index in range(args.n_batch):
            if clear_between_batches: 
                display.clear_output(wait=True)            
            print(f"Batch {batch_index+1} of {args.n_batch}")
            
            for image in init_array: # iterates the init images
                args.init_image = image
                text = split_weighted_subprompts(prompt)
                args.prompt, args.weight = text[0][0], text[1][0]
                results = generate(args)
                for image in results:
                    if args.make_grid:
                        all_images.append(T.functional.pil_to_tensor(image))
                    if args.save_samples:
                        sanitised_prompt = slugify(args.prompt, replacements=[[',', '_'],['.','.']], separator='_', max_length=96)
                        filename = f'{args.timestring}_{index:05}_{sanitised_prompt}_{args.seed}.png'
                        image.save(os.path.join(args.outdir, filename))
                    if args.display_samples:
                        display.display(image)
                    index += 1
                args.seed = next_seed(args)

        # Make grid of entire prompt run
        if args.make_grid:
            outdir_grid = f'{args.outdir}/grid'
            os.makedirs(outdir_grid, exist_ok=True)
            grid = torchvision.utils.make_grid(all_images, nrow=int(math.ceil(len(all_images)**0.5)), padding=0)
            grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()
            sanitised_prompt = slugify(args.prompt, replacements=[[',', '_'],['.','.']], separator='_', max_length=96)
            filename = f'{args.timestring}_grid_{sanitised_prompt}_{args.seed}.png'
            grid_image = Image.fromarray(grid.astype(np.uint8))
            grid_image.save(os.path.join(outdir_grid, filename))
            #display.clear_output(wait=True)            
            display.display(grid_image)

    # save settings for the batch
    if args.save_settings:
        with open(settings_filename, "w+", encoding="utf-8") as f:
            json.dump(dict(settings_dict), f, ensure_ascii=False, indent=4)
    print(f'Prompts saved to {prompts_path}')
    print(f'Settings saved to {settings_filename}')
render_image_batch(args)    