In [None]:
%cd ~/big_vision/
import big_vision.datasets.core as ds_core

/home/austinwang/big_vision


In [None]:
input_data = dict(name='laion400m/images', split='train')
train_data = ds_core.get(**input_data)

2024-03-26 04:16:31.351190: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-26 04:16:31.351338: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-26 04:16:31.818085: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


In [None]:
train_data

<big_vision.datasets.tfds.DataSource at 0x7feeb012fcd0>

In [2]:
%cd ~/big_vision/
import importlib
from typing import Any, Optional, Tuple, Union
from absl import logging

from big_vision import utils
import flax.linen as nn
import jax.numpy as jnp

ConfigDict = Any

class Model(nn.Module):
  """Two towers transformer."""
  image: Optional[ConfigDict] = None
  text: Optional[ConfigDict] = None
  text_model: str = "proj.image_text.text_transformer"
  image_model: str = "vit"
  out_dim: Union[int, Tuple[int, int]] = 128
  temperature_init: float = 1.0
  bias_init: Optional[float] = None

  @nn.compact
  def __call__(self, image, text=None, **kw):
    """Returns (B,C) image and (B,C) text representations."""

    # Support calling without text or without image, for example for few-shot.
    ztxt, zimg = None, None
    out = {}
    out_dims = self.out_dim
    if isinstance(out_dims, int):
      out_dims = (out_dims, out_dims)

    # Embed the text:
    if text is not None:
      text_model = importlib.import_module(
          f"big_vision.models.{self.text_model}"
      ).Model(**{"num_classes": out_dims[1], **(self.text or {})}, name="txt")

      ztxt, out_txt = text_model(text, **kw)
      for k, v in out_txt.items():
        out[f"txt/{k}"] = v

      # Normalize the embeddings the models give us.
      out["txt/norm"] = jnp.linalg.norm(ztxt, axis=1, keepdims=True)
      out["txt/normalized"] = ztxt = ztxt / (out["txt/norm"] + 1e-8)

    if image is not None:
      image_model = importlib.import_module(
          f"big_vision.models.{self.image_model}"
      ).Model(**{"num_classes": out_dims[0], **(self.image or {})}, name="img")  # pylint: disable=not-a-mapping

      zimg, out_img = image_model(image, **kw)
      for k, v in out_img.items():
        out[f"img/{k}"] = v

      # Normalize the embeddings the models give us.
      out["img/norm"] = jnp.linalg.norm(zimg, axis=1, keepdims=True)
      out["img/normalized"] = zimg = zimg / (out["img/norm"] + 1e-8)

    temp_init = jnp.log(self.temperature_init)
    t = self.param("t",
                   lambda key, shape, dtype: temp_init * jnp.ones(shape, dtype),
                   (1,), jnp.float32)
    out["t"] = jnp.exp(t)

    out["t/parameter"] = t
    if (b_init := self.bias_init) is not None:
      out["b"] = self.param("b", lambda k, s, d: b_init * jnp.ones(s, d),
                            (1,), jnp.float32)

    # We could actually play with pre-multiplying by temperature here, such
    # that out["t"] is nothing special to the trainer anymore.
    # logging.info("Temperature: %s", out["t"].item())
    # logging.info("Bias: %s", out["b"].item() if "b" in out else "None")
    # exit()

    return zimg, ztxt, out


/home/austinwang/big_vision


2024-03-28 19:44:46.655054: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-28 19:44:46.655123: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-28 19:44:46.656134: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
model = Model(temperature_init=10.0,bias_init=-10.0)

In [4]:
import jax
import big_vision.utils as u
def init(rng):
    shape = (1, 224, 224, 3)
    dtype = jnp.float32
    batch = {"image":jnp.zeros(shape, dtype)}
    return model.init(rng, **batch)

In [5]:
total_steps = 65_000
warmup_steps = max(int(0.03 * total_steps), 100)
print("Warmup steps:", warmup_steps)
schedule = [
      ('img/.*', None),  # Freezes image tower.
      ('.*', dict(decay_type='cosine', warmup_steps=warmup_steps)),
  ]

Warmup steps: 1950


In [6]:
def _make_mask_trees(params, patterns_values, log):
  patterns, values = zip(*patterns_values)
  masks = u.make_mask_trees(params, patterns, log=log)
  return masks, values

def _split_frozen(masks, scheds):
  """Computes `frozen_mask` and updates `masks` and `scheds`."""
  # Specifying `None` as a scheduler freezes params.
  all_false = jax.tree_map(lambda *bools: not any(bools), *masks)
  not_covered = [k for k, v in u.tree_flatten_with_names(all_false)[0] if v]
  assert not not_covered, (
      f"All params must be covered (use `None` for freezing): {not_covered}")
  frozen_masks = [
      mask for mask, sched in zip(masks, scheds) if sched is None]
  frozen_mask = jax.tree_map(
      lambda *bools: any(bools), *frozen_masks,
      all_false)  # `all_false` is required when `frozen_masks==[]`.
  masks, scheds = zip(*(
      (mask, sched) for mask, sched in zip(masks, scheds) if sched is not None))
  return frozen_mask, masks, scheds



In [7]:
import operator
import optax
optax_name = "scale_by_adam"
tx_func = operator.attrgetter(optax_name)(optax)
optax_dict = dict(beta2_cap=0.95)
temp = tx_func(optax_dict)

In [9]:
rng = jax.random.PRNGKey(u.put_cpu(0))
rng, rng_init = jax.random.split(rng)
params_shape = jax.eval_shape(init, rng_init)
masks, scheds = _make_mask_trees(params_shape, schedule, "config.schedule")


In [10]:
scheds

(None, {'decay_type': 'cosine', 'warmup_steps': 1950})