Copyright 2024 DeepMind Technologies Limited

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Neural Assets MOVi Example

This notebook shows how we can train a Neural Assets model on the public MOVi
dataset. It does not implement the entire training process such as pre-trained
weight loading and optimizer settings.

Instead, it helps to understand the input & output pipeline and 3D control
interface of our model.

We will load a batch of data, apply data pre-processing, visualize them, and
run the model on it to compute losses. We borrow the diffusion implementation
from Hugging Face Diffusers.


In [None]:
#@title Imports

import diffusion
from etils.lazy_imports import *
import modules
import preprocessing
import viz_utils


In [None]:
#@title Build the RAW MOVi Dataset

# @markdown Dataset settings:
VARIANT = 'e'  # @param {type:"string"}
RESOLUTION = 256  # @param {type:"integer"}
BATCH_SIZE = 4  # @param {type:"integer"}


def _get_max_obj_num(variant):
  """Max number of objects in the dataset."""
  if variant in ['a', 'b', 'c']:
    return 10
  elif variant in ['e', 'f']:
    return 23
  else:
    raise ValueError(f'Invalid MOVi variant: {variant}')


def load_movi():
  """Build the MOVi dataset."""
  ds_name = f'movi_{VARIANT}/{RESOLUTION}x{RESOLUTION}:1.0.0'
  ds_builder = tfds.builder(ds_name, data_dir='gs://kubric-public/tfds')
  ds = ds_builder.as_dataset(split='train', shuffle_files=False)
  ds_iter = iter(ds)
  data = next(ds_iter)
  return ds, data


# We will save the visualization results under this directory
save_path = f'./viz/movi_{VARIANT}/'

train_ds, sample = load_movi()
tf_bboxes_3d = sample['instances']['bboxes_3d']  # [N, T, 8, 3]
tf_bboxes_3d = einops.rearrange(tf_bboxes_3d, 'n t ... -> t n ...')
bboxes_3d = tf_bboxes_3d.numpy()  # [T, N, 8, 3]

video_w_bbox_3d = viz_utils.show_3d_bbox_on_image(
    viz_utils.to_numpy(sample['video']),
    bboxes_3d=bboxes_3d,
    cameras=viz_utils.to_numpy(sample['camera']),
    bboxes_center_3d=bboxes_3d.mean(-2),
)
print('Projected 3D bboxes and centers to 2D video frames')
viz_utils.show_video(
    video_w_bbox_3d,
    fps=8,
    codec='gif',
    save_path=os.path.join(save_path, 'video_bbox_3d.gif'),
)


In [None]:
#@title Apply Our Data Pre-processing Pipeline and Visualize them

preproc_fn = lambda x: preprocessing.preprocess_gv_movi_example(
    x,
    max_instances=_get_max_obj_num(variant=VARIANT),
    resolution=RESOLUTION,
    drop_cond_prob=0.1,
)
train_loader = train_ds.map(preproc_fn).batch(batch_size=BATCH_SIZE)
train_loader_iter = iter(train_loader)
batch = next(train_loader_iter)
batch = viz_utils.to_numpy(batch)

print('Loaded training data batch:')
print(etree.spec_like(batch))

vis_3d_args = dict(
    focal_length_lst=batch.get('camera_focal_length', None),
    sensor_width_lst=batch.get('camera_sensor_width', None),
    camera2image_lst=batch.get('camera_projection', None),
    is_proj_4_corner=True,
    has_background_bbox=True,
)

print('\nSource Images with 2D bboxes (used for appearance token extraction)')
viz_utils.draw_bbox(
    batch['src_image'],
    batch['src_bboxes'],
    save_path=os.path.join(save_path, 'src_bbox_2d.png'),
)
print('Source Images with 3D bboxes (not used in training)')
viz_utils.draw_bbox_3d(
    batch['src_image'],
    batch['src_bboxes_3d'],
    save_path=os.path.join(save_path, 'src_bbox_3d.png'),
    **vis_3d_args,
)
print('Background images (used for background token extraction)')
viz_utils.show_images(
    batch['src_bg_image'], save_path=os.path.join(save_path, 'src_bg.png')
)

print('\nTarget Images with 2D bboxes (not used for training)')
viz_utils.draw_bbox(
    batch['tgt_image'],
    batch['tgt_bboxes'],
    save_path=os.path.join(save_path, 'tgt_bbox_2d.png'),
)
print('Target Images with 3D bboxes (used for pose token extraction and the diffusion reconstruction target)')
viz_utils.draw_bbox_3d(
    batch['tgt_image'],
    batch['tgt_bboxes_3d'],
    save_path=os.path.join(save_path, 'tgt_bbox_3d.png'),
    **vis_3d_args,
)


In [None]:
#@title Build the Neural Assets Model

# We use SD v2.1 as our base generator
model_name = 'stable_diffusion_v2_1'
generator = diffusion.DiffuserDiffusionWrapper(model_name=model_name)
hidden_size = 1024  # the cross-attention dim in the denoising U-Net

# Learnable appearance tokens + pose tokens
token_dim = hidden_size // 2
# We will do RoIAlign to extract 2x2 feature maps as object appearance tokens
roi_align_size = 2
# We use DINO as our visual encoder
dino_version, dino_variant = 'v1', 'B/8'
conditioning_encoder = modules.ConditioningEncoder(
    appearance_encoder=modules.RoIAlignAppearanceEncoder(
        # +1 because we add a global background bbox
        shape=(_get_max_obj_num(variant=VARIANT) + 1, token_dim),
        image_backbone=modules.DINOViT.from_variant_str(
            version=dino_version,
            variant=dino_variant,
            in_vrange=(0, 1),
            use_imagenet_value_range=True,
            frozen_model=False,  # we fine-tune DINO
        ),
        roi_align_size=roi_align_size,  # feature map size: (28, 28)
        aggregate_method='flatten',  # flatten the 2x2 feature map
    ),
    # Treat 3D bbox as object pose
    object_pose_encoder=modules.MLPPoseEncoder(
        mlp_module=modules.MLP(
            hidden_size=token_dim * 2,
            output_size=token_dim,
            num_hidden_layers=1,
        ),
        # Duplicate bbox tokens to match the length of appearance tokens
        duplicate_factor=roi_align_size**2,
    ),
    # Mask out background pixels when encoding object tokens
    mask_out_bg_for_appearance=True,
    background_value=0.5,
    # Map the relative camera pose with a MLP
    # This serves as the pose token for the background
    background_pos_enc_type='mlp',
)
# Fuse appearance and pose tokens with a neck module
conditioning_neck = modules.FeedForwardNeck(
    feed_forward_module=nn.Dense(hidden_size),
)

# Full Model
ns_model = modules.ControllableGenerator(
    generator=generator,
    conditioning_encoder=conditioning_encoder,
    conditioning_neck=conditioning_neck,
)


In [None]:
#@title Model Forward & Loss Computation

input_dict = {
    'tgt_images': batch['tgt_image'],
    'tgt_object_poses': batch['tgt_bboxes_3d'],
    'src_images': batch['src_image'],
    'src_bboxes': batch['src_bboxes'],
    'src_bg_images': batch['src_bg_image'],
}
output_dict, params = ns_model.init_with_output(jax.random.key(0), **input_dict)
# Compute the denoising loss
loss = (output_dict['diff'] - output_dict['pred_diff']) ** 2
loss = loss.mean()
print('Training loss: ', loss)
