In [1]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("../")

In [None]:
import torch
from datasets import load_dataset
from depth_anything_v2.depth_anything_v2.dpt import DepthAnythingV2

DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
ENCODER = 'vits' # or 'vits', 'vitb', 'vitg'
MODEL_PATH = f'/home/tomchen/Downloads/depth_anything_v2_{ENCODER}.pth'

model_configs = {
    'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
    'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
    'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
    'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}

model = DepthAnythingV2(**model_configs[ENCODER])
model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))
model = model.to(DEVICE).eval()
# Load the dataset
dataset = load_dataset("ntudlcv/dlcv_2024_final1")

In [None]:
import numpy as np
from torch.utils.data import DataLoader


# Custom collate function to handle dictionaries with PIL images
def collate_fn(batch):
    images = [item["image"] for item in batch]
    other_properties = [
        {k: v for k, v in item.items() if k != "image"} for item in batch
    ]
    return images, other_properties


# Create a DataLoader for the dataset with the custom collate function
data_loader = DataLoader(
    dataset["train"], batch_size=8, shuffle=True, collate_fn=collate_fn
)

# Load only one batch
images, other_properties = next(iter(data_loader))

# Get depth maps for the images
depths = []
for image in images:
    depth = model.infer_image(np.array(image))
    depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
    depths.append(depth)

In [21]:
image = images[0]
depth = model.infer_image(np.array(image))
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255
depths.append(depth)

torch.Size([1, 64, 296, 560]) torch.Size([1, 64, 148, 280]) torch.Size([1, 64, 74, 140]) torch.Size([1, 64, 37, 70])
torch.Size([1, 518, 980])
torch.Size([720, 1355])


In [None]:
import matplotlib.pyplot as plt

# Plot a grid showing the original images and the depth maps
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(10, 7))

for i, (image, depth) in enumerate(zip(images, depths)):
    row = (i // 4) * 2
    col = i % 4

    axes[row, col].imshow(image)
    axes[row, col].axis("off")
    axes[row, col].set_title(f"Original {i+1}")

    axes[row + 1, col].imshow(depth, cmap="viridis")
    axes[row + 1, col].axis("off")
    axes[row + 1, col].set_title(f"Depth {i+1}")

plt.tight_layout(h_pad=0)
plt.show()