## installation

In [None]:
!git clone https://github.com/CompVis/latent-diffusion.git
!git clone https://github.com/CompVis/taming-transformers
!pip install -e ./taming-transformers
!pip3 install transformers
!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops
!mkdir -p models/ldm/text2img-large/
!wget -O models/ldm/text2img-large/model.ckpt https://ommer-lab.com/files/latent-diffusion/nitro/txt2img-f8-large/model.ckpt

## utils

In [None]:
cd /content/latent-diffusion/

In [3]:
import sys
sys.path.append(".")
sys.path.append('/content/taming-transformers')

In [None]:
import argparse, os, sys, glob
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from einops import rearrange
from torchvision.utils import make_grid

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler


def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cuda")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.eval()
    return model

config = OmegaConf.load("configs/latent-diffusion/txt2img-1p4B-eval.yaml")  # TODO: Optionally download from same location as ckpt and chnage this logic
model = load_model_from_config(config, "../models/ldm/text2img-large/model.ckpt")  # TODO: check path

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
sampler = DDIMSampler(model)

## process

In [5]:
prompt = "3 red apples and 2 bananas on a silver plate" # the prompt to render

In [6]:
outdir = 'outputs' # dir to write results to
ddim_steps = 200 # number of ddim sampling steps
ddim_eta = 0 # ddim eta (eta=0.0 corresponds to deterministic sampling
n_iter = 1 # sample this often
H,W = 256, 256 # "image height, width in pixel space
scale = 5.0 # unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))
n_samples = 2 # how many samples to produce for the given prompt


In [None]:
os.makedirs(outdir, exist_ok=True)
outpath = outdir

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))

all_samples=list()
with torch.no_grad():
    with model.ema_scope():
        uc = None
        if scale != 1.0:
            uc = model.get_learned_conditioning(n_samples * [""])
        for n in trange(n_iter, desc="Sampling"):
            c = model.get_learned_conditioning(n_samples * [prompt])
            shape = [4, H//8, W//8]
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                              conditioning=c,
                                              batch_size=n_samples,
                                              shape=shape,
                                              verbose=False,
                                              unconditional_guidance_scale=scale,
                                              unconditional_conditioning=uc,
                                              eta=ddim_eta)

            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)

            for x_sample in x_samples_ddim:
                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png"))
                base_count += 1
            all_samples.append(x_samples_ddim)


In [8]:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = make_grid(grid, nrow=n_samples)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))

In [None]:
Image.open(os.path.join(outpath, f'{prompt.replace(" ", "-")}.png'))