In [1]:

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [2]:
import torch
import numpy as np
import jax
import jax.numpy as jnp
from kauldron.modules import pos_embeddings
from kauldron.modules import vit as kd_vit
import mediapy
from representations4d.models import model as model_lib
from representations4d.models import readout
import numpy as np
from representations4d.utils import checkpoint_utils
from einops import rearrange

In [3]:
# @title Hyperparameters
model_patch_size = (2, 16, 16)
im_size = (224, 224)
model_size = "B"
dtype = jnp.float32
model_output_patch_size = (2, 8, 8)
n_pixels_patch = (
    model_output_patch_size[0]
    * model_output_patch_size[1]
    * model_output_patch_size[2]
)
num_input_frames = 16
n_pixels_video = num_input_frames * im_size[0] * im_size[1]

embedding_shape = (
    num_input_frames // model_patch_size[0],
    im_size[0] // model_patch_size[1],
    im_size[1] // model_patch_size[2],
)
num_tokens = embedding_shape[0] * embedding_shape[1] * embedding_shape[2]

In [4]:
jax_ckpt_path = "representations4d/scaling4d_dist_b_depth.npz"

In [5]:
encoder_state_dict = torch.load("checkpoints/encoder.pth")
readout_state_dict = torch.load("checkpoints/readout.pth")

### Torch model


In [6]:
from encoder import Encoder
from encoder_to_readout import EncoderToReadout
from readout import AttentionReadout
import torch.nn as nn

torch_encoder = Encoder(
    input_size=(3, 16, 224, 224),
    patch_size=(2, 16, 16),
    num_heads=12,
    num_layers=12,
    hidden_size=768,
    n_iter=1
)
torch_encoder.load_state_dict(encoder_state_dict)

torch_encoder2readout = EncoderToReadout(
    embedding_shape=embedding_shape,
    readout_depth=0.95,
    num_input_frames=num_input_frames,
    sampling_mode="bilinear"
)

torch_attn_readout = AttentionReadout(
    num_classes=n_pixels_patch,
    num_params=1024,
    num_heads=16,
    num_queries=n_pixels_video // n_pixels_patch,
    output_shape=(
        num_input_frames,
        im_size[0],
        im_size[1],
        1,
    ),
    decoding_patch_size=model_output_patch_size,
)
torch_attn_readout.load_state_dict(readout_state_dict)

torch_depth_model = nn.Sequential(
    torch_encoder,
    torch_encoder2readout,
    torch_attn_readout
)

### JAX model


In [7]:
from flax import linen as nn

encoder = model_lib.Model(
    encoder=model_lib.Tokenizer(
        patch_embedding=model_lib.PatchEmbedding(
            patch_size=model_patch_size,
            num_features=kd_vit.VIT_SIZES[model_size][0],
        ),
        posenc=pos_embeddings.LearnedEmbedding(dtype=dtype),
        posenc_axes=(-4, -3, -2),
    ),
    processor=model_lib.GeneralizedTransformer.from_variant_str(
        variant_str=model_size,
        dtype=dtype,
    ),
)
encoder2readout = model_lib.EncoderToReadout(
    embedding_shape=(
        num_input_frames // model_patch_size[0],
        im_size[0] // model_patch_size[1],
        im_size[1] // model_patch_size[2],
    ),
    readout_depth=0.95,
    num_input_frames=num_input_frames,
    mode="linear"
)
readout_head = readout.AttentionReadout(
    num_classes=n_pixels_patch,
    num_params=1024,
    num_heads=16,
    num_queries=n_pixels_video // n_pixels_patch,
    output_shape=(
        num_input_frames,
        im_size[0],
        im_size[1],
        1,
    ),
    decoding_patch_size=model_output_patch_size,
)

jax_depth_model = nn.Sequential([encoder, encoder2readout, readout_head])

### Test


In [8]:
video = mediapy.read_video("representations4d/horsejump-high.mp4")
video = mediapy.resize_video(video, im_size) / 255.0
video = video[jnp.newaxis, :num_input_frames].astype(jnp.float32)
video.shape

(1, 16, 224, 224, 3)

In [9]:
key = jax.random.key(0)
jax_params = jax_depth_model.init(key, video, is_training_property=False)

I0000 00:00:1754264837.677353  291422 service.cc:145] XLA service 0x320c68d60 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1754264837.677368  291422 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1754264837.678898  291422 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1754264837.678909  291422 mps_client.cc:384] XLA backend will use up to 22906109952 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M2 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB



In [10]:
jax_restored_params = checkpoint_utils.recover_tree(
    checkpoint_utils.npload(jax_ckpt_path)
)

In [11]:
output_jax = (
    jax_depth_model.apply(
        jax_restored_params,
        video,
        is_training_property=False
    )
)

output_jax.shape

(1, 16, 224, 224, 1)

In [12]:


video_torch = rearrange(
    torch.from_numpy(jax.device_get(video).copy()),
    "b t h w c -> b c t h w"
)

In [13]:
output_torch = torch_depth_model(video_torch)
output_torch.shape

torch.Size([1, 16, 224, 224, 1])

In [14]:
assert torch.allclose(
    output_torch,
    torch.from_numpy(jax.device_get(output_jax).copy()),
    atol=1e-5,
    rtol=1e-5
)

In [15]:
# @title Visualize JAX depth maps
out = np.array(output_jax[0])
out = jnp.tile(out, [1, 1, 1, 3])
out = out / np.max(out)
vis = np.concatenate([video[0], out], axis=2)
mediapy.show_video(vis, fps=20)

0
This browser does not support the video tag.


In [16]:
# @title Visualize Torch depth maps
out = output_torch[0].detach().numpy()
out = jnp.tile(out, [1, 1, 1, 3])
out = out / np.max(out)
vis = np.concatenate([video[0], out], axis=2)
mediapy.show_video(vis, fps=20)

0
This browser does not support the video tag.
