<a href="https://colab.research.google.com/github/pollinations/hive/blob/main/notebooks/1%20Text-To-Image/1%20DALLE-Mega.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/a9hByt4.png" width="300">

*Sunset over a lake*

The new model everyone is excited about.

In [None]:

text_input = 'sunset over a lake'  #@param {type: "string"}
output_path = '/content'

## 🛠️ Installation and set-up

In [None]:
# Install required libraries
!pip install -q dalle-mini
!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

We load required models:
* DALL·E mini for text to encoded images
* VQGAN for decoding images
* CLIP for scoring predictions

In [None]:
#!git lfs clone https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/
#!git clone https://huggingface.co/flax-community/dalle-mini

In [None]:
import jax
import jax.numpy as jnp

# check how many devices are available
jax.local_device_count()

In [None]:
!wget -c https://ipfs.pollinations.ai/ipfs/QmTo7LHa2U1MuRy1GE41hFGuXxAD5iWcWH4kp9rwpyZvNX -O /content/dallemegamodels.zip

!unzip -n /content/dallemegamodels.zip

In [None]:
# Model references

# dalle-mega
DALLE_MODEL = "/content/artifacts/mega-1-fp16:v14"  # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_COMMIT_ID = None

# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line
# DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"

# VQGAN model
VQGAN_REPO = "/content/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = None

In [None]:
# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel

# Load dalle-mini
model, params = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)

# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)

In [None]:
from flax.jax_utils import replicate

params = replicate(params)
vqgan_params = replicate(vqgan_params)

Model functions are compiled and parallelized to take advantage of multiple devices.

In [None]:
from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )


# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

Keys are passed to the model on each device to generate unique inference per device.

In [None]:
import random

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

## 🖍 Text Prompt

Our model requires processing prompts.

In [None]:
print("111")
from dalle_mini import DalleBartProcessor

processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)

Let's define some text prompts.

In [None]:
print("222")
prompts = [text_input]

Note: we could use the same prompt multiple times for faster inference.

In [None]:
print("333")
tokenized_prompts = processor(prompts)

Finally we replicate the prompts onto each device.

In [None]:
print("444")
tokenized_prompt = replicate(tokenized_prompts)

## 🎨 Generate images

We generate images using dalle-mini model and decode them with the VQGAN.

In [None]:
print("555")
# number of predictions per prompt
n_predictions = 4

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

In [None]:
print("666")
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

from tqdm import tqdm
print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    # decode images
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for i,decoded_img in enumerate(tqdm(decoded_images)):
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
        #display(img)
        img.save(f"{output_path}/{i:04d}.png")
        print("saving image",i,"to",f"{output_path}/{i:04d}.png")

In [None]:
if 'A100' in gpu:
  torch.backends.cudnn.enabled = False
  print('Finished setup for A100')

In [None]:
%cd /content
!git clone https://github.com/voodoohop/SwinIR
%cd /content/SwinIR

In [None]:
!wget -c "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.pth" -P experiments/pretrained_models

#!pip install "torchvision==0.9.0" "torch==1.8.0" "numpy==1.19.4" "opencv-python==4.4.0.46" "tqdm==4.62.2" "Pillow==8.3.2" "timm==0.4.12"
!pip install timm


In [None]:
#!pip install cog redis
#%cd /content/swinir

from predict import Predictor
p = Predictor()
p.setup()

from glob import glob
from tqdm import tqdm

images = glob(f"{output_path}/*.png")
images = list(sorted(images))

for image_file in tqdm(images):

    #print("image_file", image_file)
    

    path = p.predict(image_file)
    !cp "{path}" "{image_file}"
