In [2]:
import torch
import struct
import numpy as np
import matplotlib.pyplot as plt
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from torch.cuda.amp import autocast

device = "cuda" if torch.cuda.is_available() else "cpu"
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) 
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

In [3]:
# Initialize the model and load the pretrained weights.
# This will automatically download the model weights the first time it's run, which may take a while.
model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
model.to(device)

VGGT(
  (aggregator): Aggregator(
    (patch_embed): DinoVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
        (norm): Identity()
      )
      (blocks): ModuleList(
        (0-23): 24 x NestedTensorBlock(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): MemEffAttention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): LayerScale()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU(approximate=

In [4]:
SCENE="banana"
SKIP=1

if SCENE=="banana": 
    # Load and preprocess example images (replace with your own image paths)
    image_names = [
        "/home/skhalid/Documents/data/banana/input/frame_00001.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00002.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00003.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00004.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00005.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00006.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00007.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00008.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00009.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00010.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00011.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00012.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00013.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00014.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00015.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00016.JPG"
    ]
    ### BANANA
    width = 3008
    height = 2000
    BASE_PATH = "/home/skhalid/Documents/data/banana"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = "frame_"
    START_ID = 0
    N = 1_000

elif SCENE=="lego": 
    ### LEGO
    image_names = ["/home/skhalid/Documents/data/nerf_synthetic/lego/train/r_"+str(v)+".png" for v in range(0, 99, SKIP)]
    width = 800
    height = 800
    BASE_PATH = "/home/skhalid/Documents/data/nerf_synthetic/lego/"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = "r_"
    START_ID = 0
    N = 200_000

elif SCENE=="bicycle": 
    ### BICYCLE
    BASE="/home/skhalid/Documents/data/360_v2/bicycle/images_4/_DSC"
    image_names = [BASE+str(v)+".JPG" for v in range(8679, 8873, SKIP)]
    width = 1236
    height = 821    
    BASE_PATH = "/home/skhalid/Documents/data/360_v2/bicycle"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = "_DSC"
    START_ID = 0
    N = 3_000_000
    # test_cases = ["8679.JPG",
    #               "8687.JPG",
    #               "8695.JPG",
    #               "8703.JPG",
    #               "8711.JPG",
    #               "8719.JPG",
    #               "8727.JPG",
    #               "8735.JPG",
    #               "8744.JPG",
    #               "8752.JPG",
    #               "8760.JPG",
    #               "8768.JPG",
    #               "8776.JPG",
    #               "8784.JPG",
    #               "8792.JPG",
    #               "8800.JPG",
    #               "8808.JPG",
    #               "8816.JPG",
    #               "8824.JPG",
    #               "8832.JPG",
    #               "8840.JPG",
    #               "8848.JPG",
    #               "8856.JPG",
    #               "8864.JPG",
    #               "8872.JPG"]
    # for test_case in test_cases:
    #     image_names.append(BASE+str(test_case))

elif SCENE=="truck": 
    ### BICYCLE
    BASE="/home/skhalid/Documents/data/tandt_db/tandt/truck/images/"
    image_names = [BASE+str(v).zfill(6)+".jpg" for v in range(1, 252, SKIP)]
    width = 1957
    height = 1091    
    BASE_PATH = "/home/skhalid/Documents/data/tandt_db/tandt/truck"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = ""
    START_ID = 0
    N = 200_000
    # test_cases = ["8679.JPG",
    #               "8687.JPG",
    #               "8695.JPG",
    #               "8703.JPG",
    #               "8711.JPG",
    #               "8719.JPG",
    #               "8727.JPG",
    #               "8735.JPG",
    #               "8744.JPG",
    #               "8752.JPG",
    #               "8760.JPG",
    #               "8768.JPG",
    #               "8776.JPG",
    #               "8784.JPG",
    #               "8792.JPG",
    #               "8800.JPG",
    #               "8808.JPG",
    #               "8816.JPG",
    #               "8824.JPG",
    #               "8832.JPG",
    #               "8840.JPG",
    #               "8848.JPG",
    #               "8856.JPG",
    #               "8864.JPG",
    #               "8872.JPG"]
    # for test_case in test_cases:
    #     image_names.append(BASE+str(test_case))


In [5]:
import torch
from tqdm import tqdm

def run_batched_camera_inference(model, image_names, batch_size=8, device='cuda', dtype=torch.float16):
    from vggt.utils.pose_enc import pose_encoding_to_extri_intri
    from vggt.utils.geometry import unproject_depth_map_to_point_map
    # from vggt.utils.io import load_and_preprocess_images

    all_extrinsics = []
    all_intrinsics = []
    all_world_points = []
    depth_maps = []
    depth_conf_maps = []
    batch_tensors = []

    # Batch the rest of the images
    print(f"Processing the rest of {len(image_names)} images in batches of {batch_size}...")
    for i in tqdm(range(0, len(image_names), batch_size)):
        batch_names = image_names[i:i + batch_size]
        batch_tensor = load_and_preprocess_images(batch_names).to(device)

        if i==0:
            first_image = batch_tensor[0]
            print("first_image.shape: {}".format(batch_tensor.shape))
        else:
            # Add the first reference image to this batch as well
            batch_tensor = torch.cat((first_image[None], batch_tensor), dim=0)

        with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
            batch_tensor = batch_tensor[None]  # Add batch dim
            agg_tokens, ps_idx = model.aggregator(batch_tensor)

            pose_enc = model.camera_head(agg_tokens)[-1]
            extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, batch_tensor.shape[-2:])

            depth_map, depth_conf_map = model.depth_head(agg_tokens, batch_tensor, ps_idx)
            
            point_map_unproj = unproject_depth_map_to_point_map(depth_map.squeeze(0), extrinsic.squeeze(0), intrinsic.squeeze(0))
    
            # if i==0:    
            #     print("batch: {} | point_map_unproj.shape: {}".format(i, point_map_unproj.shape))
            # else:
            #     print("batch: {} | point_map_unproj[1:, ...].shape: {}".format(i, point_map_unproj[1:, ...].shape))

            if i==0:
                all_extrinsics.append(extrinsic[0, ...])
                all_intrinsics.append(intrinsic[0, ...])
                all_world_points.append(point_map_unproj)
                depth_maps.append(depth_map[0, ...])
                depth_conf_maps.append(depth_conf_map[0, ...])
                batch_tensors.append(batch_tensor[0, ...])
            else:
                all_extrinsics.append(extrinsic[0, 1:])
                all_intrinsics.append(intrinsic[0, 1:])
                all_world_points.append(point_map_unproj[1:, ...])
                depth_maps.append(depth_map[0, 1:, ...])
                depth_conf_maps.append(depth_conf_map[0, 1:, ...])
                batch_tensors.append(batch_tensor[0, 1:, ...])

            print("extrinsic: {}".format(extrinsic.shape))
            print("intrinsic: {}".format(intrinsic.shape))
            print("point_map_unproj: {}".format(point_map_unproj.shape))
            print("depth_map: {}".format(depth_map.shape))
            print("depth_conf_map: {}".format(depth_conf_map.shape))
            print("batch_tensor: {}".format(batch_tensor.shape))

    # Stack everything
    batch_tensors = torch.cat(batch_tensors)  # [N, 4, 4]
    all_extrinsics = torch.cat(all_extrinsics)  # [N, 4, 4]
    all_intrinsics = torch.cat(all_intrinsics)  # [N, 3, 3]
    all_world_points = np.concatenate(all_world_points)  # [N, H, W, 3]
    depth_maps = torch.cat(depth_maps, dim=0)  # [N, H, W, 3]
    depth_conf_maps = torch.cat(depth_conf_maps, dim=0)  # [N, H, W, 3]

    return {
        "all_extrinsics": all_extrinsics, 
        "all_intrinsics": all_intrinsics, 
        "all_world_points": all_world_points,
        "depth_maps": depth_maps,
        "depth_conf_maps": depth_conf_maps,
        "all_images": batch_tensors
    }

    # # Predict Tracks
    # # choose your own points to track, with shape (N, 2) for one scene
    # query_points = torch.FloatTensor([[100.0, 200.0], 
    #                                     [60.72, 259.94]]).to(device)
    # track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])

In [6]:
'''
Run inference here
'''
BATCH_SIZE=20
SINGLE_SAMPLE = False
normals = None

predictions = run_batched_camera_inference(model, image_names, batch_size=BATCH_SIZE)

all_extrinsics = predictions["all_extrinsics"]
all_intrinsics = predictions["all_intrinsics"]
all_world_points = predictions["all_world_points"]
depth_maps = predictions["depth_maps"]
depth_conf_maps = predictions["depth_conf_maps"]
all_images = predictions["all_images"]

print(all_extrinsics.shape, all_intrinsics.shape, all_world_points.shape, depth_maps.shape, depth_conf_maps.shape, all_images.shape)

Processing the rest of 16 images in batches of 20...


  0%|          | 0/1 [00:00<?, ?it/s]

first_image.shape: torch.Size([16, 3, 350, 518])


  with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
100%|██████████| 1/1 [00:01<00:00,  1.60s/it]

extrinsic: torch.Size([1, 16, 3, 4])
intrinsic: torch.Size([1, 16, 3, 3])
point_map_unproj: (16, 350, 518, 3)
depth_map: torch.Size([1, 16, 350, 518, 1])
depth_conf_map: torch.Size([1, 16, 350, 518])
batch_tensor: torch.Size([1, 16, 3, 350, 518])
torch.Size([16, 3, 4]) torch.Size([16, 3, 3]) (16, 350, 518, 3) torch.Size([16, 350, 518, 1]) torch.Size([16, 350, 518]) torch.Size([16, 3, 350, 518])





In [1]:
import os
import torch
import imageio
from torchvision.utils import save_image

def save_grayscale_image(img_tensor, save_path):
    """
    Save a single-channel image as a 3-channel grayscale PNG.
    """
    img_uint8 = normalize_to_uint8(img_tensor).numpy()
    img_uint8 = np.stack([img_uint8.squeeze()]*3, axis=-1)  # Convert to (H, W, 3)
    print(img_uint8.shape)
    imageio.imwrite(save_path, img_uint8)

# Your prediction tensors
depth_maps = predictions["depth_maps"].cpu()            # (N, H, W)
depth_conf_maps = predictions["depth_conf_maps"].cpu()  # (N, H, W)
all_images = predictions["all_images"].cpu()            # (N, 3, H, W) or (N, H, W, 3)

# Output subfolders
save_dirs = {
    "rgb": "rgb",
    "depth": "depth",
    "depth_conf": "depth_conf"
}

# Create the directories if they don't exist
for d in save_dirs.values():
    os.makedirs(d, exist_ok=True)

# Helper to normalize for saving
def normalize_to_uint8(tensor):
    tensor = tensor.clone()
    tensor -= tensor.min()
    tensor /= tensor.max() + 1e-8
    return (tensor * 255).byte()

# Save each frame
for i in range(depth_maps.shape[0]):
    fname_rgb = image_names[i].replace("input", "rgb")  # same base filename across folders
    fname_depth = image_names[i].replace("input", "depth")  # same base filename across folders
    fname_normal = image_names[i].replace("input", "normal")  # same base filename across folders

    # Save RGB image
    img = all_images[i]
    if img.dim() == 3 and img.shape[0] == 3:
        os.makedirs(os.path.dirname(fname_rgb), exist_ok=True)
        save_image(img, fname_rgb)
    elif img.dim() == 3 and img.shape[2] == 3:
        imageio.imwrite(fname_rgb, img.numpy())

    # Save depth
    os.makedirs(os.path.dirname(fname_depth), exist_ok=True)
    save_grayscale_image(depth_maps[i], fname_depth)    

    # Save confidence
    os.makedirs(os.path.dirname(fname_normal), exist_ok=True)
    save_grayscale_image(depth_conf_maps[i], fname_normal)


NameError: name 'predictions' is not defined