In [1]:

import os
os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [2]:
import torch
import numpy as np
import jax
import jax.numpy as jnp
from kauldron.modules import pos_embeddings
from kauldron.modules import vit as kd_vit
import mediapy
from representations4d.models import model as model_lib
from representations4d.models import readout
import numpy as np
from representations4d.utils import checkpoint_utils
from einops import rearrange

In [3]:
# @title Hyperparameters
model_patch_size = (2, 16, 16)
im_size = (224, 224)
model_size = "B"
dtype = jnp.float32
model_output_patch_size = (2, 8, 8)
n_pixels_patch = (
    model_output_patch_size[0]
    * model_output_patch_size[1]
    * model_output_patch_size[2]
)
num_input_frames = 16
n_pixels_video = num_input_frames * im_size[0] * im_size[1]

embedding_shape = (
    num_input_frames // model_patch_size[0],
    im_size[0] // model_patch_size[1],
    im_size[1] // model_patch_size[2],
)
num_tokens = embedding_shape[0] * embedding_shape[1] * embedding_shape[2]

In [4]:
jax_ckpt_path = "representations4d/scaling4d_dist_b.npz"

In [5]:
encoder_state_dict = torch.load(
    "/Volumes/Storage/checkpoints/scaling4d_dist_b/encoder.pth"
)

### Torch model


In [6]:
from encoder import Encoder

torch_encoder = Encoder(
    input_size=(3, 16, 224, 224),
    patch_size=(2, 16, 16),
    num_heads=12,
    num_layers=12,
    hidden_size=768,
    mlp_size=3072,
    n_iter=1
)
torch_encoder.load_state_dict(encoder_state_dict)

<All keys matched successfully>

### JAX model


In [7]:
jax_depth_model = model_lib.Model(
    encoder=model_lib.Tokenizer(
        patch_embedding=model_lib.PatchEmbedding(
            patch_size=model_patch_size,
            num_features=kd_vit.VIT_SIZES[model_size][0],
        ),
        posenc=pos_embeddings.LearnedEmbedding(dtype=dtype),
        posenc_axes=(-4, -3, -2),
    ),
    processor=model_lib.GeneralizedTransformer.from_variant_str(
        variant_str=model_size,
        dtype=dtype,
    ),
)

### Test


In [8]:
video = mediapy.read_video("representations4d/horsejump-high.mp4")
video = mediapy.resize_video(video, im_size) / 255.0
video = video[jnp.newaxis, :num_input_frames].astype(jnp.float32)
video.shape

(1, 16, 224, 224, 3)

In [9]:
key = jax.random.key(0)
jax_params = jax_depth_model.init(key, video, is_training_property=False)

I0000 00:00:1754304474.673495  407062 service.cc:145] XLA service 0x1196be560 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1754304474.673509  407062 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1754304474.674765  407062 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1754304474.674778  407062 mps_client.cc:384] XLA backend will use up to 22906109952 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M2 Pro

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB



In [10]:
jax_restored_params = checkpoint_utils.recover_tree(
    checkpoint_utils.npload(jax_ckpt_path)
)

In [11]:
output_jax = (
    jax_depth_model.apply(
        jax_restored_params,
        video,
        is_training_property=False
    )
)

In [12]:
video_torch = rearrange(
    torch.from_numpy(jax.device_get(video).copy()),
    "b t h w c -> b c t h w"
)

In [13]:
output_torch = torch_encoder(video_torch)

In [14]:
assert len(output_jax) == len(output_torch)

for (out_jax, out_torch) in zip(output_jax, output_torch):
    assert torch.allclose(
        out_torch,
        torch.from_numpy(jax.device_get(out_jax).copy()),
        atol=1e-3,
        rtol=1e-3
    )