# UNET

In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [78]:
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channel, out_channel, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

In [79]:
class Unet(nn.Module):
    def __init__(self, in_channel=7, out_channel=1, base=32):
        super().__init__()
        self.enc1 = ConvBlock(in_channel, base)
        self.enc2 = ConvBlock(base, base*2)
        self.enc3 = ConvBlock(base*2, base*4)
        self.pool = nn.MaxPool2d(2)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec2 = ConvBlock(base*4, base*2)
        self.dec1 = ConvBlock(base*2, base)
        self.out = nn.Conv2d(base, out_channel, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        
        d2 = self.up2(e3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        out = self.out(d1)

        return out


# Dataloader

In [80]:
# dataset_gtav_exr.py
import os
import cv2
import json
import random
import torch
import numpy as np
from torch.utils.data import Dataset
import OpenEXR
import Imath


def load_exr_depth(path):
    exr = OpenEXR.InputFile(path)
    dw = exr.header()['dataWindow']
    width  = dw.max.x - dw.min.x + 1
    height = dw.max.y - dw.min.y + 1

    channel = 'Z' if 'Z' in exr.header()['channels'] else 'Y'
    pt = Imath.PixelType(Imath.PixelType.FLOAT)
    depth_str = exr.channel(channel, pt)

    depth = np.frombuffer(depth_str, dtype=np.float32)
    return depth.reshape((height, width))


def load_pose_json(path):
    with open(path, "r") as f:
        meta = json.load(f)

    extr = np.array(meta["extrinsic"], dtype=np.float32)
    fx, fy = meta["f_x"], meta["f_y"]
    cx, cy = meta["c_x"], meta["c_y"]

    K = np.array([
        [fx, 0,  cx],
        [0,  fy, cy],
        [0,   0,  1]
    ], dtype=np.float32)

    # Fix left-handed coordinate issues
    R = extr[:3, :3]
    if np.linalg.det(R) < 0:
        R[:, 2] *= -1
        extr[:3, :3] = R

    return extr, K


class GTAVEXRDataset(Dataset):
    def __init__(self, root, n_views=3, max_stride=3):
        assert n_views % 2 == 1, "n_views must be odd: 3,5,7..."
        self.n_views = n_views
        self.max_stride = max_stride
        self.root = root

        self.scenes = []

        scene_ids = sorted(os.listdir(root))
        for sid in scene_ids:
            spath = os.path.join(root, sid)
            if not os.path.isdir(spath):
                continue

            imgs   = sorted(os.listdir(os.path.join(spath, "images")))
            depths = sorted(os.listdir(os.path.join(spath, "depths")))
            poses  = sorted(os.listdir(os.path.join(spath, "poses")))

            self.scenes.append({
                "img":   [os.path.join(spath, "images", f) for f in imgs],
                "depth": [os.path.join(spath, "depths", f) for f in depths],
                "pose":  [os.path.join(spath, "poses", f) for f in poses],
                "len": len(imgs),
                "root": spath
            })

    def __len__(self):
        return sum(s["len"] for s in self.scenes)


    # ----------------------------------------------------------
    # Generalized symmetric sampling around center (Option 1)
    # ----------------------------------------------------------
    def sample_views(self, L):
        """
        Returns sorted list of n_views indices:
        symmetric around a random center.
        """
        half = self.n_views // 2
        # ensure center can shift both sides
        c = random.randint(half, L - half - 1)

        views = [c]

        for h in range(1, half + 1):
            # random stride for left/right
            k = random.randint(1, self.max_stride)

            left  = max(0,     c - k)
            right = min(L - 1, c + k)

            views.append(left)
            views.append(right)

        views = sorted(views)[:self.n_views]   # ensure correct count
        return views


    # ----------------------------------------------------------
    # __getitem__
    # ----------------------------------------------------------
    def __getitem__(self, _):
        scene = random.choice(self.scenes)
        L = scene["len"]

        # pick generalized multi-view indices
        idxs = self.sample_views(L)   # e.g. [c-k2, c-k1, c, c+k1, c+k2] for 5 views

        # Load RGB
        def load_rgb(path):
            img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
            return img.astype(np.float32) / 255.0

        rgb = np.stack([load_rgb(scene["img"][v]) for v in idxs])   # (V,H,W,3)
        depth = np.stack([load_exr_depth(scene["depth"][v]) for v in idxs])   # (V,H,W)

        # Load extrinsics/intrinsics
        extrinsics = []
        intrinsics = []

        for v in idxs:
            extr, K = load_pose_json(scene["pose"][v])
            extrinsics.append(extr)
            intrinsics.append(K)

        extrinsics = np.stack(extrinsics)   # (V,4,4)
        intrinsics = np.stack(intrinsics)   # (V,3,3)

        return {
            "rgb": torch.from_numpy(rgb).permute(0, 3, 1, 2),  # (V,3,H,W)
            "depth": torch.from_numpy(depth).float(),         # (V,H,W)
            "extrinsic": torch.from_numpy(extrinsics).float(),# (V,4,4)
            "intrinsic": torch.from_numpy(intrinsics).float() # (V,3,3)
        }


In [81]:
dataset = GTAVEXRDataset(root="dataset/GTAV_540", max_stride=3, n_views=5)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=3, shuffle=True)

# Test 1 batch
test_sample = None
for batch in dataloader:
    test_sample = batch
    print("RGB:", batch["rgb"].shape)           # (B, V, 3, H, W) (batch, view_count, channels, width, height)
    print("Depth:", batch["depth"].shape)       # (B, V, H, W)
    print("Extrinsic:", batch["extrinsic"].shape) # (B, V, 4, 4)
    print("Intrinsic:", batch["intrinsic"].shape) # (B, V, 3, 3)
    break

RGB: torch.Size([3, 5, 3, 540, 960])
Depth: torch.Size([3, 5, 540, 960])
Extrinsic: torch.Size([3, 5, 4, 4])
Intrinsic: torch.Size([3, 5, 3, 3])


# Utility functions

In [82]:
def calculate_ray_dirs(intrinsics, W, H):
    i, j = torch.meshgrid(
        torch.arange(W, dtype=torch.float32),
        torch.arange(H, dtype=torch.float32),
        indexing='xy'
    ) 

    dirs = torch.stack([
        (i - intrinsics[0, 2]) / intrinsics[0, 0], # [(X_i) - (fx)]/ (B, cx) 
        (j - intrinsics[1, 2]) / intrinsics[1, 1],
        torch.ones_like(i)
    ], -1) 

    dirs = dirs / torch.norm(dirs, dim=-1, keepdim=True)  # normalize

    return dirs  # (W,H,[dir_x, dir_y, dir_z])

# Test
W, H = test_sample['rgb'].shape[-1], test_sample['rgb'].shape[-2]
calculate_ray_dirs(test_sample["intrinsic"][0,1], W, H).shape

torch.Size([540, 960, 3])

In [83]:
def project_depth_to_camera_3d(depth_map, ray_dirs):
    return depth_map[..., None] * ray_dirs # return [X, Y, Z] 3D points in camera space

def project_3d_to_camera_2d(points_3d, intrinsic):
    """
    points_3d: (..., 3) tensor of 3D points in camera space
    intrinsic: (3, 3) camera intrinsic matrix
    Returns:
        points_2d: (..., 2) tensor of 2D pixel coordinates
    """
    fx = intrinsic[0, 0]
    fy = intrinsic[1, 1]
    cx = intrinsic[0, 2]
    cy = intrinsic[1, 2]

    x = points_3d[..., 0]
    y = points_3d[..., 1]
    z = points_3d[..., 2].clamp(min=1e-6)  # Prevent division by zero

    u = fx * (x / z) + cx
    v = fy * (y / z) + cy

    return torch.stack([u, v], dim=-1)

# Test
project_depth_to_camera_3d(
    test_sample['depth'][0,1], 
    calculate_ray_dirs(test_sample["intrinsic"][0,1], W, H)
).shape

torch.Size([540, 960, 3])

In [84]:
def camera_to_world(cam_point_3d, extrinsic):
    """
    cam_point_3d: (H, W, 3) camera-space 3D points
    extrinsic:   (4,4) world -> camera matrix
    Returns:     (H, W, 3) world-space points
    """

    R = extrinsic[:3, :3]      # world->cam rotation
    t = extrinsic[:3, 3]       # world->cam translation

    # Inverse transform:
    R_inv = R.T                # cam->world rotation
    t_inv = -R_inv @ t         # cam->world translation

    # reshape points: (H,W,3) -> (H*W,3)
    H, W = cam_point_3d.shape[:2]
    pts = cam_point_3d.reshape(-1, 3).T   # (3, HW)

    # apply transform
    pts_w = R_inv @ pts + t_inv[:, None]  # (3, HW)

    # reshape back
    return pts_w.T.reshape(H, W, 3)

def world_to_camera(world_point_3d, extrinsic):
    """
    world_point_3d: (H, W, 3) world-space 3D points
    extrinsic:      (4,4) world -> camera matrix
    Returns:        (H, W, 3) camera-space points
    """

    R = extrinsic[:3, :3]      # world->cam rotation
    t = extrinsic[:3, 3]       # world->cam translation

    # reshape points: (H,W,3) -> (H*W,3)
    H, W = world_point_3d.shape[:2]
    pts = world_point_3d.reshape(-1, 3).T   # (3, HW)

    # apply transform
    pts_c = R @ pts + t[:, None]  # (3, HW)

    # reshape back
    return pts_c.T.reshape(H, W, 3)

In [85]:
def sample_depth_bilinear(point_2d, depth_map):
    """
    point_2d: (N, 2) tensor of 2D pixel coordinates
    depth_map: (H, W) tensor of depth values
    Returns:
        sampled_depths: (N,) tensor of sampled depth values
    """
    H, W = depth_map.shape
    x = point_2d[:, 0]
    y = point_2d[:, 1]

    x0 = torch.floor(x).long().clamp(0, W - 1)
    x1 = (x0 + 1).clamp(0, W - 1)
    y0 = torch.floor(y).long().clamp(0, H - 1)
    y1 = (y0 + 1).clamp(0, H - 1)

    Ia = depth_map[y0, x0]
    Ib = depth_map[y1, x0]
    Ic = depth_map[y0, x1]
    Id = depth_map[y1, x1]

    wa = (x1.float() - x) * (y1.float() - y)
    wb = (x1.float() - x) * (y - y0.float())
    wc = (x - x0.float()) * (y1.float() - y)
    wd = (x - x0.float()) * (y - y0.float())

    sampled_depths = wa * Ia + wb * Ib + wc * Ic + wd * Id

    return sampled_depths

# Batched ultilities

In [None]:
import torch
import torch.nn.functional as F

EPS = 1e-6

# ---- helpers (slightly cleaned) ----
def _pixel_grid(H, W, device, dtype):
    # Returns pixel coords u (x) and v (y) shaped (H, W)
    v, u = torch.meshgrid(
        torch.arange(H, device=device, dtype=dtype),
        torch.arange(W, device=device, dtype=dtype),
        indexing='ij'
    )
    return u, v  # u->x, v->y

def batched_ray_dirs(intrinsics, H, W):
    """
    intrinsics: (B, V, 3, 3)
    returns: dirs (B, V, H, W, 3) normalized camera ray directions
    """
    device = intrinsics.device
    dtype = intrinsics.dtype
    B, V = intrinsics.shape[0], intrinsics.shape[1]

    u, v = _pixel_grid(H, W, device, dtype)   # (H, W) u=x, v=y
    u = u.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
    v = v.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)

    fx = intrinsics[..., 0, 0].unsqueeze(-1).unsqueeze(-1)  # (B,V,1,1)
    fy = intrinsics[..., 1, 1].unsqueeze(-1).unsqueeze(-1)
    cx = intrinsics[..., 0, 2].unsqueeze(-1).unsqueeze(-1)
    cy = intrinsics[..., 1, 2].unsqueeze(-1).unsqueeze(-1)

    dirs_x = (u - cx) / (fx + EPS)   # (B,V,H,W)
    dirs_y = (v - cy) / (fy + EPS)
    dirs_z = torch.ones_like(dirs_x)

    dirs = torch.stack([dirs_x, dirs_y, dirs_z], dim=-1)  # (B,V,H,W,3)
    norm = torch.norm(dirs, dim=-1, keepdim=True)
    dirs = dirs / (norm + EPS)
    return dirs


def batched_depth_to_camera(depth, ray_dirs):
    # depth: (B,V,H,W), ray_dirs: (B,V,H,W,3)
    return depth.unsqueeze(-1) * ray_dirs  # (B,V,H,W,3)


def batched_camera_to_world(cam_pts, extrinsics):
    """
    cam_pts: (B, V, H, W, 3)
    extrinsics: (B, V, 4, 4)  (world -> camera)
    returns: world_pts (B, V, H, W, 3)
    """
    B, V, H, W, _ = cam_pts.shape
    R = extrinsics[..., :3, :3]         # (B,V,3,3) world->cam
    t = extrinsics[..., :3, 3]          # (B,V,3)

    R_inv = R.transpose(-1, -2)         # (B,V,3,3) cam->world
    t_inv = -torch.matmul(R_inv, t.unsqueeze(-1)).squeeze(-1)  # (B,V,3)

    pts = cam_pts.reshape(B, V, -1, 3)  # (B,V,HW,3)
    pts_t = pts.permute(0,1,3,2)        # (B,V,3,HW)
    world = torch.matmul(R_inv, pts_t) + t_inv.unsqueeze(-1)    # (B,V,3,HW)
    world = world.permute(0,1,3,2).reshape(B, V, H, W, 3)
    return world


def batched_world_to_camera(world_pts, extrinsics):
    """
    world_pts: (B, V, H, W, 3)
    extrinsics: (B, V, 4, 4)  (world -> camera)
    returns: cam_pts (B, V, H, W, 3)
    """
    B, V, H, W, _ = world_pts.shape
    R = extrinsics[..., :3, :3]    # (B,V,3,3)
    t = extrinsics[..., :3, 3]     # (B,V,3)

    pts = world_pts.reshape(B, V, -1, 3).permute(0,1,3,2)  # (B,V,3,HW)
    cam = torch.matmul(R, pts) + t.unsqueeze(-1)           # (B,V,3,HW)
    cam = cam.permute(0,1,3,2).reshape(B, V, H, W, 3)
    return cam


def batched_project_3d_to_2d(pts_3d, intrinsics):
    """
    pts_3d: (B, V, H, W, 3) in camera coords
    intrinsics: (B, V, 3, 3)
    returns: uv (B, V, H, W, 2), z (B,V,H,W)
    """
    x = pts_3d[..., 0]
    y = pts_3d[..., 1]
    z = pts_3d[..., 2].clamp(min=EPS)

    z = torch.clamp(z, min=1e-4)

    fx = intrinsics[..., 0, 0].unsqueeze(-1).unsqueeze(-1)
    fy = intrinsics[..., 1, 1].unsqueeze(-1).unsqueeze(-1)
    cx = intrinsics[..., 0, 2].unsqueeze(-1).unsqueeze(-1)
    cy = intrinsics[..., 1, 2].unsqueeze(-1).unsqueeze(-1)

    u = fx * (x / z) + cx
    v = fy * (y / z) + cy
    uv = torch.stack([u, v], dim=-1)
    return uv, z


def normalize_uv_for_grid_sample(uv, H, W):
    """
    uv: (..., 2) pixel coords with u in [0..W-1], v in [0..H-1]
    returns: grid coords in [-1,1] last-dim order (x,y) for grid_sample
    """
    u = uv[..., 0]
    v = uv[..., 1]

    nx = (u / (W - 1)) * 2 - 1
    ny = (v / (H - 1)) * 2 - 1
    return torch.stack([nx, ny], dim=-1)


# ================================================================
# UNIT DETECTION & AUTO-SCALING
# ================================================================
def camera_centers_from_extrinsics(extrinsics):
    """
    extrinsics: (B, V, 4, 4) world->camera
    returns: camera centers in world coords (B, V, 3)
    """
    R = extrinsics[..., :3, :3]   # (B,V,3,3)
    t = extrinsics[..., :3, 3]    # (B,V,3)
    R_inv = R.transpose(-1, -2)   # cam->world
    cam_centers = -torch.matmul(R_inv, t.unsqueeze(-1)).squeeze(-1)  # (B,V,3)
    return cam_centers


def detect_and_fix_depth_unit(depth_batch, extrinsics, threshold_scale=10.0, apply_fix=True):
    """
    Heuristic: if median depth >> median camera-translation magnitude,
    likely depth is in mm (or cm). We scale down by 1000 or 100 accordingly.
    Returns (depth_batch_scaled, scale_factor, did_scale_flag, med_scale, med_depth_mean, baseline_mean)
    """
    device = depth_batch.device
    B, V, H, W = depth_batch.shape

    # stats on depth center
    depth_center = depth_batch[:, 1]  # (B,H,W)
    med_depth = torch.median(depth_center.reshape(B, -1), dim=1).values  # (B,)

    # camera center distances (per sample)
    centers = camera_centers_from_extrinsics(extrinsics)  # (B,V,3)
    # baseline magnitude between src(0) and center(1)
    baseline = torch.norm(centers[:, 0] - centers[:, 1], dim=-1)  # (B,)
    # avoid zero baseline
    baseline = baseline + 1e-6

    scale_factors = med_depth / baseline  # if >> threshold -> likely depth in mm
    med_scale = float(torch.median(scale_factors).item())

    # Decide conversion
    scale = 1.0
    did_scale = False
    if med_scale > threshold_scale:
        # typical: depth in mm -> divide by 1000
        if med_scale > 1000:
            scale = 1.0 / 1000.0
        else:
            # try 1000 first
            scale = 1.0 / 1000.0
        if apply_fix:
            depth_batch = depth_batch * scale
            did_scale = True

    return depth_batch, scale, did_scale, med_scale, float(med_depth.mean().item()), float(baseline.mean().item())


def reprojection_pair_to_center(depth_batch, intrinsics, extrinsics,
                                 center_idx=1, src_idx=0, center_depth_override=None):
    """
    Reproject source view to center view via 3D world points.
    
    Args:
        depth_batch: (B, V, H, W) depth maps
        intrinsics: (B, V, 3, 3) camera intrinsics
        extrinsics: (B, V, 4, 4) camera extrinsics (world->camera)
        center_idx: index of center view
        src_idx: index of source view to reproject
        center_depth_override: optional (B, H, W) predicted depth for center view
    
    Returns:
        X_src_world: (B, H, W, 3) 3D points from source view
        X_center_world: (B, H, W, 3) reprojected 3D points in center view
        valid: (B, H, W) valid pixel mask
        uv: (B, H, W, 2) projected pixel coordinates
    """
    B, V, H, W = depth_batch.shape
    EPS = 1e-4
    
    # Compute ray directions for all views
    ray_dirs = batched_ray_dirs(intrinsics, H, W)
    
    # Backproject source depth to 3D camera space
    depth_src = torch.clamp(depth_batch[:, src_idx], min=EPS, max=1e6).unsqueeze(-1)
    cam_pts_src = depth_src * ray_dirs[:, src_idx]
    
    # Transform source points to world space
    X_src_world = batched_camera_to_world(
        cam_pts_src.unsqueeze(1), 
        extrinsics[:, src_idx:src_idx+1]
    ).squeeze(1)
    
    # Transform world points to center camera space
    X_src_in_center = batched_world_to_camera(
        X_src_world.unsqueeze(1),
        extrinsics[:, center_idx:center_idx+1]
    ).squeeze(1)
    
    # Clamp depth (z) to positive values (avoid in-place ops)
    X_src_in_center = torch.cat([
        X_src_in_center[..., :2],
        X_src_in_center[..., 2:3].clamp(min=EPS)
    ], dim=-1)
    
    # Project to 2D pixel coordinates in center view
    uv, z_proj = batched_project_3d_to_2d(
        X_src_in_center.unsqueeze(1),
        intrinsics[:, center_idx:center_idx+1]
    )
    uv = uv.squeeze(1)
    z_proj = z_proj.squeeze(1)
    
    # Clamp UV to image bounds
    uv = torch.stack([
        uv[..., 0].clamp(0.0, W - 1.0),
        uv[..., 1].clamp(0.0, H - 1.0)
    ], dim=-1)
    
    # Sample depth at projected coordinates
    grid = normalize_uv_for_grid_sample(uv, H, W)
    depth_center = (center_depth_override if center_depth_override is not None 
                   else depth_batch[:, center_idx])
    depth_center = torch.clamp(depth_center, min=EPS, max=1e6).unsqueeze(1)
    
    sampled_depth = F.grid_sample(
        depth_center, grid, 
        mode='bilinear', 
        padding_mode='zeros', 
        align_corners=True
    ).squeeze(1)
    
    # Sample ray directions at projected coordinates
    ray_center = ray_dirs[:, center_idx].permute(0, 3, 1, 2)
    sampled_rays = F.grid_sample(
        ray_center, grid,
        mode='bilinear',
        padding_mode='zeros',
        align_corners=True
    ).permute(0, 2, 3, 1)
    
    # Backproject sampled center depth to world space
    cam_pts_center = sampled_depth.unsqueeze(-1) * sampled_rays
    X_center_world = batched_camera_to_world(
        cam_pts_center.unsqueeze(1),
        extrinsics[:, center_idx:center_idx+1]
    ).squeeze(1)
    
    # Compute valid mask
    in_bounds = (uv[..., 0] >= 0) & (uv[..., 0] < W) & \
                (uv[..., 1] >= 0) & (uv[..., 1] < H)
    valid = in_bounds & (z_proj > 0) & (sampled_depth > 0)
    
    return X_src_world, X_center_world, valid, uv


# Train loop

In [None]:
# -------------------------
# Wrapper: DepthAnything forward on RGB batch
# -------------------------
def depth_anything_forward(rgb_batch, depth_anything_model):
    """
    rgb_batch: (B, V, 3, H, W) in [0, 1] float32
    Returns: depth_init_all (B, V, H, W) normalized to [0, 1]
    """
    from PIL import Image
    import numpy as np
    
    device = rgb_batch.device
    B, V, C, H, W = rgb_batch.shape
    
    depth_init_all = []
    
    for b in range(B):
        depth_batch_b = []
        for v in range(V):
            # Extract single RGB frame: (3, H, W) -> PIL Image
            rgb_frame = rgb_batch[b, v].permute(1, 2, 0)  # (H, W, 3)
            rgb_frame_np = (rgb_frame * 255).clamp(0, 255).byte().cpu().numpy()
            
            # Convert numpy to PIL Image
            pil_image = Image.fromarray(rgb_frame_np.astype(np.uint8), mode='RGB')
            
            # Run Depth Anything inference
            with torch.no_grad():
                result = depth_anything_model(pil_image)
            
            # result is a dict with 'depth' key (PIL Image)
            depth_pred = result['depth']
            
            # Convert PIL Image to numpy
            if isinstance(depth_pred, Image.Image):
                depth_pred = np.array(depth_pred)
            
            # Normalize to [0, 1]
            depth_pred = depth_pred.astype(np.float32)
            depth_min, depth_max = depth_pred.min(), depth_pred.max()
            if depth_max > depth_min:
                depth_pred = (depth_pred - depth_min) / (depth_max - depth_min)
            else:
                depth_pred = np.ones_like(depth_pred) * 0.5
            
            # Convert back to tensor
            depth_pred = torch.from_numpy(depth_pred).float().to(device)  # (H, W)
            depth_batch_b.append(depth_pred)
        
        depth_init_all.append(torch.stack(depth_batch_b, dim=0))  # (V, H, W)
    
    depth_init_all = torch.stack(depth_init_all, dim=0)  # (B, V, H, W)
    return depth_init_all

In [None]:
# ===========================
# Loss Functions
# ===========================

def affine_align_depth(depth_init, gt_depth, mask, eps=1e-6):
    """Compute affine alignment: s, t such that s*depth_init + t ≈ gt_depth"""
    B = depth_init.shape[0]
    s_list, t_list = [], []
    depth_aligned = torch.zeros_like(depth_init)
    
    for b in range(B):
        m = mask[b, 0].reshape(-1)
        if m.sum() < 10:
            s_list.append(1.0)
            t_list.append(0.0)
            depth_aligned[b] = depth_init[b]
            continue
        
        d = depth_init[b, 0].reshape(-1)[m]
        g = gt_depth[b, 0].reshape(-1)[m]
        A = torch.stack([d, torch.ones_like(d)], dim=1)
        g_col = g.unsqueeze(1)
        
        try:
            x = torch.linalg.lstsq(A, g_col).solution
            s, t = float(x[0].item()), float(x[1].item())
        except:
            s, t = 1.0, 0.0
        
        s_list.append(s)
        t_list.append(t)
        depth_aligned[b, 0] = depth_init[b, 0] * s + t
    
    return depth_aligned, s_list, t_list


def edge_aware_smoothness(depth, rgb):
    """Edge-aware smoothness: smooth depth except at RGB edges"""
    grad_depth_x = torch.abs(depth[:, :, :, :-1] - depth[:, :, :, 1:])
    grad_depth_y = torch.abs(depth[:, :, :-1, :] - depth[:, :, 1:, :])
    grad_rgb_x = torch.mean(torch.abs(rgb[:, :, :, :-1] - rgb[:, :, :, 1:]), dim=1, keepdim=True)
    grad_rgb_y = torch.mean(torch.abs(rgb[:, :, :-1, :] - rgb[:, :, 1:, :]), dim=1, keepdim=True)
    return torch.mean(grad_depth_x * torch.exp(-grad_rgb_x)) + torch.mean(grad_depth_y * torch.exp(-grad_rgb_y))


def normals_from_depth(depth, intrinsics):
    """Compute surface normals from depth"""
    B = depth.shape[0]
    intr_center = intrinsics[:, 0]
    fx = intr_center[:, 0, 0].view(B, 1, 1, 1)
    fy = intr_center[:, 1, 1].view(B, 1, 1, 1)
    
    dz_dx = (depth[:, :, :, 1:] - depth[:, :, :, :-1]) / (fx + 1e-6)
    dz_dy = (depth[:, :, 1:, :] - depth[:, :, :-1, :]) / (fy + 1e-6)
    dz_dx = torch.cat([dz_dx, dz_dx[:, :, :, -1:]], dim=3)
    dz_dy = torch.cat([dz_dy, dz_dy[:, :, -1:, :]], dim=2)
    
    normals = torch.stack([-dz_dx, -dz_dy, torch.ones_like(dz_dx)], dim=-1)
    norm = torch.norm(normals, dim=-1, keepdim=True) + 1e-6
    return normals / norm


def normal_smoothness_loss(normals):
    """Encourage smooth normals"""
    n = normals
    return torch.mean(torch.abs(n[:, :, :, :-1, :] - n[:, :, :, 1:, :])) + \
           torch.mean(torch.abs(n[:, :, :-1, :, :] - n[:, :, 1:, :, :]))


def masked_l1_huber(pred, gt, mask, eps=1e-6, beta=0.1):
    """Masked Huber loss"""
    valid = mask.float()
    pred_safe = torch.clamp(pred, min=0.0, max=1e4)
    gt_safe = torch.clamp(gt, min=0.0, max=1e4)
    huber = torch.nn.SmoothL1Loss(reduction='none', beta=beta)
    loss = huber(pred_safe, gt_safe)
    loss = torch.nan_to_num(loss, nan=0.0, posinf=1e2, neginf=0.0)
    loss = torch.clamp(loss, min=0.0, max=1e2)
    masked_loss = (loss * valid).sum() / (valid.sum() + eps)
    return torch.nan_to_num(masked_loss, nan=0.0, posinf=1e2, neginf=0.0)


# ===========================
# Training Loop
# ===========================

def train_one_epoch(model, depth_anything_model, dataloader, optimizer, device, 
                    epoch=0, num_epochs=1, use_gt_for_loss=True, 
                    lambda_mv=0.001, lambda_init=0.1, lambda_edge=0.05, 
                    lambda_norm=0.1, num_steps=1):
    """
    Training epoch with multi-view consistency and Huber loss
    """
    model.train()
    GT_INVALID_THRESH = 100.0
    EPS = 1e-6
    
    for step, batch in enumerate(dataloader):
        if step >= num_steps:
            break
        
        batch_rgb = batch['rgb'].to(device)
        batch_depth = batch['depth'].to(device)
        batch_intr = batch['intrinsic'].to(device)
        batch_extr = batch['extrinsic'].to(device)
        
        B, V, _, H, W = batch_rgb.shape
        center_idx = V // 2
        
        print(f"\n[Epoch {epoch+1}/{num_epochs}] Step {step+1}/{num_steps}")
        
        # DepthAnything inference
        depth_init_all = depth_anything_forward(batch_rgb, depth_anything_model)
        depth_init_all = torch.clamp(depth_init_all, min=1e-4, max=1.0)
        
        # Detect & fix depth unit mismatch
        depth_init_scaled, _, _, _, _, _ = \
            detect_and_fix_depth_unit(batch_depth, batch_extr, threshold_scale=10.0, apply_fix=True)
        
        # Prepare center view depth
        depth_init_center = depth_init_all[:, center_idx].unsqueeze(1)
        gt_center = depth_init_scaled[:, center_idx].unsqueeze(1)
        gt_valid = (~torch.isinf(gt_center)) & (gt_center < GT_INVALID_THRESH) & (gt_center > 0)
        n_valid_gt = gt_valid.float().sum().item()
        
        # Affine alignment
        try:
            depth_init_aligned, s_scales, t_offsets = affine_align_depth(depth_init_center, gt_center, gt_valid)
        except:
            depth_init_aligned = depth_init_center.clone()
            s_scales = [1.0] * B
            t_offsets = [0.0] * B
        
        # Prepare UNet input
        ray_dirs_batch = batched_ray_dirs(batch_intr, H, W)
        rgb_center = batch_rgb[:, center_idx]
        ray_dirs_center = ray_dirs_batch[:, center_idx].permute(0, 3, 1, 2)
        model_input = torch.cat([rgb_center, depth_init_aligned, ray_dirs_center], dim=1)
        
        # Forward pass
        depth_delta = model(model_input)
        pred_depth_center = depth_init_aligned + depth_delta
        pred_depth_center = torch.clamp(pred_depth_center, min=1e-4, max=1e6)
        
        # Multi-view reprojection loss
        mv_loss = 0.0
        for src_idx in [0, 2]:
            X_src, X_center, valid, _ = reprojection_pair_to_center(
                depth_init_scaled, batch_intr, batch_extr,
                center_idx=center_idx, src_idx=src_idx,
                center_depth_override=pred_depth_center.squeeze(1),
                debug=False
            )
            error_3d = torch.norm(X_src - X_center, dim=-1)
            mv_loss += (error_3d * valid.float()).sum() / (valid.float().sum() + EPS)
        mv_loss = mv_loss / 2.0
        
        # GT supervised loss (Huber with affine alignment)
        L_gt = 0.0
        if use_gt_for_loss and n_valid_gt > 10:
            s_t = torch.tensor(s_scales, device=device, dtype=pred_depth_center.dtype).view(B, 1, 1, 1)
            t_t = torch.tensor(t_offsets, device=device, dtype=pred_depth_center.dtype).view(B, 1, 1, 1)
            pred_aligned = pred_depth_center * s_t + t_t
            L_gt = masked_l1_huber(pred_aligned, gt_center, gt_valid, beta=0.1)
        
        # Regularization losses
        L_init = edge_aware_smoothness(depth_init_aligned, rgb_center)
        L_edge = edge_aware_smoothness(pred_depth_center, rgb_center)
        normals = normals_from_depth(pred_depth_center, batch_intr)
        L_norm = normal_smoothness_loss(normals)
        
        # Combined loss
        total_loss = lambda_mv*mv_loss + 1.0*L_gt + lambda_init*L_init + lambda_edge*L_edge + lambda_norm*L_norm
        
        # Backward pass
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        # Log
        print(f"  Total: {total_loss:.4f} | MV: {lambda_mv*mv_loss:.4f} | GT: {L_gt:.4f} | "
              f"Init: {L_init:.4f} | Edge: {L_edge:.4f} | Norm: {L_norm:.4f}")


In [133]:
# ===========================
# Load Depth Anything Model
# ===========================
if 'depth_anything' not in dir():
    from transformers import pipeline
    print("[Loading] Depth Anything Small...")
    depth_anything = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
    print("[Loaded] Depth Anything Small")
else:
    print("[Using] Existing Depth Anything pipeline")

# ===========================
# Training Configuration
# ===========================
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"[Device] {device}")

# Initialize model
model = Unet(in_channel=7, out_channel=1, base=32).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training parameters
num_epochs = 1
steps_per_epoch = 1

print(f"\n{'='*80}")
print(f"Training: {num_epochs} epochs × {steps_per_epoch} steps")
print(f"{'='*80}\n")

# Train
for epoch in range(num_epochs):
    train_one_epoch(
        model=model,
        depth_anything_model=depth_anything,
        dataloader=dataloader,
        optimizer=optimizer,
        device=device,
        epoch=epoch,
        num_epochs=num_epochs,
        use_gt_for_loss=True,
        lambda_mv=0.001,  # Reduce MV loss weight to balance with GT
        lambda_init=0.1,
        lambda_edge=0.05,
        lambda_norm=0.1,
        num_steps=steps_per_epoch
    )

print(f"\n{'='*80}")
print("Training completed!")
print(f"{'='*80}")


[Using] Existing Depth Anything pipeline
[Device] cuda

Training: 1 epochs × 1 steps


[Epoch 1/1] Step 1/1

[Epoch 1/1] Step 1/1
  Total: 118.6900 | MV: 118.1199 | GT: 0.5697 | Init: 0.0030 | Edge: 0.0032 | Norm: 0.0000
  Total: 118.6900 | MV: 118.1199 | GT: 0.5697 | Init: 0.0030 | Edge: 0.0032 | Norm: 0.0000

Training completed!

Training completed!
