## Dataset

In [None]:
# import sys
# sys.path.append('/home/azhuravl/work/TrajectoryCrafter/notebooks/05_11_25_training/lora_utils_ours')

# from dataset_videos import SimpleValidationDataset  # Add this import


# dataset = SimpleValidationDataset(
#     validation_dir='/home/azhuravl/scratch/datasets_latents/monkaa_1000',
#     use_depth=True,
#     max_samples=100,
#     num_ref_frames=49,
# )

In [None]:
# data_0 = dataset[0]
# data_0

In [None]:
import torch
torch.cuda.memory._record_memory_history()

## Video Depth Anything

In [None]:
import sys
sys.path.append('/home/azhuravl/work/Video-Depth-Anything')

from video_depth_anything.video_depth import VideoDepthAnything
from utils.dc_utils import read_video_frames, save_video


In [None]:
class ArgsVDA:
    input_video = '/home/azhuravl/scratch/datasets_latents/monkaa_1000/000/videos/input_video.mp4'
    output_dir = '/home/azhuravl/work/Video-Depth-Anything/outputs'
    input_size = 256
    max_res = 1280
    encoder = 'vitl'
    max_len = -1
    target_fps = -1
    metric = False
    fp32 = False
    grayscale = False
    save_npz = False
    save_exr = False
    focal_length_x = 470.4
    focal_length_y = 470.4
    
args_vda = ArgsVDA()

In [None]:
import torch

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

model_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
}
checkpoint_name = 'video_depth_anything'

video_depth_anything = VideoDepthAnything(**model_configs[args_vda.encoder], metric=args_vda.metric)
video_depth_anything.load_state_dict(torch.load(
    f'/home/azhuravl/work/Video-Depth-Anything/checkpoints/{checkpoint_name}_{args_vda.encoder}.pth', 
    map_location='cpu'), strict=True)
video_depth_anything = video_depth_anything.to(DEVICE).eval()

In [None]:
# video_depth_anything to bf16

# video_depth_anything = video_depth_anything.to(torch.bfloat16)

In [None]:
# disable grad for video_depth_anything
for param in video_depth_anything.parameters():
    param.requires_grad = False

In [None]:
video_depth_anything.encoder

In [None]:
# print number of trainable parameters
num_trainable_params = sum(p.numel() for p in video_depth_anything.head.parameters() if p.requires_grad)
print(f'Number of trainable parameters in VideoDepthAnything: {num_trainable_params}')

In [None]:

frames, target_fps = read_video_frames(args_vda.input_video, args_vda.max_len, args_vda.target_fps, args_vda.max_res)

with torch.cuda.amp.autocast():
    depths, fps = video_depth_anything.infer_video_depth(
        frames,
        target_fps, input_size=args_vda.input_size, device=DEVICE, fp32=args_vda.fp32)


In [None]:
import numpy as np
import matplotlib.cm as cm

colormap = np.array(cm.get_cmap("inferno").colors)
d_min, d_max = depths.min(), depths.max()

depth_vis_list = []
for i in range(depths.shape[0]):
    depth = depths[i]
    depth_norm = ((depth - d_min) / (d_max - d_min) * 255).astype(np.uint8)
    depth_vis = (colormap[depth_norm] * 255).astype(np.uint8) if not args_vda.grayscale else depth_norm
    depth_vis_list.append(depth_vis)


In [None]:
import matplotlib.pyplot as plt

plt.imshow(depths[0])
plt.colorbar()

## Test Time Finetuning

In [None]:
import os
import sys
import cv2
import copy
import time
import tqdm
import warnings
import argparse
import torch.optim 
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from importlib import import_module 
import torchvision.transforms.functional as TF


class Arguments:
    gpu = '0'
    random_seed = 2025
    epochs = 50
    exp_name = 'base'
    mode = 'VP'  # choices=['VP', 'FT']
    dataset = 'ibims'  # choices=['ibims', 'ddad']
    dataset_path = '/workspace/data_all'
    
args_ttt = Arguments()

In [None]:
# from unidepth_custom.models import UniDepthV2


def compute_scale_and_shift(predicted_depth, sparse_depth):

    valid_mask = (sparse_depth > 0)
    
    pred_valid = predicted_depth[valid_mask]   
    sparse_valid = sparse_depth[valid_mask]    
    
    if pred_valid.numel() == 0:
        device = predicted_depth.device
        dtype = predicted_depth.dtype
        return torch.tensor(1.0, device=device, dtype=dtype), torch.tensor(0.0, device=device, dtype=dtype)
    
    X = torch.stack([pred_valid, torch.ones_like(pred_valid)], dim=1)
    
    a = torch.pinverse(X) @ sparse_valid 
    scale = a[0]
    shift = a[1]
    
    return scale, shift



In [None]:
depth_gt = torch.load('/home/azhuravl/scratch/datasets_latents/monkaa_1000/000/videos/input_depths.pt', weights_only=True).squeeze(1) 
depth_gt_inv = 1.0 / (depth_gt + 1e-8)

depth_warped = torch.load('/home/azhuravl/scratch/datasets_latents/monkaa_1000/000/videos/warped_depths.pt', weights_only=True).squeeze(1)
# areas that are 0 should stay 0 after inversion
depth_warped_inv = torch.where(depth_warped > 0, 1.0 / depth_warped, torch.zeros_like(depth_warped))

In [None]:
# show depths[0] and depth_gt[0] side by side with colorbars
fig, axs = plt.subplots(1, 3, figsize=(12, 6))
im1 = axs[0].imshow(depths[40], cmap='inferno')
axs[0].set_title('Predicted Depth[0]')
fig.colorbar(im1, ax=axs[0])
im2 = axs[1].imshow(depth_gt_inv[40], cmap='inferno')
axs[1].set_title('Ground Truth Depth[0]')
fig.colorbar(im2, ax=axs[1])  

im3 = axs[2].imshow(depth_warped_inv[40], cmap='inferno')
axs[2].set_title('Warped Depth[0]')
fig.colorbar(im3, ax=axs[2])  


In [None]:
sys.path.append('/home/azhuravl/work/Video-Depth-Anything/video_depth_anything/util')

from transform import Resize, NormalizeImage, PrepareForNet
from torchvision.transforms import Compose
import cv2
import numpy as np
import torch.nn.functional as F
import torch

# infer settings, do not change
INFER_LEN = 32
OVERLAP = 10
KEYFRAMES = [0,12,24,25,26,27,28,29,30,31]
INTERP_LEN = 8


def prepare_frames(frames, input_size=518):
    """
    Prepare frames for inference by resizing and normalizing.
    
    Args:
        frames: numpy array of shape [T, H, W, C] containing video frames
        input_size: target input size for the model
    
    Returns:
        torch.Tensor: processed frames ready for model input [1, T, C, H, W]
        tuple: original frame dimensions (height, width)
    """
    if frames.shape[0] != INFER_LEN:
        raise ValueError(f"Expected {INFER_LEN} frames, but got {frames.shape[0]} frames")
    
    frame_height, frame_width = frames[0].shape[:2]
    ratio = max(frame_height, frame_width) / min(frame_height, frame_width)
    
    # Adjust input size for very wide/tall videos
    if ratio > 1.78:
        input_size = int(input_size * 1.777 / ratio)
        input_size = round(input_size / 14) * 14

    transform = Compose([
        Resize(
            width=input_size,
            height=input_size,
            resize_target=False,
            keep_aspect_ratio=True,
            ensure_multiple_of=14,
            resize_method='lower_bound',
            image_interpolation_method=cv2.INTER_CUBIC,
        ),
        NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        PrepareForNet(),
    ])

    # Process all frames
    processed_frames = []
    for i in range(INFER_LEN):
        frame_tensor = torch.from_numpy(
            transform({'image': frames[i].astype(np.float32) / 255.0})['image']
        ).unsqueeze(0).unsqueeze(0)
        processed_frames.append(frame_tensor)
    
    input_tensor = torch.cat(processed_frames, dim=1)
    
    return input_tensor, (frame_height, frame_width)


In [None]:
frames_resized, orig_dims = prepare_frames(frames[:32], input_size=args_vda.input_size)

# now interpolate depths to shape of frames_resized
depths_gt_inv_resized = F.interpolate(
    depth_gt_inv[:32].unsqueeze(1),
    size=frames_resized.shape[3:],
    mode='bilinear',
).unsqueeze(0)

depths_warped_inv_resized = F.interpolate(
    depth_warped_inv[:32].unsqueeze(1),
    size=frames_resized.shape[3:],
    mode='nearest',
).unsqueeze(0)


frames_resized.shape, depths_gt_inv_resized.shape, depths_warped_inv_resized.shape

In [None]:
# show depths[0] and depth_gt[0] side by side with colorbars
fig, axs = plt.subplots(1, 3, figsize=(12, 6))
im1 = axs[0].imshow(frames_resized[0, 20].permute(1, 2, 0), cmap='inferno')
axs[0].set_title('Predicted Depth[0]')
fig.colorbar(im1, ax=axs[0], shrink=0.3)
im2 = axs[1].imshow(depths_gt_inv_resized[0, 20, 0], cmap='inferno')
axs[1].set_title('Ground Truth Depth[0]')
fig.colorbar(im2, ax=axs[1], shrink=0.3)  

im3 = axs[2].imshow(depths_warped_inv_resized[0, 20, 0], cmap='inferno')
axs[2].set_title('Warped Depth[0]')
fig.colorbar(im3, ax=axs[2], shrink=0.3)  


In [None]:

# rmse_mean_ft, mae_mean_ft = 0.0, 0.0
# rmse_mean_vp, mae_mean_vp = 0.0, 0.0


# rgb = frames_resized.cuda()
# depth = depths_gt_inv_resized.cuda()
# sparse_depth = depths_warped_inv_resized.cuda()


# gt_mask = depth > 0
# sparse_mask = sparse_depth > 0

# # Visual Prompt
# # visual_prompt = torch.nn.Parameter(torch.zeros_like(rgb, device='cuda'))
# # optimizer = torch.optim.AdamW([{'params': visual_prompt, 'lr': 2e-3}])

# pbar = tqdm.tqdm(total=args_ttt.epochs)

# # Create prompt for single frame and repeat across time
# single_frame_prompt = torch.nn.Parameter(torch.zeros_like(rgb[:, :1], device='cuda'))  # [1, 1, C, H, W]
# visual_prompt = single_frame_prompt.repeat(1, rgb.shape[1], 1, 1, 1)  # Repeat across time dimension
# optimizer = torch.optim.AdamW([{'params': single_frame_prompt, 'lr': 2e-3}])


In [None]:
import torch, pickle

# snapshot = torch.cuda.memory_snapshot()
# with open("/home/azhuravl/work/TrajectoryCrafter/notebooks/05_11_25_training/snapshot.pickle", "wb") as f:
#     pickle.dump(snapshot, f)

torch.cuda.memory._dump_snapshot("/home/azhuravl/work/TrajectoryCrafter/notebooks/05_11_25_training/snapshot_before.pickle")


In [None]:
rmse_mean_ft, mae_mean_ft = 0.0, 0.0
rmse_mean_vp, mae_mean_vp = 0.0, 0.0

# Keep model in original precision, don't convert to bf16
# video_depth_anything = video_depth_anything.to(torch.bfloat16)

# Convert data to bf16
rgb = frames_resized.to(torch.bfloat16).cuda()
depth = depths_gt_inv_resized.to(torch.bfloat16).cuda()
sparse_depth = depths_warped_inv_resized.squeeze(2).to(torch.bfloat16).cuda()

single_frame_prompt = torch.nn.Parameter(torch.zeros_like(rgb[:, :1], dtype=torch.bfloat16, device='cuda'))

gt_mask = depth > 0
sparse_mask = sparse_depth > 0

optimizer = torch.optim.AdamW([{'params': single_frame_prompt, 'lr': 2e-3}])
scaler = torch.cuda.amp.GradScaler()

pbar = tqdm.tqdm(total=args_ttt.epochs)

for epoch in range(args_ttt.epochs):
    visual_prompt = single_frame_prompt.repeat(1, rgb.shape[1], 1, 1, 1)
    
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
        new_rgb = rgb + visual_prompt
        pre_depth_ = video_depth_anything.forward(new_rgb)
        scale, shift = compute_scale_and_shift(pre_depth_, sparse_depth)
        pre_depth = pre_depth_ * scale + shift
        loss_l1 = F.l1_loss(pre_depth[sparse_mask], sparse_depth[sparse_mask])
        loss_rmse = torch.sqrt(((pre_depth[sparse_mask] - sparse_depth[sparse_mask]) ** 2).mean())
        loss = loss_l1 + loss_rmse

    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    pbar.set_description(f'exp: {args_ttt.exp_name} l1: {loss_l1.item():.4f} rmse: {loss_rmse.item():.4f}')
    pbar.update()

pbar.close()

In [None]:
pre_depth_.shape

In [None]:
sparse_depth.shape

In [None]:

# for epoch in range(args_ttt.epochs):               
#     new_rgb = rgb + visual_prompt
    
    
#     # pre_depth_ = foundation_model({'image': new_rgb, 'depth': sparse_depth}, {})['depth']
#     # Run model inference
    
#     with torch.autocast(device_type='cuda', enabled=(not args_vda.fp32)):
#         depth = video_depth_anything.forward(new_rgb)  # depth shape: [1, T, H, W]

    
    
#     scale, shift = compute_scale_and_shift(pre_depth_, sparse_depth)
    
#     print(scale, shift)    
    
#     pre_depth = pre_depth_ * scale + shift    
        
#     loss_l1 = F.l1_loss(pre_depth[sparse_mask], sparse_depth[sparse_mask])
#     loss_rmse = torch.sqrt(((pre_depth[sparse_mask] - sparse_depth[sparse_mask]) ** 2).mean())
#     loss = loss_l1 + loss_rmse

#     optimizer.zero_grad()
#     loss.backward()          
#     optimizer.step()

#     pbar.set_description(f'exp: {args_ttt.exp_name} l1: {loss_l1.item():.4f} rmse: {loss_rmse.item():.4f}')
#     pbar.update()

In [None]:
torch.cuda.memory._dump_snapshot("/home/azhuravl/work/TrajectoryCrafter/notebooks/05_11_25_training/snapshot_after.pickle")

In [None]:
# empty cuda cache
torch.cuda.empty_cache()


In [None]:
!nvidia-smi

In [None]:

with torch.no_grad():
    rmse_vp, mae_vp = torch.sqrt(((pre_depth[gt_mask] - depth[gt_mask]) ** 2).mean()), torch.abs(pre_depth[gt_mask] - depth[gt_mask]).mean()
    rmse_mean_vp += rmse_vp.item()
    mae_mean_vp += mae_vp.item()
            
pbar.close()
    
print(f'RMSE: {rmse_mean_vp}, MAE: {mae_mean_vp} idx: {idx}')