# BioBatchNet Tutorial

BioBatchNet is a deep learning framework for batch effect correction in biological data, supporting both single-cell RNA-seq (scRNA-seq) and Imaging Mass Cytometry (IMC) data.

This tutorial covers:
1. Quick Start - Using the simple API
2. Advanced Usage - Using models directly
3. Custom Configuration - Adjusting model architecture and training parameters
4. Real-world Examples

## 1. Installation and Imports

In [None]:
# Install the package (if not already installed)
# !pip install biobatchnet

# Or install from source
# !git clone https://github.com/Manchester-HealthAI/BioBatchNet
# !cd BioBatchNet && pip install -e .

In [None]:
# Import necessary packages
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# Import BioBatchNet
import biobatchnet
from biobatchnet import correct_batch_effects, IMCVAE, GeneVAE

print(f"BioBatchNet version: {biobatchnet.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 2. Prepare Example Data

In [None]:
# Generate simulated data for demonstration
np.random.seed(42)

# Simulate IMC data: 1000 cells, 40 protein markers, 3 batches
n_cells = 1000
n_features = 40
n_batches = 3

# Generate base data
base_data = np.random.randn(n_cells, n_features)

# Add batch effects
batch_labels = np.random.choice(n_batches, n_cells)
batch_effects = np.zeros_like(base_data)
for i in range(n_batches):
    batch_mask = batch_labels == i
    # Add different shifts for each batch
    batch_effects[batch_mask] = np.random.randn(1, n_features) * 0.5

# Final data = base data + batch effects
data_with_batch = base_data + batch_effects

# Convert to DataFrame
data_df = pd.DataFrame(
    data_with_batch, 
    columns=[f'Protein_{i+1}' for i in range(n_features)]
)

batch_df = pd.DataFrame({
    'batch_id': batch_labels,
    'cell_id': [f'cell_{i}' for i in range(n_cells)]
})

print(f"Data shape: {data_df.shape}")
print(f"Batch distribution: {np.bincount(batch_labels)}")

In [None]:
# Visualize batch effects
def plot_batch_effect(data, batch_labels, title):
    """Visualize batch effects using PCA"""
    pca = PCA(n_components=2)
    data_pca = pca.fit_transform(StandardScaler().fit_transform(data))
    
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(data_pca[:, 0], data_pca[:, 1], 
                         c=batch_labels, cmap='viridis', alpha=0.6)
    plt.colorbar(scatter, label='Batch')
    plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%})')
    plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%})')
    plt.title(title)
    plt.show()

# Show batch effects in original data
plot_batch_effect(data_with_batch, batch_labels, 'Original Data (with batch effects)')

## 3. Method 1: Using the Simple API (Recommended)

This is the easiest way to use BioBatchNet, suitable for most users.

In [None]:
# Batch effect correction with default parameters
bio_embeddings, batch_embeddings = correct_batch_effects(
    data=data_df,           # Expression data
    batch_info=batch_df,    # Batch information
    batch_key='batch_id',   # Batch column name
    data_type='imc',        # Data type: 'imc' or 'scrna'
    latent_dim=20,          # Latent space dimension
    epochs=100              # Training epochs
)

print(f"Biological embedding shape: {bio_embeddings.shape}")
print(f"Batch embedding shape: {batch_embeddings.shape}")

In [None]:
# Visualize corrected data
plot_batch_effect(bio_embeddings, batch_labels, 'Corrected Biological Embeddings')

### 3.1 Custom Loss Weights

Adjust loss weights based on your data characteristics.

In [None]:
# Custom loss weights
custom_loss_weights = {
    'recon_loss': 10,      # Reconstruction loss weight
    'discriminator': 0.3,   # Discriminator loss weight
    'classifier': 1,        # Classifier loss weight
    'kl_loss_1': 0.005,    # KL divergence loss 1
    'kl_loss_2': 0.1,      # KL divergence loss 2
    'ortho_loss': 0.01     # Orthogonality loss weight
}

bio_embeddings_custom, batch_embeddings_custom = correct_batch_effects(
    data=data_df,
    batch_info=batch_df,
    batch_key='batch_id',
    data_type='imc',
    latent_dim=20,
    epochs=100,
    loss_weights=custom_loss_weights  # Use custom weights
)

print("Training completed with custom loss weights")

### 3.2 Automatic Parameter Adjustment for Different Batch Counts

The API automatically adjusts parameters based on the number of batches.

In [None]:
# Simulate data with different batch counts
def test_different_batch_counts():
    for n_batches in [3, 15, 35]:
        # Create test data
        test_batch_labels = np.random.choice(n_batches, n_cells)
        test_batch_df = pd.DataFrame({'batch_id': test_batch_labels})
        
        print(f"\nNumber of batches: {n_batches}")
        print("API will automatically select appropriate loss weights")
        
        # API automatically adjusts parameters
        bio_emb, batch_emb = correct_batch_effects(
            data=data_df,
            batch_info=test_batch_df,
            data_type='imc',
            epochs=50  # Fewer epochs for demonstration
        )
        
        print(f"Complete! Embedding dimensions: {bio_emb.shape}")

# test_different_batch_counts()  # Uncomment to run

## 4. Method 2: Using Models Directly (Advanced)

For more control, use the model classes directly.

In [None]:
# Create IMCVAE model instance
model = IMCVAE(
    in_sz=n_features,                              # Input dimension
    out_sz=n_features,                             # Output dimension  
    latent_sz=20,                                  # Latent space dimension
    num_batch=n_batches,                           # Number of batches
    bio_encoder_hidden_layers=[512, 1024, 1024],   # Bio encoder architecture
    batch_encoder_hidden_layers=[256],             # Batch encoder architecture
    decoder_hidden_layers=[1024, 1024, 512],       # Decoder architecture
    batch_classifier_layers_power=[512, 1024, 1024], # Strong classifier architecture
    batch_classifier_layers_weak=[128]             # Weak classifier architecture
)

print(f"Model parameter count: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Train the model
model.fit(
    data=data_df.values,
    batch_info=batch_labels,
    epochs=100,
    lr=1e-3,
    batch_size=256,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

print("Model training complete")

In [None]:
# Get corrected embeddings
bio_embeddings_direct = model.get_bio_embeddings(data_df.values)
print(f"Biological embedding shape: {bio_embeddings_direct.shape}")

# Or get both biological and batch embeddings
bio_emb, batch_emb = model.correct_batch_effects(data_df.values)
print(f"Biological embeddings: {bio_emb.shape}, Batch embeddings: {batch_emb.shape}")

## 5. Method 3: Using GeneVAE for scRNA-seq Data

GeneVAE is specifically designed for single-cell RNA-seq data.

In [None]:
# Simulate scRNA-seq data
n_cells_rna = 5000
n_genes = 2000
n_batches_rna = 4

# Generate sparse gene expression data (simulating dropout)
rna_data = np.random.negative_binomial(5, 0.3, size=(n_cells_rna, n_genes))
rna_data = rna_data.astype(np.float32)

# Add batch effects
rna_batch_labels = np.random.choice(n_batches_rna, n_cells_rna)
for i in range(n_batches_rna):
    mask = rna_batch_labels == i
    rna_data[mask] *= np.random.uniform(0.8, 1.2)  # Batch-specific scaling

print(f"scRNA-seq data shape: {rna_data.shape}")
print(f"Zero proportion: {(rna_data == 0).mean():.2%}")

In [None]:
# Process scRNA-seq data using the API
bio_emb_rna, batch_emb_rna = correct_batch_effects(
    data=rna_data,
    batch_info=rna_batch_labels,
    data_type='scrna',  # Specify scRNA-seq data
    latent_dim=30,      # Usually scRNA needs higher latent dimension
    epochs=100
)

print(f"scRNA embedding shape: {bio_emb_rna.shape}")

In [None]:
# Using GeneVAE model directly
gene_model = GeneVAE(
    in_sz=n_genes,
    out_sz=n_genes,
    latent_sz=30,
    num_batch=n_batches_rna,
    bio_encoder_hidden_layers=[500, 2000, 2000],   # Default scRNA architecture
    batch_encoder_hidden_layers=[500],
    decoder_hidden_layers=[2000, 2000, 500],
    batch_classifier_layers_power=[500, 2000, 2000],
    batch_classifier_layers_weak=[128]
)

# Custom loss weights for scRNA-seq
scrna_loss_weights = {
    'recon_loss': 10,
    'discriminator': 0.04,
    'classifier': 1,
    'kl_loss_1': 1e-7,
    'kl_loss_2': 0.01,
    'ortho_loss': 0.0002,
    'mmd_loss_1': 0,
    'kl_loss_size': 0.002  # scRNA-specific size factor KL loss
}

gene_model.fit(
    data=rna_data,
    batch_info=rna_batch_labels,
    epochs=100,
    loss_weights=scrna_loss_weights
)

print("GeneVAE model training complete")

## 6. Advanced Configuration and Tuning Tips

In [None]:
# Advanced configuration example
advanced_config = {
    # Data parameters
    'data': data_df,
    'batch_info': batch_df,
    'batch_key': 'batch_id',
    'data_type': 'imc',
    
    # Model architecture parameters
    'latent_dim': 25,
    'bio_encoder_hidden_layers': [256, 512, 512],  # Custom encoder architecture
    'batch_encoder_hidden_layers': [128, 128],     # Two-layer batch encoder
    'decoder_hidden_layers': [512, 512, 256],      # Custom decoder
    
    # Training parameters
    'epochs': 150,
    'lr': 5e-4,           # Learning rate
    'batch_size': 128,    # Batch size
    
    # Loss weights
    'loss_weights': {
        'recon_loss': 15,
        'discriminator': 0.2,
        'classifier': 1.5,
        'kl_loss_1': 0.001,
        'kl_loss_2': 0.05,
        'ortho_loss': 0.02
    }
}

# Use advanced configuration
bio_emb_advanced, batch_emb_advanced = correct_batch_effects(**advanced_config)
print("Advanced configuration training complete")

## 7. Evaluate Batch Correction Performance

In [None]:
from sklearn.metrics import silhouette_score
from scipy.stats import f_oneway

def evaluate_batch_correction(original_data, corrected_data, batch_labels):
    """
    Evaluate batch correction performance
    """
    # 1. Silhouette coefficient (lower is better for batch mixing)
    sil_original = silhouette_score(original_data, batch_labels)
    sil_corrected = silhouette_score(corrected_data, batch_labels)
    
    print(f"Silhouette coefficient (batch separation, lower is better):")
    print(f"  Original data: {sil_original:.4f}")
    print(f"  Corrected: {sil_corrected:.4f}")
    print(f"  Improvement: {(sil_original - sil_corrected) / sil_original * 100:.1f}%\n")
    
    # 2. ANOVA F-statistic (lower is better for less batch difference)
    groups_original = [original_data[batch_labels == i] for i in range(len(np.unique(batch_labels)))]
    groups_corrected = [corrected_data[batch_labels == i] for i in range(len(np.unique(batch_labels)))]
    
    # Calculate F-statistic for first 5 features
    f_stats_original = []
    f_stats_corrected = []
    
    n_features_to_test = min(5, original_data.shape[1])
    for i in range(n_features_to_test):
        f_orig, _ = f_oneway(*[g[:, i] for g in groups_original])
        f_corr, _ = f_oneway(*[g[:, i] for g in groups_corrected])
        f_stats_original.append(f_orig)
        f_stats_corrected.append(f_corr)
    
    print(f"Mean F-statistic (batch difference, lower is better):")
    print(f"  Original data: {np.mean(f_stats_original):.4f}")
    print(f"  Corrected: {np.mean(f_stats_corrected):.4f}")
    print(f"  Improvement: {(np.mean(f_stats_original) - np.mean(f_stats_corrected)) / np.mean(f_stats_original) * 100:.1f}%")
    
    return sil_corrected, np.mean(f_stats_corrected)

# Evaluate correction performance
print("=" * 50)
print("Batch Correction Performance Evaluation")
print("=" * 50)
evaluate_batch_correction(data_with_batch, bio_embeddings, batch_labels)

## 8. Visualization Comparison

In [None]:
# Create comparison visualization
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Original data PCA
pca_original = PCA(n_components=2)
data_pca_original = pca_original.fit_transform(StandardScaler().fit_transform(data_with_batch))

# Corrected data PCA
pca_corrected = PCA(n_components=2)
data_pca_corrected = pca_corrected.fit_transform(StandardScaler().fit_transform(bio_embeddings))

# Plot original data
scatter1 = axes[0].scatter(data_pca_original[:, 0], data_pca_original[:, 1],
                           c=batch_labels, cmap='viridis', alpha=0.6, s=20)
axes[0].set_xlabel(f'PC1 ({pca_original.explained_variance_ratio_[0]:.2%})')
axes[0].set_ylabel(f'PC2 ({pca_original.explained_variance_ratio_[1]:.2%})')
axes[0].set_title('Original Data (with batch effects)')
axes[0].legend(*scatter1.legend_elements(), title="Batch", loc="best")

# Plot corrected data
scatter2 = axes[1].scatter(data_pca_corrected[:, 0], data_pca_corrected[:, 1],
                           c=batch_labels, cmap='viridis', alpha=0.6, s=20)
axes[1].set_xlabel(f'PC1 ({pca_corrected.explained_variance_ratio_[0]:.2%})')
axes[1].set_ylabel(f'PC2 ({pca_corrected.explained_variance_ratio_[1]:.2%})')
axes[1].set_title('After BioBatchNet Correction')
axes[1].legend(*scatter2.legend_elements(), title="Batch", loc="best")

plt.suptitle('Before and After Batch Effect Correction', fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

## 9. Save and Load Models

In [None]:
# Save trained model
model_path = 'biobatchnet_model.pt'
torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': {
        'in_sz': n_features,
        'out_sz': n_features,
        'latent_sz': 20,
        'num_batch': n_batches,
        'bio_encoder_hidden_layers': [512, 1024, 1024],
        'batch_encoder_hidden_layers': [256],
        'decoder_hidden_layers': [1024, 1024, 512],
        'batch_classifier_layers_power': [512, 1024, 1024],
        'batch_classifier_layers_weak': [128]
    }
}, model_path)

print(f"Model saved to {model_path}")

In [None]:
# Load model
checkpoint = torch.load(model_path, map_location='cpu')
config = checkpoint['model_config']

# Recreate model
loaded_model = IMCVAE(**config)
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model.eval()

print("Model loaded successfully")

# Use loaded model for prediction
with torch.no_grad():
    bio_emb_loaded = loaded_model.get_bio_embeddings(data_df.values)
    print(f"Loaded model prediction shape: {bio_emb_loaded.shape}")

## 10. Common Questions and Tuning Suggestions

### 10.1 How to choose latent dimension (latent_dim)?
- IMC data: typically 15-25 dimensions
- scRNA-seq data: typically 20-50 dimensions
- Can be selected via cross-validation

### 10.2 Loss weight adjustment strategies
- **recon_loss**: Reconstruction quality, typically set to 10
- **discriminator**: Batch mixing degree, reduce for many batches (0.1-0.3)
- **classifier**: Batch information retention, typically set to 1
- **kl_loss**: Regularization strength, increase when overfitting
- **ortho_loss**: Orthogonality constraint, keep default

### 10.3 What if training is unstable?
- Reduce learning rate
- Decrease batch_size
- Adjust loss weights
- Increase training epochs

### 10.4 What if memory is insufficient?
- Reduce batch_size
- Use CPU training: device='cpu'
- Reduce model complexity (fewer hidden layer nodes)

## 11. Real Data Example (Using AnnData)

In [None]:
# If you have real h5ad files
import anndata as ad

# Example: Load and process real data
"""
# Load AnnData object
adata = ad.read_h5ad('your_data.h5ad')

# Extract data and batch information
if hasattr(adata.X, 'toarray'):
    X = adata.X.toarray()  # If sparse matrix
else:
    X = adata.X

batch_labels = adata.obs['batch'].values

# Batch effect correction
bio_embeddings, _ = correct_batch_effects(
    data=X,
    batch_info=batch_labels,
    data_type='scrna',  # or 'imc'
    epochs=200
)

# Store results back to AnnData
adata.obsm['X_biobatchnet'] = bio_embeddings

# Save results
adata.write('corrected_data.h5ad')
"""

print("Real data processing workflow example (requires h5ad file)")

## Summary

BioBatchNet provides flexible batch effect correction solutions:

1. **Simple API** (`correct_batch_effects`): Suitable for quick use with automatic parameter selection
2. **Model Classes** (`IMCVAE`, `GeneVAE`): Provide more control for advanced users
3. **Custom Configuration**: Adjust model architecture, loss weights, and other parameters

### Key Features:
- Supports both IMC and scRNA-seq data types
- Automatically adjusts parameters based on batch count
- Customizable model architecture and training parameters
- Provides both biological and batch embedding outputs

### Best Practices:
1. Start with default parameters
2. Adjust loss weights based on results
3. Adjust model architecture if necessary
4. Use evaluation metrics to validate performance

For more information, see:
- GitHub: https://github.com/Manchester-HealthAI/BioBatchNet
- Documentation: USAGE.md