In [1]:
import argparse, os
import cv2
import torch
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
from torchvision.utils import make_grid
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
from imwatermark import WatermarkEncoder
import time

In [2]:
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.dpm_solver import DPMSolverSampler

In [3]:
# ESRGAN Imports
import os.path as osp
import glob
import RRDBNet_arch as arch

In [4]:
def chunk(it, size):
    it = iter(it)
    return iter(lambda: tuple(islice(it, size)), ())


def load_model_from_config(config, ckpt, device=torch.device("cuda"), verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    if device == torch.device("cuda"):
        model.cuda()
    elif device == torch.device("cpu"):
        model.cpu()
        model.cond_stage_model.device = "cpu"
    else:
        raise ValueError(f"Incorrect device name. Received: {device}")
    model.eval()
    return model

In [62]:
# Set up ESRGAN model
esrgan_model_path = './esrgan_models/RRDB_ESRGAN_x4.pth'
esrgan_device = torch.device('cuda')
esrgan_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
esrgan_model.load_state_dict(torch.load(esrgan_model_path), strict=True)
esrgan_model.eval()
esrgan_model = esrgan_model.to(esrgan_device)

def upscale_samples(samples_img_gen):
    samples_img_gen = torch.cat(samples_img_gen, dim=0).to(torch.float32)
    with torch.no_grad():
        output = esrgan_model(samples_img_gen).data.squeeze().float().cpu().clamp_(0, 1).numpy()
    if len(output.shape) == 3:
        output = np.array([output])
    output = np.transpose(output, (0, 2, 3, 1))
    output = (output * 255.0).round()
    return output

In [6]:
# Set up stable diffusion model
seed = 42
seed_everything(seed)
config = OmegaConf.load("../configs/stable-diffusion/v2-inference.yaml")
device = torch.device("cuda")
model = load_model_from_config(config, "../model_weights/v2-1_512-ema-pruned.ckpt", device)

Global seed set to 42


Loading model from ../model_weights/v2-1_512-ema-pruned.ckpt
Global Step: 220000


A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'


LatentDiffusion: Running in eps-prediction mode
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 1024 and using 5 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 1024 and using 10 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 20 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 

In [83]:
plms = False
dpm = False
if plms:
    sampler = PLMSSampler(model, device=device)
elif dpm:
    sampler = DPMSolverSampler(model, device=device)
else:
    sampler = DDIMSampler(model, device=device)
outpath = '../outputs/sandbox'

n_samples=1
batch_size = 1
n_rows = 1

sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
sample_count = 0
base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1

In [84]:
start_code = None
precision_scope = autocast
sampler = DDIMSampler(model, device=device)
scale = 9
opt_C = 4
opt_H = 576
opt_f = 8
opt_W = 1024
steps=50
ddim_eta = 0.0
seed_everything(np.random.randint(9999999))

prompt = 'A realistic, highly-detailed australian shepherd catching a fish with its mouth from a river in the countryside, mountainous background, beautiful sunny day, symmetrical face, beautiful eyes, detailed eyes, detailed paws, symmetrical legs, realistic fur, high-resolution.'
data = [batch_size * [prompt]]

start_time = time.time()

with torch.no_grad(), \
    precision_scope("cuda"), \
    model.ema_scope():
        all_samples = list()
        for n in trange(n_samples, desc="Sampling"):
            for prompts in tqdm(data, desc="data"):
                uc = None
                if scale != 1.0:
                    uc = model.get_learned_conditioning(batch_size * [""])
                if isinstance(prompts, tuple):
                    prompts = list(prompts)
                c = model.get_learned_conditioning(prompts)
                shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
                samples, _ = sampler.sample(S=steps,
                                                    conditioning=c,
                                                    batch_size=batch_size,
                                                    shape=shape,
                                                    verbose=False,
                                                    unconditional_guidance_scale=scale,
                                                    unconditional_conditioning=uc,
                                                    eta=ddim_eta,
                                                    x_T=start_code)

                x_samples = model.decode_first_stage(samples)
                x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
                all_samples.append(x_samples)

print('Upscaling...')
upscaled_samples = upscale_samples(all_samples)

end_time = time.time()
print(f"Done. {round(end_time-start_time, 3)}")

Global seed set to 2086993
Sampling:   0%|          | 0/1 [00:00<?, ?it/s]

Data shape for DDIM sampling is (1, 4, 72, 128), eta 0.0
Running DDIM Sampling with 50 timesteps



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
DDIM Sampler: 100%|██████████| 50/50 [00:05<00:00,  9.17it/s]
data: 100%|██████████| 1/1 [00:05<00:00,  5.70s/it]
Sampling: 100%|██████████| 1/1 [00:05<00:00,  5.71s/it]


Upscaling...
Done. 9.691


In [81]:
for upscaled_img in upscaled_samples:
    img = Image.fromarray(upscaled_img.astype(np.uint8))
    img.save(os.path.join(sample_path, f"{base_count:05}.png"))
    base_count += 1

In [57]:
# for x_samples in all_samples:
#     for x_sample in x_samples:
#         x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
#         img = Image.fromarray(x_sample.astype(np.uint8))
#         img.save(os.path.join(sample_path, f"{base_count:05}.png"))