# TRANS_017 - Sample Integration with scVI

## Overview
This notebook integrates all 4 preprocessed CITE-seq samples using scVI (single-cell Variational Inference), a deep learning-based method that:
- Removes batch effects while preserving biological variation
- Handles multi-modal data (RNA + protein) simultaneously
- Generates a shared low-dimensional representation
- Enables joint clustering and visualization

## Why scVI for CITE-seq?
- **totalVI**: Extension of scVI specifically designed for CITE-seq
- **Probabilistic framework**: Models both technical and biological variation
- **Scalable**: Efficient for large datasets
- **State-of-the-art**: Superior performance compared to traditional methods (CCA, MNN, etc.)

## Workflow:
1. Load all preprocessed samples
2. Merge and prepare data
3. Set up totalVI model
4. Train the model
5. Extract integrated embeddings
6. Perform joint clustering
7. Visualize and analyze integrated data
8. Differential expression analysis
9. Cell type annotation

In [None]:
# Import libraries
import scanpy as sc
import scvi
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Set up plotting
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=100, facecolor='white', frameon=False)
sns.set_style('whitegrid')

# Set random seed for reproducibility
import random
import torch
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

print(f"Scanpy version: {sc.__version__}")
print(f"scvi-tools version: {scvi.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Define Paths and Parameters

In [None]:
# Define paths
DATA_DIR = Path("../../data/processed/")
OUTPUT_DIR = Path("../../data/integrated/")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Sample names
SAMPLES = ['sample1', 'sample2', 'sample3', 'sample4']

# Integration parameters
INTEGRATION_PARAMS = {
    'n_latent': 30,           # Latent space dimensions (totalVI recommendation: 20-30)
    'n_layers': 2,            # Number of hidden layers
    'max_epochs': 400,        # Training epochs (adjust based on convergence)
    'batch_size': 128,        # Batch size for training
    'early_stopping': True,   # Stop if validation loss plateaus
}

print(f"Integrating {len(SAMPLES)} samples")
print(f"Output directory: {OUTPUT_DIR}")

## 2. Load All Preprocessed Samples

Load the h5ad files created during preprocessing. Each contains:
- Normalized RNA data
- Normalized protein data (ADT)
- Individual sample clustering
- QC metrics

In [None]:
# Load all samples
adatas = {}
for sample in SAMPLES:
    file_path = DATA_DIR / sample / f"{sample}_processed.h5ad"
    if file_path.exists():
        adatas[sample] = sc.read_h5ad(file_path)
        print(f"Loaded {sample}: {adatas[sample].shape}")
    else:
        print(f"Warning: {file_path} not found! Run preprocessing first.")

if len(adatas) == 0:
    raise ValueError("No processed samples found! Please run preprocessing notebooks first.")

print(f"\nSuccessfully loaded {len(adatas)} samples")

## 3. Quality Check Before Integration

Verify that all samples have compatible features and check for batch effects.

In [None]:
# Summary statistics per sample
summary_data = []
for sample_name, adata in adatas.items():
    summary_data.append({
        'Sample': sample_name,
        'Cells': adata.n_obs,
        'Genes': adata.n_vars,
        'Proteins': len(adata.uns['protein_names']),
        'Median_genes': adata.obs['n_genes_by_counts'].median(),
        'Median_counts': adata.obs['total_counts'].median(),
        'Clusters': len(adata.obs['leiden'].unique()),
    })

summary_df = pd.DataFrame(summary_data)
print("\nPer-sample summary:")
print(summary_df.to_string(index=False))
print(f"\nTotal cells: {summary_df['Cells'].sum():,}")

## 4. Concatenate Samples

Merge all samples into a single AnnData object. We need to:
1. Ensure common genes across samples
2. Concatenate observations (cells)
3. Preserve sample identity for batch correction

In [None]:
# Concatenate all samples
# Use 'inner' join to keep only common genes
adata_concat = sc.concat(
    adatas,
    axis=0,
    join='inner',  # Keep only genes present in all samples
    label='sample',
    keys=SAMPLES,
    index_unique='_'
)

print(f"\nConcatenated data shape: {adata_concat.shape}")
print(f"Samples: {adata_concat.obs['sample'].value_counts().to_dict()}")

# Verify protein data is present
print(f"\nProtein data available: {'protein_clr' in adata_concat.obsm}")
if 'protein_clr' in adata_concat.obsm:
    print(f"Protein matrix shape: {adata_concat.obsm['protein_clr'].shape}")

## 5. Visualize Batch Effects (Before Integration)

Let's see how much the samples differ before integration. Strong separation by sample indicates batch effects.

In [None]:
# Quick PCA and UMAP for visualization
# Use the existing processed data (already normalized and scaled)
sc.tl.pca(adata_concat, n_comps=50)
sc.pp.neighbors(adata_concat, n_pcs=30)
sc.tl.umap(adata_concat)

# Plot before integration
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

sc.pl.umap(adata_concat, color='sample', ax=axes[0], show=False, title='Before Integration - by Sample')
sc.pl.umap(adata_concat, color='leiden', ax=axes[1], show=False, title='Before Integration - by Cluster')
sc.pl.umap(adata_concat, color='n_genes_by_counts', ax=axes[2], show=False, title='Before Integration - nGenes')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "before_integration.png", dpi=300, bbox_inches='tight')
plt.show()

print("If samples form distinct clusters, batch correction is needed!")

## 6. Prepare Data for totalVI

totalVI requires:
1. Raw counts for RNA (in .layers['counts'])
2. Protein data (normalized, in .obsm)
3. Batch information (sample labels)
4. Proper gene filtering

In [None]:
# Prepare for totalVI
# totalVI needs raw counts in .X
adata_totalvi = adata_concat.copy()

# Set raw counts as main layer
if 'counts' in adata_totalvi.layers:
    adata_totalvi.X = adata_totalvi.layers['counts'].copy()
else:
    print("Warning: No raw counts found! Using current X")

# Ensure protein data is properly formatted
# totalVI expects protein data in .obsm['protein_expression']
if 'protein_counts' in adata_totalvi.obsm:
    adata_totalvi.obsm['protein_expression'] = adata_totalvi.obsm['protein_counts'].copy()
    print(f"Protein expression shape: {adata_totalvi.obsm['protein_expression'].shape}")
else:
    raise ValueError("No protein data found! Check preprocessing.")

# Store protein names
if 'protein_names' in adata_totalvi.uns:
    protein_names = adata_totalvi.uns['protein_names']
    print(f"Number of proteins: {len(protein_names)}")
else:
    raise ValueError("Protein names not found!")

print(f"\nData prepared for totalVI:")
print(f"RNA: {adata_totalvi.shape}")
print(f"Proteins: {adata_totalvi.obsm['protein_expression'].shape[1]}")
print(f"Batches: {adata_totalvi.obs['sample'].nunique()}")

## 7. Filter Genes for Integration

For better performance and faster training:
- Keep highly variable genes from each sample
- Typically 2000-4000 genes is sufficient

In [None]:
# Select highly variable genes across all samples
# This improves computational efficiency and focuses on informative genes
sc.pp.highly_variable_genes(
    adata_totalvi,
    n_top_genes=3000,
    flavor='seurat_v3',
    batch_key='sample',  # Calculate HVGs per batch
    subset=False  # Don't subset yet, just mark
)

n_hvg = adata_totalvi.var['highly_variable'].sum()
print(f"Highly variable genes: {n_hvg}")

# Subset to HVGs
adata_totalvi = adata_totalvi[:, adata_totalvi.var['highly_variable']].copy()
print(f"Filtered data shape: {adata_totalvi.shape}")

## 8. Set Up totalVI Model

### totalVI Architecture:
- **Encoder**: Learns latent representation from both RNA and protein
- **Decoder**: Reconstructs RNA and protein from latent space
- **Batch correction**: Removes sample-specific effects
- **Background modeling**: Handles protein background (important for CITE-seq)

### Key parameters:
- `n_latent`: Dimensionality of latent space (like PCA dimensions)
- `n_layers`: Neural network depth
- `gene_likelihood`: Distribution for RNA (negative binomial is standard)

In [None]:
# Register the AnnData with scvi
# This tells scvi where to find batch info and protein data
scvi.model.TOTALVI.setup_anndata(
    adata_totalvi,
    batch_key='sample',
    protein_expression_obsm_key='protein_expression',
    protein_names_uns_key='protein_names'
)

print("AnnData registered with totalVI")
print(f"\nData summary:")
print(adata_totalvi)

In [None]:
# Initialize the totalVI model
model = scvi.model.TOTALVI(
    adata_totalvi,
    n_latent=INTEGRATION_PARAMS['n_latent'],
    n_layers_encoder=INTEGRATION_PARAMS['n_layers'],
    n_layers_decoder=INTEGRATION_PARAMS['n_layers'],
    gene_likelihood='nb',  # Negative binomial for count data
)

print("totalVI model initialized")
print(f"\nModel parameters:")
print(f"  Latent dimensions: {INTEGRATION_PARAMS['n_latent']}")
print(f"  Hidden layers: {INTEGRATION_PARAMS['n_layers']}")
print(f"  Batches: {adata_totalvi.obs['sample'].nunique()}")

## 9. Train the Model

This is the computationally intensive step. The model learns to:
1. Encode cells into latent space
2. Decode back to RNA and protein
3. Minimize reconstruction error
4. Remove batch effects

**Training tips:**
- GPU recommended for large datasets (>50k cells)
- Monitor training loss - should decrease and plateau
- Early stopping prevents overfitting
- Typical training time: 10-30 minutes (depends on data size)

In [None]:
# Train the model
print("Starting model training...")
print("This may take 10-30 minutes depending on dataset size and hardware.")
print("Watch the training loss - it should decrease steadily.\n")

model.train(
    max_epochs=INTEGRATION_PARAMS['max_epochs'],
    batch_size=INTEGRATION_PARAMS['batch_size'],
    early_stopping=INTEGRATION_PARAMS['early_stopping'],
    early_stopping_patience=20,  # Stop if no improvement for 20 epochs
    train_size=0.9,  # Use 90% for training, 10% for validation
)

print("\n✓ Training complete!")

## 10. Evaluate Model Performance

Check training history and convergence.

In [None]:
# Plot training history
train_history = model.history

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# ELBO loss (Evidence Lower Bound)
epochs = train_history['elbo_train'].index
axes[0].plot(epochs, train_history['elbo_train'], label='Train')
axes[0].plot(epochs, train_history['elbo_validation'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('ELBO Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Reconstruction loss
axes[1].plot(epochs, train_history['reconstruction_loss_train'], label='Train')
axes[1].plot(epochs, train_history['reconstruction_loss_validation'], label='Validation')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Reconstruction Loss')
axes[1].set_title('Reconstruction Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "training_history.png", dpi=300, bbox_inches='tight')
plt.show()

print("Loss curves should show steady decrease and convergence.")
print("If validation loss increases, the model may be overfitting.")

## 11. Extract Integrated Latent Representation

The latent representation is the batch-corrected, low-dimensional embedding that integrates RNA and protein data.

In [None]:
# Get latent representation
# This is the integrated, batch-corrected embedding
latent = model.get_latent_representation()

print(f"Latent representation shape: {latent.shape}")
print(f"Dimensions: {latent.shape[1]} (cells compressed from {adata_totalvi.n_vars} genes + {len(protein_names)} proteins)")

# Store in AnnData
adata_totalvi.obsm['X_totalvi'] = latent

# Also get normalized expression for downstream analysis
adata_totalvi.layers['totalvi_normalized'] = model.get_normalized_expression(
    n_samples=25  # Monte Carlo samples for better estimates
)

# Get denoised protein expression
adata_totalvi.obsm['protein_totalvi'] = model.get_normalized_expression(
    n_samples=25,
    return_mean=True,
    transform_batch=list(adata_totalvi.obs['sample'].unique())  # Average across batches
)[1]  # [1] returns protein data

print("\nIntegrated representations extracted")

## 12. Compute UMAP on Integrated Data

Now we'll visualize the integrated latent space. If integration worked well:
- Cells should mix by sample (no strong batch effects)
- Biological cell types should cluster together
- Similar to pre-integration clustering but more mixed

In [None]:
# Compute neighborhood graph on integrated latent space
sc.pp.neighbors(
    adata_totalvi,
    use_rep='X_totalvi',  # Use totalVI latent space
    n_neighbors=15
)

# Compute UMAP
sc.tl.umap(adata_totalvi, min_dist=0.3)

print("UMAP computed on integrated data")

## 13. Integrated Clustering

Perform clustering on the integrated data. This should identify cell types across all samples.

In [None]:
# Leiden clustering on integrated data
resolutions = [0.3, 0.5, 0.7, 1.0, 1.5]

for res in resolutions:
    sc.tl.leiden(
        adata_totalvi,
        resolution=res,
        key_added=f'leiden_integrated_r{res}'
    )
    n_clusters = len(adata_totalvi.obs[f'leiden_integrated_r{res}'].unique())
    print(f"Resolution {res}: {n_clusters} clusters")

# Set default integrated clustering
adata_totalvi.obs['leiden_integrated'] = adata_totalvi.obs['leiden_integrated_r0.7']

print(f"\nDefault integrated clustering: {len(adata_totalvi.obs['leiden_integrated'].unique())} clusters")

## 14. Visualize Integration Results

Compare before and after integration to assess batch correction quality.

In [None]:
# Comprehensive visualization
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

# Row 1: After integration
sc.pl.umap(adata_totalvi, color='sample', ax=axes[0, 0], show=False, 
           title='Integrated - by Sample', frameon=False)
sc.pl.umap(adata_totalvi, color='leiden_integrated', ax=axes[0, 1], show=False,
           title='Integrated - by Cluster', frameon=False, legend_loc='on data', legend_fontsize=6)
sc.pl.umap(adata_totalvi, color='n_genes_by_counts', ax=axes[0, 2], show=False,
           title='Integrated - nGenes', frameon=False)
sc.pl.umap(adata_totalvi, color='total_counts', ax=axes[0, 3], show=False,
           title='Integrated - Total Counts', frameon=False)

# Row 2: Sample composition and metrics
# Sample distribution per cluster
cluster_sample_counts = pd.crosstab(
    adata_totalvi.obs['leiden_integrated'],
    adata_totalvi.obs['sample'],
    normalize='index'
)
cluster_sample_counts.plot(kind='bar', stacked=True, ax=axes[1, 0], 
                           colormap='Set3', legend=True)
axes[1, 0].set_xlabel('Cluster')
axes[1, 0].set_ylabel('Proportion')
axes[1, 0].set_title('Sample Composition per Cluster')
axes[1, 0].legend(title='Sample', bbox_to_anchor=(1, 1))

# Cluster sizes
cluster_sizes = adata_totalvi.obs['leiden_integrated'].value_counts().sort_index()
axes[1, 1].bar(range(len(cluster_sizes)), cluster_sizes.values, color='steelblue')
axes[1, 1].set_xlabel('Cluster')
axes[1, 1].set_ylabel('Number of Cells')
axes[1, 1].set_title('Cluster Sizes')
axes[1, 1].set_xticks(range(len(cluster_sizes)))
axes[1, 1].set_xticklabels(cluster_sizes.index, rotation=45)

# QC metrics by sample
sample_qc = adata_totalvi.obs.groupby('sample')[['n_genes_by_counts', 'total_counts']].median()
x = np.arange(len(sample_qc))
width = 0.35
axes[1, 2].bar(x - width/2, sample_qc['n_genes_by_counts'], width, label='Genes', color='coral')
axes[1, 2].bar(x + width/2, sample_qc['total_counts']/10, width, label='Counts/10', color='skyblue')
axes[1, 2].set_xlabel('Sample')
axes[1, 2].set_ylabel('Median Value')
axes[1, 2].set_title('QC Metrics by Sample')
axes[1, 2].set_xticks(x)
axes[1, 2].set_xticklabels(sample_qc.index, rotation=45)
axes[1, 2].legend()

# Summary text
axes[1, 3].axis('off')
summary_text = f"""
INTEGRATION SUMMARY

Total cells: {adata_totalvi.n_obs:,}
Total genes: {adata_totalvi.n_vars:,}
Total proteins: {len(protein_names)}

Samples: {len(SAMPLES)}
Clusters: {len(adata_totalvi.obs['leiden_integrated'].unique())}

Latent dimensions: {INTEGRATION_PARAMS['n_latent']}
Training epochs: {len(train_history)}

Cells per sample:
"""
for sample in SAMPLES:
    n = (adata_totalvi.obs['sample'] == sample).sum()
    summary_text += f"  {sample}: {n:,}\n"

axes[1, 3].text(0.1, 0.5, summary_text, fontsize=10, family='monospace',
                verticalalignment='center')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / "integration_summary.png", dpi=300, bbox_inches='tight')
plt.show()

print("\nIntegration complete! Check plots for quality assessment.")

## 15. Visualize Multi-Resolution Clustering

Different resolutions reveal different levels of cell type granularity.

In [None]:
# Plot multiple resolutions
sc.pl.umap(
    adata_totalvi,
    color=['leiden_integrated_r0.3', 'leiden_integrated_r0.5', 
           'leiden_integrated_r0.7', 'leiden_integrated_r1.0', 'leiden_integrated_r1.5'],
    ncols=3,
    frameon=False,
    save='_multi_resolution_clustering.png'
)

## 16. Visualize Protein Expression

One of the key advantages of CITE-seq: direct protein measurements to validate cell identities.

In [None]:
# Get denoised protein expression from totalVI
print(f"Available proteins ({len(protein_names)}):")
print(protein_names)

# Create a separate adata for proteins
import anndata
adata_prot = anndata.AnnData(
    X=adata_totalvi.obsm['protein_totalvi'],
    obs=adata_totalvi.obs,
    var=pd.DataFrame(index=protein_names)
)
adata_prot.obsm['X_umap'] = adata_totalvi.obsm['X_umap']

# Plot key immune markers (adjust based on your panel)
common_markers = ['CD3', 'CD4', 'CD8', 'CD19', 'CD14', 'CD56', 'CD45', 'CD16']
available_markers = []
for marker in common_markers:
    matching = [p for p in protein_names if marker in p]
    if matching:
        available_markers.extend(matching[:1])  # Take first match

if len(available_markers) > 0:
    print(f"\nPlotting {len(available_markers)} key markers...")
    sc.pl.umap(
        adata_prot,
        color=available_markers[:min(9, len(available_markers))],
        ncols=3,
        vmax='p99',
        cmap='RdBu_r',
        frameon=False,
        save='_key_proteins.png'
    )
else:
    print("\nNo common immune markers found, plotting first 9 proteins...")
    sc.pl.umap(
        adata_prot,
        color=protein_names[:min(9, len(protein_names))],
        ncols=3,
        vmax='p99',
        cmap='RdBu_r',
        frameon=False,
        save='_proteins.png'
    )

## 17. Find Marker Genes per Cluster

Identify genes that distinguish each cluster for cell type annotation.

In [None]:
# Differential expression analysis
print("Finding marker genes for each cluster...")
print("This may take several minutes...\n")

sc.tl.rank_genes_groups(
    adata_totalvi,
    groupby='leiden_integrated',
    method='wilcoxon',  # Non-parametric test
    key_added='rank_genes_integrated'
)

print("Marker gene analysis complete!")

In [None]:
# Visualize top marker genes
sc.pl.rank_genes_groups(
    adata_totalvi,
    n_genes=20,
    sharey=False,
    key='rank_genes_integrated',
    save='_marker_genes.png'
)

# Show top 5 markers per cluster
print("\nTop 5 marker genes per cluster:")
print("="*60)
result = adata_totalvi.uns['rank_genes_integrated']
groups = result['names'].dtype.names
for group in groups:
    print(f"\nCluster {group}:")
    genes = result['names'][group][:5]
    scores = result['scores'][group][:5]
    for gene, score in zip(genes, scores):
        print(f"  {gene:20s} (score: {score:.2f})")

## 18. Visualize Key Marker Genes

Plot expression of canonical cell type markers to aid annotation.

In [None]:
# Define canonical cell type markers
# Adjust based on your expected cell types
marker_genes = {
    'T cells': ['CD3D', 'CD3E', 'CD3G'],
    'CD4 T cells': ['CD4', 'IL7R'],
    'CD8 T cells': ['CD8A', 'CD8B'],
    'B cells': ['CD19', 'MS4A1', 'CD79A'],  # MS4A1 = CD20
    'NK cells': ['NCAM1', 'NKG7', 'GNLY'],  # NCAM1 = CD56
    'Monocytes': ['CD14', 'FCGR3A', 'S100A8'],  # FCGR3A = CD16
    'Dendritic cells': ['FCER1A', 'CD1C'],
    'Proliferating': ['MKI67', 'TOP2A'],
}

# Find available markers
available_markers = []
for cell_type, genes in marker_genes.items():
    for gene in genes:
        if gene in adata_totalvi.var_names:
            available_markers.append(gene)

if len(available_markers) > 0:
    print(f"Plotting {len(available_markers)} canonical markers...")
    sc.pl.umap(
        adata_totalvi,
        color=available_markers[:min(12, len(available_markers))],
        ncols=4,
        vmax='p99',
        cmap='Reds',
        frameon=False,
        save='_canonical_markers.png'
    )
    
    # Dotplot for clearer visualization
    sc.pl.dotplot(
        adata_totalvi,
        available_markers[:min(20, len(available_markers))],
        groupby='leiden_integrated',
        save='_marker_dotplot.png'
    )
else:
    print("No canonical markers found in dataset")

## 19. Compute Integration Quality Metrics

Quantify how well integration worked using standard metrics.

In [None]:
# Calculate mixing metrics
# These quantify how well samples are mixed in the integrated space

from scipy.stats import entropy

def calculate_mixing_metric(adata, cluster_key, batch_key):
    """
    Calculate mixing metric: how well batches are mixed within clusters.
    Value of 1 = perfect mixing, 0 = no mixing
    """
    mixing_scores = []
    for cluster in adata.obs[cluster_key].unique():
        cluster_mask = adata.obs[cluster_key] == cluster
        batch_counts = adata.obs[cluster_mask][batch_key].value_counts()
        batch_props = batch_counts / batch_counts.sum()
        
        # Calculate normalized entropy
        max_entropy = np.log(len(batch_props))
        if max_entropy > 0:
            mixing_score = entropy(batch_props) / max_entropy
            mixing_scores.append(mixing_score)
    
    return np.mean(mixing_scores)

mixing_score = calculate_mixing_metric(
    adata_totalvi,
    'leiden_integrated',
    'sample'
)

print("\n" + "="*60)
print("INTEGRATION QUALITY METRICS")
print("="*60)
print(f"\nMixing score: {mixing_score:.3f}")
print("  (1.0 = perfect mixing, 0.0 = no mixing)")
print("\nInterpretation:")
if mixing_score > 0.8:
    print("  ✓ Excellent integration - samples well mixed")
elif mixing_score > 0.6:
    print("  ✓ Good integration - acceptable mixing")
elif mixing_score > 0.4:
    print("  ⚠ Moderate integration - some batch effects remain")
else:
    print("  ✗ Poor integration - strong batch effects")

print("\nNote: Some separation is expected if samples have different biology!")
print("="*60)

## 20. Save Integrated Data

Save the fully integrated dataset for downstream analysis.

In [None]:
# Save integrated AnnData
output_file = OUTPUT_DIR / "integrated_totalvi.h5ad"
adata_totalvi.write(output_file)

print(f"Integrated data saved to: {output_file}")
print(f"File size: {output_file.stat().st_size / 1024**2:.1f} MB")

# Save model for future use
model_dir = OUTPUT_DIR / "totalvi_model"
model.save(model_dir, overwrite=True)
print(f"\ntotalVI model saved to: {model_dir}")

# Save marker genes
marker_df = sc.get.rank_genes_groups_df(
    adata_totalvi,
    group=None,
    key='rank_genes_integrated'
)
marker_df.to_csv(OUTPUT_DIR / "marker_genes.csv", index=False)
print(f"Marker genes saved to: {OUTPUT_DIR / 'marker_genes.csv'}")

## 21. Generate Final Summary Report

In [None]:
# Create comprehensive summary
final_summary = {
    'n_samples': len(SAMPLES),
    'total_cells': adata_totalvi.n_obs,
    'total_genes': adata_totalvi.n_vars,
    'total_proteins': len(protein_names),
    'n_clusters': len(adata_totalvi.obs['leiden_integrated'].unique()),
    'latent_dimensions': INTEGRATION_PARAMS['n_latent'],
    'training_epochs': len(train_history),
    'mixing_score': f"{mixing_score:.3f}",
}

# Per-sample statistics
for sample in SAMPLES:
    n_cells = (adata_totalvi.obs['sample'] == sample).sum()
    final_summary[f'{sample}_cells'] = n_cells

# Cluster statistics
cluster_stats = adata_totalvi.obs.groupby('leiden_integrated').agg({
    'sample': 'count',
    'n_genes_by_counts': 'median',
    'total_counts': 'median'
})
cluster_stats.columns = ['n_cells', 'median_genes', 'median_counts']
cluster_stats.to_csv(OUTPUT_DIR / "cluster_statistics.csv")

# Save summary
summary_df = pd.DataFrame([final_summary])
summary_df.to_csv(OUTPUT_DIR / "integration_summary.csv", index=False)

print("\n" + "="*60)
print("FINAL INTEGRATION SUMMARY")
print("="*60)
for key, value in final_summary.items():
    print(f"{key:25s}: {value}")
print("="*60)

print(f"\n✓ Integration complete!")
print(f"\nAll outputs saved to: {OUTPUT_DIR}")
print("\nNext steps:")
print("  1. Annotate cell types based on marker genes and proteins")
print("  2. Perform differential expression between conditions")
print("  3. Investigate cell-cell interactions")
print("  4. Trajectory/pseudotime analysis if relevant")

## Next Steps: Cell Type Annotation Template

Use marker genes and proteins to assign cell types:

In [None]:
# TEMPLATE: Manual cell type annotation
# Examine marker genes and proteins for each cluster, then assign labels

# Example annotation dictionary (adjust based on your data)
cluster_annotations = {
    '0': 'CD4+ T cells',
    '1': 'CD8+ T cells',
    '2': 'B cells',
    '3': 'NK cells',
    '4': 'CD14+ Monocytes',
    '5': 'CD16+ Monocytes',
    # Add more as needed...
}

# Apply annotations
# adata_totalvi.obs['cell_type'] = adata_totalvi.obs['leiden_integrated'].map(cluster_annotations)

# Visualize
# sc.pl.umap(adata_totalvi, color='cell_type', legend_loc='on data')

print("Uncomment and modify the code above after examining your clusters!")