# Min-dalle classes

In [None]:
MIN_DALLE_REPO = 'https://huggingface.co/Larvik/mmin-dalle/resolve/main/'
import os
if not os.path.isfile('/content/once.txt'):
  !pip install accelerate
  !apt-get install -y libvulkan-dev libomp5
  !pip install ncnn-vulkan
  !mv /usr/local/lib/python3.7/dist-packages/ncnn_vulkan/*.so /usr/local/lib/python3.7/dist-packages/ncnn/
  !pip install emoji
  !wget -O /tmp/vq.param https://raw.githubusercontent.com/TabuaTambalam/vqqncnn/main/vq.param
  !wget -O /tmp/vq.bin https://github.com/TabuaTambalam/vqqncnn/releases/download/0.0/vq.bin
  !wget -O /tmp/vq_vert.param https://raw.githubusercontent.com/TabuaTambalam/vqqncnn/main/vq_vert.param
  !git clone https://github.com/kuprel/min-dalle.git
  !mv 'min-dalle/min_dalle' min_dalle
  !rm -rf '/content/min-dalle'

In [None]:
%%writefile /content/once.txt

2@0

In [None]:
from PIL import Image
from threading import Thread
import sys
import numpy as np
from torch import LongTensor, FloatTensor, BoolTensor,nn
from math import sqrt
import torch
import torch.backends.cudnn, torch.backends.cuda
import json
import requests
from typing import Iterator, List, Tuple, Dict
from min_dalle.text_tokenizer import TextTokenizer


torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True

IMAGE_TOKEN_COUNT = 256



class DalleBartDecoder():
  def __init__(self,jitpath,device):
    self.token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=device)
    self.jit=torch.jit.load(jitpath)
    if UseFP16:
      self.jit=self.jit.cuda()

  def sample_tokens(self, *args, settings) -> Tuple[LongTensor, FloatTensor]:
    logits, attention_state = self.jit(*args)
    image_count = logits.shape[0] // 2
    temperature = settings[[0]]
    top_k = settings[[1]].to(torch.long)
    supercondition_factor = settings[[2]]
    logits = logits[:, -1, : 2 ** 14]
    logits: FloatTensor = (
        logits[:image_count] * (1 - supercondition_factor) + 
        logits[image_count:] * supercondition_factor
    )
    logits_sorted, _ = logits.sort(descending=True)
    is_kept = logits >= logits_sorted[:, top_k - 1]
    logits -= logits_sorted[:, [0]]
    logits /= temperature
    logits.exp_()
    logits *= is_kept.to(torch.float32)
    image_tokens = torch.multinomial(logits, 1)[:, 0]
    return image_tokens, attention_state


class MinDalle:
    def __init__(
        self,
        models_root: str = 'pretrained',
        dtype: torch.dtype = torch.float32,
        device: str = None,
        is_reusable: bool = True,
        is_verbose = True
    ):
        if torch.cuda.is_available():
          if device == None:
              device = 'cuda'
        else:
          device = 'cpu'
          dtype=torch.float32

        if UseFP16:
          dtype=torch.float16

        if is_verbose: print('using device', device)
        self.device = device
        is_mega = True
        self.is_mega = True
        self.is_reusable = is_reusable
        self.dtype = dtype
        self.is_verbose = is_verbose
        self.text_token_count = 64
        self.layer_count = 24 if is_mega else 12
        self.attention_head_count = 32 if is_mega else 16
        self.embed_count = 2048 if is_mega else 1024
        self.glu_embed_count = 4096 if is_mega else 2730
        self.text_vocab_count = 50272 if is_mega else 50264
        self.image_vocab_count = 16415 if is_mega else 16384

        model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
        dalle_path = os.path.join(models_root, model_name)
        vqgan_path = os.path.join(models_root, 'vqgan')
        if not os.path.exists(dalle_path): os.makedirs(dalle_path)
        if not os.path.exists(vqgan_path): os.makedirs(vqgan_path)
        self.vocab_path = os.path.join(dalle_path, 'vocab.json')
        self.merges_path = os.path.join(dalle_path, 'merges.txt')
        self.encoder_params_path = os.path.join(dalle_path, 'encoder.pt')
        self.decoder_params_path = os.path.join(dalle_path, 'decoder.pt')
        self.detoker_params_path = os.path.join(vqgan_path, 'detoker.pt')

        self.init_tokenizer()
        if UseFP16:
          self.init_decoder()
          self.init_encoder()
        else:
          self.init_encoder()

        


    def download_tokenizer(self):
        if self.is_verbose: print('downloading tokenizer params')
        suffix = '' if self.is_mega else '_mini'
        vocab = requests.get(MIN_DALLE_REPO + 'vocab{}.json'.format(suffix))
        merges = requests.get(MIN_DALLE_REPO + 'merges{}.txt'.format(suffix))
        with open(self.vocab_path, 'wb') as f: f.write(vocab.content)
        with open(self.merges_path, 'wb') as f: f.write(merges.content)


    def download_encoder(self):
        if self.is_verbose: print('downloading encoder params')
        suffix = '_fp16' if UseFP16 else ''
        urli=MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix)
        !wget -O {self.encoder_params_path} {urli}


    def download_decoder(self):
        if self.is_verbose: print('downloading decoder params')
        suffix = '_fp16' if UseFP16 else ''
        urli=MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix)
        !wget -O {self.decoder_params_path} {urli}
    

    def init_tokenizer(self):
        is_downloaded = os.path.exists(self.vocab_path)
        is_downloaded &= os.path.exists(self.merges_path)
        if not is_downloaded: self.download_tokenizer()
        if self.is_verbose: print('intializing TextTokenizer')
        with open(self.vocab_path, 'r', encoding='utf8') as f:
            vocab = json.load(f)
        with open(self.merges_path, 'r', encoding='utf8') as f:
            merges = f.read().split('\n')[1:-1]
        self.tokenizer = TextTokenizer(vocab, merges)


    def init_encoder(self):
        is_downloaded = os.path.exists(self.encoder_params_path)
        if not is_downloaded: self.download_encoder()
        if self.is_verbose: print('initializing DalleBartEncoder')
        self.encoder = torch.jit.load(self.encoder_params_path)
        if UseFP16:
          self.encoder=self.encoder.cuda()
        return

    def init_decoder(self):
        is_downloaded = os.path.exists(self.decoder_params_path)
        if not is_downloaded: self.download_decoder()
        if self.is_verbose: print('initializing DalleBartDecoder')
        self.decoder=DalleBartDecoder(self.decoder_params_path,device=self.device)
        return

!rm /content/sample_data/izh.txt


In [None]:
import random

class ExtractorGPU(object):
  def __enter__(self):
    self.ex = net.create_extractor()
    self.ex.set_blob_vkallocator(blob_vkallocator)
    self.ex.set_workspace_vkallocator(blob_vkallocator)
    self.ex.set_staging_vkallocator(staging_vkallocator)
    return self.ex

  def __exit__(self, exc_type, exc_val, exc_tb):
    blob_vkallocator.clear()
    staging_vkallocator.clear()
    self.ex.clear()



def mkCBemb(seq):
  with NCNNex() as ex:
    ex.input('in0', ncnn.Mat(seq).clone())
    hrr, out0 = ex.extract('2')
  del ex
  return out0

def emb2img(emb):
  with NCNNex() as ex:
    ex.input('2', emb)
    hrr, out0 = ex.extract('out0')
  del ex
  return Image.fromarray(np.array(out0).astype(np.uint8))

def npmkCBemb(seq):
  with NCNNex() as ex:
    ex.input('in0', ncnn.Mat(seq.astype(np.uint32)).clone())
    hrr, out0 = ex.extract('2')
  del ex
  return np.array(out0)


def npemb2img(emb):
  with NCNNex() as ex:
    ex.input('2', ncnn.Mat(emb).clone())
    hrr, out0 = ex.extract('out0')
  del ex
  return Image.fromarray(np.array(out0).astype(np.uint8))

def pbla(step,scale):
  ret=[]
  mga=4-(4/scale)
  k=step-1
  for i in range(step):
    ret.append(k+mga*( ((i**2)/k) - i ))
  return ret


def img_float():
  with NCNNex() as ex:
      ex.input('in0', ncnn.Mat(curfull[:,sta:endo].reshape(16*(endo-sta)).astype(np.uint32)).clone())
      hrr, out0 = ex.extract('252')
  del ex
  return torch.FloatTensor(np.array(out0)).to(device).unsqueeze(0)


def npimg_float(seq):
  with NCNNex() as ex:
      ex.input('in0', ncnn.Mat(seq.astype(np.uint32)).clone())
      hrr, out0 = ex.extract('252')
  del ex
  return np.array(out0)


def hstack(sta,n=16,crop=[]):
  haf=[]
  for k in range(n):
    haf.append(dumped_seqs[sta+k].reshape((16,16))[:,:8])
  if crop:
    sta=crop[0]>>4
    endo=(crop[0]+crop[1])>>4
    return np.hstack(haf)[:,sta:endo].reshape(16*(endo-sta))
  else:
    return np.hstack(haf).reshape(128*n)
    

def fromsteps(liz):
  li=len(liz)-1
  haf=[]
  for n in range(li):
    haf.append(np.fromfile('/content/steps/s%d.bin'%n,dtype=np.uint16).astype(np.int32).reshape((-1,256))[liz[n]].reshape((16,16))[:,:8])
  haf.append(np.fromfile('/content/steps/s%d.bin'%li,dtype=np.uint16).astype(np.int32).reshape((-1,256))[liz[li]].reshape((16,16)))
  return np.hstack(haf).reshape(256 + 128*li)


def mk3x3(idx):
  rowz=[]
  for y in range(3):
    colz=[]
    for x in range(3):
      colz.append(dumped_seqs[idx[y*3+x]].reshape((16,16)))
    rowz.append(np.hstack(colz).reshape(768))
  return np.concatenate(rowz)

def interpo(seq1,seq2,step=30,scale=1.21,outfmt='/content/avif/%02d.png'):
  stp=step-1
  divi=pbla(step,scale)
  em1=npmkCBemb(seq1)
  em2=npmkCBemb(seq2)
  for i in range(step):
    npemb2img((em1*i+em2*(stp-i))/divi[i]).save(outfmt%i)

def interpoC():
  !mkdir /content/avif
  for n in range(candidate_count):
    interpo(dumped_seqs[n],dumped_seqs[n+1])
    !/tmp/ffmpeg/ffmpeg -framerate 18 -i /content/avif/%02d.png -sn -map_metadata -1 -map_chapters -1 -crf 10 -c:v libaom-av1 -aom-params enable-keyframe-filtering=0:enable-tpl-model=1 -lag-in-frames 48 -cpu-used 5 -row-mt 1 -tiles 1x1 -threads 2 -strict experimental -movflags +faststart -flags +cgop -pix_fmt yuv420p10le -c:a libopus -b:a 96k -ac 2 -f webm /content/avif/intp.webm
    os.rename('/content/avif/intp.webm','/content/avif/v%02d.webm'%n)

def showp(n, prt=False):
  global mlat
  global mfn
  daaz=dumped_seqs[n]
  with NCNNex() as ex:
    ex.input('in0', ncnn.Mat(daaz.astype(np.uint32)).clone())
    hrr, out0 = ex.extract('out0')
  del ex
  uz=Image.fromarray(np.array(out0).astype(np.uint8))
  uz.save('/content/sample_data/%d.png'%n)
  if prt:
    display(uz)
    mfn='%d'%n
    mlat=str(daaz)[1:-1]
    return mlat
  else:
    return uz

def showp2(seq):
  with NCNNex() as ex:
    ex.input('in0', ncnn.Mat(seq.astype(np.uint32)).clone())
    hrr, out0 = ex.extract('out0')
  del ex
  uz=Image.fromarray(np.array(out0).astype(np.uint8))
  uz.save('/content/sample_data/000.png')
  return uz

def hcopy(tk,left_sta=8):
  len=16-left_sta
  for y in range(16):
    mae=1+y*16
    tk[mae:mae+len]=tk[mae+left_sta:mae+16]

def hcopy_dup(tk,sele,left_sta=8):
  src=sele.expand(candidate_count,-1)
  len=16-left_sta
  for y in range(16):
    mae=1+y*16
    tk[:,mae:mae+len]=src[:,mae+left_sta:mae+16]
  for as0 in range(mindd.layer_count):
    for p in range(4):
      pkan=p*candidate_count
      sle4=attention_state[as0][pkan+candidate_select]
      for j in range(candidate_count):
        attention_state[as0][pkan+j][:]=sle4[:]

def hcopy_dst(src,tk,len=8):
  for y in range(16):
    mae=y*16
    tk[1+mae:1+mae+len]=torch.from_numpy(src[mae:mae+len])

def rumpla():
  global attention_state
  for row_index in range(   ROW_START   ,16):
    print('%x:'%row_index, end='')
    kt=16 * row_index
    for col_index in range(COL_START,16):
      i =  kt + col_index       
      with torch.cuda.amp.autocast(dtype=mindd.dtype):
          image_tokens[:,i + 1], attention_state = mindd.decoder.sample_tokens(
              attention_mask,encoder_state,attention_state,image_tokens[:,[i]],token_indices[[i]],settings=settings
          )


def rumplaL(glist):
  global attention_state
  dr=16-len(glist)
  for row_index in range(   dr   ,16):
    print('%x:'%row_index, end='')
    kt=16 * row_index
    for col_index in range( glist[row_index-dr] ,16):
      i =  kt + col_index       
      with torch.cuda.amp.autocast(dtype=mindd.dtype):
          image_tokens[:,i + 1], attention_state = mindd.decoder.sample_tokens(
              attention_mask,encoder_state,attention_state,image_tokens[:,[i]],token_indices[[i]],settings=settings
          )

def rudallestuff():
  if not os.path.isfile('/content/model-ru-latest.pt'):
    !git clone https://github.com/sberbank-ai/Real-ESRGAN > /dev/null
    !git clone https://github.com/Jack000/guided-diffusion > /dev/null
    !pip install rudalle > /dev/null
    !pip install -e ./guided-diffusion > /dev/null
    !pip install -r Real-ESRGAN/requirements.txt > /dev/null
    !wget https://huggingface.co/shonenkov/rudalle-utils/resolve/main/RealESRGAN_x2.pth > /dev/null
    !wget https://huggingface.co/shonenkov/rudalle-utils/resolve/main/RealESRGAN_x4.pth > /dev/null
    !wget https://dall-3.com/models/guided-diffusion/ru-dalle/model-ru-latest.pt > /dev/null

def mksettings(top_k0,temperature0,supercondition_factor0):
  return torch.tensor(
    [temperature0, top_k0, supercondition_factor0], 
    dtype=mindd.dtype,
    device=mindd.device
)
  
def gen0(dr=0,dc=0):
  global ROW_START
  global COL_START
  global dumped_seqs
  ROW_START=dr
  COL_START=dc
  rumpla()
  dumped_seqs=image_tokens[:, 1:].to('cpu').numpy().astype(np.uint16)
  with open('/content/ozv.bin',mode='ba+') as f:
    dumped_seqs.tofile(f)


def gen0L(glist):
  global dumped_seqs
  rumplaL(glist)
  dumped_seqs=image_tokens[:, 1:].to('cpu').numpy().astype(np.uint16)
  with open('/content/ozv.bin',mode='ba+') as f:
    dumped_seqs.tofile(f)


def chkstz(stz):
  lstz=len(stz)
  if lstz < 2:
    return True
  if stz[1] == '':
    return True
  return False


def gen1():
  torch.set_grad_enabled(False)
  dr=ROW_START
  dc=COL_START
  useL=False
  while os.path.isfile('/content/once.txt'):
    if useL:
      gen0L(glist)
    else:
      gen0(dr,dc)
    try:
      with open('/content/once.txt','rt') as f:
        stz=f.read().replace(' ','').splitlines()
      if stz[0] == '':
        dr=dr
      elif chkstz(stz):
        useL=False
        stz=stz[0].split('@')
        lstz=len(stz)
        if lstz > 0:
          dr=int(stz[0])
        if lstz > 1:
          dc=int(stz[1])
        if lstz > 2:
          settings[1]=float(stz[2])
          if lstz > 3:
            settings[0]=float(stz[3])
          if lstz > 4:
            settings[2]=float(stz[4])
      else:
        glist=[int(x) for x in stz if x.isdigit()][:16]
        useL=True
    except:
      dr=dr
  os.rename('/content/-.txt','/content/once.txt')

def localhttp():
  global HTML
  if not os.path.isfile('/content/sample_data/izh.txt'):
    from IPython.core.display import HTML
    !nohup python3 -m http.server -d /content/sample_data/ 8233 > /content/sample_data/izh.txt &


def prmp(filltoken=False):
  global seed
  global text_tokens
  global image_tokens
  global attention_mask
  global encoder_state
  global attention_state
  if is_verbose: print('tokenizing text')
  tokens = mindd.tokenizer.tokenize(text, is_verbose=is_verbose)
  if len(tokens) > mindd.text_token_count: 
      tokens = tokens[:mindd.text_token_count]
  if is_verbose: print('{} text tokens'.format(len(tokens)), tokens)
  text_tokens = np.ones((2, 64), dtype=np.int32)
  text_tokens[0, :2] = [tokens[0], tokens[-1]]
  text_tokens[1, :len(tokens)] = tokens
  dev='cpu'
  if UseFP16:
    dev=mindd.device
  text_tokens = torch.tensor(
      text_tokens, 
      dtype=torch.long, 
      device=dev
  )

  if not mindd.is_reusable: mindd.init_encoder()
  if is_verbose: print('encoding text tokens')
  with torch.cuda.amp.autocast(dtype=mindd.dtype):
      encoder_state = mindd.encoder(text_tokens)
  if not mindd.is_reusable: del mindd.encoder
  torch.cuda.empty_cache()

  if not mindd.is_reusable: mindd.init_decoder()

  with torch.cuda.amp.autocast(dtype=mindd.dtype):
      expanded_indices = [0] * candidate_count + [1] * candidate_count
      text_tokens = text_tokens[expanded_indices]
      encoder_state = encoder_state[expanded_indices]
      attention_mask = text_tokens.not_equal(1)[:, None, None, :]
      if filltoken:
        attention_state = torch.zeros(
            size=(
                mindd.layer_count,
                candidate_count * 4,
                IMAGE_TOKEN_COUNT,
                mindd.embed_count
            ), 
            device=mindd.device
        )
        image_tokens = torch.full(
            (candidate_count, IMAGE_TOKEN_COUNT + 1), 
            mindd.image_vocab_count,
            dtype=torch.long,
            device=mindd.device
        )
      
      if seed == 0:
        seed=random.randint(0, 2**32)
        print('rndseed: '+str(seed)) 
      torch.manual_seed(seed)
        

newstart=True

# PlayGround

In [None]:
UseFP16=False #@param {type:'boolean'}
PrepareRuDalleStuff=False #@param {type:'boolean'}
if newstart:
  mindd = MinDalle(is_reusable=True)
  import ncnn
  import gc
  net = ncnn.Net()
  net.opt.use_vulkan_compute = False
  NCNNex=net.create_extractor
  net.load_param(  '/tmp/vq.param'   )  #	'/content/vq3x3.txt'
  net.load_model('/tmp/vq.bin')
  newstart=False
  !nvidia-smi
if PrepareRuDalleStuff:
  t1 = Thread(target = rudallestuff)
  a1 = t1.start()


In [None]:
text ="desert oasis merchant market high fantasy book cover painting" #@param {type:'string'}
candidate_count =  2#@param {type:'integer'}
seed= 775577  #@param {type:'integer'}
log_everything = False #@param {type:'boolean'}

'''
555
'''

is_verbose=False


prmp(True)


userselect=[]
userselectN=[0]*128

!rm -rf /content/steps
!mkdir /content/steps
!rm /content/sample_data/*.png

if log_everything:
  with open('/content/steps/prompt.txt','wt') as f:
    f.write(text+'\ntokens='+str(tokens))
  torch.save(attention_mask, '/content/steps/attention_mask.pt')
  torch.save(encoder_state, '/content/steps/encoder_state.pt')


ROW_START =0
COL_START =0
step=0
newprompt=True

In [None]:
del mindd.encoder
torch.cuda.empty_cache()

In [None]:
mindd.init_decoder()
token_indices = mindd.decoder.token_indices

In [None]:
candidate_select=0 #@param {type:'integer'}
top_k = 2048 #@param {type:'integer'}
temperature = 3  #@param {type:'integer'}
supercondition_factor = 64 #@param {type:'integer'}


r_ROW_START =35 #@param {type:'integer'}
r_COL_START =0 #@param {type:'integer'}

settings=mksettings(top_k,temperature,supercondition_factor)





if newprompt:
  newprompt=False
  gen0()
  for n in range(candidate_count):
    print('init%d'%n,end='')
    display(showp(n))
elif candidate_select < candidate_count:
  image_tokens[:]=image_tokens[candidate_select].expand(candidate_count,-1)[:]
  for as0 in range(mindd.layer_count):
    for p in range(4):
      pkan=p*candidate_count
      sle4=attention_state[as0][pkan+candidate_select]
      for j in range(candidate_count):
        attention_state[as0][pkan+j][:]=sle4[:]
  del sle4
  ROW_START=r_ROW_START>>4
  COL_START=r_COL_START>>4
  print('infinite gen started in thread')
  t1 = Thread(target = gen1)
  a1 = t1.start()
else:
  gen0()
  for n in range(candidate_count):
    print('init%d'%n,end='')
    display(showp(n))

