In [5]:
import torch
import numpy as np
import copy
import os
from pathlib import Path
import cv2
import shutil

from demo import TrajCrafter
from models.utils import Warper, read_video_frames, sphere2pose, save_video
import torch.nn.functional as F
from tqdm import tqdm

from models.infer import DepthCrafterDemo

import os
from datetime import datetime
import torch
import copy
import time
import sys
import tempfile
from pathlib import Path

# Add core.py to path if needed
sys.path.append('/home/azhuravl/work/TrajectoryCrafter/notebooks/28_08_25_trajectories')
from core import VisualizationWarper

sys.path.append('/home/azhuravl/work/TrajectoryCrafter/notebooks/06_10_25_vggt')
from parsing import get_parser


class TrajCrafterAutoregressive(TrajCrafter):
    def __init__(self, opts):
        super().__init__(opts)

        # self.funwarp = VisualizationWarper(device=opts.device)
        self.prompt = None
        
        self.K = torch.tensor(
            [[500, 0.0, 512.], [0.0, 500, 288.], [0.0, 0.0, 1.0]]
            ).repeat(opts.video_length, 1, 1).to(opts.device)

In [6]:
sys.argv = [
    "",
    "--video_path", "/home/azhuravl/nobackup/DAVIS_testing/trainval/rhino.mp4",
    "--n_splits", "4",
    "--overlap_frames", "0",
    "--radius", "0",
    "--mode", "gradual",
]

parser = get_parser()
opts_base = parser.parse_args()

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
video_basename = os.path.splitext(os.path.basename(opts_base.video_path))[0]

# Setup
opts_base.weight_dtype = torch.bfloat16
opts_base.exp_name = f"{video_basename}_{timestamp}_autoregressive"
opts_base.save_dir = os.path.join(opts_base.out_dir, opts_base.exp_name)

# Create TrajCrafterVisualization instance for autoregressive generation
radius = opts_base.radius

variants = [
    ("right_90", [0, 90, radius, 0, 0]),
]
# name = "right_90"
# pose = [0, 90, radius, 0, 0]

# name = "top_90"
# pose = [90, 0, radius, 0, 0]

pose = [90, 0, 0, 0, 1]
# name = '120_0_0_0_3', make it infer values from pose
name = f"{pose[0]}_{pose[1]}_{pose[2]}_{pose[3]}_{pose[4]}"


print(f"\n=== Running Autoregressive {name} ===")
opts = copy.deepcopy(opts_base)
opts.exp_name = f"{video_basename}_{timestamp}_{name}_auto_s{opts_base.n_splits}"
opts.save_dir = os.path.join(opts.out_dir, opts.exp_name)
opts.camera = "target"
opts.target_pose = pose
opts.traj_txt = 'test/trajs/loop2.txt'

# Make directories
os.makedirs(opts.save_dir, exist_ok=True)


In [7]:
vis_crafter = TrajCrafterAutoregressive(opts_base)

In [2]:
import importlib
importlib.reload(warper_point_cloud)

In [8]:
import warper_point_cloud

# funwarp = VisualizationWarper(device=opts.device)
funwarp = warper_point_cloud.GlobalPointCloudWarper(device=opts.device, max_points=2000000)

In [9]:
import models.utils as utils
from utils_autoregressive import pad_video, generate_traj_specified, clean_single_mask_simple

# read input video

frames_np = utils.read_video_frames(
    opts.video_path, opts.video_length, opts.stride, opts.max_res,
    # height=opts.sample_size[0], width=opts.sample_size[1],
)

# pad if too short
frames_np = pad_video(frames_np, opts.video_length)
# frames_np = frames_np[::-1, ...].copy()


frames_tensor = (
    torch.from_numpy(frames_np).permute(0, 3, 1, 2).to(opts.device) * 2.0 - 1.0
)  # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
assert frames_tensor.shape[0] == opts.video_length



# prompt
vis_crafter.prompt = vis_crafter.get_caption(opts, frames_np[opts.video_length // 2])

In [10]:
############################################
# Input Depth
############################################

# TODO: this takes frames as 1024 x 576? size
# the sample will be about 640 x 360 - is it ok?
depths = vis_crafter.depth_estimater.infer(
    frames_np,
    opts.near,
    opts.far,
    opts.depth_inference_steps,
    opts.depth_guidance_scale,
    window_size=opts.window_size,
    overlap=opts.overlap,
).to(opts.device)

In [11]:
##########################################
# Cameras
##########################################

radius = (
    depths[0, 0, depths.shape[-2] // 2, depths.shape[-1] // 2].cpu()
    * opts.radius_scale
)
radius = min(radius, 5)

# radius = 10


c2ws_anchor = torch.tensor([ 
            [-1.0, 0.0, 0.0, 0.0],
            [0.0, 1.0, 0.0, 0.0],
            [0.0, 0.0, -1.0, 0.0],
            [0.0, 0.0, 0.0, 1.0],
    ]).unsqueeze(0).to(opts.device)

c2ws_target = generate_traj_specified(
    c2ws_anchor, 
    opts.target_pose, 
    opts.video_length * opts.n_splits, 
    opts.device
)
c2ws_target[:, 2, 3] += radius

c2ws_init = c2ws_target[0].repeat(opts.video_length, 1, 1)


traj_segments = c2ws_target.view(opts.n_splits, opts.video_length, 4, 4)


In [14]:
import copy

point_clouds = []
colors_list = []
weights_list = []

global_pc = []
global_colors = []

for i in tqdm(range(opts.video_length)):
    with torch.no_grad():
        points, colors, weights = funwarp.create_pointcloud_from_image(
            frames_tensor[i:i+1],
            None,
            depths[i:i+1],
            c2ws_init[i:i+1],
            vis_crafter.K[i:i+1],
            1,
        )
    point_clouds.append(points)
    colors_list.append(colors)  
    weights_list.append(weights)
    
    global_pc.append(points)
    global_colors.append(colors)
    
    

In [17]:
warped_images = []
warped_depths = []
masks = []        

for i in tqdm(range(opts.video_length)):

    warped_image, mask, warped_depth = funwarp.render_pointcloud_zbuffer_vectorized_point_size(
        point_clouds[i],
        colors_list[i],
        c2ws_target[i:i+1],
        vis_crafter.K[0:1].to(opts.device),
        (576, 1024),
        point_size=2,
        return_depth=True
    )
    
    # single_mask = masks[10][0]  # Shape: (1, H, W)
    # print(single_mask.shape)
    cleaned_mask = clean_single_mask_simple(
        mask[0],
        kernel_size=9,
        n_erosion_steps=1,
        n_dilation_steps=1
        )
    # should stay in [-1, 1] range
    
    cleaned_mask = cleaned_mask.unsqueeze(0)
    
    warped_image = warped_image * cleaned_mask
    warped_depth = warped_depth * cleaned_mask
    
    warped_images.append(warped_image)
    warped_depths.append(warped_depth)
    masks.append(cleaned_mask)


In [22]:
cleaned_mask.min(), cleaned_mask.max()

In [23]:
warped_depths[0].shape, warped_depths[10].min(), warped_depths[10].max()

In [21]:
# plot the result
import matplotlib.pyplot as plt

plt.figure(figsize=(20,6))
plt.subplot(1,3,1)

plt.imshow((warped_images[30][0].permute(1,2,0).cpu().numpy() + 1) / 2)
plt.axis('off')

plt.subplot(1,3,2)
plt.imshow((warped_depths[30][0].permute(1,2,0).cpu().numpy()))
plt.axis('off')
plt.colorbar(shrink=0.5)

plt.subplot(1,3,3)
# plt.imshow(masks[10][0].permute(1,2,0).cpu().numpy(), cmap='gray')
plt.imshow((frames_tensor[30].permute(1,2,0).cpu().numpy() + 1) / 2, cmap='gray')
plt.axis('off')
plt.show()