In [None]:
!pip install ftfy transformers omegaconf triton==2.0.0.dev20220701 einops accelerate taming-transformers-rom1504
!wget https://github.com/TabuaTambalam/DalleWebms/releases/download/0.1/vrs_lib.7z
!7z x vrs_lib.7z
!wget https://huggingface.co/spaces/shi-labs/Versatile-Diffusion/resolve/main/pretrained/vd-official.pth
!7z x vd-official.pth

In [None]:
import os
import PIL
from PIL import Image
from pathlib import Path
import numpy as np
import numpy.random as npr
from contextlib import nullcontext

import torch
import torchvision.transforms as tvtrans
from lib.cfg_helper import model_cfg_bank
from lib.model_zoo import get_model
from lib.model_zoo.ddim_vd import DDIMSampler_VD, DDIMSampler_VD_DualContext
from lib.model_zoo.ddim_dualcontext import DDIMSampler_DualContext

from lib.experiments.sd_default import color_adjust

from accelerate import init_empty_weights

cfgm_name = 'vd_noema'
sampler = DDIMSampler_VD
pth = 'pretrained/vd-official.pth'

cfgm = model_cfg_bank()(cfgm_name)

with init_empty_weights():
  net = get_model()(cfgm)

BICUBIC = PIL.Image.BICUBIC
use_cuda = torch.cuda.is_available()

model_name = cfgm_name

sampler = sampler(net)

In [None]:
metadev=False #@param {type:'boolean'}

metadev=torch.device('meta')

def get_keys_to_submodule(model):
  keys_to_submodule = {}
  # iterate all submodules
  for submodule_name, submodule in model.named_modules():
      # iterate all paramters in each submobule
      for param_name, param in submodule.named_parameters():
          # param_name is organized as .. ...
          splitted_param_name = param_name.split('.')
          # we cannot go inside it anymore. This is the actual parameter
          is_leaf_param = len(splitted_param_name) == 1
          if is_leaf_param:
              # we recreate the correct key
              key = f"{submodule_name}.{param_name}"
              # we associate this key with this submodule
              keys_to_submodule[key] = submodule
              
  return keys_to_submodule


def load_state_dict_with_low_memory(model, state_dict,modifyfunc=None,fill=True):
  if modifyfunc is not None:
    state_dict=modifyfunc(state_dict)
  print('======hacky load======')
  keys_to_submodule = get_keys_to_submodule(model)
  mste=model.state_dict()
  for key, submodule in keys_to_submodule.items():
      if key[0] == '.':
        key=key[1:]
      if key in state_dict:
        val = state_dict[key]
      elif fill:
        print(key)
        continue
        #val = torch.ones(mste[key].shape, dtype= torch.float16)
      else:
        continue

      param_name = key.split('.')[-1]
      new_val = torch.nn.Parameter(val,requires_grad=False)
      setattr(submodule, param_name, new_val)

def savcmpl(fna,z):
  sv=dict()
  sv[0]=True
  sv[1]=z[0]
  torch.save(sv,fna+'.compiled_prompt')

def regularize_image(x):
    
    if isinstance(x, str):
        x = Image.open(x).resize([512, 512], resample=BICUBIC)
        x = tvtrans.ToTensor()(x)
    elif isinstance(x, PIL.Image.Image):
        x = x.resize([512, 512], resample=BICUBIC)
        x = tvtrans.ToTensor()(x)
    elif isinstance(x, np.ndarray):
        x = PIL.Image.fromarray(x).resize([512, 512], resample=BICUBIC)
        x = tvtrans.ToTensor()(x)
    elif isinstance(x, torch.Tensor):
        pass
    else:
        assert False, 'Unknown image type'

    assert (x.shape[1]==512) & (x.shape[2]==512), \
        'Wrong image size'
    if use_cuda:
        x = x.to('cuda')
    return x

def find_low_rank(x, demean=True, q=20, niter=10):
    if demean:
        x_mean = x.mean(-1, keepdim=True)
        x_input = x - x_mean
    else:
        x_input = x

    u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
    ss = torch.stack([torch.diag(si) for si in s])
    x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))        

    if demean:
        x_lowrank += x_mean
    return x_lowrank

def remove_low_rank(x, demean=True, q=20, niter=10, q_remove=10):
    if demean:
        x_mean = x.mean(-1, keepdim=True)
        x_input = x - x_mean
    else:
        x_input = x

    u, s, v = torch.pca_lowrank(x_input, q=q, center=False, niter=niter)
    s[:, 0:q_remove] = 0
    ss = torch.stack([torch.diag(si) for si in s])
    x_lowrank = torch.bmm(torch.bmm(u, ss), torch.permute(v, [0, 2, 1]))        

    if demean:
        x_lowrank += x_mean
    return x_lowrank


def decode(z, xtype, ctype, color_adj='None', color_adj_to=None):
    if xtype == 'image':
        x = net.autokl_decode(z)

        color_adj_flag = (color_adj!='None') and (color_adj is not None)
        color_adj_simple = color_adj=='Simple'
        color_adj_keep_ratio = 0.5

        if color_adj_flag and (ctype=='vision'):
            x_adj = []
            for xi in x:
                color_adj_f = color_adjust(ref_from=(xi+1)/2, ref_to=color_adj_to)
                xi_adj = color_adj_f((xi+1)/2, keep=color_adj_keep_ratio, simple=color_adj_simple)
                x_adj.append(xi_adj)
            x = x_adj
        else:
            x = torch.clamp((x+1.0)/2.0, min=0.0, max=1.0)
            x = [tvtrans.ToPILImage()(xi) for xi in x]
        return x

    elif xtype == 'text':
        prompt_temperature = 1.0
        prompt_merge_same_adj_word = True
        x = net.optimus_decode(z, temperature=prompt_temperature)
        if prompt_merge_same_adj_word:
            xnew = []
            for xi in x:
                xi_split = xi.split()
                xinew = []
                for idxi, wi in enumerate(xi_split):
                    if idxi!=0 and wi==xi_split[idxi-1]:
                        continue
                    xinew.append(wi)
                xnew.append(' '.join(xinew))
            x = xnew
        return x

def application_disensemble(cin, n_samples=1, level=0, color_adj=None,):
    scale = 7.5

    ddim_steps = 50
    ddim_eta = 0.0
    

    cin = regularize_image(cin)
    ctemp = cin*2 - 1
    ctemp = ctemp[None].repeat(n_samples, 1, 1, 1)
    
    c = net.clip_encode_vision(ctemp)
    u = None
    if scale != 1.0:
        dummy = torch.zeros_like(ctemp)
        u = net.clip_encode_vision(dummy)

    #savcmpl('blankimg',u)
    if level == 0:
        pass
    else:
        c_glb = c[:, 0:1]
        c_loc = c[:, 1: ]
        u_glb = u[:, 0:1]
        u_loc = u[:, 1: ]

        if level == -1:
            c_loc = remove_low_rank(c_loc, demean=True, q=50, q_remove=1)
            u_loc = remove_low_rank(u_loc, demean=True, q=50, q_remove=1)
        if level == -2:
            c_loc = remove_low_rank(c_loc, demean=True, q=50, q_remove=2)
            u_loc = remove_low_rank(u_loc, demean=True, q=50, q_remove=2)
        if level == 1:
            c_loc = find_low_rank(c_loc, demean=True, q=10)
            u_loc = find_low_rank(u_loc, demean=True, q=10)
        if level == 2:
            c_loc = find_low_rank(c_loc, demean=True, q=2)
            u_loc = find_low_rank(u_loc, demean=True, q=2)

        c = torch.cat([c_glb, c_loc], dim=1)
        u = torch.cat([u_glb, u_loc], dim=1)

    h, w = [512, 512]
    shape = [n_samples, 4, h//8, w//8]
    z, _ = sampler.sample(
        steps=ddim_steps,
        shape=shape,
        conditioning=c,
        unconditional_guidance_scale=scale,
        unconditional_conditioning=u,
        xtype='image', ctype='vision',
        eta=ddim_eta,
        verbose=False,)
    x = decode(z, 'image', 'vision', color_adj=color_adj, color_adj_to=cin)
    return x

def inference(xtype, cin, ctype, scale=7.5, n_samples=1, color_adj=None,):
    ddim_steps = 50
    ddim_eta = 0.0

    if ctype in ['prompt', 'text']:
        c = net.clip_encode_text(n_samples * [cin])
        u = None
        if scale != 1.0:
            u = net.clip_encode_text(n_samples * [""])

    elif ctype in ['vision', 'image']:
        cin = regularize_image(cin)
        ctemp = cin*2 - 1
        ctemp = ctemp[None].repeat(n_samples, 1, 1, 1)
        c = net.clip_encode_vision(ctemp)
        u = None
        if scale != 1.0:
            dummy = torch.zeros_like(ctemp)
            u = net.clip_encode_vision(dummy)

    if xtype == 'image':
        h, w = [512, 512]
        shape = [n_samples, 4, h//8, w//8]
        z, _ = sampler.sample(
            steps=ddim_steps,
            shape=shape,
            conditioning=c,
            unconditional_guidance_scale=scale,
            unconditional_conditioning=u,
            xtype=xtype, ctype=ctype,
            eta=ddim_eta,
            verbose=False,)
        x = decode(z, xtype, ctype, color_adj=color_adj, color_adj_to=cin)
        return x

    elif xtype == 'text':
        n = 768
        shape = [n_samples, n]
        z, _ = sampler.sample(
            steps=ddim_steps,
            shape=shape,
            conditioning=c,
            unconditional_guidance_scale=scale,
            unconditional_conditioning=u,
            xtype=xtype, ctype=ctype,
            eta=ddim_eta,
            verbose=False,)
        x = decode(z, xtype, ctype)
        return x

def not_txtTB(k):
  if k.startswith('optimus.'):
    return False
  if k.startswith('clip.model.vision_model.'):
    return False
  if not k.startswith('model.diffusion_model.unet_text.'):
    return True
  if '.in_layers.' in k:
    return False
  if '.out_layers.' in k:
    return False
  return True

def not_txtALL(k):
  if k.startswith('model.diffusion_model.unet_text.'):
    return False
  return True

In [None]:
from vdoff import dik

dout=dict()
for k in dik:
  #   use `not_txtTB(k)` when (xtype = 'image' & ctype = 'prompt')
  if not_txtALL(k):
    yfo,offset,shape,stride,grad=dik[k]
    typ,fna,device,fsiz=yfo
    if typ=='F':
      typ=np.float32
    elif typ=='I':
      typ=np.int64
    dout[k]=torch.tensor(np.fromfile('archive/data/'+fna,dtype=typ).reshape(shape))


load_state_dict_with_low_memory(net,dout)
if use_cuda:
  net.to('cuda')

In [None]:
#example
!wget -O xipooh.jpg https://i.imgur.com/MYMdnVY.jpg

In [None]:
samp=application_disensemble('xipooh.jpg')

In [None]:
display_image = 0 #@param {type:'integer'}
samp[display_image]

In [None]:
samp=inference(
      xtype = 'image',
      cin = 'xipooh.jpg',
      ctype = 'vision',
      color_adj = None)

In [None]:
samp=inference(
      xtype = 'image',
      cin = 'a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ',
      ctype = 'prompt')