In [1]:
from model.model import PlaneEstimationModel
import h5py
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import json
import open3d as o3d
import random
from tqdm import tqdm
import os
from datetime import datetime
import numpy as np

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor


  __import__("pkg_resources").declare_namespace(__name__)

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.4 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/opt/conda/envs/neural_acd/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/envs/neural_acd/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/conda/envs/neural_acd/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/opt/conda/envs/neural_acd/lib/python3.10/site-packages/traitlets/confi

In [2]:
def apply_rotation_to_plane(a,b,c,d,rotation):
    normal = np.array([a, b, c])

    rotation = rotation[:3,:3]
    
    rotated_normal = rotation @ normal

    if np.linalg.norm(normal) == 0:
        raise ValueError("Invalid plane normal (0,0,0).")

    point_on_plane = -d * normal / np.linalg.norm(normal) ** 2 
    rotated_point = rotation @ point_on_plane 

    d_new = -np.dot(rotated_normal, rotated_point)

    if d_new < 0: #make the signs of coeffs consistent
        rotated_normal = -rotated_normal
        d_new = -d_new

    return rotated_normal[0], rotated_normal[1], rotated_normal[2], d_new 

class NeuralACDDataset(Dataset):
    def __init__(self,pc_folder,planes_folder,rotate=True):
        self.rotate =rotate
        with h5py.File(pc_folder, 'r') as f:
            self.data = f['point_clouds'][:]  # shape (N, 512, 3)
            self.hashes = f['hashes'][:]
        with open(planes_folder,'r') as f:
            self.labels = json.load(f)
            
            
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        points = self.data[idx]
        mesh_hash = self.hashes[idx].decode('utf-8')
        planes = self.labels[mesh_hash]

        if self.rotate and random.random() < 0.75:
            rotation = o3d.geometry.get_rotation_matrix_from_xyz(np.random.rand(3) * 2 * np.pi)
        else:
            rotation = np.eye(3)

        points = np.dot(points, rotation[:3,:3].T)

        points = points.transpose(1, 0)

        planes = [apply_rotation_to_plane(*plane[:4],rotation) for plane in planes]
        return points,planes

train_dataset = NeuralACDDataset("data/train_data.h5","data/plane_cache.json")
val_dataset = NeuralACDDataset("data/val_data.h5","data/plane_cache.json")

In [3]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [4]:
pl.seed_everything(42)

model = PlaneEstimationModel(learning_rate=1e-3)


callbacks = [
    ModelCheckpoint(monitor='val_loss',
        dirpath='checkpoints/',
        filename='best-model-{epoch:02d}-{val_loss:.2f}',
        save_top_k=3,
        mode='min'),
    LearningRateMonitor()]

trainer = pl.Trainer(
        devices="auto",
        accelerator="auto",
        callbacks=callbacks,
        max_epochs=100,
        log_every_n_steps=100,
        check_val_every_n_epoch=10
    )

# Start Training
trainer.fit(
    model=model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader
)

Global seed set to 42


AttributeError: `np.Inf` was removed in the NumPy 2.0 release. Use `np.inf` instead.