# Common
Always run this, when start/restart the runtime

In [1]:
import math
import torch
from torch import nn
from scipy import integrate

from tqdm.auto import trange, tqdm

def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
    return x[(...,) + (None,) * dims_to_append]


def append_zero(x):
    return torch.cat([x, x.new_zeros([1])])


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = torch.linspace(0, 1, n)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
    """Constructs an exponential noise schedule."""
    sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
    return append_zero(sigmas)


def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
    """Constructs a continuous VP noise schedule."""
    t = torch.linspace(1, eps_s, n, device=device)
    sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
    return append_zero(sigmas)


def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)


def get_ancestral_step(sigma_from, sigma_to):
    """Calculates the noise level (sigma_down) to step down to and the amount
    of noise to add (sigma_up) when doing an ancestral sampling step."""
    sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
    sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
    return sigma_down, sigma_up


@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model(x, sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        dt = sigmas[i + 1] - sigma_hat
        # Euler method
        x = x + d * dt
    return x


@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """Ancestral sampling with Euler method steps."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        d = to_d(x, sigmas[i], denoised)
        # Euler method
        dt = sigma_down - sigmas[i]
        x = x + d * dt
        x = x + torch.randn_like(x) * sigma_up
    return x


@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model(x, sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        dt = sigmas[i + 1] - sigma_hat
        if sigmas[i + 1] == 0:
            # Euler method
            x = x + d * dt
        else:
            # Heun's method
            x_2 = x + d * dt
            denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
    return x


@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model(x, sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
        sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
        dt_1 = sigma_mid - sigma_hat
        dt_2 = sigmas[i + 1] - sigma_hat
        x_2 = x + d * dt_1
        denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
    return x


@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """Ancestral sampling with DPM-Solver inspired second-order steps."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        d = to_d(x, sigmas[i], denoised)
        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
        sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
        dt_1 = sigma_mid - sigmas[i]
        dt_2 = sigma_down - sigmas[i]
        x_2 = x + d * dt_1
        denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
        x = x + torch.randn_like(x) * sigma_up
    return x


def linear_multistep_coeff(order, t, i, j):
    if order - 1 > i:
        raise ValueError(f'Order {order} too high for step {i}')
    def fn(tau):
        prod = 1.
        for k in range(order):
            if j == k:
                continue
            prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
        return prod
    return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]



def revpre0(img,sigmas,t):
  return img

def revpre1(img,sigmas,t):
  return (revpreimg+noise * sigmas[t])*(1-zamask)+img*zamask

revpre=revpre0

@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    ds = []
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model( revpre(x,sigmas,i) , sigmas[i] * s_in, **extra_args)
        d = to_d(x, sigmas[i], denoised)
        ds.append(d)
        if len(ds) > order:
            ds.pop(0)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        cur_order = min(i + 1, order)
        coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
        x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
    return x


@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    v = torch.randint_like(x, 2) * 2 - 1
    fevals = 0
    def ode_fn(sigma, x):
        nonlocal fevals
        with torch.enable_grad():
            x = x[0].detach().requires_grad_()
            denoised = model(x, sigma * s_in, **extra_args)
            d = to_d(x, sigma, denoised)
            fevals += 1
            grad = torch.autograd.grad((d * v).sum(), x)[0]
            d_ll = (v * grad).flatten(1).sum(1)
        return d.detach(), d_ll
    x_min = x, x.new_zeros([x.shape[0]])
    t = x.new_tensor([sigma_min, sigma_max])
    sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
    latent, delta_ll = sol[0][-1], sol[1][-1]
    ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
    return ll_prior + delta_ll, {'fevals': fevals}



class DiscreteSchedule(nn.Module):
    """A mapping between continuous noise levels (sigmas) and a list of discrete noise
    levels."""

    def __init__(self, sigmas, quantize):
        super().__init__()
        self.register_buffer('sigmas', sigmas)
        self.quantize = quantize

    def get_sigmas(self, n=None):
        if n is None:
            return append_zero(self.sigmas.flip(0))
        t_max = len(self.sigmas) - 1
        t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
        return append_zero(self.t_to_sigma(t))

    def sigma_to_t(self, sigma, quantize=None):
        quantize = self.quantize if quantize is None else quantize
        
        dists = torch.abs(sigma - self.sigmas[:, None])
        if quantize:
            return torch.argmin(dists, dim=0).view(sigma.shape)
        low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
        low, high = self.sigmas[low_idx], self.sigmas[high_idx]
        w = (low - sigma) / (low - high)
        w = w.clamp(0, 1)
        t = (1 - w) * low_idx + w * high_idx
        return t.view(sigma.shape)

    def t_to_sigma(self, t):
        t = t.float()
        low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
        return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]


class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
    """A wrapper for discrete schedule DDPM models that output eps (the predicted
    noise)."""

    def __init__(self, model, alphas_cumprod, quantize):
        super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
        self.inner_model = model
        self.sigma_data = 1.

    def get_scalings(self, sigma):
        c_out = -sigma
        c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
        return c_out, c_in

    def get_eps(self, *args, **kwargs):
        return self.inner_model(*args, **kwargs)

    def loss(self, input, noise, sigma, **kwargs):
        c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        noised_input = input + noise * append_dims(sigma, input.ndim)
        eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
        return (eps - noise).pow(2).flatten(1).mean(1)

    def forward(self, input, sigma, **kwargs):
        c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
        return input + eps * c_out



def make_ddim_timesteps(num_ddim_timesteps, num_ddpm_timesteps):
    c = num_ddpm_timesteps // num_ddim_timesteps
    ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))

    # add one to get the final alpha values right (the ones from first scale to data during sampling)
    steps_out = ddim_timesteps + 1

    return steps_out


def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta):
    # select alphas for computing the variance schedule
    alphas = alphacums[ddim_timesteps]
    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())

    # according the the formula provided in https://arxiv.org/abs/2010.02502
    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))

    return sigmas, alphas, alphas_prev

def makerng():
  if seed == 0:
    rng=random.randint(0, 2**32)
    np.random.seed(rng)
    torch.manual_seed(rng)
    print('random seed=')
    print(rng)
  else:
    np.random.seed(seed)
    torch.manual_seed(seed)

class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
    """A wrapper for CompVis diffusion models."""

    def __init__(self, model, quantize=False, device='cpu'):
        super().__init__(model, model.alphas_cumprod, quantize=quantize)

    def get_eps(self, *args, **kwargs):
        return self.inner_model.apply_model(*args, **kwargs)


class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
        return uncond + (cond - uncond) * cond_scale


class CompVisJIT():
  def __init__(self):
    self.alphas_cumprod=torch.tensor(alphas_cumprod,device=cudev)
    self.apply_model=apply_model


In [2]:
%%writefile /content/MultiPromptExample1.txt
intp  / 30    // intp: interploation between prompts in 2nd line and 3rd line
Elon Musk in swimming pool
Joe biden ride on Trump

Overwriting /content/MultiPromptExample1.txt


In [3]:
%%writefile /content/MultiPromptExample2.txt
avg      // avg: mix prompt with weights, weights should not be negative
Photoshot of an elephant
8.3
Photoshot of an xenomorph
2.2

Overwriting /content/MultiPromptExample2.txt


In [4]:
%%writefile /content/MultiPromptExample3.txt
mad      // mad: calc prompt with weights, weights can be negative
Photoshot of an elephant
1.2
Photoshot of an xenomorph
-0.2

Overwriting /content/MultiPromptExample3.txt


# Super Resolution 4x<br>
Select one of these task: Super Resolution, txt2img, (old ldm)infilling

In [None]:
!wget https://huggingface.co/Larvik/LDMjit/resolve/main/dm_pnnx.pt
!wget https://huggingface.co/Larvik/LDMjit/resolve/main/fsd_pnnx.pt
!wget https://huggingface.co/Larvik/LDMjit/resolve/main/alphas_cumprod.npy

import os
import sys
import time

import numpy as np
import cv2
import functools
import torch



alphas_cumprod = np.load('alphas_cumprod.npy')


torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True


# ======================
# Arguemnt Parser Config
# ======================

def imread(filename, flags=cv2.IMREAD_COLOR):
    if not os.path.isfile(filename):
        print(f"File does not exist: {filename}")
        sys.exit()
    data = np.fromfile(filename, np.int8)
    img = cv2.imdecode(data, flags)
    return img

def preprocessing_img(img):
    if len(img.shape) < 3:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGRA)
    elif img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
    elif img.shape[2] == 1:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGRA)
    return img


def load_image(image_path):
    if os.path.isfile(image_path):
        img = imread(image_path, cv2.IMREAD_UNCHANGED)
    else:
        print(f'{image_path} not found.')
    return preprocessing_img(img)



def im2col(images, filters, stride=1, pad=0):
    if images.ndim == 2:
        images = images.reshape(1, 1, *images.shape)
    elif images.ndim == 3:
        B, I_h, I_w = images.shape
        images = images.reshape(B, 1, I_h, I_w)
    B, C, I_h, I_w = images.shape

    if isinstance(filters, tuple):
        if len(filters) == 2:
            filters = (1, 1, *filters)
        elif len(filters) == 3:
            M, F_h, F_w = filters
            filters = (M, 1, F_h, F_w)
        _, _, F_h, F_w = filters
    else:
        if filters.ndim == 2:
            filters = filters.reshape(1, 1, *filters.shape)
        elif filters.ndim == 3:
            M, F_h, F_w = filters.shape
            filters = filters.reshape(M, 1, F_h, F_w)
        _, _, F_h, F_w = filters.shape

    if isinstance(stride, tuple):
        stride_ud, stride_lr = stride
    else:
        stride_ud = stride
        stride_lr = stride
    if isinstance(pad, tuple):
        pad_ud, pad_lr = pad
    elif isinstance(pad, int):
        pad_ud = pad
        pad_lr = pad
    elif pad == "same":
        pad_ud = 0.5 * ((I_h - 1) * stride_ud - I_h + F_h)
        pad_lr = 0.5 * ((I_w - 1) * stride_lr - I_w + F_w)
    pad_zero = (0, 0)

    O_h = int((I_h - F_h + 2 * pad_ud) // stride_ud + 1)
    O_w = int((I_w - F_w + 2 * pad_lr) // stride_lr + 1)

    result_pad = (pad_ud, pad_lr)
    pad_ud = int(np.ceil(pad_ud))
    pad_lr = int(np.ceil(pad_lr))
    pad_ud = (pad_ud, pad_ud)
    pad_lr = (pad_lr, pad_lr)
    images = np.pad(
        images, [pad_zero, pad_zero, pad_ud, pad_lr], "constant")

    cols = np.empty((B, C, F_h, F_w, O_h, O_w))
    for h in range(F_h):
        h_lim = h + stride_ud * O_h
        for w in range(F_w):
            w_lim = w + stride_lr * O_w
            cols[:, :, h, w, :, :] = \
                images[:, :, h:h_lim:stride_ud, w:w_lim:stride_lr]

    cols = cols.transpose(1, 2, 3, 0, 4, 5).reshape(C * F_h * F_w, B * O_h * O_w)

    return cols, (O_h, O_w), result_pad


def col2im(cols, I_shape, O_shape, stride=1, pad=0):
    def get_f_shape(i, o, s, p):
        return int(i + 2 * p - (o - 1) * s)

    if len(I_shape) == 2:
        B = C = 1
        I_h, I_w = I_shape
    elif len(I_shape) == 3:
        C = 1
        B, I_h, I_w = I_shape
    else:
        B, C, I_h, I_w = I_shape
    O_h, O_w = O_shape

    if isinstance(stride, tuple):
        stride_ud, stride_lr = stride
    else:
        stride_ud = stride
        stride_lr = stride
    if isinstance(pad, tuple):
        pad_ud, pad_lr = pad
    elif isinstance(pad, int):
        pad_ud = pad
        pad_lr = pad

    F_h = get_f_shape(I_h, O_h, stride_ud, pad_ud)
    F_w = get_f_shape(I_w, O_w, stride_lr, pad_lr)
    pad_ud = int(np.ceil(pad_ud))
    pad_lr = int(np.ceil(pad_lr))
    cols = cols.reshape(C, F_h, F_w, B, O_h, O_w).transpose(3, 0, 1, 2, 4, 5)
    images = np.zeros((B, C, I_h + 2 * pad_ud + stride_ud - 1, I_w + 2 * pad_lr + stride_lr - 1))

    for h in range(F_h):
        h_lim = h + stride_ud * O_h
        for w in range(F_w):
            w_lim = w + stride_lr * O_w
            images[:, :, h:h_lim:stride_ud, w:w_lim:stride_lr] += cols[:, :, h, w, :, :]

    return images[:, :, pad_ud: I_h + pad_ud, pad_lr: I_w + pad_lr]

def meshgrid(h, w):
    y = np.arange(0, h).reshape(h, 1, 1).repeat(w, axis=1)
    x = np.arange(0, w).reshape(1, w, 1).repeat(h, axis=0)
    arr = np.concatenate([y, x], axis=-1)

    return arr


def delta_border(h, w):
    """
    :param h: height
    :param w: width
    :return: normalized distance to image border,
     wtith min distance = 0 at border and max dist = 0.5 at image center
    """
    lower_right_corner = np.array([h - 1, w - 1]).reshape(1, 1, 2)
    arr = meshgrid(h, w) / lower_right_corner
    dist_left_up = np.min(arr, axis=-1, keepdims=True)
    dist_right_down = np.min(1 - arr, axis=-1, keepdims=True)

    edge_dist = np.min(np.concatenate([dist_left_up, dist_right_down], axis=-1), axis=-1)

    return edge_dist


def get_weighting(h, w, Ly, Lx):
    clip_min_weight = 0.01
    clip_max_weight = 0.5

    weighting = delta_border(h, w)
    weighting = np.clip(weighting, clip_min_weight, clip_max_weight)
    weighting = weighting.reshape(1, h * w, 1).repeat(Ly * Lx, axis=-1)

    return weighting


def get_fold_unfold(x, kernel_size, stride, uf=1, df=1):
    """
    :param x: img of size (bs, c, h, w)
    :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
    """
    bs, nc, h, w = x.shape

    # number of crops in image
    Ly = (h - kernel_size[0]) // stride[0] + 1
    Lx = (w - kernel_size[1]) // stride[1] + 1

    unfold = functools.partial(im2col, filters=kernel_size, stride=stride)
    if uf == 1 and df == 1:
        fold = functools.partial(
            col2im,
            stride=stride)

        weighting = get_weighting(kernel_size[0], kernel_size[1], Ly, Lx)
        weighting = weighting.reshape((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))

    elif uf > 1 and df == 1:
        fold = functools.partial(
            col2im,
            stride=(stride[0] * uf, stride[1] * uf))

        weighting = get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx)
        weighting = weighting.reshape((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))

    elif df > 1 and uf == 1:
        fold = functools.partial(
            col2im,
            stride=(stride[0] // df, stride[1] // df))

        weighting = get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx)
        weighting = weighting.reshape((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))

    else:
        raise NotImplementedError

    return fold, unfold, weighting





def normalize_image(image, normalize_type='255'):
    """
    Normalize image
    Parameters
    ----------
    image: numpy array
        The image you want to normalize
    normalize_type: string
        Normalize type should be chosen from the type below.
        - '255': simply dividing by 255.0
        - '127.5': output range : -1 and 1
        - 'ImageNet': normalize by mean and std of ImageNet
        - 'None': no normalization
    Returns
    -------
    normalized_image: numpy array
    """
    if normalize_type == 'None':
        return image
    elif normalize_type == '255':
        return image / 255.0
    elif normalize_type == '127.5':
        return image / 127.5 - 1.0
    elif normalize_type == 'ImageNet':
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = image / 255.0
        for i in range(3):
            image[:, :, i] = (image[:, :, i] - mean[i]) / std[i]
        return image
    else:
        pass



def preprocess(img):
    im_h, im_w, _ = img.shape

    up_f = 4
    oh, ow = up_f * im_h, up_f * im_w

    img = normalize_image(img, normalize_type='255')

    c = img * 2 - 1
    c = c.transpose(2, 0, 1)  # HWC -> CHW
    c = np.expand_dims(c, axis=0)
    c = c.astype(np.float32)

    c_up = cv2.resize(img, (ow, oh), interpolation=cv2.INTER_LINEAR)
    c_up = c_up.transpose(2, 0, 1)  # HWC -> CHW
    c_up = np.expand_dims(c_up, axis=0)
    c_up = c_up.astype(np.float32)

    return c_up, c


def postprocess(sample):
    sample = np.clip(sample, -1., 1.)
    sample = (sample + 1.) / 2. * 255
    sample = np.transpose(sample, (1, 2, 0))
    sample = sample[:, :, ::-1]  # RGB -> BGR
    sample = sample.astype(np.uint8)

    return sample


def ddim_sampling(
        models, cond):
    shape = cond.shape
    img = np.random.randn(shape[0] * shape[1] * shape[2] * shape[3]).reshape(shape)
    img = img.astype(np.float32)

    timesteps = ddim_timesteps
    time_range = np.flip(timesteps)
    total_steps = timesteps.shape[0]

    print(f"Running DDIM Sampling with {total_steps} timesteps")

    try:
        from tqdm import tqdm
        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
    except ModuleNotFoundError:
        def iter_func(a):
            for i, x in enumerate(a):
                print("DDIM Sampler: %s/%s" % (i + 1, len(a)))
                yield x

        iterator = iter_func(time_range)

    for i, step in enumerate(iterator):
        index = total_steps - i - 1
        ts = np.full((shape[0],), step, dtype=np.int64)

        img, pred_x0 = p_sample_ddim(
            models,
            img, cond, ts,
            index=index,
        )
        img = img.astype(np.float32)

    return img


# ddim
def p_sample_ddim(
        models, x, c, t, index,
        temperature=1):
    e_t = apply_model(models, x, t, c)

    alphas = ddim_alphas
    alphas_prev = ddim_alphas_prev
    sqrt_one_minus_alphas = ddim_sqrt_one_minus_alphas
    sigmas = ddim_sigmas

    # select parameters corresponding to the currently considered timestep
    b, *_ = x.shape
    a_t = np.full((b, 1, 1, 1), alphas[index])
    a_prev = np.full((b, 1, 1, 1), alphas_prev[index])
    sigma_t = np.full((b, 1, 1, 1), sigmas[index])
    sqrt_one_minus_at = np.full((b, 1, 1, 1), sqrt_one_minus_alphas[index])

    # current prediction for x_0
    pred_x0 = (x - sqrt_one_minus_at * e_t) / np.sqrt(a_t)

    # direction pointing to x_t
    dir_xt = np.sqrt(1. - a_prev - sigma_t ** 2) * e_t

    noise = sigma_t * np.random.randn(x.size).reshape(x.shape) * temperature
    x_prev = np.sqrt(a_prev) * pred_x0 + dir_xt + noise

    return x_prev, pred_x0


def decode_first_stage(models, z):
    ks = (128, 128)
    stride = (64, 64)
    uf = 4

    bs, nc, h, w = z.shape

    fold, unfold, weighting = get_fold_unfold(z, ks, stride, uf=uf)

    z, o_shape, _ = unfold(z)  # (bn, nc * prod(**ks), L)

    # Reshape to img shape
    z = z.reshape((bs, -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
    z = z.astype(np.float32)

    print('first_stage_decode...')

    first_stage_decode = models['first_stage_decode']
    outputs = []
    for i in range(z.shape[-1]):
        x = z[:, :, :, :, i]
        output = first_stage_decode(torch.tensor(x).cuda())
        outputs.append(output[0].cpu().numpy())

    o = np.stack(outputs, axis=-1)  # # (bn, nc, ks[0], ks[1], L)
    o = o * weighting

    # Reverse reshape to img shape
    o = o.reshape((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
    # stitch crops together
    decoded = fold(o, I_shape=(1, 3, h * uf, w * uf), O_shape=o_shape)

    normalization = fold(weighting, I_shape=(1, 1, h * uf, w * uf), O_shape=o_shape)
    decoded = decoded / normalization  # norm is shape (1, 1, h, w)

    return decoded


# ddpm
def apply_model(models, x_noisy, t, cond):
    ks = (128, 128)
    stride = (64, 64)

    bs, nc, h, w = x_noisy.shape

    fold, unfold, weighting = get_fold_unfold(x_noisy, ks, stride)

    z, o_shape, _ = unfold(x_noisy)  # (bn, nc * prod(**ks), L)
    # Reshape to img shape
    z = z.reshape((bs, -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
    z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]

    c, *_ = unfold(cond)
    c = c.reshape((bs, -1, ks[0], ks[1], c.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
    cond_list = [c[:, :, :, :, i] for i in range(c.shape[-1])]

    # apply model by loop over crops
    diffusion_model = models["diffusion_model"]
    outputs = []
    for i in range(z.shape[-1]):
        x = z_list[i]
        cond = cond_list[i]
        xc = np.concatenate([x, cond], axis=1)
        xc = xc.astype(np.float32)
        
        output = diffusion_model(torch.tensor(xc).cuda(), torch.tensor(t).cuda())
            
       
        outputs.append(output[0].cpu().numpy())

    o = np.stack(outputs, axis=-1)
    o = o * weighting

    # Reverse reshape to img shape
    o = o.reshape((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
    # stitch crops together
    normalization = fold(weighting, I_shape=(1, 1, h, w), O_shape=o_shape)
    x_recon = fold(o, I_shape=(1, 3, h, w), O_shape=o_shape) / normalization

    return x_recon


def predict(models, img):
    img = img[:, :, ::-1]  # BGR -> RGB

    _, c = preprocess(img)

    samples = ddim_sampling(models, c)

    x_sample = decode_first_stage(models, samples)

    img = postprocess(x_sample[0])

    return img

models = dict(
    first_stage_decode=torch.jit.load('/content/fsd_pnnx.pt').eval().cuda(),
    diffusion_model=torch.jit.load('/content/dm_pnnx.pt').eval().cuda(),
)


In [None]:

"""
ddim_timesteps
"""
ddim_eta = 1.0
ddim_num_steps = 100
ddpm_num_timesteps = 1000
ddim_timesteps = make_ddim_timesteps(ddim_num_steps, ddpm_num_timesteps)

"""
ddim sampling parameters
"""

ddim_sigmas, ddim_alphas, ddim_alphas_prev = \
    make_ddim_sampling_parameters(
        alphacums=alphas_cumprod,
        ddim_timesteps=ddim_timesteps,
        eta=ddim_eta)

ddim_sqrt_one_minus_alphas = np.sqrt(1. - ddim_alphas)

inputz=['/content/sample_data/zkrp.jpg']

 

for image_path in inputz:
    print(image_path)

    # prepare input data
    img = load_image(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)

    # inference
    print('Start inference...')
    
    img = predict(models, img)

    # plot result
    savepath = '/content/sample_data/out.png'
    print(f'saved at : {savepath}')
    cv2.imwrite(savepath, img)

print('Script finished successfully.')

# txt2img

In [None]:
jit=False #@param {type:'boolean'}

import os

if not os.path.isfile('autoencoder_pnnx.pt'):
  !pip install transformers
  !wget https://huggingface.co/Larvik/tempsd/resolve/main/alphas_cumprod.npz
  !wget https://huggingface.co/Larvik/temp1/resolve/main/transformer_pnnx.pt
  !wget https://huggingface.co/Larvik/temp1/resolve/main/diffusion_emb_pnnx.pt
  !wget https://huggingface.co/Larvik/temp1/resolve/main/diffusion_mid_pnnx.pt
  !wget https://huggingface.co/Larvik/temp1/resolve/main/diffusion_out_pnnx.pt
  !wget https://huggingface.co/Larvik/temp1/resolve/main/autoencoder_pnnx.pt
  !wget https://huggingface.co/Larvik/temp1/resolve/main/imgencoder_pnnx.pt

import sys
import time
import random
from threading import Thread
import numpy as np
import cv2
from PIL import Image
import PIL
from transformers import CLIPTokenizer

import torch

torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True

alphas_cumprod = np.load('alphas_cumprod.npz')['a']

cudev=torch.device('cuda')





# encoder
class BERTEmbedder:
    """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""

    def __init__(self, transformer, max_length=77):
        self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
        self.max_length = max_length

        self.transformer = transformer
        

    def encode(self, text, nsamp):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                            return_overflowing_tokens=False, padding="max_length", return_tensors="pt")

        tokens = batch_encoding["input_ids"]
        tokens = torch.tensor(tokens.numpy()).expand(nsamp,-1)
        

       
        z = self.transformer(tokens)
        

        
        
        
        return z.cuda()




preimg=None
revpreimg=None




# ddpm
def apply_model(x, t, cond):

    h, emb, hs = diffusion_emb(x, t, cond)
    
    h = diffusion_mid(h, emb, cond, *hs[6:])

    output = diffusion_out(h, emb, cond, *hs[:6])

    return output


# decoder
def decode_first_stage(z):

    output = autoencoder(z/0.18215)
       
    return output


    
def load_img(path):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) from {path}")
    w2, h2 = map(lambda x: x - x % 32, (w, h))
    if w!=w2 or h!=h2:
      image = image.resize((w2, h2), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.


def intptxtemb(t1,t2,step):
  stp=step-1
  c1=cond_stage_model.encode(t1, n_samples)
  c2=cond_stage_model.encode(t2, n_samples)
  intpos=[]
  for i in range(step):
    intpos.append((c2*i+c1*(stp-i))/stp)
  return intpos

def makeCs(prmt):
  if prmt.endswith('.txt'):
    with open(prmt,'rt') as f:
      stz=f.read().splitlines()
    cmd=stz[0].replace(' ','').replace('\t','').split('/')
    cmd0=cmd[0]
    if cmd0.startswith('intp'):
      return intptxtemb(stz[1],stz[2],int(cmd[1]))

    prmpl=(len(stz)-1)>>1
    stz=stz[1:]
    ptxt=[]
    pwgt=[]
    wgtsum=0
    for i in range(prmpl):
      ptxt.append(  cond_stage_model.encode(stz[2*i])  )
      wgt=float(stz[2*i+1])
      wgtsum+=wgt
      pwgt.append(  wgt  )
    if cmd0.startswith('avg'):
      for i in range(prmpl):
        pwgt[i]=pwgt[i]/wgtsum
    
    cout=ptxt[0]*pwgt[0]
    for i in range(1,prmpl):
      cout+=(ptxt[i]*pwgt[i])
    return [cout]
  else:
    return [cond_stage_model.encode(prmt, n_samples)]



fext='_%dx%dv%d.png'
def saver():
  global x_samples
  i=iita
  np.save( (outputp+fext%(i,1,ktta))[:-4] + '.npy', samples)
  x_samples = np.clip((x_samples.numpy() + 1.0) / 2.0, a_min=0.0, a_max=1.0)
  k=0
  for x_sample in x_samples:
      x_sample = x_sample.transpose(1, 2, 0)  # CHW -> HWC
      x_sample = x_sample * 255
      img = x_sample.astype(np.uint8)
      img = img[:, :, ::-1]  # RGB -> BGR
      cv2.imwrite(outputp+fext%(i,k,ktta), img)
      k+=1
  

UseSamplr=sample_lms

def predict(prompt, uc):
    global x_samples
    global samples
    global ktta
    global noise
    

    c_list = makeCs(prompt)
    
    
    sigmas = model_wrap.get_sigmas(ddim_num_steps)
    noise = torch.randn(shape, dtype=torch.float,device=cudev)
    if preimg is not None:
      t_enc= int(strength * ddim_num_steps)
      img = preimg.cuda() + noise * sigmas[ddim_num_steps - t_enc - 1] 
      sigma_sched = sigmas[ddim_num_steps - t_enc - 1:]
    else:
      img = noise*sigmas[0]
      sigma_sched=sigmas

    ktta=0
    for c in c_list:
      extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
      with torch.cuda.amp.autocast(dtype=torch.float16):
        samples = UseSamplr(model_wrap_cfg, img, sigma_sched, extra_args=extra_args, disable=False)
      ktta+=1
      
      
      x_samples = decode_first_stage(  samples ).cpu()
      samples=samples.cpu()
      t3 = Thread(target = saver)
      a3 = t3.start()

    return


def init_img_type():
  global init_img
  if init_img.endswith('npy'):
    return 0
  elif init_img.endswith('jpg') or init_img.endswith('png'):
    if os.path.isfile(init_img+'.npy'):
      init_img+='.npy'
      return 0
    else:
      return 1
  else:
    return 99



cond_stage_model = BERTEmbedder(torch.jit.load('transformer_pnnx.pt').eval())
diffusion_emb = torch.jit.load('diffusion_emb_pnnx.pt').eval().half().cuda()
diffusion_mid = torch.jit.load('diffusion_mid_pnnx.pt').eval().half().cuda()
diffusion_out = torch.jit.load('diffusion_out_pnnx.pt').eval().half().cuda()
autoencoder = torch.jit.load('autoencoder_pnnx.pt').eval().cuda()
imgenc = torch.jit.load('imgencoder_pnnx.pt').eval()


model_wrap = CompVisDenoiser(CompVisJIT())
model_wrap_cfg = CFGDenoiser(model_wrap)

👇Optional👇

In [None]:
'''
sample_euler
sample_euler_ancestral
sample_heun
sample_dpm_2
sample_dpm_2_ancestral
sample_lms
'''
UseSamplr = sample_heun


In [None]:
init_img='xxx' #@param {type:'string'}
initymgtyp=init_img_type()
if initymgtyp == 0:
  preimg=torch.tensor(np.load(init_img), device='cpu')
  n_samples=preimg.size(0)
  H=preimg.size(2)<<3
  W=preimg.size(3)<<3
elif initymgtyp == 1:
  n_samples=1
  rpt=load_img(init_img)
  H=rpt.size(2)
  W=rpt.size(3)
  preimg=imgenc(  rpt, torch.randn(torch.Size([n_samples,4,H>>3,W>>3]))  )*0.18215
  np.save(init_img+'.npy',preimg.numpy())
else:
  preimg=None

infilling

In [None]:
zamask=np.load('bench2_mask.npy')
revpreimg=preimg
preimg=None
revpre=revpre1

☝️Optional☝️

In [None]:
InThread=False #@param {type:'boolean'}

prompt = 'a photograph of an astronaut riding a horse' #@param {type:'string'}

n_iter = 1 #@param {type:'integer'}
if preimg is None and revpreimg is None:
  n_samples = 1 #@param {type:'integer'}
  H=704 #@param {type:'integer'}
  W=768 #@param {type:'integer'}



ddim_num_steps = 50  #@param {type:'integer'}
ddpm_num_timesteps = 1000

seed=0 #@param {type:'integer'}

outputp='/content/sample_data' #@param {type:'string'}



strength=0.5 #@param {type:'number'}



cfg_scale = 7.5 #@param {type:'number'}
ddim_eta = 0  #@param {type:'integer'}


outputp=outputp+'/'+str(len(os.listdir(outputp)))
"""
ddim_timesteps
"""

ddim_timesteps = make_ddim_timesteps(
    ddim_num_steps, ddpm_num_timesteps)

"""
ddim sampling parameters
"""

ddim_sigmas, ddim_alphas, ddim_alphas_prev = \
    make_ddim_sampling_parameters(
        alphacums=alphas_cumprod,
        ddim_timesteps=ddim_timesteps,
        eta=ddim_eta)

ddim_sqrt_one_minus_alphas = np.sqrt(1. - ddim_alphas)

shape = [n_samples, 4, H>>3 , W>>3 ]


makerng()


print("prompt: %s" % prompt)



print('Start inference...')
uc = None
if cfg_scale != 1.0:
  uc = cond_stage_model.encode([''], n_samples)


  
def wpa():
  global x_samples
  global samples
  global iita
  torch.set_grad_enabled(False)
  

  all_samples = []
  for iita in range(n_iter):
      print("iteration: %s" % (iita + 1))

      
      predict(prompt, uc)
      
  print('Script finished successfully.')
  torch.cuda.empty_cache()


if InThread:
  t1 = Thread(target = wpa)
  a1 = t1.start()
else:
  wpa()
