In [1]:
from main import MAST3RGaussians
import sys
from pathlib import Path
import torch
import torchvision.transforms as tfm
from typing import Tuple, Union, Optional
from natsort import natsorted
import os
from huggingface_hub import hf_hub_download
sys.path.append('src/mast3r_src')
sys.path.append('src/mast3r_src/dust3r')
sys.path.append('src/pixelsplat_src')

from dust3r.utils.image import load_images
from mast3r.model import AsymmetricMASt3R
from mast3r.utils.misc import hash_md5


In [None]:
H,W = 512,512

def preprocess_image(img):
    """
    Applies pre-processing transformations to the image. 
    """
    _, h, w = img.shape
    orig_shape = h, w

    # Normalize the image
    normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    img = normalize(img).unsqueeze(0)

    return img, orig_shape

def load_single_image(path: Union[str, Path], resize: Optional[Union[int, Tuple]] = None, rot_angle: float = 0) -> torch.Tensor:
    """
    Loads a single image and resizes it to the pairs = make_pairs(imgs, scene_graph=scene_graph, prefilter=None, symmetrize=True) Height and Width specified in the config.
    """
    if isinstance(resize, int):
        resize = (resize, resize)
    if isinstance(path, str):
        path = Path(path)
        img = Image.open(path).convert("RGB")
    else:
        img = path
    img = tfm.ToTensor()(img)
    if resize is not None:
        img = tfm.Resize(resize, antialias=True)(img)
    img = tfm.functional.rotate(img, rot_angle)
    return img

# def load_images(img0_path, img1_path):
#     """
#     Loads and calls pre-processing to get the images ready for mast3r inference
#     """
#     img0 = load_single_image(img0_path, (H, W))
#     img1 = load_single_image(img1_path, (H, W))

#     img0, img0_orig_shape = preprocess_image(img0)
#     img1, img1_orig_shape = preprocess_image(img1)

#     img_pair = [
#         {"img": img0, "idx": 0, "instance": 0, "true_shape": torch.tensor(img0.shape[-2:], dtype=torch.int32)},
#         {"img": img1, "idx": 1, "instance": 1, "true_shape": torch.tensor(img1.shape[-2:], dtype=torch.int32)},
#     ]

#     return img_pair

In [None]:
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

image_dir = Path("images")
output_dir = Path("pointclouds")
top = 2
image_list = natsorted(os.listdir(f'{image_dir}'))[:top]
image_list = [f'{image_dir}/{imgName}' for imgName in image_list]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
splatt3r_device=torch.device('cuda:1')
image_list, device

In [None]:
model_name = "brandonsmart/splatt3r_v1.0"
filename = "epoch=19-step=1200.ckpt"
weights_path = hf_hub_download(repo_id=model_name, filename=filename)
splatt3r_model = MAST3RGaussians.load_from_checkpoint(weights_path, map_location=splatt3r_device)
splatt3r_model.to(splatt3r_device)  # Ensure model is on correct device
print(f"Successfully loaded Splatt3r model onto {splatt3r_device}")

In [None]:
mast3r_model = AsymmetricMASt3R.from_pretrained(Path("checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth")).to(device)
print("Succesfully loaded Mast3r model")

In [None]:
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

In [None]:
silent = False
imgs = load_images(image_list, size=H, verbose=not silent)
splatt3r_imgs = load_images(image_list, size=H, verbose=not silent)

In [None]:
imgs[1].keys()

In [None]:
for img in imgs:
    img['img'] = img['img'].to(device)
    img['original_img'] = img['original_img'].to(device)
    img['true_shape'] = torch.from_numpy(img['true_shape'])

for img in splatt3r_imgs:
    img['img'] = img['img'].to(splatt3r_device)
    img['original_img'] = img['original_img'].to(splatt3r_device)
    img['true_shape'] = torch.from_numpy(img['true_shape'])
    

### Mast3r encoder output

In [None]:
# m_feat1, m_feat2, m_pos1, m_pos2 = mast3r_model._encode_image_pairs(imgs[0]['img'], imgs[1]['img'], imgs[0]['true_shape'], imgs[1]['true_shape'])
with torch.inference_mode():
    (m_shape1, m_shape2), (m_feat1, m_feat2), (m_pos1, m_pos2) = mast3r_model._encode_symmetrized(imgs[0], imgs[1])

In [None]:
m_feat1.shape, m_feat2.shape, m_pos1.shape, m_pos2.shape

In [None]:
torch.cuda.empty_cache()

### Splatt3r encoder output

In [None]:
with torch.inference_mode():
    (shape1, shape2), (feat1, feat2), (pos1, pos2) = splatt3r_model.encoder._encode_symmetrized(splatt3r_imgs[0], splatt3r_imgs[1])

In [None]:
shape1, shape2, feat1.shape, feat2.shape, pos1.shape, pos2.shape

### Mast3r decoder output

In [None]:
with torch.inference_mode():
    m_dec1, m_dec2 = mast3r_model._decoder(m_feat1, m_pos1, m_feat2, m_pos2)

In [None]:
len(m_dec1), m_dec1[0].shape

### Splatt3r decoder output

In [None]:
with torch.inference_mode():
    dec1, dec2 = splatt3r_model.encoder._decoder(feat1, pos1, feat2, pos2)

In [None]:
len(dec1), dec1[0].shape

## Get preds

### Splatt3r

In [None]:
pred_1 = splatt3r_model.encoder._downstream_head(1, [tok.float() for tok in dec1], shape1)
pred_2 = splatt3r_model.encoder._downstream_head(2, [tok.float() for tok in dec2], shape2)

In [None]:
pred_1.keys()

In [None]:
pred_2.keys()

In [None]:
pred_1['pts3d'].shape, pred_1['conf'].shape, pred_1['desc_conf'].shape, pred_1['scales'].shape, pred_1['rotations'].shape, pred_1['sh'].shape, pred_1['opacities'].shape, pred_1['means'].shape,    

### Mast3r

In [None]:
m_pred_1 = mast3r_model._downstream_head(1, [tok.float() for tok in m_dec1], m_shape1)
m_pred_2 = mast3r_model._downstream_head(2, [tok.float() for tok in m_dec2], m_shape2)

In [None]:
m_pred_1.keys()

In [None]:
m_pred_1['pts3d'].shape, m_pred_1['conf'].shape, m_pred_1['desc_conf'].shape

### SparseGA exploration

In [2]:
from model_replacement_test import MASt3R

gen3d = MASt3R(imgdir=Path("./images"), outdir=Path("./pointclouds"), sample_images=False)

# run multi-view Mast3R SfM
scene = gen3d.reconstruct_scene(outdir=str(gen3d.outdir),
                                cache_dir="/scratch/mast3r_cache",
                                scene_graph="complete",
                                optim_level="refine+depth",
                                lr1=0.07, niter1=300,
                                lr2=0.01, niter2=300)

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [on]




Loading model from: /home2/aditya.vadali/miniconda3/envs/splatt3r/lib/python3.11/site-packages/lpips/weights/v0.1/alex.pth
Successfully loaded Splatt3r model
>> Loading a list of 4 images
 - adding images/1.jpg with resolution 1280x963 --> 512x512
 - adding images/2.jpg with resolution 1280x963 --> 512x512
 - adding images/3.jpg with resolution 1280x963 --> 512x512
 - adding images/4.jpg with resolution 1280x963 --> 512x512
 (Found 4 images)
----Using scene graph method complete---
12 pairs constructed


100%|██████████| 12/12 [00:00<00:00, 15665.00it/s]
 25%|██▌       | 1/4 [00:01<00:04,  1.49s/it]

[INFO] Processed canonical data for image: images/1.jpg and avg_gaussians keys: ['sh', 'scales', 'rotations', 'opacities', 'offsets', 'means']


 50%|█████     | 2/4 [00:02<00:02,  1.50s/it]

[INFO] Processed canonical data for image: images/2.jpg and avg_gaussians keys: ['sh', 'scales', 'rotations', 'opacities', 'offsets', 'means']


 75%|███████▌  | 3/4 [00:04<00:01,  1.46s/it]

[INFO] Processed canonical data for image: images/3.jpg and avg_gaussians keys: ['sh', 'scales', 'rotations', 'opacities', 'offsets', 'means']


100%|██████████| 4/4 [00:05<00:00,  1.47s/it]


[INFO] Processed canonical data for image: images/4.jpg and avg_gaussians keys: ['sh', 'scales', 'rotations', 'opacities', 'offsets', 'means']
[CANON_KEYS]['images/1.jpg', 'images/2.jpg', 'images/3.jpg', 'images/4.jpg']
init focals = [463.9657  443.6589  456.11243 459.68405]


100%|██████████| 300/300 [00:11<00:00, 26.43it/s, lr=0.0000, loss=0.145]


>> final loss = 0.14474214613437653


100%|██████████| 300/300 [00:15<00:00, 19.82it/s, lr=0.0000, loss=0.588]


>> final loss = 0.5878666043281555
Final focals = [463.0467  463.86038 467.52655 466.56027]


In [None]:
scene.canonical_paths

In [None]:
scene.img_paths

In [None]:
scene.intrinsics[0], len(scene.intrinsics)

In [None]:
scene.cam2w[0], len(scene.cam2w)

In [None]:
scene.depthmaps[0]

In [None]:
scene.pts3d[0], len(scene.pts3d), scene.pts3d[0].shape #???? why is the shape like that? GPT says it's 64*64 (subsampled points) + mast3r correspondences = 6430

In [None]:
scene.pts3d_colors[0], scene.pts3d_colors[0].shape 

In [None]:
scene.imgs[0][30, 40] # scene.imgs[img_id][px_y, px_x] gives the color of the pixel of that particular image.

In [None]:
pixels, idxs, offsets = scene.anchors[0]

In [None]:
pixels.shape, idxs.shape, offsets.shape

In [None]:
idxs

In [None]:
px_coords = pixels[:, :2].detach().cpu().numpy().astype(int)
px_coords

In [None]:
from dust3r.utils.device import to_numpy62144, 3
pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=True))

In [18]:
import numpy as np
np.median(confs[0])

1.3804955

### Anchors

anchors is a dictionary of the form:

```python
{'img_index': [pixels, idxs, offsets]}
```

- `pixels` is a list of homogenous pixel coordinates: `[[px_x, px_y, 1],....]`

- `idxs` are the indexes of the closest anchor point for any point. (For anchor points, they are themselves.)

- `offsets` how much the depth of the current pixel varies to that of the depth of the closest anchor point.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Example: coords is your (N, 2) numpy array of [x, y] pixel positions
# coords = np.array([...])

def visualize_pixel_coords(coords):
    xs = coords[:, 0]
    ys = coords[:, 1]

    plt.figure(figsize=(6, 6))
    plt.scatter(xs, ys, s=1)  # s = point size
    plt.gca().invert_yaxis()  # Flip Y to match image coordinate system (origin at top-left)
    plt.xlim(0, 512)
    plt.ylim(0, 512)
    plt.title("Anchor Points Visualization")
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.grid(True)
    plt.show()

pixels = scene.anchors[3][0][:, :2].detach().cpu().numpy().astype(int)

visualize_pixel_coords(pixels)


In [None]:
pred_1['pts3d'].shape, pred_1['scales'].shape, pred_1['rotations'].shape, pred_1['sh'].shape, pred_1['opacities'].shape, pred_1['means'].shape

In [None]:
def get_image_hashes_list(scene):
    image_hashes = [None]*len(scene.img_paths)
    for i, img_path in enumerate(scene.img_paths):
        img_hash = hash_md5(img_path)
        image_hashes[i] = img_hash
        print(f"Image {i}: {img_path} -> {img_hash}")
    
    return image_hashes

def load_gaussian(scene, cache_dir, index):
    image_hahses = get_image_hashes_list(scene)
    return torch.load(f"{cache_dir}/{image_hahses[index]}.pth")

# /home2/aditya.vadali/splatt3r-MR-Project/mast3r_cache/gaussian_attributes/0ad929c4b83d461d351cfe97d8cb7558.pth
# gaussians = torch.load('mast3r_cache/gaussian_attributes/0ad929c4b83d461d351cfe97d8cb7558.pth')
# (sh, scales, rotations, opacities, means) = gaussians

In [None]:
(sh, scales, rotations, opacities, means) = load_gaussian(scene, 'mast3r_cache/gaussian_attributes', 0)
sh.squeeze().shape

In [None]:
means.squeeze()[0][0], mapping[(0,0,0)]

### Create (x,y,z) to {gaussians} mapping

In [None]:
pixels0, pixels1 = scene.anchors[0][0], scene.anchors[1][0]
pts3d_img0 = scene.pts3d[0]    # 3D points for image 0
pt_to_gaussian_map = {}

for i, [x, y, s] in enumerate(pixels0):
    # Use the 3D point from image 0 at index i
    point_3d = pts3d_img0[i]
    
    # Get Gaussian parameters at pixel (x, y) from the full prediction
    pt_to_gaussian_map[i] = {  # Use index i as key, or convert point_3d to tuple
        'pixel': (int(x), int(y)),
        'point_3d': point_3d,
        'sh': sh.squeeze()[int(y), int(x)],
        'scales': scales.squeeze()[int(y), int(x)],
        'rotations': rotations.squeeze()[int(y), int(x)],
        'opacities': opacities.squeeze()[int(y), int(x)],
        'means': means.squeeze()[int(y), int(x)]
    }

In [None]:
from mast3r.utils.misc import hash_md5

# pixels = []
# for i in range(len(scene.anchors)):
#     pixels.append(scene.anchors[i][0])

# Get MD5 hashes for each image in the scene
image_hashes = []
for i, img_path in enumerate(scene.img_paths):
    img_hash = hash_md5(img_path)
    image_hashes.append((i, img_path, img_hash))
    print(f"Image {i}: {img_path} -> {img_hash}")

# Access Gaussian attributes for each image
cache_dir = "mast3r_cache"
gaussian_attributes = {}

for img_idx, img_path, img_hash in image_hashes:
    gaussians_path = f"{cache_dir}/gaussian_attributes/{img_hash}.pth"
    
    try:
        # Load the Gaussian attributes for this image
        gaussians = torch.load(gaussians_path)
        sh, scales, rotations, opacities, means = gaussians
        
        gaussian_attributes[img_idx] = {
            'image_path': img_path,
            'hash': img_hash,
            'sh': sh,
            'scales': scales, 
            'rotations': rotations,
            'opacities': opacities,
            'means': means
        }
        
        print(f"Loaded Gaussians for image {img_idx}: {sh.shape}")
        
    except FileNotFoundError:
        print(f"Gaussian attributes not found for image {img_idx} (hash: {img_hash})")

In [None]:
gaussian_attributes[0]['sh'].shape

In [None]:
import utils.geometry as geometry

# colors = scene.pts3d_colors
covariances = geometry.build_covariance(scales, rotations)

coords_to_gaussians_map = {}

def map_dense_pts3d_to_pixels_with_colors(scene):
    pts3d_dense, _, _ = scene.get_dense_pts3d(clean_depth=True)
    
    mapping = {}
    for img_idx, pts3d_img in enumerate(pts3d_dense):
        img = scene.imgs[img_idx]
        H, W = 512, 512
        
        for y in range(H):
            for x in range(W):
                linear_idx = y * W + x
                if linear_idx < len(pts3d_img):
                    pt_3d = pts3d_img[linear_idx]
                    color = img[y, x] 
                    
                    mapping[(img_idx, x, y)] = {
                        'pt_3d': pt_3d,
                        'color': color 
                    }
    
    return mapping

mapping = map_dense_pts3d_to_pixels_with_colors(scene)


In [None]:
mapping[(0,32,32)]

In [None]:
for (img_index, x, y), mapping_dict in mapping.items():
    covariances = geometry.build_covariance(gaussian_attributes[img_index]['scales'], gaussian_attributes[img_index]['rotations'])
    coords_to_gaussians_map[tuple(mapping_dict['pt_3d'].detach().cpu().numpy())] = {
        'pixel': (int(x), int(y)),
        'image_index': img_index,
        'color': mapping_dict['color'],
        'sh': gaussian_attributes[img_index]['sh'].squeeze()[int(y), int(x)],
        'scales': gaussian_attributes[img_index]['scales'].squeeze()[int(y), int(x)],
        'rotations': gaussian_attributes[img_index]['rotations'].squeeze()[int(y), int(x)],
        'opacities': gaussian_attributes[img_index]['opacities'].squeeze()[int(y), int(x)],
        'means': gaussian_attributes[img_index]['means'].squeeze()[int(y), int(x)],
        'covariances': covariances.squeeze()[int(y), int(x)]# The decoder which we use to render the predicted Gaussians into
    }

In [None]:
list(coords_to_gaussians_map.items())[0]

In [None]:
import trimesh
from src.mast3r_src.dust3r.dust3r.viz import OPENGL, pts3d_to_trimesh, cat_meshes
from plyfile import PlyData, PlyElement
from scipy.spatial.transform import Rotation
import einops

def save_gaussians_as_ply(coords_to_gaussians_map, save_path):
    """Save Gaussians as PLY file using spherical harmonics for color"""
    
    def construct_list_of_attributes(num_rest: int) -> list[str]:
        '''Construct a list of attributes for the PLY file format'''
        attributes = ["x", "y", "z", "nx", "ny", "nz"]
        # Use spherical harmonics for color (first 3 coefficients = DC terms = RGB)
        for i in range(3):
            attributes.append(f"f_dc_{i}")
        for i in range(num_rest):
            attributes.append(f"f_rest_{i}")
        attributes.append("opacity")
        for i in range(3):
            attributes.append(f"scale_{i}")
        for i in range(4):
            attributes.append(f"rot_{i}")
        # No explicit RGB fields - color comes from f_dc_0, f_dc_1, f_dc_2
        return attributes

    def covariance_to_quaternion_and_scale(covariances):
        '''Convert the covariance matrix to quaternion and scale'''
        U, S, V = torch.linalg.svd(covariances)
        scale = torch.sqrt(S).detach().cpu().numpy()
        rotation_matrix = torch.bmm(U, V.transpose(-2, -1))
        rotation_matrix_np = rotation_matrix.detach().cpu().numpy()
        rotation = Rotation.from_matrix(rotation_matrix_np)
        quaternion = rotation.as_quat()
        return quaternion, scale

    def rgb_to_sh0(rgb):
        """Convert RGB color to spherical harmonic DC coefficient"""
        # SH DC coefficient is C0 = 1/(2*sqrt(pi))
        C0 = 0.28209479177387814
        return (rgb - 0.5) / C0

    # Collect the Gaussian parameters from the map
    means_list = []
    covariances_list = []
    harmonics_list = []
    opacities_list = []
    
    for pt_3d_tuple, gaussian_data in coords_to_gaussians_map.items():
        means_list.append(list(pt_3d_tuple))
        covariances_list.append(gaussian_data['covariances'].detach().cpu())
        
        # Use the original SH coefficients OR convert RGB to SH
        # if 'sh' in gaussian_data:
        #     # Use the predicted spherical harmonics as-is (original colors)
        #     sh_coeffs = gaussian_data['sh'].detach().cpu()
        #     harmonics_list.append(sh_coeffs)
        # else:
        # Convert RGB color to spherical harmonics DC coefficients
        rgb_color = gaussian_data['color']  # Original color from scene
        if torch.is_tensor(rgb_color):
            rgb_color = rgb_color.detach().cpu()
        else:
            rgb_color = torch.tensor(rgb_color)
        sh_dc = rgb_to_sh0(rgb_color)
        harmonics_list.append(sh_dc)
            
        opacities_list.append(gaussian_data['opacities'].detach().cpu())
    
    if len(means_list) == 0:
        print("ERROR: No Gaussians to save!")
        return
    
    # Convert to numpy arrays
    means = np.array(means_list)
    covariances = torch.stack(covariances_list, dim=0)
    harmonics = np.array(harmonics_list)
    opacities = np.array(opacities_list).reshape(-1, 1)
    
    print(f"Processing {len(means)} Gaussians")
    print(f"SH DC coefficients (first 3 = RGB): {harmonics[0][:3]}")
    
    # Convert covariances to quaternions and scales
    rotations, scales = covariance_to_quaternion_and_scale(covariances)
    
    # Construct the attributes
    rest = np.zeros_like(means)  # Normals
    float_attrs = np.concatenate((means, rest, harmonics, opacities, np.log(scales), rotations), axis=-1)
    
    # Create dtype - ONLY float fields (no explicit RGB)
    float_names = construct_list_of_attributes(0)  # num_rest=0 for now
    dtype_full = [(name, "f4") for name in float_names]
    
    elements = np.empty(float_attrs.shape[0], dtype=dtype_full)
    
    # Fill the structured array
    for i in range(float_attrs.shape[0]):
        elements[i] = tuple(float_attrs[i])

    # Save the point cloud
    point_cloud = PlyElement.describe(elements, "vertex")
    ply_data = PlyData([point_cloud])
    ply_data.write(save_path)
    
    print(f"Saved {len(elements)} Gaussians to {save_path}")
    print("Colors stored in spherical harmonics f_dc_0, f_dc_1, f_dc_2")

# Usage:
save_gaussians_as_ply(coords_to_gaussians_map, "pointclouds/final_maybe_clean_depth.ply")

In [None]:
from src.mast3r_src.demo import get_3D_model_from_scene

model = get_3D_model_from_scene(outdir="pointclouds", silent=False, min_conf_thr=1.5, 
                                as_pointcloud=True, clean_depth=True, transparent_cams=False, 
                                cam_size=0.2, TSDF_thresh=0.0, mask_sky=False, 
                                scene=scene )

In [None]:
[(len(col), pt.shape) for pt, col in zip(scene.pts3d, scene.pts3d_colors)]

In [None]:

for pt, col in zip(scene.pts3d, scene.pts3d_colors):
    for c in col:
        print(c)
        break
    break    

In [None]:
{'max':pred_1['sh'].max(),
 'min': pred_1['sh'].min()}

In [None]:
r_max, r_min = max(scene.pts3d_colors[0][i][0] for i in range(len(scene.pts3d_colors[0][0]))), min(scene.pts3d_colors[0][i][0] for i in range(len(scene.pts3d_colors[0][0])))
g_max, g_min = max(scene.pts3d_colors[0][i][1] for i in range(len(scene.pts3d_colors[0][0]))), min(scene.pts3d_colors[0][i][1] for i in range(len(scene.pts3d_colors[0][0])))
b_max, b_min = max(scene.pts3d_colors[0][i][2] for i in range(len(scene.pts3d_colors[0][0]))), min(scene.pts3d_colors[0][i][2] for i in range(len(scene.pts3d_colors[0][0])))
{'max': (r_max, g_max, b_max),
 'min': (r_min, g_min, b_min)}

In [None]:
scene.pts3d[0].shape

In [None]:
dense = scene.get_dense_pts3d(clean_depth=True)

In [None]:
dense[0][0].shape

In [None]:
for (img_index, x, y), pt in mapping.items():
    print(img_index, x, y, pt)
    break

In [None]:
mapping[(0,0,0)]