In [14]:
# ================================================================
# âœ… Final Stable MiDaS Fine-Tuning Script for KITTI-like Depth Data
# ================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from data_loader import get_data_loaders
import os
import numpy as np
from skimage.metrics import structural_similarity as ssim_metric

# ----------------------------
# Setup
# ----------------------------
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[INFO] Using device: {device}")

# ----------------------------
# Load MiDaS & Transforms
# ----------------------------
midas = torch.hub.load("intel-isl/MiDaS", "DPT_Hybrid")
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
midas_transform = midas_transforms.dpt_transform  # preprocessing for DPT_Hybrid

# ----------------------------
# Modify model for fine-tuning
# ----------------------------
for param in midas.parameters():
    param.requires_grad = False

# Replace final output layer for single-channel depth prediction
midas.scratch.output_conv = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)
midas.to(device)

# ----------------------------
# Loss & Optimizer
# ----------------------------
criterion = nn.L1Loss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, midas.parameters()), lr=1e-4)

# ----------------------------
# Load Data
# ----------------------------
train_loader, val_loader = get_data_loaders('data', pseudo_dir='pseudo_data', batch_size=2)

# ----------------------------
# Training Loop
# ----------------------------
num_epochs = 3
os.makedirs("outputs", exist_ok=True)

for epoch in range(num_epochs):
    midas.train()
    total_loss = 0.0

    for imgs, depths in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        imgs, depths = imgs.to(device), depths.to(device)

        # Apply MiDaS transform to each image
        batch_imgs = []
        for img in imgs:
            np_img = img.permute(1, 2, 0).cpu().numpy()  # (H, W, 3)
            t = midas_transform(np_img)
            if isinstance(t, dict):
                t = t["image"]
            batch_imgs.append(t)

        imgs = torch.stack(batch_imgs).to(device)
        if imgs.dim() == 5:
            imgs = imgs.squeeze(1)

        optimizer.zero_grad()
        preds = midas(imgs)

        # --- Fix output shape ---
        if preds.ndim == 3:  # [N, H, W]
            preds = preds.unsqueeze(1)
        elif preds.ndim == 2:  # single image [H, W]
            preds = preds.unsqueeze(0).unsqueeze(0)

        # --- Match depth map size ---
        preds = torch.nn.functional.interpolate(
            preds, size=depths.shape[-2:], mode="bilinear", align_corners=False
        )

        loss = criterion(preds, depths)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"[Epoch {epoch+1}] Training Loss: {total_loss / len(train_loader):.4f}")

# ----------------------------
# Evaluation
# ----------------------------
midas.eval()
predictions, ground_truths = [], []

with torch.no_grad():
    for imgs, depths in tqdm(val_loader, desc="Evaluating"):
        imgs, depths = imgs.to(device), depths.to(device)

        batch_imgs = []
        for img in imgs:
            np_img = img.permute(1, 2, 0).cpu().numpy()
            t = midas_transform(np_img)
            if isinstance(t, dict):
                t = t["image"]
            batch_imgs.append(t)

        imgs = torch.stack(batch_imgs).to(device)
        if imgs.dim() == 5:
            imgs = imgs.squeeze(1)

        preds = midas(imgs)
        if preds.ndim == 3:
            preds = preds.unsqueeze(1)

        preds = torch.nn.functional.interpolate(
            preds, size=depths.shape[-2:], mode="bilinear", align_corners=False
        )

        predictions.append(preds.cpu())
        ground_truths.append(depths.cpu())

predictions = torch.cat(predictions)
ground_truths = torch.cat(ground_truths)

# ----------------------------
# Compute Evaluation Metrics
# ----------------------------
mae = torch.mean(torch.abs(predictions - ground_truths)).item()
mse = torch.mean((predictions - ground_truths) ** 2).item()
rmse = np.sqrt(mse)

# Convert to numpy for SSIM
pred_np = predictions.squeeze().numpy()
gt_np = ground_truths.squeeze().numpy()
if pred_np.ndim == 3:  # handle batch
    ssim_scores = [ssim_metric(p, g, data_range=g.max() - g.min()) for p, g in zip(pred_np, gt_np)]
    mean_ssim = np.mean(ssim_scores)
else:
    mean_ssim = ssim_metric(pred_np, gt_np, data_range=gt_np.max() - gt_np.min())

print("\nðŸ“Š Evaluation Metrics:")
print(f"MAE:  {mae:.4f}")
print(f"MSE:  {mse:.4f}")
print(f"RMSE: {rmse:.4f}")
print(f"SSIM: {mean_ssim:.4f}")

# ----------------------------
# Save Outputs
# ----------------------------
torch.save(midas.state_dict(), "outputs/midas_finetuned.pth")
torch.save({
    "predictions": predictions,
    "ground_truths": ground_truths
}, "outputs/midas_predictions.pt")

print("âœ… Fine-tuned MiDaS model and predictions saved to 'outputs/'")


[INFO] Using device: cuda


Using cache found in C:\Users\kirth/.cache\torch\hub\intel-isl_MiDaS_master
Using cache found in C:\Users\kirth/.cache\torch\hub\intel-isl_MiDaS_master


[INFO] Found 250 real KITTI pairs.
[INFO] Found 2706 pseudo-labeled pairs.
[INFO] Using 2956 imageâ€“depth pairs total.
[INFO] Using 2364 imageâ€“depth pairs for training.
[INFO] Using 592 imageâ€“depth pairs for validation.


Epoch 1/3: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1182/1182 [01:47<00:00, 10.95it/s]


[Epoch 1] Training Loss: 0.5603


Epoch 2/3: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1182/1182 [01:49<00:00, 10.81it/s]


[Epoch 2] Training Loss: 0.2635


Epoch 3/3: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 1182/1182 [01:51<00:00, 10.55it/s]


[Epoch 3] Training Loss: 0.2493


Evaluating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 296/296 [00:40<00:00,  7.33it/s]



ðŸ“Š Evaluation Metrics:
MAE:  0.4510
MSE:  0.2528
RMSE: 0.5027
SSIM: 0.3223
âœ… Fine-tuned MiDaS model and predictions saved to 'outputs/'


In [15]:
# ================================================================
# Visualization Script for MiDaS Fine-Tuned Model
# ================================================================
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from skimage.metrics import structural_similarity as ssim

# ----------------------------
# Load Saved Predictions
# ----------------------------
data = torch.load("outputs/midas_predictions.pt")
predictions = data["predictions"].squeeze().numpy()
ground_truths = data["ground_truths"].squeeze().numpy()

# Make sure both have same number of samples
num_samples = min(len(predictions), len(ground_truths))
print(f"Loaded {num_samples} samples for visualization.")

os.makedirs("outputs/visuals", exist_ok=True)

# ----------------------------
# Utility Function: Normalize for display
# ----------------------------
def normalize_depth(depth):
    depth = depth - np.min(depth)
    depth = depth / (np.max(depth) + 1e-8)
    return depth

# ----------------------------
# Plot a few samples
# ----------------------------
for i in range(min(5, num_samples)):
    pred = normalize_depth(predictions[i])
    gt = normalize_depth(ground_truths[i])
    
    # Compute SSIM for display
    ssim_score = ssim(pred, gt, data_range=gt.max() - gt.min())
    
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(pred, cmap='inferno')
    plt.title(f"Predicted Depth (SSIM={ssim_score:.3f})")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(gt, cmap='inferno')
    plt.title("Ground Truth Depth")
    plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(f"outputs/visuals/depth_compare_{i}.png")
    plt.close()

print("âœ… Saved visual comparisons in 'outputs/visuals/'")

# ----------------------------
# Optional: Histogram of depth values
# ----------------------------
plt.figure(figsize=(6, 4))
plt.hist(predictions.flatten(), bins=50, alpha=0.6, label="Predicted")
plt.hist(ground_truths.flatten(), bins=50, alpha=0.6, label="Ground Truth")
plt.title("Depth Value Distribution")
plt.xlabel("Normalized Depth")
plt.ylabel("Frequency")
plt.legend()
plt.tight_layout()
plt.savefig("outputs/visuals/depth_distribution.png")
plt.close()

print("âœ… Saved depth distribution plot as 'depth_distribution.png'")


  data = torch.load("outputs/midas_predictions.pt")


Loaded 592 samples for visualization.
âœ… Saved visual comparisons in 'outputs/visuals/'
âœ… Saved depth distribution plot as 'depth_distribution.png'
