In [2]:
import jax
import jax.numpy as jnp

# dalle-mega
DALLE_MODEL = "../datasets/dalle-mini-models/dallebart"  
DALLE_COMMIT_ID = None

# DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"

# VQGAN model
VQGAN_REPO = "../datasets/dalle-mini-models/vqgan-jax"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" 
DALLE_COMMIT_ID = None

# DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"


ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.

In [3]:
from ipywidgets import FloatProgress as IProgress
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel


# utils.logging.disable_progress_bar()
# Load dalle-mini
model, params = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)

# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)

ModuleNotFoundError: No module named 'ipywidgets'

In [None]:
from flax.jax_utils import replicate

params = replicate(params)
vqgan_params = replicate(vqgan_params)

In [None]:
from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )
# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

In [None]:
# Create the processor piece by piece

from dalle_mini.model.configuration import DalleBartConfig
from dalle_mini.model.text import TextNormalizer
from dalle_mini.model.tokenizer import DalleBartTokenizer
from dalle_mini.model.utils import PretrainedFromWandbMixin

tokenizer = DalleBartTokenizer.from_pretrained('dalle-mini/dalle-mega')
config = DalleBartConfig.from_pretrained('dalle-mini/dalle-mega')
processor = DalleBartProcessor(tokenizer, config.normalize_text, config.max_text_length)

In [None]:
# Download all model files (~5 GB)
from dalle_mini import DalleBartProcessor

processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)

In [None]:
prompts = [
    "fine art painting of a foolish samurai warrior wielding a magic sword stepping forth to oppose the evil that is Aku",
    "Barack Obama holding up the World Cup trophy",
    "Obi Wan Kenobi standing over lava with a lightsaber"
]

In [None]:
tokenized_prompts = processor(prompts)
tokenized_prompt = replicate(tokenized_prompts)

In [None]:
n_predictions = 8

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

In [None]:
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    # decode images
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
        display(img)
        print()