<a href="https://colab.research.google.com/github/FractalLibrary/ruDALL-E/blob/main/ruDALL_E_Mass_Batcher_Arbitrary_Resolution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ruDALLE arbitrary resolution

by @nev#4905

## Garbage Collect as necessary
##### (Shouldn't be necessary any more, but kept just in case)

In [None]:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()

## install dependencies

In [None]:
!git clone https://github.com/neverix/ru-dalle
!mv -f ru-dalle/* ru-dalle/.git .
!rm -rf ru-dalle
!git checkout better-caching
!pip install -e .

Cloning into 'ru-dalle'...
remote: Enumerating objects: 573, done.[K
remote: Counting objects: 100% (573/573), done.[K
remote: Compressing objects: 100% (417/417), done.[K
remote: Total 573 (delta 356), reused 309 (delta 148), pack-reused 0[K
Receiving objects: 100% (573/573), 17.79 MiB | 33.17 MiB/s, done.
Resolving deltas: 100% (356/356), done.
mv: cannot move 'ru-dalle/jupyters' to './jupyters': Directory not empty
mv: cannot move 'ru-dalle/pics' to './pics': Directory not empty
mv: cannot move 'ru-dalle/rudalle' to './rudalle': Directory not empty
mv: cannot move 'ru-dalle/tests' to './tests': Directory not empty
mv: cannot move 'ru-dalle/.git' to './.git': Directory not empty
D	.coveragerc
D	.gitignore
D	.gitlab-ci.yml
D	.pre-commit-config.yaml
Already on 'better-caching'
Your branch is up to date with 'origin/better-caching'.
Obtaining file:///content
Installing collected packages: rudalle
  Attempting uninstall: rudalle
    Found existing installation: rudalle 0.0.1rc11
    

In [None]:
from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
from rudalle.utils import seed_everything

import torch
import gc

In [None]:
device = 'cuda'
# device = "cpu"
tokenizer = get_tokenizer()
dalle = get_rudalle_model('Malevich', pretrained=True,
                           fp16=device == "cuda",
                           device=device
                          )

tokenizer --> ready
◼️ Malevich is 1.3 billion params model from the family GPT3-like, that uses Russian language and text+image multi-modality.


In [None]:
try:
    realesrgan
except NameError:
    realesrgan = get_realesrgan('x4', device=device)

# realesrgan = get_realesrgan('x4', device=device)
vae = get_vae().to(device)
# ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5')
# ruclip = ruclip.to(device)

x4 --> ready
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.
vae --> ready


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

## code

In [None]:
from glob import glob
from os.path import join

import cv2
import torch
import torchvision
import transformers
import more_itertools
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from PIL import Image

from rudalle import utils


def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image_prompts=None, temperature=1.0, bs=8,
                    seed=None, use_cache=True, w=32, h=48):
    # TODO docstring
    if seed is not None:
        utils.seed_everything(seed)
    vocab_size = dalle.get_param('vocab_size')
    text_seq_length = dalle.get_param('text_seq_length')
    image_seq_length = dalle.get_param('image_seq_length')
    total_seq_length = dalle.get_param('total_seq_length')
    device = dalle.get_param('device')
    real = 32

    text = text.lower().strip()
    input_ids = tokenizer.encode_text(text, text_seq_length=text_seq_length)
    pil_images, scores = [], []
    cache = None
    past_cache = None
    try:
        for chunk in more_itertools.chunked(range(images_num), bs):
            chunk_bs = len(chunk)
            with torch.no_grad():
                attention_mask = torch.tril(torch.ones((chunk_bs, 1, total_seq_length, total_seq_length), device=device))
                out = input_ids.unsqueeze(0).repeat(chunk_bs, 1).to(device)
                grid = torch.zeros((h, w)).long().cuda()
                has_cache = False
                sample_scores = []
                if image_prompts is not None:
                    prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
                    prompts = prompts.repeat(chunk_bs, 1)
                for idx in tqdm(range(out.shape[1], total_seq_length-real*real+w*h)):
                    idx -= text_seq_length
                    if image_prompts is not None and idx in prompts_idx:
                        out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1)
                    else:
                        y = idx // w
                        x = idx % w
                        x_from = max(0, min(w-real, x-real//2))
                        y_from = max(0, y-real//2)
                        # print(y, y_from, x, x_from, idx, w, h)
                        outs = []
                        xs = []
                        for row in range(y_from, y):
                            for col in range(x_from, min(w, x_from+real)):
                                outs.append(grid[row, col].item())
                                xs.append((row, col))
                        for col in range(x_from, x):
                            outs.append(grid[y, col].item())
                            xs.append((y, col))
                        rev_xs = {v: k for k, v in enumerate(xs)}
                        if past_cache is not None:
                            cache = list(map(list, cache.values()))
                            rev_past = {v: k for k, v in enumerate(past_cache)}
                            for i, e in enumerate(cache):
                                for j, c in enumerate(e):
                                    t = cache[i][j]
                                    t, c = t[..., :text_seq_length, :], t[..., text_seq_length:, :]
                                    # nc = []
                                    # for l, m in xs:
                                    #     while (l, m) not in rev_past:
                                    #         break  # will pass
                                    #         if l <= 0 and m <= 0:
                                    #             break
                                    #         m -= 1
                                    #         if m < 0:
                                    #             l -= 1
                                    #             m = real - 1
                                    #     if (l, m) not in rev_past:
                                    #         break
                                    #     nc.append(c[..., rev_past[l, m], :])
                                    # if nc:
                                    #     c = torch.stack(nc, dim=-2)
                                    #     # print(c.shape, t.shape, nc[0].shape)
                                    #     t = torch.cat((t, c), dim=-2)
                                    cache[i][j] = t
                            cache = dict(zip(range(len(cache)), cache))
                        past_cache = xs
                        logits, cache = dalle(torch.cat((input_ids.to(device).ravel(),
                                                            torch.from_numpy(np.asarray(outs)).long().to(device)),
                                                            dim=0).unsqueeze(0), attention_mask,
                                                cache=cache, use_cache=True, return_loss=False)
                        # logits = logits[:, -1, vocab_size:]
                        logits = logits[:, :, vocab_size:].view((-1, logits.shape[-1] - vocab_size))
                        logits /= temperature
                        filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
                        probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
                        sample = torch.multinomial(probs, 1)
                        sample_scores.append(probs[torch.arange(probs.size(0)), sample.transpose(0, 1)])
                        # out = torch.cat((out, sample), dim=-1)
                        sample, xs = sample[-1:], xs[-1:]
                        # print(sample.item())
                        grid[y, x] = sample.item()
                        # for s, (y, x) in zip(sample, xs):
                            # i = y * w + x
                            # i += 1
                            # grid[i // w, i % w] = s.item()
                        codebooks = grid.flatten().unsqueeze(0)
                        # print(codebooks.shape)
                        images = vae.decode(codebooks)
                        pil_images += utils.torch_tensors_to_pil_list(images)
                        # show(utils.torch_tensors_to_pil_list(images))
                # codebooks = out[:, -image_seq_length:]
                # codebooks = grid.flatten().unsqueeze(0)
                # images = vae.decode(codebooks)
                # pil_images += utils.torch_tensors_to_pil_list(images)
                # scores += torch.cat(sample_scores).sum(0).detach().cpu().numpy().tolist()
    except Exception as e:
        print(e)
        pass
    except KeyboardInterrupt:
        pass
    return pil_images, scores

#@title adapt the vqgan decoder to a new non-square resolution. uses the global `h` 
from math import sqrt, log

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
from taming.modules.diffusionmodules.model import Encoder, Decoder

from functools import partial
def decode(self, img_seq):
    b, n = img_seq.shape
    one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=self.num_tokens).float()
    z = (one_hot_indices @ self.model.quantize.embed.weight)
    z = rearrange(z, 'b (h w) c -> b c h w', h=h
                  # int(sqrt(n))
                  )
    img = self.model.decode(z)
    img = (img.clamp(-1., 1.) + 1) * 0.5
    return img
vae.decode = partial(decode, vae)


## Directory Setup

In [None]:
#@title Connect Google Drive
import os
drive_path = "/content"

from google.colab import drive
drive.mount('/content/drive')

def ensureProperRootPath():
    if len(drive_path) > 0:
        os.chdir(drive_path) # Changes directory to absolute root path
        print("Root path check: ")
        !pwd

ensureProperRootPath()

folder_name = "AI_ART" #@param {type: "string"}
if folder_name[-1] == '/': #Take care of accidental slashes at the end of a folder name
  folder_name = folder_name[:-1]
if len(folder_name) > 0:
    path_tmp = drive_path + "/drive/MyDrive/" + folder_name
    if not os.path.exists(path_tmp):
        os.mkdir(path_tmp)
    drive_path = path_tmp

print("Created folder & set root path to: " + drive_path)

#@markdown The folder where the images are dumped

project_name = "rudalle-arb" #@param {type: "string"}
if project_name[-1] == '/': #Take care of accidental slashes at the end of a folder name
  project_name = project_name[:-1]
if len(project_name) > 0:
      path_tmp = drive_path + "/" + project_name
      if not os.path.exists(path_tmp):
          os.mkdir(path_tmp)
      drive_path = path_tmp
print("Created project subfolder & set root path to: " + drive_path)

ensureProperRootPath()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Root path check: 
/content
Created folder & set root path to: /content/drive/MyDrive/AI_ART
Created project subfolder & set root path to: /content/drive/MyDrive/AI_ART/rudalle-arb
Root path check: 
/content/drive/MyDrive/AI_ART/rudalle-arb


## generation by ruDALLE

In [None]:
#@markdown settings
#@markdown random seed (set to positive to use)
seed =  1#@param {type: "integer"}
#@markdown text prompt (russian)
text = '\u0441\u043D\u0435\u0436\u043E\u043A \u0432 \u0430\u0434\u0441\u043A\u043E\u0439 \u043E\u0433\u043D\u0435\u043D\u043D\u043E\u0439 \u043F\u0435\u0449\u0435\u0440\u0435'  #@param {type: "string"}
#@markdown image size (width/height in tokens, px / 8)

#@markdown width can't be lower than 32
w = 64  #@param {type: "number"}
h =   36#@param {type: "number"}
#@markdown note: re-run the code section whenever you change the resolution (height)

if seed > 0:
    seed_everything(seed)

num_renders =   3#@param {type: "number"}

#@markdown Image quality/match.  It's recommended that you use the defaults here
top_k_ =  1024#@param {type:"integer"}
top_p_ =  .99#@param {type:"number"}
hash_val = str(hash(text + str(top_k_) + str(top_p_)))[-5:]

pil_images = []
scores = []

for i in range(num_renders):
  for top_k, top_p, images_num in [
    (top_k_, top_p_, 1), 
  ]:
    images_num = 1
    _pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p,
                                            h=h, w=w, use_cache=False)
    pil_images += _pil_images
    scores += _scores

  pil_images[-1].save("sample.png")
  pil_images[-1]

  sr_images = super_resolution([pil_images[-1]], realesrgan)
  filename = drive_path + f"/{hash_val}-{seed}{i:03}.png"
  sr_images[-1].save(filename)

  pil_images = []
  sr_images = []
  gc.collect()
  torch.cuda.empty_cache()

  0%|          | 0/2304 [00:00<?, ?it/s]

  dtype=torch.long, device=self.device) // self.image_tokens_per_dim


  0%|          | 0/2304 [00:00<?, ?it/s]