In [5]:
import matplotlib.pyplot as plt
import numpy as np
from nuscenes.nuscenes import NuScenes
from nuscenes.utils.data_classes import LidarPointCloud
from PIL import Image
import torch

# Function to load NuScenes data
def load_nuscenes_data(nuscenes_data_path):
    nusc = NuScenes(version='v1.0-mini', dataroot=nuscenes_data_path, verbose=True)
    return nusc

# Function to run MiDaS model on an image
def run_midas(image):
    midas.eval()
    midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
    if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
        transform = midas_transforms.dpt_transform
    else:
        transform = midas_transforms.small_transform
    input_batch = transform(Image.open(image)).to(device)
    with torch.no_grad():
        prediction = midas(input_batch)
    prediction = torch.nn.functional.interpolate(
        prediction.unsqueeze(1),
        size=(900, 1600),
        mode="bicubic",
        align_corners=False,
    ).squeeze()
    midas_depth = prediction.cpu().numpy()
    return midas_depth

# Function to extract LiDAR ground-truth depth
def extract_lidar_depth(nusc, sample_token):
    sample_record = nusc.get('sample', sample_token)
    lidar_data = nusc.get('sample_data', sample_record['data']['LIDAR_TOP'])
    pc = LidarPointCloud.from_file(nusc.get_sample_data_path(lidar_data['token']))
    lidar_depth = pc.points[2, :]  # Extracting depth data, assuming Z-axis is depth in LiDAR
    return lidar_depth

# Main function to compare and plot depth
def compare_depth(nusc, sample_token):
    # Get camera image and corresponding data
    sample_record = nusc.get('sample', sample_token)
    cam_data = nusc.get('sample_data', sample_record['data']['CAM_FRONT'])
    image_path = nusc.get_sample_data_path(cam_data['token'])
    
    # Run MiDaS model on the image
    midas_depth = run_midas(image_path)
    
    # Extract LiDAR ground-truth depth
    lidar_depth = extract_lidar_depth(nusc, sample_token)
    
    # Plotting
    fig, axes = plt.subplots(1, 2, figsize=(15, 7))
    axes[0].imshow(Image.open(image_path))
    axes[0].set_title('Camera Image')
    axes[1].imshow(midas_depth, cmap='viridis', vmax=np.percentile(midas_depth, 95))
    axes[1].set_title('MiDaS Depth Prediction')
    plt.show()
    
    # Plot LiDAR ground-truth depth separately for comparison
    plt.figure(figsize=(10, 5))
    plt.plot(lidar_depth, label='LiDAR Depth')
    plt.xlabel('Points')
    plt.ylabel('Depth')
    plt.title('LiDAR Ground-Truth Depth')
    plt.legend()
    plt.show()

# Example usage
if __name__ == '__main__':
    nuscenes_data_path = '/mnt/d/datasets/nuscenes/'
    nusc = load_nuscenes_data(nuscenes_data_path)
    model_type = "DPT_Large"
    midas = torch.hub.load("intel-isl/MiDaS", model_type)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    midas.to(device)

    # Example: Process a specific sample token (you can loop through multiple samples)
    sample_token = nusc.sample[0]['token']
    compare_depth(nusc, sample_token)


Loading NuScenes tables for version v1.0-mini...
23 category,
8 attribute,
4 visibility,
911 instance,
12 sensor,
120 calibrated_sensor,
31206 ego_pose,
8 log,
10 scene,
404 sample,
31206 sample_data,
18538 sample_annotation,
4 map,
Done loading in 1.307 seconds.
Reverse indexing ...
Done reverse indexing in 0.1 seconds.


Using cache found in /home/alberto/.cache/torch/hub/intel-isl_MiDaS_master


: 