# 03 - Seurat Integration

Compare multiple Seurat integration methods: CCA, RPCA, Harmony, FastMNN.

This notebook uses R via rpy2 to run Seurat's `IntegrateLayers()` function.

## Workflow
1. Load merged h5ad
2. Convert to Seurat object in R
3. Preprocess in Seurat
4. Run multiple integration methods
5. Compare results

## Outputs
- `integrated_seurat.h5ad` - Object with multiple reductions
- `figures/comparison/` - Method comparison plots
- `metrics/integration_metrics.csv` - Quantitative comparison

In [None]:
import sys
sys.path.insert(0, "..")

import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import yaml

# Local utilities
from utils.preprocessing import store_raw_counts
from utils.evaluation import compare_integration_methods
from utils.visualization import plot_method_comparison, plot_metrics_comparison

sc.settings.verbosity = 2
sc.settings.set_figure_params(dpi=100, facecolor="white")

In [None]:
# Load rpy2
%load_ext rpy2.ipython

import rpy2.robjects as ro
from rpy2.robjects import pandas2ri, numpy2ri
from rpy2.robjects.packages import importr

pandas2ri.activate()
numpy2ri.activate()

## Configuration

In [None]:
config = {
    "input": {
        "h5ad_path": "./results/merged.h5ad",
        "batch_key": "dataset",
    },
    "preprocessing": {
        "n_top_genes": 3000,
        "regress_out": None,  # e.g., ["pct_counts_mt"]
        "n_pcs": 50,
    },
    "integration": {
        "key": "sample_id",
        "seurat": {
            "methods": ["cca", "rpca", "harmony", "fastmnn"],
        },
    },
    "clustering": {
        "resolutions": [0.5],
        "n_neighbors": 30,
    },
    "output": {
        "dir": "./results/seurat/",
    },
}

In [None]:
# Extract config
input_path = Path(config["input"]["h5ad_path"])
batch_key = config["input"]["batch_key"]
integration_key = config["integration"]["key"]
methods = config["integration"]["seurat"]["methods"]
output_dir = Path(config["output"]["dir"])

# Create output directories
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "figures").mkdir(exist_ok=True)
(output_dir / "metrics").mkdir(exist_ok=True)

print(f"Input: {input_path}")
print(f"Integration key: {integration_key}")
print(f"Methods: {methods}")
print(f"Output: {output_dir}")

## Load data

In [None]:
print(f"Loading {input_path}...")
adata = sc.read_h5ad(input_path)
print(f"Shape: {adata.shape}")
print(f"Batches ({integration_key}): {adata.obs[integration_key].nunique()}")

In [None]:
# Store raw counts if not present
if "counts" not in adata.layers:
    store_raw_counts(adata, layer_name="counts")

## Load R packages

In [None]:
%%R
library(Seurat)
library(SeuratObject)
library(harmony)
library(SeuratWrappers)  # For FastMNN
library(batchelor)  # For FastMNN backend

## Convert to Seurat and preprocess

In [None]:
# Extract count matrix and metadata for R
counts = adata.layers["counts"] if "counts" in adata.layers else adata.X
if hasattr(counts, "toarray"):
    counts = counts.toarray()

# Transpose for Seurat (genes x cells)
counts_t = counts.T

# Cell and gene names
cell_names = adata.obs_names.tolist()
gene_names = adata.var_names.tolist()

# Metadata
metadata = adata.obs.copy()

print(f"Count matrix shape: {counts_t.shape}")

In [None]:
# Push to R
ro.globalenv["counts_matrix"] = counts_t
ro.globalenv["cell_names"] = ro.StrVector(cell_names)
ro.globalenv["gene_names"] = ro.StrVector(gene_names)
ro.globalenv["metadata"] = ro.conversion.py2rpy(metadata)
ro.globalenv["integration_key"] = integration_key

In [None]:
%%R
# Create Seurat object
rownames(counts_matrix) <- gene_names
colnames(counts_matrix) <- cell_names
rownames(metadata) <- cell_names

seurat_obj <- CreateSeuratObject(
    counts = counts_matrix,
    meta.data = metadata
)

print(seurat_obj)

In [None]:
%%R -i config
# Normalize and find variable features
seurat_obj <- NormalizeData(seurat_obj)
seurat_obj <- FindVariableFeatures(seurat_obj, nfeatures = 3000)

# Scale (optionally regress out covariates)
# vars_to_regress <- c("pct_counts_mt")  # Uncomment if needed
seurat_obj <- ScaleData(seurat_obj)  # Add vars.to.regress if needed

# Run PCA
seurat_obj <- RunPCA(seurat_obj, npcs = 50)

print("Preprocessing complete")

In [None]:
%%R
# Split by integration key for IntegrateLayers
seurat_obj[[integration_key]] <- as.factor(seurat_obj[[integration_key, drop=TRUE]])
seurat_obj <- SplitObject(seurat_obj, split.by = integration_key)
seurat_obj <- merge(seurat_obj[[1]], seurat_obj[-1])
seurat_obj[['RNA']] <- split(seurat_obj[['RNA']], f = seurat_obj[[integration_key]])

# Re-run scaling and PCA on merged/split object
seurat_obj <- NormalizeData(seurat_obj)
seurat_obj <- FindVariableFeatures(seurat_obj, nfeatures = 3000)
seurat_obj <- ScaleData(seurat_obj)
seurat_obj <- RunPCA(seurat_obj, npcs = 50, reduction.name = "pca")

print("Split and re-processed")

## Run Integration Methods

In [None]:
%%R
# CCA Integration
if ("cca" %in% c("cca", "rpca", "harmony", "fastmnn")) {
    message("Running CCA integration...")
    seurat_obj <- IntegrateLayers(
        object = seurat_obj,
        method = CCAIntegration,
        orig.reduction = "pca",
        new.reduction = "integrated_cca"
    )
    message("CCA complete")
}

In [None]:
%%R
# RPCA Integration
if ("rpca" %in% c("cca", "rpca", "harmony", "fastmnn")) {
    message("Running RPCA integration...")
    seurat_obj <- IntegrateLayers(
        object = seurat_obj,
        method = RPCAIntegration,
        orig.reduction = "pca",
        new.reduction = "integrated_rpca"
    )
    message("RPCA complete")
}

In [None]:
%%R
# Harmony Integration (via Seurat)
if ("harmony" %in% c("cca", "rpca", "harmony", "fastmnn")) {
    message("Running Harmony integration...")
    seurat_obj <- IntegrateLayers(
        object = seurat_obj,
        method = HarmonyIntegration,
        orig.reduction = "pca",
        new.reduction = "integrated_harmony"
    )
    message("Harmony complete")
}

In [None]:
%%R
# FastMNN Integration
if ("fastmnn" %in% c("cca", "rpca", "harmony", "fastmnn")) {
    message("Running FastMNN integration...")
    seurat_obj <- IntegrateLayers(
        object = seurat_obj,
        method = FastMNNIntegration,
        orig.reduction = "pca",
        new.reduction = "integrated_mnn"
    )
    message("FastMNN complete")
}

## Compute UMAP for each method

In [None]:
%%R
# Join layers back for downstream
seurat_obj <- JoinLayers(seurat_obj)

# UMAP for each integration
reductions <- c("integrated_cca", "integrated_rpca", "integrated_harmony", "integrated_mnn")

for (red in reductions) {
    if (red %in% names(seurat_obj@reductions)) {
        umap_name <- gsub("integrated_", "umap_", red)
        message(paste("Computing UMAP for", red))
        seurat_obj <- FindNeighbors(seurat_obj, reduction = red, dims = 1:30)
        seurat_obj <- RunUMAP(seurat_obj, reduction = red, dims = 1:30, reduction.name = umap_name)
    }
}

# Also compute for uncorrected PCA
seurat_obj <- FindNeighbors(seurat_obj, reduction = "pca", dims = 1:30)
seurat_obj <- RunUMAP(seurat_obj, reduction = "pca", dims = 1:30, reduction.name = "umap_uncorrected")

print("UMAPs computed")

## Extract embeddings back to Python

In [None]:
%%R -o embeddings_dict -o cell_order
# Extract all embeddings
embeddings_dict <- list()
cell_order <- colnames(seurat_obj)

for (red_name in names(seurat_obj@reductions)) {
    embeddings_dict[[red_name]] <- as.matrix(Embeddings(seurat_obj, reduction = red_name))
}

print(names(embeddings_dict))

In [None]:
# Store embeddings in adata
# Reorder adata to match Seurat cell order
cell_order_py = list(cell_order)
adata = adata[cell_order_py].copy()

# Map R reduction names to adata obsm keys
name_map = {
    "pca": "X_pca",
    "umap_uncorrected": "X_umap_uncorrected",
    "integrated_cca": "X_seurat_cca",
    "integrated_rpca": "X_seurat_rpca",
    "integrated_harmony": "X_seurat_harmony",
    "integrated_mnn": "X_seurat_mnn",
    "umap_cca": "X_umap_cca",
    "umap_rpca": "X_umap_rpca",
    "umap_harmony": "X_umap_harmony",
    "umap_mnn": "X_umap_mnn",
}

for r_name, py_name in name_map.items():
    if r_name in embeddings_dict.keys():
        adata.obsm[py_name] = np.array(embeddings_dict[r_name])
        print(f"Stored {r_name} as {py_name}: {adata.obsm[py_name].shape}")

## Compare Integration Methods

In [None]:
# Define embeddings for comparison
embeddings = {
    "Uncorrected": "X_umap_uncorrected",
    "CCA": "X_umap_cca",
    "RPCA": "X_umap_rpca",
    "Harmony": "X_umap_harmony",
    "FastMNN": "X_umap_mnn",
}

# Filter to existing
embeddings = {k: v for k, v in embeddings.items() if v in adata.obsm}

In [None]:
# Plot comparison colored by batch
fig = plot_method_comparison(
    adata,
    embeddings=embeddings,
    color_by=integration_key,
    ncols=3,
    save_path=output_dir / "figures" / "comparison_by_batch.png",
)
plt.show()

In [None]:
# If celltype column exists, plot by celltype
if "celltype" in adata.obs.columns:
    fig = plot_method_comparison(
        adata,
        embeddings=embeddings,
        color_by="celltype",
        ncols=3,
        save_path=output_dir / "figures" / "comparison_by_celltype.png",
    )
    plt.show()

## Compute Metrics

In [None]:
# Define integrated embeddings (not UMAP)
integrated_embeddings = {
    "Uncorrected": "X_pca",
    "CCA": "X_seurat_cca",
    "RPCA": "X_seurat_rpca",
    "Harmony": "X_seurat_harmony",
    "FastMNN": "X_seurat_mnn",
}
integrated_embeddings = {k: v for k, v in integrated_embeddings.items() if v in adata.obsm}

# Compare
metrics_df = compare_integration_methods(
    adata,
    batch_key=integration_key,
    embeddings=integrated_embeddings,
)

display(metrics_df)

In [None]:
# Save metrics
metrics_df.to_csv(output_dir / "metrics" / "integration_metrics.csv")
print(f"Metrics saved to {output_dir / 'metrics' / 'integration_metrics.csv'}")

In [None]:
# Plot metrics comparison
fig = plot_metrics_comparison(
    metrics_df,
    save_path=output_dir / "figures" / "metrics_comparison.png",
)
plt.show()

## Save Results

In [None]:
# Save integrated object
output_path = output_dir / "integrated_seurat.h5ad"
print(f"Saving to {output_path}...")
adata.write_h5ad(output_path)
print("Done!")

## Summary

Seurat integration comparison complete. The object contains multiple embeddings:

- `X_pca` - Uncorrected PCA
- `X_seurat_cca` - CCA integrated
- `X_seurat_rpca` - RPCA integrated
- `X_seurat_harmony` - Harmony integrated
- `X_seurat_mnn` - FastMNN integrated

And corresponding UMAPs for visualization.

Review the metrics and visualizations to select the best method for your data.