In [None]:
# Fetch big_vision repository and move it into the current workdir (import path).
!git clone --depth=1 https://github.com/google-research/big_vision big_vision_repo
!cp -R big_vision_repo/big_vision big_vision
!pip install -qr big_vision/requirements.txt

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

from big_vision.models.proj.uvim import vtt  # stage-II model
from big_vision.models.proj.uvim import vit  # stage-I model

from big_vision.models.proj.uvim import decode
from big_vision.trainers.proj.uvim import depth_task as task
from big_vision.configs.proj.uvim import train_nyu_depth_pretrained as config_module

import big_vision.pp.ops_image
import big_vision.pp.ops_general
import big_vision.pp.proj.uvim.pp_ops
from big_vision.pp import builder as pp_builder

config = config_module.get_config()
res = 512
seq_len = config.model.seq_len

lm_model = vtt.Model(**config.model)
oracle_model = vit.Model(**config.oracle.model)

preprocess_fn = pp_builder.get_preprocess_fn(
    'resize(512)|value_range(-1,1)|'
    'copy(inkey="image",outkey="image_ctx")')

@jax.jit
def predict_code(params, x, rng, temperature):
  prompts = jnp.zeros((x["image"].shape[0], seq_len), dtype=jnp.int32)
  seqs, _, _ = decode.temperature_sampling(
      params=params, model=lm_model, seed=rng,
      inputs=x["image"],
      prompts=prompts,
      temperature=temperature,
      num_samples=1, eos_token=-1, prefill=False)
  seqs = jnp.squeeze(seqs, axis=1)  # drop num_samples axis 
  return seqs - 1
  
@jax.jit
def labels2code(params, x, ctx):
  y, aux = oracle_model.apply(params, x, ctx=ctx, train=False, method=oracle_model.encode)
  return aux["code"]

@jax.jit
def code2labels(params, code, ctx):
  logits, aux = oracle_model.apply(params, code, ctx=ctx, train=False, discrete_input=True, method=oracle_model.decode)
  return task.predict_outputs(logits, config.oracle)

In [None]:
# Load checkpoints
!gsutil cp -n gs://big_vision/uvim/depth_stageI_params.npz gs://big_vision/uvim/depth_stageII_params.npz .

oracle_params, oracle_state = vit.load(None, "depth_stageI_params.npz")
oracle_params = jax.device_put({"params": oracle_params, "state": oracle_state})

lm_params = vtt.load(None, "depth_stageII_params.npz")
lm_params = jax.device_put({"params": lm_params})

In [None]:
# Prepare dataset of images from NYU Depth V2:
#  - https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html
import os
import h5py
import numpy as np
import tensorflow as tf

if not os.path.exists("nyu_depth_v2_labeled.mat"):
  !wget --no-clobber http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat

dataset_file = h5py.File("nyu_depth_v2_labeled.mat", "r")

def nyu_depth_examples():
  for idx in range(dataset_file["images"].shape[0]):
    image = np.transpose(dataset_file["images"][idx], (2, 1, 0))
    yield {"image": image}

dataset = tf.data.Dataset.from_generator(
    nyu_depth_examples,
    output_signature={
        "image": tf.TensorSpec((480,640,3), tf.uint8),
    }).map(preprocess_fn)

In [None]:
# Run the model in a few examples:
from matplotlib import pyplot as plt
from matplotlib import patches

num_examples = 4
data = dataset.batch(1).take(num_examples).as_numpy_iterator()
key = jax.random.PRNGKey(0)
temperature = jnp.array(1e-7)

def to_depth(x, nbins=256, mind=1e-3, maxd=10):
  depth = x.astype(np.float32) + 0.5  # Undoes floor in expectation.
  return depth/nbins * (maxd - mind) + mind

def render_example(image, prediction, with_legend=True):
  f, ax = plt.subplots(1, 2, figsize=(10, 10))
  ax[0].imshow(image*0.5 + 0.5)
  ax[0].axis("off")
  ax[1].imshow(to_depth(prediction))
  ax[1].axis("off")

for idx, batch in enumerate(data):
  subkey = jax.random.fold_in(key, idx)
  code = predict_code(lm_params, batch, key, temperature)
  aux_inputs = task.input_pp(batch, config.oracle)
  prediction = code2labels(oracle_params, code, aux_inputs["ctx"])
  render_example(batch["image"][0], prediction["depth"][0])