# CompVis Stable Diffusion txt2img Colab notebook

### A fast and high fidelity diffusion model for the masses.

# Step 1.

### Download your model checkpoint from the internet and upload it to your Google drive (Must have atleast 3 gigabyte of space available) . Make sure to verify the SHA-256 hash using 'sha256sum $checkpoint.ckpt' to ensure the model was not corrupted.

### Here is a link on Reddit on how to download the checkpoints. [Reddit link](https://www.reddit.com/r/StableDiffusion/comments/wv4sqt/checkpoint_v14_mirror_no_huggingface_account/)

### Select the GPU instance if you haven't already to get ready.


##### *As always, I am simply standing on the shoulder of giants here. A massive thank you CompVis, LAION and stability.ai for making this possible.*

#### This version is a less confusing way of using the model compared to using the Huggingface diffusers method.

In [None]:
!nvidia-smi -L #@markdown # Make sure this says T4, P100, V100 or A100. Do not use this with a P4 or K80 as they will not work.

# Step 2.

Mount your Google drive and authenticate your account with Colab. You only need to this once per session.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Step 3.

Run this cell to clone the code and install dependency packages. You only need to this once per session.

In [None]:
!pip install diffusers==0.2.4
!pip install transformers scipy ftfy
!pip install dotmap
!git clone -b colab https://github.com/JohnnyRacer/stable-diffusion

In [None]:
save_to_drive = True
model_path = "/content/drive/MyDrive/AI/SD/weights" # Change this to your path
outputs_path =   "/content/drive/MyDrive/AI/SD/outputs" if save_to_drive else "./outputs"
skip_installs = True
latent_diffusion_model = 'sd-v1-4' #
model_fp = os.path.join(model_path, f'{latent_diffusion_model}.ckpt' )

# Step 4.

Run this cell to import the required packages and load the model. You only need to this once per session.

##### *If the model load fails due to missing dict keys, then your model is most likely corrupt. Try deleting from your drive and reuploading to try again.*

In [None]:
import argparse, os, sys, glob
sys.path.append('./stable-diffusion')
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
import time
from pytorch_lightning import seed_everything
from torch import autocast
from torch import nn
from torch.nn import functional as F
from torchvision import transforms as T
from torchvision.transforms import functional as TF
from contextlib import contextmanager, nullcontext
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
import io

def load_model_from_config(config, ckpt, verbose=False,use_half=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.cuda()
    model.eval()
    if use_half:
        model.half()
    return model

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
config = OmegaConf.load("stable-diffusion/configs/stable-diffusion/v1-inference.yaml")  # TODO: Optionally download from same location as ckpt and chnage this logic
model = load_model_from_config(config,model_fp)  # TODO: check path
if torch.cuda.is_available(): # Use fp16 whenever on GPU.  Running in full precision just slows everything down. Especially on Volta and newer cards (Turing, Ampere) .
    model = model.to(device).eval().half() 

In [None]:


def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols
    
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid


def check_safety(x_image):
    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
    print(x_checked_image)
    print(x_image)
    print(type(x_checked_image))
    print(type(x_image))
    assert x_checked_image.shape[0] == len(has_nsfw_concept)
    for i in range(len(has_nsfw_concept)):
        if has_nsfw_concept[i]:
            x_checked_image[i] = load_replacement(x_checked_image[i])
    return x_image, False

def dummy_safety_check(x_image):
    return x_image, False


def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images

def load_replacement(x):
    try:
        hwc = x.shape
        y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
        y = (np.array(y)/255.0).astype(x.dtype)
        assert y.shape == x.shape
        return y
    except Exception:
        return x


def display_handler(x,i,cadance = 5, decode = True,save_image=True):
    img_tensor = x
    if i%cadance==0:
        if decode: 
            x = model.decode_first_stage(x)
        grid = make_grid(torch.clamp((x+1.0)/2.0, min=0.0, max=1.0),round(x.shape[0]**0.5+0.2))
        grid = 255. * rearrange(grid, 'c h w -> h w c').detach().cpu().numpy()
        image_grid = grid.copy(order = "C") 
        with io.BytesIO() as output:
            im = Image.fromarray(grid.astype(np.uint8))
            #display.display(im)
            if save_image:
                im.save(output, format = "PNG")



from torch import autocast
def do_run():
    opt.output_images = []
    if opt.plms:
        sampler = PLMSSampler(model)
    else:
        sampler = DDIMSampler(model)
    outpath = opt.outdir
    batch_size = opt.n_samples
    data = [batch_size * [opt.prompt]]
    #os.makedirs(opt.outdir, exist_ok=True)
  
    n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
    if opt.fixed_code:
        start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
    precision_scope = autocast if opt.precision=="autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                all_samples = list()
                for n in trange(opt.n_iter, desc="Sampling"):
                    for prompts in tqdm(data, desc="data"):
                        uc = None
                        if opt.scale != 1.0:
                            uc = model.get_learned_conditioning(batch_size * [""])
                        c = model.get_learned_conditioning([opt.prompt])
                        shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
                        samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
                                                         conditioning=c,
                                                         batch_size=opt.n_samples,
                                                         shape=shape,
                                                         verbose=False,
                                                         unconditional_guidance_scale=opt.scale,
                                                         unconditional_conditioning=uc,
                                                         eta=opt.ddim_eta,
                                                         img_callback=display_handler
                                                         x_T=start_code)

                        x_samples_ddim = model.decode_first_stage(samples_ddim)
                        x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
                        x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
                        
                        safety_func = check_safety if opt.use_safety else dummy_safety_check
                        
                        x_checked_image, has_nsfw_concept = safety_func(x_samples_ddim)

                        x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

                        if not opt.skip_save:
                            for x_sample in x_checked_image_torch:
                                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                                img = Image.fromarray(x_sample.astype(np.uint8))
                                opt.output_images.append(img)
                                #img = put_watermark(img, wm_encoder)
                                #img.save(os.path.join(sample_path, f"{base_count:05}.png"))
                                #base_count += 1



In [None]:
from dotmap import DotMap

opt = DotMap()
opt.precision= "autocast" if torch.cuda.is_available() else "cpu"
opt.prompt = "A watercolor portrait of a geisha princess with tattoos with intricate floral patterns background, trending on artstation"
opt.uc = "" # Optional negative prompt
opt.plms = False 
opt.scale = 8.5
opt.W = 512
opt.H = 512
opt.n_iter = 1
opt. ddim_eta = 0.0
opt.ddim_steps = 50
opt.n_samples = 1 # Samples you want
opt.n_rows = 0
opt.f = 8
opt.C = 4
opt.fixed_code = True
opt.use_safety = False # Set this to true for the NSFW version.
opt.make_grid = True



In [None]:
if opt.use_safety:
    safety_model_id = "CompVis/stable-diffusion-safety-checker"
    safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
    safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)


In [None]:
do_run()

if opt.make_grid:
    rows = 3
    columns = 3
    grid = image_grid(opt.output_images, rows=rows, cols=columns)
    display(grid)