In [2]:
import torch
import numpy as np
from PIL import Image
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.geometry import unproject_depth_map_to_point_map

# Configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Initialize VGGT model
model = VGGT.from_pretrained("facebook/VGGT-1B").to(DEVICE)
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


VGGT(
  (aggregator): Aggregator(
    (patch_embed): DinoVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
        (norm): Identity()
      )
      (blocks): ModuleList(
        (0-23): 24 x NestedTensorBlock(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): MemEffAttention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): LayerScale()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU(approximate=

In [3]:
LOW_RES_WIDTH = 518  # recommended low-res width

In [4]:
import os

# Read images from 'people/images' directory
# image_dir = 'examples/room/images'
image_dir = 'people/images'
IMAGE_PATHS = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]
print(IMAGE_PATHS)

['people/images/WechatIMG199.jpg', 'people/images/WechatIMG200.jpg', 'people/images/WechatIMG201.jpg', 'people/images/WechatIMG202.jpg', 'people/images/WechatIMG203.jpg']


In [5]:
# 2. First stage: low-resolution camera parameter estimation
# Load and resize images to fixed width
low_res_images = load_and_preprocess_images(
    IMAGE_PATHS,
    mode='pad'
).to(DEVICE)  # shape: (B, N, 3, H_low, W_low)

# Forward pass to get camera intrinsics and extrinsics
with torch.no_grad():
    predictions_low = model(low_res_images)

In [None]:
# Extract intrinsics (fx, fy, cx, cy) and extrinsics (4x4 matrices)
intrinsics = predictions_low['intrinsics']  # shape: (B, N, 4)
extrinsics = predictions_low['extrinsics']  # shape: (B, N, 4, 4)

# Save or store for second stage
torch.save({'intrinsics': intrinsics, 'extrinsics': extrinsics}, 'camera_params.pth')

# 3. Second stage: high-resolution depth estimation
# Load high-resolution images without resizing
def load_high_res(paths):
    imgs = []
    for p in paths:
        img = Image.open(p).convert('RGB')
        arr = np.array(img).astype(np.float32) / 255.0  # normalize to [0,1]
        tensor = torch.from_numpy(arr).permute(2, 0, 1)  # C,H,W
        imgs.append(tensor)
    batch = torch.stack(imgs, dim=0)  # (N,3,H,W)
    return batch.unsqueeze(0).to(DEVICE)  # (1,N,3,H,W)

high_res_images = load_high_res(IMAGE_PATHS)

# Load saved camera parameters
cp = torch.load('camera_params.pth', map_location=DEVICE)
intrinsics = cp['intrinsics']
extrinsics = cp['extrinsics']

# Forward pass for depth
with torch.no_grad():
    predictions_high = model.forward_depth(
        high_res_images,
        intrinsics=intrinsics,
        extrinsics=extrinsics
    )

# Extract depth maps
depth_maps = predictions_high['depth_map']  # (B,N,H,W)

# 4. Unproject depth maps to point maps
point_maps = []
for i in range(depth_maps.shape[1]):  # iterate over N views
    depth = depth_maps[0, i]  # (H,W)
    intr = intrinsics[0, i]   # (4,) fx,fy,cx,cy
    ext = extrinsics[0, i]    # (4,4)
    pts = unproject_depth_map_to_point_map(depth, intr, ext)
    point_maps.append(pts.cpu().numpy())  # (H*W,3)

# Save or merge point clouds
def save_point_cloud(points, filename):
    with open(filename, 'w') as f:
        for x, y, z in points:
            f.write(f"v {x} {y} {z}\n")

for idx, pts in enumerate(point_maps):
    save_point_cloud(pts.reshape(-1, 3), f"point_map_view{idx}.obj")

print("Two-stage VGGT pipeline completed. Point clouds saved as OBJ files.")
