# Model Evaluator
Use This Notebook to evaluate model runs.

In [1]:
import torch
import torch.nn as nn
import open3d as o3d
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
class PointCloudAutoEncoder(nn.Module):
    def __init__(self, input_dim=32):
        super(PointCloudAutoEncoder, self).__init__()

        # Encoder: Maps input point cloud to a latent representation
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True),
            nn.Conv3d(16, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv3d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
        )

        stride_factor = 2*2*2

        # Fully connected layers to create a bottleneck
        grid_count = (input_dim // stride_factor) ** 3 # 4 = amount of padding (multiply them)
        self.flatBottleneck = nn.Sequential(
            nn.Flatten(),  # Flatten to (N, 64 * grid_count)
            nn.Dropout(p=0.2),
            nn.Linear(32 * grid_count, 16 * grid_count),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Linear(16 * grid_count, 16 * grid_count),
            nn.BatchNorm1d(16 * grid_count),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.2),
            nn.Linear(16 * grid_count, 32 * grid_count),
            nn.ReLU(inplace=True),
            nn.Unflatten(dim=1, unflattened_size=(32, input_dim // stride_factor, input_dim // stride_factor, input_dim // stride_factor))  # Reshape back
        )

        # Decoder: Maps latent representation back to point cloud
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(16, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm3d(16),
            nn.ReLU(inplace=True),
            nn.ConvTranspose3d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),  # Output values in [0, 1]
        )

    def forward(self, x):
        original_input = x  # Save the original input to use for masking

        x = self.encoder(x)  # Encode spatial features
        x = self.flatBottleneck(x)
        x = self.decoder(x)  # Decode back to voxel grid

        # Mask for original voxels that are 1
        mask = original_input > 0.8
        x = torch.where(mask, torch.tensor(1.0).to(x.device), x)
        return x

In [11]:
"""
Returns the tensor given a point cloud.
Also uses min and max bound to avoid empty space in the tensor/pcd
copied from: alex_ml_model_experiments_voxel_grid notebook dataset class
"""
def get_3d_tensor_from_pcd(pcd):
        points = np.asarray(pcd.points)
        min_bound = np.min(points, axis=0)
        max_bound = np.max(points, axis=0)
        grid_size = 32 # TODO IN PARAMS
        voxel_size = (max_bound - min_bound) / grid_size
        
        normalized_points = (points - min_bound) / voxel_size
        grid_points = np.floor(normalized_points).astype(int)
        grid_points = np.clip(grid_points, 0, grid_size - 1)
        grid_tensor = torch.zeros((grid_size, grid_size, grid_size), dtype=torch.int32)
        for point in grid_points:
            grid_tensor[tuple(point)] = 1
        return grid_tensor.float()

def visualize_3d_tensor(voxel_tensor, threshold=0.5):
    normalized_tensor = torch.where(voxel_tensor > threshold, 1, 0)
    occupied_indices = np.argwhere(normalized_tensor.numpy() > 0)
    point_cloud = o3d.geometry.PointCloud()
    point_cloud.points = o3d.utility.Vector3dVector(occupied_indices)
    o3d.visualization.draw_geometries([point_cloud], width=500, height=500)

"""
Visualize the results of a PCD using a given model
"""
def predict_and_visualize(input_pcd_path, truth_pcd_path, model, threshold=0.5):
    input_pcd = o3d.io.read_point_cloud(input_pcd_path)
    truth_pcd = o3d.io.read_point_cloud(truth_pcd_path)
    input_tensor = get_3d_tensor_from_pcd(input_pcd).to(device)
    if truth_pcd_path != "":
        truth_tensor = get_3d_tensor_from_pcd(truth_pcd).to(device)
    
    model.eval()
    with torch.no_grad():
        input_tensor = input_tensor.unsqueeze(0).unsqueeze(0) # Add batch dimension + channel
        reconstructed_tensor = model(input_tensor)
        # Visualize
        voxel_tensor = input_tensor.squeeze(0).squeeze(0).cpu()
        visualize_3d_tensor(voxel_tensor, threshold)
        voxel_tensor = reconstructed_tensor.squeeze(0).squeeze(0).cpu()
        visualize_3d_tensor(voxel_tensor, threshold)
        #print(truth_tensor.shape)
        if truth_tensor is not None:
            visualize_3d_tensor(truth_tensor.cpu())
        

In [9]:
# If the whole model was saved
#model = torch.load("../assets/model_exports/model_epoch_9.pth")
model = torch.load("../scripts/alex/alex_model_mask_10ktrain.pth")
model = model.to(device)

# If only the state dict was saved
#model = PointCloudAutoEncoder()
#state_dict = torch.load("../assets/model_exports/model_v1.pt")
#model.load_state_dict(state_dict)

  model = torch.load("../scripts/alex/alex_model_mask_10ktrain.pth")


In [18]:
input_path = "../assets/voxel10000/29330_cut.ply"
full_path = "../assets/voxel10000/29330_full.ply"
predict_and_visualize(input_path, full_path, model, threshold=0.2)

input_path = "../assets/unclean-estimated-pcds/pc_generator_frame_50.ply"
full_path = ""
predict_and_visualize(input_path, full_path, model, threshold=0.2)



UnboundLocalError: local variable 'truth_tensor' referenced before assignment