Env setting

In [None]:
# To use nyu depth v2 dataset in hugging face, use datasets==3.6.0
!pip install datasets==3.6
!pip install ptflops

import

In [None]:
from google.colab import drive
import os

drive.mount('/content/drive')

# save path
save_dir = '/content/drive/MyDrive/DepthProject'
os.makedirs(save_dir, exist_ok=True)

print(f"Model save path: {save_dir}")

In [None]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF
from PIL import Image
import timm
from datasets import load_dataset
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import time


Config

In [None]:
CONFIG = {
    "backbone": "swin_tiny_patch4_window7_224",
    "batch_size": 128,
    "lr": 1e-3,
    "epochs": 40,
    "image_size": (224, 224),
    "max_depth": 10.0,      # Dataset standard
    "device": "cuda" if torch.cuda.is_available() else "cpu",
}

print(f"device: {CONFIG['device']}")


Data

In [None]:
class HF_NYUDepthDataset(Dataset):
    def __init__(self, hf_dataset, transform=None, is_train=True):
        self.dataset = hf_dataset
        self.transform = transform
        self.is_train = is_train
        self.config = CONFIG
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]

        # Image/RGB, Depth/Depth_map
        if 'image' in sample:
            image = sample['image']
        elif 'rgb' in sample:
            image = sample['rgb']
        else:
            raise KeyError(f"Image not found. Keys: {sample.keys()}")

        if 'depth_map' in sample:
            depth = sample['depth_map'] # For 'Raw' Dataset in NYU-V2
        elif 'depth' in sample:
            depth = sample['depth']     # For 'Standard'
        else:
            raise KeyError(f"Depth not found. Keys: {sample.keys()}")

        # Turn List/Numpy to Tensor

        if isinstance(depth, list):
            depth = np.array(depth, dtype=np.float32)

        if isinstance(depth, np.ndarray):
            depth = torch.from_numpy(depth).float()
            if depth.ndim == 2:
                depth = depth.unsqueeze(0)

        # Resize
        image = TF.resize(image, self.config['image_size'], interpolation=TF.InterpolationMode.BILINEAR)
        depth = TF.resize(depth, self.config['image_size'], interpolation=TF.InterpolationMode.NEAREST)

        image = TF.to_tensor(image) # PIL 0-255 to 0-1 Tensor
        image = TF.normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        if not isinstance(depth, torch.Tensor):
            depth = TF.to_tensor(depth)

        # Scaling
        current_max = depth.max().item()

        if current_max > 100.0:
            # mm
            depth = depth / 1000.0
        elif current_max > 20.0:

            depth = depth / 10.0
        elif current_max < 1.1:
            # 0-1
            depth = depth * 10.0
        # clamp
        depth = torch.clamp(depth, 0, self.config['max_depth'])

        return image, depth
print("Downloading HuggingFace Dataset...")
Train_data = load_dataset("sayakpaul/nyu_depth_v2", split="train",trust_remote_code=True,verification_mode="no_checks",num_proc=8)

#small_data = Train_data.select(range(10000))

Val_data = load_dataset("sayakpaul/nyu_depth_v2", split="validation",trust_remote_code=True,verification_mode="no_checks",num_proc=8)


train_dataset = HF_NYUDepthDataset(Train_data, is_train=True)
val_dataset = HF_NYUDepthDataset(Val_data, is_train=False)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=6,pin_memory=True,persistent_workers=True,drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=6,pin_memory=True,persistent_workers=True)

print(f"Train set: {len(train_dataset)} images, Val set: {len(val_dataset)} images")

Model

In [None]:
class ScratchFusionBlock(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, x, skip):

        x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
        if x.shape[-2:] != skip.shape[-2:]:
            x = F.interpolate(x, size=skip.shape[-2:], mode="bilinear", align_corners=True)
        return self.layer(x + skip)

class LightweightDPT(nn.Module):
    def __init__(self, backbone_name, max_depth=10.0):
        super().__init__()
        self.max_depth = max_depth

        # 1. Backbone (Swin / ResNet)
        self.backbone = timm.create_model(backbone_name, pretrained=True, features_only=True)

        for param in self.backbone.parameters():
            param.requires_grad = True

        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            feats = self.backbone(dummy)
            ch = [f.shape[1] for f in feats]

        self.feat_channels = ch[-4:]

        # 2. Decoder Components
        embed_dim = 128

        self.projs = nn.ModuleList([
            nn.Conv2d(c, embed_dim, 1) for c in self.feat_channels
        ])

        self.fusion_blocks = nn.ModuleList([
            ScratchFusionBlock(embed_dim),
            ScratchFusionBlock(embed_dim),
            ScratchFusionBlock(embed_dim),
        ])

        self.head = nn.Sequential(
            nn.Conv2d(embed_dim, 64, 3, 1, 1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
            nn.Conv2d(64, 32, 3, 1, 1),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
            nn.Conv2d(32, 1, 1),
            nn.Sigmoid()

        )

    def forward(self, x):
        features = self.backbone(x)[-4:]
        projs = [p(f) for p, f in zip(self.projs, features)]

        out = projs[3]
        out = self.fusion_blocks[0](out, projs[2])
        out = self.fusion_blocks[1](out, projs[1])
        out = self.fusion_blocks[2](out, projs[0])

        depth = self.head(out)

        if depth.shape[-2:] != x.shape[-2:]:
            depth = F.interpolate(depth, size=x.shape[-2:], mode="bilinear", align_corners=True)

        # Sigmoid output 0~1, then multiple max_depth to predicted depth 0~10m
        return depth * self.max_depth

Loss

In [None]:
class SILogLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.name = "SILog"

    def forward(self, input, target, mask=None):
        if mask is not None:
            input = input[mask]
            target = target[mask]

        if input.numel() < 2:

            if input.numel() > 0:

                return torch.mean(input) * 0.0
            else:
                return torch.tensor(0.0, device=input.device, requires_grad=True)

        min_depth_clamp = 1e-3
        log_input = torch.log(torch.clamp(input, min=min_depth_clamp))
        log_target = torch.log(torch.clamp(target, min=min_depth_clamp))

        g = log_input - log_target

        dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
        return 10 * torch.sqrt(dg)

Metrics

In [None]:
def compute_errors(gt, pred):
    """
    AbsRel, RMSE, Log10, Delta1, Delta2, Delta3
    """
    thresh = torch.max((gt / pred), (pred / gt))
    a1 = (thresh < 1.25).float().mean()
    a2 = (thresh < 1.25 ** 2).float().mean()
    a3 = (thresh < 1.25 ** 3).float().mean()

    rmse = (gt - pred) ** 2
    rmse = torch.sqrt(rmse.mean())

    rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
    rmse_log = torch.sqrt(rmse_log.mean())

    abs_rel = torch.mean(torch.abs(gt - pred) / gt)

    sq_rel = torch.mean(((gt - pred) ** 2) / gt)

    return abs_rel, rmse, a1, a2, a3

def validate_model(model, loader, device):

    model.eval()
    depth_loss = 0.0

    metrics = {'abs_rel': 0, 'rmse': 0, 'a1': 0, 'a2': 0, 'a3': 0}
    count = 0

    with torch.no_grad():
        for images, depths in loader:
            images = images.to(device)
            depths = depths.to(device)

            preds = model(images)
            preds = torch.clamp(preds, min=1e-3, max=10.0)

            mask = (depths > 0.1) & (depths < 10.0)

            if mask.sum() > 0:

                valid_gt = depths[mask]
                valid_pred = preds[mask]

                abs_rel, rmse, a1, a2, a3 = compute_errors(valid_gt, valid_pred)

                metrics['abs_rel'] += abs_rel.item()
                metrics['rmse'] += rmse.item()
                metrics['a1'] += a1.item()
                metrics['a2'] += a2.item()
                metrics['a3'] += a3.item()
                count += 1

    for k in metrics.keys():
        metrics[k] /= count

    return metrics

Comaprision

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Swin-Tiny
swin_model = LightweightDPT(backbone_name="swin_tiny_patch4_window7_224", max_depth=CONFIG['max_depth'])

# ResNet-34
resnet_model = LightweightDPT(backbone_name="resnet34", max_depth=CONFIG['max_depth'])

In [None]:

from ptflops import get_model_complexity_info
with torch.no_grad():
    swin_macs, swin_params = get_model_complexity_info(
        swin_model,
        (3, 224, 224),
        as_strings=True,
        print_per_layer_stat=False,
    )
with torch.no_grad():
    Res_macs, Res_params = get_model_complexity_info(
        resnet_model,
        (3, 224, 224),
        as_strings=True,
        print_per_layer_stat=False,
    )
print("Swin_MACs:", swin_macs)
print("Swin_Params:", swin_params)
print("Res_MACs:", Res_macs)
print("Res_Params:", Res_params)

In [None]:
swin_model = torch.compile(swin_model)
resnet_model = torch.compile(resnet_model)

Train

In [None]:
def visualize_batch(model, loader, device):
    model.eval()
    images, depths = next(iter(loader))
    images = images.to(device)

    with torch.no_grad():
        preds = model(images)

    img = images[1].cpu().permute(1, 2, 0).numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)

    depth_gt = depths[1, 0].cpu().numpy()
    depth_pred = preds[1, 0].cpu().numpy()

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(img)
    axs[0].set_title("Input RGB")

    axs[1].imshow(depth_gt, cmap='inferno', vmin=0, vmax=10)
    axs[1].set_title("Ground Truth Depth")

    im = axs[2].imshow(depth_pred, cmap='inferno', vmin=0, vmax=10)
    axs[2].set_title("Predicted Depth")

    plt.colorbar(im, ax=axs[2])
    plt.show()


In [None]:
def plot_loss_curves(model_name, train_losses, val_losses):
    """
    Plots the training loss (SILogLoss) and validation loss (RMSE) over epochs.
    """
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(10, 6))

    # Plot Training Loss
    plt.plot(epochs, train_losses, 'o-', color='blue', label=f'{model_name} Training Loss (SILog)')

    # Plot Validation Loss
    plt.plot(epochs, val_losses, 'o-', color='red', label=f'{model_name} Validation Loss (RMSE)')

    plt.title(f'Loss Curves for {model_name} Model')
    plt.xlabel('Epoch')
    plt.ylabel('Loss Value')
    plt.legend()
    plt.grid(True)
    plt.show()

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch
import os
from tqdm.notebook import tqdm

def train_and_evaluate(model_name, model, train_loader, val_loader, save_dir, epochs=3):
    device = CONFIG['device']
    model = model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['lr'], weight_decay=0.1)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
    criterion = SILogLoss()

    # Initialize loss and metric trackers
    train_losses_history = []
    val_losses_history = []

    scaler = torch.amp.GradScaler('cuda')

    save_path = os.path.join(save_dir, f"{model_name}_best.pth")
    best_a1 = 0.0

    print(f"\n {model_name} AMP + Cosine Annealing ...")

    for epoch in range(epochs):
        model.train()

        for param in model.parameters():
            if not param.requires_grad:
                param.requires_grad = True

        train_loss = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        for imgs, gts in pbar:
            imgs, gts = imgs.to(device), gts.to(device)
            mask = (gts > 0.1) & (gts < CONFIG['max_depth'])

            optimizer.zero_grad()

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                preds = model(imgs)
                preds = torch.clamp(preds, min=1e-3, max=CONFIG['max_depth'])
                loss = criterion(preds, gts, mask)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})

        # Calculate and record average training loss for the epoch
        avg_train_loss = train_loss / len(train_loader)
        train_losses_history.append(avg_train_loss)

        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']

        val_metrics = validate_model(model, val_loader, device)
        current_rmse = val_metrics['rmse']
        current_a1 = val_metrics['a1']
        val_losses_history.append(current_rmse)

        print(f"Epoch {epoch+1} (LR={current_lr:.2e}): "
              f"Loss={train_loss/len(train_loader):.3f} | "
              f"AbsRel={val_metrics['abs_rel']:.3f} | "
              f"RMSE={val_metrics['rmse']:.3f} | "
              f"δ1={val_metrics['a1']:.3f}")

        if current_a1 > best_a1:
            best_a1 = current_a1

            print(f"Higher acc (δ1: {best_a1:.3f})! Saving...")

            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_a1': best_a1,
                'config': CONFIG
            }, save_path)

            print(f"Model has been saved: {save_path}")

    return val_metrics, train_losses_history, val_losses_history
if __name__ == "__main__":

    swin_results, swin_train_loss, swin_val_loss = train_and_evaluate(
        "Swin-Tiny", swin_model, train_loader, val_loader, save_dir, epochs=CONFIG['epochs']
    )

    resnet_results, resnet_train_loss, resnet_val_loss = train_and_evaluate(
        "ResNet-34", resnet_model, train_loader, val_loader, save_dir, epochs=CONFIG['epochs']
    )

    print("\n" + "="*50)
    print(f"{'Metric':<10} | {'Swin-Tiny (Ours)':<18} | {'ResNet-34 (Base)':<18}")
    print("-" * 50)
    print(f"{'AbsRel':<10} | {swin_results['abs_rel']:<18.4f} | {resnet_results['abs_rel']:<18.4f}")
    print(f"{'RMSE':<10} | {swin_results['rmse']:<18.4f} | {resnet_results['rmse']:<18.4f}")
    print(f"{'δ1 (Acc)':<10} | {swin_results['a1']:<18.4f} | {resnet_results['a1']:<18.4f}")
    print(f"{'δ2 ':<10} | {swin_results['a2']:<18.4f} | {resnet_results['a2']:<18.4f}")
    print("="*50)

    print("Visualize Swin-Tiny:")
    visualize_batch(swin_model, val_loader, CONFIG['device'])

In [None]:
# --- Plot Loss Curves ---
print("\nVisualizing Loss Curves:")
plot_loss_curves("Swin-Tiny", swin_train_loss, swin_val_loss)
plot_loss_curves("ResNet-34", resnet_train_loss, resnet_val_loss)

In [None]:
print("\n" + "="*65)
print(f"{'Metric':<12} | {'Swin-Tiny (Ours)':<20} | {'ResNet-34 (Base)':<20}")
print("="*65)

# Error
print(f"{'AbsRel (↓)':<12} | {swin_results['abs_rel']:<20.4f} | {resnet_results['abs_rel']:<20.4f}")
print(f"{'RMSE   (↓)':<12} | {swin_results['rmse']:<20.4f} | {resnet_results['rmse']:<20.4f}")


print("-" * 65)

# ACC
print(f"{'δ1     (↑)':<12} | {swin_results['a1']:<20.4f} | {resnet_results['a1']:<20.4f}")
print(f"{'δ2     (↑)':<12} | {swin_results['a2']:<20.4f} | {resnet_results['a2']:<20.4f}")
print(f"{'δ3     (↑)':<12} | {swin_results['a3']:<20.4f} | {resnet_results['a3']:<20.4f}")

print("="*65)

Finetune

In [None]:


Standard_Full = load_dataset("0jl/NYUv2", split="train", trust_remote_code=True)
# 795 Training Set
Standard_Train = Standard_Full.select(range(795))
# 654 Testing Set
Standard_Test = Standard_Full.select(range(795, 1449))

ft_train_dataset = HF_NYUDepthDataset(Standard_Train, is_train=True)
ft_val_dataset = HF_NYUDepthDataset(Standard_Test, is_train=False)

ft_train_loader = DataLoader(ft_train_dataset, batch_size=16, shuffle=True, num_workers=0,pin_memory=True)
ft_val_loader = DataLoader(ft_val_dataset, batch_size=16, shuffle=False, num_workers=0,pin_memory=True)


In [None]:

print("Fine-tuning Swin-Tiny...")

from google.colab import drive
drive.mount('/content/drive')

load_path = '/content/drive/MyDrive/DepthProject/Swin-Tiny_best.pth'

if os.path.exists(load_path):
    checkpoint = torch.load(load_path)
    swin_ft_model.load_state_dict(checkpoint['model_state_dict'])
    print(f"δ1: {checkpoint['best_a1']:.3f}")
else:
    print("File doesn't exist")

# Hyperparameter
CONFIG['lr'] = 1e-5

print("Start Fine-tuning on 795 images")
ft_swin_results, ft_swin_loss, ft_swin_val_loss = train_and_evaluate("Swin-Tiny-FT", swin_ft_model, ft_train_loader, ft_val_loader, save_dir,epochs=10)


In [None]:

print("\n" + "="*65)
print(f"{'Metric':<12} | {'Swin-Tiny':<20}")
print("="*65)

print(f"{'AbsRel (↓)':<12} | {ft_swin_results['abs_rel']:<20.4f}")
print(f"{'RMSE   (↓)':<12} | {ft_swin_results['rmse']:<20.4f}")

print("-" * 65)

print(f"{'δ1     (↑)':<12} | {ft_swin_results['a1']:<20.4f}")
print(f"{'δ2     (↑)':<12} | {ft_swin_results['a2']:<20.4f}")
print(f"{'δ3     (↑)':<12} | {ft_swin_results['a3']:<20.4f}")

In [None]:
print("Fine-tuning Resnet...")

from google.colab import drive
drive.mount('/content/drive')

load_path = '/content/drive/MyDrive/DepthProject/ResNet-34_best.pth'

if os.path.exists(load_path):
    checkpoint = torch.load(load_path)
    resnet_ft_model.load_state_dict(checkpoint['model_state_dict'])
    print(f"δ1: {checkpoint['best_a1']:.3f}")
else:
    print("File doesn't exist")

CONFIG['lr'] = 1e-5

print("Start Fine-tuning on 795 images")
ft_res_results, ft_res_loss, ft_res_val_loss = train_and_evaluate("ResNet-34-FT", resnet_ft_model, ft_train_loader, ft_val_loader, save_dir,epochs=10)

In [None]:
print("\n" + "="*65)
print(f"{'Metric':<12} | {'ResNet-34 (Base)':<20}")
print("="*65)

print(f"{'AbsRel (↓)':<12} | {ft_res_results['abs_rel']:<20.4f}")
print(f"{'RMSE   (↓)':<12} | {ft_res_results['rmse']:<20.4f}")

print("-" * 65)

print(f"{'δ1     (↑)':<12} | {ft_res_results['a1']:<20.4f}")
print(f"{'δ2     (↑)':<12} | {ft_res_results['a2']:<20.4f}")
print(f"{'δ3     (↑)':<12} | {ft_res_results['a3']:<20.4f}")

In [None]:
# --- Plot Loss Curves ---
print("\nVisualizing Loss Curves:")
plot_loss_curves("Swin-Tiny-FT", ft_swin_loss, ft_swin_val_loss)
plot_loss_curves("ResNet-34-FT", ft_res_loss, ft_res_val_loss)
visualize_batch(swin_ft_model, ft_val_loader, CONFIG['device'])
visualize_batch(resnet_ft_model, ft_val_loader, CONFIG['device'])
