# Common
Always run this, when start/restart the runtime

In [None]:
import math
import torch
from torch import nn
from scipy import integrate
from threading import Thread

from PIL import Image
import numpy as np
from tqdm.auto import trange, tqdm

def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
    return x[(...,) + (None,) * dims_to_append]


def append_zero(x):
    return torch.cat([x, x.new_zeros([1])])


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cuda'):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = torch.linspace(0, 1, n,device=device)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
    """Constructs an exponential noise schedule."""
    sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
    return append_zero(sigmas)


def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
    """Constructs a continuous VP noise schedule."""
    t = torch.linspace(1, eps_s, n, device=device)
    sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
    return append_zero(sigmas)


def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)


def get_ancestral_step(sigma_from, sigma_to):
    """Calculates the noise level (sigma_down) to step down to and the amount
    of noise to add (sigma_up) when doing an ancestral sampling step."""
    sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
    sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
    return sigma_down, sigma_up


@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model( i, revpre(x,sigmas,i), sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        dt = sigmas[i + 1] - sigma_hat
        # Euler method
        x = x + d * dt
    return x


@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """Ancestral sampling with Euler method steps."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(  i,  revpre(x,sigmas,i), sigmas[i] * s_in, **extra_args)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        d = to_d(x, sigmas[i], denoised)
        # Euler method
        dt = sigma_down - sigmas[i]
        x = x + d * dt
        x = x + torch.randn_like(x) * sigma_up
    return x


@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model(  i,  revpre(x,sigmas,i), sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        dt = sigmas[i + 1] - sigma_hat
        if sigmas[i + 1] == 0:
            # Euler method
            x = x + d * dt
        else:
            # Heun's method
            x_2 = x + d * dt
            denoised_2 = model(i, x_2, sigmas[i + 1] * s_in, **extra_args)
            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
    return x


@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model(  i,  revpre(x,sigmas,i), sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
        sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
        dt_1 = sigma_mid - sigma_hat
        dt_2 = sigmas[i + 1] - sigma_hat
        x_2 = x + d * dt_1
        denoised_2 = model(i,x_2, sigma_mid * s_in, **extra_args)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
    return x


@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """Ancestral sampling with DPM-Solver inspired second-order steps."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(  i,  revpre(x,sigmas,i), sigmas[i] * s_in, **extra_args)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        d = to_d(x, sigmas[i], denoised)
        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
        sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
        dt_1 = sigma_mid - sigmas[i]
        dt_2 = sigma_down - sigmas[i]
        x_2 = x + d * dt_1
        denoised_2 = model(i, x_2, sigma_mid * s_in, **extra_args)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
        x = x + torch.randn_like(x) * sigma_up
    return x


def linear_multistep_coeff(order, t, i, j):
    if order - 1 > i:
        raise ValueError(f'Order {order} too high for step {i}')
    def fn(tau):
        prod = 1.
        for k in range(order):
            if j == k:
                continue
            prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
        return prod
    return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]

latlog=[]
def revpre0_log(img,sigmas,t):
  latlog.append((img-noise * sigmas[t]).cpu().numpy())
  return img

def revpre0(img,sigmas,t):
  return img

def revpre1(img,sigmas,t):
  return (revpreimg+noise * sigmas[t])*(1-zamask)+img*zamask

revpre=revpre0

@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    ds = []
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(  i,  revpre(x,sigmas,i) , sigmas[i] * s_in, **extra_args)
        d = to_d(x, sigmas[i], denoised)
        ds.append(d)
        if len(ds) > order:
            ds.pop(0)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        cur_order = min(i + 1, order)
        coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
        x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
    return x


@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    v = torch.randint_like(x, 2) * 2 - 1
    fevals = 0
    def ode_fn(sigma, x):
        nonlocal fevals
        with torch.enable_grad():
            x = x[0].detach().requires_grad_()
            denoised = model(x, sigma * s_in, **extra_args)
            d = to_d(x, sigma, denoised)
            fevals += 1
            grad = torch.autograd.grad((d * v).sum(), x)[0]
            d_ll = (v * grad).flatten(1).sum(1)
        return d.detach(), d_ll
    x_min = x, x.new_zeros([x.shape[0]])
    t = x.new_tensor([sigma_min, sigma_max])
    sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
    latent, delta_ll = sol[0][-1], sol[1][-1]
    ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
    return ll_prior + delta_ll, {'fevals': fevals}



class DiscreteSchedule(nn.Module):
    """A mapping between continuous noise levels (sigmas) and a list of discrete noise
    levels."""

    def __init__(self, sigmas, quantize):
        super().__init__()
        self.register_buffer('sigmas', sigmas)
        self.quantize = quantize

    def get_sigmas(self, n=None):
        if n is None:
            return append_zero(self.sigmas.flip(0))
        t_max = len(self.sigmas) - 1
        t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
        return append_zero(self.t_to_sigma(t))

    def sigma_to_t(self, sigma, quantize=None):
        quantize = self.quantize if quantize is None else quantize
        
        dists = torch.abs(sigma - self.sigmas[:, None])
        if quantize:
            return torch.argmin(dists, dim=0).view(sigma.shape)
        low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
        low, high = self.sigmas[low_idx], self.sigmas[high_idx]
        w = (low - sigma) / (low - high)
        w = w.clamp(0, 1)
        t = (1 - w) * low_idx + w * high_idx
        return t.view(sigma.shape)

    def t_to_sigma(self, t):
        t = t.float()
        low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
        return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]


class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
    """A wrapper for discrete schedule DDPM models that output eps (the predicted
    noise)."""

    def __init__(self, model, alphas_cumprod, quantize):
        super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
        self.inner_model = model
        self.sigma_data = 1.

    def get_scalings(self, sigma):
        c_out = -sigma
        c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
        return c_out, c_in

    def get_eps(self, *args, **kwargs):
        return self.inner_model(*args, **kwargs)

    def loss(self, input, noise, sigma, **kwargs):
        c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        noised_input = input + noise * append_dims(sigma, input.ndim)
        eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
        return (eps - noise).pow(2).flatten(1).mean(1)

    def forward(self, input, sigma, **kwargs):
        c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
        return input + eps * c_out



def make_ddim_timesteps(num_ddim_timesteps, num_ddpm_timesteps):
    c = num_ddpm_timesteps // num_ddim_timesteps
    ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))

    # add one to get the final alpha values right (the ones from first scale to data during sampling)
    steps_out = ddim_timesteps + 1

    return steps_out


def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta):
    # select alphas for computing the variance schedule
    alphas = alphacums[ddim_timesteps]
    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())

    # according the the formula provided in https://arxiv.org/abs/2010.02502
    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))

    return sigmas, alphas, alphas_prev

def makerng():
  if seed == 0:
    rng=random.randint(0, 2**32)
    np.random.seed(rng)
    torch.manual_seed(rng)
    print('random seed=')
    print(rng)
  else:
    np.random.seed(seed)
    torch.manual_seed(seed)

def dlpromptexample():
  !wget https://github.com/TabuaTambalam/DalleWebms/releases/download/0.1/pexmp.7z
  !7z x pexmp.7z
  

def mkmodel_state_dict():
  try:
    import jkt
  except:
    !wget https://raw.githubusercontent.com/TabuaTambalam/DalleWebms/main/docs/sd/jkt.py
    import jkt
  
  difjit=[diffusion_emb,diffusion_mid,diffusion_out]
  model_state_dict = {}
  jna1=jkt.nam1
  for i in range(3):
    sd=difjit[i].state_dict()
    jna2=jkt.nam2[i]
    for k in sd:
      uwa=sd[k]
      if 'pnnx' in k:
        model_state_dict[jna2[k]]=uwa
      else:
        model_state_dict[jna1[k]]=uwa
  return model_state_dict

SDlatDEC=None
def latdec(fna,scale=5.5):
  global SDlatDEC
  if SDlatDEC is None:
    if not os.path.isfile('autoencoder_pnnx.pt'):
      !wget https://huggingface.co/Larvik/sd470k/resolve/main/autoencoder_pnnx.pt
    SDlatDEC=torch.jit.load('autoencoder_pnnx.pt').cuda()
  lat=torch.tensor(np.load(fna)).cuda()
  return SDlatDEC(lat*scale)


def localhttp(root='/'):
  global HTML
  if not os.path.isfile('/content/sample_data/izh.txt'):
    from IPython.core.display import HTML
    !nohup python3 -m http.server -d {root} 8233 > /content/sample_data/izh.txt &


def f_sampler():
  global UseSamplr
  if Sampler == 'euler':
    UseSamplr = sample_euler
  elif Sampler == 'euler_a':
    UseSamplr = sample_euler_ancestral
  elif Sampler == 'heun':
    UseSamplr = sample_heun
  elif Sampler == 'dpm_2':
    UseSamplr = sample_dpm_2
  elif Sampler == 'dpm_2_a':
    UseSamplr = sample_dpm_2_ancestral
  elif Sampler == 'lms':
    UseSamplr = sample_lms

def f_sigmas():
  if Karras:
    return ddim_eta*get_sigmas_karras(ddim_num_steps,model_wrap.sigmas[0].item(),model_wrap.sigmas[-1].item(),rho=KarrasRho, device=cudev )
  else:
    return ddim_eta*model_wrap.get_sigmas(ddim_num_steps)

def fixver(ver,dfsver):
  if ver != '470k':
    return ''
  return dfsver
def f_dljit(ver='470k',dfsver=''):
  dfsver=fixver(ver,dfsver)
  if not os.path.isfile('diffusion_out_pnnx.pt'):
    !pip install ftfy transformers einops accelerate
    !wget https://huggingface.co/Larvik/sd{ver}/resolve/main/alphas_cumprod.npz
    !wget https://huggingface.co/Larvik/tfmod/resolve/main/transformer_pnnx.pt
    !wget https://huggingface.co/Larvik/sd{ver}/resolve/main/autoencoder_pnnx.pt
    !wget https://huggingface.co/Larvik/sd{ver}/resolve/main/imgencoder_pnnx.pt
    ver+=dfsver
    !wget https://huggingface.co/Larvik/sd{ver}/resolve/main/diffusion_emb_pnnx.pt
    !wget https://huggingface.co/Larvik/sd{ver}/resolve/main/diffusion_mid_pnnx.pt
    !wget https://huggingface.co/Larvik/sd{ver}/resolve/main/diffusion_out_pnnx.pt


class Insertor:
  def __init__(self, string, n):
    self.rpla=string
    self.rplb=inzdict[n]
    self.token=inzdict[n+1]
    self.idx=[]
    self.emb=torch.tensor(np.fromfile('UserEmb/'+string[1:-1]+'.bin',dtype=np.float32))
  def repl(self, string):
    return string.replace(self.rpla,self.rplb)


# encoder
class BERTEmbedder:
    """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""

    def __init__(self, transformer, max_length=77):
        self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
        self.max_length = max_length
        self.insertor=[]
        self.transformer = transformer
        self.embedding = torch.nn.Embedding.from_pretrained(self.transformer.state_dict()['text_model_embeddings_token_embedding.weight'])
        self.encode = self.encode0
        

    def insert(self,inz):
      if len(inz)<3:
        self.insertor=[]
        return
      n=len(self.insertor)
      if n > 4:
        return
      self.insertor.append(Insertor(inz,n*2))
      

    def encode0(self, text, nsamp):
        HasInz=False
        if self.insertor and len(text) > 0:
          HasInz=True
          for iz in self.insertor:
            text=iz.repl(text)
            iz.idx=[]
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                            return_overflowing_tokens=False, padding="max_length", return_tensors="pt")

        tokens = batch_encoding["input_ids"]
        
        if HasInz:
          for iz in self.insertor:
            iz.idx=torch.where(tokens == iz.token)[1]

        
        tokens = tokens.expand(nsamp,-1)

        amb=self.embedding(tokens)
        if HasInz:
          nl=amb.size(0)
          for iz in self.insertor:
            for pidx in iz.idx:
              for i in range(nl):
                amb[i][pidx]=iz.emb

        z = self.transformer(amb)
        return z.cuda()

    def encode2(self, text, nsamp):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                            return_overflowing_tokens=False, padding="max_length", return_tensors="pt")

        tokens = batch_encoding["input_ids"]
        tokens = torch.tensor(tokens.numpy()).expand(nsamp,-1)

        z = self.transformer(tokens)
        return z.cuda()

    def encode3(self, text, nsamp):
      batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                            return_overflowing_tokens=False, padding="max_length", return_tensors="pt")

      return batch_encoding["input_ids"]

    def encode4(self, tokens):
       
         #.expand(ebb.size(0),1,-1) #.expand(1,77,-1)

        z = self.transformer(tokens)
        return z.cuda()




class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
    """A wrapper for CompVis diffusion models."""

    def __init__(self, model, quantize=False, device='cpu'):
        super().__init__(model, model.alphas_cumprod, quantize=quantize)

    def get_eps(self, *args, **kwargs):
        return self.inner_model.apply_model(*args, **kwargs)

def get_cond_simp(d,cond):
  return cond

def get_cond_list(d,cond):
  return cond[d]

get_cond=get_cond_simp

class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, d, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, get_cond(d,cond)])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
        return uncond + (cond - uncond) * cond_scale


class SRDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, cond ):
        cond = self.inner_model(x, sigma, cond=cond)
        return cond


class CompVisJIT():
  def __init__(self):
    self.alphas_cumprod=torch.tensor(alphas_cumprod,device=cudev)
    self.apply_model=apply_model

class ifeeder():
  def __init__(self):
    self.getn=self.get_simp
  def get_simp(self,n):
    return self.bs
  def setbs(self,in_bs):
    self.bs=in_bs
  
  def get_npbins(self,n):
    return torch.tensor(np.fromfile(self.pattern%(n+1),dtype=np.float32).reshape(self.shape),device=cudev)+self.noiseadd

Karras=False
model_wrap=None

In [None]:
import os
if not os.path.isfile('MultiPromptExample1.txt'):
  t3 = Thread(target = dlpromptexample)
  a3 = t3.start()

Overwriting /content/MultiPromptExample1.txt


# Super Resolution 4x<br>
Select one of these task: Super Resolution, txt2img, (old ldm)infilling

In [None]:
jit=False #@param {type:'boolean'}

import os

if not os.path.isfile('fsd_pnnx.pt'):
  !wget https://huggingface.co/Larvik/LDMjit/resolve/main/alphas_cumprod.npy
  !wget https://huggingface.co/Larvik/LDMjit/resolve/main/dm_pnnx.pt
  !wget https://huggingface.co/Larvik/LDMjit/resolve/main/fsd_pnnx.pt


import sys
import time

import numpy as np
import cv2
import functools
import torch

cudev=torch.device('cuda')

alphas_cumprod = np.load('alphas_cumprod.npy')


torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True


# ======================
# Arguemnt Parser Config
# ======================

def imread(filename, flags=cv2.IMREAD_COLOR):
    if not os.path.isfile(filename):
        print(f"File does not exist: {filename}")
        sys.exit()
    data = np.fromfile(filename, np.int8)
    img = cv2.imdecode(data, flags)
    return img

def preprocessing_img(img):
    if len(img.shape) < 3:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGRA)
    elif img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
    elif img.shape[2] == 1:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGRA)
    return img


def load_image(image_path):
    if os.path.isfile(image_path):
        img = imread(image_path, cv2.IMREAD_UNCHANGED)
    else:
        print(f'{image_path} not found.')
    return preprocessing_img(img)



def meshgrid(h, w):
    y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
    x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)

    arr = torch.cat([y, x], dim=-1)
    return arr


def delta_border(h, w):
    """
    :param h: height
    :param w: width
    :return: normalized distance to image border,
      wtith min distance = 0 at border and max dist = 0.5 at image center
    """
    lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
    arr = meshgrid(h, w) / lower_right_corner
    dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
    dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
    edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
    return edge_dist



def get_weighting(h, w, Ly, Lx, device):
  clip_min_weight = 0.01
  clip_max_weight = 0.5
  weighting = delta_border(h, w)
  weighting = torch.clip(weighting, clip_min_weight, clip_max_weight, )
  weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)


  return weighting

def get_fold_unfold(x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code
    """
    :param x: img of size (bs, c, h, w)
    :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
    """
    bs, nc, h, w = x.shape

    # number of crops in image
    Ly = (h - kernel_size[0]) // stride[0] + 1
    Lx = (w - kernel_size[1]) // stride[1] + 1

    if uf == 1 and df == 1:
        fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
        unfold = torch.nn.Unfold(**fold_params)

        fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)

        weighting = get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
        normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap
        weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))

    elif uf > 1 and df == 1:
        fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
        unfold = torch.nn.Unfold(**fold_params)

        fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
                            dilation=1, padding=0,
                            stride=(stride[0] * uf, stride[1] * uf))
        fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)

        weighting = get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
        normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap
        weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))

    elif df > 1 and uf == 1:
        fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
        unfold = torch.nn.Unfold(**fold_params)

        fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
                            dilation=1, padding=0,
                            stride=(stride[0] // df, stride[1] // df))
        fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)

        weighting = get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
        normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap
        weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))

    else:
        raise NotImplementedError

    return fold, unfold, normalization, weighting



def normalize_image(image, normalize_type='255'):
    """
    Normalize image
    Parameters
    ----------
    image: numpy array
        The image you want to normalize
    normalize_type: string
        Normalize type should be chosen from the type below.
        - '255': simply dividing by 255.0
        - '127.5': output range : -1 and 1
        - 'ImageNet': normalize by mean and std of ImageNet
        - 'None': no normalization
    Returns
    -------
    normalized_image: numpy array
    """
    if normalize_type == 'None':
        return image
    elif normalize_type == '255':
        return image / 255.0
    elif normalize_type == '127.5':
        return image / 127.5 - 1.0
    elif normalize_type == 'ImageNet':
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = image / 255.0
        for i in range(3):
            image[:, :, i] = (image[:, :, i] - mean[i]) / std[i]
        return image
    else:
        pass



def preprocess(img):
    im_h, im_w, _ = img.shape

    up_f = 4
    oh, ow = up_f * im_h, up_f * im_w

    img = normalize_image(img, normalize_type='255')

    c = img * 2 - 1
    c = c.transpose(2, 0, 1)  # HWC -> CHW
    c = np.expand_dims(c, axis=0)
    c = c.astype(np.float32)

    return None, c


def postprocess(sample):
    sample = np.clip(sample, -1., 1.)
    sample = (sample + 1.) / 2. * 255
    sample = np.transpose(sample, (1, 2, 0))
    sample = sample[:, :, ::-1]  # RGB -> BGR
    sample = sample.astype(np.uint8)

    return sample


def decode_first_stage(z):
    ks = (128, 128)
    stride = (64, 64)
    uf = 4

    bs, nc, h, w = z.shape

    fold, unfold, normalization, weighting = get_fold_unfold(z, ks, stride, uf=uf)

    z = unfold(z)  # (bn, nc * prod(**ks), L)

    # Reshape to img shape
    z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )


    print('first_stage_decode...')

    outputs = []
    for i in range(z.shape[-1]):
        x = z[:, :, :, :, i]
        output = first_stage_decode(x)
        outputs.append(output[0])

    o = torch.stack(outputs, axis=-1)  # # (bn, nc, ks[0], ks[1], L)
    o = o * weighting
    # Reverse 1. reshape to img shape
    o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
    # stitch crops together
    decoded = fold(o)
    decoded = decoded / normalization  # norm is shape (1, 1, h, w)
    return decoded




# ddpm
def apply_model(x, t, cond):
    x_noisy=x
    ks = (128, 128)
    stride = (64, 64)

    h, w = x_noisy.shape[-2:]

    fold, unfold, normalization, weighting = get_fold_unfold(x_noisy, ks, stride)


    z = unfold(x_noisy)  # (bn, nc * prod(**ks), L)
    # Reshape to img shape
    z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
    z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]

    c = unfold(cond)
    c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
    cond_list = [c[:, :, :, :, i] for i in range(c.shape[-1])]

    # apply model by loop over crops
    
    outputs = []
    for i in range(z.shape[-1]):
        x = z_list[i]
        cond = cond_list[i]
        xc = torch.cat([x, cond], dim=1)
        
        
        output = diffusion_model(xc, t)
            
       
        outputs.append(output[0])

    o = torch.stack(outputs, axis=-1)
    o = o * weighting
    # Reverse reshape to img shape
    o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
    # stitch crops together
    x_recon = fold(o) / normalization


    return x_recon

def warmup():
  v_0 = torch.rand(1,6,128,128, dtype=torch.float).half().cuda()
  v_1 = torch.randint(10, (1, ), dtype=torch.long).cuda()

  for d in range(2):
    with torch.cuda.amp.autocast(dtype=torch.float16):
      uaa = diffusion_model(v_0,v_1)
  v_0 = torch.rand(1,3,128,128, dtype=torch.float).cuda()
  for d in range(2):
    uaa = first_stage_decode(v_0)
  torch.cuda.empty_cache()


UseSamplr=sample_euler_ancestral
def predict(c):
    
    c=torch.tensor(c,device=cudev)


    sigmas = f_sigmas()

    noise = torch.randn(c.shape, dtype=torch.float,device=cudev)
    
    extra_args = {'cond': c}
    df=detail_strength/(detail_strength-1+float(sigmas[0]))
    print(df)
    with torch.cuda.amp.autocast(dtype=torch.float16):
        samples = UseSamplr(model_wrap_cfg, noise * sigmas[0] * df , sigmas, extra_args=extra_args, disable=False)
   
    x_sample = decode_first_stage(samples)

    img = postprocess(x_sample[0].cpu().numpy())

    return img

if model_wrap is None:
  first_stage_decode=torch.jit.load('/content/fsd_pnnx.pt').eval().cuda()
  diffusion_model=torch.jit.load('/content/dm_pnnx.pt').eval().half().cuda()
  warmup()
  torch.cuda.empty_cache()
  model_wrap = CompVisDenoiser(CompVisJIT())
  model_wrap_cfg = SRDenoiser(model_wrap)

In [None]:

image_path='/content/sample_data/10_0x0v1.png' #@param {type:'string'}

"""
ddim_timesteps
"""
ddim_eta = 0.75  #@param {type:'number'}
ddim_num_steps = 100  #@param {type:'number'}
ddpm_num_timesteps = 1000 #@param {type:'number'}

detail_strength=20000  #@param {type:'number'}

ddim_timesteps = make_ddim_timesteps(ddim_num_steps, ddpm_num_timesteps)




"""
ddim sampling parameters
"""

ddim_sigmas, ddim_alphas, ddim_alphas_prev = \
    make_ddim_sampling_parameters(
        alphacums=alphas_cumprod,
        ddim_timesteps=ddim_timesteps,
        eta=ddim_eta)

#ddim_sigmas=torch.tensor(ddim_sigmas.astype(np.float32),device=cudev)

ddim_sqrt_one_minus_alphas = np.sqrt(1. - ddim_alphas)



 




# inference
print('Start inference...')
if image_path.endswith('.npy'):
  c=latdec(image_path).detach()
else:
  img = load_image(image_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
  img = img[:, :, ::-1]  # BGR -> RGB
  _, c = preprocess(img)
  

img = predict(c)

# plot result
savepath = image_path[:-4]+'_4x.png'
print(f'saved at : {savepath}')
cv2.imwrite(savepath, img)



In [None]:
Sampler='euler_a' #@param ['euler', 'euler_a', 'heun','dpm_2','dpm_2_a','lms']
f_sampler()

Karras=False #@param {type:'boolean'}
KarrasRho = 7.0 #@param {type:'number'}

Optional: SD lat decoder

In [None]:
latent='4x6_1x1v1.npy' #@param {type:'string'}

ymg=Image.fromarray( (( ( latdec(latent)[0] +1)*127.5 ).cpu().numpy()).transpose(1,2,0).clip(0,255).astype(np.uint8) )
ymg.save(latent[:-4]+'.png')
ymg

Optional: GFPgan-jit

In [None]:
import cv2
import glob
import numpy as np
import os
import torch
from torch import nn
import math


from torchvision.transforms.functional import normalize
from itertools import product




def imwrite(img, file_path, params=None, auto_mkdir=True):

    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(file_path))
        os.makedirs(dir_name, exist_ok=True)
    ok = cv2.imwrite(file_path, img, params)
    if not ok:
        raise IOError('Failed in writing images.')



def bb_intersection_over_union(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    # compute the area of intersection rectangle
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)
    # return the intersection over union value
    return iou


def nms_boxes(boxes, scores, iou_thres):
    # Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU).

    keep = []
    for i, box_a in enumerate(boxes):
        is_keep = True
        for j in range(i):
            if not keep[j]:
                continue
            box_b = boxes[j]
            iou = bb_intersection_over_union(box_a, box_b)
            if iou >= iou_thres:
                if scores[i] > scores[j]:
                    keep[j] = False
                else:
                    is_keep = False
                    break

        keep.append(is_keep)

    return np.array(keep).nonzero()[0]





def get_anchor(image_size):
    
    min_sizes = [[16, 32], [64, 128], [256, 512]]
    steps = [8, 16, 32]
    feature_maps = [[math.ceil(image_size[0] / step), math.ceil(image_size[1] / step)] for step in steps]

    anchors = []
    for k, f in enumerate(feature_maps):
        m_sizes = min_sizes[k]
        for i, j in product(range(f[0]), range(f[1])):
            for min_size in m_sizes:
                s_kx = min_size / image_size[1]
                s_ky = min_size / image_size[0]
                dense_cx = [x * steps[k] / image_size[1] for x in [j + 0.5]]
                dense_cy = [y * steps[k] / image_size[0] for y in [i + 0.5]]
                for cy, cx in product(dense_cy, dense_cx):
                    anchors.extend([cx, cy, s_kx, s_ky])

    output = np.array(anchors).reshape(-1, 4)
    return output


# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """
    boxes = np.concatenate(
        (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
         priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]

    return boxes


def decode_landm(pre, priors, variances):
    """Decode landm from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        pre (tensor): landm predictions for loc layers,
            Shape: [num_priors,10]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded landm predictions
    """
    tmp = (
        priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
        priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
        priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
        priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
        priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
    )
    landms = np.concatenate(tmp, axis=1)

    return landms



def detect_faces(
        image,
        conf_threshold=0.8,
        nms_threshold=0.4,
        use_origin_size=True,
    ):
        
        height, width = image.shape[:2]
        image = image.transpose(2, 0, 1).astype(np.float32)
        image = torch.from_numpy(image).to(cudevg).unsqueeze(0)

        image = image - torch.tensor([[[[104.]], [[117.]], [[123.]]]])

        loc, conf, landmarks = RetinaFace(image)
        priors = get_anchor((height, width))

        variance = [0.1, 0.2]
        scale = np.array([width, height, width, height])
        scale1 = np.array([
            width, height, width, height, width, height, width, height, width, height
        ])

        boxes = decode(loc[0].cpu().numpy(), priors, variance)
        boxes = boxes * scale
        

        scores = conf[0][:, 1].cpu().numpy()

        landmarks = decode_landm(landmarks[0].cpu().numpy(), priors, variance)
        landmarks = landmarks * scale1
        

        # ignore low scores
        inds = np.where(scores > conf_threshold)[0]
        boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]

        # sort
        order = scores.argsort()[::-1]
        boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]

        # do NMS
        bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
        keep = nms_boxes(bounding_boxes[:, :4], bounding_boxes[:, 4], nms_threshold)
        bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
        return np.concatenate((bounding_boxes, landmarks), axis=1)

def get_largest_face(det_faces, h, w):

    def get_location(val, length):
        if val < 0:
            return 0
        elif val > length:
            return length
        else:
            return val

    face_areas = []
    for det_face in det_faces:
        left = get_location(det_face[0], w)
        right = get_location(det_face[2], w)
        top = get_location(det_face[1], h)
        bottom = get_location(det_face[3], h)
        face_area = (right - left) * (bottom - top)
        face_areas.append(face_area)
    largest_idx = face_areas.index(max(face_areas))
    return det_faces[largest_idx], largest_idx


def get_center_face(det_faces, h=0, w=0, center=None):
    if center is not None:
        center = np.array(center)
    else:
        center = np.array([w / 2, h / 2])
    center_dist = []
    for det_face in det_faces:
        face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
        dist = np.linalg.norm(face_center - center)
        center_dist.append(dist)
    center_idx = center_dist.index(min(center_dist))
    return det_faces[center_idx], center_idx






def img2tensor(imgs, bgr2rgb=True, float32=True):


    def _totensor(img, bgr2rgb, float32):
        if img.shape[2] == 3 and bgr2rgb:
            if img.dtype == 'float64':
                img = img.astype('float32')
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = torch.from_numpy(img.transpose(2, 0, 1))
        if float32:
            img = img.float()
        return img

    if isinstance(imgs, list):
        return [_totensor(img, bgr2rgb, float32) for img in imgs]
    else:
        return _totensor(imgs, bgr2rgb, float32)



def read_image(img):
    """img can be image path or cv2 loaded image."""
    # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]


    if np.max(img) > 256:  # 16-bit image
        img = (img / 65535) * 255
    if len(img.shape) == 2:  # gray image
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    elif img.shape[2] == 4:  # RGBA image with alpha channel
        img = img[:, :, 0:3]

    return img
'''
def srproc(img,fac):
  return cv2.resize(img, None,fx=fac,fy=fac, interpolation=cv2.INTER_LINEAR)
'''
def srproc(img,fac):
  _,c=preprocess(img[:, :, ::-1])
  return predict(c)

class faceimg:
  def __init__(self, image,
                 face_size=512,
                 crop_ratio=(1, 1),
                 save_ext='png',
                 template_3points=False,
                 pad_blur=False,
                 use_parse=False,
                 device=None):
    self.nXimage=read_image(image)
    downscale=1/upscale
    self.input_img=cv2.resize(self.nXimage,None,fx=downscale,fy=downscale,interpolation=cv2.INTER_AREA)
    self.template_3points = template_3points  # improve robustness
    self.upscale_factor = upscale
    # the cropped face ratio based on the square face
    self.crop_ratio = crop_ratio  # (h, w)
    assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
    self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))

    if self.template_3points:
        self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
    else:
        # standard 5 landmarks for FFHQ faces with 512 x 512
        self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
                                        [201.26117, 371.41043], [313.08905, 371.15118]])
    self.face_template = self.face_template * (face_size / 512.0)
    if self.crop_ratio[0] > 1:
        self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
    if self.crop_ratio[1] > 1:
        self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
    self.save_ext = save_ext
    self.pad_blur = pad_blur
    if self.pad_blur is True:
        self.template_3points = False

    self.all_landmarks_5 = []
    self.det_faces = []
    self.affine_matrices = []
    self.inverse_affine_matrices = []
    self.cropped_faces = []
    self.pad_input_imgs = []
    self.restored_faces=[]


    # init face parsing model
    self.use_parse = use_parse
  def get_face_landmarks_5(self,
              only_keep_largest=False,
              only_center_face=False,
              resize=None,
              blur_ratio=0.01,
              eye_dist_threshold=None):
    if resize is None:
        scale = 1
        input_img = self.input_img
    else:
        h, w = self.input_img.shape[0:2]
        scale = min(h, w) / resize
        h, w = int(h / scale), int(w / scale)
        input_img = cv2.resize(self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4)

    with torch.no_grad():
        bboxes = detect_faces( input_img ) * scale #0.97
    for bbox in bboxes:
        # remove faces with too small eye distance: side faces or too small faces
        eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
        if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
            continue

        if self.template_3points:
            landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
        else:
            landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
        self.all_landmarks_5.append(landmark)
        self.det_faces.append(bbox[0:5])
    if len(self.det_faces) == 0:
        return 0
    if only_keep_largest:
        h, w, _ = self.input_img.shape
        self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
        self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
    elif only_center_face:
        h, w, _ = self.input_img.shape
        self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
        self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]

    # pad blurry images
    if self.pad_blur:
        self.pad_input_imgs = []
        for landmarks in self.all_landmarks_5:
            # get landmarks
            eye_left = landmarks[0, :]
            eye_right = landmarks[1, :]
            eye_avg = (eye_left + eye_right) * 0.5
            mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
            eye_to_eye = eye_right - eye_left
            eye_to_mouth = mouth_avg - eye_avg

            # Get the oriented crop rectangle
            # x: half width of the oriented crop rectangle
            x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
            #  - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
            # norm with the hypotenuse: get the direction
            x /= np.hypot(*x)  # get the hypotenuse of a right triangle
            rect_scale = 1.5
            x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
            # y: half height of the oriented crop rectangle
            y = np.flipud(x) * [-1, 1]

            # c: center
            c = eye_avg + eye_to_mouth * 0.1
            # quad: (left_top, left_bottom, right_bottom, right_top)
            quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
            # qsize: side length of the square
            qsize = np.hypot(*x) * 2
            border = max(int(np.rint(qsize * 0.1)), 3)

            # get pad
            # pad: (width_left, height_top, width_right, height_bottom)
            pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
                    int(np.ceil(max(quad[:, 1]))))
            pad = [
                max(-pad[0] + border, 1),
                max(-pad[1] + border, 1),
                max(pad[2] - self.input_img.shape[0] + border, 1),
                max(pad[3] - self.input_img.shape[1] + border, 1)
            ]

            if max(pad) > 1:
                # pad image
                pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
                # modify landmark coords
                landmarks[:, 0] += pad[0]
                landmarks[:, 1] += pad[1]
                # blur pad images
                h, w, _ = pad_img.shape
                y, x, _ = np.ogrid[:h, :w, :1]
                mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
                                                    np.float32(w - 1 - x) / pad[2]),
                                  1.0 - np.minimum(np.float32(y) / pad[1],
                                                    np.float32(h - 1 - y) / pad[3]))
                blur = int(qsize * blur_ratio)
                if blur % 2 == 0:
                    blur += 1
                blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
                # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)

                pad_img = pad_img.astype('float32')
                pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
                pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
                pad_img = np.clip(pad_img, 0, 255)  # float32, [0, 255]
                self.pad_input_imgs.append(pad_img)
            else:
                self.pad_input_imgs.append(np.copy(self.input_img))

    return len(self.all_landmarks_5)
  def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
    """Align and warp faces with face template.
    """
    if self.pad_blur:
        assert len(self.pad_input_imgs) == len(
            self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
    for idx, landmark in enumerate(self.all_landmarks_5):
        # use 5 landmarks to get affine matrix
        # use cv2.LMEDS method for the equivalence to skimage transform
        # ref: https://blog.csdn.net/yichxi/article/details/115827338
        affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
        self.affine_matrices.append(affine_matrix)
        # warp and crop faces
        if border_mode == 'constant':
            border_mode = cv2.BORDER_CONSTANT
        elif border_mode == 'reflect101':
            border_mode = cv2.BORDER_REFLECT101
        elif border_mode == 'reflect':
            border_mode = cv2.BORDER_REFLECT
        if self.pad_blur:
            input_img = self.pad_input_imgs[idx]
        else:
            input_img = self.input_img
        cropped_face = cv2.warpAffine(
            input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132))  # gray
        self.cropped_faces.append(cropped_face)
        # save the cropped face
        if save_cropped_path is not None:
            path = os.path.splitext(save_cropped_path)[0]
            save_path = f'{path}_{idx:02d}.{self.save_ext}'
            imwrite(cropped_face, save_path)
  def add_restored_face(self, face):
    self.restored_faces.append(face)
  def get_inverse_affine(self, save_inverse_affine_path=None):
    """Get inverse affine matrix."""
    for idx, affine_matrix in enumerate(self.affine_matrices):
        inverse_affine = cv2.invertAffineTransform(affine_matrix)
        inverse_affine[:, 2]*= self.upscale_factor
        #inverse_affine *= self.upscale_factor
        self.inverse_affine_matrices.append(inverse_affine)
        # save inverse affine matrices
        if save_inverse_affine_path is not None:
            path, _ = os.path.splitext(save_inverse_affine_path)
            save_path = f'{path}_{idx:02d}.pth'
            torch.save(inverse_affine, save_path)
  def paste_faces_to_input_image(self, save_path=None):
    h, w, _ = self.input_img.shape
    h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)

    upsample_img = self.nXimage

    assert len(self.restored_faces) == len(
        self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
    maskpool=None
    restorepool=None
    for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
        
        if (inverse_affine[0][0]*self.upscale_factor) < 1.5:
          inverse_affine[:, 2]/= self.upscale_factor
          inverse_affine*=self.upscale_factor
          restored_face=restored_face.astype('uint8')
        else:
          restored_face=srproc(restored_face,self.upscale_factor).astype('uint8')
          

        if self.upscale_factor > 1:
            extra_offset = 0.5 * self.upscale_factor
        else:
            extra_offset = 0
        inverse_affine[:, 2] += extra_offset
        
        inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))

        
        # inference
        face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
        face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
        normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
        face_input = torch.unsqueeze(face_input, 0).to(cudevg)
        with torch.no_grad():
            out = face_parse(face_input)[0]
        out = out.argmax(dim=1).squeeze().cpu().numpy()

        mask = np.zeros(out.shape)
        MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
        for idx, color in enumerate(MASK_COLORMAP):
            mask[out == idx] = color
        #  blur the mask
        mask = cv2.GaussianBlur(mask, (101, 101), 11)
        mask = cv2.GaussianBlur(mask, (101, 101), 11)
        # remove the black borders
        thres = 10
        mask[:thres, :] = 0
        mask[-thres:, :] = 0
        mask[:, :thres] = 0
        mask[:, -thres:] = 0
        mask = mask / 255.

        mask = cv2.resize(mask, restored_face.shape[:2])
        mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
        inv_soft_mask = mask[:, :, None]
        pasted_face = inv_restored

        if maskpool is None:
          maskpool=inv_soft_mask
          restorepool=np.zeros(inv_soft_mask.shape)
          blanc=np.ones(inv_soft_mask.shape)
        else:
          maskpool = inv_soft_mask*blanc+(1 - inv_soft_mask)*maskpool

        inv_hard_mask=np.array(inv_soft_mask, copy=True)
        inv_hard_mask[np.where(inv_hard_mask!=0)]=1.0
        restorepool = inv_hard_mask * pasted_face + (1 - inv_hard_mask) * restorepool

    if np.max(upsample_img) > 256:  # 16-bit image
        upsample_img = np.concatenate((restorepool, maskpool*65535), axis=2).astype(np.uint16)
    else:
        upsample_img = np.concatenate((restorepool, maskpool*255), axis=2).astype(np.uint8)
    if save_path is not None:
        path = os.path.splitext(save_path)[0]
        save_path = f'{path}.{self.save_ext}'
        imwrite(upsample_img, save_path)
    return upsample_img


def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):

    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')

    if torch.is_tensor(tensor):
        tensor = [tensor]
    result = []
    for _tensor in tensor:
        _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])

        n_dim = _tensor.dim()
        if n_dim == 4:
            img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
            img_np = img_np.transpose(1, 2, 0)
            if rgb2bgr:
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
        elif n_dim == 3:
            img_np = _tensor.numpy()
            img_np = img_np.transpose(1, 2, 0)
            if img_np.shape[2] == 1:  # gray image
                img_np = np.squeeze(img_np, axis=2)
            else:
                if rgb2bgr:
                    img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
        elif n_dim == 2:
            img_np = _tensor.numpy()
        else:
            raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
        if out_type == np.uint8:
            # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
            img_np = (img_np * 255.0).round()
        img_np = img_np.astype(out_type)
        result.append(img_np)
    if len(result) == 1:
        result = result[0]
    return result

def doenh_gfp(cropped_face_t):
  global gfpgan_enc
  global gfpgan_dec
  if gfpgan_enc is None:
    gfpgan_enc =torch.jit.load('gfpgan_enc_pnnx.pt').eval().to(cudevg)
    gfpgan_dec =torch.jit.load('gfpgan_dec_pnnx.pt').eval().to(cudevg)
  latent, conditions = gfpgan_enc(cropped_face_t)
  output = gfpgan_dec(latent,*conditions)
  return output

doenh=doenh_gfp

@torch.no_grad()
def enhance(img, has_aligned=False, only_center_face=False, paste_back=True):
  
  faces=faceimg(img)

  if has_aligned:  # the inputs are already aligned
      img = cv2.resize(img, (512, 512))
      faces.cropped_faces = [img]
  else:
      faces.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
      # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
      # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
      # align and warp each face
      faces.align_warp_face()

  # face restoration
  for cropped_face in faces.cropped_faces:
      # prepare data
      cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
      normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
      cropped_face_t = cropped_face_t.unsqueeze(0).to(cudevg)

      output = doenh(cropped_face_t)
      restored_face = tensor2img(output[0].cpu(), rgb2bgr=True, min_max=(-1, 1))
  

      restored_face = restored_face
      faces.add_restored_face(restored_face)

  if not has_aligned and paste_back:
      # upsample the background
      

      faces.get_inverse_affine(None)
      # paste each restored face to the input image
      restored_img = faces.paste_faces_to_input_image()
      return faces, restored_img
  else:
      return faces, None



if not os.path.isfile('retinaface_pnnx.pt'):
  !wget https://huggingface.co/Larvik/GFPGANjit/resolve/main/face_parse_pnnx.pt
  !wget https://huggingface.co/Larvik/GFPGANjit/resolve/main/gfpgan_dec_pnnx.pt
  !wget https://huggingface.co/Larvik/GFPGANjit/resolve/main/gfpgan_enc_pnnx.pt
  !wget https://huggingface.co/Larvik/GFPGANjit/resolve/main/retinaface_pnnx.pt

GFPgan_device='cpu' #@param ['cpu', 'cuda']
cudevg=torch.device(GFPgan_device)


gfpgan_enc=None
RetinaFace =torch.jit.load('retinaface_pnnx.pt').eval().to(cudevg)
face_parse =torch.jit.load('face_parse_pnnx.pt').eval().to(cudevg)


In [None]:
input='/content/aaa2_4x.png' #@param {type:'string'}
output='results'

upscale=4
suffix=None
only_center_face=False
aligned=False
ext='auto'


# ------------------------ input & output ------------------------
if input.endswith('/'):
    input = input[:-1]
if os.path.isfile(input):
    img_list = [input]
else:
    img_list = sorted(glob.glob(os.path.join(input, '*')))

os.makedirs(output, exist_ok=True)





# ------------------------ restore ------------------------
for img_path in img_list:
    # read image
    img_name = os.path.basename(img_path)
    print(f'Processing {img_name} ...')
    basename, ext = os.path.splitext(img_name)
    input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)

    # restore faces and background if necessary
    faces, restored_img = enhance(input_img, has_aligned=aligned, only_center_face=only_center_face, paste_back=True)

    # save faces
    for idx, (cropped_face, restored_face) in enumerate(zip(faces.cropped_faces, faces.restored_faces)):
        # save cropped face
        save_crop_path = os.path.join(output, 'cropped_faces', f'{basename}_{idx:02d}.png')
        imwrite(cropped_face, save_crop_path)
        # save restored face
        if suffix is not None:
            save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
        else:
            save_face_name = f'{basename}_{idx:02d}.png'
        save_restore_path = os.path.join(output, 'restored_faces', save_face_name)
        imwrite(restored_face, save_restore_path)
        # save comparison image
        cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
        imwrite(cmp_img, os.path.join(output, 'cmp', f'{basename}_{idx:02d}.png'))

    # save restored img
    if restored_img is not None:
        if ext == 'auto':
            extension = ext[1:]
        else:
            extension = ext

        if suffix is not None:
            save_restore_path = os.path.join(output, 'restored_imgs', f'{basename}_{suffix}.png')
        else:
            save_restore_path = os.path.join(output, 'restored_imgs', f'{basename}.png')
        imwrite(restored_img, save_restore_path)

print(f'Results are in the [{output}] folder.')


# txt2img

In [None]:
SDver='470k' #@param ['440k', '470k']
Dfm='Orig' #@param ['Orig', '_imgemb','_a19561','_a17750','_a17750_e5750']
if Dfm=='Orig':
  Dfm=''

import os

f_dljit(SDver,Dfm)

import sys
import time
import random
import numpy as np
import cv2
from PIL import Image
import PIL
from transformers import CLIPTokenizer

import torch

torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True

alphas_cumprod = np.load('alphas_cumprod.npz')['a']

cudev=torch.device('cuda')



inzdict=[
'*',265,
'»',7599,
'¿',17133,
'¥',20199,
'®',8436]




preimg=None
revpreimg=None




# ddpm
def apply_model(x, t, cond):

    h, emb, hs = diffusion_emb(x, t, cond)
    
    h = diffusion_mid(h, emb, cond, *hs[6:])

    output = diffusion_out(h, emb, cond, *hs[:6])

    return output


# decoder
def decode_first_stage(z):

    output = autoencoder(z/0.18215)
       
    return output


    
def load_img(path):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) from {path}")
    w2, h2 = map(lambda x: x - x % 32, (w, h))
    if w!=w2 or h!=h2:
      image = image.resize((w2, h2), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

depthLimit=10

def txtErr(prmt0,msg):
  print(msg)
  prmt=prmt0.split('/')[-1][:-4]
  print('err prompt: '+prmt)
  return [cond_stage_model.encode(prmt, n_samples)]

def intptxtemb(stz):
  prmpl=len(stz)>>1
  ptxt=[]
  pstp=[]
  for i in range(prmpl):
    ptxt.append(  makeCs(stz[2*i],1)[0]  )
    pstp.append(  int(stz[2*i+1])+1  )
  prmpl-=1
  intpos=[]
  for vv in range(prmpl):
    c1=ptxt[vv]
    c2=ptxt[vv+1]
    stp=pstp[vv]
    for i in range(stp):
      intpos.append((c2*i+c1*(stp-i))/stp)

  lztbk=pstp[-1]
  if lztbk > 1:
    c1=ptxt[prmpl]
    c2=ptxt[0]
    for i in range(lztbk):
      intpos.append((c2*i+c1*(lztbk-i))/lztbk)
  else:
    intpos.append(ptxt[-1])
  return intpos

def mkcondlist(stz):
  prmpl=len(stz)>>1
  ptxt=[]
  pstp=[]
  stpsum=0
  for i in range(prmpl):
    ptxt.append(  makeCs(stz[2*i],1)[0]  )
    soi=int(stz[2*i+1])
    stpsum+=soi
    pstp.append(  soi  )
  retar=[0]*stpsum
  stpsum=0
  k=0
  for stp in pstp:
    pmt=ptxt[k]
    for aaa in range(stp):
      retar[stpsum]=pmt
      stpsum+=1
    k+=1
  return [retar]



def encodepatt():
  ozi=output_pattern.split('/')[-1]
  pdir=output_pattern[:-len(ozi)-1]
  flist=os.listdir(pdir)
  flist.sort()
  pdir+='/'
  rpt=load_img(pdir+flist[0])
  vB=1
  vH=rpt.size(2)
  vW=rpt.size(3)
  thsize=torch.Size([vB,4,vH>>3,vW>>3])
  noyaz=torch.randn(thsize)
  zadd=0
  for f in flist:
    if f.endswith('.png'):
      vlat=imgenc(  load_img( pdir+f) ,  noyaz )*0.18215
      vlat.numpy().tofile(pdir+f[:-3]+'bin')
      zadd+=1
  with open(output_pattern[:-3].replace('%','!@!')+'txt','wt') as f:
    f.write(str(list(thsize))[1:-1]+'\n'+str(zadd))
  !rm {pdir}*.png





def makeCs(prmt,depth):
  global get_cond
  global ddim_num_steps
  if prmt.endswith('.txt'):
    if depth > depthLimit:
      return txtErr(prmt,'Too many ref, probably circular reference.')
    depth+=1
    if not os.path.isfile(prmt):
      return txtErr(prmt,'ref not found.')
    with open(prmt,'rt') as f:
      stz=f.read().splitlines()
    cmd=stz[0].replace(' ','').replace('\t','').split('/')
    cmd0=cmd[0]
    if cmd0.startswith('intp'):
      if depth > 1:
        return txtErr(stz[1],'do not intp in ref')
      return intptxtemb(stz[1:])
    elif cmd0.startswith('dymc'):
      if depth > 1:
        return txtErr(stz[1],'do not dymc in ref')
      get_cond=get_cond_list
      ddim_num_steps=-1
      return mkcondlist(stz[1:])


    prmpl=(len(stz)-1)>>1
    stz=stz[1:]
    ptxt=[]
    pwgt=[]
    wgtsum=0
    for i in range(prmpl):
      ptxt.append(  makeCs(stz[2*i],depth)[0]  )
      wgt=float(stz[2*i+1])
      wgtsum+=wgt
      pwgt.append(  wgt  )
    if cmd0.startswith('avg'):
      for i in range(prmpl):
        pwgt[i]=pwgt[i]/wgtsum
    
    cout=ptxt[0]*pwgt[0]
    for i in range(1,prmpl):
      cout+=(ptxt[i]*pwgt[i])
    return [cout]
  elif prmt.endswith('.bin'):
    return [torch.tensor(np.fromfile(prmt,dtype=np.float32)).reshape((-1,768)).broadcast_to(n_samples,77,768).cuda()]
  else:
    return [cond_stage_model.encode(prmt, n_samples)]


def warmup():
  v_0 = torch.rand(2, 4, 32, 32, dtype=torch.float).half().cuda()
  v_1 = torch.randint(10, (2, ), dtype=torch.long).cuda()
  v_2 = torch.rand(2, 77, 768, dtype=torch.float).half().cuda()
  for d in range(2):
    with torch.cuda.amp.autocast(dtype=torch.float16):
      uaa = apply_model(v_0,v_1,v_2)
  v_0 = torch.rand(1, 4, 32, 32, dtype=torch.float).cuda()
  for d in range(2):
    uaa = autoencoder(v_0)
  torch.cuda.empty_cache()
  



fext='_%dx%dv%d.png'
def saver():
  global x_samples
  i=iita
  np.save( (outputp+fext%(i,1,ktta))[:-4] + '.npy', samples)
  x_samples = np.clip((x_samples.numpy() + 1.0) / 2.0, a_min=0.0, a_max=1.0)
  k=0
  for x_sample in x_samples:
      x_sample = x_sample.transpose(1, 2, 0)  # CHW -> HWC
      x_sample = x_sample * 255
      img = x_sample.astype(np.uint8)
      img = img[:, :, ::-1]  # RGB -> BGR
      cv2.imwrite(outputp+fext%(i,k,ktta), img)
      k+=1
  

UseSamplr=sample_lms

def predict(prompt, uc):
    global x_samples
    global samples
    global ktta
    global tmpfeeder
    global noise
    global ddim_num_steps
    global get_cond
    get_cond=get_cond_simp

    c_list = makeCs(prompt,0)

    if ddim_num_steps < 0:
      ddim_num_steps=len(c_list[0])

    feeder=ifeeder()
    
    sigmas = f_sigmas()

    noise = torch.randn(shape, dtype=torch.float,device=cudev)
    if preimg is not None:
      t_enc= int(strength * ddim_num_steps)
      sigma_sched = sigmas[ddim_num_steps - t_enc - 1:]
      if preimg.dim()==1:
        cmd0=int(preimg[0])
        if cmd0 == 2:
          feeder.pattern=tmpfeeder.pattern
          feeder.shape=tmpfeeder.shape
          feeder.getn=feeder.get_npbins
          feeder.noiseadd=noise * sigmas[ddim_num_steps - t_enc - 1]
          c_list=[c_list[0]]*tmpfeeder.xpenlen
          feeder.xpenlen=tmpfeeder.xpenlen
          tmpfeeder=feeder
      else:
        img = preimg.cuda() + noise * sigmas[ddim_num_steps - t_enc - 1]
        feeder.setbs(img)
    else:
      img = noise*sigmas[0]
      sigma_sched=sigmas
      feeder.setbs(img)


    ktta=0
    for c in c_list:
      extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
      with torch.cuda.amp.autocast(dtype=torch.float16):
        samples = UseSamplr(model_wrap_cfg, feeder.getn(ktta), sigma_sched, extra_args=extra_args, disable=False)
      ktta+=1
      
      
      x_samples = decode_first_stage(  samples ).cpu()
      samples=samples.cpu()
      t3 = Thread(target = saver)
      a3 = t3.start()

    return

predict_orig=predict

def init_img_type():
  global init_img
  global tmpfeeder
  if init_img.endswith('.npy'):
    return 0
  elif init_img.endswith('.jpg') or init_img.endswith('.png'):
    if os.path.isfile(init_img+'.npy'):
      init_img+='.npy'
      return 0
    else:
      return 1
  elif init_img.endswith('.txt'):
    return 2
  else:
    return 99


if model_wrap is None:
  cond_stage_model = BERTEmbedder(torch.jit.load('transformer_pnnx.pt').eval())
  diffusion_emb = torch.jit.load('diffusion_emb_pnnx.pt').eval().half().cuda()
  diffusion_mid = torch.jit.load('diffusion_mid_pnnx.pt').eval().half().cuda()
  diffusion_out = torch.jit.load('diffusion_out_pnnx.pt').eval().half().cuda()
  autoencoder = torch.jit.load('autoencoder_pnnx.pt').eval().cuda()
  SDlatDEC=autoencoder
  imgenc = torch.jit.load('imgencoder_pnnx.pt').eval()
  warmup()
  torch.cuda.empty_cache()

  model_wrap = CompVisDenoiser(CompVisJIT())
  model_wrap_cfg = CFGDenoiser(model_wrap)

👇Optional👇

In [None]:
Sampler='euler_a' #@param ['euler', 'euler_a', 'heun','dpm_2','dpm_2_a','lms']
f_sampler()

Karras=False #@param {type:'boolean'}
KarrasRho = 7.0 #@param {type:'number'}

In [None]:
cond_stage_model.insert('<>')
cond_stage_model.insert('<majipuri>')
cond_stage_model.insert('<pekora>')

In [None]:
init_img='xxx' #@param {type:'string'}
strength=0.5 #@param {type:'number'}
initymgtyp=init_img_type()
if initymgtyp == 0:
  preimg=torch.tensor(np.load(init_img), device='cpu')
  n_samples=preimg.size(0)
  H=preimg.size(2)<<3
  W=preimg.size(3)<<3
elif initymgtyp == 1:
  n_samples=1
  rpt=load_img(init_img)
  H=rpt.size(2)
  W=rpt.size(3)
  preimg=imgenc(  rpt, torch.randn(torch.Size([n_samples,4,H>>3,W>>3]))  )*0.18215
  np.save(init_img+'.npy',preimg.numpy())
elif initymgtyp == 2:
  tmpfeeder=ifeeder()
  tmpfeeder.pattern=init_img[:-3].replace('!@!','%')+'bin'
  with open(init_img,'rt') as f:
    stz=f.read().replace(' ','').replace('\t','').splitlines()
  tmpfeeder.xpenlen=int(stz[1])
  stz=stz[0].split(',')
  tmpfeeder.shape=[ int(stz[0]), int(stz[1]), int(stz[2]), int(stz[3]) ]
  n_samples=tmpfeeder.shape[0]
  H=tmpfeeder.shape[2]<<3
  W=tmpfeeder.shape[3]<<3
  preimg=torch.tensor([2])
else:
  revpre=revpre0
  preimg=None
revpreimg=None

infilling

In [None]:
FillFromNoise=False #@param {type:'boolean'}
masknpy='bench2_mask.npy' #@param {type:'string'}
zamask=torch.tensor(np.load(masknpy)).cuda()
revpreimg=preimg.cuda()
if FillFromNoise:
  preimg=None
revpre=revpre1

Prompt interpolation with latent re-feeding

In [None]:
Revert2Orig=False #@param {type:'boolean'}


def predict(prompt, uc):
    global x_samples
    global samples
    global ktta
    global tmpfeeder
    global noise
    preimg=None
    

    c_list = makeCs(prompt,0)
    feeder=ifeeder()

    
    
    sigmas = f_sigmas()

    noise = torch.randn(shape, dtype=torch.float,device=cudev)

    ktta=0
    for c in c_list:
      
      if preimg is not None:
        noise=torch.permute(noise, (0,3,1,2)).reshape(noise.shape)
        t_enc= int(strength * ddim_num_steps)
        sigma_sched = sigmas[ddim_num_steps - t_enc - 1:]
        img = preimg.cuda() +  noise* sigmas[ddim_num_steps - t_enc - 1]
        feeder.setbs(img)        
      else:
        img = noise*sigmas[0]
        sigma_sched=sigmas
        feeder.setbs(img)



      
      
      extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
      with torch.cuda.amp.autocast(dtype=torch.float16):
        samples = UseSamplr(model_wrap_cfg, feeder.getn(ktta), sigma_sched, extra_args=extra_args, disable=False)
    
      ktta+=1
      preimg=samples
      x_samples = decode_first_stage( samples ).cpu()
      samples=samples.cpu()
      
      t3 = Thread(target = saver)
      a3 = t3.start()
    
    return
if Revert2Orig:
  predict=predict_orig
else:
  strength=0.75 #@param {type:'number'}

NoiseMap interpolation<br>re-feed previous when strength > 0

In [None]:
Revert2Orig=False #@param {type:'boolean'}


def mknoises():
  sil=len(Seed_Interval_list)>>1
  nolist=[]
  for n in range(sil):
    zeed=Seed_Interval_list[n*2]
    if zeed < 1:
      zeed=random.randint(0, 2**32)
      print('seed%d='%n)
      print(zeed)
    torch.manual_seed(zeed)
    nolist.append( torch.randn(shape, dtype=torch.float,device=cudev) )
  nolist.append(nolist[0])
  interpos=[]
  DOT_THRESHOLD=0.9995
  for n in range(sil):
    stp=Seed_Interval_list[n*2+1]+1
    v0=nolist[n]
    v1=nolist[n+1]
    dot = torch.sum(v0 * v1 / (torch.linalg.norm(v0) * torch.linalg.norm(v1)))
    if torch.abs(dot) > DOT_THRESHOLD:
      for j in range(stp):
        t=j/stp
        interpos.append( (1 - t) * v0 + t * v1 )
    else:
      theta_0 = torch.acos(dot)
      sin_theta_0 = torch.sin(theta_0)
      for j in range(stp):
        t=j/stp
        theta_t = theta_0 * t
        sin_theta_t = torch.sin(theta_t)
        s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        interpos.append( s0 * v0 + s1 * v1 )

  return interpos
  


def predict(prompt, uc):
    global x_samples
    global samples
    global ktta
    global tmpfeeder
    global noise
    preimg=None
    

    c_list = makeCs(prompt,0)
    feeder=ifeeder()

    
    
    sigmas = f_sigmas()

    noise = mknoises()
    c_list=c_list*len(noise)

    ktta=0
    for c in c_list:
      
      if preimg is not None:
        t_enc= int(strength * ddim_num_steps)
        sigma_sched = sigmas[ddim_num_steps - t_enc - 1:]
        img = preimg.cuda() +  noise[ktta]* sigmas[ddim_num_steps - t_enc - 1]
        feeder.setbs(img)        
      else:
        img = noise[ktta]*sigmas[0]
        sigma_sched=sigmas
        feeder.setbs(img)



      
      
      extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
      with torch.cuda.amp.autocast(dtype=torch.float16):
        samples = UseSamplr(model_wrap_cfg, feeder.getn(ktta), sigma_sched, extra_args=extra_args, disable=False)
    
      ktta+=1
      if strength > 0:
        preimg=samples*(1-strength)
      x_samples = decode_first_stage( samples ).cpu()
      samples=samples.cpu()
      
      t3 = Thread(target = saver)
      a3 = t3.start()
    
    return
if Revert2Orig:
  predict=predict_orig
else:
  strength=0 #@param {type:'number'}
  Seed_Interval_list=[    775577,10,    881188,10,    996699,10    ] #@param {type:'raw'}

☝️Optional☝️

In [None]:
InThread=False #@param {type:'boolean'}

prompt = 'a photograph of an astronaut riding a horse' #@param {type:'string'}
neg_prompt = '' #@param {type:'string'}

n_iter = 1 #@param {type:'integer'}
if preimg is None and revpreimg is None:
  n_samples = 1 #@param {type:'integer'}
  H=704 #@param {type:'integer'}
  W=768 #@param {type:'integer'}



ddim_num_steps = 50  #@param {type:'integer'}
ddpm_num_timesteps = 1000

seed=0 #@param {type:'integer'}

outputp='/content/sample_data' #@param {type:'string'}







cfg_scale = 7.5 #@param {type:'number'}
ddim_eta = 1.0  #@param {type:'number'}


outputp=outputp+'/'+str(len(os.listdir(outputp)))
"""
ddim_timesteps
"""

ddim_timesteps = make_ddim_timesteps(
    ddim_num_steps, ddpm_num_timesteps)

"""
ddim sampling parameters
"""

ddim_sigmas, ddim_alphas, ddim_alphas_prev = \
    make_ddim_sampling_parameters(
        alphacums=alphas_cumprod,
        ddim_timesteps=ddim_timesteps,
        eta=ddim_eta)

ddim_sqrt_one_minus_alphas = np.sqrt(1. - ddim_alphas)

shape = [n_samples, 4, H>>3 , W>>3 ]


makerng()


print("prompt: %s" % prompt)



print('Start inference...')
uc = None
if cfg_scale != 1.0:
  uc = cond_stage_model.encode(neg_prompt, n_samples)


  
def wpa():
  global iita
  torch.set_grad_enabled(False)
  
  for iita in range(n_iter):
      print("iteration: %s" % (iita + 1))
      predict(prompt, uc)
      
  print('Script finished successfully.')
  torch.cuda.empty_cache()


if InThread:
  t1 = Thread(target = wpa)
  a1 = t1.start()
else:
  wpa()


In [None]:
!nvidia-smi

In [None]:
!ffmpeg -framerate 3 -i /content/sample_data/48_0x3v%d.png intp03.mp4

# Tools
designed for the gen proc running with `InThread` or gradio app<br>
so imgenc (image->latent encoder) is on cpu

Gif/Video to latent pack

In [None]:
input_anim  = '/content/senpai.gif' #@param {type:'string'}
output_pattern = '/content/ijj/senpai_%04d.png' #@param {type:'string'}
!ffmpeg -i {input_anim} {output_pattern}

Resize the output to `(64*n)x(64*m)` first

In [None]:
encodepatt()

In [None]:
ImageEmbSetup  = False #@param {type:'boolean'}

import os
import torch
from PIL import Image
from torchvision import transforms
def load_im(im_path):
    if im_path.startswith("http"):
        response = requests.get(im_path)
        response.raise_for_status()
        im = Image.open(BytesIO(response.content))
    else:
        im = Image.open(im_path).convert("RGB")
    tforms = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
    ])
    inp = tforms(im).unsqueeze(0)
    return inp*2-1
if not os.path.isfile('imgemb.pt'):
  !wget https://huggingface.co/Larvik/imgemb/resolve/main/imgemb.pt
imgemb=torch.jit.load('imgemb.pt').float()

In [None]:
imgemb(load_im('/content/chaz512.jpg')).numpy().tofile('chaz.bin')

# Gradio Gui
tho I don't really understand why you want a webui inside another webui

In [None]:
gradio=False #@param {type:'boolean'}

!pip install gradio
from google.colab import output
import gradio as gr

def dream():
  return


dream_interface = gr.Interface(
    dream,
    inputs=[
        gr.Textbox(placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
        gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
        gr.Checkbox(label='Enable PLMS sampling', value=False),
        gr.Checkbox(label='Enable Fixed Code sampling', value=False),
        gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
        gr.Slider(minimum=1, maximum=50, step=1, label='Sampling iterations', value=8),
        gr.Slider(minimum=1, maximum=8, step=1, label='Samples per iteration', value=1),
        gr.Slider(minimum=1.0, maximum=20.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
        gr.Number(label='Seed', value=-1),
        gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=704),
        gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=768),
    ],
    outputs=[
        gr.Gallery(),
        gr.Number(label='Seed')
    ],
    title="Stable Diffusion Text-to-Image",
    description="Generate images from text with Stable Diffusion",
)


gdemo = gr.TabbedInterface(interface_list=[dream_interface], tab_names=["Dream"])


output.serve_kernel_port_as_window(8233, path='/dl.htm')

Copy the link above to `GoogleLocal`

In [None]:
GoogleLocal = 'aaaaa' #@param {type:'string'}
if '.googleusercontent.com' in GoogleLocal:
  gdemo.launch()
else:
  print('set a valid GoogleLocal')

# glid-3-xl-stable

In [None]:
SDver='470k' #@param ['440k', '470k']
Dfm='Orig' #@param ['Orig', '_imgemb','_a19561','_a17750','_a17750_e5750']
if Dfm=='Orig':
  Dfm=''
import os
import torch
from torch import nn
from torch.nn import functional as F


torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True


def get_keys_to_submodule(model):
    keys_to_submodule = {}
    # iterate all submodules
    for submodule_name, submodule in model.named_modules():
        # iterate all paramters in each submobule
        for param_name, param in submodule.named_parameters():
            # param_name is organized as <name>.<subname>.<subsubname> ...
            splitted_param_name = param_name.split('.')
            # we cannot go inside it anymore. This is the actual parameter
            is_leaf_param = len(splitted_param_name) == 1
            if is_leaf_param:
                # we recreate the correct key
                key = f"{submodule_name}.{param_name}"
                # we associate this key with this submodule
                keys_to_submodule[key] = submodule
                
    return keys_to_submodule

def load_state_dict_with_low_memory(model, state_dict):
    print('======hacky load======')
    keys_to_submodule = get_keys_to_submodule(model)
    mste=model.state_dict()
    for key, submodule in keys_to_submodule.items():
        # get the valye from the state_dict
        if key in state_dict:
          val = state_dict[key]
        else:
          print(key)
          val = torch.ones(mste[key].shape, dtype= torch.float32)

        param_name = key.split('.')[-1]
        new_val = torch.nn.Parameter(val,requires_grad=False)
        setattr(submodule, param_name, new_val)






class imgencdec:
  def encode(self,im):
    nzmp=im.size(0)
    H=im.size(2)
    W=im.size(3)
    return imgenc(  im, torch.randn(torch.Size([nzmp,4,H>>3,W>>3]))  )
  def decode(self,im):
    return autoencoder(im)


f_dljit(SDver,Dfm)

if not os.path.isfile('/content/guided_diffusion/unet.py'):
  !wget https://raw.githubusercontent.com/TabuaTambalam/DalleWebms/main/docs/sd/jkt.py
  !git clone https://github.com/Jack000/glid-3-xl-stable.git
  !mv /content/glid-3-xl-stable/guided_diffusion /content/guided_diffusion 

from transformers import CLIPTokenizer
cond_stage_model = BERTEmbedder(torch.jit.load('transformer_pnnx.pt').eval())
diffusion_emb = torch.jit.load('diffusion_emb_pnnx.pt').eval().cuda()
diffusion_mid = torch.jit.load('diffusion_mid_pnnx.pt').eval().cuda()
diffusion_out = torch.jit.load('diffusion_out_pnnx.pt').eval().cuda()
autoencoder = torch.jit.load('autoencoder_pnnx.pt').eval().cuda()
SDlatDEC=autoencoder
imgenc = torch.jit.load('imgencoder_pnnx.pt').eval()

In [None]:
import gc
import io
import math
import sys

from PIL import Image, ImageOps
import requests

from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

import numpy as np

from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults


from accelerate import init_empty_weights
from einops import rearrange
from math import log2, sqrt


!mkdir output_npy
!mkdir output

def save_sample(i, sample, clip_score=False):
    for k, image in enumerate(sample['pred_xstart'][:1]):
        image /= 0.18215
        im = image.unsqueeze(0)
        out = ldm.decode(im)

        npy_filename = f'output_npy/{i * batchsz + k:05}.npy'
        with open(npy_filename, 'wb') as outfile:
            np.save(outfile, image.detach().cpu().numpy())

        out = TF.to_pil_image(out.squeeze(0).add(1).div(2).clamp(0, 1))

        filename = f'output/{i * batchsz + k:05}.png'
        out.save(filename)


# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) // 2]
    combined = torch.cat([half, half], dim=0)
    model_out = model(combined, ts, **kwargs)
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = torch.cat([half_eps, half_eps], dim=0)
    return torch.cat([eps, rest], dim=1)

device = torch.device('cuda:0')
print('Using device:', device)



model_params = {
    'attention_resolutions': '32,16,8',
    'class_cond': False,
    'diffusion_steps': 1000,
    'rescale_timesteps': True,
    'timestep_respacing': '50',  # Modify this value to decrease the number of
                                 # timesteps.
    'image_size': 32,
    'learn_sigma': False,
    'noise_schedule': 'linear',
    'num_channels': 320,
    'num_heads': 8,
    'num_res_blocks': 2,
    'resblock_updown': False,
    'use_fp16': False,
    'use_scale_shift_norm': False,
    'clip_embed_dim': None, #768,
    'image_condition': False,
    #'image_condition': True if model_state_dict['input_blocks.0.0.weight'].shape[1] == 8 else False,
    'super_res_condition': False,
}

model_params['timestep_respacing'] = '100'

model_config = model_and_diffusion_defaults()
model_config.update(model_params)


model_config['use_fp16'] = True

# Load models
with init_empty_weights():
  model, diffusion = create_model_and_diffusion(**model_config)

load_state_dict_with_low_memory(model,mkmodel_state_dict())

if model_config['use_fp16']:
  model.convert_to_fp16()

In [None]:

model.requires_grad_(False).eval().to(device)


torch.manual_seed(114514)


# vae

ldm=imgencdec()


guidance_scale=7
height=832
width=896
batchsz=1


args_text='thicc farm girl, long blonde hair, japanimation, by Alfons Maria Mucha, cinematic lightning, cinematic wallpaper'
args_negative=''
# clip context


text_emb = cond_stage_model.encode(args_text,batchsz)
text_emb_blank = cond_stage_model.encode(args_negative,batchsz)

image_embed = None



input_image = torch.zeros(batchsz, 4, height//8, width//8, device=device)
'''
lat=torch.tensor(np.load('96_4x1v1.npy'))


input_image[0][:,:,:32]=lat[0][:,:,:32]
'''

      
image_embed = None #torch.cat(batchsz*2*[input_image], dim=0).float()



kwargs = {
    "context": torch.cat([text_emb, text_emb_blank], dim=0).half().cuda(),
    "clip_embed": None,
    "image_embed": image_embed
}



cur_t = None

sample_fn = diffusion.plms_sample_loop_progressive



'''
init = Image.open('xipooh.jpg').convert('RGB')

init = TF.to_tensor(init).to(device).unsqueeze(0).clamp(0,1)
h = ldm.encode(init * 2 - 1) *  0.18215
init = torch.cat(1*2*[h], dim=0)
'''
init=None

for i in range(1):
    cur_t = diffusion.num_timesteps - 1
    with torch.cuda.amp.autocast(dtype=torch.float16):
      samples = sample_fn(
          model_fn,
          (batchsz*2, 4, height>>3, width>>3),
          clip_denoised=False,
          model_kwargs=kwargs,
          cond_fn=None,
          device=device,
          progress=True,
          init_image=init,
          skip_timesteps=0,
      )

    for j, sample in enumerate(samples):
        cur_t -= 1

    save_sample(i, sample)
torch.cuda.empty_cache()
