In [1]:
# Cell 1: Imports and Path Setup
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import copy
from datetime import datetime
from tqdm import tqdm

# Add TrajectoryCrafter to Python path
trajcrafter_path = "/home/azhuravl/work/TrajectoryCrafter"
sys.path.insert(0, trajcrafter_path)

# Change working directory to TrajectoryCrafter
os.chdir(trajcrafter_path)

# Now import TrajectoryCrafter modules
from demo import TrajCrafter
from models.utils import Warper, read_video_frames
from models.infer import DepthCrafterDemo
import inference_orbits

print("Imports successful!")
print(f"Working directory: {os.getcwd()}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [2]:
# Cell 2: Camera Setup for run_w_cam_poses workflow
import torch
import numpy as np
from datetime import datetime
import os
import cv2

# Create opts manually for notebook use (matching run_w_cam_poses structure)
class Opts:
    def __init__(self):
        # Video settings
        self.video_path = '/home/azhuravl/work/panoptic-toolbox/150821_dance4/vgaVideos/vga_05_08.mp4'
        self.video_length = 49
        self.fps = 10
        self.stride = 1
        self.max_res = 1024
        
        # Device
        self.device = 'cuda:0'
        self.weight_dtype = torch.bfloat16
        
        # Output
        timestamp = datetime.now().strftime("%Y%m%d_%H%M")
        video_basename = os.path.splitext(os.path.basename(self.video_path))[0]
        self.exp_name = f"{video_basename}_{timestamp}_cam_pose_vis"
        
        # Depth
        self.near = 0.0001
        self.far = 10000.0
        self.depth_inference_steps = 5
        self.depth_guidance_scale = 1.0
        self.window_size = 110
        self.overlap = 25
        
        # Camera
        self.mask = False
        self.seed = 43
        
        # Paths (matching run_w_cam_poses)
        self.blip_path = "checkpoints/blip2-opt-2.7b"
        self.unet_path = "checkpoints/DepthCrafter"
        self.pre_train_path = "checkpoints/stable-video-diffusion-img2vid"
        self.cpu_offload = 'model'

opts_base = Opts()

# Define source and target cameras (matching your run_w_cam_poses example)
source_camera = {
    "name": "05_08",
    "type": "vga",
    "resolution": [640,480],
    "panel": 5,
    "node": 8,
    "K": [
        [748.194573,0.403304,388.156644],
        [0,747.455308,257.075025],
        [0,0,1]
    ],
    "distCoef": [-0.352118,0.186737,0,0,-0.119772],
    "R": [
        [-0.871497831,-0.004279560553,0.4903806847],
        [0.09322575792,0.980281239,0.1742344701],
        [-0.4814566321,0.1975610738,-0.8539140083]
    ],
    "t": [
        [-0.03934843898], #[-39.34843898],
        [0.09250008112], #[92.50008112],
        [0.3049007109] #[304.9007109]
    ]
}

target_camera = {
    "name": "01_01",
    "type": "vga",
    "resolution": [640,480],
    "panel": 1,
    "node": 1,
    "K": [
        [748.561374,0.083459,378.041653],
        [0,748.351299,223.336713],
        [0,0,1]
    ],
    "distCoef": [-0.32211,0.02854,0,0,0.101902],
    "R": [
        [-0.9610410199,0.02955079861,-0.2748215937],
        [0.005847346208,0.9962196747,0.08667276504],
        [0.2763439281,0.08168910551,-0.957580766]
    ],
    "t": [
        [-0.04625903829], #[-46.25903829],
        [0.1435237551], #[143.5237551],
        [0.2871962273] #[287.1962273]
    ]
}

print(f"Video: {opts_base.video_path}")
print(f"Source camera: {source_camera['name']} (panel {source_camera['panel']}, node {source_camera['node']})")
print(f"Target camera: {target_camera['name']} (panel {target_camera['panel']}, node {target_camera['node']})")
print(f"Device: {opts_base.device}")
print(f"Video length: {opts_base.video_length} frames")
print(f"Experiment name: {opts_base.exp_name}")

In [8]:
from typing import Optional
from run_w_cam_poses import CameraPoseTrajCrafter

# Cell 3: Visualization Classes
class VisualizationWarper(Warper):
    """Extended Warper class for 3D visualization"""
    
    def extract_3d_points_with_colors(
        self,
        frame1: torch.Tensor,
        depth1: torch.Tensor,
        transformation_source: torch.Tensor,
        transformation_target: torch.Tensor,
        intrinsic_source: torch.Tensor,
        intrinsic_target: Optional[torch.Tensor] = None,
        subsample_step: int = 10
    ):
        """
        Extract 3D world points and their corresponding colors for visualization
        
        Args:
            frame1: (b, c, h, w) input frame
            depth1: (b, 1, h, w) depth map
            transformation_source: (b, 4, 4) source camera extrinsic matrix [R, t; 0, 1]
            transformation_target: (b, 4, 4) target camera extrinsic matrix [R, t; 0, 1]
            intrinsic_source: (b, 3, 3) source camera intrinsic matrix
            intrinsic_target: (b, 3, 3) target camera intrinsic matrix (optional, defaults to source)
            subsample_step: int, subsampling step for performance
        """
        b, c, h, w = frame1.shape
        
        # Move tensors to device
        frame1 = frame1.to(self.device).to(self.dtype)
        depth1 = depth1.to(self.device).to(self.dtype)
        transformation_source = transformation_source.to(self.device).to(self.dtype)
        transformation_target = transformation_target.to(self.device).to(self.dtype)
        intrinsic_source = intrinsic_source.to(self.device).to(self.dtype)
        
        if intrinsic_target is None:
            intrinsic_target = intrinsic_source.clone()
        intrinsic_target = intrinsic_target.to(self.device).to(self.dtype)
        
        # Create subsampled pixel coordinates for performance
        x_coords = torch.arange(0, w, subsample_step, dtype=torch.float32)
        y_coords = torch.arange(0, h, subsample_step, dtype=torch.float32)
        x2d, y2d = torch.meshgrid(x_coords, y_coords, indexing='xy')
        x2d = x2d.to(depth1.device)
        y2d = y2d.to(depth1.device)
        ones_2d = torch.ones_like(x2d)
        
        # Stack into homogeneous coordinates
        pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None]  # (1, h_sub, w_sub, 3, 1)
        
        # Subsample depth and colors
        depth_sub = depth1[:, 0, ::subsample_step, ::subsample_step]  # (b, h_sub, w_sub)
        colors_sub = frame1[:, :, ::subsample_step, ::subsample_step]  # (b, c, h_sub, w_sub)
        
        # Unproject to 3D camera coordinates (source camera space)
        intrinsic_source_inv = torch.linalg.inv(intrinsic_source)  # (b, 3, 3)
        intrinsic_source_inv_4d = intrinsic_source_inv[:, None, None]  # (b, 1, 1, 3, 3)
        depth_4d = depth_sub[:, :, :, None, None]  # (b, h_sub, w_sub, 1, 1)
        
        # Get 3D points in source camera coordinate system
        unnormalized_pos = torch.matmul(intrinsic_source_inv_4d, pos_vectors_homo)  # (b, h_sub, w_sub, 3, 1)
        camera_points_source = depth_4d * unnormalized_pos  # (b, h_sub, w_sub, 3, 1)
        
        # Transform to world coordinates using source camera transformation
        ones_4d = torch.ones(b, camera_points_source.shape[1], camera_points_source.shape[2], 1, 1, device=depth1.device)
        camera_points_source_homo = torch.cat([camera_points_source, ones_4d], dim=3)  # (b, h_sub, w_sub, 4, 1)
        transformation_source_4d = transformation_source[:, None, None]  # (b, 1, 1, 4, 4)
        
        # Transform from source camera space to world space
        world_points_homo = torch.matmul(transformation_source_4d, camera_points_source_homo)  # (b, h_sub, w_sub, 4, 1)
        world_points = world_points_homo[:, :, :, :3, 0]  # (b, h_sub, w_sub, 3)
        
        # Prepare colors (convert from [-1,1] to [0,1] if needed)
        colors = colors_sub.permute(0, 2, 3, 1)  # (b, h_sub, w_sub, 3)
        
        # Filter valid points (positive depth)
        valid_mask = depth_sub > 0  # (b, h_sub, w_sub)
        
        # Flatten and filter
        points_3d = world_points[valid_mask]  # (N, 3)
        colors_rgb = colors[valid_mask]       # (N, 3)
        
        return points_3d, colors_rgb


class TrajCrafterVisualization(CameraPoseTrajCrafter):
    """Lightweight TrajCrafter subclass for camera trajectory visualization"""
    
    def __init__(self, opts):
        # Only initialize what we need for pose generation and depth estimation
        self.device = opts.device
        self.depth_estimater = DepthCrafterDemo(
            unet_path=opts.unet_path,
            pre_train_path=opts.pre_train_path,
            cpu_offload=opts.cpu_offload,
            device=opts.device,
        )
        print("TrajCrafterVisualization initialized (diffusion pipeline skipped)")
    
    def extract_scene_data(self, opts, source_camera, target_camera):
        """Extract all data needed for 3D visualization following CameraPoseTrajCrafter workflow"""
        
        print("Reading video frames...")
        frames = self.read_video_frames(
            opts.video_path, opts.video_length, opts.stride, opts.max_res
        )
        
        # Pad frames if necessary
        if frames.shape[0] < opts.video_length:
            last_frame = frames[-1:]
            num_pad = opts.video_length - frames.shape[0]
            pad_frames = np.repeat(last_frame, num_pad, axis=0)
            frames = np.concatenate([frames, pad_frames], axis=0)
            print(f"Padding video from {frames.shape[0]} to {opts.video_length} frames")
            
        # Undistort the frames using source camera distortion coefficients
        print("Undistorting frames using source camera distortion coefficients...")
        frames, undistorted_K = self.undistort_frames(frames, source_camera)
        
        # Update source camera with undistorted intrinsics
        source_camera_undistorted = source_camera.copy()
        source_camera_undistorted["K"] = undistorted_K.tolist()
        source_camera_undistorted["distCoef"] = [0.0, 0.0, 0.0, 0.0, 0.0]
        
        print("Estimating depth...")
        depths = self.depth_estimater.infer(
            frames,
            opts.near,
            opts.far,
            opts.depth_inference_steps,
            opts.depth_guidance_scale,
            window_size=opts.window_size,
            overlap=opts.overlap,
        ).to(opts.device)
        
        print("Converting frames to tensors...")
        frames_tensor = (
            torch.from_numpy(frames).permute(0, 3, 1, 2).to(opts.device) * 2.0 - 1.0
        )
        
        print("Converting camera poses...")
        # Use undistorted source camera and target camera
        source_c2w, source_K = self.convert_camera_format(source_camera_undistorted)
        target_c2w, target_K = self.convert_camera_format(target_camera)

        pose_s = source_c2w.to(opts.device).unsqueeze(0).repeat(opts.video_length, 1, 1)
        pose_t = target_c2w.to(opts.device).unsqueeze(0).repeat(opts.video_length, 1, 1)
        
        # Use target intrinsics for all frames (following your workflow)
        K_matrices_s = source_K.to(opts.device).unsqueeze(0).repeat(opts.video_length, 1, 1)
        K_matrices_t = target_K.to(opts.device).unsqueeze(0).repeat(opts.video_length, 1, 1)
        print(f"Using target camera intrinsics for all frames.")
        
        return {
            'frames_numpy': frames,
            'frames_tensor': frames_tensor,
            'depths': depths,
            'pose_source': pose_s,
            'pose_target': pose_t,
            'intrinsics_source': K_matrices_s,
            'intrinsics_target': K_matrices_t,
            'source_camera_undistorted': source_camera_undistorted,
            'target_camera': target_camera,
            'video_length': opts.video_length
        }

In [9]:
# Cell 4: Run Visualization
# Initialize visualization TrajCrafter
print("Initializing TrajCrafter for visualization...")
vis_crafter = TrajCrafterVisualization(opts_base)

# Extract scene data
print("Extracting scene data...")
scene_data = vis_crafter.extract_scene_data(opts_base, source_camera, target_camera)  # Added camera parameters

# Create warper for 3D point extraction
print("Creating 3D point cloud from all frames...")
vis_warper = VisualizationWarper(device=opts_base.device)

# Extract points from all frames
all_points_3d = []
all_colors_rgb = []

num_frames = scene_data['frames_tensor'].shape[0]
for i in tqdm(range(num_frames), desc="Processing frames"):
    frame_data = {
        'frame': scene_data['frames_tensor'][i:i+1],
        'depth': scene_data['depths'][i:i+1], 
        'pose_source': scene_data['pose_source'][i:i+1],
        'pose_target': scene_data['pose_target'][i:i+1],  # Added target pose
        'intrinsics_source': scene_data['intrinsics_source'][i:i+1],  # Updated key name
        'intrinsics_target': scene_data['intrinsics_target'][i:i+1],  # Added target intrinsics
    }
    
    # Use the updated method with both source and target cameras
    points_3d_frame, colors_rgb_frame = vis_warper.extract_3d_points_with_colors(
        frame_data['frame'],
        frame_data['depth'], 
        frame_data['pose_source'],      # Source camera transformation
        frame_data['pose_target'],      # Target camera transformation
        frame_data['intrinsics_source'], # Source camera intrinsics
        frame_data['intrinsics_target'], # Target camera intrinsics
        subsample_step=20  # Increased for performance with multiple frames
    )
    
    if points_3d_frame.shape[0] > 0:  # Only add if we have valid points
        all_points_3d.append(points_3d_frame)
        all_colors_rgb.append(colors_rgb_frame)

# Concatenate all points
if all_points_3d:
    points_3d = torch.cat(all_points_3d, dim=0)
    colors_rgb = torch.cat(all_colors_rgb, dim=0)
    print(f"Generated {points_3d.shape[0]} 3D points from {len(all_points_3d)} frames")
else:
    print("No valid 3D points extracted!")
    points_3d = None
    colors_rgb = None

print(f"Source camera trajectory: {scene_data['pose_source'].shape[0]} poses")
print(f"Target camera trajectory: {scene_data['pose_target'].shape[0]} poses")


In [13]:
len(all_points_3d)

In [None]:
import viser
import threading
import time

print("Creating animated viser visualization...")

# Cell 1: Create Viser Server (run once)
# Check if server already exists and stop it
try:
    if 'viser_server' in globals() and viser_server is not None:
        print("Stopping existing server...")
        viser_server.stop()
        del viser_server
except:
    pass

# Create new server
print("Creating new Viser server on port 8080...")
viser_server = viser.ViserServer(port=8080)
print("Server started successfully!")

# Cell 2: Animated Viser Content for run_w_cam_poses
def setup_viser_scene(server, scene_data):
    """Setup static scene elements (trajectories and camera poses)"""
    
    # Source camera trajectory (green)
    source_poses_np = scene_data['pose_source'].cpu().numpy()
    source_positions = source_poses_np[:, :3, 3]
    
    # Target camera trajectory (red) 
    target_poses_np = scene_data['pose_target'].cpu().numpy()
    target_positions = target_poses_np[:, :3, 3]
    
    # Add trajectories (static)
    server.scene.add_spline_catmull_rom(
        "/source_trajectory", 
        positions=source_positions, 
        color=(0.0, 1.0, 0.0),  # Green
        line_width=3.0
    )
    
    server.scene.add_spline_catmull_rom(
        "/target_trajectory", 
        positions=target_positions, 
        color=(1.0, 0.0, 0.0),  # Red
        line_width=3.0
    )
    
    # Add source camera poses (static, green, every 5th to reduce clutter)
    for i, pose in enumerate(source_poses_np[::5]):
        position = pose[:3, 3]
        rotation_matrix = pose[:3, :3]
        
        # Convert rotation to quaternion
        wxyz = viser.transforms.SO3.from_matrix(rotation_matrix).wxyz
        
        server.scene.add_camera_frustum(
            f"/source_camera_{i}",
            fov=60, aspect=4/3, scale=0.08,  # Smaller scale
            position=position, wxyz=wxyz,
            color=(0.2, 0.8, 0.2)  # Light green
        )
    
    # Add target camera poses (static, red, every 5th to reduce clutter)
    for i, pose in enumerate(target_poses_np[::5]):
        position = pose[:3, 3]
        rotation_matrix = pose[:3, :3]
        
        # Convert rotation to quaternion
        wxyz = viser.transforms.SO3.from_matrix(rotation_matrix).wxyz
        
        server.scene.add_camera_frustum(
            f"/target_camera_{i}",
            fov=60, aspect=4/3, scale=0.08,  # Smaller scale
            position=position, wxyz=wxyz,
            color=(0.8, 0.2, 0.2)  # Light red
        )
    
    # Add start/end markers
    server.scene.add_icosphere("/source_start", radius=0.05, position=source_positions[0], color=(0.0, 1.0, 0.0))
    server.scene.add_icosphere("/source_end", radius=0.05, position=source_positions[-1], color=(0.0, 0.5, 0.0))
    server.scene.add_icosphere("/target_start", radius=0.05, position=target_positions[0], color=(1.0, 0.0, 0.0))
    server.scene.add_icosphere("/target_end", radius=0.05, position=target_positions[-1], color=(0.5, 0.0, 0.0))
    
    server.scene.add_frame("/world", axes_length=0.2, position=(0, 0, 0), wxyz=(1, 0, 0, 0))

def animate_frame(server, scene_data, frame_idx, max_points=8000):
    """Update point cloud and current camera positions for given frame"""
    # Clear previous frame
    try:
        server.scene.remove("/current_frame")
        server.scene.remove("/current_source_camera")
        server.scene.remove("/current_target_camera")
    except:
        pass
    
    # Extract points for this frame using both cameras
    frame_data = {
        'frame': scene_data['frames_tensor'][frame_idx:frame_idx+1],
        'depth': scene_data['depths'][frame_idx:frame_idx+1], 
        'pose_source': scene_data['pose_source'][frame_idx:frame_idx+1],
        'pose_target': scene_data['pose_target'][frame_idx:frame_idx+1],
        'intrinsics_source': scene_data['intrinsics_source'][frame_idx:frame_idx+1],
        'intrinsics_target': scene_data['intrinsics_target'][frame_idx:frame_idx+1],
    }
    
    # Use the updated method with both source and target cameras
    points_3d, colors_rgb = vis_warper.extract_3d_points_with_colors(
        frame_data['frame'],
        frame_data['depth'], 
        frame_data['pose_source'],      # Source camera transformation
        frame_data['pose_target'],      # Target camera transformation
        frame_data['intrinsics_source'], # Source camera intrinsics
        frame_data['intrinsics_target'], # Target camera intrinsics
        subsample_step=15
    )
    
    if points_3d.shape[0] > 0:
        points_np = points_3d.cpu().numpy()
        colors_np = colors_rgb.cpu().numpy()
        
        # Limit points for performance
        if len(points_np) > max_points:
            indices = np.random.choice(len(points_np), max_points, replace=False)
            points_np = points_np[indices]
            colors_np = colors_np[indices]
        
        # Normalize colors if needed
        if colors_np.min() < 0:
            colors_np = (colors_np + 1) / 2
            
        # Update point cloud
        server.scene.add_point_cloud(
            "/current_frame", 
            points=points_np, 
            colors=colors_np, 
            point_size=0.08
        )
        
        # Highlight current source camera (bright green)
        source_pos = scene_data['pose_source'][frame_idx, :3, 3].cpu().numpy()
        source_rot = scene_data['pose_source'][frame_idx, :3, :3].cpu().numpy()
        source_wxyz = viser.transforms.SO3.from_matrix(source_rot).wxyz
        
        server.scene.add_camera_frustum(
            "/current_source_camera",
            fov=60, aspect=4/3, scale=0.15,
            position=source_pos, wxyz=source_wxyz,
            color=(0.0, 1.0, 0.0)  # Bright green
        )
        
        # Highlight current target camera (bright red)
        target_pos = scene_data['pose_target'][frame_idx, :3, 3].cpu().numpy()
        target_rot = scene_data['pose_target'][frame_idx, :3, :3].cpu().numpy()
        target_wxyz = viser.transforms.SO3.from_matrix(target_rot).wxyz
        
        server.scene.add_camera_frustum(
            "/current_target_camera",
            fov=60, aspect=4/3, scale=0.15,
            position=target_pos, wxyz=target_wxyz,
            color=(1.0, 0.0, 0.0)  # Bright red
        )

# Setup scene
setup_viser_scene(viser_server, scene_data)

# Create GUI controls for animation
@viser_server.on_client_connect
def _(client: viser.ClientHandle) -> None:
    # Animation controls
    with client.gui.add_folder("Animation Controls"):
        play_button = client.gui.add_button("▶️ Play/Pause")
        frame_slider = client.gui.add_slider(
            "Frame", 
            min=0, 
            max=scene_data['frames_tensor'].shape[0]-1, 
            step=1, 
            initial_value=0
        )
        speed_slider = client.gui.add_slider(
            "Speed", 
            min=0.5, 
            max=10.0, 
            step=0.5, 
            initial_value=2.0
        )
        max_points_slider = client.gui.add_slider(
            "Max Points", 
            min=1000, 
            max=20000, 
            step=1000, 
            initial_value=8000
        )
    
    with client.gui.add_folder("Camera Info"):
        source_cam_info = client.gui.add_text("Source Camera", initial_value=scene_data.get('source_camera_undistorted', {}).get('name', 'Unknown'))
        target_cam_info = client.gui.add_text("Target Camera", initial_value=scene_data.get('target_camera', {}).get('name', 'Unknown'))
        frame_info = client.gui.add_text("Current Frame", initial_value="0")
    
    # Animation state
    is_playing = [False]
    
    @play_button.on_click
    def _(_):
        is_playing[0] = not is_playing[0]
        play_button.name = "⏸️ Pause" if is_playing[0] else "▶️ Play"
        
    @frame_slider.on_update
    def _(_):
        animate_frame(viser_server, scene_data, frame_slider.value, max_points_slider.value)
        frame_info.value = f"Frame {frame_slider.value}/{scene_data['frames_tensor'].shape[0]-1}"
    
    # Animation loop
    def animation_loop():
        while True:
            if is_playing[0]:
                current_frame = frame_slider.value
                next_frame = (current_frame + 1) % scene_data['frames_tensor'].shape[0]
                frame_slider.value = next_frame
                animate_frame(viser_server, scene_data, next_frame, max_points_slider.value)
                frame_info.value = f"Frame {next_frame}/{scene_data['frames_tensor'].shape[0]-1}"
            time.sleep(1.0 / speed_slider.value)
    
    # Start animation thread
    animation_thread = threading.Thread(target=animation_loop, daemon=True)
    animation_thread.start()

# Show first frame
animate_frame(viser_server, scene_data, 0)

print(f"Animated Viser server running at: http://localhost:8080")
print("Controls:")
print("- Green trajectory/cameras = Source camera path") 
print("- Red trajectory/cameras = Target camera path")
print("- Bright green/red frustums = Current frame cameras")
print("- Point cloud updates show 3D scene from current frame")
print("Use the GUI controls to play/pause animation and adjust settings")

In [None]:
# Cell 2: Animated Viser Content
def setup_viser_scene(server, scene_data):
    """Setup static scene elements (trajectory and camera poses)"""

    poses_np = scene_data['pose_target'].cpu().numpy()
    positions = poses_np[:, :3, 3]
    
    # Add trajectory (static)
    server.scene.add_spline_catmull_rom(
        "/trajectory", 
        positions=positions, 
        color=(1.0, 0.0, 0.0), 
        line_width=3.0
    )
    
    # Add all camera poses (static)
    for i, pose in enumerate(poses_np[::2]):  # Every 2nd pose to reduce clutter
        position = pose[:3, 3]
        rotation_matrix = pose[:3, :3]
        
        print(position)
        
        # flip_z = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
        # corrected_rotation = rotation_matrix @ flip_z
        
        corrected_rotation = rotation_matrix  # No correction
        wxyz = viser.transforms.SO3.from_matrix(corrected_rotation).wxyz
        
        server.scene.add_camera_frustum(
            f"/camera_{i}",
            fov=60, aspect=16/9, scale=0.15,
            position=position, wxyz=wxyz,
            color=(0.8, 0.2, 0.2)
        )
    
    # Add start/end markers
    server.scene.add_icosphere("/start", radius=0.1, position=positions[0], color=(0.0, 1.0, 0.0))
    server.scene.add_icosphere("/end", radius=0.1, position=positions[-1], color=(1.0, 0.0, 1.0))
    server.scene.add_frame("/world", axes_length=0.5, position=(0, 0, 0), wxyz=(1, 0, 0, 0))

def animate_frame(server, scene_data, frame_idx, max_points=5000):
    """Update only the point cloud for given frame"""
    # Clear previous frame
    try:
        server.scene.remove("/current_frame")
        server.scene.remove("/current_camera")
    except:
        pass
    
    # Extract points for this frame
    frame_data = {
        'frame': scene_data['frames_tensor'][frame_idx:frame_idx+1],
        'depth': scene_data['depths'][frame_idx:frame_idx+1], 
        'pose_source': scene_data['pose_source'][frame_idx:frame_idx+1],
        'intrinsics': scene_data['intrinsics'][frame_idx:frame_idx+1],
    }
    
    points_3d, colors_rgb = vis_warper.extract_3d_points_with_colors(
        frame_data['frame'], frame_data['depth'], 
        frame_data['pose_source'], frame_data['intrinsics'],
        subsample_step=15
    )
    
    if points_3d.shape[0] > 0:
        points_np = points_3d.cpu().numpy()
        colors_np = colors_rgb.cpu().numpy()
        
        # Limit points
        if len(points_np) > max_points:
            indices = np.random.choice(len(points_np), max_points, replace=False)
            points_np = points_np[indices]
            colors_np = colors_np[indices]
        
        if colors_np.min() < 0:
            colors_np = (colors_np + 1) / 2
            
        # Update point cloud
        server.scene.add_point_cloud(
            "/current_frame", 
            points=points_np, 
            colors=colors_np, 
            point_size=0.03
        )
        
        # Highlight current camera
        pos = scene_data['pose_target'][frame_idx, :3, 3].cpu().numpy()
        server.scene.add_icosphere(
            "/current_camera", 
            radius=0.08, 
            position=pos, 
            color=(1.0, 1.0, 0.0)  # Yellow
        )

# Setup scene
setup_viser_scene(viser_server, scene_data)

# Create GUI controls for animation
@viser_server.on_client_connect
def _(client: viser.ClientHandle) -> None:
    # Animation controls
    play_button = client.gui.add_button("Play/Pause")
    frame_slider = client.gui.add_slider(
        "Frame", 
        min=0, 
        max=scene_data['frames_tensor'].shape[0]-1, 
        step=1, 
        initial_value=0
    )
    speed_slider = client.gui.add_slider(
        "Speed", 
        min=1, 
        max=10.0, 
        step=0.1, 
        initial_value=3.0
    )
    
    # Animation state
    is_playing = [False]
    
    @play_button.on_click
    def _(_):
        is_playing[0] = not is_playing[0]
        
    @frame_slider.on_update
    def _(_):
        animate_frame(viser_server, scene_data, frame_slider.value)
    
    # Animation loop
    import threading
    import time
    
    def animation_loop():
        while True:
            if is_playing[0]:
                current_frame = frame_slider.value
                next_frame = (current_frame + 1) % scene_data['frames_tensor'].shape[0]
                frame_slider.value = next_frame
                animate_frame(viser_server, scene_data, next_frame)
            time.sleep(0.5 / speed_slider.value)
    
    # Start animation thread
    animation_thread = threading.Thread(target=animation_loop, daemon=True)
    animation_thread.start()

# Show first frame
animate_frame(viser_server, scene_data, 0)

In [14]:
##########################################
# Viser GUI Controls
##########################################

import math
import numpy as np
import viser


# Simple Point Size Control
@viser_server.on_client_connect
def _(client: viser.ClientHandle) -> None:
    
    update_camera_position()
    
    # Add simple point size slider
    point_size_slider = client.gui.add_slider(
        "Point Size",
        min=0.005,
        max=0.1,
        step=0.005,
        initial_value=0.015,
    )
    
    # Update point size when slider changes
    @point_size_slider.on_update
    def _(_) -> None:
        if points_3d is not None:
            points_np = points_3d.cpu().numpy()
            colors_np = colors_rgb.cpu().numpy()
            if colors_np.min() < 0:
                colors_np = (colors_np + 1) / 2
            viser_server.scene.add_point_cloud(
                "/scene_points",
                points=points_np,
                colors=colors_np,
                point_size=point_size_slider.value
            )

# Set initial camera position
initial_theta = 0
initial_phi = 75
initial_roll = -90
initial_radius = 10

# Add sliders for camera control (no global variables needed)
theta_slider = viser_server.gui.add_slider(
    "Camera Theta (deg)",
    min=0, max=360, step=1, initial_value=initial_theta,
)

phi_slider = viser_server.gui.add_slider(
    "Camera Phi (deg)", 
    min=-90, max=270, step=1, initial_value=initial_phi,
)

roll_slider = viser_server.gui.add_slider(
    "Camera Roll (deg)",
    min=-180, max=180, step=1, initial_value=initial_roll,
)

radius_slider = viser_server.gui.add_slider(
    "Camera Distance",
    min=1, max=20, step=0.1, initial_value=initial_radius,
)

def update_camera_position():
    theta = math.radians(theta_slider.value)
    phi = math.radians(phi_slider.value)
    r = radius_slider.value
    roll = math.radians(roll_slider.value)
    
    # Convert spherical to cartesian
    x = r * math.cos(phi) * math.cos(theta)
    y = r * math.cos(phi) * math.sin(theta) 
    z = r * math.sin(phi)
    
    position = np.array([x, y, z])
    look_at = np.array([0, 0, 0])
    
    # Calculate camera's forward direction
    forward = (look_at - position)
    forward = forward / np.linalg.norm(forward)
    
    # Handle world up vector based on phi angle to prevent flipping
    if abs(phi) < math.pi/2:  # -90° to +90°: normal "above horizon" view
        world_up = np.array([0, 0, 1])
    else:  # Beyond ±90°: "below horizon" or "upside down" view
        world_up = np.array([0, 0, -1])  # Flip world up
    
    # Calculate right vector
    right = np.cross(forward, world_up)
    if np.linalg.norm(right) < 1e-6:  # Handle gimbal lock at poles
        # Use a fallback right vector
        right = np.array([1, 0, 0]) if abs(theta) < math.pi else np.array([-1, 0, 0])
    else:
        right = right / np.linalg.norm(right)
    
    # Calculate up vector
    up_initial = np.cross(right, forward)
    up_initial = up_initial / np.linalg.norm(up_initial)
    
    # Apply roll rotation around the forward axis
    cos_roll = np.cos(roll)
    sin_roll = np.sin(roll)
    up = cos_roll * up_initial + sin_roll * right
    
    # Set camera using the correct API
    for client in viser_server.get_clients().values():
        client.camera.position = position
        client.camera.look_at = look_at
        client.camera.up_direction = up
        

@theta_slider.on_update
def _(_):
    update_camera_position()
    
@phi_slider.on_update 
def _(_):
    update_camera_position()
    
@radius_slider.on_update
def _(_):
    update_camera_position()

@roll_slider.on_update
def _(_):
    update_camera_position()

# Apply initial camera position
# update_camera_position()

In [None]:
# Cell 4: To test different trajectories (example)
# Change your trajectory and re-run the update
opts_test = copy.deepcopy(opts_base)
opts_test.target_pose = [0, 90, 1, 0, 0]  # 180° rotation

# Generate new scene data
scene_data_test = vis_crafter.extract_scene_data(opts_test)

# Update the same server with new content
update_viser_content(viser_server, scene_data_test, points_3d, colors_rgb)

In [None]:
#########################################################################
# Camera trajectory visualization
##########################################################################

In [None]:
# Extract scene data (same as your existing code)
scene_data = vis_crafter.extract_scene_data(opts_base)

# Generate CIRCULAR trajectory instead of linear
def generate_circular_scene_data(crafter, opts, scene_data, circle_type='horizontal'):
    """Generate new scene data with circular motion"""
    
    # Reuse existing depths and frames
    pose_s, pose_t, K = crafter.get_poses_circular(
        opts, 
        scene_data['depths'], 
        num_frames=opts.video_length,
        circle_type=circle_type
    )
    
    return {
        'frames_numpy': scene_data['frames_numpy'],
        'frames_tensor': scene_data['frames_tensor'],
        'depths': scene_data['depths'], 
        'pose_source': pose_s,
        'pose_target': pose_t,
        'intrinsics': K,
        'radius': scene_data['radius'],
        'trajectory_params': f"circular_{circle_type}"
    }

In [None]:
# Generate different circular motions
# horizontal_circle = generate_circular_scene_data(vis_crafter, opts_base, scene_data, 'horizontal')
vertical_circle = generate_circular_scene_data(vis_crafter, opts_base, scene_data, 'vertical_xz')

# Use in your existing Viser visualization
# update_trajectory_visualization(viser_server, horizontal_circle)
update_trajectory_visualization(viser_server, vertical_circle)

print("Circular trajectories generated!")
print("Available types: horizontal, vertical_xz, vertical_yz, tilted")

In [None]:
#########################################################################
# Camera trajectory visualization - Multiple Trajectories
##########################################################################

def generate_new_trajectory(vis_crafter, opts_base, target_pose_params, scene_data):
    """Generate new trajectory without recomputing point clouds"""
    print(f"Generating trajectory for pose: {target_pose_params}")
    
    # Create new opts with different target pose
    opts_new = copy.deepcopy(opts_base)
    opts_new.target_pose = target_pose_params
    
    # Only regenerate poses, reuse existing depths
    # pose_s, pose_t, K = vis_crafter.get_poses(opts_new, scene_data['depths'], num_frames=opts_new.video_length)
    pose_s, pose_t, K = vis_crafter.get_poses_circular(
        opts_new, 
        scene_data['depths'], 
        num_frames=opts_new.video_length,
        circle_type='horizontal'  # or other types based on params
    )
    
    # Create new scene data with same point clouds but new trajectory
    new_scene_data = {
        'frames_numpy': scene_data['frames_numpy'],
        'frames_tensor': scene_data['frames_tensor'], 
        'depths': scene_data['depths'],
        'pose_source': pose_s,
        'pose_target': pose_t,
        'intrinsics': K,
        'radius': scene_data['radius'],
        'trajectory_params': target_pose_params
    }
    
    return new_scene_data

def update_trajectory_visualization(server, new_scene_data):
    """Update only the trajectory visualization, keep point clouds"""
    
    # Clear existing trajectory elements
    try:
        server.scene.remove("/trajectory")
        for i in range(50):  # Clear up to 50 camera poses
            try:
                server.scene.remove(f"/camera_{i}")
            except:
                break
        server.scene.remove("/start")
        server.scene.remove("/end")
    except:
        pass
    
    # Add new trajectory
    poses_np = new_scene_data['pose_target'].cpu().numpy()
    positions = poses_np[:, :3, 3]
    
    # Add trajectory spline
    server.scene.add_spline_catmull_rom(
        "/trajectory", 
        positions=positions, 
        color=(1.0, 0.0, 0.0), 
        line_width=3.0
    )
    
    # Add camera poses (every 2nd to reduce clutter)
    for i, pose in enumerate(poses_np[::2]):
        position = pose[:3, 3]
        rotation_matrix = pose[:3, :3]
        
        # flip_z = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
        # corrected_rotation = rotation_matrix @ flip_z
        
        corrected_rotation = rotation_matrix  # No correction
        wxyz = viser.transforms.SO3.from_matrix(corrected_rotation).wxyz
        
        server.scene.add_camera_frustum(
            f"/camera_{i}",
            fov=60, aspect=16/9, scale=0.15,
            position=position, wxyz=wxyz,
            color=(0.8, 0.2, 0.2)
        )
    
    # Add new start/end markers
    server.scene.add_icosphere("/start", radius=0.1, position=positions[0], color=(0.0, 1.0, 0.0))
    server.scene.add_icosphere("/end", radius=0.1, position=positions[-1], color=(1.0, 0.0, 1.0))

# Predefined trajectory presets
TRAJECTORY_PRESETS = {
    "Original Right 90°": [0, 90, 1, 0, 0],
    "Left 90°": [0, -90, 1, 0, 0], 
    "Full Circle": [0, 360, 1, 0, 0],
    "Up and Right": [45, 45, 0.5, 0, 1],
    "Pull Back": [0, 0, 3, 0, 0],
    "Orbit Up": [30, 180, 0, 0, 0],
    "Dolly Forward": [0, 0, -2, 0, 0],
    "Rise and Turn": [60, 120, 1, 0, 2],
}

# Create trajectory selection GUI
@viser_server.on_client_connect  
def _(client: viser.ClientHandle) -> None:
    
    # Trajectory selection dropdown
    trajectory_dropdown = client.gui.add_dropdown(
        "Trajectory Preset",
        options=list(TRAJECTORY_PRESETS.keys()),
        initial_value="Original Right 90°"
    )
    
    # Manual trajectory controls
    with client.gui.add_folder("Custom Trajectory"):
        theta_input = client.gui.add_slider("Theta (pitch)", min=-90, max=90, step=1, initial_value=0)
        phi_input = client.gui.add_slider("Phi (yaw)", min=-360, max=360, step=5, initial_value=90)
        dr_input = client.gui.add_slider("Distance", min=-3, max=3, step=0.1, initial_value=1.0)
        dx_input = client.gui.add_slider("X offset", min=-2, max=2, step=0.1, initial_value=0.0)
        dy_input = client.gui.add_slider("Y offset", min=-2, max=2, step=0.1, initial_value=0.0)
        
        generate_button = client.gui.add_button("Generate Custom Trajectory")
    
    # Display current trajectory info
    trajectory_info = client.gui.add_text("Trajectory Info", initial_value="Current: [0, 90, 1, 0, 0]")
    
    # Global reference to current scene data
    current_scene_data = [scene_data]  # Use list for mutability
    
    @trajectory_dropdown.on_update
    def _(_):
        selected_preset = trajectory_dropdown.value
        target_pose = TRAJECTORY_PRESETS[selected_preset]
        
        # Generate new trajectory
        new_scene_data = generate_new_trajectory(vis_crafter, opts_base, target_pose, scene_data)
        current_scene_data[0] = new_scene_data
        
        # Update visualization
        update_trajectory_visualization(viser_server, new_scene_data)
        
        # Update info
        trajectory_info.value = f"Current: {target_pose}"
        
        # Update manual controls to match preset
        theta_input.value = target_pose[0]
        phi_input.value = target_pose[1] 
        dr_input.value = target_pose[2]
        dx_input.value = target_pose[3]
        dy_input.value = target_pose[4]
        
        print(f"Switched to trajectory: {selected_preset} -> {target_pose}")
    
    @generate_button.on_click
    def _(_):
        custom_pose = [theta_input.value, phi_input.value, dr_input.value, dx_input.value, dy_input.value]
        
        # Generate new trajectory
        new_scene_data = generate_new_trajectory(vis_crafter, opts_base, custom_pose, scene_data)
        current_scene_data[0] = new_scene_data
        
        # Update visualization  
        update_trajectory_visualization(viser_server, new_scene_data)
        
        # Update info
        trajectory_info.value = f"Current: {custom_pose}"
        
        print(f"Generated custom trajectory: {custom_pose}")
    
    # Animation controls for current trajectory
    with client.gui.add_folder("Animation"):
        play_button = client.gui.add_button("Play/Pause")
        frame_slider = client.gui.add_slider(
            "Frame", 
            min=0, 
            max=scene_data['frames_tensor'].shape[0]-1, 
            step=1, 
            initial_value=0
        )
        speed_slider = client.gui.add_slider("Speed", min=0.5, max=5.0, step=0.1, initial_value=1.0)
    
    # Animation state
    is_playing = [False]
    
    @play_button.on_click
    def _(_):
        is_playing[0] = not is_playing[0]
        
    @frame_slider.on_update
    def _(_):
        animate_frame(viser_server, current_scene_data[0], frame_slider.value)
    
    # Animation loop
    import threading
    import time
    
    def animation_loop():
        while True:
            if is_playing[0]:
                current_frame = frame_slider.value
                next_frame = (current_frame + 1) % current_scene_data[0]['frames_tensor'].shape[0]
                frame_slider.value = next_frame
                animate_frame(viser_server, current_scene_data[0], next_frame)
            time.sleep(1.0 / speed_slider.value)
    
    # Start animation thread
    animation_thread = threading.Thread(target=animation_loop, daemon=True)
    animation_thread.start()

print("Trajectory visualization setup complete!")
print("Available presets:", list(TRAJECTORY_PRESETS.keys()))

In [None]:
# Generate and print trajectory numbers for a few presets
target_poses = {
    "Right 90°": [0, 90, 1, 0, 0],
    "Full Circle": [0, 360, 1, 0, 0],
    "Pull Back": [0, 0, 3, 0, 0]
}

new_data = generate_new_trajectory(
    vis_crafter, opts_base,
    target_poses["Right 90°"],
    scene_data
    )

positions = new_data['pose_target'].cpu().numpy()[:, :3, 3]


In [None]:
# using maplotlib, plot the 0th and 2rd axis
# Using matplotlib, plot the 0th and 2nd axis of positions variable
plt.figure(figsize=(10, 8))

# Plot X vs Z (0th vs 2nd axis)
plt.plot(positions[:, 0], positions[:, 2], 'b-o', linewidth=2, markersize=4)
plt.xlabel('X Position')
plt.ylabel('Z Position') 
plt.title('Camera Trajectory: X vs Z (Top-down view)')
plt.grid(True, alpha=0.3)
plt.axis('equal')

# Mark start and end points
plt.plot(positions[0, 0], positions[0, 2], 'go', markersize=10, label='Start')
plt.plot(positions[-1, 0], positions[-1, 2], 'ro', markersize=10, label='End')

# Add frame numbers for reference
for i in range(0, len(positions), max(1, len(positions)//8)):
    plt.annotate(f'{i}', (positions[i, 0], positions[i, 2]), 
                xytext=(5, 5), textcoords='offset points', fontsize=8)

plt.legend()
plt.tight_layout()
plt.show()

# Print the trajectory values
print(f"Trajectory shape: {positions.shape}")
print(f"X range: {positions[:, 0].min():.3f} to {positions[:, 0].max():.3f}")
print(f"Z range: {positions[:, 2].min():.3f} to {positions[:, 2].max():.3f}")
print(f"Start (X,Z): ({positions[0, 0]:.3f}, {positions[0, 2]:.3f})")
print(f"End (X,Z): ({positions[-1, 0]:.3f}, {positions[-1, 2]:.3f})")

In [None]:
# Check the coordinate system by examining a few poses
poses = new_data['pose_target'].cpu().numpy()
print("Shape:", poses.shape)
print("\nFirst pose (Frame 0):")
print(poses[0])
print("\nPosition:", poses[0][:3, 3])
print("Rotation matrix:")
print(poses[0][:3, :3])

# Check if it follows OpenCV or OpenGL convention
print("\nCoordinate system analysis:")
print("Z-axis (forward direction):", poses[0][:3, 2])
print("Y-axis (up direction):", poses[0][:3, 1]) 
print("X-axis (right direction):", poses[0][:3, 0])