# Test mini_tVAE

### 1. Setup

In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import json
import matplotlib.pyplot as plt
import sys
sys.path.append('/home/mfacotti/martin/tVAE_project')
from mini_tvae_v2 import MiniTVAE
from data_load import DataLoader


from sdv.evaluation.single_table import run_diagnostic, evaluate_quality
from sdv.metadata import Metadata

### 2. Load data

In [2]:
data_loader = DataLoader(
    csv_filename='../rhc.csv',
    meta_filename='../metadata.json'
)
data, discrete_columns = data_loader.load_data()

data.head()

Detected missing values in 5 columns:
  - cat2: 4535 missing values (79.1%)
  - dschdte: 1 missing values (0.0%)
  - dthdte: 2013 missing values (35.1%)
  - adld3p: 4296 missing values (74.9%)
  - urin1: 3028 missing values (52.8%)
No discrete columns found in metadata. Inferring from data types...
Loaded dataset with 5735 rows and 63 columns
Identified 21 discrete columns


Unnamed: 0.1,Unnamed: 0,cat1,cat2,ca,sadmdte,dschdte,dthdte,lstctdte,death,cardiohx,...,meta,hema,seps,trauma,ortho,adld3p,urin1,race,income,ptid
0,1,COPD,,Yes,11142,11151.0,,11382,No,0,...,No,No,No,No,No,0.0,,white,Under $11k,5
1,2,MOSF w/Sepsis,,No,11799,11844.0,11844.0,11844,Yes,1,...,No,No,Yes,No,No,,1437.0,white,Under $11k,7
2,3,MOSF w/Malignancy,MOSF w/Sepsis,Yes,12083,12143.0,,12400,No,0,...,No,No,No,No,No,,599.0,white,$25-$50k,9
3,4,ARF,,No,11146,11183.0,11183.0,11182,Yes,0,...,No,No,No,No,No,,,white,$11-$25k,10
4,5,MOSF w/Sepsis,,No,12035,12037.0,12037.0,12036,Yes,0,...,No,No,No,No,No,,64.0,white,Under $11k,11


### 3. Training

In [3]:
def train_tvae_model(data, discrete_columns, hyperparams=None):
    """
    Create and train a MiniTVAE model with the provided hyperparameters.
    
    Parameters:
    -----------
    data : DataFrame
        The training data
    discrete_columns : list
        List of discrete/categorical columns
    hyperparams : dict, optional
        Dictionary of hyperparameters for the MiniTVAE model
        
    Returns:
    --------
    MiniTVAE
        The trained model
    """
    # Default hyperparameters
    default_params = {
        'embedding_dim': 128,
        'compress_dims': (128, 128),
        'decompress_dims': (128, 128),
        'l2scale': 1e-5,
        'batch_size': 500,
        'epochs': 1000,
        'loss_factor': 2,
        'cuda': True, # Set to False if running on CPU
        'verbose': True
    }
    
    # Use provided hyperparameters or default values
    params = default_params.copy()
    if hyperparams:
        params.update(hyperparams)
    
    print("Training Mini TVAE model with the following parameters:")
    for key, value in params.items():
        print(f"  {key}: {value}")
    
    # Create and train the model
    model = MiniTVAE(**params)
    model.fit(data, discrete_columns)
    
    return model

# add feature to save model
model = train_tvae_model(data, discrete_columns)

Training Mini TVAE model with the following parameters:
  embedding_dim: 128
  compress_dims: (128, 128)
  decompress_dims: (128, 128)
  l2scale: 1e-05
  batch_size: 500
  epochs: 1000
  loss_factor: 2
  cuda: True
  verbose: True
Note: Found NaN values in columns: ['cat2', 'dschdte', 'dthdte', 'adld3p', 'urin1']. These will be handled using the 'from_column' approach.



Best performing initialization did not converge. Try different init parameters, or increase max_iter, tol, or check for degenerate data.


Best performing initialization did not converge. Try different init parameters, or increase max_iter, tol, or check for degenerate data.


Best performing initialization did not converge. Try different init parameters, or increase max_iter, tol, or check for degenerate data.


Best performing initialization did not converge. Try different init parameters, or increase max_iter, tol, or check for degenerate data.


Best performing initialization did not converge. Try different init parameters, or increase max_iter, tol, or check for degenerate data.


Best performing initialization did not converge. Try different init parameters, or increase max_iter, tol, or check for degenerate data.


Best performing initialization did not converge. Try different init parameters, or increase max_iter, tol, or check for degenerate data.


Best performing initializa

RuntimeError: mat1 and mat2 shapes cannot be multiplied (500x401 and 400x128)

In [None]:
def plot_loss_over_epochs(loss_values):
    """Plot the loss components across epochs."""
    # Group by epoch and calculate mean loss components per epoch
    epoch_loss = loss_values.groupby('Epoch')[['loss_1', 'loss_2', 'total_loss']].mean().reset_index()
    
    plt.figure(figsize=(18, 6))
    
    # Plot mean loss components per epoch
    plt.subplot(1, 3, 1)
    plt.plot(epoch_loss['Epoch'], epoch_loss['loss_1'], 'b-', label='Reconstruction Loss')
    plt.plot(epoch_loss['Epoch'], epoch_loss['loss_2'], 'r-', label='KL Divergence')
    plt.plot(epoch_loss['Epoch'], epoch_loss['total_loss'], 'g-', label='Total Loss')
    plt.title('Mean Loss Components per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Plot batch reconstruction losses across epochs
    plt.subplot(1, 3, 2)
    for epoch in sorted(loss_values['Epoch'].unique()):
        epoch_data = loss_values[loss_values['Epoch'] == epoch]
        plt.scatter([epoch] * len(epoch_data), epoch_data['loss_1'], 
                    alpha=0.3, s=10, color='blue')
    plt.title('Reconstruction Loss per Batch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)
    
    # Plot batch KL divergence losses across epochs
    plt.subplot(1, 3, 3)
    for epoch in sorted(loss_values['Epoch'].unique()):
        epoch_data = loss_values[loss_values['Epoch'] == epoch]
        plt.scatter([epoch] * len(epoch_data), epoch_data['loss_2'], 
                    alpha=0.3, s=10, color='red')
    plt.title('KL Divergence per Batch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

    
plot_loss_over_epochs(model.loss_values)

### 5. Generation

In [None]:
def generate_synthetic_data(model, num_samples, save_path=None):
    """
    Generate synthetic data from the trained model.
    
    Parameters:
    -----------
    model : MiniTVAE
        The trained model
    num_samples : int
        Number of samples to generate
    save_path : str, optional
        Path to save the synthetic data
        
    Returns:
    --------
    DataFrame
        The generated synthetic data
    """
    print(f"Generating {num_samples} synthetic samples...")
    synthetic_data = model.sample(num_samples)
    
    if save_path:
        synthetic_data.to_csv(save_path, index=False)
        print(f"Synthetic data saved to '{save_path}'")
    
    return synthetic_data

generate_synthetic_data(model, 10000, 'synthetic_rhc.csv') # or fit the number of samples to the original dataset with len(data)


In [None]:
def compare_statistics(original_data, synthetic_data, discrete_columns, 
                       num_numeric=5, num_categorical=3, num_categories=5, 
                       exclude_first_column=True):
    """
    Compare statistics between original and synthetic data.
    
    Parameters:
    -----------
    original_data : DataFrame
        The original data
    synthetic_data : DataFrame
        The synthetic data
    discrete_columns : list
        List of discrete/categorical columns
    num_numeric : int, optional
        Number of numeric columns to compare
    num_categorical : int, optional
        Number of categorical columns to compare
    num_categories : int, optional
        Number of categories to display per categorical column
    exclude_first_column : bool, optional
        Whether to exclude the first column (patient ID) from the comparison
        
    Returns:
    --------
    dict
        Dictionary containing statistical comparison metrics
    """
    print("Comparing statistics between original and synthetic data:")
    
    # Create a copy of the dataframes to avoid modifying the originals
    orig_data = original_data.copy()
    synth_data = synthetic_data.copy()
    
    # Exclude the first column (patient ID) if requested
    if exclude_first_column:
        first_col = orig_data.columns[0]
        print(f"Excluding first column (patient ID): {first_col}")
        orig_data = orig_data.drop(columns=[first_col])
        if first_col in synth_data.columns:
            synth_data = synth_data.drop(columns=[first_col])
    
    # Initialize results dictionary to store comparison metrics
    results = {
        'numeric': {},
        'categorical': {}
    }
    
    # Compare numeric columns
    numeric_columns = orig_data.select_dtypes(include=['number']).columns
    if len(numeric_columns) > 0:
        print("\nNumeric columns comparison:")
        for col in numeric_columns[:num_numeric]:
            # Calculate statistics for original data
            orig_mean = orig_data[col].mean()
            orig_std = orig_data[col].std()
            orig_min = orig_data[col].min()
            orig_max = orig_data[col].max()
            
            # Calculate statistics for synthetic data
            syn_mean = synth_data[col].mean()
            syn_std = synth_data[col].std()
            syn_min = synth_data[col].min()
            syn_max = synth_data[col].max()
            
            # Calculate differences
            mean_diff = abs(orig_mean - syn_mean)
            std_diff = abs(orig_std - syn_std)
            
            # Store results
            results['numeric'][col] = {
                'original': {'mean': orig_mean, 'std': orig_std, 'min': orig_min, 'max': orig_max},
                'synthetic': {'mean': syn_mean, 'std': syn_std, 'min': syn_min, 'max': syn_max},
                'difference': {'mean': mean_diff, 'std': std_diff}
            }
            
            # Print comparison
            print(f"\nColumn: {col}")
            print(f"  Original - Mean: {orig_mean:.4f}, Std: {orig_std:.4f}, Range: [{orig_min:.4f}, {orig_max:.4f}]")
            print(f"  Synthetic - Mean: {syn_mean:.4f}, Std: {syn_std:.4f}, Range: [{syn_min:.4f}, {syn_max:.4f}]")
            print(f"  Difference - Mean: {mean_diff:.4f}, Std: {std_diff:.4f}")
    
    # Compare categorical columns
    valid_discrete_columns = [col for col in discrete_columns if col in orig_data.columns]
    if len(valid_discrete_columns) > 0:
        print("\nCategorical columns comparison (value counts percentage):")
        for col in valid_discrete_columns[:num_categorical]:
            print(f"\nColumn: {col}")
            
            # Calculate distributions
            orig_counts = orig_data[col].value_counts(normalize=True).sort_index()
            syn_counts = synth_data[col].value_counts(normalize=True).sort_index()
            
            # Combine indices to ensure we show all categories
            all_cats = sorted(list(set(list(orig_counts.index) + list(syn_counts.index))))
            cat_results = {}
            
            # Print and store category comparisons
            for cat in all_cats[:num_categories]:
                orig_pct = orig_counts.get(cat, 0) * 100
                syn_pct = syn_counts.get(cat, 0) * 100
                diff_pct = abs(orig_pct - syn_pct)
                
                cat_results[cat] = {
                    'original': orig_pct,
                    'synthetic': syn_pct,
                    'difference': diff_pct
                }
                
                print(f"  {cat}: Original {orig_pct:.1f}%, Synthetic {syn_pct:.1f}%, Diff {diff_pct:.1f}%")
            
            # Store results for this column
            results['categorical'][col] = cat_results
    
    return results


stats = compare_statistics(data, generate_synthetic_data(model, len(data)), discrete_columns, exclude_first_column=True)

### 5.Evaluation

In [None]:
data = pd.read_csv('../rhc.csv')
metadata = Metadata.load_from_json('../metadata.json')

# Load the synthetic data
synthetic_data = pd.read_csv('synthetic_rhc.csv')

In [None]:
diagnostic = run_diagnostic(
    real_data=data,
    synthetic_data=synthetic_data,
    metadata=metadata
)

In [None]:
quality_report = evaluate_quality(
    data,
    synthetic_data,
    metadata
)

# quality_report.save(filepath='results/diagnostic_report.pkl')
# quality_report = QualityReport.load('results/quality_report.pkl')


In [None]:
quality_report.get_details('Column Shapes')
# KSComplement for numerical columns
# TVComplement for categorical columns

In [None]:
quality_report.get_details('Column Pair Trends')

In [None]:
fig1 = quality_report.get_visualization('Column Pair Trends')
fig2 = quality_report.get_visualization('Column Shapes')

fig1.show()
fig2.show() # This one is not working properly