In [None]:
# Cell 1: Imports and Path Setup
import sys
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import copy
from datetime import datetime
from tqdm import tqdm

# Add TrajectoryCrafter to Python path
trajcrafter_path = "/home/azhuravl/work/TrajectoryCrafter"
sys.path.insert(0, trajcrafter_path)

# Change working directory to TrajectoryCrafter
os.chdir(trajcrafter_path)

# Now import TrajectoryCrafter modules
from demo import TrajCrafter
from models.utils import Warper, read_video_frames
from models.infer import DepthCrafterDemo
import inference_orbits

print("Imports successful!")
print(f"Working directory: {os.getcwd()}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Cell 2: Argument Setup
# Create opts manually for notebook use
parser = inference_orbits.get_parser()
opts_base = parser.parse_args([
    '--video_path', './test/videos/0-NNvgaTcVzAG0-r.mp4',  # Change this path
    '--radius', '1.0',
    '--device', 'cuda:0'
])

# Set common parameters
opts_base.weight_dtype = torch.bfloat16
opts_base.camera = "target"
opts_base.mode = "gradual"
opts_base.mask = True
opts_base.target_pose = [0, 90, opts_base.radius, 0, 0]  # right_90 example

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
video_basename = os.path.splitext(os.path.basename(opts_base.video_path))[0]
opts_base.exp_name = f"{video_basename}_{timestamp}_vis"

print(f"Video: {opts_base.video_path}")
print(f"Target pose: {opts_base.target_pose}")
print(f"Device: {opts_base.device}")

In [None]:
# Cell 3: Visualization Classes
class VisualizationWarper(Warper):
    """Extended Warper class for 3D visualization"""
    
    def extract_3d_points_with_colors(
        self,
        frame1: torch.Tensor,
        depth1: torch.Tensor,
        transformation1: torch.Tensor,
        intrinsic1: torch.Tensor,
        subsample_step: int = 10
    ):
        """Extract 3D world points and their corresponding colors for visualization"""
        b, c, h, w = frame1.shape
        
        # Move tensors to device
        frame1 = frame1.to(self.device).to(self.dtype)
        depth1 = depth1.to(self.device).to(self.dtype)
        transformation1 = transformation1.to(self.device).to(self.dtype)
        intrinsic1 = intrinsic1.to(self.device).to(self.dtype)
        
        # Create subsampled pixel coordinates for performance
        x_coords = torch.arange(0, w, subsample_step, dtype=torch.float32)
        y_coords = torch.arange(0, h, subsample_step, dtype=torch.float32)
        x2d, y2d = torch.meshgrid(x_coords, y_coords, indexing='xy')
        x2d = x2d.to(depth1.device)
        y2d = y2d.to(depth1.device)
        ones_2d = torch.ones_like(x2d)
        
        # Stack into homogeneous coordinates
        pos_vectors_homo = torch.stack([x2d, y2d, ones_2d], dim=2)[None, :, :, :, None]
        
        # Subsample depth and colors
        depth_sub = depth1[:, 0, ::subsample_step, ::subsample_step]
        colors_sub = frame1[:, :, ::subsample_step, ::subsample_step]
        
        # Unproject to 3D camera coordinates
        intrinsic1_inv = torch.linalg.inv(intrinsic1)
        intrinsic1_inv_4d = intrinsic1_inv[:, None, None]
        depth_4d = depth_sub[:, :, :, None, None]
        
        unnormalized_pos = torch.matmul(intrinsic1_inv_4d, pos_vectors_homo)
        camera_points = depth_4d * unnormalized_pos
        
        # Transform to world coordinates
        ones_4d = torch.ones(b, camera_points.shape[1], camera_points.shape[2], 1, 1).to(depth1)
        world_points_homo = torch.cat([camera_points, ones_4d], dim=3)
        trans_4d = transformation1[:, None, None]
        world_points_homo = torch.matmul(trans_4d, world_points_homo)
        world_points = world_points_homo[:, :, :, :3, 0]  # (b, h_sub, w_sub, 3)
        
        # Prepare colors
        colors = colors_sub.permute(0, 2, 3, 1)  # (b, h_sub, w_sub, 3)
        
        # Filter valid points (positive depth)
        valid_mask = depth_sub > 0  # (b, h_sub, w_sub)
        
        # Flatten and filter
        points_3d = world_points[valid_mask]  # (N, 3)
        colors_rgb = colors[valid_mask]       # (N, 3)
        
        return points_3d, colors_rgb


class TrajCrafterVisualization(TrajCrafter):
    """Lightweight TrajCrafter subclass for camera trajectory visualization"""
    
    def __init__(self, opts):
        # Only initialize what we need for pose generation and depth estimation
        self.device = opts.device
        self.depth_estimater = DepthCrafterDemo(
            unet_path=opts.unet_path,
            pre_train_path=opts.pre_train_path,
            cpu_offload=opts.cpu_offload,
            device=opts.device,
        )
        print("TrajCrafterVisualization initialized (diffusion pipeline skipped)")
    
    def extract_scene_data(self, opts):
        """Extract all data needed for 3D visualization"""
        print("Reading video frames...")
        frames = read_video_frames(
            opts.video_path, opts.video_length, opts.stride, opts.max_res
        )
        
        print("Estimating depth...")
        depths = self.depth_estimater.infer(
            frames,
            opts.near,
            opts.far,
            opts.depth_inference_steps,
            opts.depth_guidance_scale,
            window_size=opts.window_size,
            overlap=opts.overlap,
        ).to(opts.device)
        
        print("Converting frames to tensors...")
        frames_tensor = (
            torch.from_numpy(frames).permute(0, 3, 1, 2).to(opts.device) * 2.0 - 1.0
        )
        
        print("Generating camera poses...")
        pose_s, pose_t, K = self.get_poses(opts, depths, num_frames=opts.video_length)
        
        # Calculate scene radius
        radius = (
            depths[0, 0, depths.shape[-2] // 2, depths.shape[-1] // 2].cpu()
            * opts.radius_scale
        )
        radius = min(radius, 5)
        
        return {
            'frames_numpy': frames,
            'frames_tensor': frames_tensor,
            'depths': depths,
            'pose_source': pose_s,
            'pose_target': pose_t,
            'intrinsics': K,
            'radius': radius,
            'trajectory_params': opts.target_pose if hasattr(opts, 'target_pose') else None
        }

print("Classes defined successfully!")

In [None]:
# Cell 4: Run Visualization
# Initialize visualization TrajCrafter
print("Initializing TrajCrafter for visualization...")
vis_crafter = TrajCrafterVisualization(opts_base)

# Extract scene data
print("Extracting scene data...")
scene_data = vis_crafter.extract_scene_data(opts_base)

# Create warper for 3D point extraction
print("Creating 3D point cloud from all frames...")
vis_warper = VisualizationWarper(device=opts_base.device)

# Extract points from all frames
all_points_3d = []
all_colors_rgb = []

num_frames = scene_data['frames_tensor'].shape[0]
for i in tqdm(range(num_frames), desc="Processing frames"):
    frame_data = {
        'frame': scene_data['frames_tensor'][i:i+1],
        'depth': scene_data['depths'][i:i+1], 
        'pose_source': scene_data['pose_source'][i:i+1],
        'intrinsics': scene_data['intrinsics'][i:i+1],
    }
    
    points_3d_frame, colors_rgb_frame = vis_warper.extract_3d_points_with_colors(
        frame_data['frame'],
        frame_data['depth'], 
        frame_data['pose_source'],
        frame_data['intrinsics'],
        subsample_step=20  # Increased for performance with multiple frames
    )
    
    if points_3d_frame.shape[0] > 0:  # Only add if we have valid points
        all_points_3d.append(points_3d_frame)
        all_colors_rgb.append(colors_rgb_frame)

# Concatenate all points
if all_points_3d:
    points_3d = torch.cat(all_points_3d, dim=0)
    colors_rgb = torch.cat(all_colors_rgb, dim=0)
    print(f"Generated {points_3d.shape[0]} 3D points from {len(all_points_3d)} frames")
else:
    print("No valid 3D points extracted!")
    points_3d = None
    colors_rgb = None

print(f"Camera trajectory: {scene_data['pose_target'].shape[0]} poses")

In [None]:
# Cell 1: Create Viser Server (run once)
import viser

# Create server once
viser_server = viser.ViserServer(port=8080)

In [None]:
viser_server.stop()

In [None]:
# Cell 2: Update Viser Content with Camera Direction Arrows (Previous working version)
def update_viser_content(server, scene_data, points_3d, colors_rgb, max_points=50000):
    """Update viser server content without recreating server"""
    
    # Clear existing content
    server.scene.reset()
    
    # Add point cloud with max points limit
    if points_3d is not None and colors_rgb is not None:
        points_np = points_3d.cpu().numpy()
        colors_np = colors_rgb.cpu().numpy()
        
        # Limit number of points
        if len(points_np) > max_points:
            indices = np.random.choice(len(points_np), max_points, replace=False)
            points_np = points_np[indices]
            colors_np = colors_np[indices]
        
        if colors_np.min() < 0:
            colors_np = (colors_np + 1) / 2
        server.scene.add_point_cloud(
            "/scene_points", 
            points=points_np, 
            colors=colors_np, 
            point_size=0.05
        )
    
    poses_np = scene_data['pose_target'].cpu().numpy()
    positions = poses_np[:, :3, 3]
    
    # Add trajectory
    server.scene.add_spline_catmull_rom(
        "/trajectory", 
        positions=positions, 
        color=(1.0, 0.0, 0.0), 
        line_width=3.0
    )
    
    # Convert radius to numpy float
    arrow_length = float(scene_data['radius']) * 0.4
    
    for i, pose in enumerate(poses_np[::5]):
        position = pose[:3, 3]
        rotation_matrix = pose[:3, :3]
        
        # Flip Z-axis (180° rotation around Y) - your working correction
        flip_z = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]])
        corrected_rotation = rotation_matrix @ flip_z
        wxyz = viser.transforms.SO3.from_matrix(corrected_rotation).wxyz
        
        # Add camera frustum
        server.scene.add_camera_frustum(
            f"/camera_{i}",
            fov=60,
            aspect=16/9,
            scale=0.2,
            position=position,
            wxyz=wxyz,
            color=(0.8, 0.2, 0.2)
        )
        
        # Calculate camera look direction (negative Z after correction)
        original_forward = rotation_matrix[:, 2]  # Z column
        look_direction = -original_forward
        
        # Create arrow endpoint
        arrow_end = position + look_direction * arrow_length
        
        # Add direction arrow using thick line
        server.scene.add_spline_catmull_rom(
            f"/camera_direction_{i}",
            positions=np.array([position, arrow_end]),
            color=(0.0, 1.0, 1.0),  # Cyan for visibility
            line_width=6.0
        )
        
        # Add arrowhead using a small sphere (simpler than cone)
        server.scene.add_icosphere(
            f"/camera_arrowhead_{i}",
            radius=0.05,
            position=arrow_end,
            color=(0.0, 1.0, 1.0)  # Same cyan color
        )
        
        # Optional: Add a smaller sphere at camera position for clarity
        server.scene.add_icosphere(
            f"/camera_center_{i}",
            radius=0.03,
            position=position,
            color=(1.0, 0.5, 0.0)  # Orange
        )
    
    # Add markers
    server.scene.add_icosphere(
        "/start", 
        radius=0.1, 
        position=positions[0], 
        color=(0.0, 1.0, 0.0)
    )
    server.scene.add_icosphere(
        "/end", 
        radius=0.1, 
        position=positions[-1], 
        color=(1.0, 0.0, 1.0)
    )
    server.scene.add_frame(
        "/world", 
        axes_length=0.5, 
        axes_radius=0.02,
        position=(0, 0, 0),
        wxyz=(1, 0, 0, 0)
    )
    
    # Print info
    path_length = np.sum(np.linalg.norm(np.diff(positions, axis=0), axis=1))
    print(f"Updated scene: {points_3d.shape[0] if points_3d is not None else 0} points, {len(poses_np)} poses")
    print(f"Trajectory length: {path_length:.3f}, Arrow length: {arrow_length:.3f}")
    if scene_data['trajectory_params']:
        dtheta, dphi, dr, dx, dy = scene_data['trajectory_params']
        print(f"Motion: θ={dtheta}°, φ={dphi}°, r={dr}, x={dx}, y={dy}")

# Update the content
update_viser_content(viser_server, scene_data, points_3d, colors_rgb, max_points=5000)

In [None]:
# Simple Point Size Control
@viser_server.on_client_connect
def _(client: viser.ClientHandle) -> None:
    
    # Add simple point size slider
    point_size_slider = client.gui.add_slider(
        "Point Size",
        min=0.005,
        max=0.1,
        step=0.005,
        initial_value=0.015,
    )
    
    # Update point size when slider changes
    @point_size_slider.on_update
    def _(_) -> None:
        if points_3d is not None:
            points_np = points_3d.cpu().numpy()
            colors_np = colors_rgb.cpu().numpy()
            if colors_np.min() < 0:
                colors_np = (colors_np + 1) / 2
            viser_server.scene.add_point_cloud(
                "/scene_points",
                points=points_np,
                colors=colors_np,
                point_size=point_size_slider.value
            )

print("Point size control added!")

In [None]:
# set

In [None]:
# Cell 1: Clean Server Restart
import viser
import asyncio

# Clean up any existing server
if 'viser_server' in globals():
    try:
        viser_server.stop()
        print("Stopped existing server")
    except:
        pass
    del viser_server

# Clear any existing event loops
try:
    loop = asyncio.get_event_loop()
    if loop.is_closed():
        asyncio.set_event_loop(asyncio.new_event_loop())
except:
    asyncio.set_event_loop(asyncio.new_event_loop())

# Create fresh server
viser_server = viser.ViserServer(port=8080)
print(f"Fresh viser server created at http://localhost:8080")

In [None]:
# Cell 4: To test different trajectories (example)
# Change your trajectory and re-run the update
opts_test = copy.deepcopy(opts_base)
opts_test.target_pose = [0, 90, 1, 0, 0]  # 180° rotation

# Generate new scene data
scene_data_test = vis_crafter.extract_scene_data(opts_test)

# Update the same server with new content
update_viser_content(viser_server, scene_data_test, points_3d, colors_rgb)

In [None]:
import math
import numpy as np
import viser

# Add sliders for camera control
theta_slider = viser_server.gui.add_slider(
    "Camera Theta (deg)",
    min=0, max=360, step=1, initial_value=0,
)

phi_slider = viser_server.gui.add_slider(
    "Camera Phi (deg)", 
    min=-90, max=90, step=1, initial_value=0,
)

radius_slider = viser_server.gui.add_slider(
    "Camera Distance",
    min=1, max=10, step=0.1, initial_value=5,
)

roll_slider = viser_server.gui.add_slider(
    "Camera Roll (deg)",
    min=-180, max=180, step=1, initial_value=0,
)

def update_camera_position():
    theta = math.radians(theta_slider.value)
    phi = math.radians(phi_slider.value)
    r = radius_slider.value
    roll = math.radians(roll_slider.value)
    
    # Convert spherical to cartesian
    x = r * math.cos(phi) * math.cos(theta)
    y = r * math.cos(phi) * math.sin(theta) 
    z = r * math.sin(phi)
    
    position = np.array([x, y, z])
    look_at = np.array([0, 0, 0])
    
    # Calculate camera's forward direction
    forward = (look_at - position)
    forward = forward / np.linalg.norm(forward)
    
    # Calculate initial right and up vectors
    world_up = np.array([0, 0, 1])
    right = np.cross(forward, world_up)
    right = right / np.linalg.norm(right)
    up_initial = np.cross(right, forward)
    up_initial = up_initial / np.linalg.norm(up_initial)
    
    # Apply roll rotation around the forward axis
    cos_roll = np.cos(roll)
    sin_roll = np.sin(roll)
    up = cos_roll * up_initial + sin_roll * right
    
    # Set camera using the correct API
    for client in viser_server.get_clients().values():
        client.camera.position = position
        client.camera.look_at = look_at
        client.camera.up_direction = up

@theta_slider.on_update
def _(_):
    update_camera_position()
    
@phi_slider.on_update 
def _(_):
    update_camera_position()
    
@radius_slider.on_update
def _(_):
    update_camera_position()

@roll_slider.on_update
def _(_):
    update_camera_position()