In [None]:
from main import MAST3RGaussians
import sys
from pathlib import Path
import torch
import torchvision.transforms as tfm
from typing import Tuple, Union, Optional
from natsort import natsorted
import os
from huggingface_hub import hf_hub_download
import numpy as np

sys.path.append('src/mast3r_src')
sys.path.append('src/mast3r_src/dust3r')
sys.path.append('src/pixelsplat_src')

from dust3r.utils.image import load_images
from mast3r.model import AsymmetricMASt3R
from mast3r.utils.misc import hash_md5

In [None]:
from model_replacement_test import MASt3R

gen3d = MASt3R(imgdir=Path("./images"), outdir=Path("./pointclouds"))

# run multi-view Mast3R SfM
scene = gen3d.reconstruct_scene(outdir=str(gen3d.outdir),
                                cache_dir="mast3r_cache",
                                scene_graph="swin-3-noncyclic",
                                optim_level="refine+depth",
                                lr1=0.07, niter1=300,
                                lr2=0.01, niter2=300)

In [None]:
coords_to_gaussians_map = {}

def map_dense_pts3d_to_pixels_with_colors(scene):
    pts3d_dense, _, _ = scene.get_dense_pts3d(clean_depth=True)

    H, W = 512, 512
    result = []

    for img_idx, pts3d_img in enumerate(pts3d_dense):
        img = scene.imgs[img_idx]

        # Create a 2D array (list of lists)
        arr_2d = [[None for _ in range(W)] for _ in range(H)]

        for y in range(H):
            for x in range(W):
                linear_idx = y * W + x
                if linear_idx < len(pts3d_img):
                    arr_2d[y][x] = {
                        "pt_3d": pts3d_img[linear_idx],
                        "color": img[y, x]
                    }
                else:
                    arr_2d[y][x] = None  # or a dict with None values if you prefer

        result.append(arr_2d)

    return result


mapping = map_dense_pts3d_to_pixels_with_colors(scene)

In [None]:
mapping[0][511][511]

In [None]:
def get_image_hashes_list(scene):
    image_hashes = [None]*len(scene.img_paths)
    for i, img_path in enumerate(scene.img_paths):
        img_hash = hash_md5(img_path)
        image_hashes[i] = img_hash
        # print(f"Image {i}: {img_path} -> {img_hash}")
    
    return image_hashes

def load_gaussian(scene, cache_dir, index):
    image_hahses = get_image_hashes_list(scene)
    return torch.load(f"{cache_dir}/gaussian_attributes/{image_hahses[index]}.pth")

### RGB source = images

In [None]:
import trimesh
from src.mast3r_src.dust3r.dust3r.viz import OPENGL, pts3d_to_trimesh, cat_meshes
from plyfile import PlyData, PlyElement
from scipy.spatial.transform import Rotation
import einops
import utils.geometry as geometry

def save_to_ply_optim(mapping, scene, cache_dir, save_path):
    def construct_list_of_attributes(num_rest: int) -> list[str]:
        '''Construct a list of attributes for the PLY file format'''
        attributes = ["x", "y", "z", "nx", "ny", "nz"]
        # Use spherical harmonics for color (first 3 coefficients = DC terms = RGB)
        for i in range(3):
            attributes.append(f"f_dc_{i}")
        for i in range(num_rest):
            attributes.append(f"f_rest_{i}")
        attributes.append("opacity")
        for i in range(3):
            attributes.append(f"scale_{i}")
        for i in range(4):
            attributes.append(f"rot_{i}")
        return attributes

    def covariance_to_quaternion_and_scale(covariances):
        '''Convert the covariance matrix to quaternion and scale'''
        U, S, V = torch.linalg.svd(covariances)
        scale = torch.sqrt(S).detach().cpu().numpy()
        rotation_matrix = torch.bmm(U, V.transpose(-2, -1))
        rotation_matrix_np = rotation_matrix.detach().cpu().numpy()
        rotation = Rotation.from_matrix(rotation_matrix_np)
        quaternion = rotation.as_quat()
        return quaternion, scale

    def rgb_to_sh0(rgb):
        """Convert RGB color to spherical harmonic DC coefficient"""
        C0 = 0.28209479177387814
        return (rgb - 0.5) / C0

    # Collect all valid gaussians incrementally
    all_float_attrs = []
    
    for idx, image_map in enumerate(mapping):
        print(f"Processing image {idx}...")
        
        # Load gaussian attributes for THIS image only
        (sh, scales, rotations, opacities_orig, means_orig) = load_gaussian(scene, cache_dir, idx)
        covariances = geometry.build_covariance(scales, rotations)
        
        # Move to CPU and squeeze
        covariances_cpu = covariances.squeeze().detach().cpu()
        opacities_cpu = opacities_orig.squeeze().detach().cpu()
        scales_cpu = scales.squeeze().detach().cpu()
        rotations_cpu = rotations.squeeze().detach().cpu()

        # Process each pixel in this image
        image_means = []
        image_covariances = []
        image_harmonics = []
        image_opacities = []
        
        for y, row in enumerate(image_map):
            for x, cell in enumerate(row):
                if cell is None:
                    continue
                    
                # 3D point position
                image_means.append(cell['pt_3d'].detach().cpu().numpy())
                
                # Gaussian parameters at this pixel
                image_covariances.append(covariances_cpu[y, x])
                image_opacities.append(opacities_cpu[y, x])
                
                # Convert RGB color to spherical harmonics
                rgb_color = cell['color']
                if torch.is_tensor(rgb_color):
                    rgb_color = rgb_color.detach().cpu()
                sh_dc = rgb_to_sh0(rgb_color)
                image_harmonics.append(sh_dc)
        
        if len(image_means) == 0:
            print(f"No valid gaussians in image {idx}, skipping...")
            # Clean up for this image
            del sh, scales, rotations, opacities_orig, means_orig, covariances
            del covariances_cpu, opacities_cpu, scales_cpu, rotations_cpu
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            continue
            
        print(f"Found {len(image_means)} gaussians in image {idx}")
        
        # Convert to arrays for this image
        means = np.array(image_means)
        covariances_tensor = torch.stack(image_covariances, dim=0)
        harmonics = np.array(image_harmonics)
        opacities = np.array(image_opacities).reshape(-1, 1)
        
        # Convert covariances to quaternions and scales
        rotations_converted, scales_converted = covariance_to_quaternion_and_scale(covariances_tensor)
        
        # Create normals (zeros for now)
        normals = np.zeros_like(means)
        
        # Combine all attributes for this image
        image_float_attrs = np.concatenate([
            means,                    # x, y, z
            normals,                  # nx, ny, nz  
            harmonics,                # f_dc_0, f_dc_1, f_dc_2
            opacities,                # opacity
            np.log(scales_converted), # scale_0, scale_1, scale_2 (log space)
            rotations_converted       # rot_0, rot_1, rot_2, rot_3 (quaternion)
        ], axis=-1)
        
        # Add to overall collection
        all_float_attrs.extend(image_float_attrs)
        
        # Free memory for this image - be careful about variable names
        del means, covariances_tensor, harmonics, opacities
        del rotations_converted, scales_converted, normals, image_float_attrs
        del sh, scales, rotations, opacities_orig, means_orig, covariances
        del covariances_cpu, opacities_cpu, scales_cpu, rotations_cpu
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    if len(all_float_attrs) == 0:
        print("ERROR: No valid Gaussians found across all images!")
        return
    
    print(f"Total gaussians collected: {len(all_float_attrs)}")
    
    # Convert to final array
    float_attrs = np.array(all_float_attrs)
    
    # Create PLY dtype
    float_names = construct_list_of_attributes(0)
    dtype_full = [(name, "f4") for name in float_names]
    
    # Create structured array
    elements = np.empty(float_attrs.shape[0], dtype=dtype_full)
    for i in range(float_attrs.shape[0]):
        elements[i] = tuple(float_attrs[i])
    
    # Save PLY file
    point_cloud = PlyElement.describe(elements, "vertex")
    ply_data = PlyData([point_cloud])
    ply_data.write(save_path)
    
    print(f"Saved {len(elements)} Gaussians to {save_path}")

# Usage:
save_to_ply_optim(mapping, scene, "mast3r_cache", "pointclouds/imgrgbtest.ply")

### RGB source = spherical harmonics

In [None]:
def save_to_ply_optim(mapping, scene, cache_dir, save_path):
    def construct_list_of_attributes(num_rest: int) -> list[str]:
        '''Construct a list of attributes for the PLY file format'''
        attributes = ["x", "y", "z", "nx", "ny", "nz"]
        # Use spherical harmonics for color (first 3 coefficients = DC terms = RGB)
        for i in range(3):
            attributes.append(f"f_dc_{i}")
        for i in range(num_rest):
            attributes.append(f"f_rest_{i}")
        attributes.append("opacity")
        for i in range(3):
            attributes.append(f"scale_{i}")
        for i in range(4):
            attributes.append(f"rot_{i}")
        return attributes

    def covariance_to_quaternion_and_scale(covariances):
        '''Convert the covariance matrix to quaternion and scale'''
        U, S, V = torch.linalg.svd(covariances)
        scale = torch.sqrt(S).detach().cpu().numpy()
        rotation_matrix = torch.bmm(U, V.transpose(-2, -1))
        rotation_matrix_np = rotation_matrix.detach().cpu().numpy()
        rotation = Rotation.from_matrix(rotation_matrix_np)
        quaternion = rotation.as_quat()
        return quaternion, scale

    def sh0_to_rgb(sh_dc):
        """Convert spherical harmonic DC coefficient back to RGB"""
        C0 = 0.28209479177387814
        return sh_dc * C0 + 0.5

    # Collect all valid gaussians incrementally
    all_float_attrs = []
    
    for idx, image_map in enumerate(mapping):
        print(f"Processing image {idx}...")
        
        # Load gaussian attributes for THIS image only
        (sh, scales, rotations, opacities_orig, means_orig) = load_gaussian(scene, cache_dir, idx)
        covariances = geometry.build_covariance(scales, rotations)
        
        # Move to CPU and squeeze
        covariances_cpu = covariances.squeeze().detach().cpu()
        opacities_cpu = opacities_orig.squeeze().detach().cpu()
        scales_cpu = scales.squeeze().detach().cpu()
        rotations_cpu = rotations.squeeze().detach().cpu()
        sh_cpu = sh.squeeze().detach().cpu()  # Add SH to CPU

        # Process each pixel in this image
        image_means = []
        image_covariances = []
        image_harmonics = []
        image_opacities = []
        
        for y, row in enumerate(image_map):
            for x, cell in enumerate(row):
                if cell is None:
                    continue
                    
                # 3D point position
                image_means.append(cell['pt_3d'].detach().cpu().numpy())
                
                # Gaussian parameters at this pixel
                image_covariances.append(covariances_cpu[y, x])
                image_opacities.append(opacities_cpu[y, x])
                
                # OPTION 2: Convert SH back to RGB (for standard point cloud viewers)
                sh_values = sh_cpu[y, x, :3].numpy()
                # rgb_values = sh0_to_rgb(sh_values)
                rgb_values = sh_values
                rgb_values = np.clip(rgb_values, 0, 1)  # Ensure valid RGB range
                image_harmonics.append(rgb_values)
        
        if len(image_means) == 0:
            print(f"No valid gaussians in image {idx}, skipping...")
            # Clean up for this image
            del sh, scales, rotations, opacities_orig, means_orig, covariances
            del covariances_cpu, opacities_cpu, scales_cpu, rotations_cpu, sh_cpu
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            continue
            
        print(f"Found {len(image_means)} gaussians in image {idx}")
        
        # Convert to arrays for this image
        means = np.array(image_means)
        covariances_tensor = torch.stack(image_covariances, dim=0)
        harmonics = np.array(image_harmonics)  # Already numpy from SH cache
        opacities = np.array(image_opacities).reshape(-1, 1)
        
        # Convert covariances to quaternions and scales
        rotations_converted, scales_converted = covariance_to_quaternion_and_scale(covariances_tensor)
        
        # Create normals (zeros for now)
        normals = np.zeros_like(means)
        
        # Combine all attributes for this image
        image_float_attrs = np.concatenate([
            means,                    # x, y, z
            normals,                  # nx, ny, nz  
            harmonics,                # f_dc_0, f_dc_1, f_dc_2 (converted RGB values)
            opacities,                # opacity
            np.log(scales_converted), # scale_0, scale_1, scale_2 (log space)
            rotations_converted       # rot_0, rot_1, rot_2, rot_3 (quaternion)
        ], axis=-1)
        
        # Add to overall collection
        all_float_attrs.extend(image_float_attrs)
        
        # Free memory for this image
        del means, covariances_tensor, harmonics, opacities
        del rotations_converted, scales_converted, normals, image_float_attrs
        del sh, scales, rotations, opacities_orig, means_orig, covariances
        del covariances_cpu, opacities_cpu, scales_cpu, rotations_cpu, sh_cpu
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    if len(all_float_attrs) == 0:
        print("ERROR: No valid Gaussians found across all images!")
        return
    
    print(f"Total gaussians collected: {len(all_float_attrs)}")
    
    # Convert to final array
    float_attrs = np.array(all_float_attrs)
    
    # Create PLY dtype
    float_names = construct_list_of_attributes(0)
    dtype_full = [(name, "f4") for name in float_names]
    
    # Create structured array
    elements = np.empty(float_attrs.shape[0], dtype=dtype_full)
    for i in range(float_attrs.shape[0]):
        elements[i] = tuple(float_attrs[i])
    
    # Save PLY file
    point_cloud = PlyElement.describe(elements, "vertex")
    ply_data = PlyData([point_cloud])
    ply_data.write(save_path)
    
    print(f"Saved {len(elements)} Gaussians to {save_path}")

# Usage with converted RGB:
save_to_ply_optim(mapping, scene, "mast3r_cache", "pointclouds/sh_converted_rgb.ply")

### Save Dense Pointcloud

In [None]:
def save_pointcloud_from_mapping_binary(mapping, save_path):
    """
    Save a colored pointcloud (PLY binary) - much smaller than ASCII.
    """
    def _to_numpy(x):
        if x is None:
            return None
        if isinstance(x, torch.Tensor):
            return x.detach().cpu().numpy()
        return np.asarray(x)
    
    # Collect all points and colors
    points = []
    colors = []
    
    for idx, image_map in enumerate(mapping):
        print(f"Processing image {idx}...")
        
        for y, row in enumerate(image_map):
            for x, cell in enumerate(row):
                if cell is None:
                    continue
                
                pt_3d = _to_numpy(cell['pt_3d'])
                color = _to_numpy(cell['color'])
                
                points.append(pt_3d)
                
                # Convert to uint8
                if color.dtype.kind == "f":
                    color_uint8 = (np.clip(color, 0.0, 1.0) * 255.0).astype(np.uint8)
                else:
                    color_uint8 = color.astype(np.uint8)
                colors.append(color_uint8)
    
    if len(points) == 0:
        raise ValueError("No valid points found")
    
    print(f"Found {len(points)} points, saving as binary PLY...")
    
    # Convert to structured numpy array
    vertex_data = np.array([
        (pt[0], pt[1], pt[2], col[0], col[1], col[2])
        for pt, col in zip(points, colors)
    ], dtype=[
        ('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
        ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')
    ])
    
    # Save using plyfile
    vertex_element = PlyElement.describe(vertex_data, 'vertex')
    PlyData([vertex_element], text=False).write(save_path)  # text=False = binary
    
    print(f"Saved {len(vertex_data)} points to {save_path} (binary)")
    return save_path

# Usage:
save_pointcloud_from_mapping_binary(mapping, "pointclouds/binary_pointcloud.ply")

In [None]:
gaussians = torch.load("mast3r_cache/canon_gaussians/0ad929c4b83d461d351cfe97d8cb7558_subsample=8.pth")
gaussians.keys()

In [None]:
gaussians['sh'].shape