In [1]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from lightning.pytorch import Trainer, seed_everything, callbacks
from lightning.pytorch.loggers import TensorBoardLogger

# Add the parent directory of NN_TopOpt to the system path
sys.path.append(os.path.abspath('..'))

from models.sdf_models import LitSdfAE, LitSdfAE_MINE
from models.sdf_models import AE, VAE, MMD_VAE
from models.sdf_models import AE_DeepSDF, VAE_DeepSDF, MMD_VAE_DeepSDF
from datasets.SDF_dataset import SdfDataset, SdfDatasetSurface, collate_fn_surface
from datasets.SDF_dataset import RadiusDataset
from datasets.SDF_dataset import ReconstructionDataset
from models.sdf_models import LitRecon_MINE


# Enable anomaly detection to help find where NaN/Inf values originate
torch.autograd.set_detect_anomaly(True)

# Enable deterministic algorithms for better debugging
# torch.use_deterministic_algorithms(True)

# # Set debug mode for floating point operations
# torch.set_printoptions(precision=10, sci_mode=False)

# # Function to check for NaN/Inf values in tensors
# def check_tensor(tensor, tensor_name=""):
#     if torch.isnan(tensor).any():
#         print(f"NaN detected in {tensor_name}")
#         print(tensor)
#         raise ValueError(f"NaN detected in {tensor_name}")
#     if torch.isinf(tensor).any():
#         print(f"Inf detected in {tensor_name}") 
#         print(tensor)
#         raise ValueError(f"Inf detected in {tensor_name}")


# Add the parent directory of NN_TopOpt to the system path
sys.path.append(os.path.abspath('..'))

root_path = '../shape_datasets'

dataset_train_files = [f'{root_path}/ellipse_reconstruction_dataset_train.csv',
                       f'{root_path}/triangle_reconstruction_dataset_train.csv',
                        f'{root_path}/quadrangle_reconstruction_dataset_train.csv',
                        ]

dataset_test_files = [f'{root_path}/ellipse_reconstruction_dataset_test.csv',
                      f'{root_path}/triangle_reconstruction_dataset_test.csv',
                      f'{root_path}/quadrangle_reconstruction_dataset_test.csv',
                 ]

train_dataset = ReconstructionDataset(dataset_train_files)
test_dataset = ReconstructionDataset(dataset_test_files)


# Create DataLoaders with shuffling
batch_size = 64
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,  # Enable shuffling for training data
    num_workers=15
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=1024,
    shuffle=True,  # No need to shuffle test data
    num_workers=15
)


print(f"Training set size: {len(train_dataset)}")
print(f"Test set size: {len(test_dataset)}")


Training set size: 900000
Test set size: 15000


In [2]:
from models.sdf_models import LitSdfAE_Reconstruction

configs_dir = '../configs/NN_sdf_experiments/recon_MILoss'
models_dir = '../model_weights'

config_name = 'AE_DeepSDF_ReconDec'
# config_name = 'VAE_DeepSDF_ReconDec'
# uba_VAE_DeepSDF_minMI.pt

run_name = f'local_{config_name}_Midtest11'
# saved_model_path = f'{configs_dir}/{run_name}/checkpoints/epoch=0-step=0.ckpt'
saved_model_path = f'{models_dir}/uba_{config_name}.pt'

models = {'AE_DeepSDF': AE_DeepSDF,
          'AE': AE, 
          'VAE': VAE,
          'VAE_DeepSDF': VAE_DeepSDF,
          'MMD_VAE': MMD_VAE,
          'MMD_VAE_DeepSDF': MMD_VAE_DeepSDF}

In [3]:
# from lightning.pytorch.callbacks import Callback
torch.autograd.set_detect_anomaly(True)
import yaml

MAX_EPOCHS = 1
MAX_STEPS = MAX_EPOCHS * len(train_loader)

# Training setup
trainer = Trainer(
    max_epochs=MAX_EPOCHS, # the first epoch for training all model, the second one for training rec decoder
    accelerator='auto',
    devices=1,
    logger=TensorBoardLogger(
        name='VAE_MI_ReconDec', 
        save_dir='../logs', 
        default_hp_metric=False, 
        version=run_name
    ),
    callbacks=[
        callbacks.ModelCheckpoint(
            monitor='val_reconstruction_loss',
            mode='min',
            save_top_k=1,
            filename='best-model-{epoch:02d}-{val_reconstruction_loss:.2f}'
        ),
        callbacks.EarlyStopping(
            monitor='val_reconstruction_loss',
            patience=10,
            mode='min'
        ) #,
        # FirstEvalCallback()
    ],
    check_val_every_n_epoch=None,  # Disable validation every epoch
    val_check_interval=5000  # Perform validation every 2000 training steps
)

# Load configuration from YAML file
with open(f'{configs_dir}/{config_name}.yaml', 'r') as file:
    config = yaml.safe_load(file)

# Initialize VAE model
model_params = config['model']['params']
model_params['input_dim'] = 17 # train_dataset.feature_dim
vae_model = models[config['model']['type']](**model_params)

# Load pre-trained weights for the model
# pretrained_weights_path = config['model']['pretrained_weights_path'
# state_dict = torch.load(saved_model_path)
# new_state_dict = vae_model.state_dict()

# # Update the new_state_dict with the loaded state_dict, ignoring size mismatches
# for key in state_dict:
#     if key in new_state_dict and state_dict[key].size() == new_state_dict[key].size():
#         new_state_dict[key] = state_dict[key]

# vae_model.load_state_dict(new_state_dict)

# Initialize VAE trainer
trainer_params = config['trainer']['params']
trainer_params['vae_model'] = vae_model
trainer_params['max_steps'] = MAX_STEPS
vae_trainer = LitRecon_MINE(**trainer_params)

# Train the model
trainer.validate(vae_trainer, dataloaders=test_loader)
# trainer.fit(vae_trainer,  train_dataloaders=[train_loader, train_loader], val_dataloaders=[test_loader, surface_test_loader, radius_samples_loader])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/kalexu97/Projects/carpenter-sdf-topology-optimizer/.venv/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:208: Attribute 'vae_model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['vae_model'])`.
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


regularization: l2, reg_weight: 0.1
Using orthogonality loss: None
Freezing weights SDF decoder


/home/kalexu97/Projects/carpenter-sdf-topology-optimizer/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Validation DataLoader 0: 100%|██████████| 15/15 [00:00<00:00, 19.18it/s]


[{'val_total_loss': 0.10786858201026917,
  'val_tau_loss': 0.390581339597702,
  'val_reg_loss': 0.0010059033520519733,
  'val_reconstruction_loss': 0.0687098577618599,
  'val_ortho_mi_original': 0.45854923129081726,
  'val_ortho_mi_tau': 0.5285501480102539,
  'val_ortho_z_original_std': 0.018209142610430717,
  'val_ortho_z_tau_std': 0.015603490173816681,
  'val_ortho_z_std_ratio': 1.1677333116531372,
  'val_ortho_mi_ratio': 0.8674944043159485}]

In [4]:
trainer.fit(vae_trainer, train_loader, val_dataloaders=test_loader)
trainer.validate(vae_trainer, dataloaders=test_loader)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type        | Params | Mode 
-------------------------------------------------
0 | vae      | AE_DeepSDF  | 2.5 M  | train
1 | mine     | MINE_Critic | 17.7 K | train
2 | mine_tau | MINE_Critic | 17.3 K | train
-------------------------------------------------
704 K     Trainable params
1.8 M     Non-trainable params
2.6 M     Total params
10.208    Total estimated model params size (MB)
64        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/kalexu97/Projects/carpenter-sdf-topology-optimizer/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Epoch 0:   7%|▋         | 1010/14063 [00:19<04:11, 51.97it/s, v_num=st11, train_total_loss=-0.00802, train_tau_loss=0.00955, train_reg_loss=0.204, train_reconstruction_loss=0.0101, train_info_xz=0.0645, train_info_xz_tau=0.0251, lr_vae=0.0001, lr_mine=9.28e-6, lr_mine_tau=9.28e-6] 



Epoch 0: 100%|██████████| 14063/14063 [04:20<00:00, 53.91it/s, v_num=st11, train_total_loss=-0.258, train_tau_loss=0.003, train_reg_loss=0.139, train_reconstruction_loss=0.00139, train_info_xz=1.45e-5, train_info_xz_tau=-0.273, lr_vae=0.000, lr_mine=0.000, lr_mine_tau=0.000, val_total_loss=0.0138, val_tau_loss=0.000761, val_reg_loss=0.133, val_reconstruction_loss=0.000503, val_ortho_mi_original=0.265, val_ortho_mi_tau=1.440, val_ortho_z_original_std=0.0641, val_ortho_z_tau_std=0.404, val_ortho_z_std_ratio=0.159, val_ortho_mi_ratio=0.183]           

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 14063/14063 [04:20<00:00, 53.91it/s, v_num=st11, train_total_loss=-0.258, train_tau_loss=0.003, train_reg_loss=0.139, train_reconstruction_loss=0.00139, train_info_xz=1.45e-5, train_info_xz_tau=-0.273, lr_vae=0.000, lr_mine=0.000, lr_mine_tau=0.000, val_total_loss=0.0138, val_tau_loss=0.000761, val_reg_loss=0.133, val_reconstruction_loss=0.000503, val_ortho_mi_original=0.265, val_ortho_mi_tau=1.440, val_ortho_z_original_std=0.0641, val_ortho_z_tau_std=0.404, val_ortho_z_std_ratio=0.159, val_ortho_mi_ratio=0.183]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/kalexu97/Projects/carpenter-sdf-topology-optimizer/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:475: Your `val_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Validation DataLoader 0: 100%|██████████| 15/15 [00:00<00:00, 24.58it/s]


[{'val_total_loss': 0.013464599847793579,
  'val_tau_loss': 0.0008184274192899466,
  'val_reg_loss': 0.1287045031785965,
  'val_reconstruction_loss': 0.0005123050650581717,
  'val_ortho_mi_original': 0.2830864489078522,
  'val_ortho_mi_tau': 1.6353726387023926,
  'val_ortho_z_original_std': 0.05837404355406761,
  'val_ortho_z_tau_std': 0.405010461807251,
  'val_ortho_z_std_ratio': 0.14414367079734802,
  'val_ortho_mi_ratio': 0.17319339513778687}]

In [6]:
# Save model weights
checkpoint_path = f'../model_weights/{run_name}.ckpt'
trainer.save_checkpoint(checkpoint_path)
print(f"Model weights saved to {checkpoint_path}")

# Save just the model weights
model_weights_path = f'../model_weights/{run_name}.pt'
torch.save(vae_model.state_dict(), model_weights_path)
print(f"Model weights saved to {model_weights_path}")

Model weights saved to ../model_weights/local_AE_DeepSDF_ReconDec_Midtest11.ckpt
Model weights saved to ../model_weights/local_AE_DeepSDF_ReconDec_Midtest11.pt


In [6]:
# Get a batch from test dataset
test_batch = next(iter(test_loader))
x, sdf, tau = test_batch

# Get reconstruction
with torch.no_grad():
    output = vae_trainer.vae(x, reconstruction=True)
    x_reconstructed = output["x_reconstructed"]

# Convert to numpy for printing
x_original = x[:, 2:].cpu().numpy()  # Remove first two columns (query points)
x_recon = x_reconstructed.cpu().numpy()

# Print first example
print("Original input:")
print(x_original[0])
print("\nReconstructed input:")
print(x_recon[0])

print("\nMean squared error:")
print(np.mean((x_original[0] - x_recon[0])**2))



Original input:
[ 0.5         0.         -0.38978198  0.518743    0.00532513  0.01095491
  0.06030432  0.          0.          0.          0.          0.
  0.          0.          0.        ]

Reconstructed input:
[ 0.5187373   0.01557274 -0.07140062  0.4800849   0.02134559  0.00292744
  0.01397747  0.01079088  0.00570327  0.00451591  0.01464799  0.00974598
  0.02305626  0.00395259  0.01383292]

Mean squared error:
0.0071426327


In [6]:
mse_list = []
mse_loss_list = []

min_x_orignal_list = []
max_x_orignal_list = []

from tqdm import tqdm

for batch in tqdm(test_loader, desc="Processing batches"):
    x, sdf, tau = batch
    with torch.no_grad():
        output = vae_trainer.vae(x, reconstruction=True)
        x_reconstructed = output["x_reconstructed"]
    # mse_list.append(np.mean((x_original[0] - x_recon[0])**2))

    reconstruction_loss = F.mse_loss(x_reconstructed, x[:, 2:])
    mse_loss_list.append(reconstruction_loss.item())

    x_original = x[:, 2:].cpu().numpy()  # Remove first two columns (query points)
    x_recon = x_reconstructed.cpu().numpy()

    min_x_orignal_list.append(np.min(x_original, axis=0))
    max_x_orignal_list.append(np.max(x_original, axis=0))

    for i in range(x.shape[0]):
        mse_list.append(np.mean((x_original[i] - x_recon[i])**2))
        if np.mean((x_original[i] - x_recon[i])**2) > 1:
            print(x_original[i])
            print(x_recon[i])
            print('##########################')

print(np.mean(mse_list))





Processing batches: 100%|██████████| 15/15 [00:01<00:00, 10.17it/s]

0.00036851873





In [7]:
min_x_orignal_list = np.array(min_x_orignal_list)
max_x_orignal_list = np.array(max_x_orignal_list)

print(np.min(min_x_orignal_list, axis=0))
print(np.max(max_x_orignal_list, axis=0))


[ 0.          0.         -0.799881    0.          0.          0.
  0.         -0.7766095  -0.19944747 -0.7998588   0.          0.
  0.          0.          0.        ]
[1.         0.9998729  0.7997384  0.79972833 0.24950403 0.12693372
 0.2823601  0.7999988  0.7993542  0.79519516 0.7998265  0.2582661
 0.99214834 0.8288874  0.22966386]


In [8]:
mse_loss_list

[0.00034879017039202154,
 0.0003549386456143111,
 0.0003599647316150367,
 0.0003472205135039985,
 0.00034912803675979376,
 0.0003525547799654305,
 0.0003557983145583421,
 0.00037380977300927043,
 0.0003648570564109832,
 0.00034283087006770074,
 0.0004042361688334495,
 0.00037826638435944915,
 0.00042293747537769377,
 0.00039777022902853787,
 0.0003780173137784004]

In [9]:
print(np.array(mse_loss_list).sum() / len(mse_list))
print(np.array(mse_loss_list).mean())

3.687413642182946e-07
0.00036874136421829464


In [10]:
mse_list

[0.00089397567,
 0.00028445365,
 0.00028412035,
 0.00038552293,
 0.00052072044,
 0.00038100893,
 7.006084e-06,
 0.00018717346,
 1.4121474e-05,
 0.00048327286,
 0.0014195127,
 0.00031384852,
 1.3597034e-05,
 0.00081781443,
 0.0005745119,
 1.3189976e-05,
 1.658989e-05,
 0.00065361493,
 0.00038415586,
 0.00030403162,
 0.0011057444,
 0.00033689215,
 0.00024440902,
 7.8754265e-06,
 0.0001119691,
 1.1222213e-05,
 0.00013188084,
 0.00045341672,
 7.934791e-06,
 7.857333e-06,
 0.00016160218,
 0.00050436304,
 9.92003e-06,
 3.0374644e-05,
 0.00013023602,
 0.00039858778,
 0.0005015517,
 0.00055397686,
 2.086111e-05,
 0.00027418332,
 8.8951265e-06,
 8.6563705e-05,
 8.37228e-06,
 1.3808148e-05,
 9.091062e-06,
 0.00020172635,
 0.0002590347,
 0.0006907134,
 0.0008388558,
 0.00051521894,
 0.00016279935,
 7.150091e-06,
 0.0016878349,
 0.00041884265,
 0.0008044269,
 7.2807766e-06,
 0.00023960324,
 2.2364118e-05,
 2.6603122e-05,
 8.450766e-05,
 0.00023395265,
 8.076307e-06,
 1.5820546e-05,
 0.00039077544,