## 1. Load Data

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
from livae import agent

# Scanpy settings
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=80, facecolor='white')

# Load PBMC3k dataset (3,000 PBMCs from a healthy donor)
adata = sc.datasets.pbmc3k()

print(f"\nLoaded PBMC3k dataset:")
print(f"  Cells: {adata.n_obs}")
print(f"  Genes: {adata.n_vars}")
print(f"  Layers: {list(adata.layers.keys())}")
print(f"\nFirst look at the data:")
print(adata)

## 2. Quality Control and Preprocessing

In [None]:
# Store raw counts before filtering
adata.layers['counts'] = adata.X.copy()

# Calculate QC metrics
sc.pp.calculate_qc_metrics(adata, percent_top=None, log1p=False, inplace=True)

# Visualize QC metrics
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].hist(adata.obs['n_genes_by_counts'], bins=60, edgecolor='black')
axes[0].set_xlabel('Number of genes')
axes[0].set_ylabel('Number of cells')
axes[0].set_title('Genes per cell')

axes[1].hist(adata.obs['total_counts'], bins=60, edgecolor='black')
axes[1].set_xlabel('Total counts')
axes[1].set_ylabel('Number of cells')
axes[1].set_title('UMI counts per cell')

axes[2].scatter(adata.obs['total_counts'], adata.obs['n_genes_by_counts'], 
                alpha=0.3, s=5)
axes[2].set_xlabel('Total counts')
axes[2].set_ylabel('Number of genes')
axes[2].set_title('Counts vs Genes')

plt.tight_layout()
plt.show()

print(f"\nQC Statistics:")
print(f"  Mean genes per cell: {adata.obs['n_genes_by_counts'].mean():.0f}")
print(f"  Mean counts per cell: {adata.obs['total_counts'].mean():.0f}")

In [None]:
# Filter cells and genes
print(f"Before filtering: {adata.n_obs} cells, {adata.n_vars} genes")

# Filter cells with too few or too many genes
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_cells(adata, max_genes=2500)

# Filter genes expressed in too few cells
sc.pp.filter_genes(adata, min_cells=3)

print(f"After filtering: {adata.n_obs} cells, {adata.n_vars} genes")

# Update counts layer
adata.layers['counts'] = adata.X.copy()

## 3. Standard Scanpy Processing (for comparison)

In [None]:
# Normalize and log-transform for traditional analysis
adata_processed = adata.copy()
sc.pp.normalize_total(adata_processed, target_sum=1e4)
sc.pp.log1p(adata_processed)

# Find highly variable genes
sc.pp.highly_variable_genes(adata_processed, n_top_genes=2000)
print(f"\nIdentified {adata_processed.var['highly_variable'].sum()} highly variable genes")

# PCA
sc.tl.pca(adata_processed, svd_solver='arpack')

# UMAP
sc.pp.neighbors(adata_processed, n_neighbors=10, n_pcs=40)
sc.tl.umap(adata_processed)

# Leiden clustering
sc.tl.leiden(adata_processed)

print(f"\nStandard processing complete:")
print(f"  PCA: {adata_processed.obsm['X_pca'].shape}")
print(f"  UMAP: {adata_processed.obsm['X_umap'].shape}")
print(f"  Clusters: {len(adata_processed.obs['leiden'].unique())}")

## 4. Train LiVAE Model

In [None]:
# Train LiVAE on raw counts
print("Training LiVAE model...\n")

model = agent(
    adata=adata,                # Use original adata with raw counts
    layer='counts',             # Use count data
    latent_dim=20,              # Higher dimensional for complex data
    i_dim=2,                    # 2D for visualization
    hidden_dim=128,             # Larger hidden layer
    percent=0.1,                # 10% of cells per batch (~270 cells)
    lr=1e-3,                    # Learning rate
    # Regularization for biological data
    beta=2.0,                   # Moderate disentanglement
    lorentz=1.0,                # Hyperbolic geometry
    irecon=0.5,                 # Interpretable features
)

# Train for 100 epochs
model.fit(epochs=100)

print("\n✅ LiVAE training complete!")

## 5. Extract LiVAE Embeddings

In [None]:
# Extract embeddings
latent_livae = model.get_latent()           # 20D latent representation
iembed_livae = model.get_iembed()           # 2D interpretable embedding

# Add to AnnData
adata.obsm['X_livae'] = latent_livae
adata.obsm['X_livae_2d'] = iembed_livae

print(f"LiVAE embeddings:")
print(f"  Latent (20D): {latent_livae.shape}")
print(f"  Interpretable (2D): {iembed_livae.shape}")

# Compute UMAP on LiVAE latent space for fair comparison
adata_livae = adata.copy()
sc.pp.neighbors(adata_livae, use_rep='X_livae', n_neighbors=10)
sc.tl.umap(adata_livae)
sc.tl.leiden(adata_livae)

print(f"\nLiVAE clustering: {len(adata_livae.obs['leiden'].unique())} clusters")

## 6. Compare Visualizations

In [None]:
# Visualize all methods
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# PCA
sc.pl.pca(adata_processed, color='leiden', ax=axes[0, 0], show=False, title='PCA')

# Standard UMAP (from PCA)
sc.pl.umap(adata_processed, color='leiden', ax=axes[0, 1], show=False, title='UMAP (from PCA)')

# LiVAE interpretable embedding (2D)
scatter = axes[0, 2].scatter(
    iembed_livae[:, 0], iembed_livae[:, 1],
    c=adata_livae.obs['leiden'].astype(int),
    cmap='tab10', s=5, alpha=0.7
)
axes[0, 2].set_title('LiVAE Interpretable (2D)')
axes[0, 2].set_xlabel('Dimension 1')
axes[0, 2].set_ylabel('Dimension 2')
plt.colorbar(scatter, ax=axes[0, 2], label='Cluster')

# UMAP from LiVAE latent
sc.pl.umap(adata_livae, color='leiden', ax=axes[1, 0], show=False, title='UMAP (from LiVAE)')

# Gene expression overlay on LiVAE
# CD3D: T cells
if 'CD3D' in adata.var_names:
    gene_expr = adata[:, 'CD3D'].X.toarray().flatten() if hasattr(adata[:, 'CD3D'].X, 'toarray') else adata[:, 'CD3D'].X.flatten()
    scatter = axes[1, 1].scatter(
        iembed_livae[:, 0], iembed_livae[:, 1],
        c=gene_expr, cmap='viridis', s=5, alpha=0.7
    )
    axes[1, 1].set_title('LiVAE: CD3D (T cells)')
    axes[1, 1].set_xlabel('Dimension 1')
    axes[1, 1].set_ylabel('Dimension 2')
    plt.colorbar(scatter, ax=axes[1, 1], label='Expression')

# CD79A: B cells
if 'CD79A' in adata.var_names:
    gene_expr = adata[:, 'CD79A'].X.toarray().flatten() if hasattr(adata[:, 'CD79A'].X, 'toarray') else adata[:, 'CD79A'].X.flatten()
    scatter = axes[1, 2].scatter(
        iembed_livae[:, 0], iembed_livae[:, 1],
        c=gene_expr, cmap='viridis', s=5, alpha=0.7
    )
    axes[1, 2].set_title('LiVAE: CD79A (B cells)')
    axes[1, 2].set_xlabel('Dimension 1')
    axes[1, 2].set_ylabel('Dimension 2')
    plt.colorbar(scatter, ax=axes[1, 2], label='Expression')

plt.tight_layout()
plt.show()

## 7. Quantitative Comparison

In [None]:
from sklearn.metrics import silhouette_score, calinski_harabasz_score

# Compute metrics for each representation
methods = {
    'PCA': (adata_processed.obsm['X_pca'][:, :20], adata_processed.obs['leiden']),
    'UMAP': (adata_processed.obsm['X_umap'], adata_processed.obs['leiden']),
    'LiVAE_latent': (adata_livae.obsm['X_livae'], adata_livae.obs['leiden']),
    'LiVAE_2D': (adata_livae.obsm['X_livae_2d'], adata_livae.obs['leiden'])
}

results = []
for name, (embedding, labels) in methods.items():
    labels_int = labels.astype(int)
    silhouette = silhouette_score(embedding, labels_int)
    calinski = calinski_harabasz_score(embedding, labels_int)
    results.append({
        'Method': name,
        'Silhouette': silhouette,
        'Calinski-Harabasz': calinski
    })

df_results = pd.DataFrame(results)
print("\nClustering Quality Metrics:")
print("="*60)
print(df_results.to_string(index=False))
print("="*60)
print("\nHigher is better for both metrics")
print("Silhouette: [-1, 1], measures cluster cohesion")
print("Calinski-Harabasz: [0, ∞), measures cluster separation")

## 8. Cell Type Annotation (Optional)

In [None]:
# Find marker genes for LiVAE clusters
sc.tl.rank_genes_groups(adata_livae, 'leiden', method='wilcoxon')

# Plot top marker genes
sc.pl.rank_genes_groups(adata_livae, n_genes=5, sharey=False)

# Visualize key marker genes
marker_genes = ['CD3D', 'CD8A', 'CD4', 'CD79A', 'MS4A1', 'CD14', 'LYZ', 'NKG7', 'GNLY']
available_markers = [g for g in marker_genes if g in adata_livae.var_names]

if len(available_markers) > 0:
    sc.pl.umap(adata_livae, color=available_markers[:6], ncols=3, cmap='viridis')

print("\nCell type markers help interpret clusters:")
print("  CD3D, CD8A: T cells")
print("  CD79A, MS4A1: B cells")
print("  CD14, LYZ: Monocytes")
print("  NKG7, GNLY: NK cells")

## 9. Export Results

In [None]:
# Export embeddings
embeddings_df = pd.DataFrame(
    adata_livae.obsm['X_livae'],
    columns=[f'LiVAE_{i}' for i in range(20)],
    index=adata_livae.obs_names
)
embeddings_df['cluster'] = adata_livae.obs['leiden'].values

# Save to CSV
# embeddings_df.to_csv('pbmc3k_livae_embeddings.csv')

print("Embeddings ready for export:")
print(embeddings_df.head())
print(f"\nShape: {embeddings_df.shape}")

## 10. Summary

In [None]:
print("="*70)
print("LiVAE Real Data Analysis Summary")
print("="*70)
print(f"\nDataset: PBMC3k")
print(f"  Final size: {adata.n_obs} cells × {adata.n_vars} genes")
print(f"\nLiVAE Model:")
print(f"  Latent dimension: 20")
print(f"  Interpretable dimension: 2")
print(f"  Training epochs: 100")
print(f"  Regularization: beta=2.0, lorentz=1.0, irecon=0.5")
print(f"\nClusters identified: {len(adata_livae.obs['leiden'].unique())}")
print(f"\n✅ Successfully analyzed real single-cell data with LiVAE")
print(f"\nKey advantages of LiVAE:")
print(f"  • Works directly on raw counts (no normalization needed)")
print(f"  • Learns interpretable 2D embeddings")
print(f"  • Captures hierarchical relationships via hyperbolic geometry")
print(f"  • Provides both high-D and low-D representations")
print("="*70)