In [1]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from lightning.pytorch import Trainer, seed_everything, callbacks
from lightning.pytorch.loggers import TensorBoardLogger

# Add the parent directory of NN_TopOpt to the system path
sys.path.append(os.path.abspath('..'))

from models.sdf_models import LitSdfAE, LitSdfAE_MINE
from models.sdf_models import AE, VAE, MMD_VAE
from models.sdf_models import AE_DeepSDF, VAE_DeepSDF, MMD_VAE_DeepSDF
from datasets.SDF_dataset import SdfDataset, SdfDatasetSurface, collate_fn_surface
from datasets.SDF_dataset import RadiusDataset
from datasets.SDF_dataset import ReconstructionDataset


# Enable anomaly detection to help find where NaN/Inf values originate
torch.autograd.set_detect_anomaly(True)

# Enable deterministic algorithms for better debugging
# torch.use_deterministic_algorithms(True)

# # Set debug mode for floating point operations
# torch.set_printoptions(precision=10, sci_mode=False)

# # Function to check for NaN/Inf values in tensors
# def check_tensor(tensor, tensor_name=""):
#     if torch.isnan(tensor).any():
#         print(f"NaN detected in {tensor_name}")
#         print(tensor)
#         raise ValueError(f"NaN detected in {tensor_name}")
#     if torch.isinf(tensor).any():
#         print(f"Inf detected in {tensor_name}") 
#         print(tensor)
#         raise ValueError(f"Inf detected in {tensor_name}")


# Add the parent directory of NN_TopOpt to the system path
sys.path.append(os.path.abspath('..'))

root_path = '../shape_datasets'

dataset_train_files = [f'{root_path}/triangle_reconstruction_dataset_train.csv',
                 f'{root_path}/quadrangle_reconstruction_dataset_train.csv',
                 f'{root_path}/ellipse_reconstruction_dataset_train.csv']

dataset_test_files = [f'{root_path}/triangle_reconstruction_dataset_test.csv',
                 f'{root_path}/quadrangle_reconstruction_dataset_test.csv',
                 f'{root_path}/ellipse_reconstruction_dataset_test.csv']

train_dataset = ReconstructionDataset(dataset_train_files)
test_dataset = ReconstructionDataset(dataset_test_files)


# Create DataLoaders with shuffling
batch_size = 64
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,  # Enable shuffling for training data
    num_workers=15
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=batch_size,
    shuffle=False,  # No need to shuffle test data
    num_workers=15
)


print(f"Training set size: {len(train_dataset)}")
print(f"Test set size: {len(test_dataset)}")


Training set size: 300000
Test set size: 15000


In [2]:
from models.sdf_models import LitSdfAE_Reconstruction

configs_dir = '../configs/NN_sdf_experiments/model_arch_minmi2'
models_dir = '../model_weights'

config_name = 'VAE_DeepSDF_minMI'
# uba_VAE_DeepSDF_minMI.pt

run_name = f'local_{config_name}_test_large'
# saved_model_path = f'{configs_dir}/{run_name}/checkpoints/epoch=0-step=0.ckpt'
saved_model_path = f'{models_dir}/uba_{config_name}.pt'

models = {'AE_DeepSDF': AE_DeepSDF,
          'AE': AE, 
          'VAE': VAE,
          'VAE_DeepSDF': VAE_DeepSDF,
          'MMD_VAE': MMD_VAE,
          'MMD_VAE_DeepSDF': MMD_VAE_DeepSDF}

In [6]:
# from lightning.pytorch.callbacks import Callback
torch.autograd.set_detect_anomaly(True)
import yaml

MAX_EPOCHS = 5
MAX_STEPS = MAX_EPOCHS * len(train_loader)

# Training setup
trainer = Trainer(
    max_epochs=MAX_EPOCHS, # the first epoch for training all model, the second one for training rec decoder
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(
        name='VAE_reconstructor', 
        save_dir='../logs', 
        default_hp_metric=False, 
        version=run_name
    ),
    callbacks=[
        callbacks.ModelCheckpoint(
            monitor='val_reconstruction_loss',
            mode='min',
            save_top_k=1,
            filename='best-model-{epoch:02d}-{val_reconstruction_loss:.2f}'
        ),
        callbacks.EarlyStopping(
            monitor='val_reconstruction_loss',
            patience=10,
            mode='min'
        ) #,
        # FirstEvalCallback()
    ],
    check_val_every_n_epoch=None,  # Disable validation every epoch
    val_check_interval=5000  # Perform validation every 2000 training steps
)

# Load configuration from YAML file
with open(f'{configs_dir}/{config_name}.yaml', 'r') as file:
    config = yaml.safe_load(file)

# Initialize VAE model
model_params = config['model']['params']
model_params['input_dim'] = 17 # train_dataset.feature_dim
vae_model = models[config['model']['type']](**model_params)

# Load pre-trained weights for the model
# pretrained_weights_path = config['model']['pretrained_weights_path'
state_dict = torch.load(saved_model_path)
new_state_dict = vae_model.state_dict()

# Update the new_state_dict with the loaded state_dict, ignoring size mismatches
for key in state_dict:
    if key in new_state_dict and state_dict[key].size() == new_state_dict[key].size():
        new_state_dict[key] = state_dict[key]

vae_model.load_state_dict(new_state_dict)

# Initialize VAE trainer
trainer_params = config['trainer']['params']
trainer_params['vae_model'] = vae_model
trainer_params['max_steps'] = MAX_STEPS
vae_trainer = LitSdfAE_Reconstruction(**trainer_params)

# Train the model
trainer.validate(vae_trainer, dataloaders=test_loader)
# trainer.fit(vae_trainer,  train_dataloaders=[train_loader, train_loader], val_dataloaders=[test_loader, surface_test_loader, radius_samples_loader])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


  state_dict = torch.load(saved_model_path)
/home/kalexu97/Projects/carpenter-sdf-topology-optimizer/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'vae_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['vae_model'])`.
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation DataLoader 0: 100%|██████████| 235/235 [00:01<00:00, 167.59it/s]


[{'val_reconstruction_loss': 0.2703113853931427}]

In [7]:
trainer.fit(vae_trainer,  train_loader, val_dataloaders=test_loader)
trainer.validate(vae_trainer, dataloaders=test_loader)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type        | Params | Mode 
---------------------------------------------
0 | vae  | VAE_DeepSDF | 2.3 M  | train
---------------------------------------------
334 K     Trainable params
1.9 M     Non-trainable params
2.3 M     Total params
9.015     Total estimated model params size (MB)
52        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Epoch 0:  22%|██▏       | 1019/4688 [00:09<00:32, 112.05it/s, v_num=arge, train_reconstruction_loss=0.0661]



Epoch 4: 100%|██████████| 4688/4688 [00:42<00:00, 110.59it/s, v_num=arge, train_reconstruction_loss=0.0975, val_reconstruction_loss=0.162] 

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 4688/4688 [00:42<00:00, 110.59it/s, v_num=arge, train_reconstruction_loss=0.0975, val_reconstruction_loss=0.162]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation DataLoader 0: 100%|██████████| 235/235 [00:01<00:00, 215.75it/s]


[{'val_reconstruction_loss': 0.16054707765579224}]

In [30]:
# Get a batch from test dataset
test_batch = next(iter(train_loader))
x, sdf, tau = test_batch

# Get reconstruction
with torch.no_grad():
    output = vae_trainer.vae(x, reconstruction=True)
    x_reconstructed = output["x_reconstructed"]

# Convert to numpy for printing
x_original = x[:, 2:].cpu().numpy()  # Remove first two columns (query points)
x_recon = x_reconstructed.cpu().numpy()

# Print first example
print("Original input:")
print(x_original[0])
print("\nReconstructed input:")
print(x_recon[0])
print("\nMean squared error:")
print(np.mean((x_original[0] - x_recon[0])**2))



Original input:
[ 1.          0.          0.          0.          0.          0.
  0.          0.6828931   0.7955297  -0.78050554  0.7546705   0.26917306
  0.20375885  0.1248081   0.354417  ]

Reconstructed input:
[ 0.9672451   0.01955883 -0.0208485   0.04493848  0.01949345  0.01587804
  0.01565391  0.35579953  0.2746787  -0.28623143  0.4095033   0.15522192
  0.5061113   0.17680027  0.20572081]

Mean squared error:
0.05838171
