# TVAE Latent Space Visualization

This notebook demonstrates how to visualize and analyze the latent space of a Tabular Variational Autoencoder (TVAE) model. We'll use the RHC dataset and explore different visualization techniques to understand the structure of the learned latent space.

## Import Required Libraries

First, let's import all the necessary libraries and modules that we'll need for our analysis.

In [1]:
import os
import sys
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sdv.metadata import SingleTableMetadata
from sdv.metadata import Metadata

# Add the project root directory to the path so we can import our modules
sys.path.append(os.path.abspath(os.path.join(os.path.dirname('__file__'), '..')))

from tvae.tvae_wrapper import TVAESynthesizer
from tvae.visualization import (
    visualize_latent_space,
    compare_original_synthetic_latent,
    calculate_latent_statistics,
    plot_latent_dimensions
)

## Load Data and Trained Model

Now we'll load the RHC dataset and either load an existing trained TVAE model or train a new one if none exists.

In [None]:
# Define paths
data_path = os.path.join('data', 'rhc.csv')
metadata_path = os.path.join('data', 'metadata.json')
model_path = os.path.join('examples', 'test_model_tvae_ep1000_compress32.pkl')
synthetic_path = os.path.join('examples', 'synthetic_data_tvae_ep1000_compress32.csv')

# Create output directory for visualizations
output_dir = os.path.join('Images', 'latent_space')
os.makedirs(output_dir, exist_ok=True)

# Load the data
print("Loading data...")
data = pd.read_csv(data_path)

# Display the first few rows of the data
data.head()

In [None]:
# Function to load a trained model
def load_model(model_path):
    """Load a trained TVAE model from disk."""
    with open(model_path, 'rb') as f:
        return pickle.load(f)

In [None]:
# Check if a trained model exists, otherwise train a new one
if os.path.exists(model_path):
    print(f"Loading existing model from {model_path}")
    tvae = load_model(model_path)
else:
    print("Training a new TVAE model...")
    # Load metadata
    try:
        metadata = Metadata.load(metadata_path)
        # Extract single table metadata for 'rhc' table
        metadata = metadata.tables['rhc']
    except:
        # If loading fails, create new metadata
        print("Creating new metadata...")
        metadata = SingleTableMetadata()
        metadata.detect_from_dataframe(data)
    
    # Initialize and train TVAE
    tvae = TVAESynthesizer(
        metadata=metadata,
        epochs=50,  # Using fewer epochs for demonstration
        embedding_dim=32,
        compress_dims=(64, 32),
        decompress_dims=(32, 64),
        verbose=True
    )
    tvae.fit(data)
    
    # Save the model
    with open(model_path, 'wb') as f:
        pickle.dump(tvae, f)

## Generate and Save Synthetic Data

Let's generate synthetic data using our trained model or load existing synthetic data if available.

In [None]:
# Load synthetic data if it exists or generate new data
if os.path.exists(synthetic_path):
    print(f"Loading existing synthetic data from {synthetic_path}")
    synthetic_data = pd.read_csv(synthetic_path)
else:
    print("Generating synthetic data...")
    synthetic_data = tvae.sample(len(data))
    synthetic_data.to_csv(synthetic_path, index=False)

# Display first few rows of synthetic data
synthetic_data.head()

## Visualize Latent Space

Now we'll visualize the latent space of our TVAE model, coloring the points by various attributes to understand the structure of the learned representations.

In [None]:
# List of columns we want to color by
color_columns = ['death', 'sex', 'age']

# Initialize a dictionary to store results
visualization_results = {}

# Visualize latent space colored by different attributes
print("Visualizing latent space...")

for color_column in color_columns:
    if color_column in data.columns:
        save_path = os.path.join(output_dir, f'latent_space_{color_column}.png')
        fig, latent_emb, umap_emb = visualize_latent_space(
            tvae_synthesizer=tvae,
            data=data,
            color_by=color_column,
            save_path=save_path
        )
        
        # Store results for later use
        visualization_results[color_column] = {
            'fig': fig,
            'latent_emb': latent_emb,
            'umap_emb': umap_emb
        }
        
        # Display the figure
        plt.figure(fig.number)
        plt.title(f'Latent Space colored by {color_column}')
        plt.show()

## Compare Original and Synthetic Data

Let's compare the original data and synthetic data in the latent space to see how well our model captures the data distribution.

In [None]:
# Compare original and synthetic data in latent space
print("Comparing original and synthetic data in latent space...")
compare_path = os.path.join(output_dir, 'compare_original_synthetic.png')
compare_fig, compare_latent, compare_umap = compare_original_synthetic_latent(
    tvae_synthesizer=tvae,
    original_data=data,
    synthetic_data=synthetic_data,
    save_path=compare_path
)

# Display the comparison figure
plt.figure(compare_fig.number)
plt.title('Original vs Synthetic Data in Latent Space')
plt.show()

## Analyze Latent Dimensions

Now we'll calculate statistics of the latent dimensions and visualize the distributions of these dimensions to better understand the learned representations.

In [None]:
# Calculate statistics of the latent dimensions
print("Calculating latent space statistics...")
latent_stats = calculate_latent_statistics(compare_latent['original'])
print("Statistics of latent dimensions:")
display(pd.DataFrame(latent_stats))

In [None]:
# Plot distributions of latent dimensions
print("Plotting latent dimension distributions...")
dims_path = os.path.join(output_dir, 'latent_dimensions.png')
dims_fig = plot_latent_dimensions(
    latent_embeddings=compare_latent['original'],
    n_dims=min(20, tvae.embedding_dim),  # Plot at most 20 dimensions
    save_path=dims_path
)

# Display the dimensions figure
plt.figure(dims_fig.number)
plt.title('Distributions of Latent Dimensions')
plt.tight_layout()
plt.show()

## Conclusion

In this notebook, we've demonstrated how to:
1. Load and prepare data for a TVAE model
2. Load or train a TVAE model
3. Generate synthetic data using the trained model
4. Visualize the latent space with different colorings
5. Compare original and synthetic data in the latent space
6. Analyze the statistical properties of latent dimensions

These visualizations and analyses help us understand the structure of the TVAE's latent space and how well it captures the original data distribution.