# Common
Always run this, when start/restart the runtime

In [None]:
import math
import torch
from torch import nn
from scipy import integrate
from threading import Thread

from PIL import Image
import numpy as np
from tqdm.auto import trange, tqdm

import shutil
import subprocess
from google.colab import output


def runpyproc(na):
  subprocess.Popen(['python','/content/'+na+'.py'],close_fds=True)

def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
    return x[(...,) + (None,) * dims_to_append]


def append_zero(x):
    return torch.cat([x, x.new_zeros([1])])


def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cuda'):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = torch.linspace(0, 1, n,device=device)
    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
    """Constructs an exponential noise schedule."""
    sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
    return append_zero(sigmas)


def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
    """Constructs a continuous VP noise schedule."""
    t = torch.linspace(1, eps_s, n, device=device)
    sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
    return append_zero(sigmas)


def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)


def get_ancestral_step(sigma_from, sigma_to):
    """Calculates the noise level (sigma_down) to step down to and the amount
    of noise to add (sigma_up) when doing an ancestral sampling step."""
    sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5
    sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
    return sigma_down, sigma_up


@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model( i, hlog0.revpre(x,sigmas,i), sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        dt = sigmas[i + 1] - sigma_hat
        # Euler method
        x = x + d * dt
    return x


@torch.no_grad()
def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """Ancestral sampling with Euler method steps."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(  i,  hlog0.revpre(x,sigmas,i), sigmas[i] * s_in, **extra_args)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        d = to_d(x, sigmas[i], denoised)
        # Euler method
        dt = sigma_down - sigmas[i]
        x = x + d * dt
        x = x + torch.randn_like(x) * sigma_up
    return x


@torch.no_grad()
def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model(  i,  hlog0.revpre(x,sigmas,i), sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        dt = sigmas[i + 1] - sigma_hat
        if sigmas[i + 1] == 0:
            # Euler method
            x = x + d * dt
        else:
            # Heun's method
            x_2 = x + d * dt
            denoised_2 = model(i, x_2, sigmas[i + 1] * s_in, **extra_args)
            d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
            d_prime = (d + d_2) / 2
            x = x + d_prime * dt
    return x


@torch.no_grad()
def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
    """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
        eps = torch.randn_like(x) * s_noise
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
        denoised = model(  i,  hlog0.revpre(x,sigmas,i), sigma_hat * s_in, **extra_args)
        d = to_d(x, sigma_hat, denoised)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
        sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3
        dt_1 = sigma_mid - sigma_hat
        dt_2 = sigmas[i + 1] - sigma_hat
        x_2 = x + d * dt_1
        denoised_2 = model(i,x_2, sigma_mid * s_in, **extra_args)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
    return x


@torch.no_grad()
def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """Ancestral sampling with DPM-Solver inspired second-order steps."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(  i,  hlog0.revpre(x,sigmas,i), sigmas[i] * s_in, **extra_args)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        d = to_d(x, sigmas[i], denoised)
        # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
        sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
        dt_1 = sigma_mid - sigmas[i]
        dt_2 = sigma_down - sigmas[i]
        x_2 = x + d * dt_1
        denoised_2 = model(i, x_2, sigma_mid * s_in, **extra_args)
        d_2 = to_d(x_2, sigma_mid, denoised_2)
        x = x + d_2 * dt_2
        x = x + torch.randn_like(x) * sigma_up
    return x


def linear_multistep_coeff(order, t, i, j):
    if order - 1 > i:
        raise ValueError(f'Order {order} too high for step {i}')
    def fn(tau):
        prod = 1.
        for k in range(order):
            if j == k:
                continue
            prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
        return prod
    return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]

class area4:
  def __init__(self,tenz):
    jlist=[self.cxxxx,self.cWxxx,self.cxExx,self.cWExx,
        self.cxxNx,self.cWxNx,self.cxENx,self.cWENx,
        self.cxxxS,self.cWxxS,self.cxExS,self.cWExS,
        self.cxxNS,self.cWxNS,self.cxENS,self.cWENS]
    skey=list(tenz.shape)
    self.O_h=skey[2]
    self.O_w=skey[3]
    self.Wlap,self.Wpad,self.Wall=calcUnCrop4(0)
    self.Elap,self.Epad,self.Eall=calcUnCrop4(1)
    self.Nlap,self.Npad,self.Nall=calcUnCrop4(2)
    self.Slap,self.Spad,self.Sall=calcUnCrop4(3)
    skey[2]=self.O_h+self.Npad+self.Spad
    skey[3]=self.O_w+self.Wpad+self.Epad
    self.skey=skey
    njmp=0
    if self.Wall > 0:
      njmp+=1
    if self.Eall > 0:
      njmp+=2
    if self.Nall > 0:
      njmp+=4
    if self.Sall > 0:
      njmp+=8
    self.calc=jlist[njmp]
    hlog0.setWENS(self.Wlap,self.Elap,self.Nlap,self.Slap)

  def getshapes(self):
    return self.calc()

  def cxxxx(self):
    return []
  def simpNS(self):
    return [(0,self.Nall, 0,None, 2),(-self.Sall,None, 0,None, 3)]
  def simpWE(self):
    return [(0,None, 0,self.Wall ,0),(0,None, -self.Eall,None ,1)]
  #==
  def cWxxx(self):
    Npad=self.Npad
    return [(Npad,self.O_h+Npad,0,self.Wall,4)]
  def cxExx(self):
    Npad=self.Npad
    return [(Npad,self.O_h+Npad,-self.Eall,None,5)]
  def cxxNx(self):
    Wpad=self.Wpad
    return [(0,self.Nall,self.Wpad,self.O_w+Wpad,6)]
  def cxxxS(self):
    Wpad=self.Wpad
    return [(-self.Sall,None,self.Wpad,self.O_w+Wpad,7)]
  #==
  def cxENS(self):
    Npad=self.Npad
    return [(Npad,self.O_h+Npad,-self.Eall,None,1)]+self.simpNS()
  def cWxNS(self):
    Npad=self.Npad
    return [(Npad,self.O_h+Npad,0,self.Wall,0)]+self.simpNS()
  def cWExS(self):
    Wpad=self.Wpad
    return [(-self.Sall,None,self.Wpad,self.O_w+Wpad,3)]+self.simpWE()
  def cWENx(self):
    Wpad=self.Wpad
    return [(0,self.Nall,self.Wpad,self.O_w+Wpad,2)]+self.simpWE()
  #==
  def cWExx(self):
    return self.cWxxx()+self.cxExx()
  def cxxNS(self):
    return self.cxxNx()+self.cxxxS()
  #==
  def cWxNx(self):
    if self.skey[2] > self.skey[3]: #h>w
      return [(self.Npad,None, 0,self.Wall   ,0),(0,self.Nall, 0,None     ,2)]
    return   [(0,self.Nall,   self.Wpad,None ,2),(0,None,    0,self.Wall  ,0)]
  def cxENx(self):
    if self.skey[2] > self.skey[3]:
      return [(self.Npad,None, -self.Eall,None, 1),(0,self.Nall, 0,None,    2)]
    return   [(0,self.Nall,  0,-self.Epad,   2),(0,None,   -self.Eall,None,1)]
  def cWxxS(self):
    if self.skey[2] > self.skey[3]:
      return [(0,-self.Spad,  0,self.Wall   ,0),(-self.Sall,None, 0,None  ,3)]
    return   [(-self.Sall,None, self.Wpad,None ,3),(0,None,     0,self.Wall,0)]
  def cxExS(self):
    if self.skey[2] > self.skey[3]:
      return [(0,-self.Spad,  -self.Eall,None ,1),(-self.Sall,None,  0,None    ,3)]
    return   [(-self.Sall,None, 0,-self.Epad,  3),(0,None,      -self.Eall,None,1)]
  #==
  def cWENS(self):
    Wpad=self.Wpad
    Npad=self.Npad
    if self.skey[2] > self.skey[3]:
      if Npad > self.Spad:
        return [(-self.Sall,None, Wpad,self.O_w+Wpad,  3),(Npad,None  ,0,self.Wall,0),(Npad,None,-self.Eall,None,1)  ,(0,self.Nall,   0,None  ,2)]
      else:
        return [(0,self.Nall,   Wpad,self.O_w+Wpad,  2),(0,-self.Spad,0,self.Wall,0),(0,-self.Spad,-self.Eall,None,1)  ,(-self.Sall,None, 0,None  ,3)]
    else:
      if Wpad > self.Epad:
        return [(Npad,self.O_h+Npad, -self.Eall,None ,1),(0,self.Nall, Wpad,None   ,2),(-self.Sall,None, Wpad,None ,3)  ,(0,None, 0,self.Wall  ,0)]
      else:
        return [(Npad,self.O_h+Npad, 0,self.Wall   ,0),(0,self.Nall, 0,-self.Epad ,2),(-self.Sall,None, 0,-self.Epad ,3) ,(0,None, -self.Eall,None ,1)]

def arrmover(arr, itm, n):
  if itm is None:
    return None
  if len(arr) == n:
    new_itm=[None]*len(itm)
    arr.append(new_itm)
    return new_itm
  return arr[n]

def mulifnotnone(v,r):
  if v is None:
    return None
  return int(0.5+v*r)

class hlogger:
  def __init__(self):
    self.Arevpre = self.Arevpre0
    self.revpre = self.revpre0
    self.revpre_nocpy = self.revpre0
    self.latlog_arr=[]
    self.h_bs_arr=[]
    self.latlog=None
    self.h_bs=None
    self.h_bsB=None
    self.Wlap=None
    self.Elap=None
    self.Nlap=None
    self.Slap=None
    self.Wlap2=None
    self.Elap2=None
    self.Nlap2=None
    self.Slap2=None
    self.funclist={'0':self.Arevpre0,'logw0':self.logw0,'logw':self.logw,'loghs':self.loghs}
    self.funclistb={'0':self.revpre0,'masking':self.revpreMSK,'1s':self.revpre1s,'log':self.revpre0_log}
    self.funclist2=[self.revpreW,self.revpreE,self.revpreN,self.revpreS,
            self.revpreW_nocpy,self.revpreE_nocpy,self.revpreN_nocpy,self.revpreS_nocpy]
    self.funclist2b=[self.bW,self.bE,self.bN,self.bS]
    self.funclist2c=[self.bWsimp,self.bEsimp,self.bNsimp,self.bSsimp]
    self.func2Nb_cache=99
    self.func2Nc_cache=99


  def setWENS(self,Wlap,Elap,Nlap,Slap):
    self.Wlap=Wlap
    self.Elap=Elap
    self.Nlap=Nlap
    self.Slap=Slap
    self.Wlap2=Wlap<<1
    self.Elap2=Elap<<1
    self.Nlap2=Nlap<<1
    self.Slap2=Slap<<1

  def activefuncN2x(self, nx_cache, funclist):
    if nx_cache < 99:
      ndm=3
      if nx_cache < 2:
        ndm=2
      if self.h_bsB.size(ndm) != noise.size(ndm):
        self.revpre = self.revpre0
      else:
        self.revpre=funclist[nx_cache]



  def set_multinm(self,n,cur_h,dst_h,cur_w,dst_w):
    self.latlog = arrmover(self.latlog_arr, self.latlog, n)
    self.h_bs = arrmover(self.h_bs_arr, self.h_bs, n)
    self.activefuncN2x(self.func2Nb_cache, self.funclist2b)
    self.activefuncN2x(self.func2Nc_cache, self.funclist2c)
    if n == 0:
      self.Wlap_orig=self.Wlap
      self.Elap_orig=self.Elap
      self.Nlap_orig=self.Nlap
      self.Slap_orig=self.Slap
      self.Wlap2_orig=self.Wlap2
      self.Elap2_orig=self.Elap2
      self.Nlap2_orig=self.Nlap2
      self.Slap2_orig=self.Slap2
      
    if cur_h != dst_h:
      r=cur_h/dst_h
      self.Nlap=mulifnotnone(self.Nlap_orig,r)
      self.Nlap2=mulifnotnone(self.Nlap2_orig,r)
      self.Slap=mulifnotnone(self.Slap_orig,r)
      self.Slap2=mulifnotnone(self.Slap2_orig,r)
    else:
      self.Nlap2=self.Nlap2_orig
      self.Slap2=self.Slap2_orig
      self.Nlap=self.Nlap_orig
      self.Slap=self.Slap_orig

    if cur_w != dst_w:
      r=cur_w/dst_w
      self.Elap=mulifnotnone(self.Elap_orig,r)
      self.Elap2=mulifnotnone(self.Elap2_orig,r)
      self.Wlap=mulifnotnone(self.Wlap_orig,r)
      self.Wlap2=mulifnotnone(self.Wlap2_orig,r)
    else:
      self.Elap2=self.Elap2_orig
      self.Wlap2=self.Wlap2_orig
      self.Elap=self.Elap_orig
      self.Wlap=self.Wlap_orig


  def setfunc(self,key):
    self.Arevpre=self.funclist[key]
  def setfuncb(self,key,key2='0'):
    self.revpre=self.funclistb[key]
    self.revpre_nocpy=self.funclistb[key2]
  def setfuncN(self,n):
    self.Arevpre=self.funclist2[n]

  def setfuncNb(self,n,cache=False):
    if n > 3:
      n-=4
    if cache:
      self.func2Nb_cache=n
    else:
      self.revpre=self.funclist2b[n]

  def setfuncNc(self,n,cache=False):
    if cache:
      self.func2Nc_cache=n
    else:
      self.revpre=self.funclist2c[n]

  def setbsB(self,fn,lat):
    if fn > 3:
      fn-=4
    if fn==-10:
      self.h_bsB=lat
    elif fn==-11:
      self.h_bsB=lat[:,:,:,-self.Elap2:-self.Elap].cuda()
    elif fn==0:
      self.h_bsB=torch.cat([ lat[:,:,:,:-self.Wlap], self.h_bsB[:,:,:,self.Wlap:] ],dim=3)
    elif fn==1:
      self.h_bsB=torch.cat([ self.h_bsB[:,:,:,:-self.Elap], lat[:,:,:,self.Elap:] ],dim=3)
    elif fn==2:
      self.h_bsB=torch.cat([ lat[:,:,:-self.Nlap,:], self.h_bsB[:,:,self.Nlap:,:] ],dim=2)
    elif fn==3:
      self.h_bsB=torch.cat([ self.h_bsB[:,:,:-self.Slap,:], lat[:,:,self.Slap:,:] ],dim=2)


  def revpre0(self,img,sigmas,t):
    return img

  def revpreMSK(self,img,sigmas,t):
    return (revpreimg+noise * sigmas[t])*(1-zamask)+img*zamask

  def revpre1s(self,img,sigmas,t):
    return preimg+(noise*sigmas[t])

  def revpre0_log(self,img,sigmas,t):
    self.latlog.append( ((img-noise*sigmas[t])*(1+sigmas[t]*0.18215) ).cpu().numpy())
    return img
  def Arevpre0(self,h,d):
    return
  def logw0(self,h,d):
    self.h_bs[d]=h[:,:,:,-self.Elap:].cpu()
    return

  def logw(self,h,d):
    h[:,:,:,:self.Elap]=self.h_bs[d]
    self.logw0(h,d)
    return
  def loghs(self,h,d):
    self.h_bs[d]=h.cpu()
    return
  def revpreW(self,h,d):
    hbz=self.h_bs[d]
    h[:,:,:,-self.Wlap:]=hbz[:,:,:,:self.Wlap]
    self.h_bs[d]=torch.cat([ h[:,:,:,:-self.Wlap].cpu(), hbz ],dim=3)
    return
  def revpreE(self,h,d):
    hbz=self.h_bs[d]
    h[:,:,:,:self.Elap]=hbz[:,:,:,-self.Elap:]
    self.h_bs[d]=torch.cat([ hbz, h[:,:,:,self.Elap:].cpu() ],dim=3)
    return
  def revpreN(self,h,d):
    hbz=self.h_bs[d]
    h[:,:,-self.Nlap:,:]=hbz[:,:,:self.Nlap,:]
    self.h_bs[d]=torch.cat([ h[:,:,:-self.Nlap,:].cpu(), hbz ],dim=2)
    return
  def revpreS(self,h,d):
    hbz=self.h_bs[d]
    h[:,:,:self.Slap,:]=hbz[:,:,-self.Slap:,:]
    self.h_bs[d]=torch.cat([ hbz,h[:,:,self.Slap:,:].cpu() ],dim=2)
    return

  def revpreW_nocpy(self,h,d):
    h[:,:,:,-self.Wlap:]=self.h_bs[d][:,:,:,:self.Wlap]
    return
  def revpreE_nocpy(self,h,d):
    h[:,:,:,:self.Elap]=self.h_bs[d][:,:,:,-self.Elap:]
    return
  def revpreN_nocpy(self,h,d):
    h[:,:,-self.Nlap:,:]=self.h_bs[d][:,:,:self.Nlap,:]
    return
  def revpreS_nocpy(self,h,d):
    h[:,:,:self.Slap,:]=self.h_bs[d][:,:,-self.Slap:,:]
    return


  def bW(self,img,sigmas,t):
    img[:,:,:,-self.Wlap:]=self.h_bsB[:,:,:,self.Wlap:self.Wlap2]+(noise[:,:,:,-self.Wlap:]*sigmas[t])
    return img
  def bE(self,img,sigmas,t):
    img[:,:,:,:self.Elap]=self.h_bsB[:,:,:,-self.Elap2:-self.Elap]+(noise[:,:,:,:self.Elap]*sigmas[t])
    return img
  def bN(self,img,sigmas,t):
    img[:,:,-self.Nlap:,:]=self.h_bsB[:,:,self.Nlap:self.Nlap2,:]+(noise[:,:,-self.Nlap:,:]*sigmas[t])
    return img
  def bS(self,img,sigmas,t):
    img[:,:,:self.Slap,:]=self.h_bsB[:,:,-self.Slap2:-self.Slap,:]+(noise[:,:,:self.Slap,:]*sigmas[t])
    return img

  def bWsimp(self,img,sigmas,t):
    img[:,:,:,-self.Wlap:]=self.h_bsB+(noise[:,:,:,-self.Wlap:]*sigmas[t])
    return img
  def bEsimp(self,img,sigmas,t):
    img[:,:,:,:self.Elap]=self.h_bsB+(noise[:,:,:,:self.Elap]*sigmas[t])
    return img
  def bNsimp(self,img,sigmas,t):
    img[:,:,-self.Nlap:,:]=self.h_bsB+(noise[:,:,-self.Nlap:,:]*sigmas[t])
    return img
  def bSsimp(self,img,sigmas,t):
    img[:,:,:self.Slap,:]=self.h_bsB+(noise[:,:,:self.Slap,:]*sigmas[t])
    return img

hlog0=hlogger()

@torch.no_grad()
def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    sigmas_cpu = sigmas.detach().cpu().numpy()
    ds = []
    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(  i,  hlog0.revpre(x,sigmas,i) , sigmas[i] * s_in, **extra_args)
        d = to_d(x, sigmas[i], denoised)
        ds.append(d)
        if len(ds) > order:
            ds.pop(0)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        cur_order = min(i + 1, order)
        coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
        x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
    return x


@torch.no_grad()
def log_likelihood(model, x, sigma_min, sigma_max, extra_args=None, atol=1e-4, rtol=1e-4):
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    v = torch.randint_like(x, 2) * 2 - 1
    fevals = 0
    def ode_fn(sigma, x):
        nonlocal fevals
        with torch.enable_grad():
            x = x[0].detach().requires_grad_()
            denoised = model(x, sigma * s_in, **extra_args)
            d = to_d(x, sigma, denoised)
            fevals += 1
            grad = torch.autograd.grad((d * v).sum(), x)[0]
            d_ll = (v * grad).flatten(1).sum(1)
        return d.detach(), d_ll
    x_min = x, x.new_zeros([x.shape[0]])
    t = x.new_tensor([sigma_min, sigma_max])
    sol = odeint(ode_fn, x_min, t, atol=atol, rtol=rtol, method='dopri5')
    latent, delta_ll = sol[0][-1], sol[1][-1]
    ll_prior = torch.distributions.Normal(0, sigma_max).log_prob(latent).flatten(1).sum(1)
    return ll_prior + delta_ll, {'fevals': fevals}



class DiscreteSchedule(nn.Module):
    """A mapping between continuous noise levels (sigmas) and a list of discrete noise
    levels."""

    def __init__(self, sigmas, quantize):
        super().__init__()
        self.register_buffer('sigmas', sigmas)
        self.quantize = quantize

    def get_sigmas(self, n=None):
        if n is None:
            return append_zero(self.sigmas.flip(0))
        t_max = len(self.sigmas) - 1
        t = torch.linspace(t_max, 0, n, device=self.sigmas.device)
        return append_zero(self.t_to_sigma(t))

    def sigma_to_t(self, sigma, quantize=None):
        quantize = self.quantize if quantize is None else quantize
        
        dists = torch.abs(sigma - self.sigmas[:, None])
        if quantize:
            return torch.argmin(dists, dim=0).view(sigma.shape)
        low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0]
        low, high = self.sigmas[low_idx], self.sigmas[high_idx]
        w = (low - sigma) / (low - high)
        w = w.clamp(0, 1)
        t = (1 - w) * low_idx + w * high_idx
        return t.view(sigma.shape)

    def t_to_sigma(self, t):
        t = t.float()
        low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
        return (1 - w) * self.sigmas[low_idx] + w * self.sigmas[high_idx]


class DiscreteEpsDDPMDenoiser(DiscreteSchedule):
    """A wrapper for discrete schedule DDPM models that output eps (the predicted
    noise)."""

    def __init__(self, model, alphas_cumprod, quantize):
        super().__init__(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, quantize)
        self.inner_model = model
        self.sigma_data = 1.

    def get_scalings(self, sigma):
        c_out = -sigma
        c_in = 1 / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
        return c_out, c_in

    def get_eps(self, *args, **kwargs):
        return self.inner_model(*args, **kwargs)

    def loss(self, input, noise, sigma, **kwargs):
        c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        noised_input = input + noise * append_dims(sigma, input.ndim)
        eps = self.get_eps(noised_input * c_in, self.sigma_to_t(sigma), **kwargs)
        return (eps - noise).pow(2).flatten(1).mean(1)

    def forward(self, input, sigma, **kwargs):
        c_out, c_in = [append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
        eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
        return input + eps * c_out



def make_ddim_timesteps(num_ddim_timesteps, num_ddpm_timesteps):
    c = num_ddpm_timesteps // num_ddim_timesteps
    ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))

    # add one to get the final alpha values right (the ones from first scale to data during sampling)
    steps_out = ddim_timesteps + 1

    return steps_out


def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta):
    # select alphas for computing the variance schedule
    alphas = alphacums[ddim_timesteps]
    alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())

    # according the the formula provided in https://arxiv.org/abs/2010.02502
    sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))

    return sigmas, alphas, alphas_prev

def makerng():
  global seed
  if seed == 0:
    seed=random.randint(0, 2**32)
    print('random seed=')
    print(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)

def dlpromptexample():
  !wget https://github.com/TabuaTambalam/DalleWebms/releases/download/0.1/pexmp.7z
  !7z x pexmp.7z
  !wget -O web/psg1.htm https://raw.githubusercontent.com/TabuaTambalam/DalleWebms/main/docs/web/psg1.htm
  !wget -O web/svr.py https://raw.githubusercontent.com/TabuaTambalam/DalleWebms/main/docs/web/svr.py


def mkmodel_state_dict():
  try:
    import jkt
  except:
    !wget https://raw.githubusercontent.com/TabuaTambalam/DalleWebms/main/docs/sd/jkt.py
    import jkt
  
  difjit=[diffusion_emb,diffusion_mid,diffusion_out]
  model_state_dict = {}
  jna1=jkt.nam1
  for i in range(3):
    sd=difjit[i].state_dict()
    jna2=jkt.nam2[i]
    for k in sd:
      uwa=sd[k]
      if 'pnnx' in k:
        model_state_dict[jna2[k]]=uwa
      else:
        model_state_dict[jna1[k]]=uwa
  return model_state_dict


def procLat(lat):
  if lat.dim() == 3:
    return [lat.unsqueeze(0)],1
  nbat=lat.size(0)
  if nbat > 1:
    ret=[None]*nbat
    for i in range(nbat):
      ret[i]=lat[i].unsqueeze(0)
    return ret, nbat
  return [lat],1


SDlatDEC=None
def latdec(fna,scale=5.5):
  global SDlatDEC
  if SDlatDEC is None:
    if not os.path.isfile('autoencoder_pnnx.pt'):
      !wget https://huggingface.co/Larvik/sd470k/resolve/main/autoencoder_pnnx.pt
    SDlatDEC=torch.jit.load('autoencoder_pnnx.pt').cuda()
  lat,l =procLat(torch.tensor(np.load(fna)).cuda())
  for i in range(l):
    lat[i]=SDlatDEC(lat[i]*scale)[0]
  return lat

def latdec2(fna,scale=5.5):
  global SDlatDEC
  if SDlatDEC is None:
    if not os.path.isfile('autoencoder_pnnx.pt'):
      !wget https://huggingface.co/Larvik/sd470k/resolve/main/autoencoder_pnnx.pt
    SDlatDEC=torch.jit.load('autoencoder_pnnx.pt').cuda()
  lat,l =procLat(torch.tensor(fna).cuda())
  for i in range(l):
    lat[i]=SDlatDEC(lat[i]*scale)[0]
  return lat

def localhttp(root='/content/'):
  global HTML
  if not os.path.isfile('/content/sample_data/izh.txt'):
    from IPython.core.display import HTML
    !nohup python3 -m http.server -d {root} 8333 > /content/sample_data/izh.txt &


def f_sampler():
  global UseSamplr
  if Sampler == 'euler':
    UseSamplr = sample_euler
  elif Sampler == 'euler_a':
    UseSamplr = sample_euler_ancestral
  elif Sampler == 'heun':
    UseSamplr = sample_heun
  elif Sampler == 'dpm_2':
    UseSamplr = sample_dpm_2
  elif Sampler == 'dpm_2_a':
    UseSamplr = sample_dpm_2_ancestral
  elif Sampler == 'lms':
    UseSamplr = sample_lms

def f_sigmas():
  if Karras:
    return ddim_eta*get_sigmas_karras(ddim_num_steps,model_wrap.sigmas[0].item(),model_wrap.sigmas[-1].item(),rho=KarrasRho, device=cudev )
  else:
    return ddim_eta*model_wrap.get_sigmas(ddim_num_steps)

def fixver(ver,dfsver):
  if ver != '470k':
    return ''
  return dfsver
def f_dljit(ver='470k',dfsver=''):
  dfsver=fixver(ver,dfsver)
  if not os.path.isfile('imgencoder_pnnx.pt'):
    !pip install ftfy transformers omegaconf triton==2.0.0.dev20220701 einops accelerate
    !wget https://huggingface.co/Larvik/sd{ver}/resolve/main/alphas_cumprod.npz
    !wget https://huggingface.co/Larvik/tfmod/resolve/main/transformer_pnnx.pt
    !wget https://huggingface.co/Larvik/sd{ver}/resolve/main/autoencoder_pnnx.pt
    !wget https://huggingface.co/Larvik/sd{ver}/resolve/main/imgencoder_pnnx.pt
  ver+=dfsver
  !mkdir {ver}
  if not os.path.isfile(ver+'/diffusion_out_pnnx.pt'):
    !wget -P {ver}/ https://huggingface.co/Larvik/sd{ver}/resolve/main/diffusion_emb_pnnx.pt
    !wget -P {ver}/ https://huggingface.co/Larvik/sd{ver}/resolve/main/diffusion_mid_pnnx.pt
    !wget -P {ver}/ https://huggingface.co/Larvik/sd{ver}/resolve/main/diffusion_out_pnnx.pt
  return ver+'/'

def install_xformer():
  print('xformer')
  if not os.path.isfile('xformers/_C.so'):
    !wget https://raw.githubusercontent.com/TabuaTambalam/DalleWebms/main/docs/sd/jkt.py
    from subprocess import getoutput
    pfix='T4'
    gputyp=getoutput('nvidia-smi')
    if 'P100' in gputyp:
      pfix = 'P100'
    elif 'V100' in gputyp:
      pfix = 'V100'
    elif 'A100' in gputyp:
      pfix = 'A100'
    !pip install https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/{pfix}/xformers-0.0.13.dev0-py3-none-any.whl
    !mv /usr/local/lib/python3.7/dist-packages/xformers /content/xformers


def get_keys_to_submodule(model):
  keys_to_submodule = {}
  # iterate all submodules
  for submodule_name, submodule in model.named_modules():
      # iterate all paramters in each submobule
      for param_name, param in submodule.named_parameters():
          # param_name is organized as <name>.<subname>.<subsubname> ...
          splitted_param_name = param_name.split('.')
          # we cannot go inside it anymore. This is the actual parameter
          is_leaf_param = len(splitted_param_name) == 1
          if is_leaf_param:
              # we recreate the correct key
              key = f"{submodule_name}.{param_name}"
              # we associate this key with this submodule
              keys_to_submodule[key] = submodule
              
  return keys_to_submodule

inpaintwgt='UserEmb/inpaintwgt.pt'
def wgt_to_inp(state_dict):
  if not os.path.isfile(inpaintwgt):
    !wget -O {inpaintwgt} https://huggingface.co/Larvik/tfmod/resolve/main/inpaintwgt.pt
  state_dict['input_blocks.0.0.weight']=torch.cat((state_dict['input_blocks.0.0.weight'],torch.load(inpaintwgt)),dim=1)
  return state_dict
kvwgtbak=dict()
wgtkeybysz={
    320:[],
    640:[],
    768:[],
    1280:[]
}
atnnames=[
    'attn1.to_k',
    'attn1.to_v',
    'attn2.to_k',
    'attn2.to_v'
]

KVmerge_ratio=1.0

def mergeWnB(subkey,dk,dback):
  a3=kvwgtbak[subkey+'.weight'].float().T
  k=KVmerge_ratio
  dback[subkey+'.weight']=  ( a3+ k*(( dk['linear1.weight'].T @ dk['linear2.weight'].T )@a3) ).half().T
  dback[subkey+'.bias']=( dk['linear1.bias'].T @ (dk['linear2.weight'].T @ a3) +dk['linear2.bias'].T@a3  ).half().T


def mkPreKV_dict(pt):
  dback=dict()
  fk=torch.load(pt,map_location=cudev)
  for shpkey in wgtkeybysz:
    toK, toV=fk[shpkey]
    bag=wgtkeybysz[shpkey]
    for subkey in bag:
      if subkey.endswith('.to_k'):
        mergeWnB(subkey,toK,dback)
      else:
        mergeWnB(subkey,toV,dback)
  return dback


def add_kvbias2(state_dict,k):
  for j in range(4):
    k0=k+atnnames[j]
    k1=k0+'.weight'
    wgt=state_dict[k1]
    kvwgtbak[k1]=wgt
    state_dict[k0+'.bias']=torch.zeros(wgt.size(0))
    wgtkeybysz[wgt.size(1)].append(k0)

def add_kvbias(state_dict):
  yp=list(range(1,9))
  yp.remove(3)
  yp.remove(6)
  for i in yp:
    add_kvbias2(state_dict,'input_blocks.'+str(i)+'.1.transformer_blocks.0.')
  
  add_kvbias2(state_dict,'middle_block.1.transformer_blocks.0.')

  for i in range(3,12):
    add_kvbias2(state_dict,'output_blocks.'+str(i)+'.1.transformer_blocks.0.')

  return state_dict
      


def load_state_dict_with_low_memory(model, state_dict,modifyfunc=None,fill=True):
  if modifyfunc is not None:
    state_dict=modifyfunc(state_dict)
  print('======hacky load======')
  keys_to_submodule = get_keys_to_submodule(model)
  mste=model.state_dict()
  for key, submodule in keys_to_submodule.items():
      # get the valye from the state_dict
      if key in state_dict:
        val = state_dict[key]
      elif fill:
        print(key)
        val = torch.ones(mste[key].shape, dtype= torch.float16)
      else:
        continue

      param_name = key.split('.')[-1]
      new_val = torch.nn.Parameter(val,requires_grad=False)
      setattr(submodule, param_name, new_val)





ldmbase='ldm'
def init_ldm(mode,sd_modify=-1):
  print('orig ldm')
  if not os.path.exists('ldm_opt'):
    !wget https://github.com/TabuaTambalam/DalleWebms/releases/download/0.1/ldms.7z
    !7z x ldms.7z
  if os.path.exists(ldmbase):
    os.unlink(ldmbase)
  if mode==1:
    os.symlink('ldm_opt',ldmbase)
  elif mode==2:
    os.symlink('ldm_xfm',ldmbase)
  from ldm.modules.diffusionmodules.openaimodel import UNetModel

  in_chn=4
  if INP:
    in_chn=8
  sdt_func=None
  if sd_modify == 1:
    sdt_func=add_kvbias
    from ldm.modules.attention import CrossAttention_config
    CrossAttention_config.kvbias=True


  with init_empty_weights():
    ldm_unet = UNetModel(
        image_size=32,
        in_channels=in_chn,out_channels=4,
            model_channels=320,
            attention_resolutions=[4,2,1],
            num_res_blocks=2,
            channel_mult=[1,2,4,4],
            num_heads=8,
            use_spatial_transformer=True,
            context_dim=768,
            legacy= False).requires_grad_(False)
  load_state_dict_with_low_memory(ldm_unet,mkmodel_state_dict(),modifyfunc=sdt_func)
  ldm_unet=ldm_unet.eval().to(cudev)
  return ldm_unet

def clamp64(n):
  ret=n>>3
  lez=ret&7
  ret-=lez
  if lez >3:
    ret+=8
  return ret

def mk_shape():
  shape = [n_samples, 4, clamp64(H) , clamp64(W) ]
  nl=len(seed_size)
  if nl> 0:
    dst =[seed_size[-1]]+shape[2:]
    shape[2]=clamp64(seed_size[0])
    shape[3]=clamp64(seed_size[1])

    if nl > 3:
      ksd=seed_size[2:-1]
      nl_2=(nl-3)//3
      for n in range(nl_2):
        ksd[3*n+1]=clamp64(ksd[3*n+1])
        ksd[3*n+2]=clamp64(ksd[3*n+2])
      shape=shape+ksd+dst
    else:
      shape=shape+dst

  return shape

class Insertor:
  def __init__(self, string):
    self.rpla=string+'}'
    self.rpla_cut=len(string)+1
    varias=mkInsertor_pstz(string)
    ll=len(varias)
    self.cplxlv=-1

    idkole=[None]*ll
    for n in range(ll):
      dikv=set()
      vaa=varias[n]
      for u in vaa:
        vyd=u.id
        if vyd != 0:
          dikv.add(vyd)
      idkole[n]=dikv

    self.ids=idkole
    self.varias=varias
    self.ll=ll
    
  def cplxLevel(self,n):
    if n < 0:
      cplx=0
      if self.cplxlv>=0:
        return self.cplxlv
      for i in range(self.ll):
        cplx+=self.cplxLevel(i)
      self.cplxlv=cplx
      return cplx
    if self.ids[n]:
      return len(self.ids[n])*0x1000
    return 0






# encoder
class BERTEmbedder:
    def __init__(self, transformer, max_length=77):
        self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14')
        self.max_length = max_length
        self.dedup=dict()

        self.transformer = transformer
        self.embedding = torch.nn.Embedding.from_pretrained(self.transformer.state_dict()['text_model_embeddings_token_embedding.weight'])
        self.encode = self.encode0

        emptytok=self.tok('')
        self.tok_bos, self.tok_eos = int(emptytok[0]), int(emptytok[1])
        emptyemb=self.amb(emptytok)
        self.emb_bos, self.emb_eos = emptyemb[0].unsqueeze(0) ,emptyemb[1]
        

    def insert(self,inz):
      self.dedup[inz]=torch.tensor(np.fromfile('UserEmb/'+inz[1:-1]+'.bin',dtype=np.float32).reshape(-1,768))

    def insert_prompt_vars(self,inz):
      inz='{'+inz
      self.dedup[inz]=Insertor(inz)
      
    def get_empty(self):
      return torch.cat([self.emb_bos,self.emb_eos.expand(self.max_length-1,-1) ])
      

    

    def tok(self, text, pad=False):
      padstr='do_not_pad'
      if pad:
        padstr='max_length'
      batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                            return_overflowing_tokens=False, padding=padstr, return_tensors='pt')

      return batch_encoding['input_ids'][0]

    def amb(self, tokens):
        return self.embedding(tokens)
        
    def mk_emb_wgt(self,unit_arr,dtal=-1):
      if dtal < 0:
        dtal=len(unit_arr)
      
      emb=[None]*(dtal+2)
      wgt=[None]*(dtal+2)
      txt=[]
      emb[0]=self.emb_bos
      wgt[0]=torch.ones(1)
      count=self.max_length-1
      NoWgt=True
      for i in range(dtal):
        emb0,wgt0=unit_arr[i].emb_wgt()
        if wgt0[0] != 1.0:
          NoWgt=False
        msg = unit_arr[i].msg
        if len(msg) > 1:
          txt.append(msg)
        emb[i+1]=emb0
        wgt[i+1]=wgt0
        count-=wgt0.size(0)
        if count <= 0:
          kcut=count-1
          emb[i+1]=emb0[:kcut]
          wgt[i+1]=wgt0[:kcut]
          emb=emb[:i+3]
          wgt=wgt[:i+3]
          count=1
          print('ignore after: '+msg)
          break
      
      
      emb[-1]=self.emb_eos.expand(count,-1)
      if NoWgt:
        wgt=None
      else:
        wgt[-1]=wgt[0].expand(count)
      emb=torch.cat(emb)
      
      

      #wgt=torch.cat(wgt)
      if txt:
        if len(txt) > 1:
          txt=' # '.join(txt)
        else:
          txt=txt[0]
      else:
        txt=None
      return emb, wgt, txt

    def from_emb(self,emb0,wgt_arr=None,nsamp=1,cuda=True):
      z = self.transformer( emb0.expand(1,-1,-1) )
      if cuda:
        z=z.cuda()
      if wgt_arr is not None:
        wgt=torch.cat(wgt_arr)
        if cuda:
          wgt=wgt.cuda()
        ynt=z[:,0,:]
        wgt /= torch.abs(wgt.mean())
        z*=wgt.reshape(-1,1).expand(1,-1,-1)
        z[:,0,:]=ynt
      if nsamp > 1:
        z=z.expand(nsamp,-1,-1)
      return z

    def encode0(self, text, nsamp):

      if len(text) == 0:
        return cond_getter(None)

      units=pmpmtx_preproc([text],enable3d=False)[0]
        
      emb, wgt, txt = self.mk_emb_wgt(units,len(units))
        

      return cond_getter(emb,fast=0,nsamp=nsamp)

    def encode2(self, text, nsamp):
        batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                            return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        

        return self.transformer( self.embedding( batch_encoding["input_ids"].expand(nsamp,-1) ) )

def loadKV_merge(pt):
  state_dict=mkPreKV_dict(pt)
  load_state_dict_with_low_memory(ldm_unet, state_dict,modifyfunc=None,fill=False)


def loadKV(pt):
  if EnableKVmerges:
    loadKV_merge(pt)
    return
  wgt_k,wgt_v=torch.load(pt)[768]
  load_state_dict_with_low_memory(preprocK,wgt_k)
  load_state_dict_with_low_memory(preprocV,wgt_v)
  preprocK.use=True
  preprocV.use=True


class CompVisDenoiser(DiscreteEpsDDPMDenoiser):
    """A wrapper for CompVis diffusion models."""

    def __init__(self, model, quantize=False, device='cpu'):
        super().__init__(model, model.alphas_cumprod, quantize=quantize)

    def get_eps(self, *args, **kwargs):
        return apply_model(*args, **kwargs)




def nDfmCodeBase():
  if DfmCodeBase == 'JIT':
    return 0
  elif DfmCodeBase == 'ldm_SaveVram':
    return 1
  elif DfmCodeBase == 'ldm_xformers':
    return 2
  return 99


class PreKV(nn.Module):
    logic_multiplier = 1.0
    def __init__(self, dim, heads=0):
        super().__init__()
        self.use=False
        self.linear1 = torch.nn.Linear(dim, dim*2)
        self.linear2 = torch.nn.Linear(dim*2, dim)

    def forward(self, _x):
        return _x + (self.linear2(self.linear1(_x)) * PreKV.logic_multiplier)


class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, d, x, sigma, uncond, cond, cond_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond.get(d) , cond.get(d) ])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in,d=d).chunk(2)
        return uncond + (cond - uncond) * cond_scale


class SRDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, cond ):
        cond = self.inner_model(x, sigma, cond=cond)
        return cond

prevSDver=''
prevDfmCodeBase=''

class CompVisJIT():
  def __init__(self):
    self.alphas_cumprod=torch.tensor(alphas_cumprod,device=cudev)
    self.apply_model=apply_model

class ifeeder():
  def __init__(self):
    self.getn=self.get_simp
  def get_simp(self,n):
    return self.bs
  def setbs(self,in_bs):
    self.bs=in_bs
  
  def get_npbins(self,n):
    return torch.tensor(np.fromfile(self.pattern%(n+1),dtype=np.float32).reshape(self.shape),device=cudev)+self.noiseadd

Karras=False
model_wrap=None

In [None]:
import copy
import random




class vinfo:
  def __init__(self, txt):
    self.tag=txt
    valid=False
    cut=-1
    if txt[-3] == ',':
      cut=-3
    key=txt[:cut]
    if key in cond_stage_model.dedup:
      valid=True
      self.bazkey=key
      ActivedPromptVars[txt]=cond_stage_model.dedup[key]
    self.valid=valid



  @property
  def baz(self):
    return cond_stage_model.dedup[self.bazkey]

  def repl(self, unit, v):
    inzt=copy.deepcopy(self.baz.varias[v])
    idset=self.baz.ids[v]
    _, p_wgt, p_sta, p_end = unit.nfo()

    idmap=dict()
    procid=False
    if idset:
      procid=True
      for id in idset:
        idmap[id]=rdmIDfunc(None)

    if procid:
      for yn in inzt:
        if yn.id == 0:
          yn.update_sta_end(p_wgt,p_sta,p_end)
        else:
          yn.id=idmap[yn.id]
          erz=yn.eraz
          if erz:
            nyu_erz=dict()
            for k in erz:
              nyu_id=idmap[k]
              nyu_erz[nyu_id]=nyu_id
            yn.eraz=nyu_erz
    else:
      for yn in inzt:
        yn.update_sta_end(p_wgt,p_sta,p_end)

       
    return inzt





class sentUnit:
  def __init__(self, txt,
               fast=-1,
               p_wgt=None,p_sta=None,p_end=None
               ):

    self.emb_wgt = self.emb_wgt0
    self.id=0
    usig=0
    self.repls=dict()
    self.eraz=None
    if fast == 0:
      self.sig=usig
      self.msg, self.wgt, self.upper, self.lower =txt,p_wgt,p_sta,p_end
      return
    elif fast == 1:
      self.id=rdmIDfunc(txt) #id(self)#
      self.sig=usig
      self.msg, self.wgt, self.upper, self.lower =txt,0,p_sta,p_end
      #wgt as group_len
      self.eraz=dict()
      self.yetproc=True
      return

    
    
    prepand=None
    retThis=True
    self.real_return=[]
    if ':' in txt:
      usig+=0x100
    if '+' in txt:
      usig+=0x200
    if ';' in txt:
      usig+=0x400
    if '|' in txt:
      usig+=0x800
    
    self.sig=usig

    if usig < 0x100:
      self.msg, self.wgt, self.upper, self.lower = txt,p_wgt,p_sta,p_end
      self.real_return=[self]
    else:
      taps=mktaps(txt,  p_wgt=p_wgt,p_sta=p_sta,p_end=p_end)
      msg, self.wgt, self.upper, self.lower=taps[0]
      if '|' in msg:
        self.msg='^'
        retThis=False
        taps2=mktaps(msg, sep='|', p_wgt=self.wgt,p_sta=self.upper,p_end=self.lower)
        for m2,w2,s2,e2 in taps2:
          dmm=sentUnit(m2,p_wgt=w2,p_sta=s2,p_end=e2)
          for dmm2 in get_real_return(dmm):
            self.real_return.append(dmm2)
      else:
        self.msg=msg
        self.real_return=[self]
      
      len_taps=len(taps)
      
      if len_taps > 1:
        prepand=sentUnit('[',fast=1, p_sta=self.upper,p_end=self.lower )
        lazt_id=prepand.id
        eraz_arr=[prepand]
        for i in range(1,len_taps):
          msg,wgt_bs,sta_bs,endo_bs = taps[i]
          edb=sentUnit(']',fast=1)
          edb.id=lazt_id
          pp2=sentUnit('[',fast=1,p_sta=sta_bs,p_end=endo_bs)
          lazt_id=pp2.id
          if '|' in msg:
            pral=[edb,pp2]
            taps2=mktaps(msg, sep='|', p_wgt=wgt_bs,p_sta=sta_bs,p_end=endo_bs)
            for m2,w2,s2,e2 in taps2:
              dmm=sentUnit(m2,p_wgt=w2,p_sta=s2,p_end=e2)
              for dmm2 in get_real_return(dmm):
                pral.append(dmm2) 
            self.real_return+=pral

          else:
            dmm=sentUnit(msg,p_wgt=wgt_bs, p_sta=sta_bs,p_end=endo_bs)
            dmm_rt=get_real_return(dmm)
            self.real_return+=[edb,pp2]+dmm_rt
          
          for erz in eraz_arr:
            erz.eraz[lazt_id]=lazt_id
          eraz_arr.append(pp2)
        edb=sentUnit(']',fast=1)
        edb.id=lazt_id
        self.real_return.append(edb)
            

    if retThis:
      self.real_return=emb_and_v(self.msg,p_wgt=self.wgt,p_sta=self.upper,p_end=self.lower)+self.real_return[1:]
      
   

    if prepand is not None:
      self.real_return=[prepand]+self.real_return
 
  def nfo(self,trans_wgt=True,trans_sta=False,trans_end=False,extra=False):
    ret_wgt=self.wgt
    ret_sta=self.upper
    ret_end=self.lower
    if trans_wgt and ret_wgt is None:
      ret_wgt=1.0
    if trans_sta and ret_sta is None:
      ret_sta=0
    if trans_end and ret_end is None:
      ret_end=1.0
    if extra:
      return self.msg, ret_wgt, ret_sta, ret_end,[self.wgt is None,self.upper is None,self.lower is None]
    else:
      return self.msg, ret_wgt, ret_sta, ret_end

  def get_sig(self):
    ret=0
    if self.eraz:
      ret = 0x1000
    if self.upper is not None:
      return ret + 0x100
    if self.lower is not None:
      return ret + 0x100
    return ret

  def get_realstaend(self):
    sta=0
    endo=t_enc
    if self.upper is not None:
      sta=int(self.upper*t_enc +0.5)
    if self.lower is not None:
      endo=int(self.lower*t_enc +0.5)
    return sta, endo

  def get_realwgt(self):
    if self.wgt is None:
      return 1.0
    return self.wgt


  def set_emb(self,n):
    if n == 1:
      self.tok_len=1
      self.emb_wgt = self.emb_wgt1
      self.fast_emb= cond_stage_model.dedup[self.msg]
      self.fast_tkl=self.fast_emb.size(0)


  def emb_wgt1(self):
    wgg=torch.ones(self.fast_tkl)
    if self.wgt is not None:
      wgg*=self.wgt

    return self.fast_emb, wgg
  def emb_wgt0(self):
    tok=cond_stage_model.tok(self.msg)[1:-1]
    tkl=tok.size(0)
    self.tok_len=tkl
    wgg=torch.ones(tkl)
    if self.wgt is not None:
      wgg*=self.wgt
    amb = cond_stage_model.embedding(tok)
    return amb,wgg

  def update_sta_end(self, wgt, sta, endo):
    if self.wgt is None:
      self.wgt=wgt
    if self.upper is None:
      self.upper=sta
    if self.lower is None:
      self.lower=endo

  def __repr__(self):
    ret=stringlizeNfo(self)
    if self.eraz:
      ret+='\n'+str(self.eraz)
    return ret


# arr=emb
class cond_getter:
  def __init__(self, arr, wgt_arr=None, reftxt=None, kndref=None, fast=-1, nsamp=1,cuda=True):
    self.txt=[]
    if reftxt is not None:
      self.txt=reftxt

    self.notSave=False
    self.add_sta=0
    self.d_sta=0
    self.is_simp=True
    self.get=self.get_simp
    self.get_txt=self.txt_simp
    if arr is None:
      emb = cond_stage_model.get_empty()
      self.arr = cond_stage_model.from_emb(emb,nsamp=nsamp,cuda=cuda)
      return
    if fast==0:
      self.arr = cond_stage_model.from_emb(arr,wgt_arr=wgt_arr,nsamp=nsamp,cuda=cuda)
      return
    elif fast == 1:
      self.arr=arr
      return
  
    self.knd=kndref
    self.is_simp=False
    arr.append(arr[-1])
    self.arr=arr
    self.get=self.get_arr
    self.get_txt=self.txt_arr
      
  def get_knd(self):
    if self.is_simp:
      return np.ones(t_enc,dtype=np.uint8)*0xff
    return self.knd

  def get_fullarr(self):
    if self.is_simp:
      return [self.arr]*t_enc
    return self.arr


  def get_simp(self,d):
    return self.arr

  def reset(self):
    self.add_sta=0
    self.d_sta=0


  def txt_simp(self,d):
    return self.txt
  
  def get_arr(self,d):
    sd=d+self.add_sta
    if self.d_sta > 1:
      sd=int(0.5+d*self.d_sta)
    return self.arr[sd]

  def txt_arr(self,d):
    return self.txt[d]

  def save(self,pname='prmt'):
    if self.notSave:
      return None
    sv=dict()
    if self.is_simp:
      sv[0]=True
      savarr=self.arr[0]
      dfarr=None
    else:
      sv[0]=False
      sv[2]=self.knd
      savarr, dfarr = kndmax_diff(self.knd)
      knd_l=len(savarr)
      for i in range(knd_l):
        savarr[i]=self.arr[ savarr[i] ][0]

    sv[1]=savarr
    if self.txt is not None:
      pname+=self.get_txt(0)[:20]
    pname+='.compiled_prompt'
    torch.save(sv,pname)
    return dfarr



  def load(self,nsamp=1,cuda=True):
    sv=torch.load(self.arr[0],map_location=cudev)
    simp=sv[0]
    self.notSave=True
    
    self.get_txt=self.txt_simp
    self.txt='===secret==='
    if simp:
      self.get=self.get_simp
      self.arr=sv[1].expand(nsamp,-1,-1)
    else:
      self.knd=sv[2]
      self.get=self.get_arr
      karr=sv[1]
      knd_l=len(karr)
      for i in range(knd_l):
        z=karr[i]
        if cuda:
          z=z.cuda()
        else:
          z=z.cpu()
        karr[i]=z.expand(nsamp,-1,-1)

      knduse = resizeknd(self.knd)
      arr=[None]*t_enc
      for i in range(t_enc):
        arr[i]=karr[knduse[i]]
      arr.append(arr[-1])
      self.knd=knduse
      self.arr=arr
    self.is_simp=simp



In [None]:
def get_real_return(unit):
  grr=unit.real_return
  del unit.real_return
  return grr

def rdmIDfunc(yd):
  #print(yd[:2])
  return random.randint(0, 2**32)


def resizeknd(knd):
  ldl=len(knd)
  jd_sta=0
  if ldl > t_enc:
    jd_sta=ldl/t_enc
    knd2=knd
  elif ldl < t_enc:
    rpt=int(0.9999+t_enc/ldl)
    knd2=knd.repeat( rpt )
    jd_sta=ldl*rpt/t_enc
  knduse=knd
  if jd_sta!=0:
    knduse=[None]*(t_enc+1)
    for d in range(t_enc):
      sd=int(0.5+d*jd_sta)
      knduse[d]=knd2[sd]
    knduse=knduse[:-1]
  return knduse


def kndmax_diff(knd):
  curknd=knd[0]
  dfarr=[]
  knd_l=len(knd)
  kndict=dict()
  for i in range(knd_l):
    wua=knd[i]
    if wua != curknd:
      curknd=knd[i]
      dfarr.append(i)
    kndict[wua]=i
  revknd=[]
  for i in range(knd_l):
    if i in kndict:
      revknd.append(kndict[i])
    else:
      break
  return revknd, dfarr


def stringlizeNfo(src):
  msg,wgt,sta,endo = src.nfo()
  if wgt != 1.0:
    msg+='+'+str(wgt)
  sig=0
  if sta is not None:
    sig+=1
  if endo is not None:
    sig+=2

  if sig == 0:
    return msg
  elif sig==1:
    return msg+':'+str(int(0.5+sta*100))+':'
  elif sig==2:
    return msg+'::'+str(int(0.5+endo*100))
  elif sig==3:
    return msg+':'+str(int(0.5+sta*100))+':'+str(int(0.5+endo*100))




def mkInsertor_pstz(string):
  fna = 'UserEmb/'+string[1:]+'.txt'
  with open(fna,'rt') as f:
    stz=f.read().splitlines()
  stz=('@'.join(stz)).replace('@@','^').split('^')
  stz_l=len(stz)
  stz_n=[]
  for i in range(stz_l):
    txt=stz[i]
    if txt[0] == '#':
      continue
    if '@' in txt:
      stz2=txt.split('@')
      arr=[]
      for s in stz2:
        arr+=get_real_return(sentUnit(s))
      stz_n.append(arr)
    else:
      stz_n.append( get_real_return(sentUnit(stz[i])) )
  return stz_n




def i2t(strr, ifempty=None):
  if strr:
    f = float(strr)
    if f > 1:
      f/=100 
    return f
  return ifempty

def m2mw(strr,prev,p_wgt):
  wgt=p_wgt
  spl=strr.split('+')
  if len(spl)>1:
    wgt=float(spl[1])
    strr=spl[0]
    if strr == '':
      strr=prev
  return strr,wgt


InfoChrs='1234567890+-:. '
def findposiblesplit(str_in):
  lstr=len(str_in)-1
  for n in range(lstr,-1,-1):
    if str_in[n] not in InfoChrs:
      return n-lstr
  return 0

def mktaps(str,sep=';',p_wgt=None,p_sta=None,p_end=None):
  Enbale_s_in_s = True
  if sep != ';':
    Enbale_s_in_s=False
  segs=str.split(sep)
  if len(segs[-1]) == 0:
    segs=segs[:-1]
  if len(segs[0]) == 0:
    segs=segs[1:]
  prevstr=''
  ret=[]
  for s in segs:
    sta=p_sta
    endo=p_end
    repl_msg=None
    info_s=s
    s_in_s=False
    if Enbale_s_in_s and '|' in s:
      s_in_s=True
      idx=findposiblesplit(s)
      if idx == 0:
        info_s = 'dummy'
        repl_msg=s
      else:
        info_s = 'dummy'+s[idx:]
        repl_msg = s[:idx]

    msg=info_s.split(':')
    if len(msg) > 2:
      sta=i2t(msg[1],p_sta)
      endo=i2t(msg[2],p_end)
    msg, wgt=m2mw(msg[0],prevstr,p_wgt)
    if s_in_s:
      msg=repl_msg+msg[5:]

    msg=msg.strip()
    prevstr=msg
    ret.append((msg,wgt,sta,endo))
  return ret

def m2unit(data,dtal,mtx):
  ret=[]
  for i in range(dtal):
    if mtx[i] !=0xff:
      ret.append(data[i])
  return ret





def chkrealexist(key,src_n):
  if not src_n.repls:
    return False
  if key in src_n.repls:
    return True

  return False

def vintzproc(src,k,v):
  dtal=len(src)
  for n in range(dtal):
    if chkrealexist(k,src[n]):
      vinfo=src[n].repls[k]
      brd=vinfo.repl(src[n],v)
      if len(brd) == 1:
        src[n]=brd[0]
      else:
        src=src[:n]+brd+src[n+1:]
  return src

def recurflatten(seed,key_list):
  k=key_list[-1]
  n_pl=ActivedPromptVars[k].ll
  n_seed=len(seed)
  newseed=[]
  for i in range(n_seed):
    for v in range(n_pl):
      src=copy.deepcopy(seed[i])
      newseed.append( vintzproc(src,k,v) )
  if len(key_list) > 1:
    return recurflatten(newseed,key_list[:-1])
  else:
    return newseed

def ActivedPromptVarsByCplx():
  key_list=list(ActivedPromptVars.keys())
  kl=len(key_list)
  for n in range(kl):
    key=key_list[n]
    key_list[n]=('%08X'%ActivedPromptVars[key].cplxLevel(-1))+key
  key_list.sort()
  for n in range(kl):
    key_list[n]=key_list[n][8:]
  return key_list


def proc3d(data):
  key_list=ActivedPromptVarsByCplx()
  arr= recurflatten([data],key_list)
  arrl=len(arr)
  for i in range(arrl):
    arr[i]=trimgroup(arr[i])
  return arr


def proc1d(data):
  return [trimgroup(data)]


def trymakeemb(tag):
  if tag in cond_stage_model.dedup:
    return True
  if os.path.isfile('UserEmb/'+tag[1:-1]+'.bin'):
    cond_stage_model.insert(tag)
    return True
  return False
    




def dfind_emb(txt,poz,l,p_wgt,p_sta,p_end):
  i=poz
  while i < l:
    c=txt[i]
    i+=1
    if c == '>':
      sig=txt[poz-1:i]
      unit = sentUnit(sig,fast=0,p_wgt=p_wgt,p_sta=p_sta,p_end=p_end)
      valid=trymakeemb(sig)
      if valid:
        unit.set_emb(1)
      else:
        unit.wgt=-333
        unit.msg=sig[1:-1]
      return unit, 0 ,i
    

def dfind_v(txt,poz,l,p_wgt,p_sta,p_end):
  i=poz
  while i < l:
    c=txt[i]
    i+=1
    if c == '}':
      sig=txt[poz-1:i]
      unit = sentUnit('}',fast=0,p_wgt=p_wgt,p_sta=p_sta,p_end=p_end)
      dmm=vinfo(sig)
      if dmm.valid:
        unit.repls[sig]=dmm
      else:
        unit.wgt=-333
        unit.msg=sig[1:-1]
      return unit, 0 ,i

def dfind_v_dummy(txt,poz,l,p_wgt,p_sta,p_end):
  i=poz
  while i < l:
    c=txt[i]
    i+=1
    if c == '}':
      sig=txt[poz:i-1]
      unit = sentUnit(sig,fast=0,p_wgt=p_wgt,p_sta=p_sta,p_end=p_end)
      unit.wgt=-333
      return unit, 0 ,i

def dfind_head(txt,poz,l,p_wgt,p_sta,p_end):
  i=poz
  while i < l:
    c=txt[i]
    i+=1
    if c == '<':
      if i - poz > 1:
        unit= sentUnit(txt[poz:i-1].strip(),fast=0,p_wgt=p_wgt,p_sta=p_sta,p_end=p_end)
      else:
        unit= sentUnit('empty',fast=0,p_wgt=-666)
      return unit, 1 ,i
    elif c == '{':
      if i - poz > 1:
        unit= sentUnit(txt[poz:i-1].strip(),fast=0,p_wgt=p_wgt,p_sta=p_sta,p_end=p_end)
      else:
        unit= sentUnit('empty',fast=0,p_wgt=-666)
      return unit, 2 ,i
  
  fina=sentUnit(txt[poz:].strip(),fast=0,p_wgt=p_wgt,p_sta=p_sta,p_end=p_end)
  fina.wgt=-333
  return fina,0,l


def canmerge(ret):
  if len(ret) == 0:
    return False
  if len(ret[-1].msg) < 2:
    return False
  if ret[-1].msg[0] == '<':
    return False
  return True

def emb_and_v(txt,p_wgt=None,p_sta=None,p_end=None,enable3d=True):
  l=len(txt)
  i=0
  functbl=[dfind_head, dfind_emb, dfind_v]
  if not enable3d:
    functbl[2]=dfind_v_dummy

  finderfunc=dfind_head
  ret=[]
  while i < l:
    result, nfunc, i = finderfunc(txt,i,l,p_wgt,p_sta,p_end)
    finderfunc=functbl[nfunc]
    if result.wgt == -333:
      if canmerge(ret):
        ret[-1].msg+=' '+result.msg
      else:
        result.wgt=p_wgt
        ret.append(result)
    elif result.wgt != -666:
      ret.append(result)

  return ret





def dumbunit(txt,wgt):
  if wgt == 1:
    wgt = None
  return emb_and_v(txt,p_wgt=wgt)



def filltimeinfo(arr,sta,endo,wgtfix):
  if not arr:
    return 0
  for itm in arr:
    itm.wgt+=wgtfix 
    itm.upper=sta
    itm.lower=endo
  return 1

def pp_edb(sta,endo):
  prepand=sentUnit('[',fast=1, p_sta=sta,p_end=endo )
  lazt_id=prepand.id
  edb=sentUnit(']',fast=1)
  edb.id=lazt_id
  return [prepand],[edb]

def flattenretk(retk):
  ret=retk[0]
  retkl=len(retk)
  if retkl == 2:
    ret[0].wgt=float(retk[1][0].msg)
  elif retkl > 2:
    timeinfo=float(retk[2][0].msg)
    hazcot=0
    hazcot+=filltimeinfo(ret,None,timeinfo,0.1)
    hazcot+=filltimeinfo(retk[1],timeinfo,None,0.1)
    if hazcot > 1:
      pp, edb = pp_edb(None,timeinfo)
      ret=pp+ret+edb
      pp, edb = pp_edb(timeinfo,None)
      erz_id=pp[0].id
      ret[0].eraz[erz_id]=erz_id
      retk[1]=pp+retk[1]+edb

    ret+=retk[1]

  return ret



def parsedumbformat(txt,sta=0,l=-1,wgt=1,sqq=False):
  cut0=sta
  if l < 0:
    l=len(txt)
  retk=[[]]
  ptidx=0


  i=sta
  while i < l:
    c=txt[i]
    i+=1
    if c == '(':
      if i-cut0>1:
        retk[ptidx]+=dumbunit(txt[cut0:i-1],wgt)
      cut0, ret = parsedumbformat(txt,i,l,wgt+0.1,sqq=True)
      i=cut0
      retk[ptidx]+=ret
    elif c == '[':
      if i-cut0>1:
        retk[ptidx]+= dumbunit(txt[cut0:i-1],wgt) 
      cut0, ret = parsedumbformat(txt,i,l,wgt-0.1,sqq=True)
      i=cut0
      retk[ptidx]+=ret
    elif c == ')':
      if i-cut0>1:
        retk[ptidx]+= dumbunit(txt[cut0:i-1],wgt) 
      return i,flattenretk(retk)
    elif c == ']':
      if i-cut0>1:
        retk[ptidx]+= dumbunit(txt[cut0:i-1],wgt)
        return i,flattenretk(retk)
    elif sqq and c == ':':
      if i-cut0>1:
        retk[ptidx]+= dumbunit(txt[cut0:i-1],wgt) 
      cut0=i
      ptidx+=1
      retk.append([])


  retk=flattenretk(retk)
  if cut0<l:
    retk+= dumbunit(txt[cut0:],wgt) 
  return retk




def pmpmtx(data_in,nsamp=1,cuda=True,fromtxt=True,enable3d=True):
  if len(data_in[0]) == 0:
    return [cond_getter(None,nsamp=nsamp,cuda=cuda)]
  arr = pmpmtx_preproc(data_in,fromtxt=fromtxt,enable3d=enable3d)

  arrl=len(arr)
  for c in range(arrl):
    arr_for_getter, fastmode,txt, kndref = to_arr_for_getter(arr[c],nsamp=nsamp,cuda=cuda)
    arr[c]=cond_getter(arr_for_getter,fast=fastmode,reftxt=txt,kndref=kndref)
  return arr


def pmpmtx_preproc(data_in,fromtxt=True,enable3d=True):
  global ActivedPromptVars
  ActivedPromptVars=dict()
  arr=[]

  if fromtxt:
    if len(data_in) == 1:
      if '((' in data_in[0]:
        arr=parsedumbformat(data_in[0])
      else:
        arr= emb_and_v(data_in[0], enable3d=enable3d)
    else:
      for d in data_in:
        if d[0] != '#':
          arr+=get_real_return(sentUnit(d))
  else:
    arr=data_in

  if enable3d and ActivedPromptVars:
    arr=proc3d(arr)
  else:
    ActivedPromptVars=dict()
    arr=proc1d(arr)
  return arr

  


def to_arr_for_getter(data,nsamp=1,cuda=True):
  dtal=len(data)
  cpy_ones=np.ones(t_enc,dtype=np.uint8)
  cpy_eraz=cpy_ones*0xff
  mtx=np.ones((dtal,t_enc),dtype=np.uint8)

  txtid=-1
  txtkole=[]
  notTime=True
  for i in range(dtal):
    dta_i=data[i]
    sig = dta_i.get_sig()
    if sig > 0xFF:
      notTime=False
      mtx[i]*=0xFF

      
      sta0, end0 = dta_i.get_realstaend()
      
      mtx[i][sta0:end0]=cpy_ones[sta0:end0]

      if sig > 0xfff:
        erazd=dta_i.eraz
        for k in erazd:
          sta1, end1=erazd[k].get_realstaend()
          mtx[i][sta1:end1]=cpy_eraz[sta1:end1]

  if notTime:
    emb, wgt, txt = cond_stage_model.mk_emb_wgt(data,dtal)
    arr = cond_stage_model.from_emb(emb,wgt_arr=wgt,nsamp=nsamp,cuda=cuda)
    return  arr, 1, txt, None #arr, fastmode, txt


  mtx=mtx.transpose((1,0))
  knd=np.ones(t_enc,dtype=np.uint8)
  ar2i=dict()
  i2txt=[]
  txtid=0
  for i in range(t_enc):
    sig=str(mtx[i].tobytes())[2:-1].replace('\\','')
    if sig in ar2i:
      i_sig=ar2i[sig]
    else:
      ar2i[sig]=txtid
      i2txt.append( m2unit(data,dtal,mtx[i]) )
      i_sig=txtid
      txtid+=1
    knd[i]=i_sig
  

  if knd.sum() == 0:
    emb, wgt, txt = cond_stage_model.mk_emb_wgt(i2txt[0])
    arr = cond_stage_model.from_emb(emb,wgt_arr=wgt,nsamp=nsamp,cuda=cuda)
    return  arr, 1, txt, None
  
  knd_arr=[None]*t_enc
  knd_arr_txt=[None]*t_enc
  enc_l=len(i2txt)

  txtk=[None]*enc_l
  for i in range(enc_l):
    emb, wgt, txt = cond_stage_model.mk_emb_wgt(i2txt[i])
    i2txt[i] = cond_stage_model.from_emb(emb,wgt_arr=wgt,nsamp=nsamp,cuda=cuda)
    txtk[i]=txt
  
  for i in range(t_enc):
    poo=knd[i]
    knd_arr[i]=i2txt[poo]
    knd_arr_txt[i]=txtk[poo]

  return  knd_arr, -1, knd_arr_txt, knd




def wgtfix0(wgt):
  if wgt is None:
    return None
  elif wgt > 2:
    return 1+0.1*wgt
  elif wgt < -2:
    return -1+0.1*wgt
  else:
    return wgt
  
  
def wgtfix(b):
  b.wgt=wgtfix0(b.wgt)
  return b


def trimdpth(dyp):
  ret=[]
  for i in range(9,-1,-1):
    if dyp[i]:
      ret+=list(dyp[i])
  return ret

def trimgroup(unit_arr):
  bdict=dict()
  stapoz=dict()
  

  ul=len(unit_arr)
  dyp=[]
  for i in range(10):
    dyp.append(set())
  depth=0
  clean_ret=[]
  for i in range(ul):
    b=unit_arr[i]
    bmsg=b.msg
    if len(bmsg) == 1:
      if bmsg == '[':
        depth+=1
        dyp[depth].add(b.id)
        bdict[b.id]=b
        stapoz[b.id]=[i+1,None]
      elif bmsg == ']':
        depth-=1
        stapoz[b.id][1]=i
    else:
      clean_ret.append(wgtfix(b))

  dyp=trimdpth(dyp)
  if len(dyp) == 0:
    return clean_ret
  
  for k in dyp:
    sta, endo =stapoz[k]
    bdict[k].wgt=endo-sta+1



  for k in dyp:
    sta, endo =stapoz[k]
    b=bdict[k]
    nfo=b.eraz
    isany=False

    for erzid in nfo:
      cur=bdict[erzid]
      b.eraz[erzid]=cur
      isany=True
      if cur.yetproc:
        sta2, endo2 =stapoz[erzid]
        _,_,cur_osta, cur_oendo = cur.nfo(trans_sta=True,trans_end=True)
        for i in range(sta2,endo2):
          msg,_, cmp_osta, cmp_oendo = unit_arr[i].nfo(trans_sta=True,trans_end=True)
          if msg != ']':
            if cmp_osta < cur_osta:
              cur_osta=cmp_osta
            if cmp_oendo > cur_oendo:
              cur_oendo = cmp_oendo
        cur.upper=cur_osta
        cur.lower=cur_oendo
        cur.yetproc=False
        

    if isany:
      mergedict(unit_arr,b.eraz,sta,endo)
      b.eraz=None


   
  return clean_ret


def mergedict(unit_arr,b_eraz,sta,endo):
  for n in range(sta,endo):
    ue=unit_arr[n]
    if ue.id == 0:
      if ue.eraz:
        for k in b_eraz:
          ue.eraz[k]=b_eraz[k]
      else:
        ue.eraz=b_eraz

def tenzclamp(tenz,tolen=77):
  dup=int(0.9999+(tolen/tenz.size(0)))
  return torch.cat([tenz]*dup)[:tolen]


def prmt_bin(binfna,nsamp=1,cuda=True):
  if '%' in binfna:
    bink=[]
    for i in range(78):
      nfna=binfna%i
      if os.path.isfile(nfna):
        bink.append( torch.tensor( np.fromfile(nfna,dtype=np.float32) ).reshape((-1,768)) )
    tenz = tenzclamp(torch.cat(bink))
  else:
    tenz = tenzclamp( torch.tensor( np.fromfile(binfna,dtype=np.float32) ).reshape((-1,768)) )

  tenz=tenz.expand(nsamp,-1,-1)
  if cuda:
    tenz=tenz.cuda()

  return [cond_getter(tenz,fast=1)]
  

def calcknd(knd_arr,ptxt):
  knd_arr = np.stack(knd_arr).transpose((1,0))
  ar2i=dict()
  i2txt=[]
  txtid=0
  hgt,prmpl=knd_arr.shape

  kndmap=np.ones(hgt,dtype=np.uint8)

  for i in range(hgt):
    sig=str(knd_arr[i].tobytes())[2:-1].replace('\\','')
    if sig in ar2i:
      i_sig=ar2i[sig]
    else:
      ar2i[sig]=txtid
      i2txt.append( i )
      i_sig=txtid
      txtid+=1
    kndmap[i]=i_sig


  stk=len(i2txt)
  for i in range(0,prmpl):
    stacking=[None]*stk
    ge=ptxt[i]

    for n in range(stk):
      stacking[n]=ge.get(i2txt[n])
    ptxt[i]=torch.stack(stacking)

  return kndmap

def kmapout(kndmap,calc_result):
  stk=kndmap.shape[0]
  cout2=[None]*stk

  for i in range(stk):
    cout2[i]=calc_result[ kndmap[i] ]
  return cout2

def prmt_avg(ptxt,pwgt,prmpl):
  knd_arr=[None]*prmpl
  cplx=False
  for i in range(prmpl):
    if not ptxt[i].is_simp:
      cplx=True
    knd_arr[i] = ptxt[i].get_knd()
  
  if cplx:
    kndmap =calcknd( knd_arr, ptxt )
    

    cout=ptxt[0]*pwgt[0]
    for i in range(1,prmpl):
      cout+=(ptxt[i]*pwgt[i])

    
    cout2=kmapout(kndmap,cout)
    

    return [ cond_getter( cout2,kndref=kndmap ) ]

  cout=ptxt[0].get(0)*pwgt[0]
  for i in range(1,prmpl):
    cout+=(ptxt[i].get(0)*pwgt[i])
  return [ cond_getter( cout,fast=1 )]


def prmt_dymc(stz,cuda):
  prmpl=len(stz)>>1
  ptxt=[]
  pstp=[0]
  stpsum=1
  for i in range(prmpl):
    ptxt.append(  makeCs(stz[2*i],1, cuda=cuda,enable3d=False )[0]  )
    soi=float(stz[2*i+1])
    stpsum+=soi
    pstp.append(  stpsum  )

  for i in range(prmpl):
    pstp[i+1]=int(0.5+(pstp[i+1]/stpsum)*t_enc)

  bs_knd=ptxt[0].get_knd().astype(np.uint16)
  bs_arr=ptxt[0].get_fullarr()
  for i in range(1,prmpl):
    cut0=pstp[i]
    bs_knd[cut0:]=ptxt[i].get_knd()[cut0:].astype(np.uint16)+0x100*i
    bs_arr[cut0:]=ptxt[i].get_fullarr()[cut0:]

  return [ cond_getter( bs_arr, kndref=bs_knd ) ]


def prmt_intp_cplx(ptxt,pstp,knd_arr,prmpl):
  kndmap =calcknd( knd_arr, ptxt )

  intpos=[]
  for vv in range(prmpl):
    c1=ptxt[vv]
    c2=ptxt[vv+1]
    stp=pstp[vv]
    for i in range(stp):
      cn= kmapout(kndmap, (c2*i+c1*(stp-i))/stp )
      intpos.append( cond_getter(cn, kndref=kndmap) )

  lztbk=pstp[-1]
  if lztbk > 1:
    c1=ptxt[prmpl]
    c2=ptxt[0]
    for i in range(lztbk):
      cn=kmapout(kndmap, (c2*i+c1*(lztbk-i))/lztbk )
      intpos.append( cond_getter(cn, kndref=kndmap) )
  else:
    cn = kmapout(kndmap,ptxt[-1])
    intpos.append( cond_getter(cn, kndref=kndmap) )
  return intpos

def prmt_intp(stz,cuda):
  prmpl=len(stz)>>1
  ptxt=[None]*prmpl
  pstp=[None]*prmpl
  knd_arr=[None]*prmpl
  cplx=False
  for i in range(prmpl):
    ge=makeCs(stz[2*i],1, cuda=cuda,enable3d=False )[0]
    knd_arr[i] = ge.get_knd()
    if not ge.is_simp:
      cplx=True
    ptxt[i]=  ge  
    pstp[i]=  int(stz[2*i+1])+1  
  prmpl-=1

  if cplx:
    return prmt_intp_cplx(ptxt,pstp,knd_arr,prmpl)
  
  intpos=[]
  for vv in range(prmpl):
    c1=ptxt[vv].get(0)
    c2=ptxt[vv+1].get(0)
    stp=pstp[vv]
    for i in range(stp):
      cn=(c2*i+c1*(stp-i))/stp
      intpos.append( cond_getter(cn,fast=1) )

  lztbk=pstp[-1]
  if lztbk > 1:
    c1=ptxt[prmpl].get(0)
    c2=ptxt[0].get(0)
    for i in range(lztbk):
      cn=(c2*i+c1*(lztbk-i))/lztbk
      intpos.append( cond_getter(cn,fast=1) )
  else:
    intpos.append(ptxt[-1])
  return intpos

def printprompts(detailed=False):
  k=0
  for c in c_list:
    tstr='PromptV'+str(k)+' at step'
    dfarr=None
    if SaveCompiledPrompt:
      dfarr=c.save(tstr)
    k+=1
    if c.txt:
      print(tstr+'0:')
      print(c.get_txt(0))
      if detailed and c.knd is not None:
        if dfarr is not None:
          for j in dfarr:
            print(tstr+str(j)+':')
            print(c.get_txt(j))
        else:
          knd=c.knd
          prev=knd[0]
          knd_l=len(knd)
          for j in range(knd_l):
            cur=knd[j]
            if cur != prev:
              prev=cur
              print(tstr+str(j)+':')
              print(c.get_txt(j))

depthLimit=10

def txtErr(prmt0,msg):
  print(msg)
  prmt=prmt0.split('/')[-1][:-4]
  print('err prompt: '+prmt)
  return pmpmtx([prmt0],nsamp=n_samples,enable3d=False)


def cmdtype(cmd0):
  if cmd0.startswith('intp:'):
    return 1
  elif cmd0.startswith('dymc:'):
    return 2
  elif cmd0.startswith('mad:'):
    return 10
  elif cmd0.startswith('avg:'):
    return 11
  return 0

rtdir=''
def makeCs(prmt,depth=0,cuda=True,enable3d=True):
  global rtdir
  if prmt.endswith('.txt'):
    if depth > depthLimit:
      return txtErr(prmt,'Too many ref, probably circular reference.')
    if depth==0:
      rtdir=''
      try:
        rtdir=prmt[:prmt.rindex('/')+1]
      except:
        pass
    depth+=1
    if not os.path.isfile(prmt):
      prmt=rtdir+prmt
      if not os.path.isfile(prmt):
        return txtErr(prmt,'ref not found.')
    with open(prmt,'rt') as f:
      stz=f.read().splitlines()
    cmd=stz[0].replace(' ','').replace('\t','').split('/')
    cmd0=cmdtype(cmd[0])
    if cmd0 == 0:
      return pmpmtx(stz,nsamp=n_samples,cuda=cuda,enable3d=enable3d)
    elif cmd0 == 1:
      if depth > 1:
        return txtErr(stz[1],'do not intp in ref')
      return prmt_intp(stz[1:],cuda=cuda)
    elif cmd0 == 2:
      return prmt_dymc(stz[1:],cuda=cuda)


    prmpl=(len(stz)-1)>>1
    stz=stz[1:]
    ptxt=[]
    pwgt=[]
    wgtsum=0
    for i in range(prmpl):
      ptxt.append(  makeCs(stz[2*i],depth, cuda=cuda,enable3d=False )[0]  )
      wgt=float(stz[2*i+1])
      wgtsum+=wgt
      pwgt.append(  wgt  )
    if cmd0 == 11:
      for i in range(prmpl):
        pwgt[i]=pwgt[i]/wgtsum
    
    return prmt_avg(ptxt,pwgt,prmpl)

  elif prmt.endswith('.bin'):
    return prmt_bin(prmt,nsamp=n_samples,cuda=cuda)
  elif prmt.endswith('.compiled_prompt'):
    kn=cond_getter([prmt],fast=1)
    kn.load(nsamp=n_samples,cuda=cuda)
    return [kn]
  else:
    return pmpmtx([prmt],nsamp=n_samples,cuda=cuda,enable3d=enable3d)

In [None]:
import os
if not os.path.isfile('PromptFuncsExample/MultiPrompt_average.txt'):
  t3 = Thread(target = dlpromptexample)
  a3 = t3.start()
if os.path.isfile('web/svr.py_one'):
  !mv web/svr.py_one web/svr.py
!rm /content/sample_data/izh.txt

# Super Resolution 4x<br>
Select ONE -- I SAY, JUST ONE -- of these tasks: Super Resolution, txt2img

In [None]:
jit=False #@param {type:'boolean'}

import os

if not os.path.isfile('fsd_pnnx.pt'):
  !wget https://huggingface.co/Larvik/LDMjit/resolve/main/alphas_cumprod.npy
  !wget https://huggingface.co/Larvik/LDMjit/resolve/main/dm_pnnx.pt
  !wget https://huggingface.co/Larvik/LDMjit/resolve/main/fsd_pnnx.pt


import sys
import time

import numpy as np
import cv2
import functools
import torch

cudev=torch.device('cuda')

alphas_cumprod = np.load('alphas_cumprod.npy')


torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True


# ======================
# Arguemnt Parser Config
# ======================

def imread(filename, flags=cv2.IMREAD_COLOR):
    if not os.path.isfile(filename):
        print(f"File does not exist: {filename}")
        sys.exit()
    data = np.fromfile(filename, np.int8)
    img = cv2.imdecode(data, flags)
    return img

def preprocessing_img(img):
    if len(img.shape) < 3:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGRA)
    elif img.shape[2] == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2BGRA)
    elif img.shape[2] == 1:
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGRA)
    return img


def load_image(image_path):
    if os.path.isfile(image_path):
        img = imread(image_path, cv2.IMREAD_UNCHANGED)
    else:
        print(f'{image_path} not found.')
    return preprocessing_img(img)



def meshgrid(h, w):
    y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
    x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)

    arr = torch.cat([y, x], dim=-1)
    return arr


def delta_border(h, w):
    """
    :param h: height
    :param w: width
    :return: normalized distance to image border,
      wtith min distance = 0 at border and max dist = 0.5 at image center
    """
    lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
    arr = meshgrid(h, w) / lower_right_corner
    dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
    dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
    edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
    return edge_dist



def get_weighting(h, w, Ly, Lx, device):
  clip_min_weight = 0.01
  clip_max_weight = 0.5
  weighting = delta_border(h, w)
  weighting = torch.clip(weighting, clip_min_weight, clip_max_weight, )
  weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)


  return weighting

def get_fold_unfold(x, kernel_size, stride, uf=1, df=1):  # todo load once not every time, shorten code
    """
    :param x: img of size (bs, c, h, w)
    :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
    """
    bs, nc, h, w = x.shape

    # number of crops in image
    Ly = (h - kernel_size[0]) // stride[0] + 1
    Lx = (w - kernel_size[1]) // stride[1] + 1

    if uf == 1 and df == 1:
        fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
        unfold = torch.nn.Unfold(**fold_params)

        fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)

        weighting = get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
        normalization = fold(weighting).view(1, 1, h, w)  # normalizes the overlap
        weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))

    elif uf > 1 and df == 1:
        fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
        unfold = torch.nn.Unfold(**fold_params)

        fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
                            dilation=1, padding=0,
                            stride=(stride[0] * uf, stride[1] * uf))
        fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)

        weighting = get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
        normalization = fold(weighting).view(1, 1, h * uf, w * uf)  # normalizes the overlap
        weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))

    elif df > 1 and uf == 1:
        fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
        unfold = torch.nn.Unfold(**fold_params)

        fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
                            dilation=1, padding=0,
                            stride=(stride[0] // df, stride[1] // df))
        fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)

        weighting = get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
        normalization = fold(weighting).view(1, 1, h // df, w // df)  # normalizes the overlap
        weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))

    else:
        raise NotImplementedError

    return fold, unfold, normalization, weighting



def normalize_image(image, normalize_type='255'):
    """
    Normalize image
    Parameters
    ----------
    image: numpy array
        The image you want to normalize
    normalize_type: string
        Normalize type should be chosen from the type below.
        - '255': simply dividing by 255.0
        - '127.5': output range : -1 and 1
        - 'ImageNet': normalize by mean and std of ImageNet
        - 'None': no normalization
    Returns
    -------
    normalized_image: numpy array
    """
    if normalize_type == 'None':
        return image
    elif normalize_type == '255':
        return image / 255.0
    elif normalize_type == '127.5':
        return image / 127.5 - 1.0
    elif normalize_type == 'ImageNet':
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = image / 255.0
        for i in range(3):
            image[:, :, i] = (image[:, :, i] - mean[i]) / std[i]
        return image
    else:
        pass



def preprocess(img):
    im_h, im_w, _ = img.shape

    up_f = 4
    oh, ow = up_f * im_h, up_f * im_w

    img = normalize_image(img, normalize_type='255')

    c = img * 2 - 1
    c = c.transpose(2, 0, 1)  # HWC -> CHW
    c = np.expand_dims(c, axis=0)
    c = c.astype(np.float32)

    return None, c


def postprocess(sample):
    sample = np.clip(sample, -1., 1.)
    sample = (sample + 1.) / 2. * 255
    sample = np.transpose(sample, (1, 2, 0))
    sample = sample[:, :, ::-1]  # RGB -> BGR
    sample = sample.astype(np.uint8)

    return sample


def decode_first_stage(z):
    ks = (128, 128)
    stride = (64, 64)
    uf = 4

    bs, nc, h, w = z.shape

    fold, unfold, normalization, weighting = get_fold_unfold(z, ks, stride, uf=uf)

    z = unfold(z)  # (bn, nc * prod(**ks), L)

    # Reshape to img shape
    z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )


    print('first_stage_decode...')

    outputs = []
    for i in range(z.shape[-1]):
        x = z[:, :, :, :, i]
        output = first_stage_decode(x)
        outputs.append(output[0])

    o = torch.stack(outputs, axis=-1)  # # (bn, nc, ks[0], ks[1], L)
    o = o * weighting
    # Reverse 1. reshape to img shape
    o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
    # stitch crops together
    decoded = fold(o)
    decoded = decoded / normalization  # norm is shape (1, 1, h, w)
    return decoded




# ddpm
def apply_model(x, t, cond,d):
    x_noisy=x
    ks = (128, 128)
    stride = (64, 64)

    h, w = x_noisy.shape[-2:]

    fold, unfold, normalization, weighting = get_fold_unfold(x_noisy, ks, stride)


    z = unfold(x_noisy)  # (bn, nc * prod(**ks), L)
    # Reshape to img shape
    z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
    z_list = [z[:, :, :, :, i] for i in range(z.shape[-1])]

    c = unfold(cond)
    c = c.view((c.shape[0], -1, ks[0], ks[1], c.shape[-1]))  # (bn, nc, ks[0], ks[1], L )
    cond_list = [c[:, :, :, :, i] for i in range(c.shape[-1])]

    # apply model by loop over crops
    
    outputs = []
    for i in range(z.shape[-1]):
        x = z_list[i]
        cond = cond_list[i]
        xc = torch.cat([x, cond], dim=1)
        
        
        output = diffusion_model(xc, t)


        outputs.append(output[0])

    o = torch.stack(outputs, axis=-1)
    o = o * weighting
    # Reverse reshape to img shape
    o = o.view((o.shape[0], -1, o.shape[-1]))  # (bn, nc * ks[0] * ks[1], L)
    # stitch crops together
    x_recon = fold(o) / normalization


    return x_recon

def warmup():
  v_0 = torch.rand(1,6,128,128, dtype=torch.float).half().cuda()
  v_1 = torch.randint(10, (1, ), dtype=torch.long).cuda()

  for d in range(2):
    with torch.cuda.amp.autocast(dtype=torch.float16):
      uaa = diffusion_model(v_0,v_1)
  v_0 = torch.rand(1,3,128,128, dtype=torch.float).cuda()
  for d in range(2):
    uaa = first_stage_decode(v_0)
  torch.cuda.empty_cache()


UseSamplr=sample_euler_ancestral
def predict(c):
    
    c=torch.tensor(c,device=cudev)


    sigmas = f_sigmas()

    noise = torch.randn(c.shape, dtype=torch.float,device=cudev)
    
    extra_args = {'cond': c}
    df=detail_strength/(detail_strength-1+float(sigmas[0]))
    print(df)
    with torch.cuda.amp.autocast(dtype=torch.float16):
        samples = UseSamplr(model_wrap_cfg, noise * sigmas[0] * df , sigmas, extra_args=extra_args, disable=False)
   
    x_sample = decode_first_stage(samples)

    img = postprocess(x_sample[0].cpu().numpy())

    return img

if model_wrap is None:
  first_stage_decode=torch.jit.load('/content/fsd_pnnx.pt').eval().cuda()
  diffusion_model=torch.jit.load('/content/dm_pnnx.pt').eval().half().cuda()
  warmup()
  torch.cuda.empty_cache()
  model_wrap = CompVisDenoiser(CompVisJIT())
  model_wrap_cfg = SRDenoiser(model_wrap)

In [None]:

image_path='/content/sample_data/10_0x0v1.png' #@param {type:'string'}

"""
ddim_timesteps
"""
ddim_eta = 0.75  #@param {type:'number'}
ddim_num_steps = 100  #@param {type:'number'}
ddpm_num_timesteps = 1000 #@param {type:'number'}

detail_strength=20000  #@param {type:'number'}

ddim_timesteps = make_ddim_timesteps(ddim_num_steps, ddpm_num_timesteps)




"""
ddim sampling parameters
"""

ddim_sigmas, ddim_alphas, ddim_alphas_prev = \
    make_ddim_sampling_parameters(
        alphacums=alphas_cumprod,
        ddim_timesteps=ddim_timesteps,
        eta=ddim_eta)

#ddim_sigmas=torch.tensor(ddim_sigmas.astype(np.float32),device=cudev)

ddim_sqrt_one_minus_alphas = np.sqrt(1. - ddim_alphas)



 




# inference
print('Start inference...')
if image_path.endswith('.npy'):
  c=latdec(image_path).detach()
else:
  img = load_image(image_path)
  img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
  img = img[:, :, ::-1]  # BGR -> RGB
  _, c = preprocess(img)
  

img = predict(c)

# plot result
savepath = image_path[:-4]+'_4x.png'
print(f'saved at : {savepath}')
cv2.imwrite(savepath, img)



In [None]:
Sampler='euler_a' #@param ['euler', 'euler_a', 'heun','dpm_2','dpm_2_a','lms']
f_sampler()

Karras=False #@param {type:'boolean'}
KarrasRho = 7.0 #@param {type:'number'}

Optional: SD lat decoder

In [None]:
latent='4x6_1x1v1.npy' #@param {type:'string'}
ext='.png' #@param ['.png', '.jpg']
lat=latdec(latent)
k=0
for lla in lat:
  ymg=Image.fromarray( (( ( lla +1)*127.5 ).cpu().numpy()).transpose(1,2,0).clip(0,255).astype(np.uint8) )
  ymg.save(latent[:-4].replace('x1v','x'+str(k)+'v')+ext)
  k+=1
ymg

Optional: GFPgan-jit

In [None]:
import cv2
import glob
import numpy as np
import os
import torch
from torch import nn
import math


from torchvision.transforms.functional import normalize
from itertools import product




def imwrite(img, file_path, params=None, auto_mkdir=True):

    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(file_path))
        os.makedirs(dir_name, exist_ok=True)
    ok = cv2.imwrite(file_path, img, params)
    if not ok:
        raise IOError('Failed in writing images.')



def bb_intersection_over_union(boxA, boxB):
    # determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    # compute the area of intersection rectangle
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    # compute the area of both the prediction and ground-truth
    # rectangles
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    # compute the intersection over union by taking the intersection
    # area and dividing it by the sum of prediction + ground-truth
    # areas - the interesection area
    iou = interArea / float(boxAArea + boxBArea - interArea)
    # return the intersection over union value
    return iou


def nms_boxes(boxes, scores, iou_thres):
    # Performs non-maximum suppression (NMS) on the boxes according to their intersection-over-union (IoU).

    keep = []
    for i, box_a in enumerate(boxes):
        is_keep = True
        for j in range(i):
            if not keep[j]:
                continue
            box_b = boxes[j]
            iou = bb_intersection_over_union(box_a, box_b)
            if iou >= iou_thres:
                if scores[i] > scores[j]:
                    keep[j] = False
                else:
                    is_keep = False
                    break

        keep.append(is_keep)

    return np.array(keep).nonzero()[0]





def get_anchor(image_size):
    
    min_sizes = [[16, 32], [64, 128], [256, 512]]
    steps = [8, 16, 32]
    feature_maps = [[math.ceil(image_size[0] / step), math.ceil(image_size[1] / step)] for step in steps]

    anchors = []
    for k, f in enumerate(feature_maps):
        m_sizes = min_sizes[k]
        for i, j in product(range(f[0]), range(f[1])):
            for min_size in m_sizes:
                s_kx = min_size / image_size[1]
                s_ky = min_size / image_size[0]
                dense_cx = [x * steps[k] / image_size[1] for x in [j + 0.5]]
                dense_cy = [y * steps[k] / image_size[0] for y in [i + 0.5]]
                for cy, cx in product(dense_cy, dense_cx):
                    anchors.extend([cx, cy, s_kx, s_ky])

    output = np.array(anchors).reshape(-1, 4)
    return output


# Adapted from https://github.com/Hakuyume/chainer-ssd
def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """
    boxes = np.concatenate(
        (priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
         priors[:, 2:] * np.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]

    return boxes


def decode_landm(pre, priors, variances):
    """Decode landm from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        pre (tensor): landm predictions for loc layers,
            Shape: [num_priors,10]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded landm predictions
    """
    tmp = (
        priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
        priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
        priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
        priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
        priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
    )
    landms = np.concatenate(tmp, axis=1)

    return landms



def detect_faces(
        image,
        conf_threshold=0.8,
        nms_threshold=0.4,
        use_origin_size=True,
    ):
        
        height, width = image.shape[:2]
        image = image.transpose(2, 0, 1).astype(np.float32)
        image = torch.from_numpy(image).to(cudevg).unsqueeze(0)

        image = image - torch.tensor([[[[104.]], [[117.]], [[123.]]]])

        loc, conf, landmarks = RetinaFace(image)
        priors = get_anchor((height, width))

        variance = [0.1, 0.2]
        scale = np.array([width, height, width, height])
        scale1 = np.array([
            width, height, width, height, width, height, width, height, width, height
        ])

        boxes = decode(loc[0].cpu().numpy(), priors, variance)
        boxes = boxes * scale
        

        scores = conf[0][:, 1].cpu().numpy()

        landmarks = decode_landm(landmarks[0].cpu().numpy(), priors, variance)
        landmarks = landmarks * scale1
        

        # ignore low scores
        inds = np.where(scores > conf_threshold)[0]
        boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]

        # sort
        order = scores.argsort()[::-1]
        boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]

        # do NMS
        bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
        keep = nms_boxes(bounding_boxes[:, :4], bounding_boxes[:, 4], nms_threshold)
        bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
        return np.concatenate((bounding_boxes, landmarks), axis=1)

def get_largest_face(det_faces, h, w):

    def get_location(val, length):
        if val < 0:
            return 0
        elif val > length:
            return length
        else:
            return val

    face_areas = []
    for det_face in det_faces:
        left = get_location(det_face[0], w)
        right = get_location(det_face[2], w)
        top = get_location(det_face[1], h)
        bottom = get_location(det_face[3], h)
        face_area = (right - left) * (bottom - top)
        face_areas.append(face_area)
    largest_idx = face_areas.index(max(face_areas))
    return det_faces[largest_idx], largest_idx


def get_center_face(det_faces, h=0, w=0, center=None):
    if center is not None:
        center = np.array(center)
    else:
        center = np.array([w / 2, h / 2])
    center_dist = []
    for det_face in det_faces:
        face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
        dist = np.linalg.norm(face_center - center)
        center_dist.append(dist)
    center_idx = center_dist.index(min(center_dist))
    return det_faces[center_idx], center_idx






def img2tensor(imgs, bgr2rgb=True, float32=True):


    def _totensor(img, bgr2rgb, float32):
        if img.shape[2] == 3 and bgr2rgb:
            if img.dtype == 'float64':
                img = img.astype('float32')
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = torch.from_numpy(img.transpose(2, 0, 1))
        if float32:
            img = img.float()
        return img

    if isinstance(imgs, list):
        return [_totensor(img, bgr2rgb, float32) for img in imgs]
    else:
        return _totensor(imgs, bgr2rgb, float32)



def read_image(img):
    """img can be image path or cv2 loaded image."""
    # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]


    if np.max(img) > 256:  # 16-bit image
        img = (img / 65535) * 255
    if len(img.shape) == 2:  # gray image
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    elif img.shape[2] == 4:  # RGBA image with alpha channel
        img = img[:, :, 0:3]

    return img
'''
def srproc(img,fac):
  return cv2.resize(img, None,fx=fac,fy=fac, interpolation=cv2.INTER_LINEAR)
'''
def srproc(img,fac):
  _,c=preprocess(img[:, :, ::-1])
  return predict(c)

class faceimg:
  def __init__(self, image,
                 face_size=512,
                 crop_ratio=(1, 1),
                 save_ext='png',
                 template_3points=False,
                 pad_blur=False,
                 use_parse=False,
                 device=None):
    self.nXimage=read_image(image)
    downscale=1/upscale
    self.input_img=cv2.resize(self.nXimage,None,fx=downscale,fy=downscale,interpolation=cv2.INTER_AREA)
    self.template_3points = template_3points  # improve robustness
    self.upscale_factor = upscale
    # the cropped face ratio based on the square face
    self.crop_ratio = crop_ratio  # (h, w)
    assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
    self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))

    if self.template_3points:
        self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
    else:
        # standard 5 landmarks for FFHQ faces with 512 x 512
        self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
                                        [201.26117, 371.41043], [313.08905, 371.15118]])
    self.face_template = self.face_template * (face_size / 512.0)
    if self.crop_ratio[0] > 1:
        self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
    if self.crop_ratio[1] > 1:
        self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
    self.save_ext = save_ext
    self.pad_blur = pad_blur
    if self.pad_blur is True:
        self.template_3points = False

    self.all_landmarks_5 = []
    self.det_faces = []
    self.affine_matrices = []
    self.inverse_affine_matrices = []
    self.cropped_faces = []
    self.pad_input_imgs = []
    self.restored_faces=[]


    # init face parsing model
    self.use_parse = use_parse
  def get_face_landmarks_5(self,
              only_keep_largest=False,
              only_center_face=False,
              resize=None,
              blur_ratio=0.01,
              eye_dist_threshold=None):
    if resize is None:
        scale = 1
        input_img = self.input_img
    else:
        h, w = self.input_img.shape[0:2]
        scale = min(h, w) / resize
        h, w = int(h / scale), int(w / scale)
        input_img = cv2.resize(self.input_img, (w, h), interpolation=cv2.INTER_LANCZOS4)

    with torch.no_grad():
        bboxes = detect_faces( input_img ) * scale #0.97
    for bbox in bboxes:
        # remove faces with too small eye distance: side faces or too small faces
        eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
        if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
            continue

        if self.template_3points:
            landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
        else:
            landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
        self.all_landmarks_5.append(landmark)
        self.det_faces.append(bbox[0:5])
    if len(self.det_faces) == 0:
        return 0
    if only_keep_largest:
        h, w, _ = self.input_img.shape
        self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
        self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
    elif only_center_face:
        h, w, _ = self.input_img.shape
        self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
        self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]

    # pad blurry images
    if self.pad_blur:
        self.pad_input_imgs = []
        for landmarks in self.all_landmarks_5:
            # get landmarks
            eye_left = landmarks[0, :]
            eye_right = landmarks[1, :]
            eye_avg = (eye_left + eye_right) * 0.5
            mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
            eye_to_eye = eye_right - eye_left
            eye_to_mouth = mouth_avg - eye_avg

            # Get the oriented crop rectangle
            # x: half width of the oriented crop rectangle
            x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
            #  - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
            # norm with the hypotenuse: get the direction
            x /= np.hypot(*x)  # get the hypotenuse of a right triangle
            rect_scale = 1.5
            x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
            # y: half height of the oriented crop rectangle
            y = np.flipud(x) * [-1, 1]

            # c: center
            c = eye_avg + eye_to_mouth * 0.1
            # quad: (left_top, left_bottom, right_bottom, right_top)
            quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
            # qsize: side length of the square
            qsize = np.hypot(*x) * 2
            border = max(int(np.rint(qsize * 0.1)), 3)

            # get pad
            # pad: (width_left, height_top, width_right, height_bottom)
            pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
                    int(np.ceil(max(quad[:, 1]))))
            pad = [
                max(-pad[0] + border, 1),
                max(-pad[1] + border, 1),
                max(pad[2] - self.input_img.shape[0] + border, 1),
                max(pad[3] - self.input_img.shape[1] + border, 1)
            ]

            if max(pad) > 1:
                # pad image
                pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
                # modify landmark coords
                landmarks[:, 0] += pad[0]
                landmarks[:, 1] += pad[1]
                # blur pad images
                h, w, _ = pad_img.shape
                y, x, _ = np.ogrid[:h, :w, :1]
                mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
                                                    np.float32(w - 1 - x) / pad[2]),
                                  1.0 - np.minimum(np.float32(y) / pad[1],
                                                    np.float32(h - 1 - y) / pad[3]))
                blur = int(qsize * blur_ratio)
                if blur % 2 == 0:
                    blur += 1
                blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
                # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)

                pad_img = pad_img.astype('float32')
                pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
                pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
                pad_img = np.clip(pad_img, 0, 255)  # float32, [0, 255]
                self.pad_input_imgs.append(pad_img)
            else:
                self.pad_input_imgs.append(np.copy(self.input_img))

    return len(self.all_landmarks_5)
  def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
    """Align and warp faces with face template.
    """
    if self.pad_blur:
        assert len(self.pad_input_imgs) == len(
            self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
    for idx, landmark in enumerate(self.all_landmarks_5):
        # use 5 landmarks to get affine matrix
        # use cv2.LMEDS method for the equivalence to skimage transform
        # ref: https://blog.csdn.net/yichxi/article/details/115827338
        affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
        self.affine_matrices.append(affine_matrix)
        # warp and crop faces
        if border_mode == 'constant':
            border_mode = cv2.BORDER_CONSTANT
        elif border_mode == 'reflect101':
            border_mode = cv2.BORDER_REFLECT101
        elif border_mode == 'reflect':
            border_mode = cv2.BORDER_REFLECT
        if self.pad_blur:
            input_img = self.pad_input_imgs[idx]
        else:
            input_img = self.input_img
        cropped_face = cv2.warpAffine(
            input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132))  # gray
        self.cropped_faces.append(cropped_face)
        # save the cropped face
        if save_cropped_path is not None:
            path = os.path.splitext(save_cropped_path)[0]
            save_path = f'{path}_{idx:02d}.{self.save_ext}'
            imwrite(cropped_face, save_path)
  def add_restored_face(self, face):
    self.restored_faces.append(face)
  def get_inverse_affine(self, save_inverse_affine_path=None):
    """Get inverse affine matrix."""
    for idx, affine_matrix in enumerate(self.affine_matrices):
        inverse_affine = cv2.invertAffineTransform(affine_matrix)
        inverse_affine[:, 2]*= self.upscale_factor
        #inverse_affine *= self.upscale_factor
        self.inverse_affine_matrices.append(inverse_affine)
        # save inverse affine matrices
        if save_inverse_affine_path is not None:
            path, _ = os.path.splitext(save_inverse_affine_path)
            save_path = f'{path}_{idx:02d}.pth'
            torch.save(inverse_affine, save_path)
  def paste_faces_to_input_image(self, save_path=None):
    h, w, _ = self.input_img.shape
    h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)

    upsample_img = self.nXimage

    assert len(self.restored_faces) == len(
        self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
    maskpool=None
    restorepool=None
    for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
        
        if (inverse_affine[0][0]*self.upscale_factor) < 1.5:
          inverse_affine[:, 2]/= self.upscale_factor
          inverse_affine*=self.upscale_factor
          restored_face=restored_face.astype('uint8')
        else:
          restored_face=srproc(restored_face,self.upscale_factor).astype('uint8')
          

        if self.upscale_factor > 1:
            extra_offset = 0.5 * self.upscale_factor
        else:
            extra_offset = 0
        inverse_affine[:, 2] += extra_offset
        
        inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))

        
        # inference
        face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
        face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
        normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
        face_input = torch.unsqueeze(face_input, 0).to(cudevg)
        with torch.no_grad():
            out = face_parse(face_input)[0]
        out = out.argmax(dim=1).squeeze().cpu().numpy()

        mask = np.zeros(out.shape)
        MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
        for idx, color in enumerate(MASK_COLORMAP):
            mask[out == idx] = color
        #  blur the mask
        mask = cv2.GaussianBlur(mask, (101, 101), 11)
        mask = cv2.GaussianBlur(mask, (101, 101), 11)
        # remove the black borders
        thres = 10
        mask[:thres, :] = 0
        mask[-thres:, :] = 0
        mask[:, :thres] = 0
        mask[:, -thres:] = 0
        mask = mask / 255.

        mask = cv2.resize(mask, restored_face.shape[:2])
        mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up), flags=3)
        inv_soft_mask = mask[:, :, None]
        pasted_face = inv_restored

        if maskpool is None:
          maskpool=inv_soft_mask
          restorepool=np.zeros(inv_soft_mask.shape)
          blanc=np.ones(inv_soft_mask.shape)
        else:
          maskpool = inv_soft_mask*blanc+(1 - inv_soft_mask)*maskpool

        inv_hard_mask=np.array(inv_soft_mask, copy=True)
        inv_hard_mask[np.where(inv_hard_mask!=0)]=1.0
        restorepool = inv_hard_mask * pasted_face + (1 - inv_hard_mask) * restorepool

    if np.max(upsample_img) > 256:  # 16-bit image
        upsample_img = np.concatenate((restorepool, maskpool*65535), axis=2).astype(np.uint16)
    else:
        upsample_img = np.concatenate((restorepool, maskpool*255), axis=2).astype(np.uint8)
    if save_path is not None:
        path = os.path.splitext(save_path)[0]
        save_path = f'{path}.{self.save_ext}'
        imwrite(upsample_img, save_path)
    return upsample_img


def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):

    if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
        raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')

    if torch.is_tensor(tensor):
        tensor = [tensor]
    result = []
    for _tensor in tensor:
        _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
        _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])

        n_dim = _tensor.dim()
        if n_dim == 4:
            img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
            img_np = img_np.transpose(1, 2, 0)
            if rgb2bgr:
                img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
        elif n_dim == 3:
            img_np = _tensor.numpy()
            img_np = img_np.transpose(1, 2, 0)
            if img_np.shape[2] == 1:  # gray image
                img_np = np.squeeze(img_np, axis=2)
            else:
                if rgb2bgr:
                    img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
        elif n_dim == 2:
            img_np = _tensor.numpy()
        else:
            raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
        if out_type == np.uint8:
            # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
            img_np = (img_np * 255.0).round()
        img_np = img_np.astype(out_type)
        result.append(img_np)
    if len(result) == 1:
        result = result[0]
    return result

def doenh_gfp(cropped_face_t):
  global gfpgan_enc
  global gfpgan_dec
  if gfpgan_enc is None:
    gfpgan_enc =torch.jit.load('gfpgan_enc_pnnx.pt').eval().to(cudevg)
    gfpgan_dec =torch.jit.load('gfpgan_dec_pnnx.pt').eval().to(cudevg)
  latent, conditions = gfpgan_enc(cropped_face_t)
  output = gfpgan_dec(latent,*conditions)
  return output

doenh=doenh_gfp

@torch.no_grad()
def enhance(img, has_aligned=False, only_center_face=False, paste_back=True):
  
  faces=faceimg(img)

  if has_aligned:  # the inputs are already aligned
      img = cv2.resize(img, (512, 512))
      faces.cropped_faces = [img]
  else:
      faces.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
      # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
      # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
      # align and warp each face
      faces.align_warp_face()

  # face restoration
  for cropped_face in faces.cropped_faces:
      # prepare data
      cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
      normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
      cropped_face_t = cropped_face_t.unsqueeze(0).to(cudevg)

      output = doenh(cropped_face_t)
      restored_face = tensor2img(output[0].cpu(), rgb2bgr=True, min_max=(-1, 1))
  

      restored_face = restored_face
      faces.add_restored_face(restored_face)

  if not has_aligned and paste_back:
      # upsample the background
      

      faces.get_inverse_affine(None)
      # paste each restored face to the input image
      restored_img = faces.paste_faces_to_input_image()
      return faces, restored_img
  else:
      return faces, None



if not os.path.isfile('retinaface_pnnx.pt'):
  !wget https://huggingface.co/Larvik/GFPGANjit/resolve/main/face_parse_pnnx.pt
  !wget https://huggingface.co/Larvik/GFPGANjit/resolve/main/gfpgan_dec_pnnx.pt
  !wget https://huggingface.co/Larvik/GFPGANjit/resolve/main/gfpgan_enc_pnnx.pt
  !wget https://huggingface.co/Larvik/GFPGANjit/resolve/main/retinaface_pnnx.pt

GFPgan_device='cpu' #@param ['cpu', 'cuda']
cudevg=torch.device(GFPgan_device)


gfpgan_enc=None
RetinaFace =torch.jit.load('retinaface_pnnx.pt').eval().to(cudevg)
face_parse =torch.jit.load('face_parse_pnnx.pt').eval().to(cudevg)


In [None]:
input='/content/aaa2_4x.png' #@param {type:'string'}
output='results'

upscale=4
suffix=None
only_center_face=False
aligned=False
ext='auto'


# ------------------------ input & output ------------------------
if input.endswith('/'):
    input = input[:-1]
if os.path.isfile(input):
    img_list = [input]
else:
    img_list = sorted(glob.glob(os.path.join(input, '*')))

os.makedirs(output, exist_ok=True)





# ------------------------ restore ------------------------
for img_path in img_list:
    # read image
    img_name = os.path.basename(img_path)
    print(f'Processing {img_name} ...')
    basename, ext = os.path.splitext(img_name)
    input_img = cv2.imread(img_path, cv2.IMREAD_COLOR)

    # restore faces and background if necessary
    faces, restored_img = enhance(input_img, has_aligned=aligned, only_center_face=only_center_face, paste_back=True)

    # save faces
    for idx, (cropped_face, restored_face) in enumerate(zip(faces.cropped_faces, faces.restored_faces)):
        # save cropped face
        save_crop_path = os.path.join(output, 'cropped_faces', f'{basename}_{idx:02d}.png')
        imwrite(cropped_face, save_crop_path)
        # save restored face
        if suffix is not None:
            save_face_name = f'{basename}_{idx:02d}_{suffix}.png'
        else:
            save_face_name = f'{basename}_{idx:02d}.png'
        save_restore_path = os.path.join(output, 'restored_faces', save_face_name)
        imwrite(restored_face, save_restore_path)
        # save comparison image
        cmp_img = np.concatenate((cropped_face, restored_face), axis=1)
        imwrite(cmp_img, os.path.join(output, 'cmp', f'{basename}_{idx:02d}.png'))

    # save restored img
    if restored_img is not None:
        if ext == 'auto':
            extension = ext[1:]
        else:
            extension = ext

        if suffix is not None:
            save_restore_path = os.path.join(output, 'restored_imgs', f'{basename}_{suffix}.png')
        else:
            save_restore_path = os.path.join(output, 'restored_imgs', f'{basename}.png')
        imwrite(restored_img, save_restore_path)

print(f'Results are in the [{output}] folder.')


# txt2img

In [None]:
SDver='470k' #@param ['440k', '470k']
Dfm='Orig' #@param ['Orig','_inpaint','_imgemb','_a19561','_a17750','_a17750_e9750','_e26500','_z313000']
DfmCodeBase='JIT' #@param ['JIT', 'ldm_xformers']
EnableKVmerges=False #@param {type:'boolean'}

INP=False

if Dfm=='Orig':
  Dfm=''
elif Dfm=='_inpaint':
  INP=True


import os

SDver=f_dljit(SDver,Dfm)

import sys
import time
import random
import numpy as np
import cv2
from PIL import Image
import PIL
from IPython.core.display import HTML
from transformers import CLIPTokenizer

import torch

torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True

alphas_cumprod = np.load('alphas_cumprod.npz')['a']

cudev=torch.device('cuda')







preimg=None
revpreimg=None


from accelerate import init_empty_weights
with init_empty_weights():
  preprocK=PreKV(768)
  preprocV=PreKV(768)

# ddpm
def apply_model_jit(x, t, cond,d):

  x_in=x

  if INP:
    x_in = torch.cat([x, image_embed], dim=1)

  h, emb, hs = diffusion_emb(x_in, t, cond)
  
  h = diffusion_mid(h, emb, cond, *hs[6:])
  hlog0.Arevpre(h,d)
  output = diffusion_out(h, emb, cond, *hs[:6])

  return output

def apply_model_ldm(x, t, cond,d):

  x_in=x

  if INP:
    x_in = torch.cat([x, image_embed], dim=1)

  cond_k=cond
  cond_v=cond
  if preprocK.use:
    cond_k=preprocK(cond)
    cond_v=preprocV(cond)


  h, emb, hs = ldm_unet.forward_crossattn(x_in, t, cond_k,cond_v) 
  
  h = ldm_unet.forward2(h, emb, cond_k, *hs[6:],cond_v) 
  hlog0.Arevpre(h,d)
  output = ldm_unet.forward3(h, emb, cond_k, *hs[:6],cond_v) 

  return output

apply_model=apply_model_jit

preview_mtx = torch.tensor( [
    #   R       G       B
    [ 0.298,  0.207,  0.208],  # L1
    [ 0.187,  0.286,  0.173],  # L2
    [-0.158,  0.189,  0.264],  # L3
    [-0.184, -0.271, -0.473],  # L4
],device=cudev)

# decoder
def decode_first_stage(z, hsz=-1):
  if hsz < 0:
    hsz=(n_samples*H*W)
  if hsz>0x200000:
    return (z.permute(0,2,3,1) @ preview_mtx).permute(0,3,1,2)

  output = autoencoder(z/0.18215)
  return output



def load_img(path):
    image = Image.open(path).convert("RGB")
    w, h = image.size
    print(f"loaded input image of size ({w}, {h}) from {path}")
    w2, h2 = map(lambda x: x - x % 32, (w, h))
    if w!=w2 or h!=h2:
      image = image.resize((w2, h2), resample=PIL.Image.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.*image - 1.

def cutorexpand(tenz,dstbsz):
  retsz=tenz.size(0)
  if retsz > dstbsz:
    return tenz[:dstbsz]
  elif retsz < dstbsz:
    return tenz[:1].expand(dstbsz,-1,-1,-1)

def fiximgemb(nshape):
  ret=image_embed
  
  if tuple(ret.shape[2:]) != nshape:
    ret = torch.nn.functional.interpolate(ret ,size=nshape,mode='bicubic')
  dstbsz=(noise.size(0)<<1)
  ret=cutorexpand(ret,dstbsz)

  return ret

def waitingmask():
  while not os.path.isfile('/content/user_mask.npy'):
    time.sleep(5)

def findnpy(maskpath):
  if maskpath.endswith('.npy'):
    return torch.tensor( np.load(maskpath).astype(np.float32) )
  maa=maskpath[:-3]+'npy'
  if os.path.isfile(maa):
    return torch.tensor( np.load(maa).astype(np.float32) )
  if os.path.isfile(maskpath):
    mask = Image.open(maskpath).convert('RGB')
    mask = mask.resize(( lat.size(3) , lat.size(2) )).point( lambda p: 255 if p > 128 else 0 ).convert('1')
    return torch.tensor(np.array(mask).astype(np.float32))
  else:
    print("Make a mask with webui (u shouldn't waste colab gpu time doing such thing)")
    localhttp()
    if os.path.isfile('web/svr.py'):
      runpyproc('web/svr')
    web_masking()
    waitingmask()
    return torch.tensor( np.load('/content/user_mask.npy').astype(np.float32) )



def do_masking(maskpath, lat=None):
  if lat is None:
    lat=preimg
  mask = findnpy(maskpath)
  
    
  return (lat*mask).cuda()


def encodepatt():
  ozi=output_pattern.split('/')[-1]
  pdir=output_pattern[:-len(ozi)-1]
  flist=os.listdir(pdir)
  flist.sort()
  pdir+='/'
  rpt=load_img(pdir+flist[0])
  vB=1
  vH=rpt.size(2)
  vW=rpt.size(3)
  thsize=torch.Size([vB,4,vH>>3,vW>>3])
  noyaz=torch.randn(thsize)
  zadd=0
  for f in flist:
    if f.endswith('.png'):
      vlat=imgenc(  load_img( pdir+f) ,  noyaz )*0.18215
      vlat.numpy().tofile(pdir+f[:-3]+'bin')
      zadd+=1
  with open(output_pattern[:-3].replace('%','!@!')+'txt','wt') as f:
    f.write(str(list(thsize))[1:-1]+'\n'+str(zadd))
  !rm {pdir}*.png



def warmup():
  global image_embed
  v_0 = torch.rand(2, 4, 32, 32, dtype=torch.float).half().cuda()
  if INP:
    image_embed=v_0
  v_1 = torch.randint(10, (2, ), dtype=torch.long).cuda()
  v_2 = torch.rand(2, 77, 768, dtype=torch.float).half().cuda()
  for d in range(2):
    with torch.cuda.amp.autocast(dtype=torch.float16):
      uaa = apply_model(v_0,v_1,v_2,d)
  v_0 = torch.rand(1, 4, 32, 32, dtype=torch.float).cuda()
  for d in range(2):
    uaa = autoencoder(v_0)
  if INP:
    del image_embed
  torch.cuda.empty_cache()
  



fext='_%dx%dv%d.png'
def saver():
  global x_samples
  i=iita
  np.save( (outputp+fext%(i,1,ktta))[:-4] + '.npy', samples)
  x_samples = np.clip((x_samples.numpy() + 1.0) / 2.0, a_min=0.0, a_max=1.0)
  k=0
  for x_sample in x_samples:
      x_sample = x_sample.transpose(1, 2, 0)  # CHW -> HWC
      x_sample = x_sample * 255
      img = x_sample.astype(np.uint8)
      img = img[:, :, ::-1]  # RGB -> BGR
      cv2.imwrite(outputp+fext%(i,k,ktta), img)
      k+=1



fextWENS='_%s_%dx%dv%d.png'
WENSsig=['W','E','N','S']
def saverWENS():
  global x_samples
  i=iita
  otp2=init_img[:-4]
  dfn=fn
  if dfn>3:
    dfn-=4
  sigg=WENSsig[dfn]
  np.save( (otp2+fextWENS%(sigg,i,1,ktta))[:-4] + '.npy', samples)
  x_samples = np.clip((x_samples.numpy() + 1.0) / 2.0, a_min=0.0, a_max=1.0)
  k=0
  for x_sample in x_samples:
      x_sample = x_sample.transpose(1, 2, 0)  # CHW -> HWC
      x_sample = x_sample * 255
      img = x_sample.astype(np.uint8)
      img = img[:, :, ::-1]  # RGB -> BGR
      cv2.imwrite(otp2+fextWENS%(sigg,i,k,ktta), img)
      k+=1


UseSamplr=sample_lms

def predict(c_list, uc,noi=None):
    global x_samples
    global samples
    global ktta
    global tmpfeeder
    global noise



    feeder=ifeeder()

    sigmas = f_sigmas()

    shape_alters=[]
    if len(shape) > 4:
      shape_r=shape[:4]
      shape_alters=shape[4:]
    else:
      shape_r=shape

    if noi is None:
      noise = torch.randn(shape_r, dtype=torch.float,device=cudev)
    else:
      noise = noi
    if preimg is not None:
      vt_enc= t_enc-1
      sigma_sched = sigmas[ddim_num_steps - vt_enc - 1:]
      if preimg.dim()==1:
        cmd0=int(preimg[0])
        if cmd0 == 2:
          feeder.pattern=tmpfeeder.pattern
          feeder.shape=tmpfeeder.shape
          feeder.getn=feeder.get_npbins
          feeder.noiseadd=noise * sigmas[ddim_num_steps - vt_enc - 1]
          c_list=[c_list[0]]*tmpfeeder.xpenlen
          feeder.xpenlen=tmpfeeder.xpenlen
          tmpfeeder=feeder
      else:
        img = preimg.cuda() + noise * sigmas[ddim_num_steps - vt_enc - 1]
        feeder.setbs(img)
    else:
      img = noise*sigmas[0]
      sigma_sched=sigmas
      feeder.setbs(img)


    ktta=0
    for c in c_list:
      c.reset()
      samples = zemp0(feeder.getn(ktta),c,sigma_sched,shape_alters)
      ktta+=1
      
      
      x_samples = decode_first_stage(  samples ).cpu()
      samples=samples.cpu()
      t3 = Thread(target = saver)
      a3 = t3.start()

    return
predict_orig=predict

def zemp(noi,c,sigmas):
  global noise
  noise=noi
  extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
  c.reset()
  with torch.cuda.amp.autocast(dtype=torch.float16):
    samples = UseSamplr(model_wrap_cfg, noise*sigmas[0], sigmas, extra_args=extra_args, disable=False)
  return samples


def ConcatNoise_shuf(noi, new_h, new_w):
  shape_new = list(noi.shape)
  m64_orig=(shape_new[-1]*shape_new[-2])>>6
  shape_new[-1]=m64_orig
  shape_new[-2]=64
  r1=noi.permute(1,3,0,2).reshape(shape_new)

  shape_new[-1]=new_w
  shape_new[-2]=new_h
  m64_new=(new_w*new_h)>>6
  dup=int(0.9999+(m64_new/m64_orig))
  return (torch.cat([r1]*dup,dim=3)[:,:,:,:m64_new]).permute(1,3,0,2).reshape(shape_new)

def ConcatNoise_rdm(noi, new_h, new_w):
  shape_r=list(noi.shape)
  shape_r[-1]=new_w
  shape_r[-2]=new_h
  return torch.randn(shape_r, dtype=torch.float,device=cudev)

ConcatNoise=ConcatNoise_shuf

def zemp0(noi,c,sigmas,shape_alters):
  global noise
  global image_embed
  samples=noi
  curnoise=noise
  shape_alt=[[(0,0),0,None,0]]
  if len(shape_alters) > 0:
    sigml=len(sigmas)
    l2=len(shape_alters)//3
    shape_alt=[None]*(l2+1)
    shape_alt[0]=[(0,0),0,None,0]
    for t in range(l2):
      shape_alt[t+1]=[(0,0),None,None,0]
      shape_alt[t+1][0]= ( shape_alters[t*3+1], shape_alters[t*3+2] )
      tima=int(sigml*shape_alters[t*3])
      shape_alt[t+1][3]=shape_alters[t*3]
      shape_alt[t][2]=tima
      shape_alt[t+1][1]=tima-1
      shape_alt[t]=tuple(shape_alt[t])
    dsth,dstw=shape_alt[-1][0]
  else:
    dsth=curnoise.size(2)
    dstw=curnoise.size(3)
  shape_alt[-1]=tuple(shape_alt[-1])


  extra_args = {'cond': c, 'uncond': uc, 'cond_scale': cfg_scale}
  with torch.cuda.amp.autocast(dtype=torch.float16):
    kot1=0
    for r_shape, sta, endo, o_rate in shape_alt:
      new_h,new_w=r_shape
      v_sta=sta>>1
      if o_rate > 0.4:
        v_sta=int(sta*o_rate)
      sub_sigma=sigmas[v_sta:endo]
      if new_h != 0:
        c.add_sta=v_sta
        noise_new=ConcatNoise(curnoise,new_h,new_w)
        noise = noise_new
        samples=torch.nn.functional.interpolate(samples-curnoise*sigmas[sta] ,size=r_shape,mode='bicubic')+noise_new*sub_sigma[0]
        curnoise=noise_new
      else:
        new_h=curnoise.size(2)
        new_w=curnoise.size(3)
      if INP:
        image_embed = fiximgemb((new_h,new_w))

      hlog0.set_multinm(kot1,new_h,dsth,new_w,dstw)
      samples = UseSamplr(model_wrap_cfg, samples, sub_sigma, extra_args=extra_args, disable=False)
      kot1+=1
  return samples

def calcsz(tenz):
  return (tenz.size(0)*tenz.size(2)*tenz.size(3))<<6

def UnCrop(c_list, uc):
  global x_samples
  global samples
  global ktta
  global fn
  global image_embed

  if not INP:
    print('Not inpainting weights, fail back to old buggy dumb uncrop')
    UnCrop_old(c_list, uc)
    return

  
  sigmas = f_sigmas()

  a4=area4(preimg)
  noise_whole = torch.randn(a4.skey, dtype=torch.float,device=cudev)

  embdstsz=noise_whole.size(0)<<1
  zero_whole=torch.zeros(a4.skey, dtype=torch.float,device=cudev)
  zero_whole[:,:,a4.Npad:a4.O_h+a4.Npad,a4.Wpad:a4.O_w+a4.Wpad]=preimg

  corner8=a4.getshapes()
  ktta=0
  for c in c_list:
    nxktta=ktta+1
    for kU,kB,kL,kR,fn in corner8:

      image_embed=cutorexpand(zero_whole[:,:,kU:kB,kL:kR],  embdstsz )
      samples=zemp(noise_whole[:,:,kU:kB,kL:kR],c,sigmas)
      zero_whole[:,:,kU:kB,kL:kR]=samples
      zero_whole[:,:,a4.Npad:a4.O_h+a4.Npad,a4.Wpad:a4.O_w+a4.Wpad]=preimg

      x_samples = decode_first_stage(  samples, calcsz(noise) ).cpu()
      samples=samples.cpu()
      ktta=nxktta
      t3 = Thread(target = saverWENS)
      a3 = t3.start()
    x_samples = decode_first_stage(  zero_whole, calcsz(zero_whole) ).cpu()
    samples=zero_whole.cpu()
    ktta=999
    t3 = Thread(target = saver)
    a3 = t3.start()




def UnCrop_old(c_list, uc):
  global x_samples
  global samples
  global ktta
  global fn
  global hlog0






  sigmas = f_sigmas()

  a4=area4(preimg)
  noise_whole = torch.randn(a4.skey, dtype=torch.float,device=cudev)
  zero_whole=torch.zeros(a4.skey, dtype=torch.float,device=cudev)
  zero_whole[:,:,a4.Npad:a4.O_h+a4.Npad,a4.Wpad:a4.O_w+a4.Wpad]=preimg


  hlog0.h_bs=[None]*ddim_num_steps
  corner8=a4.getshapes()

  ktta=0
  for c in c_list:
    if CopyDFout:
      hlog0.setfuncb('1s')
      hlog0.setfunc('loghs')

    zemp(noise_whole[:,:,a4.Npad:a4.O_h+a4.Npad,a4.Wpad:a4.O_w+a4.Wpad],c,sigmas)
    hlog0.setbsB(-10,preimg)

    hlog0.setfuncb('0')
    nxktta=ktta+1
    for kU,kB,kL,kR,fn in corner8:
      if CopyDFout:
        hlog0.setfuncN(fn)
      hlog0.setfuncNb(fn)
      samples=zemp(noise_whole[:,:,kU:kB,kL:kR],c,sigmas)

      zero_whole[:,:,kU:kB,kL:kR]=samples
      #zero_whole[:,:,a4.Npad:a4.O_h+a4.Npad,a4.Wpad:a4.O_w+a4.Wpad]=preimg

      hlog0.setbsB(fn,samples)
      x_samples = decode_first_stage(  samples, calcsz(noise) ).cpu()
      samples=samples.cpu()
      ktta=nxktta
      t3 = Thread(target = saverWENS)
      a3 = t3.start()
    x_samples = decode_first_stage(  zero_whole, calcsz(zero_whole) ).cpu()
    samples=zero_whole.cpu()
    ktta=999
    t3 = Thread(target = saver)
    a3 = t3.start()
  hlog0=hlogger()


def imgthumb(z):
  x_sample= (z.permute(1,2,0) @ preview_mtx).permute(2,0,1)
  x_sample = np.clip((x_sample.cpu().numpy() + 1.0) / 2.0, a_min=0.0, a_max=1.0)

  x_sample = x_sample.transpose(1, 2, 0)  # CHW -> HWC
  x_sample = x_sample * 255
  img = x_sample.astype(np.uint8)
  img = img[:, :, ::-1]  # RGB -> BGR
  cv2.imwrite('/content/tb.png', img)

def readhtml(fna):
  with open(fna,'rt') as f:
    kt=f.read()
  return kt

psg0=readhtml('web/psg0.htm')
psg1=readhtml('web/psg1.htm')


def web_masking():
  if initymgtyp == 0:
    popic=init_img[:-4]
    if os.path.isfile(popic):
      bsimg=popic
    elif os.path.isfile(popic+'.png'):
      bsimg=popic+'.png'
    else:
      imgthumb(preimg[0].cuda())
      bsimg='tb.png'
  elif initymgtyp == 1:
    bsimg=init_img
  strW=str(W>>3)
  strH=str(H>>3)
  bsimg=bsimg.replace('/content/','')
  display(HTML(psg0+ strW +';\nvar dYh='+strH+";\nvar imgfna='"+bsimg+"';\n\n"+psg1))
  with open('web/curmsk.txt','wt') as f:
    f.write('\n'.join([strW,strH,bsimg]))


def dumplogs(lbg):
  arnnstr='/content/sample_data/vyi'+str(lbg)
  !mkdir {arnnstr}
  lu=hlog0.latlog_arr[lbg]
  lul=len(lu)-1
  psta='0'
  for n in range(lul,-1,-1):
    if lu[n] is None:
      psta=str(n)
      break
    Image.fromarray( (( ( latdec2(lu[n])[0] +1)*127.5 ).cpu().numpy()).transpose(1,2,0).clip(0,255).astype(np.uint8) ).save(arnnstr+'/stp%05d.png'%n)
  !ffmpeg -framerate 3 -start_number {psta} -i {arnnstr}/stp%05d.png -pix_fmt yuv420p intp{psta}.mp4

def EnLarge2seed_size(enl):
  l_enl = len(enl)//3
  if l_enl == 0:
    return H, W, []
  l_enlx3=l_enl*3
  newH=enl[l_enlx3-3]
  newW=enl[l_enlx3-2]
  if l_enl > 1:
    for i in range(l_enl-1,0,-1):
      enl[i*3]=enl[i*3-3]
      enl[i*3+1]=enl[i*3-2]

  enl[0]=H
  enl[1]=W
  return newH,newW,enl


def init_img_type():
  global init_img
  global tmpfeeder
  if init_img.endswith('.npy'):
    return 0
  elif init_img.endswith('.jpg') or init_img.endswith('.png'):
    if os.path.isfile(init_img+'.npy'):
      init_img+='.npy'
      return 0
    else:
      return 1
  elif init_img.endswith('.txt'):
    return 2
  else:
    return 99
initymgtyp=99

def fixb64(dx2,dx1):
  ruz=dx1&7
  if ruz !=0:
    dx1-=ruz
    dx2+=ruz>>1
  return (dx2,dx1)


def mkUnCrop4():
  vbdr=UnCrop4
  if len(vbdr) != 8:
    vbdr+=[0,0,0,0]
  refwh=[W,W,H,H]
  for i in range(4):
    bsex=vbdr[i]
    if bsex<1:
      refwh[i]=(0,0)
      continue
    x=vbdr[4+i]
    if x==0:
      x=refwh[i]>>1
    x=(x>>4)+1
    xrf=x<<4
    bsex=(((xrf+bsex)&0xffc0)+64)-xrf
    refwh[i]=fixb64(x,bsex>>3)
  return refwh

def calcUnCrop4(n):
  lap,pad=UnCrop4[n]
  whole=pad+(lap<<1)
  return lap,pad,whole


if model_wrap is None:
  cond_stage_model = BERTEmbedder(torch.jit.load('transformer_pnnx.pt').eval())
  autoencoder = torch.jit.load('autoencoder_pnnx.pt').eval().cuda()
  SDlatDEC=autoencoder
  imgenc = torch.jit.load('imgencoder_pnnx.pt').eval()
  model_wrap = CompVisDenoiser(CompVisJIT())
  model_wrap_cfg = CFGDenoiser(model_wrap)
if SDver != prevSDver:
  diffusion_emb = torch.jit.load(SDver+'diffusion_emb_pnnx.pt').eval().half().cuda()
  diffusion_mid = torch.jit.load(SDver+'diffusion_mid_pnnx.pt').eval().half().cuda()
  diffusion_out = torch.jit.load(SDver+'diffusion_out_pnnx.pt').eval().half().cuda()
  if DfmCodeBase != prevDfmCodeBase:
    prevDfmCodeBase=DfmCodeBase
    DfmCodeBase=nDfmCodeBase()
    if DfmCodeBase > 0:
      apply_model=apply_model_ldm
      if DfmCodeBase == 2:
        install_xformer()
      mdfy=-1
      if EnableKVmerges:
        mdfy=1
      ldm_unet=init_ldm(DfmCodeBase,mdfy)
    else:
      apply_model=apply_model_jit
      warmup()
    torch.cuda.empty_cache()
elif DfmCodeBase != prevDfmCodeBase:
  prevDfmCodeBase=DfmCodeBase
  DfmCodeBase=nDfmCodeBase()
  if DfmCodeBase > 0:
    apply_model=apply_model_ldm
    if DfmCodeBase == 2:
      install_xformer()
    mdfy=-1
    if EnableKVmerges:
      mdfy=1
    ldm_unet=init_ldm(DfmCodeBase,mdfy)
  torch.cuda.empty_cache()
prevSDver=SDver



👇Optional👇

Sampler Selector

In [None]:
Sampler='euler_a' #@param ['euler', 'euler_a', 'heun','dpm_2','dpm_2_a','lms']
f_sampler()

Karras=False #@param {type:'boolean'}
KarrasRho = 7.0 #@param {type:'number'}

Insert TI

In [None]:
cond_stage_model.insert('<majipuri>')
cond_stage_model.insert('<pekora>')

Insert prompt variables<br>
e.g. `A {animals} in water` will make `A dog in water` and `A cat in water` two images

In [None]:
cond_stage_model.insert_prompt_vars('animals')
cond_stage_model.insert_prompt_vars('artists')

So-called img2img thing, `masking` only works for `Dfm` ends with `_inpaint`

In [None]:
init_img='xxx' #@param {type:'string'}
strength=0.5 #@param {type:'number'}
EnLarge = [     ]  #@param {type:'raw'}
UnCrop4=[     ] #@param {type:'raw'}
#200,300,250,330 
initymgtyp=init_img_type()
if initymgtyp == 0:
  preimg=torch.tensor(np.load(init_img), device='cpu')
  n_samples=preimg.size(0)
  H=preimg.size(2)<<3
  W=preimg.size(3)<<3
  H, W, seed_size = EnLarge2seed_size(EnLarge)
elif initymgtyp == 1:
  n_samples=1
  rpt=load_img(init_img)
  H=rpt.size(2)
  W=rpt.size(3)
  preimg=imgenc(  rpt, torch.randn(torch.Size([n_samples,4,H>>3,W>>3]))  )*0.18215
  np.save(init_img+'.npy',preimg.numpy())
  H, W, seed_size = EnLarge2seed_size(EnLarge)
elif initymgtyp == 2:
  tmpfeeder=ifeeder()
  tmpfeeder.pattern=init_img[:-3].replace('!@!','%')+'bin'
  with open(init_img,'rt') as f:
    stz=f.read().replace(' ','').replace('\t','').splitlines()
  tmpfeeder.xpenlen=int(stz[1])
  stz=stz[0].split(',')
  tmpfeeder.shape=[ int(stz[0]), int(stz[1]), int(stz[2]), int(stz[3]) ]
  n_samples=tmpfeeder.shape[0]
  H=tmpfeeder.shape[2]<<3
  W=tmpfeeder.shape[3]<<3
  preimg=torch.tensor([2])
else:
  hlog0.setfuncb('0')
  preimg=None
revpreimg=None

if initymgtyp<2:
  if len(UnCrop4)>3:
    strength=999
    initymgtyp=10
    UnCrop4=mkUnCrop4()

masking=False #@param {type:'boolean'}
if (masking or INP) and strength<1:
  mask_path='msk.png'  #@param {type:'string'}
  image_embed=do_masking(mask_path)
  UnCrop4=[]
  initymgtyp=0
  if strength == 0:
    preimg=None

Prompt interpolation with latent re-feeding<br>tick `Revert2Orig` to disable it

In [None]:
Revert2Orig=False #@param {type:'boolean'}


def predict(c_list, uc, noi=None):
    global x_samples
    global samples
    global ktta
    global tmpfeeder
    global noise
    preimg=None


    feeder=ifeeder()

    
    
    sigmas = f_sigmas()
    shape_alters=[]
    if len(shape) > 4:
      shape_r=shape[:4]
      shape_alters=shape[4:]
    else:
      shape_r=shape

    if noi is None:
      noise = torch.randn(shape_r, dtype=torch.float,device=cudev)
    else:
      noise = noi

    ktta=0
    for c in c_list:
      c.reset()
      if preimg is not None:
        noise=torch.permute(noise, (0,3,1,2)).reshape(noise.shape)
        vt_enc= int(strength * ddim_num_steps)
        c.d_sta=t_enc/vt_enc
        sigma_sched = sigmas[ddim_num_steps - vt_enc - 1:]
        if preimg.shape != noise.shape:
          print('resiz')
          preimg=torch.nn.functional.interpolate(preimg ,size=noise.shape,mode='area')
        img = preimg +  noise* sigma_sched[0]
        feeder.setbs(img)        
      else:
        img = noise*sigmas[0]
        sigma_sched=sigmas
        feeder.setbs(img)





      samples = zemp0(feeder.getn(ktta),c,sigma_sched,shape_alters)

      ktta+=1
      preimg=samples
      x_samples = decode_first_stage( samples ).cpu()
      samples=samples.cpu()
      
      t3 = Thread(target = saver)
      a3 = t3.start()
    
    return
if Revert2Orig:
  predict=predict_orig
else:
  preimg=None
  strength=0.75 #@param {type:'number'}

NoiseMap interpolation<br>re-feed previous when strength > 0

In [None]:
Revert2Orig=False #@param {type:'boolean'}


def mknoises():
  sil=len(Seed_Interval_list)>>1
  nolist=[]
  resize_noise=None
  shape_alters=[]
  if len(shape) > 4:
    shape_r=shape[:4]
    shape_alters=shape[4:]
  else:
    shape_r=shape
  for n in range(sil):
    zeed=Seed_Interval_list[n*2]
    if zeed < 1:
      zeed=random.randint(0, 2**32)
      print('seed%d='%n)
      print(zeed)
    torch.manual_seed(zeed)    
    nolist.append( torch.randn(shape_r, dtype=torch.float,device=cudev) )
  nolist.append(nolist[0])
  interpos=[]
  DOT_THRESHOLD=0.9995
  for n in range(sil):
    stp=Seed_Interval_list[n*2+1]+1
    v0=nolist[n]
    v1=nolist[n+1]
    dot = torch.sum(v0 * v1 / (torch.linalg.norm(v0) * torch.linalg.norm(v1)))
    if torch.abs(dot) > DOT_THRESHOLD:
      for j in range(stp):
        t=j/stp
        interpos.append( (1 - t) * v0 + t * v1 )
    else:
      theta_0 = torch.acos(dot)
      sin_theta_0 = torch.sin(theta_0)
      for j in range(stp):
        t=j/stp
        theta_t = theta_0 * t
        sin_theta_t = torch.sin(theta_t)
        s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        interpos.append( s0 * v0 + s1 * v1 )

  return interpos, shape_alters



def predict(c_list, uc,noi=None):
    global x_samples
    global samples
    global ktta
    global tmpfeeder
    global noise
    preimg=None


    feeder=ifeeder()



    sigmas = f_sigmas()

    noises, shape_alters = mknoises()
    c_list=c_list*len(noises)

    ktta=0
    for c in c_list:
      c.reset()
      noise=noises[ktta]
      
      if preimg is not None:
        vt_enc= int(strength * ddim_num_steps)
        c.d_sta=t_enc/vt_enc
        sigma_sched = sigmas[ddim_num_steps - vt_enc - 1:]
        if preimg.shape != noise.shape:
          print('resiz')
          preimg=torch.nn.functional.interpolate(preimg ,size=noise.shape[-2:],mode='area')
        img = preimg +  noise* sigma_sched[0]
        feeder.setbs(img)        
      else:
        img = noise*sigmas[0]
        sigma_sched=sigmas
        feeder.setbs(img)





      samples = zemp0(feeder.getn(ktta),c,sigma_sched,shape_alters)

      ktta+=1
      if strength > 0:
        preimg=samples #*(1-strength)
      x_samples = decode_first_stage( samples ).cpu()
      samples=samples.cpu()
      
      t3 = Thread(target = saver)
      a3 = t3.start()
    
    return
if Revert2Orig:
  predict=predict_orig
else:
  preimg=None
  strength=0 #@param {type:'number'}
  Seed_Interval_list=[    775577,10,    881188,10,    996699,10    ] #@param {type:'raw'}

☝️Optional☝️

In [None]:
InThread=False #@param {type:'boolean'}

prompt = 'a photograph of an astronaut riding a horse' #@param {type:'string'}
neg_prompt = '' #@param {type:'string'}

n_iter = 1 #@param {type:'integer'}
if preimg is None and revpreimg is None:
  n_samples = 1 #@param {type:'integer'}
  H=704 #@param {type:'integer'}
  W=768 #@param {type:'integer'}
  seed_size=[     ] #@param {type:'raw'}
# 512,512,0.1
# 512,512,0.3, 768,512,0.6, 512,768,0.8


seed=0 #@param {type:'integer'}

ddim_num_steps = 50  #@param {type:'integer'}
ddpm_num_timesteps = 1000


outputp='/content/sample_data' #@param {type:'string'}







cfg_scale = 7.5 #@param {type:'number'}
ddim_eta = 1.0  #@param {type:'number'}

CopyDFout=True #@param {type:'boolean'}
UnifiedNoise=False #@param {type:'boolean'}
Overlap = 300  #@param {type:'integer'}
Overlap=Overlap>>4

outputp=outputp+'/'+str(len(os.listdir(outputp)))

t_enc=ddim_num_steps
if preimg is not None and strength < 1:
  t_enc = int(strength * ddim_num_steps)+1

shape = mk_shape()


makerng()





print('Start inference...')
c_list = makeCs(prompt,0)
uc = makeCs(neg_prompt,0,enable3d=False)[0]

PrintPromptPhases=False #@param {type:'boolean'}
SaveCompiledPrompt=False #@param {type:'boolean'}
printprompts(PrintPromptPhases)

  
def wpa_orig():
  global iita
  torch.set_grad_enabled(False)
  
  for iita in range(n_iter):
      print("iteration: %s" % (iita + 1))
      predict(c_list, uc)
      
  print('Script finished successfully.')
  torch.cuda.empty_cache()

def wpa_wide():
  global iita
  global halfb
  global hlog0
  torch.set_grad_enabled(False)


  halfa=(W>>3)
  Elap2=Overlap<<1
  hlog0.Elap2=Elap2
  halfb=halfa-Elap2
  unoise=torch.randn([n_samples, 4, H>>3 , (halfb*n_iter)+halfa ], dtype=torch.float,device=cudev)
  hlog0.Elap=Overlap


  if CopyDFout:
    hlog0.h_bs=[None]*ddim_num_steps
    hlog0.setfunc('logw0')
  iita=0
  predict(c_list, uc,unoise[:,:,:,0:halfa])
  if CopyDFout:
    hlog0.setfunc('logw')
  hlog0.setfuncNc(1,cache=True)
  hlog0.setbsB(-11,samples)
  for iita in range(1,n_iter):
      print("iteration: %s" % (iita + 1))
      sta=iita*halfb
      endo=sta+halfa
      predict(c_list, uc,unoise[:,:,:,sta:endo])
      hlog0.setbsB(-11,samples)

  print('Wide finished successfully.')
  torch.cuda.empty_cache()
  hlog0=hlogger()

def wpa_uncrop():
  global iita
  torch.set_grad_enabled(False)
  
  for iita in range(n_iter):
      print("iteration: %s" % (iita + 1))
      UnCrop(c_list, uc)
      
  print('UnCrop finished successfully.')
  torch.cuda.empty_cache()

wpa=wpa_orig
if initymgtyp == 10:
  wpa=wpa_uncrop
  if not INP:
    initymgtyp=99
  preimg=preimg.cuda()
elif UnifiedNoise:
  print('RealOverlap=%d'%(Overlap<<3))
  wpa=wpa_wide

if InThread:
  t1 = Thread(target = wpa)
  a1 = t1.start()
else:
  wpa()


In [None]:
!nvidia-smi

In [None]:
!ffmpeg -framerate 3 -i /content/sample_data/48_0x3v%d.png intp03.mp4

# Tools
designed for the gen proc running with `InThread` or gradio app<br>
so imgenc (image->latent encoder) is on cpu

In [None]:
hlog0.latlog=[]
hlog0.setfuncb('log')

In [None]:
dumplogs(0)
#hlog0.latlog_arr=[]

In [None]:
os.link('/content/sample_data/64_7x0v1.png','/content/sample_data/vyi/stp00100.png')

In [None]:
hlog0.latlog=[]

In [None]:
ConcatNoise=ConcatNoise_rdm #ConcatNoise_shuf

Gif/Video to latent pack

In [None]:
input_anim  = '/content/senpai.gif' #@param {type:'string'}
output_pattern = '/content/ijj/senpai_%04d.png' #@param {type:'string'}
!ffmpeg -i {input_anim} {output_pattern}

Resize the output to `(64*n)x(64*m)` first

In [None]:
encodepatt()

In [None]:
ImageEmbSetup  = False #@param {type:'boolean'}

import os
import torch
from PIL import Image
from torchvision import transforms
def load_im(im_path):
    if im_path.startswith("http"):
        response = requests.get(im_path)
        response.raise_for_status()
        im = Image.open(BytesIO(response.content))
    else:
        im = Image.open(im_path).convert("RGB")
    tforms = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop((224, 224)),
        transforms.ToTensor(),
    ])
    inp = tforms(im).unsqueeze(0)
    return inp*2-1
if not os.path.isfile('imgemb.pt'):
  !wget https://huggingface.co/Larvik/imgemb/resolve/main/imgemb.pt
imgemb=torch.jit.load('imgemb.pt').float()

In [None]:
imgemb(load_im('/content/chaz512.jpg')).numpy().tofile('chaz.bin')

# Gradio Gui
tho I don't really understand why you want a webui inside another webui

In [None]:
gradio=False #@param {type:'boolean'}

!pip install gradio
from google.colab import output
import gradio as gr

def dream():
  return


dream_interface = gr.Interface(
    dream,
    inputs=[
        gr.Textbox(placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
        gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=50),
        gr.Checkbox(label='Enable PLMS sampling', value=False),
        gr.Checkbox(label='Enable Fixed Code sampling', value=False),
        gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="DDIM ETA", value=0.0, visible=False),
        gr.Slider(minimum=1, maximum=50, step=1, label='Sampling iterations', value=8),
        gr.Slider(minimum=1, maximum=8, step=1, label='Samples per iteration', value=1),
        gr.Slider(minimum=1.0, maximum=20.0, step=0.5, label='Classifier Free Guidance Scale', value=7.0),
        gr.Number(label='Seed', value=-1),
        gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=704),
        gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=768),
    ],
    outputs=[
        gr.Gallery(),
        gr.Number(label='Seed')
    ],
    title="Stable Diffusion Text-to-Image",
    description="Generate images from text with Stable Diffusion",
)


gdemo = gr.TabbedInterface(interface_list=[dream_interface], tab_names=["Dream"])


output.serve_kernel_port_as_window(8233, path='/dl.htm')

Copy the link above to `GoogleLocal`

In [None]:
GoogleLocal = 'aaaaa' #@param {type:'string'}
if '.googleusercontent.com' in GoogleLocal:
  gdemo.launch()
else:
  print('set a valid GoogleLocal')

# glid-3-xl-stable

In [None]:
SDver='470k' #@param ['440k', '470k']
Dfm='Orig' #@param ['Orig', '_imgemb','_a19561','_a17750','_a17750_e9750','_e26500']
if Dfm=='Orig':
  Dfm=''
import os
import torch
from torch import nn
from torch.nn import functional as F


torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True





class imgencdec:
  def encode(self,im):
    nzmp=im.size(0)
    H=im.size(2)
    W=im.size(3)
    return imgenc(  im, torch.randn(torch.Size([nzmp,4,H>>3,W>>3]))  )
  def decode(self,im):
    return autoencoder(im)


SDver=f_dljit(SDver,Dfm)

if not os.path.isfile('/content/guided_diffusion/unet.py'):
  !wget https://raw.githubusercontent.com/TabuaTambalam/DalleWebms/main/docs/sd/jkt.py
  !git clone https://github.com/Jack000/glid-3-xl-stable.git
  !mv /content/glid-3-xl-stable/guided_diffusion /content/guided_diffusion 

from transformers import CLIPTokenizer
cond_stage_model = BERTEmbedder(torch.jit.load('transformer_pnnx.pt').eval())
diffusion_emb = torch.jit.load(SDver+'diffusion_emb_pnnx.pt').eval().cuda()
diffusion_mid = torch.jit.load(SDver+'diffusion_mid_pnnx.pt').eval().cuda()
diffusion_out = torch.jit.load(SDver+'diffusion_out_pnnx.pt').eval().cuda()
autoencoder = torch.jit.load('autoencoder_pnnx.pt').eval().cuda()
SDlatDEC=autoencoder
imgenc = torch.jit.load('imgencoder_pnnx.pt').eval()

In [None]:
#https://huggingface.co/Jack000/glid-3-xl-stable/tree/main/super_lg
import gc
import io
import math
import sys

from PIL import Image, ImageOps
import requests

from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

import numpy as np

from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults


from accelerate import init_empty_weights
from einops import rearrange
from math import log2, sqrt


!mkdir output_npy
!mkdir output

def save_sample(i, sample, clip_score=False):
    for k, image in enumerate(sample['pred_xstart'][:1]):
        image /= 0.18215
        im = image.unsqueeze(0)
        out = ldm.decode(im)

        npy_filename = f'output_npy/{i * batchsz + k:05}.npy'
        with open(npy_filename, 'wb') as outfile:
            np.save(outfile, image.detach().cpu().numpy())

        out = TF.to_pil_image(out.squeeze(0).add(1).div(2).clamp(0, 1))

        filename = f'output/{i * batchsz + k:05}.png'
        out.save(filename)


# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) // 2]
    combined = torch.cat([half, half], dim=0)
    model_out = model(combined, ts, **kwargs)
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = torch.cat([half_eps, half_eps], dim=0)
    return torch.cat([eps, rest], dim=1)

device = torch.device('cuda:0')
print('Using device:', device)



model_params = {
    'attention_resolutions': '32,16,8',
    'class_cond': False,
    'diffusion_steps': 1000,
    'rescale_timesteps': True,
    'timestep_respacing': '50',  # Modify this value to decrease the number of
                                 # timesteps.
    'image_size': 32,
    'learn_sigma': False,
    'noise_schedule': 'linear',
    'num_channels': 320,
    'num_heads': 8,
    'num_res_blocks': 2,
    'resblock_updown': False,
    'use_fp16': False,
    'use_scale_shift_norm': False,
    'clip_embed_dim': None, #768,
    'image_condition': False,
    #'image_condition': True if model_state_dict['input_blocks.0.0.weight'].shape[1] == 8 else False,
    'super_res_condition': False,
}

model_params['timestep_respacing'] = '100'

model_config = model_and_diffusion_defaults()
model_config.update(model_params)


model_config['use_fp16'] = True

# Load models
with init_empty_weights():
  model, diffusion = create_model_and_diffusion(**model_config)

load_state_dict_with_low_memory(model,mkmodel_state_dict())

if model_config['use_fp16']:
  model.convert_to_fp16()

In [None]:

model.requires_grad_(False).eval().to(device)


torch.manual_seed(114514)


# vae

ldm=imgencdec()


guidance_scale=7
height=832
width=896
batchsz=1


args_text='thicc farm girl, long blonde hair, japanimation, by Alfons Maria Mucha, cinematic lightning, cinematic wallpaper'
args_negative=''
# clip context

n_samples=batchsz
t_enc=100
text_emb = makeCs(args_text)[0].get(0)
text_emb_blank = makeCs(args_negative)[0].get(0)

image_embed = None



input_image = torch.zeros(batchsz, 4, height//8, width//8, device=device)
'''
lat=torch.tensor(np.load('96_4x1v1.npy'))


input_image[0][:,:,:32]=lat[0][:,:,:32]
'''

      
image_embed = None #torch.cat(batchsz*2*[input_image], dim=0).float()



kwargs = {
    "context": torch.cat([text_emb, text_emb_blank], dim=0).half().cuda(),
    "clip_embed": None,
    "image_embed": image_embed
}



cur_t = None

sample_fn = diffusion.plms_sample_loop_progressive



'''
init = Image.open('xipooh.jpg').convert('RGB')

init = TF.to_tensor(init).to(device).unsqueeze(0).clamp(0,1)
h = ldm.encode(init * 2 - 1) *  0.18215
init = torch.cat(1*2*[h], dim=0)
'''
init=None

for i in range(1):
    cur_t = diffusion.num_timesteps - 1
    with torch.cuda.amp.autocast(dtype=torch.float16):
      samples = sample_fn(
          model_fn,
          (batchsz*2, 4, height>>3, width>>3),
          clip_denoised=False,
          model_kwargs=kwargs,
          cond_fn=None,
          device=device,
          progress=True,
          init_image=init,
          skip_timesteps=0,
      )

    for j, sample in enumerate(samples):
        cur_t -= 1

    save_sample(i, sample)
torch.cuda.empty_cache()
