In [2]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from lightning.pytorch import Trainer, seed_everything, callbacks
from lightning.pytorch.loggers import TensorBoardLogger

sys.path.append(os.path.abspath('..'))

from models.sdf_models import LitSdfAE, LitSdfAE_MINE
from models.sdf_models import AE, VAE, MMD_VAE
from models.sdf_models import AE_DeepSDF, VAE_DeepSDF, MMD_VAE_DeepSDF
from datasets.SDF_dataset import SdfDataset, SdfDatasetSurface, collate_fn_surface
from datasets.SDF_dataset import RadiusDataset

# Enable anomaly detection to help find where NaN/Inf values originate
torch.autograd.set_detect_anomaly(True)

# Enable deterministic algorithms for better debugging
# torch.use_deterministic_algorithms(True)

# # Set debug mode for floating point operations
# torch.set_printoptions(precision=10, sci_mode=False)

# # Function to check for NaN/Inf values in tensors
# def check_tensor(tensor, tensor_name=""):
#     if torch.isnan(tensor).any():
#         print(f"NaN detected in {tensor_name}")
#         print(tensor)
#         raise ValueError(f"NaN detected in {tensor_name}")
#     if torch.isinf(tensor).any():
#         print(f"Inf detected in {tensor_name}") 
#         print(tensor)
#         raise ValueError(f"Inf detected in {tensor_name}")


# Add the parent directory of NN_TopOpt to the system path
sys.path.append(os.path.abspath('NN_TopOpt'))

dataset_path = '../shape_datasets'

    
# dataset_test_files = [f'{dataset_path}/ellipse_sdf_dataset_smf22_arc_ratio_500_test.csv',
#                 f'{dataset_path}/triangle_sdf_dataset_smf20_arc_ratio_500_test.csv', 
#                 f'{dataset_path}/quadrangle_sdf_dataset_smf20_arc_ratio_500_test.csv']

dataset_test_files = [f'{dataset_path}/ellipse_sdf_dataset_smf22_arc_ratio_500.csv',
                 f'{dataset_path}/triangle_sdf_dataset_smf20_arc_ratio_500.csv', 
                 f'{dataset_path}/quadrangle_sdf_dataset_smf20_arc_ratio_500.csv']


test_dataset = SdfDataset(dataset_test_files, exclude_ellipse=False)

# Create DataLoaders with shuffling
batch_size = 64

test_loader = torch.utils.data.DataLoader(
    test_dataset, 
    batch_size=batch_size,
    shuffle=False,  # No need to shuffle test data
    num_workers=15
)

In [3]:
from models.sdf_models import LitSdfAE_Reconstruction

# configs_dir = '../configs/NN_sdf_experiments/recon_MILoss'
configs_dir = '../configs/NN_sdf_experiments/model_arch_minmi'
models_dir = '../model_weights'

# config_name = 'AE_DeepSDF_ReconDec'
config_name = 'AE_DeepSDF_minMI'
# uba_VAE_DeepSDF_minMI.pt

run_name = f'local_{config_name}_test_large'
# saved_model_path = f'{configs_dir}/{run_name}/checkpoints/epoch=0-step=0.ckpt'
# saved_model_path = f'{models_dir}/local_{config_name}_test6.pt'
saved_model_path = f'{models_dir}/{config_name}_full.pt'

models = {'AE_DeepSDF': AE_DeepSDF,
          'AE': AE, 
          'VAE': VAE,
          'VAE_DeepSDF': VAE_DeepSDF,
          'MMD_VAE': MMD_VAE,
          'MMD_VAE_DeepSDF': MMD_VAE_DeepSDF}

In [4]:
import yaml


# Load configuration from YAML file
with open(f'{configs_dir}/{config_name}.yaml', 'r') as file:
    config = yaml.safe_load(file)

# Initialize VAE model
model_params = config['model']['params']
model_params['input_dim'] = 17 # train_dataset.feature_dim
vae_model = models[config['model']['type']](**model_params)

# Load pre-trained weights for the model
# pretrained_weights_path = config['model']['pretrained_weights_path'
state_dict = torch.load(saved_model_path)
new_state_dict = vae_model.state_dict()

# Update the new_state_dict with the loaded state_dict, ignoring size mismatches
for key in state_dict:
    if key in new_state_dict and state_dict[key].size() == new_state_dict[key].size():
        new_state_dict[key] = state_dict[key]

vae_model.load_state_dict(state_dict)


regularization: l2, reg_weight: 0.1
Using orthogonality loss: None


  state_dict = torch.load(saved_model_path)


<All keys matched successfully>

In [5]:
class_names = ['Ellipse', 'Triangle', 'Quadrangle']

def investigate_latent_space(model, dataloader, stats_dir=None, config_name=None):
    """Visualize the latent space"""

    # Create stats directory if it doesn't exist
    if stats_dir is not None:
        os.makedirs(stats_dir, exist_ok=True)

    model.eval()
    latent_vectors = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Processing batches"):
            output = model(batch[0])
            latent_vectors.append(output['z'])
                
    latent_vectors = torch.cat(latent_vectors, dim=0)
    latent_vectors = latent_vectors.cpu().numpy()

    print(latent_vectors.shape)

    latent_mins = np.min(latent_vectors, axis=0)
    latent_maxs = np.max(latent_vectors, axis=0)

    if stats_dir is not None and config_name is not None:
        np.savez(
            f"{stats_dir}/{config_name}_full_stats.npz",
            latent_mins=latent_mins,
            latent_maxs=latent_maxs
        )


investigate_latent_space(vae_model, test_loader, stats_dir=f'../z_limits', config_name=config_name)

Processing batches: 100%|██████████| 23438/23438 [01:34<00:00, 247.02it/s]


(1500000, 9)
