This is a JAX implementation of the VQGAN based on this: 

https://github.com/patil-suraj/vqgan-jax

https://huggingface.co/flax-community/vqgan_f16_16384

However, I failed to run it with batch any larger than 18 images, although tried every possible combination of parameters. It just constantly fails with OOM. On the other hand I managed to run Pytorch implementation with batch size of 32, so I did not use Jax whatsoever.. Therefore I'll leave this notebook here just in case

In [None]:
import io

import requests
from PIL import Image
import numpy as np
from datetime import datetime
import os
# os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.08'
# os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] ='platform'
# os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
# os.environ['XLA_FLAGS'] = '--xla_gpu_strict_conv_algorithm_picker=false'

In [None]:
from vqgan_jax.modeling_flax_vqgan import VQModel

VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
model, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, 
    revision=VQGAN_COMMIT_ID,
    # _do_init=False
)

In [None]:
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

In [None]:
def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content))

def preprocess_vqgan(x):
    x = 2.*x - 1.
    return x

def custom_to_pil(x):
    x = np.clip(x, -1., 1.)
    x = (x + 1.)/2.
    x = (255*x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def preprocess(img):
    img = TF.resize(img, (512, 512), interpolation=Image.LANCZOS)
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return img.permute(0, 2, 3, 1).numpy()

In [None]:
test_img_path = '/mnt/home/data/diffusiondb_img/part-013448/cfb9cd9a-84f4-402a-9547-060654a1e9a3.webp'
image = Image.open(test_img_path)
image = preprocess(image)

In [None]:
image_batch = np.random.rand(30, 512, 512, 3)

In [None]:
quant_states, indices = model.encode(image_batch)
rec = model.decode(quant_states)

In [None]:
custom_to_pil(preprocess_vqgan(image[0]))

In [None]:
custom_to_pil(preprocess_vqgan(np.asarray(rec[0])))