# Intro

*Speaker: S. Zuffi, S. Melzi*

*Author: D. Baieri*

In this demo, we'll see how to design step-by-step a 3D reconstruction method for a very challenging setting: drone-shot dolphin videos in open waters. The full version of this algorithm was developed as recent research in collaboration with the University of Zurich and presented at the CV4Animals Workshop at CVPR 2025.


# 0. Setup

Run the following cells to setup the environment. This takes a while, so do it ASAP!

The notebook was designed to work inside Colab, but it should still work outside by changing file paths. If you're on Colab, remember to select a GPU runtime to use SAM-2.

In [None]:
ROOT = '/content'    # change if outside Colab
!mkdir $ROOT/data/
!unzip -n $ROOT/stag_dolphins_assets.zip &> /dev/null
!pip install kaolin==0.18.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.9.0_cu126.html
!pip install huggingface_hub
!pip install git+https://github.com/mattloper/chumpy@9b045ff5d6588a24a0bab52c83f032e2ba433e17
!git clone https://github.com/facebookresearch/segment-anything-2.git
!cd segment-anything-2 && pip install -e .

In [None]:
import sys
sys.path.append(f'{ROOT}/segment-anything-2')
import torch
import random
import numpy as np
from tqdm.notebook import tqdm


SEED = 123456789
DATA_DIR = f'{ROOT}/data/2024_08_20_DF2_S8_SW_TT_ELM-fb_23450-fe_23950'
MODEL_DIR = f'{ROOT}/model'

random.seed(SEED)
np.random.seed(SEED)
torch.random.manual_seed(SEED)

# 1. Preprocessing

Our design starts with the input data. Inside the "data" folder, you will find a "frames" subfolder (containing video frames we want to reconstruct) and a "metadata" subfolder (which we will use later). First off, we need **masks** to tell us which pixels of each image contain the target object (the dolphin): to this end, we will use the powerful SAM-2 segmentation model.

In [None]:
import pathlib
from PIL import Image
from sam2.sam2_video_predictor import SAM2VideoPredictor


def preprocessing():

    predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
    frames_dir = pathlib.Path(DATA_DIR) / 'frames'
    segment_dir = pathlib.Path(DATA_DIR) / 'segmentation'
    segment_dir.mkdir(exist_ok=True)

    positive_points = np.array([[360, 245], [326, 287], [334, 293], [336, 280], [348, 251], [348, 262], [356, 259]])
    negative_points = np.array([[273, 272], [290, 229], [280, 245], [546, 125], [339, 180], [402, 311], [316, 261]])

    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        state = predictor.init_state(str(frames_dir), async_loading_frames=True, offload_video_to_cpu=True)

        predictor.add_new_points_or_box(
            inference_state=state, frame_idx=0, obj_id=1,
            points=np.concatenate([positive_points, negative_points], axis=0),
            labels=np.concatenate([np.ones(positive_points.shape[0]), np.zeros(negative_points.shape[0])], axis=0)
        )

        # propagate the prompts to get masklets throughout the video
        t = 0
        for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
            frame = 255 * np.uint8(np.transpose((masks[0] > 0).float().cpu().numpy(), [1, 2, 0]).repeat(3, axis=-1))
            Image.fromarray(frame).save(segment_dir / (f'{str(t).zfill(6)}.png'))
            t += 1

if not (pathlib.Path(DATA_DIR) / 'segmentation').exists():
    preprocessing()

Let's package all our data in a single Dataset class to use later in our reconstruction algorithm:

In [None]:
CAMERA_FOCAL = 12.29 #mm
SENSOR_WIDTH = 17.27 #mm

import pathlib
import torchvision.transforms.functional as FV

from PIL import Image
from torch.utils.data import Dataset


class SegmentedVideo(Dataset):

    def __init__(self, data_dir: pathlib.Path, fps: int):
        super(SegmentedVideo, self).__init__()
        self.frames_original = []
        self.frames_segmented = []
        self.segment_centers = []
        for orig, segm in zip(sorted((data_dir / 'frames').iterdir()), sorted((data_dir / 'segmentation').iterdir())):
            orig_img, segm_img = Image.open(orig), Image.open(segm)
            self.frames_original.append(FV.pil_to_tensor(orig_img).permute(1, 2, 0) / 255.)  # (H, W)
            self.frames_segmented.append(FV.pil_to_tensor(FV.to_grayscale(segm_img)).permute(1, 2, 0) / 255.)  # (H, W)
            self.segment_centers.append(self.frames_segmented[-1].nonzero()[:, [1, 0]].float().mean(dim=0))  # (W, H)
        self.data_resolution = self.frames_original[0].shape[1::-1]  # (W, H)
        self.center_normalization = torch.tensor([self.data_resolution[0], self.data_resolution[1]])

        self.fps = fps

        # Load altitudes and store per-camera information
        altitudes = torch.from_numpy(np.load(data_dir / 'metadata' / 'altitude.npy')).float() / 100.
        altitudes = altitudes[50:250]  # filter specific frames out of 500, selected for this experiment
        self.cam_eye = torch.zeros(len(self.frames_original), 3, dtype=torch.float)
        self.cam_eye[:, 1] = altitudes
        self.cam_at = torch.zeros(3, dtype=torch.float)
        self.cam_up = torch.tensor([0, 0, 1], dtype=torch.float)

    def __getitem__(self, idx):
        c = (self.segment_centers[idx] / (self.center_normalization / 2.)) - 1.0
        return {'original': self.frames_original[idx],
                'segmented': self.frames_segmented[idx],
                'segm_center': torch.stack([c[0], -c[1]], dim=0),
                't': idx,
                'cam_eye': self.cam_eye[idx],
                'cam_at': self.cam_at,
                'cam_up': self.cam_up,
                'cam_width': self.data_resolution[0],
                'cam_height': self.data_resolution[1]}

    def __len__(self):
        return len(self.frames_original)

def collate_batch(frames):
    return {
        'original': torch.stack([f['original'] for f in frames], dim=0),
        'segmented': torch.stack([f['segmented'] for f in frames], dim=0),
        'segm_center': torch.stack([f['segm_center'] for f in frames], dim=0),
        't': torch.tensor([f['t'] for f in frames], dtype=torch.long),
        'cam': kal.render.camera.Camera.from_args(
            eye=torch.stack([f['cam_eye'] for f in frames], dim=0),
            at=torch.stack([f['cam_at'] for f in frames], dim=0),
            up=torch.stack([f['cam_up'] for f in frames], dim=0),
            fov=2 * math.atan(SENSOR_WIDTH / (2 * CAMERA_FOCAL)),
            fov_direction=CameraFOV.HORIZONTAL, width=frames[0]['cam_width'], height=frames[0]['cam_height'],
            near=1e-2, far=1e2,
            dtype=torch.float32,
        )
    }

dataset = SegmentedVideo(pathlib.Path(DATA_DIR), 50)

## Exercise 1

Query the dataset object at various indices. Then, use the outputs to visualize the original frames and the SAM-2 segmentations, to ensure that the input data is good to go for the rest of the pipeline.

In [None]:
# @title Your code here:



# 2. Defining reconstruction model

Now that our data is ready to be used, we define the model we will optimize to represent the scene. The first part is a SMPL-inspired **template model** controlling several features of the reconstructed dolphin, which are defined in the following cells (fine-grained details are beyond the scope of this tutorial).

In [None]:
# @title Constants

BASE_DOLPHIN_SIZE = 2.6 #m  Adult size of the indo-pacific bottlenose dolphin
BASE_MODEL_PARTS = 10   #   Number of bones in the original template model


In [None]:
# @title [[Ignore]] Chumpy backend for template model loading
import cv2
import pickle
import chumpy as ch
import scipy.sparse as sp
from chumpy.ch import MatVecMult
use_python_3 = True

class Rodrigues(ch.Ch):
    dterms = 'rt'

    def compute_r(self):
        return cv2.Rodrigues(self.rt.r)[0]

    def compute_dr_wrt(self, wrt):
        if wrt is self.rt:
            return cv2.Rodrigues(self.rt.r)[1].T

def ischumpy(x): return hasattr(x, 'dterms')

def verts_decorated(trans, pose,
    v_template, J, weights, kintree_table, bs_style, f,
    bs_type=None, posedirs=None, betas=None, shapedirs=None, want_Jtr=False):

    for which in [trans, pose, v_template, weights, posedirs, betas, shapedirs]:
        if which is not None:
            assert ischumpy(which)

    v = v_template

    if shapedirs is not None:
        if betas is None:
            betas = ch.zeros(shapedirs.shape[-1])
        v_shaped = v + shapedirs.dot(betas)
    else:
        v_shaped = v

    if posedirs is not None:
        v_posed = v_shaped + posedirs.dot(posemap(bs_type)(pose))
    else:
        v_posed = v_shaped

    v = v_posed

    if sp.issparse(J):
        regressor = J
        J_tmpx = MatVecMult(regressor, v_shaped[:,0])
        J_tmpy = MatVecMult(regressor, v_shaped[:,1])
        J_tmpz = MatVecMult(regressor, v_shaped[:,2])
        J = ch.vstack((J_tmpx, J_tmpy, J_tmpz)).T
    else:
        assert(ischumpy(J))

    assert(bs_style=='lbs')
    result, Jtr = lbs_verts_core(pose, v, J, weights, kintree_table, want_Jtr=True, xp=ch)

    tr = trans.reshape((1,3))
    result = result + tr
    Jtr = Jtr + tr

    result.trans = trans
    result.f = f
    result.pose = pose
    result.v_template = v_template
    result.J = J
    result.weights = weights
    result.kintree_table = kintree_table
    result.bs_style = bs_style
    result.bs_type =bs_type
    if posedirs is not None:
        result.posedirs = posedirs
        result.v_posed = v_posed
    if shapedirs is not None:
        result.shapedirs = shapedirs
        result.betas = betas
        result.v_shaped = v_shaped
    if want_Jtr:
        result.J_transformed = Jtr
    return result

def global_rigid_transformation(pose, J, kintree_table, xp):
    results = {}
    pose = pose.reshape((-1,3))
    id_to_col = {kintree_table[1,i] : i for i in range(kintree_table.shape[1])}
    parent = {i : id_to_col[kintree_table[0,i]] for i in range(1, kintree_table.shape[1])}

    if xp == ch:
        rodrigues = lambda x : Rodrigues(x)
    else:
        import cv2
        rodrigues = lambda x : cv2.Rodrigues(x)[0]

    with_zeros = lambda x : xp.vstack((x, xp.array([[0.0, 0.0, 0.0, 1.0]])))
    results[0] = with_zeros(xp.hstack((rodrigues(pose[0,:]), J[0,:].reshape((3,1)))))

    for i in range(1, kintree_table.shape[1]):
        results[i] = results[parent[i]].dot(with_zeros(xp.hstack((
            rodrigues(pose[i,:]),
            ((J[i,:] - J[parent[i],:]).reshape((3,1)))
            ))))

    pack = lambda x : xp.hstack([np.zeros((4, 3)), x.reshape((4,1))])

    results = [results[i] for i in sorted(results.keys())]
    results_global = results

    if True:
        results2 = [results[i] - (pack(
            results[i].dot(xp.concatenate( ( (J[i,:]), 0 ) )))
            ) for i in range(len(results))]
        results = results2
    result = xp.dstack(results)
    return result, results_global

def lbs_verts_core(pose, v, J, weights, kintree_table, want_Jtr=False, xp=ch):
    A, A_global = global_rigid_transformation(pose, J, kintree_table, xp)
    T = A.dot(weights.T)

    rest_shape_h = xp.vstack((v.T, np.ones((1, v.shape[0]))))

    v =(T[:,0,:] * rest_shape_h[0, :].reshape((1, -1)) +
        T[:,1,:] * rest_shape_h[1, :].reshape((1, -1)) +
        T[:,2,:] * rest_shape_h[2, :].reshape((1, -1)) +
        T[:,3,:] * rest_shape_h[3, :].reshape((1, -1))).T

    v = v[:,:3]

    if not want_Jtr:
        return v
    Jtr = xp.vstack([g[:3,3] for g in A_global])
    return (v, Jtr)

def verts_core(pose, v, J, weights, kintree_table, bs_style, want_Jtr=False, xp=ch):

    if xp == ch:
        assert(hasattr(pose, 'dterms'))
        assert(hasattr(v, 'dterms'))
        assert(hasattr(J, 'dterms'))
        assert(hasattr(weights, 'dterms'))
    assert(bs_style=='lbs')
    result = lbs_verts_core(pose, v, J, weights, kintree_table, want_Jtr, xp)

    return result

def lrotmin(p):
    if isinstance(p, np.ndarray):
        p = p.ravel()[3:]
        return np.concatenate([(cv2.Rodrigues(np.array(pp))[0]-np.eye(3)).ravel() for pp in p.reshape((-1,3))]).ravel()
    if p.ndim != 2 or p.shape[1] != 3:
        p = p.reshape((-1,3))
    p = p[1:]
    return ch.concatenate([(Rodrigues(pp)-ch.eye(3)).ravel() for pp in p]).ravel()

def posemap(s):
    if s == 'lrotmin':
        return lrotmin
    else:
        raise Exception('Unknown posemapping: %s' % (str(s),))

def backwards_compatibility_replacements(dd):

    # replacements
    if 'default_v' in dd:
        dd['v_template'] = dd['default_v']
        del dd['default_v']
    if 'template_v' in dd:
        dd['v_template'] = dd['template_v']
        del dd['template_v']
    if 'joint_regressor' in dd:
        dd['J_regressor'] = dd['joint_regressor']
        del dd['joint_regressor']
    if 'blendshapes' in dd:
        dd['posedirs'] = dd['blendshapes']
        del dd['blendshapes']
    if 'J' not in dd:
        dd['J'] = dd['joints']
        del dd['joints']

    # defaults
    if 'bs_style' not in dd:
        dd['bs_style'] = 'lbs'

def ready_arguments(fname_or_dict):

    if not isinstance(fname_or_dict, dict):
        if use_python_3:
            dd = pickle.load(open(fname_or_dict, "rb"), encoding='latin1')
        else:
            dd = pickle.load(open(fname_or_dict, "rb"))
    else:
        dd = fname_or_dict

    backwards_compatibility_replacements(dd)

    want_shapemodel = 'shapedirs' in dd
    nposeparms = dd['kintree_table'].shape[1]*3

    if 'trans' not in dd:
        dd['trans'] = np.zeros(3)
    if 'pose' not in dd:
        dd['pose'] = np.zeros(nposeparms)
    if 'shapedirs' in dd and 'betas' not in dd:
        dd['betas'] = np.zeros(dd['shapedirs'].shape[-1])

    for s in ['v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs', 'betas', 'J']:
        if (s in dd) and not hasattr(dd[s], 'dterms'):
            dd[s] = ch.array(dd[s])

    if want_shapemodel:
        dd['v_shaped'] = dd['shapedirs'].dot(dd['betas'])+dd['v_template']
        v_shaped = dd['v_shaped']
        J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:,0])
        J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:,1])
        J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:,2])
        dd['J'] = ch.vstack((J_tmpx, J_tmpy, J_tmpz)).T
        dd['v_posed'] = v_shaped + dd['posedirs'].dot(posemap(dd['bs_type'])(dd['pose']))
    else:
        dd['v_posed'] = dd['v_template'] + dd['posedirs'].dot(posemap(dd['bs_type'])(dd['pose']))

    return dd

def load_model(fname_or_dict):
    dd = ready_arguments(fname_or_dict)

    args = {
        'pose': dd['pose'],
        'v': dd['v_posed'],
        'J': dd['J'],
        'weights': dd['weights'],
        'kintree_table': dd['kintree_table'],
        'xp': ch,
        'want_Jtr': True,
        'bs_style': dd['bs_style']
    }

    result, Jtr = verts_core(**args)
    result = result + dd['trans'].reshape((1,3))
    result.J_transformed = Jtr + dd['trans'].reshape((1,3))

    for k, v in dd.items():
        setattr(result, k, v)

    return result


## Exercise 2

We will now implement a crucial part of our articulated model: the [Rodrigues' rotation formula](https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula). This operation is the backbone of our brand of linear blend skinning (LBS), a simple linear model for mesh animation. The goal is to efficiently map 3D angles in axis-angle format (i.e. rotation around x,y,z in radians) to rotation matrices that we can actually use for computation.

In [None]:
# @title Your code here:


def quat_to_rotmat(quat):
    """Convert quaternion coefficients to rotation matrix.
    Args:
        quat: size = [B, 4] 4 <===>(w, x, y, z)
    Returns:
        Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
    """
    pass


def batch_rodrigues(theta):
    """Convert axis-angle representation to rotation matrix.
    Args:
        theta: size = [B, 3]
    Returns:
        Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
    """
    pass

## Template model

In [None]:
# @title Linear Blend Skinning

class LBS:
    '''
    Implementation of linear blend skinning, with additional bone and scale
    Input:
        V (BN, V, 3): vertices to pose and shape
        pose (BN, J, 3, 3) or (BN, J, 3): pose in rot or axis-angle
        bone (BN, K): allow for direct change of relative joint distances
        scale (1): scale the whole kinematic tree
    '''
    def __init__(self, J_regressor, parents, weights, shapedirs, segmentation):
        self.n_joints = weights.shape[1]
        self.parents = parents
        self.weights = weights[None].float()
        self.shapedirs = shapedirs
        self.J_regressor = J_regressor
        self.seg_grouping = segmentation['grouping']
        self.v_labels = segmentation['v_labels']
        self.v_labels_coarse = segmentation['v_labels_coarse']
        self.n_groups = self.v_labels_coarse.max().item() + 1

        self.gidx = []
        self.seg_idx_coarse = []
        for i in range(self.n_groups):
            self.gidx.append(torch.where(self.seg_grouping == i)[0].cpu().tolist())
            self.seg_idx_coarse.append(torch.where(self.v_labels_coarse == i)[0])

        self.seg_idx = []
        self.parent_idx = []
        for i in range(self.n_joints):
            self.seg_idx.append(torch.where(self.v_labels == i)[0])
            self.parent_idx.append(torch.where(self.parents == i)[0].cpu().tolist())

    def __call__(self, V, pose, scale, betas=None, to_rotmats=True):
        batch_size = pose.shape[0]
        device = pose.device

        if betas is not None:
            V, J = self.apply_betas(V, betas)
        else:
            J = (self.J_regressor.unsqueeze(-1) * V.view(1, 1, -1, 3)).sum(dim=-2)

        V = V.expand(batch_size, -1, -1) * scale

        h_joints = F.pad(J.unsqueeze(-1), [0, 0, 0, 1], value=0)
        kin_tree = torch.cat([J[:, [0], :], J[:, 1:] - J[:, self.parents[1:]]], dim=1).unsqueeze(-1)

        V = F.pad(V.unsqueeze(-1), [0, 0, 0, 1], value=1)
        kin_tree = scale * kin_tree.expand(batch_size, -1, -1, -1)

        if to_rotmats:
            pose = batch_rodrigues(pose.view(-1, 3))
        pose = pose.view([batch_size, -1, 3, 3])
        T = torch.zeros([batch_size, self.n_joints, 4, 4]).float().to(device)
        T[:, :, -1, -1] = 1
        T[:, :, :3, :] = torch.cat([pose, kin_tree], dim=-1)
        T_rel = [T[:, 0]]
        for i in range(1, self.n_joints):
            T_rel.append(T_rel[self.parents[i]] @ T[:, i])
        T_rel = torch.stack(T_rel, dim=1)
        T_rel[:, :, :, [-1]] -= T_rel.clone() @ (h_joints * scale)
        T_ = self.weights @ T_rel.view(batch_size, self.n_joints, -1)
        T_ = T_.view(batch_size, -1, 4, 4)
        V = T_ @ V

        return V[:, :, :3, 0]

    def apply_betas(self, V, betas):

        V_betas = V + ((self.shapedirs * betas.view(self.n_groups, 1, 1, 4)).sum(dim=-1))
        J = torch.zeros((self.n_groups, self.n_joints, 3), device=V.device)
        V_shaped = torch.zeros_like(V)
        for i in range(self.n_groups):
            seg_idx = self.seg_idx_coarse[i]
            V_shaped[:, seg_idx, :] = V_betas[i, seg_idx, :]
            J[i] = (self.J_regressor.unsqueeze(-1) * V_betas[i].view(1, 1, -1, 3)).sum(dim=-2)

        # Fix this to obtain joint displacements
        Dv = torch.zeros((self.n_joints, 3), device=V.device)
        for seg in range(self.n_groups):
            gidx = self.gidx[seg]
            for p in gidx:
                if p > 0:
                    parent = self.parents[p]
                    # Look at where the joint would be in the shape of the parent
                    if self.seg_grouping[parent] == self.seg_grouping[p]:
                        dv = Dv[parent,:]
                    else:
                        dv = J[self.seg_grouping[parent], p, :] - J[self.seg_grouping[p], p, :]
                    Dv[p, :] = dv
                    idx = self.seg_idx[p]
                    V_shaped[:, idx, :] = V_shaped[:, idx, :] + dv
                    # find the children
                    idx = self.parent_idx[p]

                    for i in idx:
                        # Add the displacement to the children's joints as I have translated the whole part
                        J[self.seg_grouping[p], i, :] += dv

        J = (self.J_regressor.unsqueeze(-1) * V_shaped.view(1, 1, -1, 3)).sum(dim=-2)
        return V_shaped, J


In [None]:
# @title Template model definition

import kaolin as kal
import pickle as pkl
from torch.nn import functional as F


class DolphinModel:

    def __init__(self, data_dir: pathlib.Path, device=torch.device('cpu')):

        self.device = device
        data_dir = pathlib.Path(MODEL_DIR)

        self.vert2kpt = torch.from_numpy(np.load(data_dir / 'dol_verts2kp.npy')).to(device, dtype=torch.float)

        model = load_model(data_dir / 'dolphin_model.pkl')
        # self.J = torch.from_numpy(model.J.r).to(device, dtype=torch.float).unsqueeze(0)
        self.J_regressor = torch.from_numpy(model.J_regressor.todense()).to(device, dtype=torch.float).unsqueeze(0)
        self.V = torch.from_numpy(model.v_template.r).to(device, dtype=torch.float).unsqueeze(0)
        self.F = torch.from_numpy(model.f).to(device, dtype=torch.long)
        self.weights = torch.from_numpy(model.weights.r).to(device, dtype=torch.float)
        self.kintree_table = torch.from_numpy(model.kintree_table).to(device, dtype=torch.long)
        self.parents = self.kintree_table[0]

        kal_mesh = kal.io.obj.import_mesh(data_dir / 'dolphin_template.obj', with_materials=True)
        self.face_uvs = kal_mesh.face_uvs.to(device, dtype=torch.float).unsqueeze(0)

        with open(data_dir / 'dolphin_seg_gloss.pkl', 'rb') as f:
            segmentation = pkl.load(f)
        vert_parts = torch.from_numpy(segmentation['v_labels']).to(device, dtype=torch.long)

        with open(data_dir / 'dolphin_coarse_seg.pkl', 'rb') as f:
            coarse_parts = pkl.load(f)
        coarse_vert_parts = torch.from_numpy(coarse_parts['coarse_v_labels']).to(device, dtype=torch.long)
        coarse_parts_idx = coarse_parts['coarse_parts_idx']
        segmentation_grouping = -1 * torch.ones((BASE_MODEL_PARTS), dtype=int, device=self.device)
        for i in range(BASE_MODEL_PARTS):
            for j in range(len(coarse_parts_idx)):
                if i in coarse_parts_idx[j]:
                    segmentation_grouping[i] = j

        shapedirs = torch.from_numpy(model.shapedirs.r).to(device, dtype=torch.float).unsqueeze(0)
        self.LBS = LBS(self.J_regressor, self.parents, self.weights, shapedirs, {
            'grouping': segmentation_grouping, 'v_labels': vert_parts, 'v_labels_coarse': coarse_vert_parts
        })

        self.global_scale = BASE_DOLPHIN_SIZE / (self.V[0, :, 2].max() - self.V[0, :, 2].min())
        self.fwd_vector = torch.tensor([0.0, 0.0, 1.0], device=device)

        self.bone_selection = torch.tensor([1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], device=device)

    def __call__(self, global_pose, body_pose, translation, part_scales=None, pose2rot=True):

        # concatenate bone and pose
        pose = torch.cat([global_pose, body_pose], dim=1) * self.bone_selection.view(1, -1, 1)

        # LBS
        verts = self.LBS(self.V, pose, self.global_scale, betas=part_scales, to_rotmats=pose2rot)

        # Final output after articulation
        output = {'vertices': verts + translation}

        return output

After the template model, it's time to define the set of variables we will optimize.

In [None]:
import torch.nn as nn

NUM_FRAMES = len(dataset)
INIT_ROT_UP = 0.0
INIT_ROT_FWD = 0.0

device = 'cuda' if torch.cuda.is_available() else 'cpu'

dolphin_model = DolphinModel(MODEL_DIR, device)

# Create num_frames pose parameters
global_poses = nn.Parameter(torch.tensor([[[0.0, np.deg2rad(INIT_ROT_UP), np.deg2rad(INIT_ROT_FWD)]]], device=device).repeat(NUM_FRAMES, 1, 1), requires_grad=True)
joint_poses = nn.Parameter(torch.zeros([NUM_FRAMES, dolphin_model.LBS.n_groups, 3], device=device), requires_grad=True)
translations = nn.Parameter(torch.zeros([NUM_FRAMES, 1, 3], device=device), requires_grad=True)

# Dolphin shape can't change over time
part_scales = nn.Parameter(torch.zeros((dolphin_model.LBS.n_groups, 4), requires_grad=True, device=device))

# Create lighting parameters
azimuth = nn.Parameter(torch.full((1,), torch.pi / 2., device=device), requires_grad=True)
elevation = nn.Parameter(torch.full((1,), torch.pi / 2., device=device), requires_grad=True)
strength = nn.Parameter(torch.full((1,), 1.0, device=device), requires_grad=True)
angle = nn.Parameter(torch.full((1,), torch.pi / 2.0, device=device), requires_grad=True)

# Create textures
dolphin_albedo = nn.Parameter(torch.full([1, 1, 512, 512], 0.5, device=device), requires_grad=True)
dolphin_specular_color = nn.Parameter(torch.full([3,], 1.0, device=device), requires_grad=False)
dolphin_roughness = nn.Parameter(torch.full([1,], 0.1, device=device), requires_grad=False)

# Create water filter parameter
water_absorption = nn.Parameter(torch.tensor([1.0, 0.0, 0.0], device=device), requires_grad=True)


### Exercise 3

Due to our particular setting (we can assume the camera to be perfectly perpendicular to the object), we can employ the drone metadata and the SAM-2 masks to initialize the location of the dolphin in the scene, so that the dolphin object and the segmentation mask are intersecting at step 0.

In [None]:
import math
from torch.utils.data import DataLoader
from kaolin.render.camera.intrinsics import CameraFOV

# Employs the normalized center of pixel coordinates of the segmentation masks to compute
# an initialization of each frame's global translation which guarantees alignment between
# the rendered shape and the segmentation mask.

loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_batch)
for i, batch in tqdm(enumerate(loader), desc="Initializing translations..."):
    # We have to invert the following flow of operations:
    # 1. extrinsics projection (easy)
    # 2. intrinsics projection (probably doable only in our special case)
    # 3. normalization (doable only in our special case)
    camera = batch['cam']
    camera_altitude = -camera.extrinsics.t[0, 2, 0:1]
    # We want the y (up) coord for this point in 3D world space to be 0
    image_plane_translation = batch['segm_center'][0]
    # We know that the input for the intrinsic projection is (in this specific setting):
    # [original_x, original_z, original_y - camera_altitude]
    # But since original_y = 0, we get [original_x, original_z, -camera_altitude]
    # Also for all cameras, the normalizing coordinate (w) is exactly -input_z=-camera_altitude
    # Therefore we begin with step 3:
    unnormalized_coords = image_plane_translation * camera_altitude
    # We now only need the 3rd coordinate to invert the intrinsic projection
    intrinsic_proj = camera.intrinsics.projection_matrix()
    z_transform = intrinsic_proj[0, 2, :]
    # We put unnormalized_coords in the first two spots but it doesn't really matter (in our case)
    # since z_proj_output has 0 in the first 2 components
    intrinsic_proj_input = torch.cat([unnormalized_coords, -camera_altitude, torch.ones([1], dtype=torch.float)], dim=-1)
    z_proj_output = (z_transform * intrinsic_proj_input).sum(dim=-1, keepdim=True)
    projection_output = torch.cat([unnormalized_coords, z_proj_output, camera_altitude], dim=-1)
    # And we invert step 2
    projection_input = (intrinsic_proj[0].inverse() @ projection_output)[:3]
    # And we invert step 1
    original_point = camera.extrinsics.R[0].T @ (projection_input - camera.extrinsics.t[0, :, 0])
    translations.data[i, ...] = original_point.unsqueeze(0).to(device)

Now, for each frame $i$:
1. Obtain the initialized translation $T_i$
2. Use the camera object $C_i$ from the dataset to project it to pixel space
3. **Tracking**: Create a black image with the same resolution as the data. Then, find each pixel corresponding to some projection obtained at step (2). Color said pixels according to their temporal order and display the image.

In [None]:
# @title Your code here:

# 3. Fitting with differentiable rendering

At this stage, we need to implement a simple differentiable render, which will allow us to:
1. Render the current scene parameters into an image
2. Compare the output with the data frames
3. Compute losses and backpropagate gradients

We will use the DIB-R algorithm implemented in the NVIDIA kaolin library.

In [None]:
# @title [[Ignore]] Rendering helper functions

@torch.jit.script
def _dot(a, b):
    """Compute dot product of two tensors on the last axis."""
    return torch.sum(a * b, dim=-1, keepdim=True)

@torch.jit.script
def _ggx_v1(m2, nDotX):
    """Helper for computing the Smith visibility term with Trowbridge-Reitz (GGX) distribution"""
    return 1. / (nDotX + torch.sqrt(m2 + (1. - m2) * nDotX * nDotX))


def sg_warp_specular_term_full_diff(amplitude, direction, sharpness, normal,
                                    roughness, view, spec_albedo):
    assert amplitude.ndim == 2 and amplitude.shape[-1]
    assert direction.shape == amplitude.shape
    assert sharpness.shape == amplitude.shape[:1]
    assert normal.ndim == 2 and normal.shape[-1] == 3
    assert roughness.shape == normal.shape[:1]
    assert view.shape == normal.shape
    assert spec_albedo.shape == normal.shape
    ndf_amplitude, ndf_direction, ndf_sharpness = kal.render.lighting.sg_distribution_term(
        normal, roughness)
    ndf_amplitude, ndf_direction, ndf_sharpness = kal.render.lighting.sg_warp_distribution(
        ndf_amplitude, ndf_direction, ndf_sharpness, view)
    ndl = torch.clamp(_dot(normal, ndf_direction), min=0., max=1.)
    ndv = torch.clamp(_dot(normal, view), min=0., max=1.)
    h = ndf_direction + view
    h = h / torch.sqrt(_dot(h, h))
    ldh = torch.clamp(_dot(ndf_direction, h), min=0., max=1.)

    output = kal.render.lighting.sg.unbatched_reduced_sg_inner_product(
        ndf_amplitude, ndf_direction, ndf_sharpness, amplitude, direction, sharpness)
    m2 = (roughness * roughness).unsqueeze(-1)
    output = output * (_ggx_v1(m2, ndl) * _ggx_v1(m2, ndv))
    output = output * kal.render.lighting.fresnel(ldh, spec_albedo)
    output = output * ndl
    return torch.clamp(output, min=0.)

In [None]:
# @title Constants
RENDER_RESOLUTION = dataset.data_resolution
SIGMAINV = 100000   # Defines sharpness of the optimiz-able band of pixels around object
BOXLEN = 0.01       # Defines box of influence over pixels for each mesh element
KNUM = 40           # Maximum number of mesh elements contributing to each pixel

In [None]:
# @title Dolphin skinning

def dolphin_skinning(t):
    global_batch = global_poses[t, ...]
    joint_batch = joint_poses[t, ...].view(-1, dolphin_model.LBS.n_groups, 3)[:, dolphin_model.LBS.seg_grouping, :][:, 1:, :]  # filter out spine joint
    trans_batch = translations[t, ...]
    parts_batch = part_scales

    return dolphin_model(global_batch, joint_batch, trans_batch, part_scales=parts_batch)['vertices']


### Exercise 4: Basic rendering

This function is the core routine of our differentiable renderer. Complete the code below following this template:

1. $V_e$ = Apply the extrinsic camera transform to the mesh vertices, projecting them to "camera space". *(Hint: `cam.extrinsics.transform`*)
2. $V_i$ = Apply the intrinsic camera transform to $V_e$. The vertices are now projected to "render space", so we only care about its first two coordinates.
3. $F_e, F_i$ = Index the outputs of points 1 and 2 using the mesh's face indices. *(Hint: `kaolin.ops.mesh.index_vertices_by_faces`)*
4. $F_n$ = Compute face normals in camera space. *(Hint: `kaolin.ops.mesh.face_normals`)*
5. We need to define which mesh features we want DIB-R to interpolate. For texturing (later), we will interpolate face uvs and normals, and we will pass a tensor of 1s to obtain a hard rendering mask. We also want to interpolate the sea level height of each rendered mesh point. *(Hints: the up axis is y, the 2nd coordinate of vertex matrices. Also, the value needs to be per-face.)*
6. `rasterize` needs the following information:
    1. Depth (z coordinate) of face vertices in camera space
    2. Face vertices in render space
    3. Which faces are "looking" at the camera *(Hint: consider the z coordinate of their normals)*
7. `dibr_soft_mask` needs the face vertices in render space to compute the soft mask.

In [None]:
# @title Your code here:

def dibr_rendering(verts, cam):
    batch_size = verts.shape[0]
    mesh = kal.rep.SurfaceMesh(verts, dolphin_model.F, face_uvs=dolphin_model.face_uvs.repeat(batch_size, 1, 1, 1))
    mesh.face_uvs[..., 1] = 1 - mesh.face_uvs[..., 1]

    ???

    face_attributes = [
        mesh.face_uvs, mesh.face_normals, ???,
        torch.ones((batch_size, mesh.faces.shape[0], 3, 1), device=face_vertices_image.device)
    ]

    image_features, rendered_faces = kal.render.mesh.rasterize(RENDER_RESOLUTION[1], RENDER_RESOLUTION[0],
                                                               ???, ???, face_attributes, valid_faces=???, backend='cuda')
    soft_mask = kal.render.mesh.dibr_soft_mask(???, rendered_faces,
                                               SIGMAINV, BOXLEN, KNUM)[..., None]

    return image_features, soft_mask


## Completing the renderer

In [None]:
# @title Texturing
# @markdown The texture coordinates and hard mask we need here are the first and last face attributes we interpolated with DIB-R!

def texturing(texture_coords, hard_mask):
    batch_size = texture_coords.shape[0]
    image = kal.render.mesh.texture_mapping(texture_coords, dolphin_albedo.repeat(batch_size, 3, 1, 1), mode='bilinear')
    return image * hard_mask

### Exercise 5: Lighting effects

We will implement a simple exponential, frequency-based filtering on the image. This works as a post-processing and will remove specific frequencies from the image pixels depending on the sea depth of the rendered object. The filter frequencies are one of the parameters we optimize, `water_absorption`.

For pixels $ij$ where the sea level height is below 0, we want to modify the image color as: $$I_{ij} \leftarrow I_{ij} \cdot \exp(w \cdot l_{ij}) $$
Where $w$ indicates the `water_absorption` parameter and $l_{ij}$ is the sea level height at pixel $ij$.

In [None]:
# @title Your code here
# @markdown The image normals and sea level height we need here are the 2nd and 3rd face attributes we interpolated with DIB-R!

def lighting(cam, hard_mask, image, image_normals, sl_height_mapped):
    bool_mask = hard_mask.bool()[..., 0]
    direction = torch.cat(kal.ops.coords.spherical2cartesian(azimuth, elevation), dim=-1).unsqueeze(0)
    sg_sun = kal.render.lighting.SgLightingParameters.from_sun(direction, strength, angle)
    spec_albedo = hard_mask * dolphin_specular_color.view(1, 1, 1, -1)
    roughness = hard_mask * dolphin_roughness.view(1, 1, 1, -1)
    rays_d = -torch.stack(
        [kal.render.camera.generate_pinhole_rays(c)[1].reshape(*RENDER_RESOLUTION[::-1], 3) for c in cam], dim=0)
    diffuse_effect = kal.render.lighting.sg_diffuse_fitted(
        sg_sun.amplitude,
        direction,
        sg_sun.sharpness,
        image_normals[bool_mask, :],
        image[bool_mask, :]
    )
    specular_effect = sg_warp_specular_term_full_diff(
        sg_sun.amplitude,
        direction,
        sg_sun.sharpness,
        image_normals[bool_mask, :],
        roughness[bool_mask, :][..., 0],
        rays_d[bool_mask, :],
        spec_albedo[bool_mask, :]
    )
    image[bool_mask, :] = diffuse_effect + specular_effect

    # Apply water volume absorption:
    ???

    return image

## Exercise 6: Rendering function

Now, put all these sub-routines together to obtain a complete renderer. Here is a pseudo-code:

1. Deform the template vertices
2. Run the core differentiable rendering procedure
3. Apply texturing
4. Apply lighting

In [None]:
# @title Your code here:

from typing import List

def render(t: List[int], cam: kal.render.camera.Camera):
    ???
    return image, hard_mask, soft_mask

# 4. Loss function design

Our loss function will be composed of multiple terms, which we will now implement. Remember that we can only optimize the rendered image, the soft mask (both only in the pixels where the parametrized object was hit), and the scene parameters directly (*e.g.* for regularization).

## Data fidelity terms

In [None]:
def photometric_loss(render, original, hard_mask):
    bool_mask = hard_mask.bool()[..., 0]
    return F.mse_loss(render[bool_mask, :], original[bool_mask, :])

def iou_loss(silhouette, segmentation):
    return kal.metrics.render.mask_iou(silhouette[..., 0], segmentation[..., 0])

## Parameter distribution priors

In [None]:
def scale_prior(scale_params, unseen_axis):
    global_prior = (scale_params ** 2).mean()
    axes_scaling = scale_params[..., 1:]
    x_scale, y_scale, z_scale = axes_scaling.split(1, dim=-1)
    # Make unseen axis scaling similar to that of other axes
    if unseen_axis == 'x':
        per_axis_prior = (
            ((x_scale - y_scale.detach()) ** 2).sum(dim=-1) +
            ((x_scale - z_scale.detach()) ** 2).sum(dim=-1)
        ).mean()
    elif unseen_axis == 'y':
        per_axis_prior = (
            ((x_scale.detach() - y_scale) ** 2).sum(dim=-1) +
            ((y_scale - z_scale.detach()) ** 2).sum(dim=-1)
        ).mean()
    part_scales = scale_params[dolphin_model.LBS.seg_grouping]
    scale_s = ((part_scales[1:] - part_scales[dolphin_model.LBS.parents[1:]]) ** 2).mean()
    return {'global': global_prior, 'axes': per_axis_prior, 'smoothness': scale_s}

bone_prior_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0], device=device)  # Only one value for lower fins and tails (symmetry)

def pose_prior(joint_poses):
    return ((joint_poses ** 2) * bone_prior_weights.view(1, -1, 1)).mean()

## Temporal priors

In [None]:
def smoothness_prior(global_poses, joint_poses, translations):
    trans_s = ((translations[1:, ...] - translations[:-1, ...]) ** 2).mean()
    global_s = ((global_poses[1:, ...] - global_poses[:-1, ...]) ** 2).mean()
    joint_s = ((joint_poses[1:, ...] - joint_poses[:-1, ...]) ** 2).mean()
    return {'transl': trans_s, 'global': global_s, 'joint': joint_s}

## Putting it all together:

In [None]:
loss_weights = {
    'mask': 1.0,
    'scale': {
        'global': 0.001,
        'axes': 0.1,
        'smoothness': 5.0
    },
    'pose': 2.0,
    'smoothness': {
        'translations': 500.0,
        'global_poses': 0.0,
        'joint_poses': 500.0
    }
}

def compute_loss(render, silhouette, original, segmentation, hard_mask,
         scale_params, global_poses, joint_poses, translations):
    photometric = photometric_loss(render, original, hard_mask)
    mask = loss_weights['mask'] * iou_loss(silhouette, segmentation)
    pose = loss_weights['pose'] * pose_prior(joint_poses)

    scale = scale_prior(scale_params, 'y')
    scale_global = loss_weights['scale']['global'] * scale['global']
    scale_axes = loss_weights['scale']['axes'] * scale['axes']
    scale_smooth = loss_weights['scale']['smoothness'] * scale['smoothness']

    smoothness = smoothness_prior(global_poses, joint_poses, translations)
    smooth_translations = loss_weights['smoothness']['translations'] * smoothness['transl']
    smooth_gposes = loss_weights['smoothness']['global_poses'] * smoothness['global']
    smooth_jposes = loss_weights['smoothness']['joint_poses'] * smoothness['joint']

    return photometric + mask + pose + \
      scale_global + scale_axes + scale_smooth + \
      smooth_translations + smooth_gposes + smooth_jposes


# Optimization loop

We now design a simple gradient-descent based optimization loop to fit the parameters to our data.

In [None]:
# @title [[Ignore]] Optimization helper functions

def dict_to_device(data, device):
    return {k: (data[k].to(device) if hasattr(data[k], 'to') else data[k]) for k in data.keys()}

def safe_clamp(param, min=None, max=None):
    param.data.clamp_(min=min, max=max)

## Exercise 7: Optimization loop

Complete the optimization loop below using the rendering and loss functions we designed in the previous sections.

In [None]:
# @title Your code here:

STEPS = 50
BATCH_SIZE = 7
LRS = {
    'transform': 0.01,
    'pose': 0.01,
    'texture': 0.001,
    'lighting': 0.01
}

import torch.optim as opt


optimizer = opt.Adam([
    {'params': [translations, part_scales], 'lr': LRS['transform']},
    {'params': [global_poses, joint_poses], 'lr': LRS['pose']},
    {'params': [dolphin_albedo, water_absorption], 'lr': LRS['texture']},
    {'params': [azimuth, elevation, strength, angle], 'lr': LRS['lighting']}
])

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)

outer_loop = tqdm(range(STEPS))
for i in outer_loop:
    step_losses = []
    for frames in tqdm(loader, leave=False):
        frames = dict_to_device(frames, device)

        optimizer.zero_grad()

        ???

        loss.backward()
        optimizer.step()

        safe_clamp(water_absorption, 0, 1)
        safe_clamp(dolphin_albedo, 0, 1)
        safe_clamp(strength, min=0.0)
        safe_clamp(azimuth, min=0.0, max=2*torch.pi)
        safe_clamp(angle, min=0.0, max=2*torch.pi)
        safe_clamp(elevation, min=0.0, max=2*torch.pi)

        step_losses.append(loss.item())

    outer_loop.set_description(f'Loss: {np.array(step_losses).mean()}')

In [None]:
# Let's save the trained model:

torch.save(
    {
        'translations': translations,
        'part_scales': part_scales,
        'global_poses': global_poses,
        'joint_poses': joint_poses,
        'dolphin_albedo': dolphin_albedo,
        'water_absorption': water_absorption,
        'azimuth': azimuth,
        'elevation': elevation,
        'strength': strength,
        'angle': angle
    },
    f'{ROOT}/dolphin_model_fit.pth'
)

# Visualizing results

After optimization, we will combine 3D and 2D visualizations to evaluate how well our reconstruction algorithm performed.

In [None]:
import plotly.figure_factory as ff

PICK_FRAME = 100  # Something in [0, 199]
data = dict_to_device(collate_batch([dataset[PICK_FRAME]]), device)

with torch.no_grad():
    out_render, _, silhouette = render(data['t'], data['cam'])
    original, segmentation = data['original'], data['segmented']

    global_batch = global_poses[data['t'], ...]
    joint_batch = joint_poses[data['t'], ...].view(-1, dolphin_model.LBS.n_groups, 3)[:, dolphin_model.LBS.seg_grouping, :][:, 1:, :]  # filter out spine joint
    trans_batch = translations[data['t'], ...]
    parts_batch = part_scales
    skinned_verts = dolphin_model(global_batch, joint_batch, trans_batch, part_scales=parts_batch)['vertices']
    dolphin_mesh = [skinned_verts[0, ...].cpu().numpy(), dolphin_model.F.cpu().numpy()]

dolphin_verts, dolphin_faces = dolphin_mesh
fig = ff.create_trisurf(x=dolphin_verts[:, 0], y=dolphin_verts[:, 2], z=dolphin_verts[:, 1],
                        simplices=dolphin_faces)
fig.update_scenes(xaxis_range=[-2, 2], yaxis_range=[-2, 2], zaxis_range=[-3, 1])
fig.show()

## Exercise 8: Output video

Let's combine the results with the input data to obtain an informative short video showing the output of our method.

Iterate over all frames in the dataset in order and query our optimized model. For each frame, stack the original frame, segmentation, render and soft mask in the following way:
```
original frame | SAM-2 segmentation
---------------|-------------------
render         | DIB-r soft mask
```
For the render, to avoid a black background, let's do some compositing on the original frame: simply copy the original frame, and for all pixels where the DIB-R soft mask is greater than 0.5, replace the color of the original frame with the rendered one at the same pixel.

Once you're done, download the video and take a look!

In [None]:
import cv2

height, width = dataset.data_resolution[::-1]
video = cv2.VideoWriter(f'{ROOT}/dolphin_fit.avi', cv2.VideoWriter_fourcc(*'XVID'), 50.0, (width * 2, height * 2))

loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=collate_batch)
for data in loader:

    ???

    # frame has to be a numpy array with [0;1] values and shape (2 * height, 2 * width, 3)
    video.write((frame * 255).astype(np.uint8))

cv2.destroyAllWindows()
video.release()


Whoops, something went wrong: it's a known problem of working with the axis-angle representation, known as **gimbal lock**. Let's look for temporal discontinuities in the global orientation of the dolphin:

In [None]:
plt.plot(np.rad2deg(global_poses[:, 0, 0].cpu().detach().numpy()))
plt.plot(np.rad2deg(global_poses[:, 0, 1].cpu().detach().numpy()))
plt.plot(np.rad2deg(global_poses[:, 0, 2].cpu().detach().numpy()))
plt.title("Gimbal lock")
plt.show()