<a href="https://colab.research.google.com/github/SaulLu/google_colab_notebooks/blob/main/jax_vqgan_clip_clean_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
!nvidia-smi

Tue Jul 13 15:27:16 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P0    26W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Profiling

Install graphviz

In [3]:
# !sudo apt install graphviz

Install go

In [4]:
# !add-apt-repository ppa:longsleep/golang-backports -y
# !apt update
# !apt install golang-go
# %env GOPATH=/root/go
# !go get -u github.com/gopherdata/gophernotes
# !cp ~/go/bin/gophernotes /usr/bin/
# !mkdir /usr/local/share/jupyter/kernels/gophernotes
# !cp ~/go/src/github.com/gopherdata/gophernotes/kernel/* \
#        /usr/local/share/jupyter/kernels/gophernotes

install profiler

In [5]:
# !go get -u github.com/google/pprof

#Init

In [6]:
!pip install git+https://github.com/SaulLu/vqgan-jax@create-package 

Collecting git+https://github.com/SaulLu/vqgan-jax@create-package
  Cloning https://github.com/SaulLu/vqgan-jax (to revision create-package) to /tmp/pip-req-build-eesxj8rx
  Running command git clone -q https://github.com/SaulLu/vqgan-jax /tmp/pip-req-build-eesxj8rx
  Running command git checkout -b create-package --track origin/create-package
  Switched to a new branch 'create-package'
  Branch 'create-package' set up to track remote branch 'create-package' from 'origin'.
Building wheels for collected packages: vqgan-jax
  Building wheel for vqgan-jax (setup.py) ... [?25l[?25hdone
  Created wheel for vqgan-jax: filename=vqgan_jax-0.0.1-cp37-none-any.whl size=7467 sha256=14c24e0906b5c04c306d37f012e79ac13589c67dc25d3163bf8ed2d6daa8a5f1
  Stored in directory: /tmp/pip-ephem-wheel-cache-y2vn584a/wheels/81/fe/b0/d77c661ac2ddc2753719df436c8a6f744d0e23b569eb19d507
Successfully built vqgan-jax


In [7]:
!pip install transformers ftfy wandb Pillow



In [8]:
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33msaullu[0m (use `wandb login --relogin` to force relogin)


In [9]:
import wandb

# Core

In [None]:
import argparse
import time
from pathlib import Path
from typing import Any, Callable

import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax import core, struct
from flax.training.common_utils import get_metrics
from jax import custom_vjp
from PIL import Image
from torchvision.transforms import functional as TF
from transformers import (
    CLIPFeatureExtractor,
    CLIPProcessor,
    CLIPTokenizer,
    CLIPTokenizerFast,
    FlaxCLIPModel,
    is_tensorboard_available,
)
from vqgan_jax.modeling_flax_vqgan import VQModel


class TrainState(struct.PyTreeNode):
    """Simple train state for the common case with a single Optax optimizer.

    Synopsis::

        state = TrainState.create(
            apply_fn=model.apply,
            params=variables['params'],
            tx=tx)
        grad_fn = jax.grad(make_loss_fn(state.apply_fn))
        for batch in data:
            grads = grad_fn(state.params, batch)
            state = state.apply_gradients(grads=grads)

    Note that you can easily extend this dataclass by subclassing it for storing
    additional data (e.g. additional variable collections).

    For more exotic usecases (e.g. multiple optimizers) it's probably best to
    fork the class and modify it.

    Args:
        step: Counter starts at 0 and is incremented by every call to
        `.apply_gradients()`.
        apply_fn: Usually set to `model.apply()`. Kept in this dataclass for
        convenience to have a shorter params list for the `train_step()` function
        in your training loop.
        params: The parameters to be updated by `tx` and used by `apply_fn`.
        tx: An Optax gradient transformation.
        opt_state: The state for `tx`.
    """

    step: int
    params: core.FrozenDict[str, Any]
    tx: optax.GradientTransformation = struct.field(pytree_node=False)
    opt_state: optax.OptState

    def apply_gradients(self, *, grads, **kwargs):
        """Updates `step`, `params`, `opt_state` and `**kwargs` in return value.

        Note that internally this function calls `.tx.update()` followed by a call
        to `optax.apply_updates()` to update `params` and `opt_state`.

        Args:
        grads: Gradients that have the same pytree structure as `.params`.
        **kwargs: Additional dataclass attributes that should be `.replace()`-ed.

        Returns:
        An updated instance of `self` with `step` incremented by one, `params`
        and `opt_state` updated by applying `grads`, and additional attributes
        replaced as specified by `kwargs`.
        """
        updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
        new_params = optax.apply_updates(self.params, updates)
        return self.replace(
            step=self.step + 1,
            params=new_params,
            opt_state=new_opt_state,
            **kwargs,
        )

    @classmethod
    def create(cls, *, params, tx, **kwargs):
        """Creates a new instance with `step=0` and initialized `opt_state`."""
        opt_state = tx.init(params)
        return cls(
            step=0,
            params=params,
            tx=tx,
            opt_state=opt_state,
            **kwargs,
        )


@dataclass
class ModelArguments:
    """
    Arguments
    """
    clip_model_name_or_path: Optional[str] = field(
        default="openai/clip-vit-base-patch32",
        metadata={
            "help": "The model checkpoint for weights initialization of CLIP model."
        },
    )
    vqgan_model_name_or_path: Optional[str] = field(
        default="valhalla/vqgan-imagenet-f16-1024",
        metadata={
            "help": "The model checkpoint for weights initialization of VQGAN model."
        },
    )
    model_type: Optional[str] = field(
        default=None,
        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    dtype: Optional[str] = field(
        default="float32",
        metadata={
            "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
        },
    )

@dataclass
class TrainingArguments:
    pass

args = argparse.Namespace(
    prompts=["superrealistic house in forest"],
    output_dir="logs",
    image_prompts=[],
    noise_prompt_seeds=[],
    noise_prompt_weights=[],
    size=[480, 480],
    init_image=None,
    init_weight=0.0,
    clip_model="openai/clip-vit-base-patch32",  # change value
    # vqgan_config='vqgan_imagenet_f16_1024.yaml',  # with model_path
    vqgan_model="valhalla/vqgan-imagenet-f16-1024",  # rename vqgan_checkpoint
    step_size=0.05,
    cutn=5,
    cut_pow=1.0,
    display_freq=50,
    seed=0,
    scale_min=1.0,
    scale_max=1.0,
)


wandb.init(project="test-vqgan-clip", config=args)


model = VQModel.from_pretrained(args.vqgan_model)

tokenizer = CLIPTokenizer.from_pretrained(args.clip_model)
processor = CLIPProcessor.from_pretrained(args.clip_model)
perceptor = FlaxCLIPModel.from_pretrained(args.clip_model)


context_length = 77  # todo fix that


cut_size = perceptor.config.vision_config.image_size  # perceptor.visual.input_resolution
e_dim = model.config.embed_dim  # model.quantize.e_dim

f = 2 ** (model.config.num_resolutions - 1)

n_toks = model.config.n_embed
toksX, toksY = args.size[0] // f, args.size[1] // f
sideX, sideY = toksX * f, toksY * f

z_min = jnp.min(model.params["quantize"]["embedding"]["embedding"], axis=0)
z_max = jnp.max(model.params["quantize"]["embedding"]["embedding"], axis=0)


def parse_prompt(prompt):
    vals = prompt.rsplit(":", 2)
    vals = vals + ["", "1", "-inf"][len(vals) :]
    return vals[0], float(vals[1]), float(vals[2])


texts = []
for prompt in args.prompts:
    txt, weight, stop = parse_prompt(prompt)
    texts.append(txt)

inputs = tokenizer(texts, padding="max_length", max_length=context_length, return_tensors="jax")
inputs




# f :: a -> b
@custom_vjp
def clip_with_grad(x):
    return jnp.clip(x, a_min=0, a_max=1)


# f_fwd :: a -> (b, c)
def clip_with_grad_fwd(x):
    return clip_with_grad(x), x


# f_bwd :: (c, CT b) -> CT a
def clip_with_grad_bwd(x, y_bar):
    ans = clip_with_grad(x)
    boolean = jnp.heaviside(y_bar * (x - ans), 1)
    ans_dot = y_bar * boolean
    return (ans_dot,)


clip_with_grad.defvjp(clip_with_grad_fwd, clip_with_grad_bwd)


def resample(input, size, align_corners=True):
    return jax.image.resize(input, size, method="bicubic")


def random_resized_crop(img, rng, shape, n_subimg):
    sideY, sideX = img.shape[2:4]
    max_size = min(sideX, sideY)
    min_size = min(sideX, sideY, shape[0])
    cutouts = []
    metrics = {}

    for j in range(n_subimg):
        rng, subrng = jax.random.split(rng)
        size = int(
            jax.random.randint(subrng, shape=(1,), minval=0, maxval=1.0) * (max_size - min_size) + min_size
        )  # **self.cut_pow

        rng, subrng = jax.random.split(rng)
        offsetx = int(jax.random.randint(subrng, shape=(1,), minval=0, maxval=sideX - size + 1))

        rng, subrng = jax.random.split(rng)
        offsety = int(jax.random.randint(subrng, shape=(1,), minval=0, maxval=sideY - size + 1))
        cutout = img[:, :, offsety : offsety + size, offsetx : offsetx + size]

        tmp_img = np.moveaxis(np.asarray((cutout[0] * 255).astype(np.uint8)), 0, -1)
        image = Image.fromarray(tmp_img)

        # resize
        final_shape = img.shape
        final_shape = jax.ops.index_update(final_shape, jax.ops.index[-2], shape[0])
        final_shape = jax.ops.index_update(final_shape, jax.ops.index[-1], shape[1])
        cutout = resample(cutout, final_shape)
        cutouts.append(cutout)

        # tmp show cutouts
        tmp_img = np.moveaxis(np.asarray((cutout[0] * 255).astype(np.uint8)), 0, -1)
        image = Image.fromarray(tmp_img)
        metrics[f"cutout {j}"] = wandb.Image(image)

    imgs_stacked = jnp.concatenate(cutouts, axis=0)
    return clip_with_grad(imgs_stacked), metrics


def train_step(rng, state, batch, n_subimg):
    def loss_fn(params, rng):  # rng not used yet
        def straight_through_quantize(x):
            return x + jax.lax.stop_gradient(model.quantize(x)[0] - x)

        z_latent_q = straight_through_quantize(params)
        output_vqgan_decoder = clip_with_grad((model.decode(z_latent_q) + 1) / 2)  # deterministic ??

        output_vqgan_decoder_reshaped = jnp.moveaxis(
            output_vqgan_decoder, (2, 1), (3, 2)
        )

        rng, subrng = jax.random.split(rng)
        imgs_stacked, metrics = random_resized_crop(
            output_vqgan_decoder_reshaped, subrng, shape=(cut_size, cut_size), n_subimg=n_subimg
        )

        outputs = perceptor(pixel_values=imgs_stacked, **batch)
        embed_img = jnp.expand_dims(outputs.image_embeds, axis=1)
        embed_txt = jnp.expand_dims(outputs.text_embeds, axis=0)
        dists = jnp.add(embed_img, -embed_txt)
        dists = jax.numpy.linalg.norm(dists, ord=2, axis=2)
        dists = jnp.arcsin(dists / 2) ** 2 * 2
        loss = dists.mean()
        # loss = jnp.mean(outputs.logits_per_text)
        return loss, (output_vqgan_decoder, metrics)

    rng, subrng = jax.random.split(rng)
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (output_vqgan_decoder, metrics)), grad = grad_fn(state.params, subrng)

    new_state = state.apply_gradients(grads=grad)

    image = Image.fromarray(np.asarray((output_vqgan_decoder[0] * 255).astype(np.uint8)))

    metrics.update({"loss": np.array(loss), "step": state.step, "image": wandb.Image(image)})

    return new_state, metrics


if args.seed is not None:
    rng = jax.random.PRNGKey(args.seed)
else:
    rng = jax.random.PRNGKey(0)

rng, subrng = jax.random.split(rng)
one_hot = jax.nn.one_hot(jax.random.randint(subrng, [toksY * toksX], 0, n_toks), n_toks)
z = jnp.matmul(one_hot, model.params["quantize"]["embedding"]["embedding"])
z = jnp.reshape(z, (-1, toksY, toksX, e_dim))

z_orig = z.clone()
tx = optax.adam(args.step_size)

state = TrainState.create(params=z, tx=tx)

i = 0
try:
    train_time = 0
    while i < 10000:
        i += 1
        # ======================== Training ================================
        # train_start = time.time()

        rng, subrng = jax.random.split(rng)
        state, train_metric = train_step(subrng, state, inputs, args.cutn)

        # train_time += time.time() - train_start

        # trick
        # state.replace(params= jnp.clip(state.params, a_min=z_min, a_max=z_max))

        # Save metrics
        if jax.process_index() == 0:
            wandb.log(train_metric)
except KeyboardInterrupt:
    pass
