In [1]:
import torch
import struct
import numpy as np
import matplotlib.pyplot as plt
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from torch.cuda.amp import autocast

# device = "cuda" if torch.cuda.is_available() else "cpu"
device_vggt = "cpu"

# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) 
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

In [2]:
# Initialize the model and load the pretrained weights.
# This will automatically download the model weights the first time it's run, which may take a while.
model_vggt = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model_vggt.load_state_dict(torch.hub.load_state_dict_from_url(_URL))

# TODO: Fix depth pruning
model_vggt.aggregator.patch_embed.blocks = model_vggt.aggregator.patch_embed.blocks[:]
model_vggt.to(device_vggt)

VGGT(
  (aggregator): Aggregator(
    (patch_embed): DinoVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
        (norm): Identity()
      )
      (blocks): ModuleList(
        (0-23): 24 x NestedTensorBlock(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): MemEffAttention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): LayerScale()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU(approximate=

In [3]:
def print_model_layer_sizes_sorted(model):
    param_list = []

    for name, param in model.named_parameters():
        if param.requires_grad:
            param_count = param.numel()
            size_mb = param_count * param.element_size() / 1e6  # Bytes to MB
            param_list.append((name, param_count, size_mb))

    # Sort by size descending
    param_list.sort(key=lambda x: x[2], reverse=True)

    print(f"{'Layer':60} {'Param Count':>15} {'Size (MB)':>12}")
    print("-" * 90)
    
    total_params = 0
    total_size_mb = 0.0

    for name, count, size in param_list:
        total_params += count
        total_size_mb += size
        print(f"{name:60} {count:15,} {size:12.2f}")

    print("-" * 90)
    print(f"{'TOTAL':60} {total_params:15,} {total_size_mb:12.2f} MB")

# Example usage
print_model_layer_sizes_sorted(model_vggt)


Layer                                                            Param Count    Size (MB)
------------------------------------------------------------------------------------------
camera_head.trunk.0.mlp.fc1.weight                                16,777,216        67.11
camera_head.trunk.0.mlp.fc2.weight                                16,777,216        67.11
camera_head.trunk.1.mlp.fc1.weight                                16,777,216        67.11
camera_head.trunk.1.mlp.fc2.weight                                16,777,216        67.11
camera_head.trunk.2.mlp.fc1.weight                                16,777,216        67.11
camera_head.trunk.2.mlp.fc2.weight                                16,777,216        67.11
camera_head.trunk.3.mlp.fc1.weight                                16,777,216        67.11
camera_head.trunk.3.mlp.fc2.weight                                16,777,216        67.11
camera_head.trunk.0.attn.qkv.weight                               12,582,912        50.33
camera_he

In [None]:
# # truck
# SKIP=1
# BATCH_SIZE=2
# SCENE="truck"
# TARGET = "truck" # added to the query
# SAMPLED_POINTS = 1000
# SCORE_THRESHOLD = 0.25

# # banana
# SKIP=1
# BATCH_SIZE=2
# SCENE="banana"
# TARGET = "fruit" # added to the query
# SAMPLED_POINTS = 1000
# SCORE_THRESHOLD = 0.25

# train
SKIP = 10
BATCH_SIZE = 10
SCENE = "train"
TARGET = "train" # added to the query
SAMPLED_POINTS = 500
SCORE_THRESHOLD = 0.25

if SCENE=="banana": 
    # Load and preprocess example images (replace with your own image paths)
    image_names = [
        "/home/skhalid/Documents/data/banana/input/frame_00001.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00002.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00003.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00004.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00005.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00006.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00007.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00008.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00009.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00010.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00011.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00012.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00013.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00014.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00015.JPG", 
        "/home/skhalid/Documents/data/banana/input/frame_00016.JPG"
    ]
    ### BANANA
    width = 3008
    height = 2000
    BASE_PATH = "/home/skhalid/Documents/data/banana"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = "frame_"
    START_ID = 0
    N = 200_000

elif SCENE=="lego": 
    ### LEGO
    image_names = ["/home/skhalid/Documents/data/nerf_synthetic/lego/train/r_"+str(v)+".png" for v in range(0, 99, SKIP)]
    width = 800
    height = 800
    BASE_PATH = "/home/skhalid/Documents/data/nerf_synthetic/lego/"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = "r_"
    START_ID = 0
    N = 200_000

elif SCENE=="bicycle": 
    ### BICYCLE
    BASE="/home/skhalid/Documents/data/360_v2/bicycle/images_4/_DSC"
    image_names = [BASE+str(v)+".JPG" for v in range(8679, 8873, SKIP)]
    width = 1236
    height = 821    
    BASE_PATH = "/home/skhalid/Documents/data/360_v2/bicycle"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = "_DSC"
    START_ID = 0
    N = 1_000_000
    # test_cases = ["8679.JPG",
    #               "8687.JPG",
    #               "8695.JPG",
    #               "8703.JPG",
    #               "8711.JPG",
    #               "8719.JPG",
    #               "8727.JPG",
    #               "8735.JPG",
    #               "8744.JPG",
    #               "8752.JPG",
    #               "8760.JPG",
    #               "8768.JPG",
    #               "8776.JPG",
    #               "8784.JPG",
    #               "8792.JPG",
    #               "8800.JPG",
    #               "8808.JPG",
    #               "8816.JPG",
    #               "8824.JPG",
    #               "8832.JPG",
    #               "8840.JPG",
    #               "8848.JPG",
    #               "8856.JPG",
    #               "8864.JPG",
    #               "8872.JPG"]
    # for test_case in test_cases:
    #     image_names.append(BASE+str(test_case))

elif SCENE=="truck": 
    ### BICYCLE
    BASE="/home/skhalid/Documents/data/tandt_db/tandt/truck/images/"
    image_names = [BASE+str(v).zfill(6)+".jpg" for v in range(1, 252, SKIP)]
    width = 1957
    height = 1091    
    BASE_PATH = "/home/skhalid/Documents/data/tandt_db/tandt/truck"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = ""
    START_ID = 0
    N = 1_000_000
    # test_cases = ["8679.JPG",
    #               "8687.JPG",
    #               "8695.JPG",
    #               "8703.JPG",
    #               "8711.JPG",
    #               "8719.JPG",
    #               "8727.JPG",
    #               "8735.JPG",
    #               "8744.JPG",
    #               "8752.JPG",
    #               "8760.JPG",
    #               "8768.JPG",
    #               "8776.JPG",
    #               "8784.JPG",
    #               "8792.JPG",
    #               "8800.JPG",
    #               "8808.JPG",
    #               "8816.JPG",
    #               "8824.JPG",
    #               "8832.JPG",
    #               "8840.JPG",
    #               "8848.JPG",
    #               "8856.JPG",
    #               "8864.JPG",
    #               "8872.JPG"]
    # for test_case in test_cases:
    #     image_names.append(BASE+str(test_case))

elif SCENE=="train": 
    ### BICYCLE
    BASE="/home/skhalid/Documents/data/tandt_db/tandt/train/images/"
    image_names = [BASE+str(v).zfill(5)+".jpg" for v in range(1, 252, SKIP)]
    width = 980
    height = 505    
    BASE_PATH = "/home/skhalid/Documents/data/tandt_db/tandt/train"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = ""
    START_ID = 0
    N = 1_000_000
    # test_cases = ["8679.JPG",
    #               "8687.JPG",
    #               "8695.JPG",
    #               "8703.JPG",
    #               "8711.JPG",
    #               "8719.JPG",
    #               "8727.JPG",
    #               "8735.JPG",
    #               "8744.JPG",
    #               "8752.JPG",
    #               "8760.JPG",
    #               "8768.JPG",
    #               "8776.JPG",
    #               "8784.JPG",
    #               "8792.JPG",
    #               "8800.JPG",
    #               "8808.JPG",
    #               "8816.JPG",
    #               "8824.JPG",
    #               "8832.JPG",
    #               "8840.JPG",
    #               "8848.JPG",
    #               "8856.JPG",
    #               "8864.JPG",
    #               "8872.JPG"]
    # for test_case in test_cases:
    #     image_names.append(BASE+str(test_case))


elif SCENE=="barn": 
    ### BICYCLE
    BASE="/home/skhalid/Documents/data/tandt_db/tandt/barn/images/"
    image_names = [BASE+str(v).zfill(6)+".jpg" for v in range(1, 252, SKIP)]
    width = 1920
    height = 1080    
    BASE_PATH = "/home/skhalid/Documents/data/tandt_db/tandt/barn"
    INTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/cameras.bin"
    EXTRINSICS_BINARY_PATH = BASE_PATH+"/sparse/0/images.bin"
    PTS_PATH = BASE_PATH+"/sparse/0/points3D.ply"
    PREFIX = ""
    START_ID = 0
    N = 1_000_000
    # test_cases = ["8679.JPG",
    #               "8687.JPG",
    #               "8695.JPG",
    #               "8703.JPG",
    #               "8711.JPG",
    #               "8719.JPG",
    #               "8727.JPG",
    #               "8735.JPG",
    #               "8744.JPG",
    #               "8752.JPG",
    #               "8760.JPG",
    #               "8768.JPG",
    #               "8776.JPG",
    #               "8784.JPG",
    #               "8792.JPG",
    #               "8800.JPG",
    #               "8808.JPG",
    #               "8816.JPG",
    #               "8824.JPG",
    #               "8832.JPG",
    #               "8840.JPG",
    #               "8848.JPG",
    #               "8856.JPG",
    #               "8864.JPG",
    #               "8872.JPG"]
    # for test_case in test_cases:
    #     image_names.append(BASE+str(test_case))


In [5]:
import torch
from tqdm import tqdm

def run_batched_camera_inference(model, image_names, batch_size=8, device='cuda', dtype=torch.float16):
    from vggt.utils.pose_enc import pose_encoding_to_extri_intri
    from vggt.utils.geometry import unproject_depth_map_to_point_map
    # from vggt.utils.io import load_and_preprocess_images

    all_extrinsics = []
    all_intrinsics = []
    all_world_points = []
    depth_maps = []
    depth_conf_maps = []
    batch_tensors = []
    agg_tokens_tensor = torch.Tensor([]).to(device)
    ps_idx_list = []

    # Batch the rest of the images
    print(f"Processing the rest of {len(image_names)} images in batches of {batch_size}...")
    for i in tqdm(range(0, len(image_names), batch_size)):
        batch_names = image_names[i:i + batch_size]
        if i==0:
            first_fn = batch_names[0]
            print("first_fn: {}".format(first_fn))
        else:
            batch_names.insert(0, first_fn)
            print("batch_names: {}".format(batch_names))

        batch_tensor = load_and_preprocess_images(batch_names).to(device)

        with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):
            batch_tensor = batch_tensor[None]  # Add batch dim
            agg_tokens, ps_idx = model.aggregator(batch_tensor)

            pose_enc = model.camera_head(agg_tokens)[-1]
            extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, batch_tensor.shape[-2:])

            depth_map, depth_conf_map = model.depth_head(agg_tokens, batch_tensor, ps_idx)
            
            point_map_unproj = unproject_depth_map_to_point_map(depth_map.squeeze(0), extrinsic.squeeze(0), intrinsic.squeeze(0))
    
            agg_tokens = torch.stack(agg_tokens)

            if i==0:
                all_extrinsics.append(extrinsic[0, ...])
                all_intrinsics.append(intrinsic[0, ...])
                all_world_points.append(point_map_unproj)
                depth_maps.append(depth_map[0, ...])
                depth_conf_maps.append(depth_conf_map[0, ...])
                batch_tensors.append(batch_tensor[0, ...])
                agg_tokens_tensor = torch.cat((agg_tokens_tensor, agg_tokens), dim=2)
                ps_idx_list.append(ps_idx)
                print()
                print("when i==0")
                print("extrinsic: {}".format(extrinsic.shape))
                print("intrinsic: {}".format(intrinsic.shape))
                print("point_map_unproj: {}".format(point_map_unproj.shape))
                print("depth_map: {}".format(depth_map.shape))
                print("depth_conf_map: {}".format(depth_conf_map.shape))
                print("batch_tensor: {}".format(batch_tensor.shape))
                # extrinsic: torch.Size([1, 10, 3, 4])
                # intrinsic: torch.Size([1, 10, 3, 3])
                # point_map_unproj: (10, 350, 518, 3)
                # depth_map: torch.Size([1, 10, 350, 518, 1])
                # depth_conf_map: torch.Size([1, 10, 350, 518])
                # batch_tensor: torch.Size([1, 10, 3, 350, 518])
            else:
                all_extrinsics.append(extrinsic[0, 1:])
                all_intrinsics.append(intrinsic[0, 1:])
                all_world_points.append(point_map_unproj[1:, ...])
                depth_maps.append(depth_map[0, 1:, ...])
                depth_conf_maps.append(depth_conf_map[0, 1:, ...])
                batch_tensors.append(batch_tensor[0, 1:, ...])
                print("agg_tokens_tensor: {} | agg_tokens: {}".format(agg_tokens_tensor.shape, agg_tokens.shape))
                agg_tokens_tensor = torch.cat((agg_tokens_tensor, agg_tokens[:, :, 1:, ...]), dim=2)
                ps_idx_list.append(ps_idx)
                print("when i>0")
                print("extrinsic: {}".format(extrinsic.shape))
                print("intrinsic: {}".format(intrinsic.shape))
                print("point_map_unproj: {}".format(point_map_unproj.shape))
                print("depth_map: {}".format(depth_map.shape))
                print("depth_conf_map: {}".format(depth_conf_map.shape))
                print("batch_tensor: {}".format(batch_tensor.shape))
                # extrinsic: torch.Size([1, 7, 3, 4])
                # intrinsic: torch.Size([1, 7, 3, 3])
                # point_map_unproj: (7, 350, 518, 3)
                # depth_map: torch.Size([1, 7, 350, 518, 1])
                # depth_conf_map: torch.Size([1, 7, 350, 518])
                # batch_tensor: torch.Size([1, 7, 3, 350, 518])


    # Stack everything
    batch_tensors = torch.cat(batch_tensors)  # [N, 4, 4]
    all_extrinsics = torch.cat(all_extrinsics)  # [N, 4, 4]
    all_intrinsics = torch.cat(all_intrinsics)  # [N, 3, 3]
    all_world_points = np.concatenate(all_world_points)  # [N, H, W, 3]
    depth_maps = torch.cat(depth_maps, dim=0)  # [N, H, W, 3]
    depth_conf_maps = torch.cat(depth_conf_maps, dim=0)  # [N, H, W, 3]
    # aggregated_tokens_list = torch.cat(aggregated_tokens_list, dim=0)  # [N, H, W, 3]
    # ps_idx_list = torch.cat(ps_idx_list, dim=0)  # [N, H, W, 3]

    return {
        "all_extrinsics": all_extrinsics, 
        "all_intrinsics": all_intrinsics, 
        "all_world_points": all_world_points,
        "depth_maps": depth_maps,
        "depth_conf_maps": depth_conf_maps,
        "all_images": batch_tensors,
        "aggregated_tokens_list": agg_tokens_tensor,
        "ps_idx_list": ps_idx_list
    }

    # # Predict Tracks
    # # choose your own points to track, with shape (N, 2) for one scene
    # query_points = torch.FloatTensor([[100.0, 200.0], 
    #                                     [60.72, 259.94]]).to(device)
    # track_list, vis_score, conf_score = model.track_head(aggregated_tokens_list, images, ps_idx, query_points=query_points[None])

In [6]:
predictions = run_batched_camera_inference(model_vggt, image_names, device=device_vggt, batch_size=BATCH_SIZE)

Processing the rest of 26 images in batches of 2...


  0%|          | 0/13 [00:00<?, ?it/s]

first_fn: /home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg


  with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype):


5 None None
torch.Size([1, 2, 777, 2048])
tensor([[[[ 0.1673,  0.0909, -0.0761,  ...,  0.0427, -0.0496,  0.0481],
          [ 0.0940,  0.1096, -0.1159,  ...,  0.0482, -0.0069,  0.0618],
          [ 0.0743,  0.0254, -0.0543,  ...,  0.0177, -0.0484,  0.0746],
          ...,
          [ 0.1780,  0.2604,  0.1352,  ...,  0.0397, -0.1301,  0.0439],
          [ 0.2488, -0.0037,  0.0573,  ...,  0.0611, -0.0537,  0.1125],
          [ 0.1231,  0.0809,  0.0473,  ...,  0.0296, -0.0943,  0.0192]],

         [[ 0.0752,  0.1526, -0.0438,  ..., -0.0575,  0.0858,  0.0430],
          [ 0.0564, -0.0451,  0.1498,  ..., -0.0171,  0.1415,  0.0907],
          [ 0.1446,  0.1104,  0.0016,  ..., -0.0689,  0.0769,  0.0445],
          ...,
          [-0.0059,  0.1162,  0.3848,  ..., -0.1466,  0.0770, -0.0598],
          [-0.0151,  0.0765,  0.1950,  ..., -0.0416, -0.0109, -0.0715],
          [-0.0248,  0.0683,  0.0687,  ..., -0.0160, -0.0245, -0.0505]]]])
5 None None
torch.Size([1, 2, 777, 2048])
tensor([[[[-0.228

  8%|▊         | 1/13 [00:06<01:20,  6.69s/it]


when i==0
extrinsic: torch.Size([1, 2, 3, 4])
intrinsic: torch.Size([1, 2, 3, 3])
point_map_unproj: (2, 294, 518, 3)
depth_map: torch.Size([1, 2, 294, 518, 1])
depth_conf_map: torch.Size([1, 2, 294, 518])
batch_tensor: torch.Size([1, 2, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00021.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00031.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 0.1624,  0.0863, -0.0741,  ...,  0.0430, -0.0472,  0.0506],
          [ 0.0900,  0.1055, -0.1159,  ...,  0.0481, -0.0045,  0.0645],
          [ 0.0694,  0.0203, -0.0542,  ...,  0.0178, -0.0466,  0.0764],
          ...,
          [ 0.1696,  0.2570,  0.1283,  ...,  0.0409, -0.1306,  0.0409],
          [ 0.2414, -0.0074,  0.0510,  ...,  0.0607, -0.0535,  0.1090],
          [ 0.1142,  0.0778,  0.0428,  ...,  0.0299, -0.0935,  0.0169]],

         [[ 0.1260,  0.0931, -0.0

 15%|█▌        | 2/13 [00:24<02:26, 13.31s/it]

agg_tokens_tensor: torch.Size([24, 1, 2, 782, 2048]) | agg_tokens: torch.Size([24, 1, 3, 782, 2048])
when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00041.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00051.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 0.1631,  0.0944, -0.0690,  ...,  0.0425, -0.0513,  0.0533],
          [ 0.0907,  0.1145, -0.1094,  ...,  0.0481, -0.0082,  0.0675],
          [ 0.0714,  0.0301, -0.0488,  ...,  0.0178, -0.0489,  0.0799],
          ...,
          [ 0.1711,  0.2701,  0.1345,  ...,  0.0409, -0.1322,  0.0454],
          [ 0.2440,  0.0008,  0.0536,  ...,  0.0605, -0.0549,  0.1127],
        

 23%|██▎       | 3/13 [00:46<02:53, 17.32s/it]

agg_tokens_tensor: torch.Size([24, 1, 4, 782, 2048]) | agg_tokens: torch.Size([24, 1, 3, 782, 2048])
when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00061.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00071.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 0.1668,  0.0950, -0.0597,  ...,  0.0447, -0.0535,  0.0467],
          [ 0.0954,  0.1161, -0.0989,  ...,  0.0517, -0.0101,  0.0604],
          [ 0.0774,  0.0357, -0.0404,  ...,  0.0217, -0.0502,  0.0717],
          ...,
          [ 0.1751,  0.2753,  0.1377,  ...,  0.0413, -0.1334,  0.0398],
          [ 0.2505,  0.0102,  0.0579,  ...,  0.0606, -0.0563,  0.1079],
        

 31%|███       | 4/13 [00:58<02:17, 15.24s/it]

agg_tokens_tensor: torch.Size([24, 1, 6, 782, 2048]) | agg_tokens: torch.Size([24, 1, 3, 782, 2048])
when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00081.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00091.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 1.7430e-01,  7.5612e-02, -5.4490e-02,  ...,  4.1338e-02,
           -4.6672e-02,  5.2860e-02],
          [ 1.0248e-01,  9.6647e-02, -9.5574e-02,  ...,  4.8399e-02,
           -2.2146e-03,  6.7467e-02],
          [ 8.2970e-02,  1.5988e-02, -3.6173e-02,  ...,  1.8166e-02,
           -4.2330e-02,  7.9070e-02],
          ...,
          [ 1.8362e-01,  2.5689e-01,  1.5513e-0

 38%|███▊      | 5/13 [01:13<02:01, 15.16s/it]

agg_tokens_tensor: torch.Size([24, 1, 8, 782, 2048]) | agg_tokens: torch.Size([24, 1, 3, 782, 2048])
when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00101.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00111.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 1.6527e-01,  9.2079e-02, -7.5282e-02,  ...,  4.4466e-02,
           -5.3713e-02,  5.6851e-02],
          [ 9.2391e-02,  1.1436e-01, -1.1259e-01,  ...,  5.1272e-02,
           -1.0859e-02,  7.1286e-02],
          [ 7.6359e-02,  3.3170e-02, -4.9823e-02,  ...,  2.1850e-02,
           -5.1543e-02,  8.2112e-02],
          ...,
          [ 1.7738e-01,  2.6619e-01,  1.3161e-0

 46%|████▌     | 6/13 [01:30<01:48, 15.55s/it]

agg_tokens_tensor: torch.Size([24, 1, 10, 782, 2048]) | agg_tokens: torch.Size([24, 1, 3, 782, 2048])
when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00121.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00131.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 0.1775,  0.0929, -0.0862,  ...,  0.0442, -0.0527,  0.0452],
          [ 0.1057,  0.1120, -0.1250,  ...,  0.0492, -0.0110,  0.0618],
          [ 0.0897,  0.0260, -0.0564,  ...,  0.0199, -0.0505,  0.0743],
          ...,
          [ 0.1849,  0.2589,  0.1272,  ...,  0.0404, -0.1311,  0.0420],
          [ 0.2596, -0.0060,  0.0483,  ...,  0.0598, -0.0532,  0.1094],
       

 54%|█████▍    | 7/13 [01:40<01:23, 13.84s/it]

agg_tokens_tensor: torch.Size([24, 1, 12, 782, 2048]) | agg_tokens: torch.Size([24, 1, 3, 782, 2048])
when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00141.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00151.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 0.1785,  0.0889, -0.0751,  ...,  0.0420, -0.0552,  0.0527],
          [ 0.1068,  0.1098, -0.1166,  ...,  0.0472, -0.0134,  0.0683],
          [ 0.0908,  0.0273, -0.0532,  ...,  0.0169, -0.0523,  0.0808],
          ...,
          [ 0.1854,  0.2691,  0.1394,  ...,  0.0398, -0.1318,  0.0487],
          [ 0.2599,  0.0027,  0.0550,  ...,  0.0597, -0.0544,  0.1175],
       

 62%|██████▏   | 8/13 [01:56<01:12, 14.55s/it]

agg_tokens_tensor: torch.Size([24, 1, 14, 782, 2048]) | agg_tokens: torch.Size([24, 1, 3, 782, 2048])
when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00161.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00171.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 1.6882e-01,  9.1660e-02, -6.1622e-02,  ...,  4.1994e-02,
           -5.1072e-02,  5.1361e-02],
          [ 9.6989e-02,  1.1218e-01, -1.0345e-01,  ...,  4.8286e-02,
           -8.2342e-03,  6.5631e-02],
          [ 7.8853e-02,  3.0872e-02, -4.3043e-02,  ...,  1.7112e-02,
           -4.9166e-02,  7.8764e-02],
          ...,
          [ 1.8337e-01,  2.6482e-01,  1.4348e-

 69%|██████▉   | 9/13 [02:17<01:06, 16.52s/it]

agg_tokens_tensor: torch.Size([24, 1, 16, 782, 2048]) | agg_tokens: torch.Size([24, 1, 3, 782, 2048])
when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00181.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00191.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 1.7345e-01,  6.9685e-02, -7.3563e-02,  ...,  4.3827e-02,
           -4.5267e-02,  4.0690e-02],
          [ 1.0204e-01,  8.8703e-02, -1.1491e-01,  ...,  5.0104e-02,
           -6.0400e-05,  5.5940e-02],
          [ 7.8778e-02,  1.5322e-03, -5.0961e-02,  ...,  2.0923e-02,
           -4.0424e-02,  6.9691e-02],
          ...,
          [ 1.7633e-01,  2.4224e-01,  1.2732e-

 77%|███████▋  | 10/13 [02:33<00:48, 16.32s/it]

when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00201.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00211.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 0.1698,  0.0983, -0.0689,  ...,  0.0433, -0.0541,  0.0537],
          [ 0.0989,  0.1181, -0.1085,  ...,  0.0495, -0.0102,  0.0687],
          [ 0.0796,  0.0346, -0.0468,  ...,  0.0195, -0.0506,  0.0813],
          ...,
          [ 0.1774,  0.2808,  0.1459,  ...,  0.0393, -0.1333,  0.0468],
          [ 0.2555,  0.0115,  0.0653,  ...,  0.0586, -0.0575,  0.1155],
          [ 0.1235,  0.0937,  0.0629,  ...,  0.0299, -0.0994,  0.0235]],

         [[ 0.0685,  0.1572, -0.071

 85%|████████▍ | 11/13 [02:54<00:35, 17.94s/it]

when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00221.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00231.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 0.1690,  0.0712, -0.0699,  ...,  0.0467, -0.0493,  0.0548],
          [ 0.0965,  0.0931, -0.1107,  ...,  0.0528, -0.0065,  0.0684],
          [ 0.0794,  0.0131, -0.0485,  ...,  0.0228, -0.0465,  0.0780],
          ...,
          [ 0.1787,  0.2517,  0.1302,  ...,  0.0408, -0.1336,  0.0494],
          [ 0.2530, -0.0129,  0.0550,  ...,  0.0603, -0.0587,  0.1161],
          [ 0.1248,  0.0688,  0.0440,  ...,  0.0301, -0.0993,  0.0253]],

         [[ 0.0497, -0.0381, -0.188

 92%|█████████▏| 12/13 [03:08<00:16, 16.75s/it]

when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])
batch_names: ['/home/skhalid/Documents/data/tandt_db/tandt/train/images/00001.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00241.jpg', '/home/skhalid/Documents/data/tandt_db/tandt/train/images/00251.jpg']
5 None None
torch.Size([1, 3, 777, 2048])
tensor([[[[ 0.1695,  0.0860, -0.0827,  ...,  0.0456, -0.0513,  0.0410],
          [ 0.0985,  0.1064, -0.1231,  ...,  0.0510, -0.0082,  0.0549],
          [ 0.0819,  0.0237, -0.0558,  ...,  0.0207, -0.0482,  0.0672],
          ...,
          [ 0.1810,  0.2645,  0.1228,  ...,  0.0428, -0.1336,  0.0349],
          [ 0.2553,  0.0008,  0.0519,  ...,  0.0631, -0.0562,  0.1031],
          [ 0.1288,  0.0831,  0.0411,  ...,  0.0313, -0.0992,  0.0118]],

         [[ 0.0609,  0.0571, -0.070

100%|██████████| 13/13 [03:27<00:00, 15.95s/it]

when i>0
extrinsic: torch.Size([1, 3, 3, 4])
intrinsic: torch.Size([1, 3, 3, 3])
point_map_unproj: (3, 294, 518, 3)
depth_map: torch.Size([1, 3, 294, 518, 1])
depth_conf_map: torch.Size([1, 3, 294, 518])
batch_tensor: torch.Size([1, 3, 3, 294, 518])





In [7]:
all_extrinsics = predictions["all_extrinsics"].cpu().numpy()
all_intrinsics = predictions["all_intrinsics"].cpu().numpy()
all_world_points = predictions["all_world_points"]
depth_maps = predictions["depth_maps"].cpu().numpy()
depth_conf_maps = predictions["depth_conf_maps"].cpu().numpy()
all_images = predictions["all_images"].cpu()
ps_idx_list = predictions["ps_idx_list"]
aggregated_tokens_tensor = predictions["aggregated_tokens_list"]
# ps_idx_list = predictions["ps_idx_list"].cpu().numpy()

print(all_extrinsics.shape, all_intrinsics.shape, all_world_points.shape, depth_maps.shape, depth_conf_maps.shape, all_images.shape, aggregated_tokens_tensor.shape)

(26, 3, 4) (26, 3, 3) (26, 294, 518, 3) (26, 294, 518, 1) (26, 294, 518) torch.Size([26, 3, 294, 518]) torch.Size([24, 1, 26, 782, 2048])


In [8]:
torch.cuda.empty_cache()

In [9]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2

def mask_to_query_points(mask, max_points=100):
    """
    Convert a binary mask to a list of query points (N, 2)
    """
    mask = mask.astype(np.uint8)
    coords = np.column_stack(np.where(mask > 0))  # shape: (N, 2) as (row, col)
    if coords.shape[0] == 0:
        raise ValueError("No foreground pixels found in mask.")
    
    # Subsample if too many points
    if coords.shape[0] > max_points:
        idx = np.random.choice(len(coords), max_points, replace=False)
        coords = coords[idx]

    # Flip from (row, col) to (x, y)
    query_points = torch.FloatTensor(coords[:, [1, 0]])  # (N, 2)
    return query_points

#track_list: torch.Size([4, 1, 8, 50, 2])
def visualize_tracks(all_images, track_list, ps_idx_list, radius=2):
    """
    Visualize tracked points on image sequence.
    
    all_images: (T, 3, H, W) NumPy array
    track_list: (T, 1, N, 2) torch.Tensor
    ps_idx_list: (N, 2) torch.Tensor or NumPy array
    """
    T = all_images.shape[0]
    N = ps_idx_list.shape[0]

    all_images = all_images.detach().cpu().numpy()

    # Normalize images if necessary
    if all_images.max() <= 1.0:
        all_images = (all_images * 255).astype(np.uint8)
    else:
        all_images = all_images.astype(np.uint8)

    # Generate unique colors
    colors = plt.cm.jet(np.linspace(0, 1, N))[:, :3] * 255

    for t in range(T):
        # Convert image to (H, W, 3) and transpose channels
        img = np.transpose(all_images[t], (1, 2, 0)).copy()  # (H, W, 3)

        for i in range(N):
            point = track_list[1, 0, t, i].detach().cpu().numpy()
            if np.any(np.isnan(point)) or np.any(np.isinf(point)):
                continue
            # x, y = int(round(point[0])), int(round(point[1]))
            x, y = int(round(float(point[0]))), int(round(float(point[1])))
            color = tuple(map(int, colors[i]))
            cv2.circle(img, (x, y), radius, color, -1)

        plt.imshow(img)
        plt.title(f"Frame {t}")
        plt.axis("off")
        plt.show()

In [10]:
'''
OBJECT_MASKING
'''
import torch
from transformers import OwlViTProcessor, OwlViTForObjectDetection
from segment_anything import sam_model_registry, SamPredictor
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt

# Load OWL-ViT
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
model = OwlViTForObjectDetection.from_pretrained("google/owlvit-large-patch14").to(device)
processor = OwlViTProcessor.from_pretrained("google/owlvit-large-patch14")

# Load SAM
sam_checkpoint = "/home/skhalid/Downloads/sam_vit_l.pth"
sam = sam_model_registry["vit_l"](checkpoint=sam_checkpoint).to(device).eval()
predictor = SamPredictor(sam)

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [11]:
# Load image
# image_paths = ["/home/skhalid/Documents/data/tandt_db/tandt/truck/images/000001.jpg"]
# # image_paths = ["/home/skhalid/Documents/data/nerf_synthetic/lego/images/r_0.png"]
# # image_paths = ["/home/skhalid/Documents/data/bicycle/images_4/_DSC8679.JPG"]
# # image_paths = ["/home/skhalid/Documents/data/Synthetic4Relight/hotdog/train/000.png"]
# # image_paths = ["/home/skhalid/Documents/data/data_dtu/DTU_scan24/inputs/images/000000.png"]
# # image_paths = ["/home/skhalid/Documents/data/banana/images/frame_00002.JPG"]
# image_paths = []

object_masks = []

for image_path in image_names:
    image = Image.open(image_path).convert("RGB").resize((518, 350), resample=Image.BILINEAR) # <--- match VGGT
    image_np = np.array(image)

    # Text prompt
    # texts = [["car", "bicycle", "trees", "grass", "ground", "bench", "lego", "fruit"]]  # You can modify this list
    texts = [["truck", "bicycle", "ground", "bench", "tree", "chair", "building", "sky", "clouds", "road", "lego", "grass", "toy", "hotdog", "fruit", "food", "window"]]

    # Prepare input for OWL-ViT
    inputs = processor(text=texts, images=image, return_tensors="pt").to(device)

    # Detect with OWL-ViT
    # print("Begin detect")
    with torch.no_grad():
        outputs = model(**inputs)
    # print("End detect")

    # Get boxes and scores
    target_sizes = torch.tensor([image.size[::-1]]).to(device)
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=SCORE_THRESHOLD)[0]

    # Run SAM
    predictor.set_image(image_np)

    # Process each box from OWL-ViT
    for i, (box, score, label) in enumerate(zip(results["boxes"], results["scores"], results["labels"])):

        # Modify the class as needed
        if texts[0][label] == TARGET:
            box = box.cpu().numpy().astype(int)
            x0, y0, x1, y1 = box
            print(f"{texts[0][label]}: {score:.2f} at box {box}")

            # SAM expects box in XYXY format
            input_box = np.array([x0, y0, x1, y1])
            masks, _, _ = predictor.predict(box=input_box[None, :], multimask_output=False)

            # Overlay mask
            mask = masks[0]
            overlay = image_np.copy()
            overlay[mask] = (255, 0, 0)  # Red mask

            # Convert mask to query points
            query_points = mask_to_query_points(mask, max_points=SAMPLED_POINTS).to(device_vggt)

            # Track
            print("query_points.shape: {}".format(query_points.shape))
            print("all_images.shape: {}".format(all_images.shape))
            print("ps_idx_list: {}".format(ps_idx_list))
            print("aggregated_tokens_tensor: {}".format(aggregated_tokens_tensor.shape))
            track_list, vis_score, conf_score = model_vggt.track_head(
                aggregated_tokens_tensor, 
                all_images[None], 
                ps_idx_list[0], 
                query_points=query_points[None]
            )

            # Draw bounding box
            cv2.rectangle(overlay, (x0, y0), (x1, y1), (0, 255, 0), 2)
            cv2.putText(overlay, f"{texts[0][label]} {score:.2f}", (x0, y0 - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

            # Show
            plt.imshow(overlay)
            plt.axis("off")
            plt.show()
            resized_mask = cv2.resize(mask.astype(np.uint8), (depth_conf_maps[0].shape[1], depth_conf_maps[0].shape[0]), interpolation=cv2.INTER_NEAREST)
            resized_mask = resized_mask[None, ...].astype(np.float32)
            
            # Visualize
            track_list = torch.stack(track_list)
            query_points = query_points.detach().cpu().numpy()
            visualize_tracks(all_images, track_list, query_points)

            torch.cuda.empty_cache()

            break

    object_masks.append(resized_mask)
    break

object_masks = np.concatenate((object_masks), axis=0)

NameError: name 'resized_mask' is not defined

In [None]:
# print(mask.view(depth_conf_maps[0].shape[0], depth_conf_maps[0].shape[1]).shape, depth_conf_maps[0].shape)
print(mask.shape, resized_mask.shape, depth_conf_maps[0].shape, object_masks.shape)

In [None]:
N = all_intrinsics.shape[0]
MASKING_MODE = "depth"

for i in range(N):
    # filter out pixels
    if MASKING_MODE == "depth":
        conf_mask = depth_conf_maps[i]
    elif MASKING_MODE == "object":
        conf_mask = object_masks[i]
    else:
        assert MASKING_MODE == "depth" or MASKING_MODE == "object"

    conf_mask /= conf_mask.max()
    conf_mask[conf_mask<0.5] = 0.0
    conf_mask[conf_mask>0.0] = 1.0
    all_world_points[i, :] *= conf_mask[..., None]
    depth_maps[i, :] *= conf_mask[..., None]
    all_images[i, :] *= conf_mask[None, ...].astype(np.uint8)

In [12]:
import numpy as np
import open3d as o3d
import torch

def coords_to_mask(coords_all, shape, conf_score):
    """
    Creates a binary mask from a list of 2D coordinates.

    Args:
        coords: (N, 2) tensor or array of (x, y) or (u, v) pixel coordinates.
        shape: (H, W) shape of the desired mask.

    Returns:
        mask: (H, W) binary mask with 1s at specified coordinates.
    """
    B, H, W, C = shape
    mask = torch.zeros(B, H, W, dtype=torch.bool)
    threshold = 0.20

    for b in range(B):
        coords = coords_all[b].round().long()  # Ensure integer pixel positions
        conf = conf_score[b]
        coords = coords[conf>threshold] 
        x, y = coords[:, 0], coords[:, 1]

        # Filter valid in-bounds coords
        valid = (x >= 0) & (x < W) & (y >= 0) & (y < H)
        x, y = x[valid], y[valid]

        mask[b, y, x] = 1.0

    return mask

def show_point_clouds_batch(point_maps, images, masks=None, show_individually=False):
    """
    Generates and optionally visualizes point clouds from batched point maps and images.

    Args:
        point_maps: (N, H, W, 3)
        images: (N, H, W, 3) or (N, 3, H, W)
        masks: optional (N, H, W) binary masks
        show_individually: if True, shows each scene separately. If False, merges all into one scene.
    """
    if isinstance(point_maps, torch.Tensor):
        point_maps = point_maps.detach().cpu().numpy()
    if isinstance(images, torch.Tensor):
        images = images.detach().cpu().numpy()
        if images.shape[1] == 3:
            images = np.transpose(images, (0, 2, 3, 1))  # (N, H, W, 3)

    N, H, W, _ = point_maps.shape
    pcd_list = []

    for i in range(N):
        points = point_maps[i].reshape(-1, 3)
        colors = images[i].reshape(-1, 3)
        colors = colors / 255.0 if colors.max() > 1 else colors

        if masks is not None:
            mask = masks[i].reshape(-1)
            points = points[mask > 0]
            colors = colors[mask > 0]

        # Filter invalid points
        valid_mask = np.isfinite(points).all(axis=1)
        points = points[valid_mask]
        colors = colors[valid_mask]

        # Create Open3D point cloud
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        pcd.colors = o3d.utility.Vector3dVector(colors)

        if show_individually:
            o3d.visualization.draw_geometries([pcd])
        else:
            pcd_list.append(pcd)

    if not show_individually:
        o3d.visualization.draw_geometries(pcd_list)

def show_meshes_batch(point_maps, images, masks=None, show_individually=False):
    """
    Generates and visualizes meshes from batched point maps and images.

    Args:
        point_maps: (N, H, W, 3)
        images: (N, H, W, 3) or (N, 3, H, W)
        masks: optional (N, H, W) binary masks
        show_individually: if True, shows each mesh separately.
    """
    if isinstance(point_maps, torch.Tensor):
        point_maps = point_maps.detach().cpu().numpy()

    if images.shape[1] == 3:
        images = np.transpose(images, (0, 2, 3, 1))  # (N, H, W, 3)

    N, H, W, _ = point_maps.shape
    mesh_list = []

    for i in range(N):
        xyz = point_maps[i]
        rgb = images[i] / 255.0 if images[i].max() > 1 else images[i]

        print(rgb.shape)

        if masks is not None:
            mask = masks[i]
        else:
            mask = np.isfinite(xyz[..., 2]) & (xyz[..., 2] > 0)

        # Create point cloud
        pcd = o3d.geometry.PointCloud()
        valid_points = xyz[mask]
        valid_colors = rgb[mask]
        pcd.points = o3d.utility.Vector3dVector(valid_points)
        pcd.colors = o3d.utility.Vector3dVector(valid_colors)

        # Estimate normals (required for meshing)
        pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.05, max_nn=30))
        # pcd.orient_normals_consistent_tangent_plane(10)

        # Reconstruct mesh using Poisson (good for dense, smooth surfaces)
        mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=8)

        # Crop mesh to input bounding box to remove floaty artifacts
        bbox = pcd.get_axis_aligned_bounding_box()
        mesh = mesh.crop(bbox)

        # Optional clean-up
        mesh.remove_non_manifold_edges()
        mesh.compute_vertex_normals()

        if show_individually:
            o3d.visualization.draw_geometries([mesh])
        else:
            mesh_list.append(mesh)

    if not show_individually and mesh_list:
        o3d.visualization.draw_geometries(mesh_list)


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [None]:
track_list.shape

In [None]:
print(vis_score.shape)
print(vis_score.min(), vis_score.max())

In [None]:
conf_score.shape

In [None]:
pixel_values = track_list[0, 0, ...].detach().cpu()
print(pixel_values.shape)

In [None]:
print(all_world_points.shape)

In [None]:
all_images.shape

In [None]:
masks = coords_to_mask(pixel_values, all_world_points.shape, conf_score[0].detach().cpu())

In [None]:
masks.shape
print(masks.min(), masks.max())

In [None]:
plt.figure()
plt.imshow(masks[3])

In [None]:
show_point_clouds_batch(all_world_points, all_images, masks)

In [None]:
indx = 4
print(conf_score[0, indx].min(), conf_score[0, indx].max())

In [None]:
# show_meshes_batch(all_world_points, all_images, masks)

In [None]:
model_vggt