# Multi-Model Embedding Visualization

Compare embeddings from different fine-tuned models (zero-shot, human, mouse, combined)


### Setup


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import anndata as ad
from pathlib import Path
from dotenv import load_dotenv

load_dotenv()

# Set scanpy settings
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=100, facecolor='white')


### Load Data with Multi-Model Embeddings


In [None]:
DATA_DIR = os.getenv("INFLAMM_DEBATE_FM_DATA_ROOT")
if DATA_DIR is None:
    DATA_DIR = "../data/"

ANN_DATA_DIR = Path(DATA_DIR) / "processed" / "anndata_cleaned"

# Load all datasets
adatas = {}
for f in sorted(ANN_DATA_DIR.glob("*.h5ad")):
    name = f.stem
    adatas[name] = ad.read_h5ad(f)
    print(f"Loaded {name}: {adatas[name].shape}")
    
    # Check which embeddings are available
    embedding_keys = [k for k in adatas[name].obsm.keys() if k.startswith("X_")]
    if embedding_keys:
        print(f"  Available embeddings: {embedding_keys}")
    else:
        print(f"  Warning: No multi-model embeddings found in obsm")

print(f"\nTotal datasets loaded: {len(adatas)}")


### Helper Functions for Visualization


In [None]:
def plot_umap_scanpy(adata, embedding_key, color_vars, components='1,2', size=None, title_suffix=""):
    """
    Plot UMAP using scanpy for a specific embedding.
    
    Parameters
    ----------
    adata : AnnData
        AnnData object with embeddings in obsm
    embedding_key : str
        Key in obsm to use for embedding (e.g., 'X_zero_shot')
    color_vars : list
        List of column names in obs to color by
    components : str
        UMAP components to plot (e.g., '1,2' or '2,3')
    size : float, optional
        Point size for scatter plot
    title_suffix : str
        Additional text for plot title
    """
    if embedding_key not in adata.obsm:
        print(f"Warning: {embedding_key} not found in obsm. Available keys: {list(adata.obsm.keys())}")
        return
    
    # Create a copy to avoid modifying original
    adata_plot = adata.copy()
    
    # Set the embedding as the main representation
    adata_plot.obsm["X_embedding"] = adata.obsm[embedding_key]
    
    # Compute neighbors and UMAP
    sc.pp.neighbors(adata_plot, use_rep="X_embedding", n_neighbors=min(15, adata_plot.n_obs - 1))
    sc.tl.umap(adata_plot, n_components=10)
    
    # Plot
    model_name = embedding_key.replace("X_", "").replace("_", " ").title()
    title = f"{model_name} Embeddings{title_suffix}"
    
    plt.figure(figsize=(8, 6))
    sc.pl.umap(adata_plot, color=color_vars, show=False, alpha=0.7, 
               components=components, size=size, title=title)
    plt.tight_layout()
    plt.show()


def compare_models_umap(adatas_dict, dataset_name, color_vars, components='1,2', size=None):
    """
    Compare UMAP visualizations across all available models for a dataset.
    
    Parameters
    ----------
    adatas_dict : dict
        Dictionary of AnnData objects
    dataset_name : str
        Name of dataset to visualize
    color_vars : list
        List of column names in obs to color by
    components : str
        UMAP components to plot
    size : float, optional
        Point size
    """
    if dataset_name not in adatas_dict:
        print(f"Dataset {dataset_name} not found")
        return
    
    adata = adatas_dict[dataset_name]
    
    # Get all available embedding keys
    embedding_keys = sorted([k for k in adata.obsm.keys() if k.startswith("X_")])
    
    if not embedding_keys:
        print(f"No embeddings found for {dataset_name}")
        return
    
    print(f"Comparing {len(embedding_keys)} models for {dataset_name}")
    
    # Plot each model
    for emb_key in embedding_keys:
        plot_umap_scanpy(adata, emb_key, color_vars, components=components, 
                        size=size, title_suffix=f" - {dataset_name}")


In [None]:
# Visualize human datasets
human_datasets = [k for k in adatas.keys() if k.startswith("human_")]

for dataset_name in human_datasets:
    print(f"\n{'='*60}")
    print(f"Dataset: {dataset_name}")
    print(f"{'='*60}")
    
    # Determine color variables based on available metadata
    color_vars = ["group"]
    if "time_point_hours" in adatas[dataset_name].obs.columns:
        color_vars.append("time_point_hours")
    if "takao_status" in adatas[dataset_name].obs.columns:
        color_vars.append("takao_status")
    
    compare_models_umap(adatas, dataset_name, color_vars, size=100)


In [None]:
# Visualize mouse datasets
mouse_datasets = [k for k in adatas.keys() if k.startswith("mouse_")]

for dataset_name in mouse_datasets:
    print(f"\n{'='*60}")
    print(f"Dataset: {dataset_name}")
    print(f"{'='*60}")
    
    # Determine color variables based on available metadata
    color_vars = ["group"]
    if "time_point_hours" in adatas[dataset_name].obs.columns:
        color_vars.append("time_point_hours")
    if "time_point" in adatas[dataset_name].obs.columns:
        color_vars.append("time_point")
    if "takao_status" in adatas[dataset_name].obs.columns:
        color_vars.append("takao_status")
    if "patient_id" in adatas[dataset_name].obs.columns:
        color_vars.append("patient_id")
    
    compare_models_umap(adatas, dataset_name, color_vars, size=100)


### Cross-Model Comparison: Side-by-Side UMAP


In [None]:
def plot_all_models_side_by_side(adata, dataset_name, color_var="group", components='1,2', size=None):
    """
    Plot all available models side-by-side for easy comparison.
    """
    embedding_keys = sorted([k for k in adata.obsm.keys() if k.startswith("X_")])
    
    if not embedding_keys:
        print(f"No embeddings found for {dataset_name}")
        return
    
    n_models = len(embedding_keys)
    fig, axes = plt.subplots(1, n_models, figsize=(6*n_models, 5))
    if n_models == 1:
        axes = [axes]
    
    for idx, emb_key in enumerate(embedding_keys):
        ax = axes[idx]
        
        # Create temporary adata for this embedding
        adata_temp = adata.copy()
        adata_temp.obsm["X_embedding"] = adata.obsm[emb_key]
        
        # Compute neighbors and UMAP
        sc.pp.neighbors(adata_temp, use_rep="X_embedding", n_neighbors=min(15, adata_temp.n_obs - 1))
        sc.tl.umap(adata_temp, n_components=2)
        
        # Plot on specific axis
        model_name = emb_key.replace("X_", "").replace("_", " ").title()
        sc.pl.umap(adata_temp, color=color_var, ax=ax, show=False, 
                  alpha=0.7, components=components, size=size, title=model_name)
    
    plt.suptitle(f"{dataset_name} - Model Comparison", fontsize=14, y=1.02)
    plt.tight_layout()
    plt.show()


# Example: Compare all models for a specific dataset
example_dataset = "human_burn"
if example_dataset in adatas:
    plot_all_models_side_by_side(adatas[example_dataset], example_dataset, 
                                color_var="group", size=50)


### Combined Datasets: Cross-Species Visualization


In [None]:
# Combine human and mouse datasets for cross-species comparison
# Example: burn datasets
burn_datasets = ["human_burn", "mouse_burn"]
if all(d in adatas for d in burn_datasets):
    print("Combining burn datasets for cross-species comparison")
    
    # Get common embedding keys
    human_keys = set(adatas["human_burn"].obsm.keys())
    mouse_keys = set(adatas["mouse_burn"].obsm.keys())
    common_keys = sorted([k for k in human_keys & mouse_keys if k.startswith("X_")])
    
    print(f"Common embedding keys: {common_keys}")
    
    for emb_key in common_keys:
        print(f"\n{'='*60}")
        print(f"Model: {emb_key.replace('X_', '').replace('_', ' ').title()}")
        print(f"{'='*60}")
        
        # Combine datasets
        combined = ad.concat(
            [adatas["human_burn"], adatas["mouse_burn"]],
            label="species",
            keys=["human", "mouse"],
            index_unique=None,
        )
        combined.obs['species'] = combined.obs['species'].astype('category')
        
        # Use the embedding
        combined.obsm["X_embedding"] = combined.obsm[emb_key]
        
        # Compute UMAP
        sc.pp.neighbors(combined, use_rep="X_embedding", n_neighbors=min(15, combined.n_obs - 1))
        sc.tl.umap(combined, n_components=2)
        
        # Plot
        model_name = emb_key.replace("X_", "").replace("_", " ").title()
        plt.figure(figsize=(10, 5))
        
        # Plot 1: Color by species
        ax1 = plt.subplot(1, 2, 1)
        sc.pl.umap(combined, color="species", ax=ax1, show=False, 
                  alpha=0.7, size=50, title=f"{model_name} - Species")
        
        # Plot 2: Color by group
        ax2 = plt.subplot(1, 2, 2)
        sc.pl.umap(combined, color="group", ax=ax2, show=False, 
                  alpha=0.7, size=50, title=f"{model_name} - Group")
        
        plt.suptitle("Human vs Mouse Burn - Cross-Species Comparison", fontsize=14)
        plt.tight_layout()
        plt.show()


### Embedding Similarity Analysis


In [None]:
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import pearsonr

def compare_embedding_similarity(adata, dataset_name):
    """
    Compare similarity between different model embeddings using correlation.
    """
    embedding_keys = sorted([k for k in adata.obsm.keys() if k.startswith("X_")])
    
    if len(embedding_keys) < 2:
        print(f"Need at least 2 embeddings for comparison. Found: {embedding_keys}")
        return
    
    print(f"Comparing embeddings for {dataset_name}")
    print(f"Available models: {[k.replace('X_', '') for k in embedding_keys]}\n")
    
    # Compute pairwise correlations
    n_models = len(embedding_keys)
    corr_matrix = np.zeros((n_models, n_models))
    
    for i, key1 in enumerate(embedding_keys):
        for j, key2 in enumerate(embedding_keys):
            emb1 = adata.obsm[key1]
            emb2 = adata.obsm[key2]
            
            # Flatten embeddings if needed and compute correlation
            if len(emb1.shape) > 2:
                emb1 = emb1.reshape(emb1.shape[0], -1)
            if len(emb2.shape) > 2:
                emb2 = emb2.reshape(emb2.shape[0], -1)
            
            # Compute mean correlation across samples
            # For each sample, compute correlation between embedding vectors
            sample_corrs = []
            for s in range(min(emb1.shape[0], emb2.shape[0])):
                corr, _ = pearsonr(emb1[s], emb2[s])
                if not np.isnan(corr):
                    sample_corrs.append(corr)
            
            corr_matrix[i, j] = np.mean(sample_corrs) if sample_corrs else 0.0
    
    # Plot correlation matrix
    model_names = [k.replace("X_", "").replace("_", " ").title() for k in embedding_keys]
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(corr_matrix, annot=True, fmt='.3f', cmap='viridis', 
                xticklabels=model_names, yticklabels=model_names,
                square=True, cbar_kws={'label': 'Correlation'})
    plt.title(f"Embedding Similarity Matrix - {dataset_name}")
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    return corr_matrix, model_names


# Compare embeddings for each dataset
for dataset_name in list(adatas.keys())[:3]:  # First 3 datasets as example
    if dataset_name in adatas:
        corr_matrix, model_names = compare_embedding_similarity(adatas[dataset_name], dataset_name)
        print()


### Summary: Available Embeddings per Dataset


In [None]:
# Create summary table of available embeddings
summary_data = []
for dataset_name, adata in adatas.items():
    embedding_keys = sorted([k for k in adata.obsm.keys() if k.startswith("X_")])
    for emb_key in embedding_keys:
        emb = adata.obsm[emb_key]
        summary_data.append({
            "dataset": dataset_name,
            "model": emb_key.replace("X_", ""),
            "embedding_shape": str(emb.shape),
            "n_samples": emb.shape[0],
            "embedding_dim": emb.shape[1] if len(emb.shape) > 1 else "N/A"
        })

summary_df = pd.DataFrame(summary_data)
print("Summary of Available Embeddings:")
print("="*80)
print(summary_df.to_string(index=False))
