In [1]:
from argparse import Namespace
from functools import partial
from pathlib import Path
from pprint import pprint

import numpy as np
import torch
from torch import nn
import yaml

In [2]:
assert torch.cuda.is_available()

In [3]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [4]:
from ldm.models.autoencoder import AutoencoderKL  # First Stage
from ldm.modules.encoders.modules import BERTEmbedder  # Condition Model
from ldm.modules.diffusionmodules.openaimodel import UNetModel  # Diffusion Model

  from .autonotebook import tqdm as notebook_tqdm
2024-03-08 04:38:58.178902: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-03-08 04:38:58.222795: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [24]:
class DiffusionWrapper(nn.Module):
    def __init__(self, diff_model_config):
        super().__init__()
        self.diffusion_model = UNetModel(**diff_model_config)

    def forward(self, x, t, c_crossattn: list = None):
        cc = torch.cat(c_crossattn, 1)
        out = self.diffusion_model(x, t, context=cc)
        return out


class DDPM(nn.Module):
    def __init__(
        self,
        unet_config,
        timesteps=1000,
        image_size=256,
        channels=3,
        log_every_t=100,
        clip_denoised=True,
        linear_start=1e-4,
        linear_end=2e-2,
        original_elbo_weight=0.0,
        v_posterior=0.0,  # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
        l_simple_weight=1.0,
        use_positional_encodings=False,
        learn_logvar=False,
        logvar_init=0.0,
    ):
        super().__init__()
        self.cond_stage_model = None
        self.clip_denoised = clip_denoised
        self.log_every_t = log_every_t
        self.image_size = image_size  # try conv?
        self.channels = channels
        self.use_positional_encodings = use_positional_encodings
        self.model = UNetModel(**unet_config)

        self.v_posterior = v_posterior
        self.original_elbo_weight = original_elbo_weight
        self.l_simple_weight = l_simple_weight

        #### register_schedule() start
        betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, timesteps, dtype=np.float64) ** 2
        alphas = 1. - betas
        alphas_cumprod = np.cumprod(alphas, axis=0)
        alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])

        timesteps, = betas.shape
        self.num_timesteps = int(timesteps)
        self.linear_start = linear_start
        self.linear_end = linear_end
        assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'

        to_torch = partial(torch.tensor, dtype=torch.float32)

        self.register_buffer('betas', to_torch(betas))
        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
        self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
        self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
        self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
        self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))

        # calculations for posterior q(x_{t-1} | x_t, x_0)
        posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + self.v_posterior * betas
        # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
        self.register_buffer('posterior_variance', to_torch(posterior_variance))
        # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
        self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
        self.register_buffer('posterior_mean_coef1', to_torch(
            betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
        self.register_buffer('posterior_mean_coef2', to_torch(
            (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))

        lvlb_weights = self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
        # TODO how to choose this term
        lvlb_weights[0] = lvlb_weights[1]
        self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
        assert not torch.isnan(self.lvlb_weights).all()
        #### register_schedule() end

        self.learn_logvar = learn_logvar
        self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))



class LatentDiffusion(DDPM):
    def __init__(
        self,
        first_stage_config,
        cond_stage_config,
        num_timesteps_cond=None,
        cond_stage_key="image",
        cond_stage_trainable=False,
        concat_mode=True,
        cond_stage_forward=None,
        scale_factor=1.0,
        *args,
        **kwargs
    ):
        self.num_timesteps_cond = num_timesteps_cond
        assert self.num_timesteps_cond <= kwargs["timesteps"]
        ckpt_path = kwargs.pop("ckpt_path", None)
        ignore_keys = kwargs.pop("ignore_keys", [])
        super().__init__(*args, **kwargs)
        self.concat_mode = concat_mode
        self.cond_stage_trainable = cond_stage_trainable
        self.cond_stage_key = cond_stage_key
        self.num_downs = len(first_stage_config["ddconfig"]["ch_mult"]) - 1
        self.scale_factor = scale_factor

        self.first_stage_model = AutoencoderKL(**first_stage_config)
        self.cond_stage_model = BERTEmbedder(**cond_stage_config)

        self.cond_stage_forward = cond_stage_forward
        self.clip_denoised = False
        self.bbox_tokenizer = None  

        self.restarted_from_ckpt = False


    @torch.no_grad()
    def decode_first_stage(self, z):
        z = 1. / self.scale_factor * z
        return self.first_stage_model.decode(z)

In [None]:
class DDIMSamper:
    def __init__(self, model: LatentDiffusion):
        self.model = model
        self.ddpm_num_time_steps = model.num_timesteps

    @torch.no_grad()
    def sample(
        self,
        ddim_steps,
        batch_size,
        shape,
        conditioning=None,
        callback=None,
        normals_sequence=None,
        img_callback=None,
        quantize_x0=False,
        eta=0.0,
        x0=None,
        temperature=1.0,
        noise_dropout=0.0,
        score_corrector=None,
        corrector_kwargs=None,
        verbose=True,
        x_T=None,
        log_every_t=100,
        unconditional_guidance_scale=1.0,
        unconditional_conditioning=None,
    ):
        self.ddim_timesteps = np.arange(0, self.ddpm_num_timesteps, ddpm_num_timesteps // ddim_steps) + 1

        # ddim sampling parameters
        ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
                                                                                   ddim_timesteps=self.ddim_timesteps,
                                                                                   eta=ddim_eta,verbose=verbose)
        self.register_buffer('ddim_sigmas', ddim_sigmas)
        self.register_buffer('ddim_alphas', ddim_alphas)
        self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
        self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
        sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt((1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev))
        self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)

        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        print(f"Data shape for DDIM sampling is {size}, eta {eta}")

        samples, intermediates = self.ddim_sampling(
            conditioning,
            size,
            callback=callback,
            img_callback=img_callback,
            quantize_denoised=quantize_x0,
            x0=x0,
            noise_dropout=noise_dropout,
            temperature=temperature,
            score_corrector=score_corrector,
            corrector_kwargs=corrector_kwargs,
            x_T=x_T,
            log_every_t=log_every_t,
            unconditional_guidance_scale=unconditional_guidance_scale,
            unconditional_conditioning=unconditional_conditioning,
        )
        return samples, intermediates

    @torch.no_grad()
    def ddim_sampling(self, cond, shape,
                      x_T=None,
                      callback=None, quantize_denoised=False,
                      x0=None, img_callback=None, log_every_t=100,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None,):
        device = self.model.betas.device
        b = shape[0]
        img = torch.randn(shape, device=device)
        timesteps = self.ddpm_num_timesteps
        intermediates = {'x_inter': [img], 'pred_x0': [img]}
        time_range = reversed(range(0, timesteps))
        total_steps = timesteps

        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)

        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)

            outs = self.p_sample_ddim(img, cond, ts, index=index,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning)
            img, pred_x0 = outs

            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

        return img, intermediates

    @torch.no_grad()
    def p_sample_ddim(self, x, c, t, index, repeat_noise=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None):
        b, *_, device = *x.shape, x.device

        if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
            e_t = self.model.apply_model(x, t, c)
        else:
            x_in = torch.cat([x] * 2)
            t_in = torch.cat([t] * 2)
            c_in = torch.cat([unconditional_conditioning, c])
            e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
            e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

        alphas = self.model.alphas_cumprod
        alphas_prev = self.model.alphas_cumprod_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod
        sigmas = self.model.ddim_sigmas_for_original_num_steps
        # select parameters corresponding to the currently considered timestep
        a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
        a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
        sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
        sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

        # current prediction for x_0
        pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
        # direction pointing to x_t
        dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
        noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
        x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
        return x_prev, pred_x0

In [5]:
opt = Namespace()
opt.prompt = "a virus monster is playing guitar, oil on canvas"
opt.ddim_eta = 0.0
opt.n_samples = 4
opt.n_iter = 4
opt.scale = 5.0
opt.ddim_steps = 50
opt.outdir = "outputs/txt2img-samples"
opt.plms = False
opt.H = 256
opt.W = 256
print(opt)

Namespace(H=256, W=256, ddim_eta=0.0, ddim_steps=50, n_iter=4, n_samples=4, outdir='outputs/txt2img-samples', plms=False, prompt='a virus monster is playing guitar, oil on canvas', scale=5.0)


In [18]:
# Load yaml configs
with Path("latent-diffusion.yaml").open("r") as f:
    model_config = yaml.load(f, Loader=yaml.FullLoader)
    pprint(model_config)

{'channels': 4,
 'cond_stage_config': {'n_embed': 1280, 'n_layer': 32},
 'cond_stage_key': 'caption',
 'cond_stage_trainable': True,
 'conditioning_key': 'crossattn',
 'first_stage_config': {'ddconfig': {'attn_resolutions': [],
                                     'ch': 128,
                                     'ch_mult': [1, 2, 4, 4],
                                     'double_z': True,
                                     'dropout': 0.0,
                                     'in_channels': 3,
                                     'num_res_blocks': 2,
                                     'out_ch': 3,
                                     'resolution': 256,
                                     'z_channels': 4},
                        'embed_dim': 4,
                        'lossconfig': {'target': 'torch.nn.Identity'},
                        'monitor': 'val/rec_loss'},
 'first_stage_key': 'image',
 'image_size': 32,
 'linear_end': 0.012,
 'linear_start': 0.00085,
 'log_every_t': 200,


In [None]:
state_dict = torch.load("models/ldm/text2img-large/model.ckpt", map_location="cpu")["state_dict"]
print([k for k in state_dict.keys() if "alphas_cumprod" in k])

In [None]:
model = LatentDiffusion(**model_config)

In [None]:
sample_path = os.path.join(outpath, "samples")
os.makedirs(sample_path, exist_ok=True)
base_count = len(os.listdir(sample_path))

with torch.no_grad():
    uc = model.cond_stage_model.encode(opt.n_samples * [""])
    c = model.cond_stage_model.encode(opt.n_samples * [prompt])
    shape = [4, opt.H//8, opt.W//8]
    samples_ddim, _ = sampler.sample(ddim_steps=opt.ddim_steps,
                                    conditioning=c,
                                    batch_size=opt.n_samples,
                                    shape=shape,
                                    verbose=False,
                                    unconditional_guidance_scale=opt.scale,
                                    unconditional_conditioning=uc,
                                    eta=opt.ddim_eta)

    samples_ddim = 1. / self.scale_factor * samples_ddim5
    x_samples_ddim = model.first_stage_model.decode(samples_ddim)
    x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)

    for x_sample in x_samples_ddim:
        x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
        Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(sample_path, f"{base_count:04}.png"))
        base_count += 1

