In [66]:
from utils.DataManager import DynamicGaussianDatasetManager
from utils.DynamicGaussianModel import GaussianModelTrainer, VoxelModelManager
from utils.Metrices import Metrics
from utils.StatSaver import StatSaver
from utils.Helpers import visulaize_point_cloud_6d_torch_array
import torch
import os
import time
from tqdm import tqdm
import matplotlib.pyplot as plt


exp_name = "start4"
dataset_name = "dynamic"
dataset_root_path = "/mnt/c/MyFiles/Datasets/dynamic/data"
output_path = "/home/anurag/codes/MV4D_reconstruction/output"
sequence = "basketball"
dataset_path = os.path.join(dataset_root_path, sequence)
output_path = os.path.join(output_path, exp_name, dataset_name, sequence)
dataset_manager = DynamicGaussianDatasetManager(dataset_path, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
metrics_manager = Metrics(dataset_manager)
stat_saver = StatSaver(os.path.join(output_path, "train_stats.csv"))
num_timesteps = dataset_manager.num_timesteps
print(f"Number of timesteps: {num_timesteps}")
print(f"Dataset for timestep 0: {dataset_manager.train_num_cams} cameras")
model = GaussianModelTrainer(dataset_manager)
# initial timestep optimization
curr_timestep = 0
num_iterations = 5000
start = time.time()

optimizer = model.initialize_optimizer(scaling_factor=15.0)
if os.path.exists(os.path.join(output_path, f"final_model.pth")):
    model.load_model(os.path.join(output_path, f"final_model.pth"))
else:
    pbar = tqdm(range(num_iterations), desc=f"Timestep {curr_timestep}")
    running_loss = 0.0
    for i in pbar:
        batch_loss = 0.0
        for j in range(dataset_manager.train_num_cams):
            # Ensure both tensors are on the same device
            loss, variables = model.get_loss_ij(curr_timestep, j)
            model.update_variables(variables)
            batch_loss += loss.item()
            loss.backward()
        with torch.no_grad():
            if i % 10 == 0:
                model.adaptive_densification(optimizer, j)
            optimizer.step()
            optimizer.zero_grad()
        running_loss = 0.9 * running_loss + 0.1 * batch_loss if i > 0 else batch_loss
        pbar.set_postfix({"loss": f"{running_loss:.6f}"})
    end = time.time()
    model.save_model(os.path.join(output_path, f"final_model.pth"))
    train_all_metrics = metrics_manager.get_metrics(model, curr_timestep, train=True)
    test_all_metrics = metrics_manager.get_metrics(model, curr_timestep, train=False)
    log_stat = [curr_timestep, num_iterations, train_all_metrics['avg_L1'], train_all_metrics['avg_PSNR'], train_all_metrics['avg_SSIM'], train_all_metrics['avg_LPIPS'], train_all_metrics['avg_MSSSIM'],        test_all_metrics['avg_L1'], test_all_metrics['avg_PSNR'], test_all_metrics['avg_SSIM'], test_all_metrics['avg_LPIPS'], test_all_metrics['avg_MSSSIM'],
                end - start, model.get_num_params(), model.get_model_size()]
    stat_saver.save_stat(log_stat)

i = 1
running_loss = 0.0
curr_timestep = i
num_iterations1 = 100
start = time.time()


all_points, remove_points, data_to_optimize = model.get_changes(curr_timestep)
print(all_points.shape, remove_points.shape)
print('Initial', model.gm.params['means3D'].shape)
VM = VoxelModelManager(model.avg_distance*2)
VM.initialize_points(model.gm.params['means3D'])
remove_points_indices = VM.indices_of_points_to_remove(remove_points)
remove_points_mask = torch.tensor(remove_points_indices, dtype=torch.bool, device=model.gm.params['means3D'].device)

  self.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)


Number of timesteps: 150
Dataset for timestep 0: 27 cameras


[32m2025-05-26 20:59:22.932[0m | [1mINFO    [0m | [36mptlflow[0m:[36mrestore_model[0m:[36m283[0m - [1mRestored model state from checkpoint: things[0m
  self.params = torch.load(path)


torch.Size([160578, 6]) torch.Size([516871, 3])
Initial torch.Size([194105, 3])


In [67]:
prev_points = model.gm.params['means3D'].detach().cpu().numpy()
points_to_remove = remove_points.detach().cpu().numpy()
import numpy as np
import open3d as o3d
point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(points_to_remove)
voxel_grid_to_remove = o3d.geometry.VoxelGrid.create_from_point_cloud(point_cloud, voxel_size=VM.voxel_size*4)
indices_to_remove = voxel_grid_to_remove.check_if_included(o3d.utility.Vector3dVector(prev_points))
indices_to_remove = np.array(indices_to_remove, dtype=bool)
# rc = remove_points.detach().cpu().numpy()


In [68]:
prev_points.shape, points_to_remove.shape, indices_to_remove.shape

((194105, 3), (516871, 3), (194105,))

In [69]:
new_points = prev_points[np.logical_not(indices_to_remove)]

In [70]:
new_points.shape

(66188, 3)

In [71]:
point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(new_points)
o3d.visualization.draw_geometries([point_cloud])

In [77]:
model.load_model(os.path.join(output_path, f"final_model.pth"))
point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(model.gm.params['means3D'].detach().cpu().numpy())
o3d.visualization.draw_geometries([point_cloud])
to_keep = torch.logical_not(torch.from_numpy(indices_to_remove))
for k, v in model.gm.params.items():
    if v.shape[0] == to_keep.shape[0]:
        model.gm.params[k] = v[to_keep]

In [78]:
point_cloud = o3d.geometry.PointCloud()
point_cloud.points = o3d.utility.Vector3dVector(model.gm.params['means3D'].detach().cpu().numpy())
o3d.visualization.draw_geometries([point_cloud])