# MVSNeRF DTU Reconstruction
This notebook takes several photos from different views (DTU dataset) and reconstructs a 3D point cloud and mesh using MVSNeRF. It saves interactive HTML visualizations (point cloud and mesh) that you can open in a browser.

In [18]:
# Optional: install visualization dependencies if missing
import sys, subprocess
def ensure(pkgs):
    for p in pkgs:
        try:
            __import__(p.split('==')[0].replace('-', '_'))
        except Exception:
            subprocess.check_call([sys.executable, '-m', 'pip', 'install', p])
ensure(['plotly', 'scikit-image', 'trimesh'])
print('Dependencies ready.')

Dependencies ready.


In [19]:
# Imports and configuration
import os, json, math
import numpy as np
import torch
from types import SimpleNamespace
from skimage.measure import marching_cubes
import plotly.graph_objects as go
from pathlib import Path

# Repo modules
from data.dtu import MVSDatasetDTU
from models import create_nerf_mvs, RefVolume
from renderer import render_density
from utils import get_ptsvolume, build_color_volume

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)

# Paths and parameters (edit as needed)
DTU_ROOT = os.environ.get('DTU_ROOT', '/ibex/user/yaoz0b/mvs_training/dtu')  # e.g., '/data/DTU'
SCAN = os.environ.get('DTU_SCAN', 'scan1')
N_VIEWS = int(os.environ.get('DTU_N_VIEWS', '4'))  # total views including reference
DOWNSAMPLE = float(os.environ.get('DTU_DOWNSAMPLE', '1.0'))
PAD = int(os.environ.get('DTU_PAD', '24'))
CKPT_PATH = os.environ.get('MVSNERF_CKPT', 'runs_fine_tuning/scan1-ft/ckpts/latest.tar')
OUT_DIR = Path('results/mvsnerf_dtu')
OUT_DIR.mkdir(parents=True, exist_ok=True)
print('DTU_ROOT:', DTU_ROOT)
print('SCAN:', SCAN)
print('Checkpoint:', CKPT_PATH)

Device: cuda
DTU_ROOT: /ibex/user/yaoz0b/mvs_training/dtu
SCAN: scan1
Checkpoint: runs_fine_tuning/scan1-ft/ckpts/latest.tar


In [20]:
# Patch MVSDatasetDTU to fix proj_mats stacking issue
def patched_build_proj_mats(self):
    proj_mats, intrinsics, world2cams, cam2worlds = [], [], [], []
    for vid in self.id_list:
        proj_mat_filename = os.path.join(self.root_dir, f'Cameras/train/{vid:08d}_cam.txt')
        intrinsic, extrinsic, near_far = self.read_cam_file(proj_mat_filename)
        intrinsic[:2] *= 4
        extrinsic[:3, 3] *= self.scale_factor
        intrinsic[:2] = intrinsic[:2] * self.downSample
        intrinsics += [intrinsic.copy()]
        proj_mat_l = np.eye(4)
        intrinsic[:2] = intrinsic[:2] / 4
        proj_mat_l[:3, :4] = intrinsic @ extrinsic[:3, :4]
        proj_mats += [(proj_mat_l, near_far)]
        world2cams += [extrinsic]
        cam2worlds += [np.linalg.inv(extrinsic)]
    # Fix: separate tuples before stacking
    proj_mats_only = np.stack([pm[0] for pm in proj_mats])
    near_fars = np.stack([pm[1] for pm in proj_mats])
    self.proj_mats = list(zip(proj_mats_only, near_fars))
    self.intrinsics = np.stack(intrinsics)
    self.world2cams, self.cam2worlds = np.stack(world2cams), np.stack(cam2worlds)

# Apply patch before instantiation
MVSDatasetDTU.build_proj_mats = patched_build_proj_mats

# Load DTU dataset sample for the selected scan
dataset = MVSDatasetDTU(root_dir=DTU_ROOT, split='val', n_views=N_VIEWS, downSample=DOWNSAMPLE)

# Find an index corresponding to the desired scan
scan_indices = [i for i, m in enumerate(dataset.metas) if m[0] == SCAN]
if len(scan_indices) == 0:
    raise RuntimeError(f'No entries found for scan {SCAN}. Available scans: {dataset.scans[:5]} ...')
idx = scan_indices[1]
sample = dataset[idx]

# Pack pose information expected by renderer/utils
pose_source = {
    'w2cs': torch.tensor(sample['w2cs'], dtype=torch.float32, device=DEVICE),
    'c2ws': torch.tensor(sample['c2ws'], dtype=torch.float32, device=DEVICE),
    'intrinsics': torch.tensor(sample['intrinsics'], dtype=torch.float32, device=DEVICE),
    'near_fars': torch.tensor(sample['near_fars'], dtype=torch.float32, device=DEVICE)
}

imgs = torch.tensor(sample['images'], dtype=torch.float32, device=DEVICE).unsqueeze(0)  # [1, V, 3, H, W]
proj_mats = torch.tensor(sample['proj_mats'], dtype=torch.float32, device=DEVICE).unsqueeze(0)  # [1, V, 3, 4]
near_far_source = torch.tensor(sample['near_fars'][0], dtype=torch.float32, device=DEVICE)  # [2]
print('Images:', tuple(imgs.shape), 'Proj:', tuple(proj_mats.shape), 'Near/Far:', near_far_source.tolist())

==> image down scale: 1.0
Images: (1, 4, 3, 512, 640) Proj: (1, 4, 3, 4) Near/Far: [2.125, 4.525000095367432]



To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



In [21]:
# Initialize MVSNeRF with checkpoint
from opt import config_parser

# Build args via config_parser to satisfy model creation
args = config_parser(cmd=[
    '--expname', 'dtu_recon',
    '--dataset_name', 'dtu',
    '--net_type', 'v2',
    '--multires', '10',
    '--multires_views', '4',
    '--N_samples', '128',
    '--N_importance', '0',
    '--netchunk', '2048',
    '--chunk', '1024',
    '--pad', str(PAD),
    '--raw_noise_std', '0.0',
    '--ckpt', CKPT_PATH
])
# Additional fields expected by models/renderers
args.pts_dim = 3
args.dir_dim = 3
# Use 3 views for color features to match checkpoint (8 + 3*4 = 20)
FEATURE_VIEWS = 3
args.feat_dim = 8 + FEATURE_VIEWS * 4
args.use_viewdirs = False
args.white_bkgd = False
args.use_color_volume = True
args.use_density_volume = True

render_kwargs_train, render_kwargs_test, _, _ = create_nerf_mvs(args, use_mvs=True, dir_embedder=False, pts_embedder=True)
MVSNet = render_kwargs_train['network_mvs']
network_fn = render_kwargs_train['network_fn']
network_query_fn = render_kwargs_train['network_query_fn']
print('Models ready.')

Found ckpts ['runs_fine_tuning/scan1-ft/ckpts/latest.tar']
Reloading from runs_fine_tuning/scan1-ft/ckpts/latest.tar
Models ready.


In [22]:
# Build the feature volume and density volume
with torch.no_grad():
    volume_feature, img_feats, depth_values = MVSNet(imgs, proj_mats, near_far_source, pad=PAD, lindisp=args.use_disp)
    print('Volume feature:', tuple(volume_feature.shape))

# Create reference volume wrapper
ref_volume = RefVolume(volume_feature.detach()).to(DEVICE)

# Prepare voxel grid in world coordinates
D, H, W = volume_feature.shape[-3:]
intrinsic_ref = pose_source['intrinsics'][0].clone()
c2w_ref = pose_source['c2ws'][0]
intrinsic_ref[:2] /= 4
vox_pts = get_ptsvolume(H - 2 * PAD, W - 2 * PAD, D, PAD, near_far_source, intrinsic_ref, c2w_ref)
# Unnormalize images for color projection
def unpreprocess(data, shape=(1,1,3,1,1)):
    device = data.device
    mean = torch.tensor([-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225]).view(*shape).to(device)
    std = torch.tensor([1 / 0.229, 1 / 0.224, 1 / 0.225]).view(*shape).to(device)
    return (data - mean) / std
imgs_un = unpreprocess(imgs)

# Limit color features to FEATURE_VIEWS to match checkpoint feature dim
imgs_un_feat = imgs_un[:, :FEATURE_VIEWS]
pose_source_feat = {
    'w2cs': pose_source['w2cs'][:FEATURE_VIEWS],
    'c2ws': pose_source['c2ws'][:FEATURE_VIEWS],
    'intrinsics': pose_source['intrinsics'][:FEATURE_VIEWS],
    'near_fars': pose_source['near_fars'][:FEATURE_VIEWS]
}

# Color feature volume from multi-view images
color_feat = build_color_volume(vox_pts, pose_source_feat, imgs_un_feat, with_mask=True)
color_feat = color_feat.view(D, H, W, -1).unsqueeze(0).permute(0, 4, 1, 2, 3)  # [1, C, D, H, W]

# Compute density over the voxel grid
features = torch.cat((ref_volume.feat_volume, color_feat), dim=1).permute(0, 2, 3, 4, 1).reshape(D * H, W, -1)
with torch.no_grad():
    density = render_density(network_fn, vox_pts, features, network_query_fn)
density_vol = density.reshape(D, H, W).detach().cpu()
print('Density volume:', tuple(density_vol.shape))

Volume feature: (1, 8, 128, 176, 208)
Density volume: (128, 176, 208)


In [23]:
# Extract point cloud by thresholding density
D, H, W = density_vol.shape
coords = vox_pts.view(D, H, W, 3).cpu().numpy()
sigma = density_vol.numpy()
thr = np.percentile(sigma, 90)  # keep top 10% densest voxels
mask = sigma > thr
points = coords[mask]
# Use reference view RGB for coloring
rgb_vol = color_feat[0, 0:3].permute(1, 2, 3, 0).cpu().numpy()  # [D,H,W,3]
colors = (np.clip(rgb_vol[mask], 0, 1) * 255).astype(np.uint8)
print('Point cloud size:', points.shape[0])

# Save PLY
ply_path = OUT_DIR / f'{SCAN}_pointcloud.ply'
with open(ply_path, 'w') as f:
    f.write('ply\nformat ascii 1.0\n')
    f.write(f'element vertex {points.shape[0]}\n')
    f.write('property float x\nproperty float y\nproperty float z\n')
    f.write('property uchar red\nproperty uchar green\nproperty uchar blue\n')
    f.write('end_header\n')
    for (x, y, z), (r, g, b) in zip(points, colors):
        f.write(f'{x} {y} {z} {int(r)} {int(g)} {int(b)}\n')
print('Saved PLY:', ply_path)

# Interactive point cloud HTML
pc_fig = go.Figure(data=[go.Scatter3d(
    x=points[:,0], y=points[:,1], z=points[:,2],
    mode='markers',
    marker=dict(size=1.5, color=['rgb(%d,%d,%d)' % tuple(c) for c in colors])
)])
pc_fig.update_layout(scene_aspectmode='data', title=f'{SCAN} Point Cloud')
pc_html = OUT_DIR / f'{SCAN}_pointcloud.html'
pc_fig.write_html(pc_html, include_plotlyjs='cdn')
print('Saved HTML:', pc_html)

Point cloud size: 468583
Saved PLY: results/mvsnerf_dtu/scan1_pointcloud.ply
Saved HTML: results/mvsnerf_dtu/scan1_pointcloud.html


In [24]:
# Reconstruct mesh via marching cubes on density volume
level = float(thr)
verts, faces, normals, values = marching_cubes(volume=sigma, level=level, spacing=(1.0, 1.0, 1.0))
verts_idx = np.clip(np.round(verts).astype(np.int64), [0,0,0], [D-1,H-1,W-1])
world_coords = coords[verts_idx[:,0], verts_idx[:,1], verts_idx[:,2]]

# Save a simple OBJ mesh
obj_path = OUT_DIR / f'{SCAN}_mesh.obj'
with open(obj_path, 'w') as f:
    for v in world_coords:
        f.write(f'v {v[0]} {v[1]} {v[2]}\n')
    for (a,b,c) in faces:
        f.write(f'f {int(a)+1} {int(b)+1} {int(c)+1}\n')
print('Saved OBJ:', obj_path)

# Interactive mesh HTML
mesh_fig = go.Figure(data=[go.Mesh3d(
    x=world_coords[:,0], y=world_coords[:,1], z=world_coords[:,2],
    i=faces[:,0], j=faces[:,1], k=faces[:,2],
    color='lightblue', opacity=0.5
)])
mesh_fig.update_layout(scene_aspectmode='data', title=f'{SCAN} Mesh')
mesh_html = OUT_DIR / f'{SCAN}_mesh.html'
mesh_fig.write_html(mesh_html, include_plotlyjs='cdn')
print('Saved HTML:', mesh_html)

Saved OBJ: results/mvsnerf_dtu/scan1_mesh.obj
Saved HTML: results/mvsnerf_dtu/scan1_mesh.html


In [25]:
# Side-by-side visualization: input views + point cloud + mesh
from plotly.subplots import make_subplots

# Build a simple montage from input views
def make_montage(tensor_imgs, max_views=4):
    # tensor_imgs: [1, V, 3, H, W] in 0..1
    imgs_np = tensor_imgs[0].permute(0, 2, 3, 1).detach().cpu().numpy()
    v = min(max_views, imgs_np.shape[0])
    imgs_np = imgs_np[:v]
    # pad to 4 for 2x2 grid
    while imgs_np.shape[0] < 4:
        imgs_np = np.concatenate([imgs_np, imgs_np[-1:]], axis=0)
    top = np.concatenate([imgs_np[0], imgs_np[1]], axis=1)
    bottom = np.concatenate([imgs_np[2], imgs_np[3]], axis=1)
    montage = np.concatenate([top, bottom], axis=0)
    return (np.clip(montage, 0, 1) * 255).astype(np.uint8)

montage = make_montage(imgs_un_feat, max_views=FEATURE_VIEWS)

fig = make_subplots(
    rows=1, cols=3,
    specs=[[{"type": "image"}, {"type": "scene"}, {"type": "scene"}]],
    column_widths=[0.35, 0.325, 0.325],
    subplot_titles=["Input Views", "Point Cloud", "Mesh"]
)
fig.add_trace(go.Image(z=montage), row=1, col=1)
fig.add_trace(go.Scatter3d(
    x=points[:,0], y=points[:,1], z=points[:,2],
    mode='markers',
    marker=dict(size=1.5, color=['rgb(%d,%d,%d)' % tuple(c) for c in colors])
), row=1, col=2)
fig.add_trace(go.Mesh3d(
    x=world_coords[:,0], y=world_coords[:,1], z=world_coords[:,2],
    i=faces[:,0], j=faces[:,1], k=faces[:,2],
    color='lightblue', opacity=0.5
), row=1, col=3)
fig.update_layout(height=500, showlegend=False, scene_aspectmode='data', scene2_aspectmode='data')

# Save combined visualization to HTML (browser-friendly without nbformat)
combined_html = OUT_DIR / f'{SCAN}_combined.html'
fig.write_html(combined_html, include_plotlyjs='cdn')
print('Saved HTML:', combined_html)

Saved HTML: results/mvsnerf_dtu/scan1_combined.html


## Usage
- Set `DTU_ROOT`, `SCAN`, and `MVSNERF_CKPT` environment variables or edit the config cell.
- Run cells 1â†’7. Outputs are saved under `results/mvsnerf_dtu/`.
- Open the HTML files in your browser: point cloud and mesh visualizations.

Example to run in terminal before opening the notebook:

```bash
export DTU_ROOT=/data/DTU
export DTU_SCAN=scan1
export MVSNERF_CKPT=runs_fine_tuning/scan1-ft/ckpts/latest.tar
```