In [None]:
import sys
sys.path.append('/home/azhuravl/work')

In [None]:
import stereoanyvideo.datasets.video_datasets as video_datasets

In [None]:
import importlib
importlib.reload(video_datasets)

In [None]:
dataset_monkaa = video_datasets.SequenceSceneFlowDatasetCamera(
    aug_params=None,
    root="/home/azhuravl/nobackup/SceneFlow",
    dstype="frames_cleanpass",
    sample_len=49,
    things_test=False,
    add_things=False,
    add_monkaa=True,
    add_driving=False,
    split="test"
)

In [None]:
data_0 = dataset_monkaa[0]

In [None]:
data_0.keys()

In [None]:
# If you have a PyTorch3D camera from your output_tensor
viewpoint = data_0["viewpoint"][0][0]  # First frame, left camera

# Convert to OpenCV format
opencv_params = video_datasets.pytorch3d_to_opencv_camera_general(viewpoint, (540, 960))

# Access the parameters
K = opencv_params['K']          # 3x3 intrinsic matrix
R = opencv_params['R']          # 3x3 rotation matrix
t = opencv_params['t']          # 3x1 translation vector

In [None]:
# calculate depth from disparity using torch

import torch
disp = data_0['disp'][0][0]
focal_length = K[0, 0]
baseline = 1
depth = -(focal_length * baseline) / disp

In [None]:
# plot rgb and disparity for the first frame
import matplotlib.pyplot as plt
plt.figure(figsize=(12,6))
plt.subplot(1,3,1)
plt.imshow(data_0['img'][0][0].permute(1, 2, 0).int().numpy())
plt.title('RGB Frame 0')
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(data_0['disp'][0][0].permute(1, 2, 0).numpy(), cmap='plasma')
plt.title('Disparity Frame 0')
plt.axis('off')
plt.colorbar()
plt.subplot(1,3,3)
plt.imshow(depth.permute(1, 2, 0).numpy(), cmap='plasma')
plt.title('Depth Frame 0')
plt.axis('off')
plt.colorbar()
plt.show()

In [None]:
import sys
sys.path.append('/home/azhuravl/work/TrajectoryCrafter/notebooks/06_10_25_vggt')

import warper_point_cloud

In [None]:
warper = warper_point_cloud.GlobalPointCloudWarper(device='cuda')

In [None]:
data_0['img'].shape

In [None]:
def extract_video_data(data, baseline=1):
    """
    Extract frames, depths, poses, and camera intrinsics from data object.
    
    Args:
        data: Data object containing 'img', 'disp', and 'viewpoint'
        baseline: Baseline for depth calculation (default: 1)
    
    Returns:
        frames_tensor: [T, 3, H, W] in [-1, 1] range
        depths: [T, 1, H, W] depth maps
        poses_tensor: [T, 4, 4] camera poses
        K_tensor: [T, 3, 3] camera intrinsics
    """
    # Convert to [-1, 1] range
    frames_tensor = data['img'][:,0] / 127.5 - 1.0  # [T, 3, H, W]
    disparity_tensor = data['disp'][:,0]  # [T, 1, H, W]
    
    poses_list = []
    K_list = []
    for i in range(frames_tensor.shape[0]):
        viewpoint = data["viewpoint"][i][0]
        opencv_params = video_datasets.pytorch3d_to_opencv_camera_general(viewpoint, (540, 960))
        R = opencv_params['R']
        t = opencv_params['t']
        pose = torch.eye(4)
        pose[:3, :3] = R
        pose[:3, 3] = t.squeeze()
        poses_list.append(pose)
        
        K_list.append(opencv_params['K'])

    poses_tensor = torch.stack(poses_list)  # [T, 4, 4]
    K_tensor = torch.stack(K_list)  # [T, 3, 3]
    
    # Calculate focal length from K tensor
    focal_length = K_tensor[0, 0, 0]
    
    depths = -(focal_length * baseline) / disparity_tensor  # [T, 1, H, W]
    
    return frames_tensor, depths, poses_tensor, K_tensor

In [None]:
import torch

frames_tensor, depths, poses_tensor, K_tensor = extract_video_data(data_0)

In [None]:
frames_tensor.shape

In [None]:
from tqdm import tqdm

pc_list = []
color_list = []

with torch.no_grad():
    for i in tqdm(range(frames_tensor.shape[0])):
        points, colors, _ = warper.create_pointcloud_from_image(
            frames_tensor[i:i+1],
            None,
            depths[i:i+1],
            poses_tensor[i:i+1],
            K_tensor[i:i+1],
            1,
        )
        pc_list.append(points)
        color_list.append(colors)

In [None]:
warped_images = []
masks = []        

for i in tqdm(range(frames_tensor.shape[0])):

    warped_image, mask = warper.render_pointcloud_zbuffer_vectorized_point_size(
        pc_list[i],
        color_list[i],
        poses_tensor[0:1].to('cuda'),
        K_tensor[0:1].to('cuda'),
        (540, 960),
        point_size=2,
    )
    
    cleaned_mask = utils_ar.clean_single_mask_simple(
        mask[0],
        kernel_size=9,
        n_erosion_steps=1,
        n_dilation_steps=1
        )
    # should stay in [-1, 1] range
    
    cleaned_mask = cleaned_mask.unsqueeze(0)
    warped_image = warped_image * cleaned_mask
    
    warped_images.append(warped_image)
    masks.append(cleaned_mask)


In [None]:
import matplotlib.pyplot as plt

# plt.imshow(warped_image[0].cpu().permute(1, 2, 0).numpy())
# show image and mask
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.imshow(warped_images[40][0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
plt.title('Warped Image to Frame 10')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(masks[40][0].cpu().permute(1, 2, 0).numpy(), cmap='gray')
plt.title('Mask')
plt.axis('off')
plt.show()

In [None]:
# plot images 0 and 10
plt.figure(figsize=(12,6))
plt.subplot(1,2,1)
plt.imshow(frames_tensor[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
plt.title('Original Image Frame 0')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(frames_tensor[40].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
plt.title('Original Image Frame 10')
plt.axis('off')
plt.show()

## Run Diffusion

In [None]:
sys.path.append('/home/azhuravl/work/TrajectoryCrafter/notebooks/06_10_25_vggt')
from parsing import get_parser
import utils_autoregressive as utils_ar
from datetime import datetime
import os
import copy


sys.argv = [
    "",
    "--video_path", "/home/azhuravl/nobackup/DAVIS_testing/trainval/monkaa.mp4",
    "--n_splits", "4",
    "--overlap_frames", "0",
    "--radius", "0",
    "--mode", "gradual",
]

parser = get_parser()
opts_base = parser.parse_args()

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
video_basename = os.path.splitext(os.path.basename(opts_base.video_path))[0]

# Setup
opts_base.weight_dtype = torch.bfloat16
opts_base.exp_name = f"{video_basename}_{timestamp}_autoregressive"
opts_base.save_dir = os.path.join(opts_base.out_dir, opts_base.exp_name)

# Create TrajCrafterVisualization instance for autoregressive generation
radius = opts_base.radius

variants = [
    ("right_90", [0, 90, radius, 0, 0]),
]

pose = [90, 0, 0, 0, 1]
name = f"{pose[0]}_{pose[1]}_{pose[2]}_{pose[3]}_{pose[4]}"

opts = copy.deepcopy(opts_base)
opts.exp_name = f"{video_basename}_{timestamp}_{name}_auto_s{opts_base.n_splits}"
opts.save_dir = os.path.join(opts.out_dir, opts.exp_name)
opts.camera = "target"
opts.target_pose = pose
opts.traj_txt = 'test/trajs/loop2.txt'

# Make directories
os.makedirs(opts.save_dir, exist_ok=True)

In [None]:
importlib.reload(utils_ar)

In [None]:
trajcrafter = utils_ar.TrajCrafterAutoregressive(opts)

In [None]:
import numpy as np

# frames_tensor = (
    # torch.from_numpy(frames_np).permute(0, 3, 1, 2).to(opts.device) * 2.0 - 1.0
  # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
# reverse this to get frames in numpy
frames_np = ((frames_tensor.cpu().permute(0, 2, 3, 1).numpy() + 1.0) / 2.0).astype(np.float32)

trajcrafter.prompt = trajcrafter.get_caption(opts, frames_np[opts.video_length // 2])
print(trajcrafter.prompt)

In [None]:
_, segment_dir = utils_ar.sample_diffusion(
    trajcrafter,
    frames_tensor,
    warped_images,
    frames_tensor[:10],
    masks,
    opts,
)

In [None]:
collected_features = trajcrafter.pipeline.collected_features

In [None]:
for keys in collected_features['timestep_839'].keys():
    print(keys, '           ', collected_features['timestep_839'][keys].shape)

In [None]:
384 * 672*49, 13104 * 3072

In [None]:
# get size of collected features in MB
total_size = 0
for timestep in collected_features.keys():
    for keys in collected_features[timestep].keys():
        total_size += collected_features[timestep][keys].element_size() * collected_features[timestep][keys].nelement()
print(f"Total size of collected features: {total_size / (1024 ** 2):.2f} MB")        

In [None]:
opts.sample_size

In [None]:
# TODO
# fix cameras - why it goes left-up?
# how to get GT depth
# why do we have so many features?
# generate 100 features + depths samples