# Population Health Analytics with Pythae VAE

This notebook demonstrates advanced population health analytics using Variational Autoencoders (VAE) from the Pythae library. We'll analyze member embeddings to identify risk patterns, health phenotypes, and care opportunities.

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import torch
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

# Pythae imports
from pythae.models import VAE, BetaVAE, VAE_LinNF, RHVAE
from pythae.models.base.base_utils import ModelOutput
from pythae.trainers import BaseTrainer, BaseTrainerConfig
from pythae.data.preprocessors import DataProcessor

# Local imports
import sys
sys.path.append('..')
from pipelines.embedding_pipeline import EmbeddingPipeline
from models.config_models import PipelineConfig
from utils.logging_utils import get_logger

# Set random seeds
np.random.seed(42)
torch.manual_seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8')
%matplotlib inline

## 1. Load and Prepare Healthcare Data

In [None]:
# Load medical claims data
data_path = Path('data/medical_claims_complete.csv')
df = pd.read_csv(data_path)

print(f"Loaded {len(df)} member records")
print(f"Columns: {df.columns.tolist()}")
print(f"\nLabel distribution:\n{df['label'].value_counts()}")

## 2. Generate Member Embeddings

In [None]:
# Configure embedding pipeline
config = {
    'pipeline': {
        'job_name': 'population_health_vae',
        'log_level': 'INFO'
    },
    'data': {
        'data_path': str(data_path.absolute()),
        'claim_column': 'claim',
        'label_column': 'label',
        'mcid_column': 'mcid'
    },
    'llm': {
        'model_url': 'http://localhost:8000',
        'batch_size': 32,
        'max_retries': 3
    },
    'outputs': {
        'output_dir': 'outputs/population_health_vae',
        'save_embeddings': True
    }
}

# Run embedding pipeline
pipeline_config = PipelineConfig(**config)
embedding_pipeline = EmbeddingPipeline(pipeline_config)
embeddings_df = embedding_pipeline.run()

print(f"Generated embeddings shape: {embeddings_df.shape}")

## 3. Prepare Data for VAE Training

In [None]:
# Extract embedding features
embedding_cols = [col for col in embeddings_df.columns if col.startswith('embedding_')]
X = embeddings_df[embedding_cols].values
y = embeddings_df['label'].values
member_ids = embeddings_df['mcid'].values

# Standardize features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Convert to tensors
X_tensor = torch.FloatTensor(X_scaled)

print(f"Data shape: {X_tensor.shape}")
print(f"Number of members: {len(np.unique(member_ids))}")

## 4. Train Population Health VAE Models

In [None]:
# Define VAE architectures for population health
input_dim = X_tensor.shape[1]
latent_dim = 32  # Latent space for health phenotypes

# Model 1: Standard VAE for baseline
vae_config = VAE.default_config()
vae_config.input_dim = input_dim
vae_config.latent_dim = latent_dim
vae_model = VAE(vae_config)

# Model 2: Beta-VAE for disentangled health factors
beta_vae_config = BetaVAE.default_config()
beta_vae_config.input_dim = input_dim
beta_vae_config.latent_dim = latent_dim
beta_vae_config.beta = 4.0  # Encourage disentanglement
beta_vae_model = BetaVAE(beta_vae_config)

# Model 3: RHVAE for complex health patterns
rhvae_config = RHVAE.default_config()
rhvae_config.input_dim = input_dim
rhvae_config.latent_dim = latent_dim
rhvae_model = RHVAE(rhvae_config)

models = {
    'VAE': vae_model,
    'Beta-VAE': beta_vae_model,
    'RHVAE': rhvae_model
}

In [None]:
# Train models
trainer_config = BaseTrainerConfig(
    num_epochs=50,
    learning_rate=1e-3,
    batch_size=64,
    steps_saving=10
)

trained_models = {}
training_losses = {}

for model_name, model in models.items():
    print(f"\nTraining {model_name}...")
    
    trainer = BaseTrainer(
        model=model,
        train_dataset=X_tensor,
        training_config=trainer_config
    )
    
    trainer.train()
    trained_models[model_name] = trainer.model
    training_losses[model_name] = trainer.training_logs['train_loss']
    
    print(f"{model_name} training complete!")

## 5. Member Risk Stratification Using Reconstruction Error

In [None]:
# Calculate reconstruction errors for risk assessment
def calculate_reconstruction_error(model, data):
    model.eval()
    with torch.no_grad():
        model_output = model(data)
        recon_x = model_output.recon_x
        errors = torch.mean((data - recon_x) ** 2, dim=1)
    return errors.numpy()

# Calculate errors for each model
reconstruction_errors = {}
for model_name, model in trained_models.items():
    errors = calculate_reconstruction_error(model, X_tensor)
    reconstruction_errors[model_name] = errors

# Create risk stratification dataframe
risk_df = pd.DataFrame({
    'mcid': member_ids,
    'label': y,
    **{f'{model}_error': errors for model, errors in reconstruction_errors.items()}
})

# Calculate risk scores (higher error = higher risk)
for model_name in reconstruction_errors.keys():
    risk_df[f'{model_name}_risk_percentile'] = risk_df[f'{model_name}_error'].rank(pct=True) * 100

print("Risk Stratification Summary:")
print(risk_df.describe())

In [None]:
# Visualize risk distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

for idx, (model_name, errors) in enumerate(reconstruction_errors.items()):
    ax = axes[idx]
    
    # Plot by label
    for label in np.unique(y):
        mask = y == label
        ax.hist(errors[mask], bins=30, alpha=0.6, label=f'Label {label}', density=True)
    
    ax.set_title(f'{model_name} Risk Distribution')
    ax.set_xlabel('Reconstruction Error')
    ax.set_ylabel('Density')
    ax.legend()

plt.tight_layout()
plt.show()

# Identify high-risk members
high_risk_threshold = 95  # Top 5% risk
high_risk_members = risk_df[risk_df['VAE_risk_percentile'] >= high_risk_threshold]
print(f"\nIdentified {len(high_risk_members)} high-risk members (top 5%)")

## 6. Health Phenotype Discovery Through Latent Space Clustering

In [None]:
# Extract latent representations
def get_latent_representations(model, data):
    model.eval()
    with torch.no_grad():
        encoder_output = model.encoder(data)
        if hasattr(encoder_output, 'embedding'):
            z = encoder_output.embedding
        else:
            z = encoder_output.z
    return z.numpy()

# Get latent representations from Beta-VAE (best for disentanglement)
latent_representations = get_latent_representations(trained_models['Beta-VAE'], X_tensor)

# Determine optimal number of clusters
silhouette_scores = []
K_range = range(2, 11)

for k in K_range:
    kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
    labels = kmeans.fit_predict(latent_representations)
    score = silhouette_score(latent_representations, labels)
    silhouette_scores.append(score)

optimal_k = K_range[np.argmax(silhouette_scores)]
print(f"Optimal number of health phenotypes: {optimal_k}")

# Perform clustering with optimal k
kmeans = KMeans(n_clusters=optimal_k, random_state=42, n_init=10)
phenotype_clusters = kmeans.fit_predict(latent_representations)

# Add to risk dataframe
risk_df['phenotype'] = phenotype_clusters

In [None]:
# Visualize phenotypes in 2D
pca = PCA(n_components=2)
latent_2d = pca.fit_transform(latent_representations)

plt.figure(figsize=(12, 8))
scatter = plt.scatter(latent_2d[:, 0], latent_2d[:, 1], 
                     c=phenotype_clusters, cmap='tab10', 
                     alpha=0.6, edgecolors='black', linewidth=0.5)
plt.colorbar(scatter, label='Health Phenotype')
plt.xlabel('First Principal Component')
plt.ylabel('Second Principal Component')
plt.title('Member Health Phenotypes in Latent Space')

# Add cluster centers
centers_2d = pca.transform(kmeans.cluster_centers_)
plt.scatter(centers_2d[:, 0], centers_2d[:, 1], 
           marker='*', s=300, c='red', edgecolors='black', linewidth=2)

plt.tight_layout()
plt.show()

# Phenotype characteristics
print("\nPhenotype Distribution:")
phenotype_stats = risk_df.groupby('phenotype').agg({
    'mcid': 'count',
    'label': lambda x: (x == 1).mean(),
    'VAE_risk_percentile': 'mean'
}).round(2)
phenotype_stats.columns = ['Member Count', 'Positive Rate', 'Avg Risk Score']
print(phenotype_stats)

## 7. Member Journey Analysis Through Latent Interpolation

In [None]:
# Analyze transitions between health states
def interpolate_latent_path(z_start, z_end, steps=10):
    """Interpolate between two latent representations"""
    alphas = np.linspace(0, 1, steps)
    path = []
    for alpha in alphas:
        z_interp = (1 - alpha) * z_start + alpha * z_end
        path.append(z_interp)
    return np.array(path)

# Find example members from different phenotypes
phenotype_examples = {}
for phenotype in range(optimal_k):
    members = risk_df[risk_df['phenotype'] == phenotype]
    # Get a low-risk member from this phenotype
    example_idx = members['VAE_risk_percentile'].idxmin()
    phenotype_examples[phenotype] = example_idx

# Create transition paths between phenotypes
transition_paths = {}
for p1 in range(optimal_k):
    for p2 in range(p1 + 1, optimal_k):
        idx1 = phenotype_examples[p1]
        idx2 = phenotype_examples[p2]
        
        z1 = latent_representations[idx1]
        z2 = latent_representations[idx2]
        
        path = interpolate_latent_path(z1, z2)
        transition_paths[f'P{p1}_to_P{p2}'] = path

# Visualize a sample transition
sample_transition = list(transition_paths.keys())[0]
path = transition_paths[sample_transition]
path_2d = pca.transform(path)

plt.figure(figsize=(10, 8))
# Plot all members
plt.scatter(latent_2d[:, 0], latent_2d[:, 1], 
           c=phenotype_clusters, cmap='tab10', alpha=0.3)
# Plot transition path
plt.plot(path_2d[:, 0], path_2d[:, 1], 'r-', linewidth=2, label=sample_transition)
plt.scatter(path_2d[0, 0], path_2d[0, 1], marker='o', s=200, c='green', label='Start')
plt.scatter(path_2d[-1, 0], path_2d[-1, 1], marker='s', s=200, c='red', label='End')

plt.xlabel('First Principal Component')
plt.ylabel('Second Principal Component')
plt.title('Health State Transition Path Analysis')
plt.legend()
plt.tight_layout()
plt.show()

## 8. Population Health Insights Dashboard

In [None]:
# Create comprehensive population health dashboard
fig = plt.figure(figsize=(20, 12))

# 1. Risk Distribution by Phenotype
ax1 = plt.subplot(2, 3, 1)
risk_by_phenotype = risk_df.groupby('phenotype')['VAE_risk_percentile'].apply(list)
ax1.boxplot(risk_by_phenotype.values, labels=risk_by_phenotype.index)
ax1.set_xlabel('Health Phenotype')
ax1.set_ylabel('Risk Percentile')
ax1.set_title('Risk Distribution by Phenotype')

# 2. Phenotype Size and Composition
ax2 = plt.subplot(2, 3, 2)
phenotype_sizes = risk_df.groupby(['phenotype', 'label']).size().unstack(fill_value=0)
phenotype_sizes.plot(kind='bar', stacked=True, ax=ax2)
ax2.set_xlabel('Health Phenotype')
ax2.set_ylabel('Member Count')
ax2.set_title('Phenotype Composition')
ax2.legend(['Negative', 'Positive'])

# 3. High-Risk Member Distribution
ax3 = plt.subplot(2, 3, 3)
high_risk_dist = risk_df[risk_df['VAE_risk_percentile'] >= 90].groupby('phenotype').size()
ax3.pie(high_risk_dist.values, labels=high_risk_dist.index, autopct='%1.1f%%')
ax3.set_title('High-Risk Members by Phenotype')

# 4. Latent Space Visualization
ax4 = plt.subplot(2, 3, 4)
scatter = ax4.scatter(latent_2d[:, 0], latent_2d[:, 1], 
                     c=risk_df['VAE_risk_percentile'], cmap='RdYlBu_r',
                     alpha=0.6, edgecolors='black', linewidth=0.5)
plt.colorbar(scatter, ax=ax4, label='Risk Percentile')
ax4.set_xlabel('First Principal Component')
ax4.set_ylabel('Second Principal Component')
ax4.set_title('Risk Landscape in Latent Space')

# 5. Model Comparison
ax5 = plt.subplot(2, 3, 5)
model_performance = pd.DataFrame({
    model: risk_df[f'{model}_error'].describe()
    for model in reconstruction_errors.keys()
})
model_performance.loc[['mean', 'std', 'min', 'max']].plot(kind='bar', ax=ax5)
ax5.set_ylabel('Reconstruction Error')
ax5.set_title('Model Performance Comparison')
ax5.legend(loc='upper right')

# 6. Risk Score Correlations
ax6 = plt.subplot(2, 3, 6)
risk_cols = [col for col in risk_df.columns if 'risk_percentile' in col]
correlation_matrix = risk_df[risk_cols].corr()
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0, ax=ax6)
ax6.set_title('Risk Score Correlations Across Models')

plt.tight_layout()
plt.savefig('outputs/population_health_dashboard.png', dpi=300, bbox_inches='tight')
plt.show()

## 9. Actionable Insights for Care Management

In [None]:
# Generate care management recommendations
def generate_care_recommendations(risk_df, top_n=10):
    recommendations = []
    
    # 1. Highest risk members needing immediate intervention
    immediate_intervention = risk_df.nlargest(top_n, 'VAE_risk_percentile')[['mcid', 'phenotype', 'VAE_risk_percentile']]
    
    # 2. Members showing anomalous patterns (high reconstruction error)
    anomaly_threshold = risk_df['VAE_error'].quantile(0.95)
    anomalous_members = risk_df[risk_df['VAE_error'] > anomaly_threshold]
    
    # 3. Phenotype-specific interventions
    phenotype_risks = risk_df.groupby('phenotype')['VAE_risk_percentile'].mean().sort_values(ascending=False)
    high_risk_phenotypes = phenotype_risks[phenotype_risks > 70].index.tolist()
    
    return {
        'immediate_intervention': immediate_intervention,
        'anomalous_patterns': anomalous_members[['mcid', 'phenotype', 'VAE_error']].head(top_n),
        'high_risk_phenotypes': high_risk_phenotypes,
        'phenotype_risk_scores': phenotype_risks
    }

recommendations = generate_care_recommendations(risk_df)

print("=== CARE MANAGEMENT RECOMMENDATIONS ===")
print("\n1. Members Requiring Immediate Intervention:")
print(recommendations['immediate_intervention'])

print("\n2. Members with Anomalous Health Patterns:")
print(recommendations['anomalous_patterns'])

print("\n3. High-Risk Phenotypes for Targeted Programs:")
print(f"Phenotypes: {recommendations['high_risk_phenotypes']}")
print("\nPhenotype Risk Scores:")
print(recommendations['phenotype_risk_scores'])

## 10. Export Results for Clinical Integration

In [None]:
# Prepare comprehensive member risk profile
member_risk_profile = risk_df[[
    'mcid', 'label', 'phenotype',
    'VAE_risk_percentile', 'Beta-VAE_risk_percentile', 'RHVAE_risk_percentile'
]].copy()

# Add risk categories
member_risk_profile['risk_category'] = pd.cut(
    member_risk_profile['VAE_risk_percentile'],
    bins=[0, 50, 80, 95, 100],
    labels=['Low', 'Moderate', 'High', 'Critical']
)

# Add phenotype descriptions
phenotype_descriptions = {
    i: f"Phenotype_{i}_{'HighRisk' if i in recommendations['high_risk_phenotypes'] else 'Standard'}"
    for i in range(optimal_k)
}
member_risk_profile['phenotype_description'] = member_risk_profile['phenotype'].map(phenotype_descriptions)

# Save results
output_dir = Path('outputs/population_health_vae')
output_dir.mkdir(parents=True, exist_ok=True)

# Save member risk profiles
member_risk_profile.to_csv(output_dir / 'member_risk_profiles.csv', index=False)

# Save model artifacts
for model_name, model in trained_models.items():
    torch.save(model.state_dict(), output_dir / f'{model_name.lower()}_weights.pt')

# Save summary statistics
summary_stats = {
    'total_members': len(risk_df),
    'phenotypes_identified': optimal_k,
    'high_risk_members': len(risk_df[risk_df['risk_category'] == 'Critical']),
    'anomalous_members': len(risk_df[risk_df['VAE_error'] > risk_df['VAE_error'].quantile(0.95)]),
    'model_performance': {
        model: float(risk_df[f'{model}_error'].mean())
        for model in reconstruction_errors.keys()
    }
}

import json
with open(output_dir / 'population_health_summary.json', 'w') as f:
    json.dump(summary_stats, f, indent=2)

print("\n=== EXPORT COMPLETE ===")
print(f"Results saved to: {output_dir}")
print(f"\nSummary:")
for key, value in summary_stats.items():
    if isinstance(value, dict):
        print(f"{key}:")
        for k, v in value.items():
            print(f"  {k}: {v:.4f}")
    else:
        print(f"{key}: {value}")

## Conclusion

This notebook demonstrated advanced population health analytics using Pythae VAE models:

1. **Risk Stratification**: Identified high-risk members using reconstruction errors
2. **Phenotype Discovery**: Found distinct health phenotypes through latent space clustering
3. **Journey Analysis**: Analyzed potential health state transitions
4. **Actionable Insights**: Generated specific recommendations for care management

The VAE approach provides a powerful framework for understanding complex health patterns and enabling proactive, personalized care management at the population level.