# UNET

In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [78]:
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

In [79]:
class Unet(nn.Module):
    def __init__(self, in_channel=7, out_channel=1, base=32):
        super().__init__()
        self.enc1 = ConvBlock(in_channel, base)
        self.enc2 = ConvBlock(base, base*2)
        self.enc3 = ConvBlock(base*2, base*4)
        self.pool = nn.MaxPool2d(2)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec2 = ConvBlock(base*4, base*2)
        self.dec1 = ConvBlock(base*2, base)
        self.out = nn.Conv2d(base, out_channel, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        
        d2 = self.up2(e3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        out = self.out(d1)

        return out


# Dataloader

In [80]:
# dataset_gtav_exr.py
import os
import cv2
import json
import random
import torch
import numpy as np
from torch.utils.data import Dataset
import OpenEXR
import Imath


def load_exr_depth(path):
    exr = OpenEXR.InputFile(path)
    dw = exr.header()['dataWindow']
    width  = dw.max.x - dw.min.x + 1
    height = dw.max.y - dw.min.y + 1

    channel = 'Z' if 'Z' in exr.header()['channels'] else 'Y'
    pt = Imath.PixelType(Imath.PixelType.FLOAT)
    depth_str = exr.channel(channel, pt)

    depth = np.frombuffer(depth_str, dtype=np.float32)
    return depth.reshape((height, width))


def load_pose_json(path):
    with open(path, "r") as f:
        meta = json.load(f)

    extr = np.array(meta["extrinsic"], dtype=np.float32)
    fx, fy = meta["f_x"], meta["f_y"]
    cx, cy = meta["c_x"], meta["c_y"]

    K = np.array([
        [fx, 0,  cx],
        [0,  fy, cy],
        [0,   0,  1]
    ], dtype=np.float32)

    # Fix left-handed coordinate issues
    R = extr[:3, :3]
    if np.linalg.det(R) < 0:
        R[:, 2] *= -1
        extr[:3, :3] = R

    return extr, K


class GTAVEXRDataset(Dataset):
    def __init__(self, root, n_views=3, max_stride=3):
        assert n_views % 2 == 1, "n_views must be odd: 3,5,7..."
        self.n_views = n_views
        self.max_stride = max_stride
        self.root = root

        self.scenes = []

        scene_ids = sorted(os.listdir(root))
        for sid in scene_ids:
            spath = os.path.join(root, sid)
            if not os.path.isdir(spath):
                continue

            imgs   = sorted(os.listdir(os.path.join(spath, "images")))
            depths = sorted(os.listdir(os.path.join(spath, "depths")))
            poses  = sorted(os.listdir(os.path.join(spath, "poses")))

            self.scenes.append({
                "img":   [os.path.join(spath, "images", f) for f in imgs],
                "depth": [os.path.join(spath, "depths", f) for f in depths],
                "pose":  [os.path.join(spath, "poses", f) for f in poses],
                "len": len(imgs),
                "root": spath
            })

    def __len__(self):
        return sum(s["len"] for s in self.scenes)


    # ----------------------------------------------------------
    # Generalized symmetric sampling around center (Option 1)
    # ----------------------------------------------------------
    def sample_views(self, L):
        """
        Returns sorted list of n_views indices:
        symmetric around a random center.
        """
        half = self.n_views // 2
        # ensure center can shift both sides
        c = random.randint(half, L - half - 1)

        views = [c]

        for h in range(1, half + 1):
            # random stride for left/right
            k = random.randint(1, self.max_stride)

            left  = max(0,     c - k)
            right = min(L - 1, c + k)

            views.append(left)
            views.append(right)

        views = sorted(views)[:self.n_views]   # ensure correct count
        return views


    # ----------------------------------------------------------
    # __getitem__
    # ----------------------------------------------------------
    def __getitem__(self, _):
        scene = random.choice(self.scenes)
        L = scene["len"]

        # pick generalized multi-view indices
        idxs = self.sample_views(L)   # e.g. [c-k2, c-k1, c, c+k1, c+k2] for 5 views

        # Load RGB
        def load_rgb(path):
            img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
            return img.astype(np.float32) / 255.0

        rgb = np.stack([load_rgb(scene["img"][v]) for v in idxs])   # (V,H,W,3)
        depth = np.stack([load_exr_depth(scene["depth"][v]) for v in idxs])   # (V,H,W)

        # Load extrinsics/intrinsics
        extrinsics = []
        intrinsics = []

        for v in idxs:
            extr, K = load_pose_json(scene["pose"][v])
            extrinsics.append(extr)
            intrinsics.append(K)

        extrinsics = np.stack(extrinsics)   # (V,4,4)
        intrinsics = np.stack(intrinsics)   # (V,3,3)

        return {
            "rgb": torch.from_numpy(rgb).permute(0, 3, 1, 2),  # (V,3,H,W)
            "depth": torch.from_numpy(depth).float(),         # (V,H,W)
            "extrinsic": torch.from_numpy(extrinsics).float(),# (V,4,4)
            "intrinsic": torch.from_numpy(intrinsics).float() # (V,3,3)
        }


In [81]:
dataset = GTAVEXRDataset(root="dataset/GTAV_540", max_stride=3, n_views=5)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, shuffle=True)

# Test 1 batch
test_sample = None
for batch in dataloader:
    test_sample = batch
    print("RGB:", batch["rgb"].shape)           # (B, V, 3, H, W) (batch, view_count, channels, width, height)
    print("Depth:", batch["depth"].shape)       # (B, V, H, W)
    print("Extrinsic:", batch["extrinsic"].shape) # (B, V, 4, 4)
    print("Intrinsic:", batch["intrinsic"].shape) # (B, V, 3, 3)
    break

RGB: torch.Size([3, 5, 3, 540, 960])
Depth: torch.Size([3, 5, 540, 960])
Extrinsic: torch.Size([3, 5, 4, 4])
Intrinsic: torch.Size([3, 5, 3, 3])


# Utility functions

In [82]:
def calculate_ray_dirs(intrinsics, W, H):
    i, j = torch.meshgrid(
        torch.arange(W, dtype=torch.float32),
        torch.arange(H, dtype=torch.float32),
        indexing='xy'
    ) 

    dirs = torch.stack([
        (i - intrinsics[0, 2]) / intrinsics[0, 0], # [(X_i) - (fx)]/ (B, cx) 
        (j - intrinsics[1, 2]) / intrinsics[1, 1],
        torch.ones_like(i)
    ], -1) 

    dirs = dirs / torch.norm(dirs, dim=-1, keepdim=True)  # normalize

    return dirs  # (W,H,[dir_x, dir_y, dir_z])

# Test
W, H = test_sample['rgb'].shape[-1], test_sample['rgb'].shape[-2]
calculate_ray_dirs(test_sample["intrinsic"][0,1], W, H).shape

torch.Size([540, 960, 3])

In [83]:
def project_depth_to_camera_3d(depth_map, ray_dirs):
    return depth_map[..., None] * ray_dirs # return [X, Y, Z] 3D points in camera space

def project_3d_to_camera_2d(points_3d, intrinsic):
    """
    points_3d: (..., 3) tensor of 3D points in camera space
    intrinsic: (3, 3) camera intrinsic matrix
    Returns:
        points_2d: (..., 2) tensor of 2D pixel coordinates
    """
    fx = intrinsic[0, 0]
    fy = intrinsic[1, 1]
    cx = intrinsic[0, 2]
    cy = intrinsic[1, 2]

    x = points_3d[..., 0]
    y = points_3d[..., 1]
    z = points_3d[..., 2].clamp(min=1e-6)  # Prevent division by zero

    u = fx * (x / z) + cx
    v = fy * (y / z) + cy

    return torch.stack([u, v], dim=-1)

# Test
project_depth_to_camera_3d(
    test_sample['depth'][0,1], 
    calculate_ray_dirs(test_sample["intrinsic"][0,1], W, H)
).shape

torch.Size([540, 960, 3])

In [84]:
def camera_to_world(cam_point_3d, extrinsic):
    """
    cam_point_3d: (H, W, 3) camera-space 3D points
    extrinsic:   (4,4) world -> camera matrix
    Returns:     (H, W, 3) world-space points
    """

    R = extrinsic[:3, :3]      # world->cam rotation
    t = extrinsic[:3, 3]       # world->cam translation

    # Inverse transform:
    R_inv = R.T                # cam->world rotation
    t_inv = -R_inv @ t         # cam->world translation

    # reshape points: (H,W,3) -> (H*W,3)
    H, W = cam_point_3d.shape[:2]
    pts = cam_point_3d.reshape(-1, 3).T   # (3, HW)

    # apply transform
    pts_w = R_inv @ pts + t_inv[:, None]  # (3, HW)

    # reshape back
    return pts_w.T.reshape(H, W, 3)

def world_to_camera(world_point_3d, extrinsic):
    """
    world_point_3d: (H, W, 3) world-space 3D points
    extrinsic:      (4,4) world -> camera matrix
    Returns:        (H, W, 3) camera-space points
    """

    R = extrinsic[:3, :3]      # world->cam rotation
    t = extrinsic[:3, 3]       # world->cam translation

    # reshape points: (H,W,3) -> (H*W,3)
    H, W = world_point_3d.shape[:2]
    pts = world_point_3d.reshape(-1, 3).T   # (3, HW)

    # apply transform
    pts_c = R @ pts + t[:, None]  # (3, HW)

    # reshape back
    return pts_c.T.reshape(H, W, 3)

In [85]:
def sample_depth_bilinear(point_2d, depth_map):
    """
    point_2d: (N, 2) tensor of 2D pixel coordinates
    depth_map: (H, W) tensor of depth values
    Returns:
        sampled_depths: (N,) tensor of sampled depth values
    """
    H, W = depth_map.shape
    x = point_2d[:, 0]
    y = point_2d[:, 1]

    x0 = torch.floor(x).long().clamp(0, W - 1)
    x1 = (x0 + 1).clamp(0, W - 1)
    y0 = torch.floor(y).long().clamp(0, H - 1)
    y1 = (y0 + 1).clamp(0, H - 1)

    Ia = depth_map[y0, x0]
    Ib = depth_map[y1, x0]
    Ic = depth_map[y0, x1]
    Id = depth_map[y1, x1]

    wa = (x1.float() - x) * (y1.float() - y)
    wb = (x1.float() - x) * (y - y0.float())
    wc = (x - x0.float()) * (y1.float() - y)
    wd = (x - x0.float()) * (y - y0.float())

    sampled_depths = wa * Ia + wb * Ib + wc * Ic + wd * Id

    return sampled_depths

# Batched ultilities

In [106]:
import torch
import torch.nn.functional as F

EPS = 1e-6

# ---- helpers (slightly cleaned) ----
def _pixel_grid(H, W, device, dtype):
    # Returns pixel coords u (x) and v (y) shaped (H, W)
    v, u = torch.meshgrid(
        torch.arange(H, device=device, dtype=dtype),
        torch.arange(W, device=device, dtype=dtype),
        indexing='ij'
    )
    return u, v  # u->x, v->y

def batched_ray_dirs(intrinsics, H, W):
    """
    intrinsics: (B, V, 3, 3)
    returns: dirs (B, V, H, W, 3) normalized camera ray directions
    """
    device = intrinsics.device
    dtype = intrinsics.dtype
    B, V = intrinsics.shape[0], intrinsics.shape[1]

    u, v = _pixel_grid(H, W, device, dtype)   # (H, W) u=x, v=y
    u = u.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    v = v.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)

    fx = intrinsics[..., 0, 0].unsqueeze(-1).unsqueeze(-1)  # (B,V,1,1)
    fy = intrinsics[..., 1, 1].unsqueeze(-1).unsqueeze(-1)
    cx = intrinsics[..., 0, 2].unsqueeze(-1).unsqueeze(-1)
    cy = intrinsics[..., 1, 2].unsqueeze(-1).unsqueeze(-1)

    dirs_x = (u - cx) / (fx + EPS)   # (B,V,H,W)
    dirs_y = (v - cy) / (fy + EPS)
    dirs_z = torch.ones_like(dirs_x)

    dirs = torch.stack([dirs_x, dirs_y, dirs_z], dim=-1)  # (B,V,H,W,3)
    norm = torch.norm(dirs, dim=-1, keepdim=True)
    dirs = dirs / (norm + EPS)
    return dirs


def batched_depth_to_camera(depth, ray_dirs):
    # depth: (B,V,H,W), ray_dirs: (B,V,H,W,3)
    return depth.unsqueeze(-1) * ray_dirs  # (B,V,H,W,3)


def batched_camera_to_world(cam_pts, extrinsics):
    """
    cam_pts: (B, V, H, W, 3)
    extrinsics: (B, V, 4, 4)  (world -> camera)
    returns: world_pts (B, V, H, W, 3)
    """
    B, V, H, W, _ = cam_pts.shape
    R = extrinsics[..., :3, :3]         # (B,V,3,3) world->cam
    t = extrinsics[..., :3, 3]          # (B,V,3)

    R_inv = R.transpose(-1, -2)         # (B,V,3,3) cam->world
    t_inv = -torch.matmul(R_inv, t.unsqueeze(-1)).squeeze(-1)  # (B,V,3)

    pts = cam_pts.reshape(B, V, -1, 3)  # (B,V,HW,3)
    pts_t = pts.permute(0,1,3,2)        # (B,V,3,HW)
    world = torch.matmul(R_inv, pts_t) + t_inv.unsqueeze(-1)    # (B,V,3,HW)
    world = world.permute(0,1,3,2).reshape(B, V, H, W, 3)
    return world


def batched_world_to_camera(world_pts, extrinsics):
    """
    world_pts: (B, V, H, W, 3)
    extrinsics: (B, V, 4, 4)  (world -> camera)
    returns: cam_pts (B, V, H, W, 3)
    """
    B, V, H, W, _ = world_pts.shape
    R = extrinsics[..., :3, :3]    # (B,V,3,3)
    t = extrinsics[..., :3, 3]     # (B,V,3)

    pts = world_pts.reshape(B, V, -1, 3).permute(0,1,3,2)  # (B,V,3,HW)
    cam = torch.matmul(R, pts) + t.unsqueeze(-1)           # (B,V,3,HW)
    cam = cam.permute(0,1,3,2).reshape(B, V, H, W, 3)
    return cam


def batched_project_3d_to_2d(pts_3d, intrinsics):
    """
    pts_3d: (B, V, H, W, 3) in camera coords
    intrinsics: (B, V, 3, 3)
    returns: uv (B, V, H, W, 2), z (B,V,H,W)
    """
    x = pts_3d[..., 0]
    y = pts_3d[..., 1]
    z = pts_3d[..., 2].clamp(min=EPS)

    z = torch.clamp(z, min=1e-4)

    fx = intrinsics[..., 0, 0].unsqueeze(-1).unsqueeze(-1)
    fy = intrinsics[..., 1, 1].unsqueeze(-1).unsqueeze(-1)
    cx = intrinsics[..., 0, 2].unsqueeze(-1).unsqueeze(-1)
    cy = intrinsics[..., 1, 2].unsqueeze(-1).unsqueeze(-1)

    u = fx * (x / z) + cx
    v = fy * (y / z) + cy
    uv = torch.stack([u, v], dim=-1)
    return uv, z


def normalize_uv_for_grid_sample(uv, H, W):
    """
    uv: (..., 2) pixel coords with u in [0..W-1], v in [0..H-1]
    returns: grid coords in [-1,1] last-dim order (x,y) for grid_sample
    """
    u = uv[..., 0]
    v = uv[..., 1]

    nx = (u / (W - 1)) * 2 - 1
    ny = (v / (H - 1)) * 2 - 1
    return torch.stack([nx, ny], dim=-1)


# ================================================================
# UNIT DETECTION & AUTO-SCALING
# ================================================================
def camera_centers_from_extrinsics(extrinsics):
    """
    extrinsics: (B, V, 4, 4) world->camera
    returns: camera centers in world coords (B, V, 3)
    """
    R = extrinsics[..., :3, :3]   # (B,V,3,3)
    t = extrinsics[..., :3, 3]    # (B,V,3)
    R_inv = R.transpose(-1, -2)   # cam->world
    cam_centers = -torch.matmul(R_inv, t.unsqueeze(-1)).squeeze(-1)  # (B,V,3)
    return cam_centers


def detect_and_fix_depth_unit(depth_batch, extrinsics, threshold_scale=10.0, apply_fix=True):
    """
    Heuristic: if median depth >> median camera-translation magnitude,
    likely depth is in mm (or cm). We scale down by 1000 or 100 accordingly.
    Returns (depth_batch_scaled, scale_factor, did_scale_flag, med_scale, med_depth_mean, baseline_mean)
    """
    device = depth_batch.device
    B, V, H, W = depth_batch.shape

    # stats on depth center
    depth_center = depth_batch[:, 1]  # (B,H,W)
    med_depth = torch.median(depth_center.reshape(B, -1), dim=1).values  # (B,)

    # camera center distances (per sample)
    centers = camera_centers_from_extrinsics(extrinsics)  # (B,V,3)
    # baseline magnitude between src(0) and center(1)
    baseline = torch.norm(centers[:, 0] - centers[:, 1], dim=-1)  # (B,)
    # avoid zero baseline
    baseline = baseline + 1e-6

    scale_factors = med_depth / baseline  # if >> threshold -> likely depth in mm
    med_scale = float(torch.median(scale_factors).item())

    # Decide conversion
    scale = 1.0
    did_scale = False
    if med_scale > threshold_scale:
        # typical: depth in mm -> divide by 1000
        if med_scale > 1000:
            scale = 1.0 / 1000.0
        else:
            # try 1000 first
            scale = 1.0 / 1000.0
        if apply_fix:
            depth_batch = depth_batch * scale
            did_scale = True

    return depth_batch, scale, did_scale, med_scale, float(med_depth.mean().item()), float(baseline.mean().item())


# ================================================================
# SAFE REPROJECTION (NO IN-PLACE OPS)
# ================================================================
def reprojection_pair_to_center_safe(
    depth_batch, intrinsics, extrinsics,
    center_idx=1, src_idx=0,
    center_depth_override=None,
    debug=False
):
    """
    Safe reprojection: src -> center, no in-place ops on tensors that require grad.
    center_depth_override: optional (B,H,W) tensor to use instead of depth_batch[:,center_idx]
    Returns: X_src_world (B,H,W,3), X_center_world_reproj (B,H,W,3), valid (B,H,W), uv (B,H,W,2)
    """
    B, V, H, W = depth_batch.shape
    device = depth_batch.device
    dtype = depth_batch.dtype

    # 1) ray dirs
    ray_dirs = batched_ray_dirs(intrinsics, H, W)  # (B,V,H,W,3)

    # 2) src cam pts (use only src view)
    ray_src = ray_dirs[:, src_idx]                 # (B,H,W,3)
    depth_src = depth_batch[:, src_idx].unsqueeze(-1)   # (B,H,W,1)
    depth_src = torch.clamp(depth_src, min=1e-4, max=1e6)
    depth_src = torch.nan_to_num(depth_src, nan=1e-4, posinf=1e6, neginf=1e-4)
    cam_pts_src = depth_src * ray_src              # (B,H,W,3)

    # 3) world points src
    X_src_world = batched_camera_to_world(cam_pts_src.unsqueeze(1), extrinsics[:, src_idx:src_idx+1]).squeeze(1)
    X_src_world = torch.nan_to_num(X_src_world, nan=0.0, posinf=1e6, neginf=-1e6)

    # 4) reproject src world into center camera coords (no in-place)
    X_src_in_center_cam = batched_world_to_camera(X_src_world.unsqueeze(1), extrinsics[:, center_idx:center_idx+1]).squeeze(1)
    x_xy = X_src_in_center_cam[..., :2]   # new view copies
    z = X_src_in_center_cam[..., 2]
    z_clamped = z.clamp(min=1e-4)
    X_src_in_center_cam_clamped = torch.cat([x_xy, z_clamped.unsqueeze(-1)], dim=-1)

    # 5) project to pixels in center (uses a copy)
    intr_c = intrinsics[:, center_idx:center_idx+1]
    uv, z_proj = batched_project_3d_to_2d(X_src_in_center_cam_clamped.unsqueeze(1), intr_c)
    uv = uv.squeeze(1); z_proj = z_proj.squeeze(1)

    # clamp uv to image bounds without in-place writes
    u = uv[..., 0].clamp(0.0, W - 1.0)
    v = uv[..., 1].clamp(0.0, H - 1.0)
    uv_clamped = torch.stack([u, v], dim=-1)
    uv_clamped = torch.nan_to_num(uv_clamped, nan=0.0, posinf=W-1.0, neginf=0.0)

    # 6) sample center depth at uv using either override or original depth
    grid = normalize_uv_for_grid_sample(uv_clamped, H, W)  # safe normalization
    if center_depth_override is None:
        depth_center = depth_batch[:, center_idx].unsqueeze(1)  # (B,1,H,W)
    else:
        # center_depth_override: (B,H,W) -> convert to (B,1,H,W) for grid_sample
        depth_center = center_depth_override.unsqueeze(1)

    depth_center = torch.clamp(depth_center, min=1e-4, max=1e6)
    sampled_center = F.grid_sample(depth_center, grid, mode='bilinear', padding_mode='zeros', align_corners=True)
    sampled_center = torch.nan_to_num(sampled_center, nan=0.0, posinf=1e6, neginf=0.0).squeeze(1)

    # 7) valid mask
    in_bounds = (uv_clamped[...,0] >= 0) & (uv_clamped[...,0] <= (W - 1)) & (uv_clamped[...,1] >= 0) & (uv_clamped[...,1] <= (H - 1))
    valid = in_bounds & (z_proj > 0) & (sampled_center > 0)

    # 8) sample center ray dirs
    ray_center = ray_dirs[:, center_idx]  # (B,H,W,3)
    ray_center_t = ray_center.permute(0,3,1,2)
    sampled_rays = F.grid_sample(ray_center_t, grid, mode='bilinear', padding_mode='zeros', align_corners=True)
    sampled_rays = torch.nan_to_num(sampled_rays, nan=0.0, posinf=1.0, neginf=-1.0)
    sampled_rays = sampled_rays.permute(0,2,3,1)

    # 9) backproject sampled center -> world
    X_center_cam_reproj = sampled_center.unsqueeze(-1) * sampled_rays
    X_center_world_reproj = batched_camera_to_world(X_center_cam_reproj.unsqueeze(1), extrinsics[:, center_idx:center_idx+1]).squeeze(1)
    X_center_world_reproj = torch.nan_to_num(X_center_world_reproj, nan=0.0, posinf=1e6, neginf=-1e6)

    if debug:
        print(f"[reproj debug SAFE] src={src_idx} B={B} H={H} W={W}")
        print(" X_src_world minZ:", float(X_src_world[...,2].min().item()))
        print(" X_center_world_reproj minZ:", float(X_center_world_reproj[...,2].min().item()))
        print(" uv min/max:", float(uv_clamped[...,0].min().item()), float(uv_clamped[...,0].max().item()),
              float(uv_clamped[...,1].min().item()), float(uv_clamped[...,1].max().item()))
        print(" sampled_center min/max:", float(sampled_center.min().item()), float(sampled_center.max().item()))
        print(" valid_count:", int(valid.sum().item()))

    return X_src_world, X_center_world_reproj, valid, uv_clamped


In [107]:
# quick test
B, V, H, W = 2, 3, 64, 80
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32

depth = torch.rand(B, V, H, W, device=device, dtype=dtype) * 10.0 + 0.1  # positive depths
intr = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0).repeat(B, V, 1, 1)
# tweak fx,fy,cx,cy for visual
intr[..., 0, 0] = 60.0
intr[..., 1, 1] = 60.0
intr[..., 0, 2] = W / 2.0
intr[..., 1, 2] = H / 2.0

# build extrinsics: identity cameras at different x offsets
extr = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).unsqueeze(0).repeat(B, V, 1, 1)
# shift cameras along x
for v in range(V):
    extr[:, v, 0, 3] = (v - 1) * 0.2  # small baseline

X_src_world, X_center_world_reproj, valid, uv = reprojection_pair_to_center(
    depth, intr, extr, center_idx=1, src_idx=0
)

print(X_src_world.shape)         # (B,H,W,3)
print(X_center_world_reproj.shape)  # (B,H,W,3)
print(valid.shape)               # (B,H,W)
print(uv.shape)                  # (B,H,W,2)
print(valid.float().mean().item())


torch.Size([2, 64, 80, 3])
torch.Size([2, 64, 80, 3])
torch.Size([2, 64, 80])
torch.Size([2, 64, 80, 2])
1.0


# Train loop

In [None]:
# -------------------------
# Wrapper: DepthAnything forward on RGB batch
# -------------------------
def depth_anything_forward(rgb_batch, depth_anything_model):
    """
    rgb_batch: (B, V, 3, H, W) in [0, 1] float32
    Returns: depth_init_all (B, V, H, W) normalized to [0, 1]
    """
    from PIL import Image
    import numpy as np
    
    device = rgb_batch.device
    B, V, C, H, W = rgb_batch.shape
    
    depth_init_all = []
    
    for b in range(B):
        depth_batch_b = []
        for v in range(V):
            # Extract single RGB frame: (3, H, W) -> PIL Image
            rgb_frame = rgb_batch[b, v].permute(1, 2, 0)  # (H, W, 3)
            rgb_frame_np = (rgb_frame * 255).clamp(0, 255).byte().cpu().numpy()
            
            # Convert numpy to PIL Image
            pil_image = Image.fromarray(rgb_frame_np.astype(np.uint8), mode='RGB')
            
            # Run Depth Anything inference
            with torch.no_grad():
                result = depth_anything_model(pil_image)
            
            # result is a dict with 'depth' key (PIL Image)
            depth_pred = result['depth']
            
            # Convert PIL Image to numpy
            if isinstance(depth_pred, Image.Image):
                depth_pred = np.array(depth_pred)
            
            # Normalize to [0, 1]
            depth_pred = depth_pred.astype(np.float32)
            depth_min, depth_max = depth_pred.min(), depth_pred.max()
            if depth_max > depth_min:
                depth_pred = (depth_pred - depth_min) / (depth_max - depth_min)
            else:
                depth_pred = np.ones_like(depth_pred) * 0.5
            
            # Convert back to tensor
            depth_pred = torch.from_numpy(depth_pred).float().to(device)  # (H, W)
            depth_batch_b.append(depth_pred)
        
        depth_init_all.append(torch.stack(depth_batch_b, dim=0))  # (V, H, W)
    
    depth_init_all = torch.stack(depth_init_all, dim=0)  # (B, V, H, W)
    return depth_init_all

In [129]:
# -------------------------
# Loss Functions with Huber
# -------------------------

def affine_align_depth(depth_init, gt_depth, mask, eps=1e-6):
    """
    Compute affine alignment: s, t such that s*depth_init + t â‰ˆ gt_depth
    Returns aligned depth and scale factors
    
    depth_init: (B, 1, H, W)
    gt_depth:   (B, 1, H, W)
    mask:       (B, 1, H, W) boolean
    """
    B = depth_init.shape[0]
    s_list = []
    t_list = []
    depth_aligned = torch.zeros_like(depth_init)
    
    for b in range(B):
        m = mask[b, 0].reshape(-1)
        if m.sum() < 10:
            # Not enough valid pixels, use identity
            s_list.append(1.0)
            t_list.append(0.0)
            depth_aligned[b] = depth_init[b]
            continue
        
        d = depth_init[b, 0].reshape(-1)[m]  # (N,)
        g = gt_depth[b, 0].reshape(-1)[m]    # (N,)
        
        # Least squares: minimize || s*d + t - g ||^2
        # A = [d, 1], x = [s, t]^T, b = g
        A = torch.stack([d, torch.ones_like(d)], dim=1)  # (N, 2)
        g_col = g.unsqueeze(1)  # (N, 1)
        
        # Solve using lstsq (stable)
        try:
            x = torch.linalg.lstsq(A, g_col).solution  # (2, 1)
            s, t = float(x[0].item()), float(x[1].item())
        except:
            s, t = 1.0, 0.0
        
        s_list.append(s)
        t_list.append(t)
        depth_aligned[b, 0] = depth_init[b, 0] * s + t
    
    return depth_aligned, s_list, t_list


def edge_aware_smoothness(depth, rgb):
    """
    Edge-aware smoothness loss: smooth depth except at RGB edges
    depth: (B, 1, H, W)
    rgb: (B, 3, H, W)
    """
    grad_depth_x = torch.abs(depth[:, :, :, :-1] - depth[:, :, :, 1:])
    grad_depth_y = torch.abs(depth[:, :, :-1, :] - depth[:, :, 1:, :])
    
    grad_rgb_x = torch.mean(torch.abs(rgb[:, :, :, :-1] - rgb[:, :, :, 1:]), dim=1, keepdim=True)
    grad_rgb_y = torch.mean(torch.abs(rgb[:, :, :-1, :] - rgb[:, :, 1:, :]), dim=1, keepdim=True)
    
    L_edge_x = torch.mean(grad_depth_x * torch.exp(-grad_rgb_x))
    L_edge_y = torch.mean(grad_depth_y * torch.exp(-grad_rgb_y))
    
    return L_edge_x + L_edge_y


def normals_from_depth(depth, intrinsics):
    """
    Compute surface normals from depth
    depth: (B, 1, H, W)
    intrinsics: (B, V, 3, 3) - use first V=0 for center
    returns: normals (B, 1, H, W, 3) normalized
    """
    B = depth.shape[0]
    intr_center = intrinsics[:, 0]  # (B, 3, 3)
    
    fx = intr_center[:, 0, 0].view(B, 1, 1, 1)
    fy = intr_center[:, 1, 1].view(B, 1, 1, 1)
    
    # Finite difference for gradients
    dz_dx = (depth[:, :, :, 1:] - depth[:, :, :, :-1]) / (fx + 1e-6)
    dz_dy = (depth[:, :, 1:, :] - depth[:, :, :-1, :]) / (fy + 1e-6)
    
    # Pad to original size
    dz_dx = torch.cat([dz_dx, dz_dx[:, :, :, -1:]], dim=3)
    dz_dy = torch.cat([dz_dy, dz_dy[:, :, -1:, :]], dim=2)
    
    # Normal: (-dz/dx, -dz/dy, 1)
    nx = -dz_dx
    ny = -dz_dy
    nz = torch.ones_like(nx)
    
    normals = torch.stack([nx, ny, nz], dim=-1)  # (B, 1, H, W, 3)
    norm = torch.norm(normals, dim=-1, keepdim=True) + 1e-6
    normals = normals / norm
    
    return normals


def normal_smoothness_loss(normals):
    """
    Encourage smooth normals (small differences between neighbors)
    normals: (B, 1, H, W, 3)
    """
    n = normals
    L_x = torch.mean(torch.abs(n[:, :, :, :-1, :] - n[:, :, :, 1:, :]))
    L_y = torch.mean(torch.abs(n[:, :, :-1, :, :] - n[:, :, 1:, :, :]))
    return L_x + L_y


def masked_l1_huber(pred, gt, mask, eps=1e-6, beta=0.1):
    """
    Masked Huber loss (SmoothL1)
    pred: (B, 1, H, W)
    gt: (B, 1, H, W)
    mask: (B, 1, H, W) boolean
    beta: Huber transition threshold
    """
    valid = mask.float()
    
    # Clamp both pred and gt to avoid extreme values
    pred_safe = torch.clamp(pred, min=0.0, max=1e4)
    gt_safe = torch.clamp(gt, min=0.0, max=1e4)
    
    huber = torch.nn.SmoothL1Loss(reduction='none', beta=beta)
    loss = huber(pred_safe, gt_safe)
    
    # Safety: remove any NaN/Inf
    loss = torch.nan_to_num(loss, nan=0.0, posinf=1e2, neginf=0.0)
    loss = torch.clamp(loss, min=0.0, max=1e2)
    
    masked_loss = (loss * valid).sum() / (valid.sum() + eps)
    masked_loss = torch.nan_to_num(masked_loss, nan=0.0, posinf=1e2, neginf=0.0)
    
    return masked_loss


def train_one_epoch(model, depth_anything_model, dataloader, optimizer, device, 
                    use_gt_for_loss=True, lambda_mv=0.001, lambda_init=0.1, lambda_edge=0.05, 
                    lambda_norm=0.1, num_steps=1):
    """
    Single training epoch with multi-view consistency and Huber loss
    
    Args:
        model: UNet (7 in, 1 out)
        depth_anything_model: Frozen depth estimation pipeline
        dataloader: GTAV loader
        optimizer: Adam
        device: cuda/cpu
        use_gt_for_loss: use GT depth for L_gt
        lambda_init: weight for L_init (smoothness)
        lambda_edge: weight for L_edge (edge-aware)
        lambda_norm: weight for L_norm (normal smoothness)
        num_steps: number of batches to process
    """
    model.train()
    GT_INVALID_THRESH = 100.0
    EPS = 1e-6
    
    for step, batch in enumerate(dataloader):
        if step >= num_steps:
            break
        
        batch_rgb = batch['rgb'].to(device)  # (B, V, 3, H, W)
        batch_depth = batch['depth'].to(device)  # (B, V, H, W)
        batch_intr = batch['intrinsic'].to(device)  # (B, V, 3, 3)
        batch_extr = batch['extrinsic'].to(device)  # (B, V, 4, 4)
        
        B, V, _, H, W = batch_rgb.shape
        center_idx = V // 2
        
        print(f"\n[Step {step}] B={B}, V={V}, H={H}, W={W}")
        
        # ============================================
        # 1) DepthAnything inference (frozen)
        # ============================================
        depth_init_all = depth_anything_forward(batch_rgb, depth_anything_model)  # (B, V, H, W)
        depth_init_all = torch.clamp(depth_init_all, min=1e-4, max=1.0)
        print(f"[DepthAnything] min={depth_init_all.min():.4f}, max={depth_init_all.max():.4f}")
        
        # ============================================
        # 2) Detect & fix depth unit mismatch
        # ============================================
        depth_init_scaled, scale, did_scale, med_scale, med_depth, baseline = \
            detect_and_fix_depth_unit(batch_depth, batch_extr, threshold_scale=10.0, apply_fix=True)
        print(f"[Unit detection] scale={scale:.4f}, did_scale={did_scale}, med_scale={med_scale:.2f}")
        
        # ============================================
        # 3) Prepare center view depth
        # ============================================
        depth_init_center = depth_init_all[:, center_idx].unsqueeze(1)  # (B, 1, H, W)
        gt_center = depth_init_scaled[:, center_idx].unsqueeze(1)  # (B, 1, H, W)
        
        # GT mask: finite, positive, not too large
        gt_valid = (~torch.isinf(gt_center)) & (gt_center < GT_INVALID_THRESH) & (gt_center > 0)
        n_valid_gt = gt_valid.float().sum().item()
        print(f"[GT mask] valid_pixels={n_valid_gt:.0f} / {B*H*W}")
        
        # ============================================
        # 4) Affine alignment: scale init_depth to match GT scale
        # ============================================
        try:
            depth_init_aligned, s_scales, t_offsets = affine_align_depth(depth_init_center, gt_center, gt_valid)
            print(f"[Affine align] scales={[f'{s:.2f}' for s in s_scales]}, offsets={[f'{t:.2f}' for t in t_offsets]}")
        except Exception as e:
            print(f"[Affine align ERROR] {e}, using identity (s=1, t=0)")
            depth_init_aligned = depth_init_center.clone()
            s_scales = [1.0] * B
            t_offsets = [0.0] * B
        
        # ============================================
        # 5) Prepare input to UNet (use aligned depth)
        # ============================================
        ray_dirs_batch = batched_ray_dirs(batch_intr, H, W)  # (B, V, H, W, 3)
        
        rgb_center = batch_rgb[:, center_idx]  # (B, 3, H, W)
        ray_dirs_center = ray_dirs_batch[:, center_idx]  # (B, H, W, 3)
        ray_dirs_center = ray_dirs_center.permute(0, 3, 1, 2)  # (B, 3, H, W)
        
        # Concatenate input: RGB(3) + depth_init_aligned(1) + ray_dirs(3) = 7 channels
        model_input = torch.cat([rgb_center, depth_init_aligned, ray_dirs_center], dim=1)  # (B, 7, H, W)
        
        # ============================================
        # 6) Forward pass
        # ============================================
        depth_delta = model(model_input)  # (B, 1, H, W)
        pred_depth_center = depth_init_aligned + depth_delta
        pred_depth_center = torch.clamp(pred_depth_center, min=1e-4, max=1e6)
        
        print(f"[Prediction] pred_depth min={pred_depth_center.min():.4f}, max={pred_depth_center.max():.4f}")
        
        # ============================================
        # 7) Multi-view reprojection loss
        # ============================================
        mv_loss = 0.0
        for src_idx in [0, 2]:  # skip center_idx=1
            X_src_world, X_center_world_reproj, valid, _ = reprojection_pair_to_center_safe(
                depth_init_scaled, batch_intr, batch_extr,
                center_idx=center_idx, src_idx=src_idx,
                center_depth_override=pred_depth_center.squeeze(1),
                debug=(step == 0)
            )
            
            # 3D point reprojection error in world space
            diff_3d = X_src_world - X_center_world_reproj  # (B, H, W, 3)
            error_3d = torch.norm(diff_3d, dim=-1)  # (B, H, W)
            valid_f = valid.float()
            
            loss_pair = (error_3d * valid_f).sum() / (valid_f.sum() + EPS)
            mv_loss = mv_loss + loss_pair
            
            n_valid = valid_f.sum().item()
            print(f"  src={src_idx}: loss_pair={loss_pair:.4f}, valid_pixels={n_valid:.0f}")
        
        mv_loss = mv_loss / 2.0  # average over 2 pairs
        print(f"[MV Loss] {mv_loss:.4f}")
        
        # ============================================
        # 8) GT supervised loss (Huber, now on aligned scale)
        # ============================================
        L_gt = 0.0
        if use_gt_for_loss and n_valid_gt > 10:
            # Apply the same affine scale/offset to pred as we did to init
            # This brings pred into GT's metric space
            s_tensor = torch.tensor(s_scales, device=device, dtype=pred_depth_center.dtype).view(B, 1, 1, 1)
            t_tensor = torch.tensor(t_offsets, device=device, dtype=pred_depth_center.dtype).view(B, 1, 1, 1)
            pred_aligned = pred_depth_center * s_tensor + t_tensor
            
            print(f"[Affine applied to pred] s={s_scales}, t={t_offsets}")
            print(f"  pred_aligned min/max: {pred_aligned.min():.4f}, {pred_aligned.max():.4f}")
            print(f"  gt_center min/max: {gt_center.min():.4f}, {gt_center.max():.4f}")
            
            L_gt = masked_l1_huber(pred_aligned, gt_center, gt_valid, beta=0.1)
            L_gt = torch.nan_to_num(L_gt, nan=0.0, posinf=1e2, neginf=0.0)
            print(f"[L_gt] {L_gt:.4f} (Huber, affine-aligned)")

        
        # ============================================
        # 9) Initial depth smoothness (edge-aware)
        # ============================================
        L_init = edge_aware_smoothness(depth_init_aligned, rgb_center)
        print(f"[L_init (edge_smooth)] {L_init:.4f}")
        
        # ============================================
        # 10) Edge-aware smoothness on prediction
        # ============================================
        L_edge = edge_aware_smoothness(pred_depth_center, rgb_center)
        print(f"[L_edge] {L_edge:.4f}")
        
        # ============================================
        # 11) Normal smoothness
        # ============================================
        normals = normals_from_depth(pred_depth_center, batch_intr)
        L_norm = normal_smoothness_loss(normals)
        print(f"[L_norm] {L_norm:.4f}")
        
        # ============================================
        # 12) Combined loss
        # ============================================
        total_loss = lambda_mv*mv_loss + 1.0*L_gt + lambda_init*L_init + lambda_edge*L_edge + lambda_norm*L_norm
        print(f"[Total Loss] {total_loss:.4f} (mv_weighted={lambda_mv*mv_loss:.4f})")
        
        # ============================================
        # 13) Backward pass
        # ============================================
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        print(f"[Backward] completed, gradients updated")


In [None]:
# ============================================
# Load Depth Anything Small (if not loaded)
# ============================================
if 'depth_anything' not in dir():
    from transformers import pipeline
    print("[Loading] Depth Anything Small model...")
    depth_anything = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
    print("[Loaded] Depth Anything Small")
else:
    print("[Using] Existing Depth Anything pipeline")

# ============================================
# Example: run training with Depth Anything Small
# ============================================
if __name__ == "__main__":
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"[Device] Using: {device}")
    
    # Initialize model (fresh)
    model = Unet(in_channel=7, out_channel=1, base=32).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    # Train for 5 epochs
    print("\n" + "="*80)
    print("Starting training with Depth Anything Small model - 5 epochs")
    print("="*80)
    
    for epoch in range(5):
        print(f"\n{'='*80}")
        print(f"EPOCH {epoch+1}/5")
        print(f"{'='*80}")
        
        train_one_epoch(
            model=model,
            depth_anything_model=depth_anything,
            dataloader=dataloader,
            optimizer=optimizer,
            device=device,
            use_gt_for_loss=True,
            lambda_mv=0.001,  # Reduce MV loss weight for balance
            lambda_init=0.1,
            lambda_edge=0.05,
            lambda_norm=0.1,
            num_steps=5  # 5 batches per epoch
        )
    
    print("\n" + "="*80)
    print("Training finished!")
    print("="*80)


[Using] Existing Depth Anything pipeline
[Device] Using: cuda

Starting training with Depth Anything Small model - 5 epochs

EPOCH 1/5

[Step 0] B=3, V=5, H=540, W=960

[Step 0] B=3, V=5, H=540, W=960
[DepthAnything] min=0.0001, max=1.0000
[Unit detection] scale=0.0010, did_scale=True, med_scale=266750000.00
[GT mask] valid_pixels=1399672 / 1555200
[DepthAnything] min=0.0001, max=1.0000
[Unit detection] scale=0.0010, did_scale=True, med_scale=266750000.00
[GT mask] valid_pixels=1399672 / 1555200
[Affine align] scales=['-1.51', '-1.21', '-1.83'], offsets=['0.86', '0.88', '1.37']
[Affine align] scales=['-1.51', '-1.21', '-1.83'], offsets=['0.86', '0.88', '1.37']
[Prediction] pred_depth min=0.0001, max=1.5878
[Prediction] pred_depth min=0.0001, max=1.5878
[reproj debug SAFE] src=0 B=3 H=540 W=960
 X_src_world minZ: -201601.421875
 X_center_world_reproj minZ: 215.28631591796875
 uv min/max: 0.0 925.4099731445312 0.0 390.44287109375
 sampled_center min/max: 0.6472070813179016 1.580121994018

In [125]:
# Quick test of affine align function
test_init = torch.rand(2, 1, 10, 10) * 0.5  # [0, 0.5]
test_gt = torch.rand(2, 1, 10, 10) * 50.0  # [0, 50]
test_mask = torch.ones(2, 1, 10, 10, dtype=torch.bool)

try:
    aligned, s, t = affine_align_depth(test_init, test_gt, test_mask)
    print(f"Affine test SUCCESS: s={s}, t={t}")
    print(f"  init range: [{test_init.min():.4f}, {test_init.max():.4f}]")
    print(f"  gt range: [{test_gt.min():.4f}, {test_gt.max():.4f}]")
    print(f"  aligned range: [{aligned.min():.4f}, {aligned.max():.4f}]")
except Exception as e:
    print(f"Affine test FAILED: {e}")
    import traceback
    traceback.print_exc()


Affine test SUCCESS: s=[-0.3661380410194397, 0.516023576259613], t=[22.120376586914062, 23.312768936157227]
  init range: [0.0012, 0.4991]
  gt range: [0.0719, 49.7345]
  aligned range: [21.9399, 23.5703]
