In [None]:
# Cell 1: Imports and Path Setup
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import copy
from datetime import datetime
from tqdm import tqdm


sys.path.insert(0, "/home/azhuravl/work/TrajectoryCrafter")
# os.chdir(trajcrafter_path)

import inference_orbits

sys.path.insert(0, "/home/azhuravl/work/TrajectoryCrafter/notebooks/28_08_25_trajectories")

import core, trajectory_generation, viser_utils

In [None]:
# Cell 2: Argument Setup
# Create opts manually for notebook use
parser = inference_orbits.get_parser()
opts_base = parser.parse_args([
    '--video_path', './test/videos/0-NNvgaTcVzAG0-r.mp4',  # Change this path
    '--radius', '1.0',
    '--device', 'cuda:0'
])

# Set common parameters
opts_base.weight_dtype = torch.bfloat16
opts_base.camera = "target"
opts_base.mode = "gradual"
opts_base.mask = True
opts_base.target_pose = [0, 90, opts_base.radius, 0, 0]  # right_90 example
opts_base.exp_name = f"test"

In [None]:
import torch.nn.functional as F

# Cell 4: Run Visualization
# Initialize visualization TrajCrafter
print("Initializing TrajCrafter for visualization...")
vis_crafter = core.TrajCrafterVisualization(opts_base)

# Extract scene data
print("Extracting scene data...")
scene_data = vis_crafter.extract_scene_data(opts_base)

In [None]:
from typing import Optional, Tuple


class WarperDebug(core.VisualizationWarper):

    def forward_warp(
        self,
        frame1: torch.Tensor,
        mask1: Optional[torch.Tensor],
        depth1: torch.Tensor,
        transformation1: torch.Tensor,
        transformation2: torch.Tensor,
        intrinsic1: torch.Tensor,
        intrinsic2: Optional[torch.Tensor],
        mask=False,
        twice=False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Given a frame1 and global transformations transformation1 and transformation2, warps frame1 to next view using
        bilinear splatting.
        All arrays should be torch tensors with batch dimension and channel first
        :param frame1: (b, 3, h, w). If frame1 is not in the range [-1, 1], either set is_image=False when calling
                        bilinear_splatting on frame within this function, or modify clipping in bilinear_splatting()
                        method accordingly.
        :param mask1: (b, 1, h, w) - 1 for known, 0 for unknown. Optional
        :param depth1: (b, 1, h, w)
        :param transformation1: (b, 4, 4) extrinsic transformation matrix of first view: [R, t; 0, 1]
        :param transformation2: (b, 4, 4) extrinsic transformation matrix of second view: [R, t; 0, 1]
        :param intrinsic1: (b, 3, 3) camera intrinsic matrix
        :param intrinsic2: (b, 3, 3) camera intrinsic matrix. Optional
        """
        if self.resolution is not None:
            assert frame1.shape[2:4] == self.resolution
        b, c, h, w = frame1.shape
        if mask1 is None:
            mask1 = torch.ones(size=(b, 1, h, w)).to(frame1)
        if intrinsic2 is None:
            intrinsic2 = intrinsic1.clone()

        assert frame1.shape == (b, 3, h, w)
        assert mask1.shape == (b, 1, h, w)
        assert depth1.shape == (b, 1, h, w)
        assert transformation1.shape == (b, 4, 4)
        assert transformation2.shape == (b, 4, 4)
        assert intrinsic1.shape == (b, 3, 3)
        assert intrinsic2.shape == (b, 3, 3)

        frame1 = frame1.to(self.device).to(self.dtype)
        mask1 = mask1.to(self.device).to(self.dtype)
        depth1 = depth1.to(self.device).to(self.dtype)
        transformation1 = transformation1.to(self.device).to(self.dtype)
        transformation2 = transformation2.to(self.device).to(self.dtype)
        intrinsic1 = intrinsic1.to(self.device).to(self.dtype)
        intrinsic2 = intrinsic2.to(self.device).to(self.dtype)

        trans_points1 = self.compute_transformed_points(
            depth1, transformation1, transformation2, intrinsic1, intrinsic2
        )
        trans_coordinates = (
            trans_points1[:, :, :, :2, 0] / trans_points1[:, :, :, 2:3, 0]
        )
        trans_depth1 = trans_points1[:, :, :, 2, 0]
        grid = self.create_grid(b, h, w).to(trans_coordinates)
        flow12 = trans_coordinates.permute(0, 3, 1, 2) - grid
        if not twice:
            warped_frame2, mask2 = self.bilinear_splatting(
                frame1, mask1, trans_depth1, flow12, None, is_image=True
            )
            if mask:
                warped_frame2, mask2 = self.clean_points(warped_frame2, mask2)
            return warped_frame2, mask2, None, flow12

        else:
            warped_frame2, mask2 = self.bilinear_splatting(
                frame1, mask1, trans_depth1, flow12, None, is_image=True
            )
            # warped_frame2, mask2 = self.clean_points(warped_frame2, mask2)
            warped_flow, _ = self.bilinear_splatting(
                flow12, mask1, trans_depth1, flow12, None, is_image=False
            )
            twice_warped_frame1, _ = self.bilinear_splatting(
                warped_frame2,
                mask2,
                depth1.squeeze(1),
                -warped_flow,
                None,
                is_image=True,
            )
            return twice_warped_frame1, warped_frame2, None, None
        
        
    def compute_transformed_points(
        self,
        depth1: torch.Tensor,
        transformation1: torch.Tensor,
        transformation2: torch.Tensor,
        intrinsic1: torch.Tensor,
        intrinsic2: Optional[torch.Tensor],
    ):
        """
        Computes transformed position for each pixel location
        """
        if self.resolution is not None:
            assert depth1.shape[2:4] == self.resolution
        b, _, h, w = depth1.shape
        if intrinsic2 is None:
            intrinsic2 = intrinsic1.clone()
        transformation = torch.bmm(
            transformation2, torch.linalg.inv(transformation1)
        )  # (b, 4, 4)

        x1d = torch.arange(0, w)[None]
        y1d = torch.arange(0, h)[:, None]
        x2d = x1d.repeat([h, 1]).to(depth1)  # (h, w)
        y2d = y1d.repeat([1, w]).to(depth1)  # (h, w)
        ones_2d = torch.ones(size=(h, w)).to(depth1)  # (h, w)
        ones_4d = ones_2d[None, :, :, None, None].repeat(
            [b, 1, 1, 1, 1]
        )  # (b, h, w, 1, 1)
        pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[
            None, :, :, :, None
        ]  # (1, h, w, 3, 1)

        intrinsic1_inv = torch.linalg.inv(intrinsic1)  # (b, 3, 3)
        intrinsic1_inv_4d = intrinsic1_inv[:, None, None]  # (b, 1, 1, 3, 3)
        intrinsic2_4d = intrinsic2[:, None, None]  # (b, 1, 1, 3, 3)
        depth_4d = depth1[:, 0][:, :, :, None, None]  # (b, h, w, 1, 1)
        trans_4d = transformation[:, None, None]  # (b, 1, 1, 4, 4)

        unnormalized_pos = torch.matmul(
            intrinsic1_inv_4d, pos_vectors_homo
        )  # (b, h, w, 3, 1)
        
        # world_point = actual 3D points in world space
        world_points = depth_4d * unnormalized_pos  # (b, h, w, 3, 1)
        world_points_homo = torch.cat([world_points, ones_4d], dim=3)  # (b, h, w, 4, 1)
        trans_world_homo = torch.matmul(trans_4d, world_points_homo)  # (b, h, w, 4, 1)
        trans_world = trans_world_homo[:, :, :, :3]  # (b, h, w, 3, 1)
        trans_norm_points = torch.matmul(intrinsic2_4d, trans_world)  # (b, h, w, 3, 1)
        return trans_norm_points, trans_world, world_points
    
    def extract_3d_points_with_colors(
        self,
        frame1: torch.Tensor,
        depth1: torch.Tensor,
        transformation1: torch.Tensor,
        intrinsic1: torch.Tensor,
        subsample_step: int = 10
    ):
        """Extract 3D world points and their corresponding colors for visualization"""
        b, c, h, w = frame1.shape
        
        # Move tensors to device
        frame1 = frame1.to(self.device).to(self.dtype)
        depth1 = depth1.to(self.device).to(self.dtype)
        transformation1 = transformation1.to(self.device).to(self.dtype)
        intrinsic1 = intrinsic1.to(self.device).to(self.dtype)
        
        # Create subsampled pixel coordinates for performance
        x_coords = torch.arange(0, w, subsample_step, dtype=torch.float32)
        y_coords = torch.arange(0, h, subsample_step, dtype=torch.float32)
        x2d, y2d = torch.meshgrid(x_coords, y_coords, indexing='xy')
        x2d = x2d.to(depth1.device)
        y2d = y2d.to(depth1.device)
        ones_2d = torch.ones_like(x2d)
        
        # Stack into homogeneous coordinates
        pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None]
        
        # Subsample depth and colors
        depth_sub = depth1[:, 0, ::subsample_step, ::subsample_step]
        colors_sub = frame1[:, :, ::subsample_step, ::subsample_step]
        
        # Unproject to 3D camera coordinates
        intrinsic1_inv = torch.linalg.inv(intrinsic1)
        intrinsic1_inv_4d = intrinsic1_inv[:, None, None]
        depth_4d = depth_sub[:, :, :, None, None]
        
        unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo)
        camera_points = depth_4d * unnormalized_pos
        
        # Transform to world coordinates
        ones_4d = torch.ones(b, camera_points.shape[1], camera_points.shape[2], 1, 1).to(depth1)
        world_points_homo = torch.cat([camera_points, ones_4d], dim=3)
        trans_4d = transformation1[:, None, None]
        world_points_homo = torch.matmul(trans_4d, world_points_homo)
        world_points = world_points_homo[:, :, :, :3, 0]  # (b, h_sub, w_sub, 3)
        
        
        
        # world_points = depth_4d * unnormalized_pos  # (b, h, w, 3, 1)
        # world_points_homo = torch.cat([world_points, ones_4d], dim=3)  # (b, h, w, 4, 1)
        trans_world_homo = torch.matmul(trans_4d, world_points_homo)  # (b, h, w, 4, 1)
        trans_world = trans_world_homo[:, :, :, :3]  # (b, h, w, 3, 1)
        trans_norm_points = torch.matmul(intrinsic2_4d, trans_world)  # (b, h, w, 3, 1)
        
        # Prepare colors
        colors = colors_sub.permute(0, 2, 3, 1)  # (b, h_sub, w_sub, 3)
        
        # Filter valid points (positive depth)
        valid_mask = depth_sub > 0  # (b, h_sub, w_sub)
        
        # Flatten and filter
        points_3d = world_points[valid_mask]  # (N, 3)
        colors_rgb = colors[valid_mask]       # (N, 3)
        
        return points_3d, colors_rgb, trans_world, trans_norm_points



In [None]:
scene_data.keys()

In [None]:
vis_warper = WarperDebug(device=opts_base.device)

num_frames = scene_data['frames_tensor'].shape[0]

all_trans_norm_points = []
all_trans_world_points = []
all_world_points = []

for i in tqdm(range(num_frames), desc="Processing frames"):

    trans_norm_points, trans_world, world_points = vis_warper.compute_transformed_points(
        depth1=scene_data['depths'][i:i+1].to(vis_warper.device),
        transformation1=scene_data['pose_source'][i:i+1].to(vis_warper.device),
        transformation2=scene_data['pose_target'][i:i+1].to(vis_warper.device),
        intrinsic1=scene_data['intrinsics'][i:i+1].to(vis_warper.device),
        intrinsic2=None
    )

    all_trans_norm_points.append(trans_norm_points)
    all_trans_world_points.append(trans_world)
    all_world_points.append(world_points)


In [None]:
all_trans_norm_points[0].shape, all_trans_world_points[0].shape, all_world_points[0].shape

In [None]:
# Cell 1: Create Viser Server (run once)
import viser

# Check if server already exists and stop it
try:
    if 'viser_server' in globals() and viser_server is not None:
        print("Stopping existing server...")
        viser_server.stop()
        del viser_server
except:
    pass

# Create new server
print("Creating new Viser server on port 8080...")
viser_server = viser.ViserServer(port=8080)
print("Server started successfully!")

In [None]:
# Clear existing
try:
    viser_server.scene.remove("/points")
    viser_server.scene.remove("/camera_source")
    viser_server.scene.remove("/camera_target")
except:
    pass


# Show all_trans_points1[0] in viser with camera frustums
i = 25


def extract_xyz_points(points_raw):
    # points_raw: (1, h, w, 3, 1)

    subsample_factor = 100

    points_3d = points_raw[0, :, :, :, 0].cpu().numpy()  # (h, w, 3)
    points_3d = points_3d.reshape(-1, 3)  # (N, 3)
    
    points_3d = points_3d[::subsample_factor]  # Subsample for visualization
    return points_3d

# Extract points and poses
world_points_3d = extract_xyz_points(all_world_points[i])
trans_world_points_3d = extract_xyz_points(all_trans_world_points[i])


pose_source = scene_data['pose_source'][i:i+1].cpu().numpy()[0]  # (4, 4)
pose_target = scene_data['pose_target'][i:i+1].cpu().numpy()[0]  # (4, 4)



# Add 3D points
viser_server.scene.add_point_cloud(
    "/points",
    points=world_points_3d,
    colors=(0.0, 0.0, 0.0),  # Red points
    point_size=0.05
)

# Add transformed points
viser_server.scene.add_point_cloud(
    "/transformed_points",
    points=trans_world_points_3d,
    colors=(1.0, 0.0, 0.0),  # Blue points
    point_size=0.05
)

# show world axes
viser_server.scene.add_frame("/world", axes_length=0.5, position=(0, 0, 0), wxyz=(1, 0, 0, 0))

# Add source camera (green)
viser_server.scene.add_camera_frustum(
    "/camera_source",
    fov=60, aspect=16/9, scale=0.2,
    position=pose_source[:3, 3],
    wxyz=viser.transforms.SO3.from_matrix(pose_source[:3, :3]).wxyz,
    color=(0.0, 1.0, 0.0)
)

# Add target camera (blue)
viser_server.scene.add_camera_frustum(
    "/camera_target", 
    fov=60, aspect=16/9, scale=0.2,
    position=pose_target[:3, 3],
    wxyz=viser.transforms.SO3.from_matrix(pose_target[:3, :3]).wxyz,
    color=(0.0, 0.0, 1.0)
)

print(f"Showing {len(world_points_3d)}")

In [None]:
# plot all target cameras

for i in range(num_frames):
    pose_target = scene_data['pose_target'][i:i+1].cpu().numpy()[0]  # (4, 4)
    
    # invert the pose_target
    pose_target_inv = np.linalg.inv(pose_target)
    
    
    viser_server.scene.add_camera_frustum(
        f"/camera_target_{i}", 
        fov=60, aspect=16/9, scale=0.1,
        position=pose_target_inv[:3, 3],
        wxyz=viser.transforms.SO3.from_matrix(pose_target_inv[:3, :3]).wxyz,
        color=(0.0, 0.0, 1.0)
    )
    
    # remove all target cameras
    try:
        viser_server.scene.remove(f"/camera_target_{i}")
    except:
        pass

In [None]:
points_3d

In [None]:
# Create warper for 3D point extraction
print("Creating 3D point cloud from all frames...")
vis_warper = WarperDebug(device=opts_base.device)

# Extract points from all frames
all_points_3d = []
all_colors_rgb = []
all_trans_world = []
all_trans_norm_points = []

warped_images = []
masks = []

num_frames = scene_data['frames_tensor'].shape[0]
for i in tqdm(range(num_frames), desc="Processing frames"):
    frame_data = {
        'frame': scene_data['frames_tensor'][i:i+1],
        'depth': scene_data['depths'][i:i+1], 
        'pose_source': scene_data['pose_source'][i:i+1],
        'intrinsics': scene_data['intrinsics'][i:i+1],
    }

    points_3d_frame, colors_rgb_frame, trans_world_frame, trans_norm_points_frame = vis_warper.extract_3d_points_with_colors(
        frame_data['frame'],
        frame_data['depth'], 
        frame_data['pose_source'],
        frame_data['intrinsics'],
        subsample_step=20  # Increased for performance with multiple frames
    )
    
    if points_3d_frame.shape[0] > 0:  # Only add if we have valid points
        all_points_3d.append(points_3d_frame)
        all_colors_rgb.append(colors_rgb_frame)
        all_trans_world.append(trans_world_frame)
        all_trans_norm_points.append(trans_norm_points_frame)

    warped_frame2, mask2, warped_depth2, flow12 = vis_warper.forward_warp(
        scene_data['frames_tensor'][i:i+1],
        None,
        scene_data['depths'][i:i+1],
        scene_data['pose_source'][i:i+1],
        scene_data['pose_target'][i:i+1],
        scene_data['intrinsics'][i:i+1],
        None,
        opts_base.mask,
        twice=False,
    )
    warped_images.append(warped_frame2)
    masks.append(mask2)

# Concatenate all points
if all_points_3d:
    points_3d = torch.cat(all_points_3d, dim=0)
    colors_rgb = torch.cat(all_colors_rgb, dim=0)
    trans_world = torch.cat(all_trans_world, dim=0)
    trans_norm_points = torch.cat(all_trans_norm_points, dim=0)
    print(f"Generated {points_3d.shape[0]} 3D points from {len(all_points_3d)} frames")
else:
    print("No valid 3D points extracted!")
    points_3d = None
    colors_rgb = None

print(f"Camera trajectory: {scene_data['pose_target'].shape[0]} poses")


cond_video = (torch.cat(warped_images) + 1.0) / 2.0
cond_masks = torch.cat(masks)

cond_video = F.interpolate(
    cond_video, size=opts_base.sample_size, mode='bilinear', align_corners=False
)
cond_masks = F.interpolate(cond_masks, size=opts_base.sample_size, mode='nearest')