#**ruDALLE** <font size="+2">🥑</font>

This is a 1.3 billion parameter model for Russian, recreating OpenAI's DALL·E, a model capable of generating arbitrary images from a text prompt that describes the desired result.

The generation pipeline includes ruDALL-E, ruCLIP for ranging results, and a superresolution model. You can use automatic translation into Russian to create desired images with ruDALL-E.

Model was trained by Sber AI and SberDevices teams.

English2Russian translator is Facebook-FAIR's WMT'19 model.



Added a nice GUI & cleaned things, added an English2Russian translator, added a super-res script (from [this notebook](https://colab.research.google.com/drive/1_evIaGmeo-RXWJrmavB3hR4UBhJJV5xL)) made by danielrussruss#6125, combined the image prompt notebook from the official ruDALLE repo with this one, Philipuss#4066.

In [None]:
#@title GPU Info <font size="+2">📊</font>
import multiprocessing
import torch
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)
print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device.type)

!nvidia-smi

In [None]:
#@title <font color="lightgreen" size="+3">←</font> **Installing and initializing libraries** <font size="+2">🎡</font>

!pip install rudalle           > /dev/null
!pip install transformers      > /dev/null

from rudalle.pipelines import generate_images, show, super_resolution, cherry_pick_by_clip
from rudalle.image_prompts import ImagePrompts
from rudalle import get_rudalle_model, get_tokenizer, get_vae, get_realesrgan, get_ruclip
from rudalle.utils import seed_everything
import requests
from PIL import Image
import torch

from transformers import FSMTForConditionalGeneration, FSMTTokenizer

device = 'cuda'
dalle = get_rudalle_model('Malevich', pretrained=True, fp16=True, device=device)
try:
    realesrgan, tokenizer, ruclip, ruclip_processor
except NameError:
    realesrgan = get_realesrgan('x4', device=device)
    tokenizer = get_tokenizer()
    vae = get_vae().to(device)
    ruclip, ruclip_processor = get_ruclip('ruclip-vit-base-patch32-v5')
    ruclip = ruclip.to(device)



mname = "facebook/wmt19-en-ru"
enru_tokenizer = FSMTTokenizer.from_pretrained(mname)
enru_model = FSMTForConditionalGeneration.from_pretrained(mname)


import torch
import torchvision
import transformers
import more_itertools
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from rudalle import utils

import math
import torch
import torch.nn.functional as F
from rudalle.dalle.utils import divide, split_tensor_along_last_dim

@torch.jit.script
def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))


def gelu(x):
    return gelu_impl(x)


In [None]:
#@title # <font color="lightgreen" size="+3">←</font> **Generate Images** <font size="+2">🎨</font>
#@markdown Type your prompt and run this cell to let the magic happen
import random
from tqdm.auto import tqdm

prompt = "A photo of a male Philipuss. Pencil sketch." #@param {type:"string"}
translate_english_to_russian = True #@param {type:"boolean"}
image_count =  6#@param {type:"number"}
seed = -1 #@param {type:"number"}

#@markdown ---

#@markdown **ADVANCED SETTINGS**

num_resolutions = 7 #@param {type:"integer"}
top_K = 1024 #@param {type:"number"}
top_P = 0.92 #@param {type:"number"}
#@markdown ---
#@markdown **Image Prompt Settings**
image_prompt = "" #@param {type:"string"}
crop_up = 10 #@param {type:"number"}
crop_left = 0 #@param {type:"number"}
crop_right = 0 #@param {type:"number"}
crop_down = 0 #@param {type:"number"}



#########################################################


if translate_english_to_russian:
  input_ids = enru_tokenizer.encode(prompt, return_tensors="pt")
  outputs = enru_model.generate(input_ids)
  decoded = enru_tokenizer.decode(outputs[0], skip_special_tokens=True)
  prompt = decoded

print(prompt)
text = prompt


if seed == -1:
  #Thanks to kendrick#9537 for spotting the problem
  seed = random.randint(0, 2**32-1)


pil_images = []
scores = []

if image_prompt == "":
  seed_everything(seed)

  for top_k, top_p, images_num in tqdm([(top_K, top_P, image_count)][::-1][:num_resolutions]):
      _pil_images, _scores = generate_images(text, tokenizer, dalle, vae, top_k=top_k, images_num=images_num, top_p=top_p)
      pil_images += _pil_images
      scores += _scores
      show([pil_image for pil_image, score in sorted(zip(pil_images, scores), key=lambda x: -x[1])], 6)
else:
  #Image_Prompt = Image.open(requests.get(image_prompt, stream=True).raw).resize((256, 256))
  Image_Prompt = Image.open(image_prompt).resize((256, 256))
  borders = {'up': crop_up, 'left': crop_left, 'right': crop_right, 'down': crop_down}
  image_prompts = [
      ImagePrompts(Image_Prompt, borders, vae, torch.device('cuda'), crop_first=True)
  ]

  for image_prompt in image_prompts:
      total_image_prompts = []
      seed_everything(42)
      for top_k, top_p, images_num in [
          (top_K, top_P, image_count),
      ]:
        _pil_images, _ = generate_images(
            text,
            tokenizer,
            dalle,
            vae,
            top_k=top_K,
            images_num=images_num,
            image_prompts=image_prompt,
            top_p=top_P
        )
        pil_images += _pil_images
      #top_images, _ = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=5)
      #total_image_prompts += super_resolution(top_images, realesrgan)

In [None]:
#@title <font color="lightgreen" size="+3">←</font> **Show Best Outputs** <font size="+2">🥇</font>

top_images_count = -1 #@param {type:"number"}

if top_images_count == -1:
  top_images_count = image_count

top_images, clip_scores = cherry_pick_by_clip(pil_images, prompt, ruclip, ruclip_processor, device=device, count=top_images_count)
show(top_images, top_images_count)

In [None]:
#@title <font color="lightgreen" size="+3">←</font> **Increase Resolution** <font size="+2">✨</font>

upscale_images_count = -1 #@param {type:"number"}

if upscale_images_count == -1:
  upscale_images_count = image_count

import os
import time
from IPython import display

top_images, clip_scores = cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device=device, count=upscale_images_count)
realesrgan = get_realesrgan('x4', device=device)
sr_images = super_resolution(top_images, realesrgan)
timestring = time.strftime('%Y%m%d%H%M%S')
os.makedirs(f'dalle_outputs/{timestring}', exist_ok=True)

for pil_image, score in sorted(zip(sr_images, clip_scores), key=lambda x: -x[1]):
    pil_image.save(f'dalle_outputs/{timestring}/rudalle_{score}.png')
    print(f'CLIP score: {score}')
    display.display(pil_image)