# Embedding Pooling Strategy Analysis


### Setup


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from dotenv import load_dotenv

load_dotenv()


### Load Embedding Data


In [None]:
DATA_DIR = os.getenv("INFLAMM_DEBATE_FM_DATA_ROOT")
EMBEDDINGS_BASE_DIR = Path(DATA_DIR) / "processed" / "bulkformer_embeddings"

# Define the three configurations
configs = {
    "human_only": EMBEDDINGS_BASE_DIR / "human_only",
    "mouse_only": EMBEDDINGS_BASE_DIR / "mouse_only",
    "human_ortholog_filtered": EMBEDDINGS_BASE_DIR / "human_ortholog_filtered",
}

# Check which configurations exist
available_configs = {k: v for k, v in configs.items() if v.exists()}
print(f"Available configurations: {list(available_configs.keys())}")

# List available embedding files in each config
for config_name, config_dir in available_configs.items():
    print(f"\n{config_name}:")
    npy_files = sorted(config_dir.glob("*.npy"))
    for f in npy_files:
        print(f"  - {f.name}")


### Helper Functions for Memory-Efficient Loading


In [None]:
def load_metadata_for_embeddings(embedding_path: Path):
    """
    Load metadata to identify control and inflammation samples.
    Assumes metadata is stored alongside embeddings or can be inferred from AnnData.
    For now, we'll need to load the corresponding AnnData to get sample labels.
    """
    # Extract dataset name from embedding filename
    # Format: {dataset}_transcriptome_embeddings.npy
    dataset_name = embedding_path.stem.replace("_transcriptome_embeddings", "")
    
    # Load corresponding AnnData to get metadata
    ann_data_dir = Path(DATA_DIR) / "processed" / "anndata_cleaned"
    ann_data_path = ann_data_dir / f"{dataset_name}.h5ad"
    
    if not ann_data_path.exists():
        # Try alternative locations
        ann_data_dir_alt = Path(DATA_DIR) / "processed" / "anndata_orthologs"
        ann_data_path = ann_data_dir_alt / f"{dataset_name}_orthologs.h5ad"
    
    if ann_data_path.exists():
        import anndata as ad
        adata = ad.read_h5ad(ann_data_path)
        return adata.obs
    else:
        print(f"Warning: Could not find AnnData for {dataset_name}")
        return None


def get_sample_indices(metadata, n_samples_per_group=15):
    """
    Get indices for control and inflammation samples.
    Assumes metadata has a column indicating control/inflammation status.
    """
    # Try common column names for status
    status_cols = ["takao_status", "status", "condition", "group", "inflammation"]
    status_col = None
    for col in status_cols:
        if col in metadata.columns:
            status_col = col
            break
    
    if status_col is None:
        print(f"Warning: Could not find status column. Available: {metadata.columns.tolist()}")
        # Return first n_samples_per_group*2 samples as fallback
        return np.arange(n_samples_per_group * 2)
    
    # Get control and inflammation indices
    if "control" in status_col.lower() or "takao" in status_col.lower():
        # Handle takao_status format
        control_mask = metadata[status_col].str.contains("control", case=False, na=False)
        inflam_mask = metadata[status_col].str.contains("inflam", case=False, na=False)
    else:
        # Generic handling
        unique_vals = metadata[status_col].unique()
        if len(unique_vals) >= 2:
            control_mask = metadata[status_col] == unique_vals[0]
            inflam_mask = metadata[status_col] == unique_vals[1]
        else:
            # Fallback: split by index
            n_total = len(metadata)
            return np.arange(min(n_samples_per_group * 2, n_total))
    
    control_indices = np.where(control_mask)[0]
    inflam_indices = np.where(inflam_mask)[0]
    
    # Sample up to n_samples_per_group from each
    n_control = min(n_samples_per_group, len(control_indices))
    n_inflam = min(n_samples_per_group, len(inflam_indices))
    
    selected_control = np.random.choice(control_indices, size=n_control, replace=False)
    selected_inflam = np.random.choice(inflam_indices, size=n_inflam, replace=False)
    
    return np.sort(np.concatenate([selected_control, selected_inflam]))


def load_embeddings_subset(embedding_path: Path, sample_indices: np.ndarray):
    """
    Load a subset of embeddings without loading the full array into memory.
    Uses memory mapping for large files.
    """
    # Use memory mapping to avoid loading full array
    mmap_emb = np.load(embedding_path, mmap_mode='r')
    
    # Load only the selected samples
    if len(sample_indices) > 0:
        embeddings = mmap_emb[sample_indices]
    else:
        embeddings = mmap_emb[:]
    
    return embeddings


### Load Embeddings with Sample Limiting


In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Load embeddings for each configuration
embeddings_data = {}

for config_name, config_dir in available_configs.items():
    print(f"\n{'='*60}")
    print(f"Processing: {config_name}")
    print(f"{'='*60}")
    
    config_embeddings = {}
    npy_files = sorted(config_dir.glob("*.npy"))
    
    for emb_file in npy_files:
        dataset_name = emb_file.stem.replace("_transcriptome_embeddings", "")
        print(f"\n  Dataset: {dataset_name}")
        
        # Load metadata to identify samples
        metadata = load_metadata_for_embeddings(emb_file)
        
        if metadata is not None:
            # Get sample indices (15 control + 15 inflammation)
            sample_indices = get_sample_indices(metadata, n_samples_per_group=15)
            print(f"    Selected {len(sample_indices)} samples")
            
            # Load embeddings subset
            emb_subset = load_embeddings_subset(emb_file, sample_indices)
            print(f"    Embedding shape: {emb_subset.shape}")
            
            config_embeddings[dataset_name] = {
                "embeddings": emb_subset,
                "sample_indices": sample_indices,
                "metadata": metadata.iloc[sample_indices] if metadata is not None else None
            }
        else:
            # Fallback: load first 30 samples
            print(f"    Warning: Loading first 30 samples (no metadata found)")
            emb_subset = load_embeddings_subset(emb_file, np.arange(30))
            config_embeddings[dataset_name] = {
                "embeddings": emb_subset,
                "sample_indices": np.arange(30),
                "metadata": None
            }
    
    embeddings_data[config_name] = config_embeddings

print(f"\n{'='*60}")
print("Loading complete!")
print(f"{'='*60}")


### Variance Analysis: Full Embeddings


In [None]:
def compute_variance_per_gene(embeddings):
    """
    Compute variance per gene across samples.
    embeddings: (n_samples, n_genes, embedding_dim) or (n_samples, n_features)
    """
    if len(embeddings.shape) == 3:
        # (n_samples, n_genes, embedding_dim) - need to pool first or compute per gene
        # For now, flatten gene and embedding dimensions
        n_samples, n_genes, emb_dim = embeddings.shape
        # Reshape to (n_samples, n_genes * emb_dim) for variance calculation
        embeddings_flat = embeddings.reshape(n_samples, -1)
        # Variance across samples (axis=0)
        var_per_feature = np.var(embeddings_flat, axis=0)
        # Reshape back to (n_genes, emb_dim) to get variance per gene
        var_per_gene_emb = var_per_feature.reshape(n_genes, emb_dim)
        # Average variance across embedding dimensions per gene
        var_per_gene = np.mean(var_per_gene_emb, axis=1)
        return var_per_gene, var_per_gene_emb
    else:
        # (n_samples, n_features) - assume features are already pooled
        var_per_feature = np.var(embeddings, axis=0)
        return var_per_feature, None


def compute_pca_variance(embeddings, n_components=50):
    """
    Compute PCA and return explained variance.
    """
    # Flatten if 3D
    if len(embeddings.shape) == 3:
        n_samples, n_genes, emb_dim = embeddings.shape
        embeddings_flat = embeddings.reshape(n_samples, -1)
    else:
        embeddings_flat = embeddings
    
    # Standardize
    scaler = StandardScaler()
    embeddings_scaled = scaler.fit_transform(embeddings_flat)
    
    # PCA
    pca = PCA(n_components=min(n_components, embeddings_scaled.shape[1], embeddings_scaled.shape[0]-1))
    pca.fit(embeddings_scaled)
    
    return {
        "explained_variance_ratio": pca.explained_variance_ratio_,
        "explained_variance": pca.explained_variance_,
        "cumulative_variance": np.cumsum(pca.explained_variance_ratio_),
        "n_components": len(pca.explained_variance_ratio_)
    }


In [None]:
# Analyze variance for each configuration
variance_results = {}

for config_name, config_data in embeddings_data.items():
    print(f"\n{'='*60}")
    print(f"Variance Analysis: {config_name}")
    print(f"{'='*60}")
    
    config_results = {}
    
    for dataset_name, data in config_data.items():
        embeddings = data["embeddings"]
        print(f"\n  Dataset: {dataset_name}")
        print(f"    Shape: {embeddings.shape}")
        
        # Compute variance per gene
        var_per_gene, var_per_gene_emb = compute_variance_per_gene(embeddings)
        print(f"    Variance per gene shape: {var_per_gene.shape}")
        print(f"    Mean variance: {np.mean(var_per_gene):.4f}")
        print(f"    Std variance: {np.std(var_per_gene):.4f}")
        
        # Compute PCA variance
        pca_results = compute_pca_variance(embeddings, n_components=50)
        print(f"    PCA components: {pca_results['n_components']}")
        print(f"    First 5 PC explained variance: {pca_results['explained_variance_ratio'][:5]}")
        print(f"    Cumulative variance (first 10 PCs): {pca_results['cumulative_variance'][9]:.4f}")
        
        config_results[dataset_name] = {
            "var_per_gene": var_per_gene,
            "var_per_gene_emb": var_per_gene_emb,
            "pca_results": pca_results,
            "embeddings_shape": embeddings.shape
        }
    
    variance_results[config_name] = config_results


### Pooling Strategies


In [None]:
def mean_pooling(embeddings):
    """
    Mean pooling along gene dimension.
    embeddings: (n_samples, n_genes, embedding_dim)
    returns: (n_samples, embedding_dim)
    """
    if len(embeddings.shape) == 3:
        return np.mean(embeddings, axis=1)
    else:
        return embeddings  # Already pooled or 2D


def max_pooling(embeddings):
    """
    Max pooling along gene dimension.
    embeddings: (n_samples, n_genes, embedding_dim)
    returns: (n_samples, embedding_dim)
    """
    if len(embeddings.shape) == 3:
        return np.max(embeddings, axis=1)
    else:
        return embeddings  # Already pooled or 2D


def sum_pooling(embeddings):
    """
    Sum pooling along gene dimension.
    embeddings: (n_samples, n_genes, embedding_dim)
    returns: (n_samples, embedding_dim)
    """
    if len(embeddings.shape) == 3:
        return np.sum(embeddings, axis=1)
    else:
        return embeddings  # Already pooled or 2D


def apply_pooling_strategies(embeddings):
    """
    Apply all pooling strategies to embeddings.
    """
    results = {}
    
    if len(embeddings.shape) == 3:
        results["mean"] = mean_pooling(embeddings)
        results["max"] = max_pooling(embeddings)
        results["sum"] = sum_pooling(embeddings)
    else:
        # Already 2D, assume it's already pooled somehow
        results["original"] = embeddings
    
    return results


In [None]:
# Apply pooling strategies to all embeddings
pooling_results = {}

for config_name, config_data in embeddings_data.items():
    print(f"\n{'='*60}")
    print(f"Pooling Analysis: {config_name}")
    print(f"{'='*60}")
    
    config_pooling = {}
    
    for dataset_name, data in config_data.items():
        embeddings = data["embeddings"]
        print(f"\n  Dataset: {dataset_name}")
        print(f"    Original shape: {embeddings.shape}")
        
        pooled = apply_pooling_strategies(embeddings)
        
        for strategy, pooled_emb in pooled.items():
            print(f"    {strategy} pooling shape: {pooled_emb.shape}")
        
        config_pooling[dataset_name] = {
            "original": embeddings,
            "pooled": pooled
        }
    
    pooling_results[config_name] = config_pooling


### Visualization: PCA Variance Explained


In [None]:
# Plot PCA explained variance for full embeddings
fig, axes = plt.subplots(len(available_configs), 1, figsize=(10, 4*len(available_configs)))
if len(available_configs) == 1:
    axes = [axes]

for idx, (config_name, config_results) in enumerate(variance_results.items()):
    ax = axes[idx]
    
    for dataset_name, results in config_results.items():
        pca_res = results["pca_results"]
        n_components = pca_res["n_components"]
        x = np.arange(1, min(21, n_components + 1))  # First 20 components
        
        ax.plot(x, pca_res["explained_variance_ratio"][:len(x)], 
                marker='o', label=dataset_name, alpha=0.7)
    
    ax.set_xlabel("Principal Component")
    ax.set_ylabel("Explained Variance Ratio")
    ax.set_title(f"PCA Explained Variance: {config_name}")
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


### Visualization: Variance per Gene


In [None]:
# Plot variance per gene distribution
fig, axes = plt.subplots(len(available_configs), 1, figsize=(12, 4*len(available_configs)))
if len(available_configs) == 1:
    axes = [axes]

for idx, (config_name, config_results) in enumerate(variance_results.items()):
    ax = axes[idx]
    
    for dataset_name, results in config_results.items():
        var_per_gene = results["var_per_gene"]
        ax.hist(var_per_gene, bins=50, alpha=0.6, label=dataset_name, density=True)
    
    ax.set_xlabel("Variance per Gene")
    ax.set_ylabel("Density")
    ax.set_title(f"Variance per Gene Distribution: {config_name}")
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()


### Visualization: Pooling Strategy Comparison


In [None]:
# Compare pooling strategies using PCA on pooled embeddings
for config_name, config_pooling in pooling_results.items():
    print(f"\n{'='*60}")
    print(f"Pooling Strategy Comparison: {config_name}")
    print(f"{'='*60}")
    
    # Compute PCA for each pooling strategy
    pooling_pca_results = {}
    
    for dataset_name, data in config_pooling.items():
        dataset_pca = {}
        
        for strategy, pooled_emb in data["pooled"].items():
            if len(pooled_emb.shape) == 2 and pooled_emb.shape[1] > 1:
                pca_res = compute_pca_variance(pooled_emb, n_components=20)
                dataset_pca[strategy] = pca_res
        
        pooling_pca_results[dataset_name] = dataset_pca
    
    # Plot comparison
    n_datasets = len(config_pooling)
    fig, axes = plt.subplots(1, n_datasets, figsize=(5*n_datasets, 5))
    if n_datasets == 1:
        axes = [axes]
    
    for ax, (dataset_name, dataset_pca) in zip(axes, pooling_pca_results.items()):
        for strategy, pca_res in dataset_pca.items():
            n_comp = min(20, pca_res["n_components"])
            x = np.arange(1, n_comp + 1)
            ax.plot(x, pca_res["explained_variance_ratio"][:n_comp], 
                   marker='o', label=strategy, alpha=0.7)
        
        ax.set_xlabel("Principal Component")
        ax.set_ylabel("Explained Variance Ratio")
        ax.set_title(f"{dataset_name}\nPooling Strategy Comparison")
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.suptitle(f"PCA Comparison: {config_name}", fontsize=14)
    plt.tight_layout()
    plt.show()


### Visualization: UMAP Clustering for Pooling Strategies


In [None]:
try:
    import umap
    from sklearn.manifold import TSNE
    
    # UMAP visualization for each pooling strategy
    for config_name, config_pooling in pooling_results.items():
        print(f"\n{'='*60}")
        print(f"UMAP Clustering: {config_name}")
        print(f"{'='*60}")
        
        for dataset_name, data in config_pooling.items():
            metadata = embeddings_data[config_name][dataset_name]["metadata"]
            
            # Get status labels if available
            if metadata is not None:
                status_cols = ["takao_status", "status", "condition", "group"]
                status_col = None
                for col in status_cols:
                    if col in metadata.columns:
                        status_col = col
                        break
                
                if status_col:
                    labels = metadata[status_col].values
                else:
                    labels = None
            else:
                labels = None
            
            # Plot UMAP for each pooling strategy
            n_strategies = len([s for s in data["pooled"].keys() if len(data["pooled"][s].shape) == 2])
            fig, axes = plt.subplots(1, n_strategies, figsize=(5*n_strategies, 5))
            if n_strategies == 1:
                axes = [axes]
            
            ax_idx = 0
            for strategy, pooled_emb in data["pooled"].items():
                if len(pooled_emb.shape) == 2 and pooled_emb.shape[1] > 2:
                    reducer = umap.UMAP(n_components=2, random_state=42, n_neighbors=min(15, len(pooled_emb)-1))
                    embedding_2d = reducer.fit_transform(pooled_emb)
                    
                    ax = axes[ax_idx]
                    if labels is not None:
                        unique_labels = np.unique(labels)
                        for label in unique_labels:
                            mask = labels == label
                            ax.scatter(embedding_2d[mask, 0], embedding_2d[mask, 1], 
                                     label=str(label), alpha=0.6, s=50)
                        ax.legend()
                    else:
                        ax.scatter(embedding_2d[:, 0], embedding_2d[:, 1], alpha=0.6, s=50)
                    
                    ax.set_xlabel("UMAP 1")
                    ax.set_ylabel("UMAP 2")
                    ax.set_title(f"{strategy} pooling")
                    ax.grid(True, alpha=0.3)
                    ax_idx += 1
            
            plt.suptitle(f"{dataset_name} - UMAP Clustering", fontsize=14)
            plt.tight_layout()
            plt.show()
            
except ImportError:
    print("UMAP not available. Install with: pip install umap-learn")


### Summary Statistics


In [None]:
# Create summary table comparing pooling strategies
summary_rows = []

for config_name, config_pooling in pooling_results.items():
    for dataset_name, data in config_pooling.items():
        original_shape = data["original"].shape
        
        for strategy, pooled_emb in data["pooled"].items():
            if len(pooled_emb.shape) == 2:
                # Compute statistics
                pca_res = compute_pca_variance(pooled_emb, n_components=10)
                
                summary_rows.append({
                    "config": config_name,
                    "dataset": dataset_name,
                    "strategy": strategy,
                    "original_shape": str(original_shape),
                    "pooled_shape": str(pooled_emb.shape),
                    "mean": np.mean(pooled_emb),
                    "std": np.std(pooled_emb),
                    "pca_variance_1pc": pca_res["explained_variance_ratio"][0] if len(pca_res["explained_variance_ratio"]) > 0 else np.nan,
                    "pca_variance_10pc": pca_res["cumulative_variance"][9] if len(pca_res["cumulative_variance"]) > 9 else np.nan,
                })

summary_df = pd.DataFrame(summary_rows)
print("Summary Statistics for Pooling Strategies:")
print("="*80)
print(summary_df.to_string(index=False))
