In [None]:
import os
import numpy as np
import glob
from PIL import Image
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation


def read_tartanair_extrinsic(extrinsic_path, side='left'):
    data = []
    camera_id = {'left': 0, 'right': 1}
    with open(extrinsic_path, 'r') as fp:
        lines = fp.readlines()
    for lineid, line in enumerate(lines):
        frame = int(lineid)
        values = line.rstrip().split(' ')
        assert len(values) == 7, 'Pose must be quaterion format -- 7 params, but {} got'.format(len(values))
        pose = np.array([float(values[i]) for i in range(len(values))])
        tx, ty, tz, qx, qy, qz, qw = pose
        R = Rotation.from_quat((qx, qy, qz, qw)).as_matrix()
        t = np.array([tx, ty, tz])
        T = np.eye(4)
        T[:3, :3] = R.transpose()
        T[:3, 3] = -R.transpose().dot(t)
        # ned(z-axis down) to z-axis forward
        m_correct = np.zeros_like(T)
        m_correct[0, 1] = 1
        m_correct[1, 2] = 1
        m_correct[2, 0] = 1
        m_correct[3, 3] = 1

        # m_correct
        T = np.matmul(m_correct, T)
        data.append(T)
        lineid += 1

    return data


def read_tartanair_sequence(sequence_path, max_frames=50):
    """
    Read TartanAir sequence data for debugging
    
    Args:
        sequence_path: Path to sequence directory (e.g., /path/to/abandonedfactory/Easy/P001)
        max_frames: Maximum number of frames to read
    
    Returns:
        dict with 'images', 'depths', 'poses', 'intrinsics'
    """
    
    # 1. Find RGB images (left camera only)
    image_dir = os.path.join(sequence_path, "image_left")
    image_files = sorted(glob.glob(os.path.join(image_dir, "*_left.png")))[:max_frames]
    print(f"Found {len(image_files)} RGB images")
    if len(image_files) > 0:
        print(f"First image: {image_files[0]}")
        print(f"Last image: {image_files[-1]}")
    
    # 2. Find depth files (left camera only)
    depth_dir = os.path.join(sequence_path, "depth_left")
    depth_files = sorted(glob.glob(os.path.join(depth_dir, "*_left_depth.npy")))[:max_frames]
    print(f"Found {len(depth_files)} depth files")
    if len(depth_files) > 0:
        print(f"First depth: {depth_files[0]}")
        print(f"Last depth: {depth_files[-1]}")
    
    # 3. Find pose file (left camera)
    pose_file = os.path.join(sequence_path, "pose_left.txt")
    print(f"Pose file exists: {os.path.exists(pose_file)}")
    print(f"Pose file path: {pose_file}")
    
    # 4. Load RGB images
    images = []
    for img_file in image_files:
        img = np.array(Image.open(img_file))
        images.append(img)
        print(f"Image {len(images)}: shape={img.shape}, dtype={img.dtype}")
        if len(images) == 1:  # Show info for first image only
            print(f"  Min={img.min()}, Max={img.max()}")
    
    # 5. Load depth files
    depths = []
    for depth_file in depth_files:
        depth = np.load(depth_file)
        depths.append(depth)
        print(f"Depth {len(depths)}: shape={depth.shape}, dtype={depth.dtype}")
        if len(depths) == 1:  # Show info for first depth only
            print(f"  Min={depth.min():.3f}, Max={depth.max():.3f}")
    
    # 6. Load poses
    poses = []
    pose_data = read_tartanair_extrinsic(pose_file, side='left')
    
    poses = pose_data[:max_frames]
    
    # print(len(pose_data), pose_data)
    
    # 7. TartanAir intrinsics (fixed for all sequences)
    K = np.array([[320.0, 0, 320.0],
                  [0, 320.0, 240.0],
                  [0, 0, 1]], dtype=np.float32)
    
    print(f"\nLoaded:")
    print(f"  {len(images)} RGB images")
    print(f"  {len(depths)} depth maps") 
    print(f"  {len(poses)} poses")
    print(f"  Intrinsics K:\n{K}")
    
    return {
        'images': images,
        'depths': depths, 
        'poses': poses,
        'intrinsics': K,
        'image_files': image_files,
        'depth_files': depth_files,
        'pose_file': pose_file
    }

def visualize_data(data, frame_idx=0):
    """Visualize a single frame"""
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Show RGB image
    axes[0].imshow(data['images'][frame_idx])
    axes[0].set_title(f'RGB Frame {frame_idx}')
    axes[0].axis('off')
    
    # Show depth map
    depth = data['depths'][frame_idx]
    im = axes[1].imshow(depth, cmap='viridis')
    axes[1].set_title(f'Depth Frame {frame_idx}')
    axes[1].axis('off')
    plt.colorbar(im, ax=axes[1])
    
    plt.tight_layout()
    plt.show()
    
    # Print pose info
    if frame_idx < len(data['poses']):
        pose = data['poses'][frame_idx]
        print(f"Frame {frame_idx} pose:")
        print(f"  Camera position: {pose[:3, 3]}")
        print(f"  Camera rotation matrix:\n{pose[:3, :3]}")


In [None]:
# Replace with your actual path
sequence_path = "/home/azhuravl/scratch/tartanair/abandonedfactory/Easy/P001"

# Load data
data = read_tartanair_sequence(sequence_path, max_frames=50)

# Check data consistency
print(f"\nData consistency check:")
print(f"  Images: {len(data['images'])}")
print(f"  Depths: {len(data['depths'])}")  
print(f"  Poses: {len(data['poses'])}")

if len(data['images']) == len(data['depths']) == len(data['poses']):
    print("✓ All data lengths match!")
else:
    print("✗ Data length mismatch!")

In [None]:
print(data.keys())

depths = data['depths']
images = data['images']
poses = data['poses']
K = data['intrinsics']

In [None]:
# convert lists of numpy arrays to tensors
depths = np.stack(depths, axis=0)
images = np.stack(images, axis=0)
poses = np.stack(poses, axis=0)
K = np.stack(K, axis=0)


print(f"Depths shape: {depths.shape}, dtype: {depths.dtype}")
print(f"Images shape: {images.shape}, dtype: {images.dtype}")
print(f"Poses shape: {poses.shape}, dtype: {poses.dtype}")
print(f"K shape: {K.shape}, dtype: {K.dtype}")

depths_tensor = torch.from_numpy(depths).unsqueeze(1)  # Convert to (N, 1, H, W)
frames_tensor = torch.from_numpy(images).permute(0, 3, 1, 2)  # Convert to (N, C, H, W)
poses_tensor = torch.from_numpy(poses)
K_tensor = torch.from_numpy(K)

# normalize to -1, 1
frames_tensor = frames_tensor.float() / 127.5 - 1.0


## Try TartanAir dataset

In [None]:
import sys
sys.path.append('/home/azhuravl/work')

import stereoanyvideo.datasets.video_datasets as video_datasets

In [None]:
import importlib
importlib.reload(video_datasets)

In [12]:
train_sequences = [
    # 'abandonedfactory/Easy/P001',
    # 'abandonedfactory/Easy/P005', 
    # 'office/Easy/P001',
    'office/Easy/P002',
    # 'office2/Easy/P001'
]

dataset_tartanair = video_datasets.TartanAirDataset(
        aug_params=None,
        root="/home/azhuravl/scratch/tartanair",
        split="train",
        sample_len=59,
        only_first_n_samples=-1,
        sampling_stride=3,          # Starting frame stride (default 3)
        min_temporal_step=1,        # Minimum temporal step (default 1)  
        max_temporal_step=2,        # Maximum temporal step (default 6)
        train_sequences=train_sequences
    )

In [14]:
data_0 = dataset_tartanair[100]

In [None]:
data_0['metadata']

In [None]:
data_0['RTK']

In [15]:
def extract_tartanair(data):
    """
    Extract frames, depths, poses, and camera intrinsics from data object.
    
    Args:
        data: Data object containing 'img', 'disp', and 'viewpoint'
        baseline: Baseline for depth calculation (default: 1)
    
    Returns:
        frames_tensor: [T, 3, H, W] in [-1, 1] range
        depths: [T, 1, H, W] depth maps
        poses_tensor: [T, 4, 4] camera poses
        K_tensor: [3, 3] camera intrinsics
    """
    # Convert to [-1, 1] range
    frames_tensor = data['img'][:,0] / 127.5 - 1.0  # [T, 3, H, W]
    depths_tensor = data['depth'][:,0:1,...]  # [T, 1, H, W]
    poses_tensor = data['RTK'][0]
    
    K_tensor = data['RTK'][1]  # [3, 3]
    
    return frames_tensor, depths_tensor, poses_tensor, K_tensor



In [16]:
frames_tensor, depths_tensor, poses_tensor, K_tensor = extract_tartanair(data_0)

## Warp

In [8]:
sys.path.append('/home/azhuravl/work/TrajectoryCrafter')

import models.utils as utils

warper_old = utils.Warper(device='cuda')

In [17]:
from tqdm import tqdm

warped_images = []
masks = []
warped_depths = []

for i in tqdm(range(10, frames_tensor.shape[0])):
    warped_frame2, mask2, warped_depth2, flow12 = warper_old.forward_warp(
        frame1=frames_tensor[i:i+1],
        mask1=None,
        depth1=depths_tensor[i:i+1],
        transformation1=poses_tensor[i:i+1],
        transformation2=poses_tensor[10:11],
        intrinsic1=K_tensor.unsqueeze(0),
        intrinsic2=K_tensor.unsqueeze(0),
        mask=False,
        twice=True,
    )
    # print(warped_frame2[0])
    warped_images.append(warped_frame2)
    masks.append(mask2)
    warped_depths.append(warped_depth2)
    
    # print(warped_frame2.shape, mask2.shape, warped_depth2.shape)

In [18]:
import numpy as np
import matplotlib.pyplot as plt

# plot warped image j, mask j, warped depth j
j = 0
k = 30

frame = frames_tensor[j].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
warped_image = warped_images[k][0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5
target_frame = frames_tensor[k].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5

plt.figure(figsize=(18,6))
plt.subplot(1,3,1)
plt.imshow(frame)
plt.title('Source Frame')
plt.axis('off')

plt.subplot(1,3,2)
plt.imshow(warped_image)
plt.title('Warped Image to Frame {}'.format(j))
plt.axis('off')

plt.subplot(1,3,3)
plt.imshow(target_frame)
plt.title('Target Frame {}'.format(j))
plt.axis('off')

# plt.subplot(1,3,1)
# plt.imshow(depths[10+j].cpu().permute(1, 2, 0).numpy(), cmap='plasma')
# plt.title('Warped Image to Frame {}'.format(j))
# plt.axis('off')
# plt.colorbar()

# plt.subplot(1,3,2)
# plt.imshow((warped_depths[j][0].cpu().permute(1, 2, 0).numpy() + 1e-2), cmap='plasma')
# plt.title('Warped Image to Frame {}'.format(j))
# plt.axis('off')
# plt.colorbar()

# plt.subplot(1,3,3)
# plt.imshow(masks[j][0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5, cmap='gray')
# plt.title('Mask')
# plt.axis('off')
# plt.show()


In [19]:
from models.utils import save_video
import os
import torch

def make_dimensions_even(tensor):
    """Pad tensor to make height and width even numbers"""
    _, h, w, c = tensor.shape
    pad_h = h % 2
    pad_w = w % 2
    
    if pad_h > 0 or pad_w > 0:
        # Pad bottom and right if needed
        tensor = torch.nn.functional.pad(tensor, (0, 0, 0, pad_w, 0, pad_h))
    
    return tensor


cond_video = (torch.cat(warped_images) + 1.0) / 2.0  # [T, 3, H, W] in [0,1]
cond_video_padded = make_dimensions_even(
    cond_video.permute(0, 2, 3, 1)
)

save_video(
    cond_video_padded,
    '/home/azhuravl/work/TrajectoryCrafter/notebooks/22_10_25_scaling_up/warped_video.mp4',
    fps=10,
)
# --- save inputs for visualization ---

input_video_padded = make_dimensions_even(
    (frames_tensor[10:].permute(0, 2, 3, 1) + 1.0) / 2.0
)

save_video(
    input_video_padded,
    '/home/azhuravl/work/TrajectoryCrafter/notebooks/22_10_25_scaling_up/input_video.mp4',
    fps=10,
)

warped_depths_tensor = torch.cat(warped_depths)
# Apply before saving
warped_depths_padded = make_dimensions_even(
    (warped_depths_tensor.permute(0, 2, 3, 1).repeat(1, 1, 1, 3)) / warped_depths_tensor.max()
)
save_video(
    warped_depths_padded,
    '/home/azhuravl/work/TrajectoryCrafter/notebooks/22_10_25_scaling_up/warped_depths.mp4',
    fps=10,
)

depths_padded = make_dimensions_even(
    (depths_tensor.permute(0, 2, 3, 1).repeat(1, 1, 1, 3)) / depths_tensor.max()
)
save_video(
    depths_padded,
    '/home/azhuravl/work/TrajectoryCrafter/notebooks/22_10_25_scaling_up/input_depths.mp4',
    fps=10,
)

## Sample diffusion

In [None]:
sys.path.append('/home/azhuravl/work/TrajectoryCrafter/notebooks/06_10_25_vggt')
from parsing import get_parser
import utils_autoregressive as utils_ar
from datetime import datetime
import os
import copy


sys.argv = [
    "",
    "--video_path", "/home/azhuravl/nobackup/DAVIS_testing/trainval/monkaa.mp4",
    "--n_splits", "4",
    "--overlap_frames", "0",
    "--radius", "0",
    "--mode", "gradual",
]

parser = get_parser()
opts_base = parser.parse_args()

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
video_basename = os.path.splitext(os.path.basename(opts_base.video_path))[0]

# Setup
opts_base.weight_dtype = torch.bfloat16
opts_base.exp_name = f"{video_basename}_{timestamp}_autoregressive"
opts_base.save_dir = os.path.join(opts_base.out_dir, opts_base.exp_name)

# Create TrajCrafterVisualization instance for autoregressive generation
radius = opts_base.radius

variants = [
    ("right_90", [0, 90, radius, 0, 0]),
]

pose = [90, 0, 0, 0, 1]
name = f"{pose[0]}_{pose[1]}_{pose[2]}_{pose[3]}_{pose[4]}"

opts = copy.deepcopy(opts_base)
opts.exp_name = f"{video_basename}_{timestamp}_{name}_auto_s{opts_base.n_splits}"
opts.save_dir = os.path.join(opts.out_dir, opts.exp_name)
opts.camera = "target"
opts.target_pose = pose
opts.traj_txt = 'test/trajs/loop2.txt'

# Make directories
os.makedirs(opts.save_dir, exist_ok=True)

In [None]:
trajcrafter = utils_ar.TrajCrafterAutoregressive(opts)

In [None]:
import numpy as np

# frames_tensor = (
    # torch.from_numpy(frames_np).permute(0, 3, 1, 2).to(opts.device) * 2.0 - 1.0
  # 49 576 1024 3 -> 49 3 576 1024, [-1,1]
# reverse this to get frames in numpy
frames_np = ((frames_tensor.cpu().permute(0, 2, 3, 1).numpy() + 1.0) / 2.0).astype(np.float32)

trajcrafter.prompt = trajcrafter.get_caption(opts, frames_np[opts.video_length // 2])
print(trajcrafter.prompt)

In [None]:
importlib.reload(utils_ar)

In [None]:
_, segment_dir = utils_ar.sample_diffusion(
    trajcrafter,
    frames_tensor[10:],
    warped_images,
    frames_tensor[:10],
    masks,
    opts,
)

In [None]:
%prun data_0 = dataset_tartanair[0]
