## 1. Setup and Data Preparation

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import anndata as ad
from livae import agent
from sklearn.metrics import adjusted_rand_score

np.random.seed(42)

# Create synthetic data with clear hierarchical structure
def create_hierarchical_data():
    """
    Create synthetic single-cell data with hierarchical structure:
    - Main lineage A: Subtypes A1, A2
    - Main lineage B: Subtypes B1, B2, B3
    """
    n_genes = 150
    
    # Lineage A
    data_A1 = np.random.poisson(4.0, (100, n_genes))
    data_A1[:, :25] = np.random.poisson(20.0, (100, 25))  # Shared A genes
    data_A1[:, 25:40] = np.random.poisson(15.0, (100, 15))  # A1-specific
    
    data_A2 = np.random.poisson(4.0, (80, n_genes))
    data_A2[:, :25] = np.random.poisson(20.0, (80, 25))  # Shared A genes
    data_A2[:, 40:55] = np.random.poisson(15.0, (80, 15))  # A2-specific
    
    # Lineage B
    data_B1 = np.random.poisson(4.0, (90, n_genes))
    data_B1[:, 60:80] = np.random.poisson(20.0, (90, 20))  # Shared B genes
    data_B1[:, 80:90] = np.random.poisson(15.0, (90, 10))  # B1-specific
    
    data_B2 = np.random.poisson(4.0, (70, n_genes))
    data_B2[:, 60:80] = np.random.poisson(20.0, (70, 20))  # Shared B genes
    data_B2[:, 90:100] = np.random.poisson(15.0, (70, 10))  # B2-specific
    
    data_B3 = np.random.poisson(4.0, (60, n_genes))
    data_B3[:, 60:80] = np.random.poisson(20.0, (60, 20))  # Shared B genes
    data_B3[:, 100:110] = np.random.poisson(15.0, (60, 10))  # B3-specific
    
    # Combine
    X = np.vstack([data_A1, data_A2, data_B1, data_B2, data_B3]).astype(float)
    
    # Labels
    lineage = ['A']*180 + ['B']*220
    subtype = ['A1']*100 + ['A2']*80 + ['B1']*90 + ['B2']*70 + ['B3']*60
    
    # Create AnnData
    adata = ad.AnnData(X)
    adata.obs['lineage'] = lineage
    adata.obs['subtype'] = subtype
    adata.layers['counts'] = X.copy()
    adata.var_names = [f'Gene_{i}' for i in range(n_genes)]
    
    return adata

adata = create_hierarchical_data()

print(f"Created hierarchical dataset:")
print(f"  Shape: {adata.shape}")
print(f"  Lineages: {adata.obs['lineage'].value_counts().to_dict()}")
print(f"  Subtypes: {adata.obs['subtype'].value_counts().to_dict()}")

## 2. Baseline Model (No Regularization)

In [None]:
# Train baseline model with minimal regularization
print("Training baseline model (beta=1.0, no additional regularization)...\n")

model_baseline = agent(
    adata=adata,
    layer='counts',
    latent_dim=10,
    i_dim=2,
    hidden_dim=64,
    percent=0.15,
    lr=1e-3,
    beta=1.0,         # Standard VAE
    lorentz=0.0,      # No Lorentzian regularization
    irecon=0.0,       # No interpretable reconstruction
    dip=0.0,          # No DIP
    tc=0.0,           # No TC
    info=0.0          # No InfoVAE
)

model_baseline.fit(epochs=50)

# Extract embeddings
latent_baseline = model_baseline.get_latent()
iembed_baseline = model_baseline.get_iembed()

print("\nâœ… Baseline model trained")

## 3. Î²-VAE: Enhanced Disentanglement

In [None]:
# Train Î²-VAE with higher beta for disentanglement
print("Training Î²-VAE model (beta=4.0)...\n")

model_beta = agent(
    adata=adata,
    layer='counts',
    latent_dim=10,
    i_dim=2,
    hidden_dim=64,
    percent=0.15,
    lr=1e-3,
    beta=4.0,         # Higher beta encourages disentanglement
    lorentz=0.0,
    irecon=0.0,
    dip=0.0,
    tc=0.0,
    info=0.0
)

model_beta.fit(epochs=50)

latent_beta = model_beta.get_latent()
iembed_beta = model_beta.get_iembed()

print("\nâœ… Î²-VAE model trained")
print("\nÎ²-VAE Effect: Higher Î² â†’ More disentangled latent factors")
print("  Î² = 1.0: Standard VAE")
print("  Î² > 1.0: Encourages independence between latent dimensions")

## 4. Lorentzian Regularization: Hyperbolic Geometry

In [None]:
# Train model with Lorentzian regularization
print("Training Lorentzian model (lorentz=2.0)...\n")

model_lorentz = agent(
    adata=adata,
    layer='counts',
    latent_dim=10,
    i_dim=2,
    hidden_dim=64,
    percent=0.15,
    lr=1e-3,
    beta=1.0,
    lorentz=2.0,      # Lorentzian/hyperbolic geometry regularization
    irecon=0.0,
    dip=0.0,
    tc=0.0,
    info=0.0
)

model_lorentz.fit(epochs=50)

latent_lorentz = model_lorentz.get_latent()
iembed_lorentz = model_lorentz.get_iembed()

print("\nâœ… Lorentzian model trained")
print("\nLorentzian Effect: Embeds data in hyperbolic space")
print("  Benefit: Better captures hierarchical/tree-like relationships")
print("  Use case: Cell differentiation trajectories, developmental hierarchies")

## 5. Combined Regularization: Full LiVAE

In [None]:
# Train full LiVAE with multiple regularizations
print("Training full LiVAE model (beta + lorentz + irecon)...\n")

model_full = agent(
    adata=adata,
    layer='counts',
    latent_dim=10,
    i_dim=2,
    hidden_dim=64,
    percent=0.15,
    lr=1e-3,
    beta=2.0,         # Disentanglement
    lorentz=1.5,      # Hyperbolic geometry
    irecon=1.0,       # Interpretable reconstruction
    dip=0.5,          # Disentangled prior
    tc=0.3,           # Total correlation
    info=0.0          # Can add InfoVAE if needed
)

model_full.fit(epochs=50)

latent_full = model_full.get_latent()
iembed_full = model_full.get_iembed()

print("\nâœ… Full LiVAE model trained")
print("\nCombined regularization benefits:")
print("  Î²-VAE: Disentangled factors")
print("  Lorentzian: Hierarchical structure")
print("  irecon: Interpretable compressed features")
print("  DIP: Independent latent dimensions")
print("  TC: Minimize total correlation")

## 6. Compare All Models

In [None]:
# Visualize all models
fig, axes = plt.subplots(2, 4, figsize=(18, 9))

models = [
    ('Baseline\n(beta=1.0)', iembed_baseline, latent_baseline),
    ('Î²-VAE\n(beta=4.0)', iembed_beta, latent_beta),
    ('Lorentzian\n(lorentz=2.0)', iembed_lorentz, latent_lorentz),
    ('Full LiVAE\n(combined)', iembed_full, latent_full)
]

colors_lineage = {'A': 'red', 'B': 'blue'}
colors_subtype = {'A1': '#ff6b6b', 'A2': '#ff9999', 
                  'B1': '#4dabf7', 'B2': '#74c0fc', 'B3': '#a5d8ff'}

for i, (name, iembed, latent) in enumerate(models):
    # Plot interpretable embedding colored by lineage
    for lineage in ['A', 'B']:
        mask = adata.obs['lineage'] == lineage
        axes[0, i].scatter(
            iembed[mask, 0], iembed[mask, 1],
            c=colors_lineage[lineage], label=lineage,
            alpha=0.6, s=25, edgecolors='none'
        )
    axes[0, i].set_title(f'{name}\nLineage', fontsize=10)
    axes[0, i].legend(fontsize=8)
    axes[0, i].grid(alpha=0.3)
    
    # Plot interpretable embedding colored by subtype
    for subtype in ['A1', 'A2', 'B1', 'B2', 'B3']:
        mask = adata.obs['subtype'] == subtype
        axes[1, i].scatter(
            iembed[mask, 0], iembed[mask, 1],
            c=colors_subtype[subtype], label=subtype,
            alpha=0.6, s=25, edgecolors='none'
        )
    axes[1, i].set_title(f'Subtype', fontsize=10)
    axes[1, i].legend(fontsize=7, ncol=2)
    axes[1, i].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("\nVisualization shows how different regularizations affect the latent space structure.")

## 7. Quantitative Comparison

In [None]:
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, silhouette_score

def evaluate_model(latent, true_labels, n_clusters=5):
    """Evaluate clustering quality"""
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    pred_labels = kmeans.fit_predict(latent)
    
    ari = adjusted_rand_score(true_labels, pred_labels)
    nmi = normalized_mutual_info_score(true_labels, pred_labels)
    asw = silhouette_score(latent, true_labels)
    
    return {'ARI': ari, 'NMI': nmi, 'ASW': asw}

# Encode true labels
subtype_map = {'A1': 0, 'A2': 1, 'B1': 2, 'B2': 3, 'B3': 4}
true_labels = adata.obs['subtype'].map(subtype_map).values

# Evaluate all models
results = {
    'Baseline': evaluate_model(latent_baseline, true_labels),
    'Î²-VAE': evaluate_model(latent_beta, true_labels),
    'Lorentzian': evaluate_model(latent_lorentz, true_labels),
    'Full LiVAE': evaluate_model(latent_full, true_labels)
}

# Display results
df_results = pd.DataFrame(results).T
print("\nQuantitative Evaluation (Subtype Clustering):")
print("="*60)
print(df_results.round(3))
print("="*60)
print("\nMetrics:")
print("  ARI: Adjusted Rand Index (higher is better, max=1.0)")
print("  NMI: Normalized Mutual Information (higher is better, max=1.0)")
print("  ASW: Average Silhouette Width (higher is better, range=[-1,1])")

# Plot comparison
fig, ax = plt.subplots(figsize=(10, 5))
df_results.plot(kind='bar', ax=ax, rot=0)
ax.set_ylabel('Score')
ax.set_title('Model Comparison: Clustering Metrics')
ax.legend(loc='lower right')
ax.grid(alpha=0.3, axis='y')
plt.tight_layout()
plt.show()

## 8. Parameter Guidelines

### When to Use Each Regularization:

| Parameter | Use Case | Typical Range |
|-----------|----------|---------------|
| `beta` | General disentanglement | 1.0 - 5.0 |
| `lorentz` | Hierarchical/trajectory data | 0.5 - 3.0 |
| `irecon` | Need interpretable features | 0.5 - 2.0 |
| `dip` | Force independent dimensions | 0.1 - 1.0 |
| `tc` | Minimize redundancy | 0.1 - 1.0 |
| `info` | Match latent distribution | 0.1 - 1.0 |

### Recommended Starting Points:

**For standard single-cell analysis:**
```python
model = agent(adata, beta=1.0, lorentz=0.0, irecon=0.0)
```

**For developmental/trajectory data:**
```python
model = agent(adata, beta=2.0, lorentz=1.5, irecon=1.0)
```

**For maximum disentanglement:**
```python
model = agent(adata, beta=4.0, dip=1.0, tc=0.5, irecon=1.0)
```

## 9. Summary

In [None]:
print("="*70)
print("LiVAE Regularization Tutorial Summary")
print("="*70)
print("\nâœ… Explored regularization techniques:")
print("   1. Î²-VAE: Disentanglement through KL weighting")
print("   2. Lorentzian: Hyperbolic geometry for hierarchies")
print("   3. irecon: Interpretable compressed features")
print("   4. DIP, TC, InfoVAE: Additional disentanglement")
print("\nâœ… Compared 4 model configurations")
print("\nâœ… Evaluated clustering performance quantitatively")
print("\nðŸ“Š Key Takeaway:")
print("   Different regularizations suit different data structures.")
print("   Start simple (beta=1.0) and add complexity as needed.")
print("="*70)