# Min-dalle classes

In [1]:
import os
if not os.path.isfile('/content/once.txt'):
  !pip install accelerate
  !pip install ncnn
  !wget -O /tmp/vq.param https://raw.githubusercontent.com/TabuaTambalam/vqqncnn/main/vq.param
  !wget -O /tmp/vq.bin https://github.com/TabuaTambalam/vqqncnn/releases/download/0.0/vq.bin
  !wget -O /tmp/vq_vert.param https://raw.githubusercontent.com/TabuaTambalam/vqqncnn/main/vq_vert.param
  !git clone https://github.com/kuprel/min-dalle.git
  !mv 'min-dalle/min_dalle' min_dalle

In [2]:
%%writefile /content/once.txt
uwa

Overwriting /content/once.txt


In [3]:
from PIL import Image
import numpy as np
from torch import LongTensor, FloatTensor
from math import sqrt
import torch
import torch.backends.cudnn, torch.backends.cuda
import json
import requests
from typing import Iterator
from min_dalle.text_tokenizer import TextTokenizer
from min_dalle.models import DalleBartEncoder, DalleBartDecoder

torch.set_grad_enabled(False)
torch.set_num_threads(os.cpu_count())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True

MIN_DALLE_REPO = 'https://huggingface.co/kuprel/min-dalle/resolve/main/'
IMAGE_TOKEN_COUNT = 256


class MinDalle:
    def __init__(
        self,
        models_root: str = 'pretrained',
        dtype: torch.dtype = torch.float16,
        device: str = None,
        is_mega: bool = True, 
        is_reusable: bool = True,
        is_verbose = True
    ):
        if torch.cuda.is_available():
          if device == None:
              device = 'cuda'
        else:
          device = 'cpu'
          dtype=torch.float32

        if is_verbose: print("using device", device)
        self.device = device
        self.is_mega = is_mega
        self.is_reusable = is_reusable
        self.dtype = dtype
        self.is_verbose = is_verbose
        self.text_token_count = 64
        self.layer_count = 24 if is_mega else 12
        self.attention_head_count = 32 if is_mega else 16
        self.embed_count = 2048 if is_mega else 1024
        self.glu_embed_count = 4096 if is_mega else 2730
        self.text_vocab_count = 50272 if is_mega else 50264
        self.image_vocab_count = 16415 if is_mega else 16384

        model_name = 'dalle_bart_{}'.format('mega' if is_mega else 'mini')
        dalle_path = os.path.join(models_root, model_name)
        vqgan_path = os.path.join(models_root, 'vqgan')
        if not os.path.exists(dalle_path): os.makedirs(dalle_path)
        if not os.path.exists(vqgan_path): os.makedirs(vqgan_path)
        self.vocab_path = os.path.join(dalle_path, 'vocab.json')
        self.merges_path = os.path.join(dalle_path, 'merges.txt')
        self.encoder_params_path = os.path.join(dalle_path, 'encoder.pt')
        self.decoder_params_path = os.path.join(dalle_path, 'decoder.pt')
        self.detoker_params_path = os.path.join(vqgan_path, 'detoker.pt')

        self.init_tokenizer()
        if is_reusable:
            self.init_encoder()
            self.init_decoder()


    def download_tokenizer(self):
        if self.is_verbose: print("downloading tokenizer params")
        suffix = '' if self.is_mega else '_mini'
        vocab = requests.get(MIN_DALLE_REPO + 'vocab{}.json'.format(suffix))
        merges = requests.get(MIN_DALLE_REPO + 'merges{}.txt'.format(suffix))
        with open(self.vocab_path, 'wb') as f: f.write(vocab.content)
        with open(self.merges_path, 'wb') as f: f.write(merges.content)


    def download_encoder(self):
        if self.is_verbose: print("downloading encoder params")
        suffix = '' if self.is_mega else '_mini'
        urli=MIN_DALLE_REPO + 'encoder{}.pt'.format(suffix)
        !wget -O {self.encoder_params_path} {urli}


    def download_decoder(self):
        if self.is_verbose: print("downloading decoder params")
        suffix = '' if self.is_mega else '_mini'
        urli=MIN_DALLE_REPO + 'decoder{}.pt'.format(suffix)
        !wget -O {self.decoder_params_path} {urli}
    

    def init_tokenizer(self):
        is_downloaded = os.path.exists(self.vocab_path)
        is_downloaded &= os.path.exists(self.merges_path)
        if not is_downloaded: self.download_tokenizer()
        if self.is_verbose: print("intializing TextTokenizer")
        with open(self.vocab_path, 'r', encoding='utf8') as f:
            vocab = json.load(f)
        with open(self.merges_path, 'r', encoding='utf8') as f:
            merges = f.read().split("\n")[1:-1]
        self.tokenizer = TextTokenizer(vocab, merges)


    def init_encoder(self):
        is_downloaded = os.path.exists(self.encoder_params_path)
        if not is_downloaded: self.download_encoder()
        if self.is_verbose: print("initializing DalleBartEncoder")
        self.encoder = DalleBartEncoder(
            attention_head_count = self.attention_head_count,
            embed_count = self.embed_count,
            glu_embed_count = self.glu_embed_count,
            text_token_count = self.text_token_count,
            text_vocab_count = self.text_vocab_count,
            layer_count = self.layer_count,
            device=self.device
        ).to(self.dtype).eval()
        params = torch.load(self.encoder_params_path)
        self.encoder.load_state_dict(params, strict=False)
        del params
        self.encoder = self.encoder.to(device=self.device)


    def init_decoder(self):
        is_downloaded = os.path.exists(self.decoder_params_path)
        if not is_downloaded: self.download_decoder()
        if self.is_verbose: print("initializing DalleBartDecoder")
        self.decoder = DalleBartDecoder(
            image_vocab_count = self.image_vocab_count,
            attention_head_count = self.attention_head_count,
            embed_count = self.embed_count,
            glu_embed_count = self.glu_embed_count,
            layer_count = self.layer_count,
            device=self.device
        ).to(self.dtype).eval()
        params = torch.load(self.decoder_params_path)
        self.decoder.load_state_dict(params, strict=False)
        del params
        self.decoder = self.decoder.to(device=self.device)



In [4]:
def mkCBemb(seq):
  with net.create_extractor() as ex:
    ex.input("in0", ncnn.Mat(seq).clone())
    hrr, out0 = ex.extract("2")
  del ex
  return out0

def emb2img(emb):
  with net.create_extractor() as ex:
    ex.input("2", emb)
    hrr, out0 = ex.extract("out0")
  del ex
  return Image.fromarray(np.array(out0).astype(np.uint8))

def npmkCBemb(seq):
  with net.create_extractor() as ex:
    ex.input("in0", ncnn.Mat(seq).clone())
    hrr, out0 = ex.extract("2")
  del ex
  return np.array(out0)


def npemb2img(emb):
  with net.create_extractor() as ex:
    ex.input("2", ncnn.Mat(emb).clone())
    hrr, out0 = ex.extract("out0")
  del ex
  return Image.fromarray(np.array(out0).astype(np.uint8))

def pbla(step,scale):
  ret=[]
  mga=4-(4/scale)
  k=step-1
  for i in range(step):
    ret.append(k+mga*( ((i**2)/k) - i ))
  return ret


def hstack(sta,n=16,crop=[]):
  haf=[]
  for k in range(n):
    haf.append(dumped_seqs[sta+k].reshape((16,16))[:,:8])
  if crop:
    sta=crop[0]>>4
    endo=(crop[0]+crop[1])>>4
    return np.hstack(haf)[:,sta:endo].reshape(16*(endo-sta))
  else:
    return np.hstack(haf).reshape(128*n)
    


def mk3x3(idx):
  rowz=[]
  for y in range(3):
    colz=[]
    for x in range(3):
      colz.append(dumped_seqs[idx[y*3+x]].reshape((16,16)))
    rowz.append(np.hstack(colz).reshape(768))
  return np.concatenate(rowz)

def interpo(seq1,seq2,step=30,scale=1.21,outfmt='/content/avif/%02d.png'):
  stp=step-1
  divi=pbla(step,scale)
  em1=npmkCBemb(seq1)
  em2=npmkCBemb(seq2)
  for i in range(step):
    npemb2img((em1*i+em2*(stp-i))/divi[i]).save(outfmt%i)

def showp(n):
  with net.create_extractor() as ex:
    ex.input("in0", ncnn.Mat(dumped_seqs[n].astype(np.uint32)).clone())
    hrr, out0 = ex.extract("out0")
  del ex
  uz=Image.fromarray(np.array(out0).astype(np.uint8))
  uz.save('/content/sample_data/%d.png'%n)
  return uz

def showp2(seq):
  with net.create_extractor() as ex:
    ex.input("in0", ncnn.Mat(seq.astype(np.uint32)).clone())
    hrr, out0 = ex.extract("out0")
  del ex
  uz=Image.fromarray(np.array(out0).astype(np.uint8))
  uz.save('/content/sample_data/000.png')
  return uz

def hcopy(tk,left_sta=8):
  len=16-left_sta
  for y in range(16):
    mae=1+y*16
    tk[mae:mae+len]=tk[mae+left_sta:mae+16]

def hcopy_dup(tk,sele,left_sta=8):
  src=sele.expand(candidate_count,-1).T
  len=16-left_sta
  for y in range(16):
    mae=1+y*16
    tk[mae:mae+len]=src[mae+left_sta:mae+16]

def hcopy_dst(src,tk,len=8):
  for y in range(16):
    mae=y*16
    tk[1+mae:1+mae+len]=torch.from_numpy(src[mae:mae+len])

def rumpla():
  global attention_state
  for row_index in range(   ROW_START   ,16):
    print('%x:'%row_index, end='')
    kt=16 * row_index
    for col_index in range(COL_START,16):
      i =  kt + col_index       
      with torch.cuda.amp.autocast(dtype=mindd.dtype):
          image_tokens[i + 1], attention_state = mindd.decoder.forward(
              settings=settings,
              attention_mask=attention_mask,
              encoder_state=encoder_state,
              attention_state=attention_state,
              prev_tokens=image_tokens[i],
              token_index=token_indices[[i]]
          )
  
!rm /content/sample_data/*.png
newstart=True

# PlayGround

In [None]:
if newstart:
  UseMega=False #@param {type:"boolean"}
  mindd = MinDalle(is_mega=UseMega, is_reusable=True)
  import ncnn
  import gc
  net = ncnn.Net()
  net.opt.use_vulkan_compute = True
  net.load_param(  "/tmp/vq.param"   )  #   "/content/vq3x3.txt"
  net.load_model("/tmp/vq.bin")
  newstart=False
  !nvidia-smi

Make a prompt

In [None]:
text ="desert oasis merchant market high fantasy book cover painting" #@param {type:"string"}
candidate_count =  4#@param {type:"integer"}
seed= 776677  #@param {type:"integer"}

'''
555
'''

is_verbose=False

if is_verbose: print("tokenizing text")
tokens = mindd.tokenizer.tokenize(text, is_verbose=is_verbose)
if len(tokens) > mindd.text_token_count: 
    tokens = tokens[:mindd.text_token_count]
if is_verbose: print("{} text tokens".format(len(tokens)), tokens)
text_tokens = np.ones((2, 64), dtype=np.int32)
text_tokens[0, :2] = [tokens[0], tokens[-1]]
text_tokens[1, :len(tokens)] = tokens
text_tokens = torch.tensor(
    text_tokens, 
    dtype=torch.long, 
    device=mindd.device
)

if not mindd.is_reusable: mindd.init_encoder()
if is_verbose: print("encoding text tokens")
with torch.cuda.amp.autocast(dtype=mindd.dtype):
    encoder_state = mindd.encoder.forward(text_tokens)
if not mindd.is_reusable: del mindd.encoder
torch.cuda.empty_cache()

if not mindd.is_reusable: mindd.init_decoder()

with torch.cuda.amp.autocast(dtype=mindd.dtype):
    expanded_indices = [0] * candidate_count + [1] * candidate_count
    text_tokens = text_tokens[expanded_indices]
    encoder_state = encoder_state[expanded_indices]
    attention_mask = text_tokens.not_equal(1)
    attention_state = torch.zeros(
        size=(
            mindd.layer_count,
            candidate_count * 4,
            IMAGE_TOKEN_COUNT,
            mindd.embed_count
        ), 
        device=mindd.device
    )
    image_tokens = torch.full(
        (IMAGE_TOKEN_COUNT + 1, candidate_count), 
        mindd.image_vocab_count,
        dtype=torch.long,
        device=mindd.device
    )
    
    if seed > 0: torch.manual_seed(seed)

token_indices = torch.arange(IMAGE_TOKEN_COUNT, device=mindd.device)

userselect=[]
userselectN=[0]*128

!rm -rf /content/steps
!mkdir /content/steps
!rm /content/sample_data/*.png

step=0
newprompt=True

Generate loop<br>
(Run this cell multiple times,<br>changing candidate_select & other settings)

In [None]:
candidate_select=0 #@param {type:"integer"}
ROW_START=0 
COL_START=7 #@param {type:"integer"}
top_k= 2048 #@param {type:"integer"}
temperature= 3  #@param {type:"integer"}
supercondition_factor= 64 #@param {type:"integer"}

PreviewLimit=2

settings = torch.tensor(
    [temperature, top_k, supercondition_factor], 
    dtype=torch.float32,
    device=mindd.device
)


if newprompt:
  ROW_START=0
  COL_START=0
elif candidate_select < candidate_count:
  userselectN[step]+=candidate_select
  userselect.append(dumped_seqs[candidate_select].reshape((16,16))[:,:8])
  hcopy_dup(image_tokens,image_tokens.T[candidate_select])
  step+=1
else:
  userselectN[step]+=candidate_count

rumpla()

dumped_seqs=image_tokens[1:].T.to('cpu').numpy().astype(np.uint16)
with open('/content/steps/s%d.bin'%step,mode='ba+') as f:
  dumped_seqs.tofile(f)

if newprompt:
  for n in range(candidate_count):
    print('Init%d= '%n,end='')
    display(showp(n))
    print('=====================')
else:
  syz=len(userselect)
  if syz > PreviewLimit:
    tview=userselect[syz-PreviewLimit:]
    syz=PreviewLimit
  else:
    tview=userselect
  for n in range(candidate_count):
    print('Next%d= '%n,end='')
    display(showp2( np.hstack(tview+[dumped_seqs[n].reshape((16,16))]).reshape(256+128*syz) ))
    dfna='-'.join(str(x) for x in userselectN[:step])
    os.rename('/content/sample_data/000.png','/content/sample_data/'+dfna+'-'+str(n)+'.png')
    print('=====================')


newprompt=False

Finial<br>(last select)

In [None]:
candidate_select=2 #@param {type:"integer"}
userselect.append(dumped_seqs[candidate_select].reshape((16,16)))
userselectN[step]+=candidate_select
step+=1


userselectN=userselectN[:step]
curfull=np.hstack(userselect).astype(np.uint32)
print(userselectN)
showp2(curfull.reshape(128+128*len(userselect)))

================<br>Crop & Re-decode

In [None]:
Left=219 #@param {type:"integer"}
Width=355 #@param {type:"integer"}

sta=Left>>4
endo=(Left+Width)>>4
showp2(curfull[:,sta:endo].reshape(16*(endo-sta)))

# Tools

pack steps

In [None]:
!7z a /content/pk.7z /content/steps
dfna='-'.join(str(x) for x in userselectN)
os.rename('/content/pk.7z','/content/'+dfna+'.7z')

Reload ncnndec<br>
(when memory leaks)

In [None]:
UseV=False #@param {type:"boolean"}
del net
gc.collect()

net = ncnn.Net()
net.opt.use_vulkan_compute = True
if UseV:
  net.load_param("/tmp/vq_vert.param")
else:
  net.load_param(  "/tmp/vq.param"   )  #   "/content/vq3x3.txt"
net.load_model("/tmp/vq.bin")