# MuVIcell Tutorial: Multi-View Integration for Sample-Aggregated Single-Cell Data

This notebook demonstrates how to use the MuVIcell package for multi-view integration and analysis of sample-aggregated single-cell data using MuVI (Multi-View Integration).

## Overview

MuVIcell provides a streamlined workflow for:
1. **Generating/Loading** multi-view data in muon format (samples x features)
2. **Preprocessing** data for MuVI analysis
3. **Running MuVI** to identify latent factors using `muvi.tl.from_mdata`
4. **Analyzing** and interpreting factors
5. **Visualizing** results

Note: Each row represents a **sample** (not individual cells) and views contain **cell type aggregated data per sample**.

In [None]:
import muvicell
muvicell.__version__

In [None]:
import tensordict
tensordict.__version__

In [None]:
import numpy as np
import pandas as pd
import muon as mu
import warnings
warnings.filterwarnings('ignore')

from plotnine import *
import scanpy as sc

# Import muvicell modules
from muvicell import synthetic, preprocessing, analysis, visualization

# Import MuVI directly to show compatibility
import muvi

In [None]:
device = "cpu"
try:
    device = f"cuda:{muvi.get_free_gpu_idx()}"
except Exception as e:
    print(e)

## 1. Generate Synthetic Multi-View Data

Generate synthetic data with 3 views (5, 10, 15 features) and 200 samples:

In [None]:
# Generate synthetic multi-view data (3 views matching 3 true factors)
mdata = synthetic.generate_synthetic_data(
    n_samples=200,
    view_configs={
        'Cell Type 1': {'n_vars': 5, 'sparsity': 0.15},
        'Cell Type 2': {'n_vars': 10, 'sparsity': 0.25},
        'Cell Type 3': {'n_vars': 15, 'sparsity': 0.35}
    }
)

print(f"Generated synthetic data:")
print(f"- Samples: {mdata.n_obs}")
print(f"- Views: {len(mdata.mod)} ({', '.join([f'{k}: {v.n_vars} features' for k, v in mdata.mod.items()])})")
print(f"- Total features: {sum(v.n_vars for v in mdata.mod.values())}")

## 2. Add Latent Factor Structure

Add realistic latent factor structure to the synthetic data:

In [None]:
# Add latent structure with 3 factors (matching n_true_factors)
mdata_structured = synthetic.add_latent_structure(
    mdata, 
    n_latent_factors = 3,
    factor_variance = [0.5, 0.4, 0.3],
    structure_strength = 1.0,
    baseline_strength = 0.6
)

print(f"Sample metadata columns: {list(mdata_structured.obs.columns)}")

In [None]:
for mod in mdata_structured.mod:
    # Highly variable features can be used if there's enough of them
    sc.pp.pca(mdata_structured[mod], 
              use_highly_variable=False)
    sc.pp.neighbors(mdata_structured[mod])

mu.pp.neighbors(mdata_structured)

In [None]:
# Copy factor loadings to single view
for var in ['sim_factor_1', 'sim_factor_2', 'sim_factor_3']:
    mdata_structured[mod].obs[var] = mdata_structured.obs[var]

sc.tl.umap(mdata_structured[mod])
sc.pl.umap(mdata_structured[mod], color=['batch', 
                                        'sim_factor_1',
                                        'sim_factor_2',
                                        'sim_factor_3'])

In [None]:
mu.tl.umap(mdata_structured)
mu.pl.umap(mdata_structured, wspace=0.3, color=['batch', 
                                                'sim_factor_1',
                                                'sim_factor_2',
                                                'sim_factor_3'])

With these parameters, we create 3 latent factors with specified variances, a strong structured signal, and a moderate baseline signal across all features.

In [None]:
# Let's say an exposure variable is highly correlated with factor 1
mdata_structured.obs['exposure'] = [np.random.choice(["Hi", "Medium", "Low"], p = [0.7,0.2,0.1]) if x > 1 else
                                    np.random.choice(["Hi", "Medium", "Low"], p = [0.1,0.7,0.2]) if x > 0 else
                                    np.random.choice(["Hi", "Medium", "Low"], p = [0.1,0.2,0.7]) 
                                    for x in mdata_structured.obs['sim_factor_1']]
mdata_structured.obs['exposure'] = pd.Categorical(mdata_structured.obs['exposure'], 
                                                  categories=["Low", "Medium", "Hi"], 
                                                  ordered=True)
mdata_structured.obs['exposure'].value_counts()

## 3. Preprocess Data for MuVI

Apply preprocessing pipeline (optimized for synthetic data):

In [None]:
# Preprocess for MuVI analysis
mdata_processed = preprocessing.preprocess_for_muvi(
    mdata_structured,
    filter_cells=False,  # Don't filter synthetic data
    filter_genes=False,  # Don't filter synthetic data
    normalize=True,
    find_hvg=False,      # Skip HVG for synthetic data
    subset_hvg=False
)

print(f"Preprocessed data shape: {mdata_processed.shape}")
print("Data ready for MuVI analysis")

## 4. Run MuVI Analysis

Run MuVI using the exact same API as the original analysis, with 3 factors to match our synthetic data:

In [None]:
# Run MuVI using the standard API
model = muvi.tl.from_mdata(
    mdata_processed,
    n_factors=3,
    nmf=False,
    device=device
)

# Fit the model
model.fit()

print(f"MuVI model fitted with {model.n_factors} factors")

In [None]:
# Display variance explained
r2_pool = []
for vn in model.get_factor_loadings().keys():
    rec = model.get_factor_scores() @ model.get_factor_loadings()[vn]
    r2 = pd.DataFrame({'x': mdata_processed[vn].X.flatten(), 
                       'y': rec.flatten()}).corr()
    r2_pool.append(r2.iloc[0,1])
print(f"Macro R2: {np.mean(np.square(r2_pool))}")

# Check factor scores
factor_scores = model.get_factor_scores()
print(f"Factor scores shape: {factor_scores.shape}")

## (Bonus) Confirm the factors recovered match the simulation parameters
This is only possible here since we generated the data ourselves, and cannot be done in real applications.

In [None]:
factors_df = pd.DataFrame(
    np.hstack([mdata_processed.obsm['true_factors'], factor_scores]),
    columns=[f"True_Factor_{i+1}" for i in range(3)] + [f"MuVI_Factor_{i+1}" for i in range(model.n_factors)]
)
corr_factors = factors_df.corr(method='spearman')
corr_factors

We see that many of the true factors are well recovered, with high positive/negative correlation (> 0.5) between true and inferred factor scores. Some effects are split across multiple inferred factors, as different combinations of factors can explain the variance if they are independent.

## 5. Characterize Factors
Identify top genes contributing to each factor:

In [None]:
# =============================================================================
# MuVI reusable utilities: info + plot pairs
# =============================================================================
from typing import Dict, List, Optional, Sequence, Tuple, Callable
import numpy as np
import pandas as pd

from plotnine import (
    ggplot, aes, geom_col, geom_rect, geom_text, geom_tile, geom_violin, geom_path,
    scale_fill_gradientn, scale_fill_gradient2, scale_fill_manual,
    scale_x_continuous, scale_y_continuous,
    theme_classic, theme, element_text, coord_fixed, coord_flip, coord_equal,
    labs, ggtitle, guides
)
from plotnine import ggsave

# If you use liana, keep the import; the code below does not require it.
# import liana as li

# -----------------------------------------------------------------------------
# Helpers
# -----------------------------------------------------------------------------

def _to_factor_labels(names: Sequence[str]) -> List[str]:
    """
    Map ['factor_0','factor_1', ...] to ['Factor 1','Factor 2', ...] if pattern matches,
    otherwise return names unchanged.
    """
    out = []
    for n in names:
        try:
            if isinstance(n, str) and "factor_" in n:
                idx = int(n.split("_")[1])
                out.append(f"Factor {idx+1}")
            else:
                out.append(n)
        except Exception:
            out.append(n)
    return out

def _rename_factor_columns(df: pd.DataFrame) -> pd.DataFrame:
    """Rename factor_* columns to human readable Factor i."""
    renamed = df.copy()
    newcols = []
    for c in renamed.columns:
        if isinstance(c, str) and c.startswith("factor_"):
            try:
                newcols.append(f"Factor {int(c.split('_')[1]) + 1}")
            except Exception:
                newcols.append(c)
        else:
            newcols.append(c)
    renamed.columns = newcols
    return renamed

def _ggsave_if(p, save_path: Optional[str], width: float = 6, height: float = 4, dpi: int = 300, verbose: bool = False):
    if save_path:
        ggsave(save_path, plot=p, width=width, height=height, dpi=dpi, verbose=verbose)

def _nan_pearsonr(x: np.ndarray, y: np.ndarray) -> float:
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < 3:
        return np.nan
    return np.corrcoef(x[m], y[m])[0, 1]

def _factor_names_from_model(model) -> List[str]:
    # Try to pull names from factor scores df if available
    try:
        fs = model.get_factor_scores(as_df=True)
        cols = list(fs.columns)
        return cols
    except Exception:
        # Fallback to factor_0..factor_{k-1}
        try:
            k = model.n_factors
        except Exception:
            k = 10
        return [f"factor_{i}" for i in range(k)]

def _view_names(model, mdata) -> List[str]:
    for attr in ["view_names", "views", "modalities"]:
        if hasattr(model, attr):
            v = getattr(model, attr)
            return list(v)
    return list(mdata.mod.keys())

# -----------------------------------------------------------------------------
# 1) Reconstruction R2 per view
# -----------------------------------------------------------------------------

def muvi_reconstruction_info(model, mdata, views: Optional[Sequence[str]] = None, verbosity: int = 1) -> dict:
    """
    Compute R and R2 between original X and reconstructed X = scores @ loadings, per view.
    Returns dict with keys:
      'by_view': DataFrame[view, R, R2]
      'macro': {'R': macro_R, 'R2': macro_R2}
    """
    if views is None:
        views = _view_names(model, mdata)

    # scores: (n_obs x n_factors)
    scores = model.get_factor_scores()
    if isinstance(scores, pd.DataFrame):
        scores = scores.to_numpy()

    per_view = []
    loadings = model.get_factor_loadings()  # dict or similar keyed by view
    for v in views:
        L = loadings[v]  # shape: n_factors x n_features_v
        # rebuild
        rec = scores @ L  # (n_obs x n_features_v)
        X = mdata[v].X  # original (n_obs x n_features_v)
        r = _nan_pearsonr(np.asarray(X).ravel(), np.asarray(rec).ravel())
        per_view.append({"view": v, "R": r, "R2": None if pd.isna(r) else r * r})

    df = pd.DataFrame(per_view)
    macro_R = df["R"].mean(skipna=True)
    macro_R2 = df["R2"].mean(skipna=True)

    if verbosity:
        print("Reconstruction macro R:", np.round(macro_R, 3))
        print("Reconstruction macro R2:", np.round(macro_R2, 3))
        print(df.sort_values("R2", ascending=False).to_string(index=False))

    return {"by_view": df, "macro": {"R": macro_R, "R2": macro_R2}}

def muvi_reconstruction_plot(stats_df: pd.DataFrame,
                             title: str = "Reconstruction R2 by view",
                             save_path: Optional[str] = None,
                             width: float = 6, height: float = 4, dpi: int = 300):
    """
    Bar plot of R2 per view. Returns the ggplot object.
    """
    p = (
        ggplot(stats_df, aes(x="view", y="R2"))
        + geom_col()
        + theme_classic()
        + theme(axis_text_x=element_text(angle=45, hjust=1))
        + labs(title=title, x="View", y="R2")
    )
    _ggsave_if(p, save_path, width, height, dpi)
    return p

# -----------------------------------------------------------------------------
# 2) Variance explained by factors per view with marginal sums
# -----------------------------------------------------------------------------

def muvi_variance_by_view_info(model, view_name_transform: Optional[Callable[[str], str]] = None,
                               verbosity: int = 1) -> pd.DataFrame:
    """
    Long tidy DataFrame with columns [Factor, View, Variance].
    Uses muvi.tl.variance_explained(model)[1].
    """
    vexp = __import__("muvi").tl.variance_explained(model)[1]  # DataFrame factors x views
    df = vexp.copy()
    df.index.name = "Factor"
    df = df.reset_index().melt(id_vars="Factor", var_name="View", value_name="Variance")

    # Human friendly factor labels where possible
    df["Factor"] = _to_factor_labels(df["Factor"])
    # Optional view transformation
    if view_name_transform is not None:
        df["View"] = df["View"].map(view_name_transform)

    # Ordered category by factor index if parsable
    facs = pd.unique(df["Factor"])
    df["Factor"] = pd.Categorical(df["Factor"], categories=facs, ordered=True)

    if verbosity:
        print(df.head().to_string(index=False))
    return df

def muvi_variance_by_view_plot(df: pd.DataFrame,
                               subtitle: Optional[str] = None,
                               save_path: Optional[str] = None,
                               width: float = 6, height: float = 5, dpi: int = 300):
    """
    Heatmap with marginal sums. Returns ggplot.
    """
    # Compute marginals
    row_sums = df.groupby("View", as_index=False)["Variance"].sum()
    row_sums["Factor"] = "Sum"

    col_sums = df.groupby("Factor", as_index=False)["Variance"].sum()
    col_sums["View"] = "Sum"

    # sort views descending by total variance
    sorted_views = row_sums.sort_values("Variance", ascending=False)["View"].tolist()

    # Prepare extended frame with 'Sum' row and col
    factor_levels = list(df["Factor"].cat.categories if isinstance(df["Factor"].dtype, pd.CategoricalDtype) else df["Factor"].unique())
    factor_levels = list(factor_levels) + ["Sum"]
    view_levels = sorted_views + ["Sum"]

    dfm = pd.concat([df, row_sums, col_sums], ignore_index=True)
    dfm["Factor"] = pd.Categorical(dfm["Factor"], categories=factor_levels, ordered=True)
    dfm["View"] = pd.Categorical(dfm["View"], categories=view_levels, ordered=True)

    # Split for plotting
    main = dfm[(dfm["Factor"] != "Sum") & (dfm["View"] != "Sum")].copy()
    rowb = dfm[(dfm["Factor"] == "Sum") & (dfm["View"] != "Sum")].copy()
    colb = dfm[(dfm["View"] == "Sum") & (dfm["Factor"] != "Sum")].copy()

    # Normalize bar lengths
    rowb["bar_length"] = rowb["Variance"] / rowb["Variance"].max()
    colb["bar_length"] = colb["Variance"] / colb["Variance"].max()
    rowb["Variance_label"] = rowb["Variance"].round(2).astype(str)
    colb["Variance_label"] = colb["Variance"].round(2).astype(str)

    # Positions
    fac_no_sum = [x for x in factor_levels if x != "Sum"]
    view_no_sum = [x for x in view_levels if x != "Sum"]
    fpos = {k: i for i, k in enumerate(fac_no_sum)}
    vpos = {k: i for i, k in enumerate(view_no_sum)}

    main["x"] = main["Factor"].map(fpos)
    main["y"] = main["View"].map(vpos)
    colb["x"] = colb["Factor"].map(fpos)
    rowb["y"] = rowb["View"].map(vpos)

    # tile coords
    main["xmin"] = main["x"] - 0.5
    main["xmax"] = main["x"] + 0.5
    main["ymin"] = main["y"] - 0.5
    main["ymax"] = main["y"] + 0.5

    # top bars
    colb["xmin"] = colb["x"] - 0.5
    colb["xmax"] = colb["x"] + 0.5
    colb["ymin"] = len(view_no_sum) - 0.5
    colb["ymax"] = colb["ymin"] + colb["bar_length"]

    # right bars
    rowb["ymin"] = rowb["y"] - 0.5
    rowb["ymax"] = rowb["y"] + 0.5
    rowb["xmin"] = len(fac_no_sum) - 0.5
    rowb["xmax"] = rowb["xmin"] + rowb["bar_length"]

    p = (
        ggplot()
        + geom_rect(main, aes(xmin="xmin", xmax="xmax", ymin="ymin", ymax="ymax", fill="Variance"))
        + geom_rect(colb, aes(xmin="xmin", xmax="xmax", ymin="ymin", ymax="ymax"), fill="#acabab")
        + geom_text(colb, aes(x="x", y=main["ymax"].max() + 0.2, label="Variance_label"), va="bottom", size=8)
        + geom_rect(rowb, aes(xmin="xmin", xmax="xmax", ymin="ymin", ymax="ymax"), fill="#acabab")
        + geom_text(rowb, aes(x=main["xmax"].max() + 0.2, y="y", label="Variance_label"), ha="left", size=8)
        + scale_fill_gradientn(colors=["#EFF822", "#CC4977", "#0F0782"])
        + scale_x_continuous(breaks=list(range(len(fac_no_sum))), labels=fac_no_sum)
        + scale_y_continuous(breaks=list(range(len(view_no_sum))), labels=view_no_sum)
        + theme_classic()
        + theme(axis_text_x=element_text(angle=45, hjust=1))
        + labs(title="Variance explained by MuVI factors", subtitle=subtitle, x="Factor", y="View")
    )
    _ggsave_if(p, save_path, width, height, dpi)
    return p

# -----------------------------------------------------------------------------
# 3) Feature class vs factor variance explained (aggregated across views)
# -----------------------------------------------------------------------------

def muvi_featureclass_variance_info(model, mdata,
                                    feature_type_map: Dict[str, List[str]],
                                    aggregator: str = "median",
                                    verbosity: int = 1) -> pd.DataFrame:
    """
    Compute variance explained per factor for each feature class across views.
    feature_type_map: {'Class name': [feature1, feature2, ...]}
    aggregator: 'median' or 'mean' across views.
    Returns long DataFrame [Factor, Feature_type, Variance].
    """
    assert aggregator in {"median", "mean"}
    views = _view_names(model, mdata)
    import muvi as _muvi  # local import to avoid collisions

    rows = []
    for cls, feats in feature_type_map.items():
        per_view_vals = []
        for v in views:
            feats_in_view = [f for f in feats if f in list(mdata[v].var_names)]
            if len(feats_in_view) == 0:
                continue
            # r2[1] is per-factor variance for the selected features in this view
            r2 = _muvi.tl.variance_explained(model, view_idx=v, feature_idx=feats_in_view, cache=False, sort=False)[1]
            per_view_vals.append(r2.values.reshape(-1, 1))  # factors x 1

        if len(per_view_vals) == 0:
            continue

        A = np.hstack(per_view_vals)  # factors x n_views_present
        agg = np.median(A, axis=1) if aggregator == "median" else np.mean(A, axis=1)
        # Build factor names from r2.index
        factor_labels = _to_factor_labels(list(r2.index))
        rows.extend([{"Factor": f, "Feature_type": cls, "Variance": val} for f, val in zip(factor_labels, agg)])

    out = pd.DataFrame(rows)
    out["Factor"] = pd.Categorical(out["Factor"], categories=_to_factor_labels(sorted(set(x for x in out["Factor"]), key=lambda s: int(str(s).split()[-1]) if str(s).startswith("Factor") else 0)), ordered=True)

    if verbosity:
        print(out.head().to_string(index=False))
    return out

def muvi_featureclass_variance_plot(df: pd.DataFrame,
                                    save_path: Optional[str] = None,
                                    width: float = 5, height: float = 5, dpi: int = 300):
    """
    Heatmap with marginal sums for feature class vs factor. Returns ggplot.
    """
    # marginals
    row_sums = df.groupby("Feature_type", as_index=False)["Variance"].sum()
    row_sums["Factor"] = "Sum"
    col_sums = df.groupby("Factor", as_index=False)["Variance"].sum()
    col_sums["Feature_type"] = "Sum"

    # order feature types
    sorted_ft = row_sums.sort_values("Variance", ascending=False)["Feature_type"].tolist()

    # categories
    factor_levels = list(pd.unique(df["Factor"])) + ["Sum"]
    ft_levels = sorted_ft + ["Sum"]

    dfx = pd.concat([df, row_sums, col_sums], ignore_index=True)
    dfx["Factor"] = pd.Categorical(dfx["Factor"], categories=factor_levels, ordered=True)
    dfx["Feature_type"] = pd.Categorical(dfx["Feature_type"], categories=ft_levels, ordered=True)

    main = dfx[(dfx["Factor"] != "Sum") & (dfx["Feature_type"] != "Sum")].copy()
    rowb = dfx[(dfx["Factor"] == "Sum") & (dfx["Feature_type"] != "Sum")].copy()
    colb = dfx[(dfx["Feature_type"] == "Sum") & (dfx["Factor"] != "Sum")].copy()

    rowb["bar_length"] = rowb["Variance"] / rowb["Variance"].max()
    colb["bar_length"] = colb["Variance"] / colb["Variance"].max()
    rowb["Variance_label"] = rowb["Variance"].round(2).astype(str)
    colb["Variance_label"] = colb["Variance"].round(2).astype(str)

    fac_no_sum = [x for x in factor_levels if x != "Sum"]
    ft_no_sum = [x for x in ft_levels if x != "Sum"]
    fpos = {k: i for i, k in enumerate(fac_no_sum)}
    ftpos = {k: i for i, k in enumerate(ft_no_sum)}

    main["x"] = main["Factor"].map(fpos)
    main["y"] = main["Feature_type"].map(ftpos)
    colb["x"] = colb["Factor"].map(fpos)
    rowb["y"] = rowb["Feature_type"].map(ftpos)

    main["xmin"] = main["x"] - 0.5
    main["xmax"] = main["x"] + 0.5
    main["ymin"] = main["y"] - 0.5
    main["ymax"] = main["y"] + 0.5

    colb["xmin"] = colb["x"] - 0.5
    colb["xmax"] = colb["x"] + 0.5
    colb["ymin"] = len(ft_no_sum) - 0.5
    colb["ymax"] = colb["ymin"] + colb["bar_length"]

    rowb["ymin"] = rowb["y"] - 0.5
    rowb["ymax"] = rowb["y"] + 0.5
    rowb["xmin"] = len(fac_no_sum) - 0.5
    rowb["xmax"] = rowb["xmin"] + rowb["bar_length"]

    p = (
        ggplot()
        + geom_rect(main, aes(xmin="xmin", xmax="xmax", ymin="ymin", ymax="ymax", fill="Variance"))
        + geom_rect(colb, aes(xmin="xmin", xmax="xmax", ymin="ymin", ymax="ymax"), fill="#acabab")
        + geom_text(colb, aes(x="x", y=main["ymax"].max() + 0.2, label="Variance_label"), va="bottom", size=8)
        + geom_rect(rowb, aes(xmin="xmin", xmax="xmax", ymin="ymin", ymax="ymax"), fill="#acabab")
        + geom_text(rowb, aes(x=main["xmax"].max() + 0.2, y="y", label="Variance_label"), ha="left", size=8)
        + scale_fill_gradientn(colors=["#EFF822", "#CC4977", "#0F0782"])
        + scale_x_continuous(breaks=list(range(len(fac_no_sum))), labels=fac_no_sum)
        + scale_y_continuous(breaks=list(range(len(ft_no_sum))), labels=ft_no_sum)
        + theme_classic()
        + theme(axis_text_x=element_text(angle=45, hjust=1))
        + labs(title="Variance explained by MuVI factors", x="Factor", y="Feature type")
    )
    _ggsave_if(p, save_path, width, height, dpi)
    return p

# -----------------------------------------------------------------------------
# 4) Variable loadings per feature and top features heatmap
# -----------------------------------------------------------------------------

def muvi_variable_loadings_info(model, mdata, verbosity: int = 1) -> pd.DataFrame:
    """
    Return a wide DataFrame with rows = variables (feature names), columns = Factor i,
    plus a 'view' column with the originating view.
    """
    all_loadings = model.get_factor_loadings(model.view_names, model.factor_names, as_df=True)
    chunks = []
    for view in _view_names(model, mdata):
        # transpose to variables x factors
        df_v = all_loadings[view].T.copy()
        # Reorder factor columns if factor_0.. present
        fcols = [c for c in df_v.columns if isinstance(c, str) and c.startswith("factor_")]
        if len(fcols) == df_v.shape[1]:
            fcols_sorted = [f"factor_{i}" for i in range(len(fcols))]
            df_v = df_v.loc[:, fcols_sorted]
        df_v["view"] = view
        df_v["variable"] = df_v.index
        chunks.append(df_v)

    var_load = pd.concat(chunks, axis=0, ignore_index=False)
    var_load = var_load.rename_axis("variable").reset_index(drop=True)
    var_load = _rename_factor_columns(var_load)

    if verbosity:
        print(var_load.head().to_string(index=False))
    return var_load  # columns: Factor i..., view, variable

def muvi_plot_top_loadings_heatmap(variable_loadings: pd.DataFrame,
                                   factor: str = "Factor 1",
                                   top_n: int = 30, by_abs: bool = True,
                                   save_path: Optional[str] = None,
                                   width: float = 5, height: float = 5, dpi: int = 300):
    """
    Tile heatmap of top features across views for a given factor.
    """
    df = variable_loadings.copy()
    if factor not in df.columns:
        raise ValueError(f"Factor column not found: {factor}")
    key = df[factor].abs() if by_abs else df[factor]
    top_vars = df.assign(score=key).sort_values("score", ascending=False).head(top_n)["variable"].tolist()
    plot_df = df[df["variable"].isin(top_vars)][["variable", "view", factor]].copy()
    p = (
        ggplot(plot_df)
        + aes(x="view", y="variable", fill=factor)
        + geom_tile()
        + scale_fill_gradient2(low="#1f77b4", mid="lightgray", high="#c20019", limits=[-1.1, 1.1])
        + theme_classic()
        + theme(axis_text_x=element_text(angle=45, hjust=1))
        + labs(title=factor, x="View", y="Feature", fill="Loading")
        + coord_fixed()
    )
    _ggsave_if(p, save_path, width, height, dpi)
    return p

def muvi_selected_features_info(variable_loadings: pd.DataFrame,
                                selections: Sequence[Tuple[str, str]],
                                verbosity: int = 1) -> pd.DataFrame:
    """
    selections: list of (feature_name, view) pairs.
    Returns long DataFrame with columns [Variable, view, Factor, loading].
    """
    base = variable_loadings.set_index(["variable", "view"])
    # Collect rows present in selections
    chosen = []
    for ft, view in selections:
        if (ft, view) in base.index:
            row = base.loc[(ft, view)]
            tmp = row.drop(labels=[], errors="ignore")
            tmp_df = tmp.dropna().to_frame().T  # include all factor columns
            tmp_df["variable"] = ft
            tmp_df["view"] = view
            chosen.append(tmp_df)
    if len(chosen) == 0:
        out = pd.DataFrame(columns=["Variable", "view", "Factor", "loading"])
    else:
        wide = pd.concat(chosen, ignore_index=True)
        # gather factors only
        factor_cols = [c for c in wide.columns if str(c).startswith("Factor ")]
        out = wide.melt(id_vars=["variable", "view"], value_vars=factor_cols,
                        var_name="Factor", value_name="loading").rename(columns={"variable": "Variable"})
    if verbosity:
        print(out.head().to_string(index=False))
    return out

def muvi_selected_features_plot(df_long: pd.DataFrame,
                                save_path: Optional[str] = None,
                                width: float = 6, height: float = 5, dpi: int = 300):
    """
    Heatmap of selected feature loadings across factors.
    """
    p = (
        ggplot(df_long, aes(x="Factor", y="Variable", fill="loading"))
        + geom_tile()
        + scale_fill_gradient2(low="#1f77b4", mid="lightgray", high="#c20019", limits=[-1.1, 1.1])
        + theme_classic()
        + theme(axis_text_x=element_text(angle=45, hjust=1), legend_position="bottom")
        + labs(title="Selected features loadings", x="Factor", y="Feature/view", fill="Loading")
        + coord_fixed()
    )
    _ggsave_if(p, save_path, width, height, dpi)
    return p

# -----------------------------------------------------------------------------
# 5) Factor scores with clinical covariates and tests
# -----------------------------------------------------------------------------

def muvi_factor_scores_info(model, mdata, obs_keys: Optional[Sequence[str]] = None,
                            verbosity: int = 1) -> pd.DataFrame:
    """
    Return factor scores as a DataFrame with optional columns from mdata.obs joined.
    """
    fs = model.get_factor_scores(as_df=True)
    fs = fs.rename(columns={c: f"Factor {i+1}" for i, c in enumerate(fs.columns) if str(c).startswith("factor_")})
    if obs_keys:
        fs = fs.join(mdata.obs[obs_keys])
    if verbosity:
        cols_show = [c for c in fs.columns if c.startswith("Factor ")][:3]
        print("Scores columns:", ", ".join(cols_show), "...")
        if obs_keys:
            print("Joined obs:", ", ".join(obs_keys))
    return fs

def muvi_kruskal_info(scores_df: pd.DataFrame, group_col: str,
                      factors: Optional[Sequence[str]] = None,
                      bonferroni: bool = True, verbosity: int = 1) -> pd.DataFrame:
    """
    Kruskal Wallis p-values per factor across categories in group_col.
    """
    from scipy.stats import kruskal
    if factors is None:
        factors = [c for c in scores_df.columns if str(c).startswith("Factor ")]
    groups = scores_df[group_col].dropna().unique().tolist()

    rows = []
    for f in factors:
        samples = [scores_df.loc[scores_df[group_col] == g, f].values for g in groups]
        _, p = kruskal(*samples)
        if bonferroni:
            p = p * len(factors)
            p = min(p, 1.0)
        rows.append({"Factor": f, "pvalue": p})
    out = pd.DataFrame(rows).sort_values("pvalue")
    if verbosity:
        print(out.to_string(index=False))
    return out

def muvi_kendall_info(scores_df: pd.DataFrame, ordinal_col: str,
                      factors: Optional[Sequence[str]] = None,
                      bonferroni: bool = True, verbosity: int = 1) -> pd.DataFrame:
    """
    Kendall tau p-values vs an ordinal encoding of ordinal_col.
    """
    from scipy.stats import kendalltau
    if factors is None:
        factors = [c for c in scores_df.columns if str(c).startswith("Factor ")]
    codes = pd.Categorical(scores_df[ordinal_col]).codes
    rows = []
    for f in factors:
        _, p = kendalltau(scores_df[f], codes)
        if bonferroni:
            p = p * len(factors)
            p = min(p, 1.0)
        rows.append({"Factor": f, "pvalue": p})
    out = pd.DataFrame(rows).sort_values("pvalue")
    if verbosity:
        print(out.to_string(index=False))
    return out

def muvi_violin_plot(scores_df: pd.DataFrame, factor: str, group_col: str,
                     palette: Optional[List[str]] = None, pvalue: Optional[float] = None,
                     save_path: Optional[str] = None,
                     width: float = 4.5, height: float = 4.5, dpi: int = 300):
    """
    Violin plot for one factor across categories in group_col.
    """
    p = (
        ggplot(scores_df, aes(y=factor, x=group_col, fill=group_col))
        + geom_violin(style="right", scale="width", width=1.25)
        + theme_classic()
        + coord_flip()
        + guides(fill=False)
        + labs(title=f"{factor}" + (f" adjusted p = {np.round(pvalue, 5)}" if pvalue is not None else ""), x=group_col, y=factor)
    )
    if palette is not None:
        p = p + scale_fill_manual(values=palette)
    _ggsave_if(p, save_path, width, height, dpi)
    return p

# -----------------------------------------------------------------------------
# 6) Confidence ellipses for two factors by group
# -----------------------------------------------------------------------------

def _cov_ellipse_points(cov: np.ndarray, center: np.ndarray, nstd: float = 2.0, num: int = 100) -> pd.DataFrame:
    eigvals, eigvecs = np.linalg.eigh(cov)
    order = eigvals.argsort()[::-1]
    eigvals, eigvecs = eigvals[order], eigvecs[:, order]
    transform = eigvecs @ np.diag(nstd * np.sqrt(eigvals))
    t = np.linspace(0, 2*np.pi, num)
    circle = np.column_stack([np.cos(t), np.sin(t)])
    ellipse = circle @ transform.T
    ellipse += center
    return pd.DataFrame({"x": ellipse[:, 0], "y": ellipse[:, 1]})

def muvi_confidence_ellipses_info(scores_df: pd.DataFrame, x_factor: str, y_factor: str,
                                  group_col: str, nstd: float = 2.0, verbosity: int = 1) -> pd.DataFrame:
    """
    Return ellipse points for each group level.
    """
    out = []
    for g in scores_df[group_col].dropna().unique():
        sub = scores_df[scores_df[group_col] == g]
        x = sub[x_factor].to_numpy()
        y = sub[y_factor].to_numpy()
        center = np.array([np.nanmean(x), np.nanmean(y)])
        cov = np.cov(np.vstack([x, y]))
        ell = _cov_ellipse_points(cov, center, nstd=nstd)
        ell[group_col] = g
        out.append(ell)
    df = pd.concat(out, ignore_index=True)
    if verbosity:
        print(df.head().to_string(index=False))
    return df

def muvi_confidence_ellipses_plot(scores_df: pd.DataFrame, ellipses_df: pd.DataFrame,
                                  x_factor: str, y_factor: str, group_col: str,
                                  palette: Optional[List[str]] = None,
                                  save_path: Optional[str] = None,
                                  width: float = 4.5, height: float = 4.5, dpi: int = 300):
    """
    Plot confidence ellipses only. Returns ggplot.
    """
    p = (
        ggplot(scores_df, aes(x=x_factor, y=y_factor, color=group_col))
        + geom_path(ellipses_df, aes(x="x", y="y", group=group_col, color=group_col), size = 3)
        + theme_classic()
        + ggtitle("Confidence ellipses by group")
        + coord_equal()
    )
    if palette is not None:
        p = p + scale_fill_manual(values=palette)
    _ggsave_if(p, save_path, width, height, dpi)
    return p

# -----------------------------------------------------------------------------
# 7) Export top features by view or by class
# -----------------------------------------------------------------------------

def muvi_top_features_by_view_info(variable_loadings: pd.DataFrame,
                                   factors: Sequence[str], top_per_view: int = 5,
                                   by_abs: bool = True, verbosity: int = 1) -> pd.DataFrame:
    """
    Return a tidy table with top features per view across selected factors.
    Columns: Variable, View, Weight, Factor
    """
    rows = []
    for f in factors:
        df = variable_loadings[["variable", "view", f]].copy()
        df["score"] = df[f].abs() if by_abs else df[f]
        top = df.sort_values("score", ascending=False).groupby("view").head(top_per_view)
        top = top.drop(columns="score").rename(columns={f: "Weight", "view": "View", "variable": "Variable"})
        top["Factor"] = f
        rows.append(top)
    out = pd.concat(rows, ignore_index=True)
    # drop duplicates while keeping first
    out = out.sort_values(by="Weight", key=lambda x: x.abs(), ascending=False).drop_duplicates(subset=["Variable", "View"])
    # keep final number per view
    out = out.groupby("View", group_keys=False).head(top_per_view)
    if verbosity:
        print(out.head().to_string(index=False))
    return out

def muvi_top_features_by_class_info(variable_loadings: pd.DataFrame,
                                    types_map: Dict[str, str],
                                    factors: Sequence[str], top_per_class: int = 5,
                                    by_abs: bool = True, verbosity: int = 1) -> pd.DataFrame:
    """
    types_map: dict feature_name -> class label
    Returns tidy table with Variable, View, Feature type, Weight, Factor
    """
    df = variable_loadings.copy()
    df["Feature type"] = df["variable"].map(types_map).fillna("NA")

    rows = []
    for f in factors:
        tmp = df[["variable", "view", "Feature type", f]].copy()
        tmp["score"] = tmp[f].abs() if by_abs else tmp[f]
        top = tmp.sort_values("score", ascending=False).groupby("Feature type").head(top_per_class)
        top = top.drop(columns="score").rename(columns={f: "Weight", "view": "View", "variable": "Variable"})
        top["Factor"] = f
        rows.append(top)
    out = pd.concat(rows, ignore_index=True)
    out = out.sort_values(by="Weight", key=lambda x: x.abs(), ascending=False).drop_duplicates(subset=["Variable", "View"])
    out = out.groupby("Feature type", group_keys=False).head(top_per_class)
    if verbosity:
        print(out.head().to_string(index=False))
    return out

def muvi_build_selected_anndata(mdata, selection_df: pd.DataFrame,
                                obs_anchor_view: Optional[str] = None):
    """
    Build a single-view AnnData matrix from selected features listed in selection_df
    which must have columns ['Variable', 'View'].
    """
    import muon as mu
    assert set(["Variable", "View"]).issubset(selection_df.columns)
    if obs_anchor_view is None:
        obs_anchor_view = list(mdata.mod.keys())[0]
    def _col(view, var):
        return mdata[view].X[:, mdata[view].var_names == var]
    X = np.hstack([_col(r["View"], r["Variable"]) for _, r in selection_df.iterrows()])
    ad = mu.AnnData(X)
    ad.obs = mdata[obs_anchor_view].obs.copy()
    ad.var = selection_df.reset_index(drop=True).copy()
    return ad


In [None]:
# =============================================================================
# Minimal usage examples
# =============================================================================
# Assuming: model (trained MuVI) and features (MuData)

# 1) Reconstruction R2 per view
recon = muvi_reconstruction_info(model, mdata_processed, verbosity=1)
p1 = muvi_reconstruction_plot(recon["by_view"], title="Reconstruction R2 by view")
p1.show()

In [None]:
# 2) Variance explained by factors per view with marginal sums
v_by_view = muvi_variance_by_view_info(model, verbosity=0)
p2 = muvi_variance_by_view_plot(v_by_view, subtitle=f"Macro-R2: {np.round(recon['macro']['R2'], 3)}")
p2.show()

In [None]:
# 3) Feature class vs factor variance explained
# Example feature classes
all_vars = mdata_processed.var_names.tolist()

feature_type_map = {
    "Pathway 1": [v for v in all_vars if any(f"ft_{i}" in v for i in [0, 1, 2])],
    "Pathway 2": [v for v in all_vars if any(f"ft_{i}" in v for i in [3, 4])],
    "Pathway 3": [v for v in all_vars if not any(f"ft_{i}" in v for i in range(5))]
}

# Run analysis
v_by_class = muvi_featureclass_variance_info(
    model,
    mdata_processed,
    feature_type_map=feature_type_map,
    aggregator="median",
    verbosity=1
)

# Plot
p3 = muvi_featureclass_variance_plot(v_by_class)
p3.show()

In [None]:
# 4) Variable loadings and top features for a factor
var_load = muvi_variable_loadings_info(model, mdata_processed, verbosity=0)
p4 = muvi_plot_top_loadings_heatmap(var_load, factor="Factor 1", top_n=30)
p4.show()

In [None]:
# 5) Selected features across factors
picked = [("Cell Type 1_gene_1", "Cell Type 1"), 
          ("Cell Type 2_gene_1", "Cell Type 2"),
          ("Cell Type 3_gene_3", "Cell Type 3"),]
picked_long = muvi_selected_features_info(var_load, picked, verbosity=0)
p5 = muvi_selected_features_plot(picked_long)
p5.show()

In [None]:
# 6) Scores with clinical covariates and tests
clin_variables = ["Cell Type 1:batch", "exposure"]
scores = muvi_factor_scores_info(model, mdata_processed, obs_keys=clin_variables, verbosity=0)

for var in clin_variables:
    print(f"Kruskal-Wallis test for {var}")
    kw = muvi_kruskal_info(scores, group_col=var, verbosity=1)
    print(f"Kendall tau test for {var}") # For ordinal variables
    muvi_kendall_info(scores, ordinal_col=var, verbosity=1)
    p6 = muvi_violin_plot(scores, factor="Factor 2", group_col=var, pvalue=float(kw[kw.Factor=="Factor 2"]["pvalue"]))
    p6.show()

In [None]:
# 7) Confidence ellipses for two factors by 
ell = muvi_confidence_ellipses_info(scores, x_factor="Factor 1", y_factor="Factor 2", group_col="exposure", 
                                    nstd=2, verbosity=1)
p7 = muvi_confidence_ellipses_plot(scores, ell, x_factor="Factor 1", y_factor="Factor 2", group_col="exposure")
p7.show()

In [None]:
# 8) Export top features
top_by_view = muvi_top_features_by_view_info(var_load, factors=["Factor 1", "Factor 2"], 
                                             top_per_view=2, verbosity=0)
top_ad = muvi_build_selected_anndata(mdata_processed, top_by_view.rename(columns={"View": "View"}))
top_ad.write("top_features_multi_pT.h5ad")
top_by_view.to_csv("top_features_multi_pT.csv", index=False)