In [None]:
# Input Video
# Trajectory Planning
# Camera_0 = Estimate video camera pose
# Target angle: 100, 4 segments, from 1st frame
# Camera_1 = 0 to 25 – Camera_0
# Camera_2 

def pad_video(frames, target_length):
    if frames.shape[0] < target_length:
        last_frame = frames[-1:]
        num_pad = target_length - frames.shape[0]
        pad_frames = np.repeat(last_frame, num_pad, axis=0)
        frames = np.concatenate([frames, pad_frames], axis=0)
    return frames


if __name__ == "__main__":
    # 0: set up stage
    
    # read input video
    frames = read_video_frames(
            opts.video_path, opts.video_length, opts.stride, opts.max_res
        )
    
    # pad if too short
    frames = pad_video(frames, opts.video_length)
    
    # prompt    
    prompt = self.get_caption(opts, frames[opts.video_length // 2])

    # get global cam_pos of input video + n=video_length 3D point clouds
    poses_input, pc_input = geometric_fm(frames, opts)
    
    # plan trajectory
    traj_segments = plan_trajectory(
        poses_input[0], opts.target_pose, opts.n_splits
    )
    
    # save everything
    # Directory Structure
    # - exp_name/
    
    #   - input/
    #     + input.mp4
    #     + cameras_input.npy
    #     + point_cloud_input.ply
    #     + prompt.txt
    
    #   - stage_1/
    #     - input.mp4 mask.mp4 render.mp4 gen.mp4
    #     - point_cloud_input.ply
    #     + cameras_target.npy
    setup_exp_directory(opts, frames, poses_input, pc_input, prompt, traj_segments)
    
    pc_global = pc_input
    
    # 1: autoregressive generation
    for i in range(opts.n_splits):
        
        # TODO: video reversal for even segments
        
        inpainted_video = generate_segment(frames, pc_global, traj_segments[i], opts)
        
        pc_inpainted = geometric_fm(inpainted_video, opts)
        pc_global = merge_point_clouds(pc_global, pc_inpainted)
        
        save_segment_results(inpainted_video, pc_global, traj_segments[i], opts, segment_idx=i)    
    
    # maybe final global inpainting here
        

# Inpaint
# Project to 3D, save PC_1 in COLMAP format
# Inpaint
# Save camera_1, gen_1, render_1
# Extend
# Reverse gen_1
# camera_1 are known - Global
# Estimate Depth from gen_1_r
# Project to 3D
# Merge with PC_1, save in COLMAP format
# Inpaint
# Save camera_2


### Implementation

In [None]:
import torch
import numpy as np
import copy
import os
from pathlib import Path
import cv2
import shutil

from demo import TrajCrafter
from models.utils import Warper, read_video_frames, sphere2pose, save_video
import torch.nn.functional as F
from tqdm import tqdm

from models.infer import DepthCrafterDemo

import os
from datetime import datetime
import torch
import copy
import time
import sys
import tempfile
from pathlib import Path

# Add core.py to path if needed
sys.path.append('/home/azhuravl/work/TrajectoryCrafter/notebooks/28_08_25_trajectories')
from core import VisualizationWarper

sys.path.append('/home/azhuravl/work/TrajectoryCrafter/notebooks/06_10_25_vggt')
from parsing import get_parser


class TrajCrafterAutoregressive(TrajCrafter):
    def __init__(self, opts):
        super().__init__(opts)

        # self.funwarp = VisualizationWarper(device=opts.device)
        self.prompt = None
        
        self.K = torch.tensor(
            [[500, 0.0, 512.], [0.0, 500, 288.], [0.0, 0.0, 1.0]]
            ).repeat(opts.video_length, 1, 1).to(opts.device)

In [None]:
# ============================================================================
# UTILITY FUNCTIONS (moved outside class)
# ============================================================================

def pad_video(frames, target_length):
    if frames.shape[0] < target_length:
        last_frame = frames[-1:]
        num_pad = target_length - frames.shape[0]
        pad_frames = np.repeat(last_frame, num_pad, axis=0)
        frames = np.concatenate([frames, pad_frames], axis=0)
    return frames


def generate_traj_specified(c2ws_anchor, target_pose, n_frames, device):
    theta, phi, d_r, d_x, d_y = target_pose
    
    thetas = np.linspace(0, theta, n_frames)  
    phis = np.linspace(0, phi, n_frames)          
    rs = np.linspace(0, d_r, n_frames)            
    xs = np.linspace(0, d_x, n_frames)            
    ys = np.linspace(0, d_y, n_frames)            
    
    c2ws_list = []
    for th, ph, r, x, y in zip(thetas, phis, rs, xs, ys):
        c2w_new = sphere2pose(
            c2ws_anchor,
            np.float32(th),
            np.float32(ph),
            np.float32(r),
            device,
            np.float32(x),
            np.float32(y),
        )
        c2ws_list.append(c2w_new)
    c2ws = torch.cat(c2ws_list, dim=0)
    return c2ws


def save_poses_torch(c2ws, filepath):
    """Save camera poses as PyTorch tensor (.pth file)"""
    torch.save(c2ws.cpu(), filepath)

def save_point_clouds_torch(pc_list, color_list, dirpath):
    """Save point clouds as PyTorch tensors (much faster than text files)"""
    os.makedirs(dirpath, exist_ok=True)
    
    # Save as individual tensor files
    for idx, (pc, color) in enumerate(zip(pc_list, color_list)):
        # Save points and colors as separate tensors
        torch.save(pc.cpu(), os.path.join(dirpath, f'points_{idx:03d}.pth'))
        torch.save(color.cpu(), os.path.join(dirpath, f'colors_{idx:03d}.pth'))
    

def save_segment_results(pc_input, color_input, pc_inpainted, color_inpainted, 
                        pc_merged, color_merged, traj_segment, opts, segment_idx):
    # Function to save results for each segment
    # Implementation needed
    stage_dir = Path(opts.save_dir) / f'stage_{segment_idx+1}'
    stage_dir.mkdir(parents=True, exist_ok=True)
    

    save_point_clouds_torch(pc_input, color_input, stage_dir / 'point_cloud_input')
    # save_point_clouds_torch(pc_inpainted, color_inpainted, stage_dir / 'point_cloud_inpainted')
    # save_point_clouds_torch(pc_merged, color_merged, stage_dir / 'point_cloud_merged')
    save_poses_torch(traj_segment, stage_dir / 'cameras_target.pth')



In [None]:

def extract_point_cloud(frames, c2ws, opts):
    
    # print('before depth', frames.shape, frames.dtype, min(frames), max(frames))
    
    depths = vis_crafter.depth_estimater.infer(
        frames,
        opts.near,
        opts.far,
        opts.depth_inference_steps,
        opts.depth_guidance_scale,
        window_size=opts.window_size,
        overlap=opts.overlap,
    ).to(opts.device)
    
    radius = (
        depths[0, 0, depths.shape[-2] // 2, depths.shape[-1] // 2].cpu()
        * opts.radius_scale
    )
    radius = min(radius, 5)
    
    # frames = torch.from_numpy(frames).to(opts.device)
    
    frames = (
        torch.from_numpy(frames).permute(0, 3, 1, 2).to(opts.device) * 2.0 - 1.0
    )  # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
    assert frames.shape[0] == opts.video_length

    
    pc_list = []
    color_list = []
    for i in range(opts.video_length):
        
        # print(frames[i:i+1].shape)
        # print(depths[i:i+1].shape)
        
        pc, color = funwarp.extract_3d_points_with_colors(
            frames[i:i+1],
            depths[i:i+1],
            c2ws[i:i+1],
            vis_crafter.K[i:i+1],
            subsample_step=1
        )
        # print(pc.device)
        # print(color.device)
        pc_list.append(pc)
        color_list.append(color)
    
    return pc_list, color_list, radius


def generate_segment(frames, pc_input, color_input, traj_segment, segment_dir, opts):
    
    frames = (
        torch.from_numpy(frames).permute(0, 3, 1, 2).to(opts.device) * 2.0 - 1.0
    )  # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
    assert frames.shape[0] == opts.video_length

    # render the point clouds
    warped_images = []
    masks = []        
    for i in tqdm(range(opts.video_length)):
        
        # print(pc_input[i].device)
        # print(color_input[i].device)
        # print(traj_segment[i:i+1].device)
        # print(vis_crafter.K[i:i+1].device)
        
        output_frame, output_mask = funwarp.render_pointcloud_native(
            pc_input[i],
            color_input[i],
            traj_segment[i:i+1],
            vis_crafter.K[i:i+1],
            image_size=(576, 1024),
            mask=opts.mask,
        )
        warped_images.append(output_frame)
        masks.append(output_mask)
        
        # print(color_input)
        
    # plot warped images 0, 10, 20
    import matplotlib.pyplot as plt
    fig, axs = plt.subplots(3, figsize=(20, 6))
    
    # torch.Size([1, 3, 576, 1024]) torch.Size([1, 1, 576, 1024])
    
    # print(warped_images.min(), warped_images.max())
    # print(warped_images)
    
    # print(warped_images[0].shape, masks[0].shape)
    axs[0].imshow((frames[30].permute(1, 2, 0).cpu().numpy() + 1.0) / 2.0)
    axs[1].imshow((warped_images[30][0].permute(1, 2, 0).cpu().numpy() + 1.0) / 2.0)
    axs[2].imshow(masks[30][0].permute(1, 2, 0).cpu().numpy())
    
    plt.show()
    return
        
    cond_video = (torch.cat(warped_images) + 1.0) / 2.0
    cond_masks = torch.cat(masks)

    frames = F.interpolate(
        frames, size=opts.sample_size, mode='bilinear', align_corners=False
    )
    cond_video = F.interpolate(
        cond_video, size=opts.sample_size, mode='bilinear', align_corners=False
    )
    cond_masks = F.interpolate(cond_masks, size=opts.sample_size, mode='nearest')
    
    save_video(
        (frames.permute(0, 2, 3, 1) + 1.0) / 2.0,
        os.path.join(segment_dir, 'input.mp4'),
        fps=opts.fps,
    )
    save_video(
        cond_video.permute(0, 2, 3, 1),
        os.path.join(segment_dir, 'render.mp4'),
        fps=opts.fps,
    )
    save_video(
        cond_masks.repeat(1, 3, 1, 1).permute(0, 2, 3, 1),
        os.path.join(segment_dir, 'mask.mp4'),
        fps=opts.fps,
    )

    frames = (frames.permute(1, 0, 2, 3).unsqueeze(0) + 1.0) / 2.0
    frames_ref = frames[:, :, :10, :, :]
    cond_video = cond_video.permute(1, 0, 2, 3).unsqueeze(0)
    cond_masks = (1.0 - cond_masks.permute(1, 0, 2, 3).unsqueeze(0)) * 255.0
    generator = torch.Generator(device=opts.device).manual_seed(opts.seed)

    # with torch.no_grad():            
    #     sample = vis_crafter.pipeline(
    #         vis_crafter.prompt,
    #         num_frames=opts.video_length,
    #         negative_prompt=opts.negative_prompt,
    #         height=opts.sample_size[0],
    #         width=opts.sample_size[1],
    #         generator=generator,
    #         guidance_scale=opts.diffusion_guidance_scale,
    #         num_inference_steps=opts.diffusion_inference_steps,
    #         video=cond_video.to(opts.device),
    #         mask_video=cond_masks.to(opts.device),
    #         reference=frames_ref,
    #     ).videos
    
    sample = frames
        
    save_video(
        sample[0].permute(1, 2, 3, 0),
        os.path.join(segment_dir, 'gen.mp4'),
        fps=opts.fps,
    )
    
    frames = read_video_frames(
        os.path.join(segment_dir, 'gen.mp4'), opts.video_length, opts.stride, opts.max_res
    )
    return frames
    
    # print('after diffusion', sample[0].shape, sample[0].min(), sample[0].max())
    
    # return sample[0].permute(1, 2, 3, 0)

    

def infer_autoregressive(opts):
    
    # read input video
    frames = read_video_frames(
        opts.video_path, opts.video_length, opts.stride, opts.max_res
    )
    
    # pad if too short
    frames = pad_video(frames, opts.video_length)
    
    # prompt
    vis_crafter.prompt = vis_crafter.get_caption(opts, frames[opts.video_length // 2])
    
    ########################################################
    # Geometric FM
    ########################################################
    
    c2ws_init = torch.tensor([
                [-1.0, 0.0, 0.0, 0.0],
                [0.0, 1.0, 0.0, 0.0],
                [0.0, 0.0, -1.0, 0.0],
                [0.0, 0.0, 0.0, 1.0],
        ]).repeat(opts.video_length, 1, 1).to(opts.device)
    
    # c2ws_init = torch.tensor([
    #             [1.0, 0.0, 0.0, 0.0],
    #             [0.0, 1.0, 0.0, 0.0],
    #             [0.0, 0.0, 1.0, 0.0],
    #             [0.0, 0.0, 0.0, 1.0],
    #     ]).repeat(opts.video_length, 1, 1).to(opts.device)
    
    # radius = (
    #         depths[0, 0, depths.shape[-2] // 2, depths.shape[-1] // 2].cpu()
    #         * opts.radius_scale
    #     )
    #     radius = min(radius, 5)
        # poses[:, 2, 3] = poses[:, 2, 3] + radius
        
    
    
    pc_input, color_input, radius = extract_point_cloud(frames, c2ws_init, opts)
    
    ########################################################
    # Camera Pose Planning
    ########################################################
            
    c2ws_target = generate_traj_specified(
        c2ws_init[0:1], 
        opts.target_pose, 
        opts.video_length * opts.n_splits, 
        opts.device
    )
    
    # c2ws_target[:, 2, 3] = c2ws_target[:, 2, 3] + radius
    
    # take inverse
    c2ws_target = torch.inverse(c2ws_target)
    
    # split into segments
    traj_segments = c2ws_target.view(opts.n_splits, opts.video_length, 4, 4)
    
    ########################################################
    # Autoregressive Generation
    ########################################################
    
    for i in range(opts.n_splits):
        
        segment_dir = os.path.join(opts.save_dir, f'stage_{i+1}')
        os.makedirs(segment_dir, exist_ok=True)
        
        inpainted_video = generate_segment(
            frames, pc_input, color_input, traj_segments[i], segment_dir, opts
            )
        
        pc_inpainted = pc_input
        color_inpainted = color_input
        # pc_inpainted, color_inpainted = extract_point_cloud(inpainted_video, traj_segments[i], opts)
        
        # pc_merged, color_merged = merge_point_clouds(pc_input, color_input, pc_inpainted, color_inpainted)
        pc_merged, color_merged = pc_inpainted, color_inpainted
                    
        save_segment_results(
            pc_input,
            color_input,
            pc_inpainted,
            color_inpainted,            
            pc_merged,
            color_merged,
            traj_segments[i],
            opts, segment_idx=i)
        
        frames = inpainted_video
        pc_input = pc_merged
        color_input = color_merged


In [None]:
# Use autoregressive generation for large trajectories
final_video = infer_autoregressive(opts)

In [None]:
sys.argv = [
    "",
    "--video_path", "/home/azhuravl/nobackup/DAVIS_testing/trainval/rhino.mp4",
    "--n_splits", "2",
    "--overlap_frames", "0",
    "--radius", "0",
    "--mode", "gradual",
]

parser = get_parser()
opts_base = parser.parse_args()

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
video_basename = os.path.splitext(os.path.basename(opts_base.video_path))[0]

# Setup
opts_base.weight_dtype = torch.bfloat16
opts_base.exp_name = f"{video_basename}_{timestamp}_autoregressive"
opts_base.save_dir = os.path.join(opts_base.out_dir, opts_base.exp_name)

# Create TrajCrafterVisualization instance for autoregressive generation
radius = opts_base.radius

variants = [
    ("right_90", [0, 90, radius, 0, 0]),
]
name = "right_90"
pose = [0, 90, radius, 0, 0]

print(f"\n=== Running Autoregressive {name} ===")
opts = copy.deepcopy(opts_base)
opts.exp_name = f"{video_basename}_{timestamp}_{name}_auto_s{opts_base.n_splits}"
opts.save_dir = os.path.join(opts.out_dir, opts.exp_name)
opts.camera = "target"
opts.target_pose = pose
opts.traj_txt = 'test/trajs/loop2.txt'

# Make directories
os.makedirs(opts.save_dir, exist_ok=True)


In [None]:
sys.path.append('/home/azhuravl/work/TrajectoryCrafter/notebooks/28_08_25_trajectories')
import core


In [None]:
import importlib
importlib.reload(core)

In [None]:

funwarp = core.VisualizationWarper(device=opts.device)
funwarp.device = opts.device

In [None]:
vis_crafter = TrajCrafterAutoregressive(opts_base)
