In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch

from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) 
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

In [3]:
# Initialize the model and load the pretrained weights.
# This will automatically download the model weights the first time it's run, which may take a while.
model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)

In [4]:
# Load and preprocess example images (replace with your own image paths)
rt = "../data/mustard360/images/frame_"
ext = ".png"
image_names = [str(i*10) for i in range(33)]
image_names = [rt+name+ext for name in image_names]
images = load_and_preprocess_images(image_names).to(device)

In [5]:
with torch.no_grad():
    with torch.cuda.amp.autocast(dtype=dtype):
        # Predict attributes including cameras, depth maps, and point maps.
        predictions = model(images)

  with torch.cuda.amp.autocast(dtype=dtype):
  with torch.cuda.amp.autocast(enabled=False):


In [6]:
import open3d as o3d

depths = predictions['depth'].cpu().numpy()[0]
world_points = predictions['world_points'].cpu().numpy()[0]
world_points_conf = predictions['world_points_conf'].cpu().numpy()[0]

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [7]:
print(world_points.shape)
print(world_points_conf.shape)

(33, 294, 518, 3)
(33, 294, 518)


In [15]:
cmask = world_points_conf > 1.1
cpoints = world_points[cmask]

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(cpoints)
o3d.visualization.draw_geometries([pcd])

In [10]:
pts = world_points.reshape(-1, 3)
conf = world_points_conf.reshape(-1)

valid_xyz = np.isfinite(pts).all(axis=1)
valid_conf = np.isfinite(conf)
valid = valid_xyz & valid_conf

conf_norm = np.zeros_like(conf, dtype=np.float64)
if np.any(valid):
    cmin = conf[valid].min()
    cmax = conf[valid].max()
    conf_norm[valid] = 1.0 if cmax == cmin else (conf[valid] - cmin) / (cmax - cmin)

H = conf_norm * (2.0 / 3.0)
V = 0.9
h6 = H * 6.0
i = np.floor(h6).astype(int)
f = h6 - i
i_mod = i % 6

v = np.full_like(f, V)
p = np.zeros_like(f)
q = v * (1.0 - f)
t = v * f

rgb = np.zeros((pts.shape[0], 3), dtype=np.float64)

m0 = i_mod == 0; rgb[m0, 0] = v[m0]; rgb[m0, 1] = t[m0]; rgb[m0, 2] = p[m0]
m1 = i_mod == 1; rgb[m1, 0] = q[m1]; rgb[m1, 1] = v[m1]; rgb[m1, 2] = p[m1]
m2 = i_mod == 2; rgb[m2, 0] = p[m2]; rgb[m2, 1] = v[m2]; rgb[m2, 2] = t[m2]
m3 = i_mod == 3; rgb[m3, 0] = p[m3]; rgb[m3, 1] = q[m3]; rgb[m3, 2] = v[m3]
m4 = i_mod == 4; rgb[m4, 0] = t[m4]; rgb[m4, 1] = p[m4]; rgb[m4, 2] = v[m4]
m5 = i_mod == 5; rgb[m5, 0] = v[m5]; rgb[m5, 1] = p[m5]; rgb[m5, 2] = q[m5]

rgb[~valid] = 0.0

pts_draw = pts[valid]
rgb_draw = rgb[valid]

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pts_draw)
pcd.colors = o3d.utility.Vector3dVector(rgb_draw)
o3d.visualization.draw_geometries([pcd])
