In [None]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import wandb
from omegaconf import OmegaConf

from torch_3dgs.data import read_data
from torch_3dgs.trainer import Trainer
from torch_3dgs.model import GaussianModel
from torch_3dgs.point import get_point_clouds
from torch_3dgs.utils import dict_to_device, visualize_points_plotly
import random
from typing import BinaryIO, Dict, List, Optional, Union

config = OmegaConf.load("config.yaml")
os.makedirs(config.output_folder, exist_ok=True)
device = torch.device(config.device)

data = read_data(config.data_folder, resize_scale=config.resize_scale)
# data = dict_to_device(data, device)



In [None]:
from torch_3dgs.camera import extract_camera_params
from torch_3dgs.point import PointCloud

cameras = data["camera"]
depths =  data["depth"]
alphas = data["alpha"]
rgbs= data["rgb"]

Hs, Ws, intrinsics, c2ws = extract_camera_params(cameras)
W, H = int(Ws[0].item()), int(Hs[0].item())
assert (depths.shape == alphas.shape)
coords = []
rgbas = []

#test for first image
intrinsic = intrinsics[0][:3, :3]   # shape: (3, 3)
c2w   = c2ws[0]                     # shape: (4, 4)
depth = depths[0]                   # shape: (H, W) 
alpha = alphas[0]                   # shape: (H, W)    
rgba = rgbs[0]                      # (H, W, 3)

for idx, h, w , intrinsic, c2w, depth, alpha in enumerate(zip(Hs, Ws, intrinsics, c2ws, depths, alphas)):

    if idx >= 10:
        break
    
    # 2) Create a grid of pixel coordinates in homogeneous form: (3, H*W)
    #    Here, we flatten them for easier matrix multiplication.
    i_coords = torch.arange(W, device=device)  # 0..W-1
    j_coords = torch.arange(H, device=device)  # 0..H-1
    i_grid, j_grid = torch.meshgrid(i_coords, j_coords, indexing="xy")  # shape (W, H) each

    # Flatten to (W*H,)
    i_flat = i_grid.flatten()
    j_flat = j_grid.flatten()
    pix_coords = torch.stack([i_flat, j_flat, torch.ones_like(i_flat)], dim=0)  # shape: (3, H*W)

    # 3) Multiply by the inverse of intrinsics to get camera directions (unnormalized).
    K_inv = torch.linalg.inv(intrinsic)         # shape: (3, 3)
    pix_coords = pix_coords.to(dtype=torch.float32)  # Ensure dtype is torch.float32
    # print(f'type of K_inv : {K_inv.dtype}')
    # print(f'type of pix_coords : {pix_coords.dtype}')
    cam_dirs = K_inv @ pix_coords               # shape (3, W*H)

    # 4) Multiply each ray direction by the corresponding depth to get actual camera-frame 3D coords.
    depth_flat = depth.flatten()                # shape (W*H,)
    cam_points_3D = cam_dirs * depth_flat       # shape (3, W*H)

    # 5) Convert to homogeneous camera coordinates: (4, W*H)
    ones = torch.ones(1, cam_points_3D.shape[1], device=device)
    cam_points_hom = torch.cat([cam_points_3D, ones], dim=0)  # shape (4, W*H)

    # 6) Transform these camera points into world coordinates using c2w (camera->world).
    world_points_hom = c2w @ cam_points_hom  # shape (4, W*H) do i need to inverse it ? 

    # 7) Divide by the last row (perspective division) to get 3D points in world coords.
    # world_points_3D = world_points_hom[:3]
    world_points_3D = world_points_hom[:3] / world_points_hom[3].unsqueeze(0)  # shape (3, W*H)

    # 8) Reshape to (H, W, 3) so it matches the image layout.
    rays_d = world_points_3D.permute(1, 0).reshape(H, W, 3)  # shape (H, W, 3)

    # 9) get rays center
    rays_o = np.broadcast_to(c2w[:3, 3], rays_d.shape)  # Shape: (H, W, 3)
    rays_o = torch.tensor(rays_o, dtype=torch.float32).to(depth.device)

    # pts = rays_o + rays_d * depths[0][..., np.newaxis]  # Shape: (H, W, 3)
    pts = rays_o + rays_d
    mask = alphas[0].bool()  # Shape: (H, W)
    valid_pts = pts[mask].cpu().numpy()  # Extract only valid 3D points
    coords.append(valid_pts)
    # print(valid_pts)

    channels = {}
    if rgba is not None:
        print(f'shape of rgba : {rgba.shape}')
        valid_rgb = rgba[mask]  # Shape: (N, 3)
        print(f'shape of valid_rgb : {valid_rgb.shape}')
        # Store each channel separately
        channels['R'] = valid_rgb[:, 0]
        channels['G'] = valid_rgb[:, 1]
        channels['B'] = valid_rgb[:, 2]


    point_cloud = PointCloud(coords, channels)



    

In [None]:
import plotly.graph_objects as go
from torch_3dgs.point import PointCloud

# Create a Plotly 3D scatter plot of the point cloud

for id, coords, channel in enumerate(zip(point_cloud.coords, point_cloud.channels)):
    print(f'shape of coords : {coords.shape}')
    print(f'shape of channels : {channel.shape}')

    coords_array = np.array(point_cloud.coords[id])
    # Prepare color values for Plotly visualization. If channels exist, create color strings.
    if channels:
        colors = [
            "rgb({},{},{})".format(int(r), int(g), int(b))
            for r, g, b in zip(channels['R'], channels['G'], channels['B'])
        ]
    else:
        colors = "blue"
    print(coords_array.shape)
    
    # Flatten the coordinates array to 1D
    coords_array = coords_array.reshape(-1, 3)
    scatter = go.Scatter3d(
        x=coords_array[:, 0],
        y=coords_array[:, 1],
        z=coords_array[:, 2],
        mode='markers',
        marker=dict(size=2, color=colors, opacity=0.8)
    )

    layout = go.Layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z'
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )

    fig = go.Figure(data=[scatter], layout=layout)
    # fig.show()
    fig.write_image(f"point_cloud_{id}.png", width=800, height=600)
    #save point cloud
    output_path = os.path.join(config.output_folder, f"point_cloud_{id}.ply")
    point_cloud.save(output_path)
    print(f"Point cloud saved to {output_path}")

