In [1]:
import anndata as ad
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import StratifiedKFold
from sklearn import preprocessing
from scipy.optimize import linear_sum_assignment

import sys
sys.path.append('..')
import copy as copy
import torch
import os

In [2]:
from src.configs import *

In [3]:
def partitions(celltype, n_partitions, seed=0):
    """
    adapted from https://github.com/AllenInstitute/coupledAE-patchseq
    """
    import warnings
    warnings.filterwarnings("ignore", category=UserWarning)

    # Safe to ignore warning - there are celltypes with a low sample number that are not crucial for the analysis.
    with warnings.catch_warnings():
        skf = StratifiedKFold(n_splits=n_partitions, random_state=seed, shuffle=True)

    # Get all partition indices from the sklearn generator:
    ind_dict = [{'train': train_ind, 'val': val_ind} for train_ind, val_ind in
                skf.split(X=np.zeros(shape=celltype.shape), y=celltype)]
    return ind_dict

In [4]:
def pre_ps(adata_rna_raw,adata_ephys_raw,adata_morph_raw,cv,split=False):
  adata_rna,adata_ephys,adata_morph = adata_rna_raw.copy(),adata_ephys_raw.copy(),adata_morph_raw.copy()
  adatas_train,adatas_test = [],[]
  assert (adata_rna.X>=0).all(), "poluted input"
  for mod in [adata_rna,adata_ephys,adata_morph]:
    mod.obs['label'] = mod.obs['cell_type_TEM']
    if split:
      m_train = mod[ind_dict[cv]['train']]
      scaler = preprocessing.StandardScaler().fit(m_train.X)
      m_train.X = scaler.transform(m_train.X)

      m_test = mod[ind_dict[cv]['val']]
      scaler = preprocessing.StandardScaler().fit(m_test.X)
      m_test.X = scaler.transform(m_test.X)
    else:
      scaler = preprocessing.StandardScaler().fit(mod.X)
      mod.X = scaler.transform(mod.X)
      m_train = mod[ind_dict[cv]['train']]
      m_test = mod[ind_dict[cv]['val']]

    adatas_train.append(m_train)
    adatas_test.append(m_test)
  adatas_all = [ad.concat([m_train,m_test]) for m_train,m_test in zip(adatas_train,adatas_test)]
  return adatas_train,adatas_test,adatas_all

In [5]:
from src.encoder_decoder_only.encoder_decoder_only_model import EncoderDecoderOnlyUnitedNet

## Load Data set ## 

In [6]:
# Set this to split training : testing data size
k_folds=3

In [7]:
technique = 'patchseq'
data_path = f"../data/{technique}"
device = "mps"
#load data
adata_rna_raw = sc.read_h5ad(f'{data_path}/adata_RNA_TEM.h5ad')
adata_ephys_raw = sc.read_h5ad(f'{data_path}/adata_Ephys_TEM.h5ad')
adata_morph_raw = sc.read_h5ad(f'{data_path}/adata_Morph_TEM.h5ad')
ind_dict = partitions(adata_rna_raw.obs['cell_type_TEM'], n_partitions=k_folds, seed=0)


In [8]:
root_save_path = f"../saved_results/encoder_decoder/{technique}"

## Train Model

In [None]:
for cv in range(k_folds):
  adatas_train,adatas_test,_ = pre_ps(adata_rna_raw,adata_ephys_raw,adata_morph_raw,cv,split=True)
  root_save_path = f"./saved_results/encoder_decoder/{technique}_{cv}"
  model = EncoderDecoderOnlyUnitedNet(root_save_path, device=device, technique=encoder_decoder_only_patchseq_config)
  print(model.model.config[str_train_task])
  model.train(adatas_train,adatas_val = adatas_test, verbose=True)

  print(model.evaluate(adatas_test))
  adata_last_fold = adatas_test
  

In [10]:
# print("TRAINING FOR str_supervised_group_identigy_only")
# for cv in range(k_folds):
#     _,_,adatas_all = pre_ps(adata_rna_raw,adata_ephys_raw,adata_morph_raw,cv,split=False)
#     model.load_model(f"{root_save_path}/train_best.pt",device=torch.device(device))
#     model.model.device_in_use = device
#     model.train(adatas_all,verbose=True,init_classify=True)
  

In [50]:
adatas_all = adata_last_fold

In [None]:
print('==============best finetune================')
model.load_model(f"{root_save_path}/train_best.pt",device=torch.device(device))
# model.model.device_in_use = device
model.evaluate(adatas_all,give_losses=True,stage="train")


In [52]:
# losses = model.evaluate(adatas_all,give_losses=True,stage='train')
predict_label = model.predict_label(adatas_all)
adata_fused = model.infer(adatas_all)
adata_fused.obs['label'] = list(adatas_all[0].obs['label'])
adata_fused.obs['label_less'] = [ct.split('-')[0] for ct in adata_fused.obs['label'].values]


In [53]:
from sklearn.utils.multiclass import unique_labels


pseudo_label = np.array(adata_fused.obs['predicted_label'].values)
cmat = confusion_matrix(adata_fused.obs['label'], pseudo_label)
ri, ci = linear_sum_assignment(-cmat)
ordered_all = cmat[np.ix_(ri, ci)]
major_sub_names = {}
pred_labels_re_order = copy.deepcopy(pseudo_label)
for re_oder,(lb_correct,lb) in enumerate(zip(unique_labels(adata_fused.obs['label'], pseudo_label)[ri],
                                unique_labels(adata_fused.obs['label'], pseudo_label)[ci])):
  idx = pseudo_label==lb
  if any(idx):
    nm = '-'.join(lb_correct.split('-')[:-1])
    if nm in major_sub_names.keys():
      major_sub_names[nm]+=1
    else:
      major_sub_names[nm]=1
    
    pred_labels_re_order[idx] = f'{nm}-{major_sub_names[nm]}-Uni'

adata_fused.obs['predicted_label'] = pred_labels_re_order


In [54]:
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score
def ordered_cmat(labels, pred):
    """
    Compute the confusion matrix and accuracy corresponding to the best cluster-to-class assignment.

    :param labels: Label array
    :type labels: np.array
    :param pred: Predictions array
    :type pred: np.array
    :return: Accuracy and confusion matrix
    :rtype: Tuple[float, np.array]
    """
    cmat = confusion_matrix(labels, pred)
    ri, ci = linear_sum_assignment(-cmat)
    ordered = cmat[np.ix_(ri, ci)]
    acc = np.sum(np.diag(ordered))/np.sum(ordered)
    return acc, ordered

labels = pseudo_label
predictions = adata_fused.obs['label']
acc, ordered = ordered_cmat(labels, predictions)
metrics = {
    "confusion": ordered,
    "acc": acc,
    "ari": adjusted_rand_score(labels, predictions),
    "nmi": normalized_mutual_info_score(
        labels, predictions, average_method="geometric"
    ),
}

In [None]:
metrics

In [None]:
len(unique_labels(pseudo_label))

In [None]:
len(pseudo_label)

In [None]:
len(unique_labels(adata_fused.obs['label']))

In [None]:
model.model

In [None]:
sns.set_style('ticks')
adata_fused.obs['predicted_label_less'] = [ct.split('-')[0] for ct in adata_fused.obs['predicted_label'].values]
cmat = confusion_matrix(adata_fused.obs['predicted_label'], adata_fused.obs['label_less'])
cmat = cmat[:,cmat.sum(axis=0)!=0]
cmat = cmat[cmat.sum(axis=1)!=0,:]
cmat = (cmat.T / cmat.sum(axis=1)).T

fig,ax = plt.subplots(figsize=[1.2,5])
sns.heatmap(cmat,ax=ax,yticklabels=unique_labels(adata_fused.obs['predicted_label']),xticklabels=unique_labels(adata_fused.obs['label_less']),vmin=0, vmax=1)
plt.xlabel('TEM joint label')
plt.savefig('./figures/major_matching_heatmap.pdf')

fig,ax = plt.subplots(figsize=[6,5])
ordered = ordered_all[:,ordered_all.sum(axis=0)!=0]
ordered = ordered[ordered.sum(axis=1)!=0,:]
ordered_re = ordered.T
ordered_norm = (ordered_re.T / ordered_re.sum(axis=1)).T

sns.heatmap(ordered_norm,ax=ax,xticklabels=unique_labels(adata_fused.obs['label']),yticklabels=unique_labels(adata_fused.obs['predicted_label']),vmin=0, vmax=1)
plt.xlabel('TEM joint label')
os.makedirs('./figures/', exist_ok=True)
plt.savefig('./figures/sub_matching_heatmap.pdf')



In [None]:

sc.pl.umap(adata_fused, 
           color=['label'], 
           palette='rainbow', 
           show=True, 
           edges=True, 
           edges_width=0.2, 
           edgecolor='k', 
           title='', 
           save='patch_seq_2D_orig_MET.pdf')
sc.pl.umap(adata_fused,
           color=['predicted_label'],
           palette='rainbow',
           show=True,
           edges=True,
           edges_width = 0.2,
           edgecolor='k',
           title='',
           save='patch_seq_2D_Uni_MET.pdf')


sc.pl.umap(adata_fused,color=['label_less'],palette='rainbow',show=True,edges=True,edges_width = 0.2,edgecolor='k',title='',save='patch_seq_2D_MET_comparison.pdf')
sc.pl.umap(adata_fused,color=['predicted_label_less'],palette='rainbow',show=True,edges=True,edges_width = 0.2,edgecolor='k',title='',save='patch_seq_2D_MET_comparison_no_legend.pdf')


### SHAP IQ

First I will try to recover the current SHAP values

In [58]:
import shapiq

In [None]:
adatas_test

In [81]:
def get_fused_latent_codes(adatas):
    # Get fused latent codes from the model
    from src.data import create_dataloader
    dataloader = create_dataloader(model.model, adatas, shuffle=False, batch_size=len(adatas[0]))
    
    model.model.eval()
    with torch.no_grad():
        for modalities, labels in dataloader:
            outputs = model.model(modalities, labels)
            fused_latents = model.model.fused_latents[model.model.best_head]
            break
    
    # Convert to numpy
    X = fused_latents.detach().cpu().numpy()
    return X

In [89]:
fused_latent_codes = get_fused_latent_codes(adatas_all)

In [None]:
fused_latent_codes

In [None]:
adata_fused

In [101]:
instance_id = 144

In [None]:
# Predicted label
adata_fused.obs['predicted_label'][instance_id]

In [None]:
adata_fused.obs['label'][instance_id]

In [84]:
# Define simple prediction function (returns class probabilities)
def predict_func(latent_codes):
    if latent_codes.ndim == 1:
        latent_codes = latent_codes.reshape(1, -1)
    
    latent_tensor = torch.tensor(latent_codes, dtype=torch.float32).to(model.model.device_in_use)
    
    with torch.no_grad():
        hidden = model.model.projectors[model.model.best_head](latent_tensor)
        class_outputs = model.model.clusters[model.model.best_head](hidden)
        class_probs = model.model.prob_layer(class_outputs)
    
    return class_probs.detach().cpu().numpy()

In [95]:
explainer = shapiq.TabularExplainer(
        model=predict_func,
        data=fused_latent_codes,
        index="SV",
        max_order=1
    )

In [126]:
sv = explainer.explain(x=fused_latent_codes[instance_id], budget=500)

In [None]:
print(sv)

In [None]:
max_index = max(ind_dict[cv]['train'])
max_index

In [141]:
from src.modules import submodel_clus
sub = submodel_clus(model.model).to(model.device)
# Simple and effective background data
unique_labels = np.unique(pseudo_label)
cluster_prototype_features = []
for ad_x in adatas_all:
    type_means = []
    for label in unique_labels:
        mask = pseudo_label == label
        if np.any(mask):
            type_mean = np.mean(ad_x.X[mask], axis=0)
            type_means.append(type_mean)
    cluster_prototype_features.append(torch.tensor(np.array(type_means), device=model.device))

In [142]:
e = shap.DeepExplainer(sub, cluster_prototype_features)

In [None]:
model.device

In [None]:
# Option 3: Following the exact PatchSeq pattern
modality_to_explain = 0  # Start with first modality (e.g., RNA)
test_type = torch.tensor(adatas_all[modality_to_explain].X, device=model.device)
shap_values = e.shap_values([test_type], check_additivity=True)

print(f"SHAP values shape: {shap_values.shape}")
print(f"Explained modality {modality_to_explain}")

In [None]:
test_sample = [torch.tensor(ad_x.X, device=model.device) for ad_x in adatas_all]
shap_values = e.shap_values(test_sample, check_additivity=False)
print(f"Done! Shape: {len(shap_values)} modalities")
for i, sv in enumerate(shap_values):
    print(f"Modality {i}: {sv.shape}")

In [None]:
def map_shap_to_features_fixed(shap_values, adatas, sample_idx=0, class_idx=None):
    """
    Map SHAP values back to original feature names
    
    Args:
        shap_values: SHAP values - list of arrays [modality] with shape (samples, features, classes)
        adatas: List of AnnData objects for each modality
        sample_idx: Which sample to analyze
        class_idx: Which class to analyze (if None, uses mean across all classes)
    """
    modality_names = ['RNA', 'Ephys', 'Morph']
    results = {}

    for mod_idx, (shap_mod, adata, mod_name) in enumerate(zip(shap_values, adatas, modality_names)):
        
        # shap_mod has shape (n_samples, n_features, n_classes)
        if class_idx is not None:
            # Use specific class
            sample_shap = shap_mod[sample_idx, :, class_idx]  # Shape: (n_features,)
        else:
            # Average across all classes (mean absolute SHAP values)
            sample_shap = np.mean(np.abs(shap_mod[sample_idx, :, :]), axis=1)  # Shape: (n_features,)

        # Get feature names
        if hasattr(adata, 'var_names'):
            feature_names = list(adata.var_names)
        elif hasattr(adata, 'var') and 'gene_symbols' in adata.var.columns:
            feature_names = list(adata.var['gene_symbols'])
        else:
            feature_names = [f"{mod_name}_Feature_{i}" for i in range(len(sample_shap))]

        # Create feature importance list
        feature_importance = []
        for i, (shap_val, feature_name) in enumerate(zip(sample_shap, feature_names)):
            if hasattr(shap_val, 'item'):
                shap_scalar = shap_val.item()
            else:
                shap_scalar = float(shap_val)

            feature_importance.append({
                'feature_name': feature_name,
                'shap_value': shap_scalar,
                'abs_shap_value': abs(shap_scalar),
                'feature_index': i
            })

        # Sort by absolute SHAP value
        feature_importance.sort(key=lambda x: x['abs_shap_value'], reverse=True)
        results[mod_name] = feature_importance

    return results

# Now you can use the original analyze_pvalb_features function:
def analyze_pvalb_features(shap_values, adatas_all, pseudo_label, cell_type_prefix='Pvalb', top_n=10):
    """
    Analyze which features are most important for a specific cell type
    """
    
    # Get indices of neurons with the specified prefix
    print(len(pseudo_label))
    cell_type_mask = np.array([label.startswith(cell_type_prefix) for label in pseudo_label])
    cell_type_indices = np.where(cell_type_mask)[0]
    
    if len(cell_type_indices) == 0:
        print(f"No {cell_type_prefix} neurons found!")
        return None
    
    print(f"Found {len(cell_type_indices)} {cell_type_prefix} neurons")
    
    # Show the subtypes found
    subtypes = set([label for label in pseudo_label if label.startswith(cell_type_prefix)])
    print(f"Subtypes: {subtypes}")
    print("=" * 60)
    
    # Analyze each neuron and aggregate results
    all_features = {mod: [] for mod in ['RNA', 'Ephys', 'Morph']}
    
    for neuron_idx in cell_type_indices:
        # Get SHAP mapping for this neuron using the fixed function
        mapped_shap = map_shap_to_features_fixed(shap_values, adatas_all, 
                                               sample_idx=neuron_idx, class_idx=None)
        
        # Collect features from each modality
        for modality in mapped_shap:
            all_features[modality].extend(mapped_shap[modality])
    
    # Rest remains the same...
    aggregated_features = {}
    
    for modality in all_features:
        if not all_features[modality]:
            continue
            
        feature_groups = {}
        for feature_data in all_features[modality]:
            feature_name = feature_data['feature_name']
            if feature_name not in feature_groups:
                feature_groups[feature_name] = []
            feature_groups[feature_name].append(feature_data['abs_shap_value'])
        
        aggregated = []
        for feature_name, shap_values_list in feature_groups.items():
            mean_shap = np.mean(shap_values_list)
            aggregated.append({
                'feature_name': feature_name,
                'mean_abs_shap': mean_shap,
                'n_neurons': len(shap_values_list)
            })
        
        aggregated.sort(key=lambda x: x['mean_abs_shap'], reverse=True)
        aggregated_features[modality] = aggregated
    
    # Print results
    for modality in ['RNA', 'Ephys', 'Morph']:
        if modality in aggregated_features:
            print(f"\nTop {top_n} features for {modality} ({cell_type_prefix}-specific):")
            print("-" * 50)
            
            for i, feature in enumerate(aggregated_features[modality][:top_n]):
                print(f"{i+1:2d}. {feature['feature_name']:25s} | "
                      f"Mean |SHAP|: {feature['mean_abs_shap']:8.4f}")
    
    return aggregated_features


pvalb_features = analyze_pvalb_features(shap_values, adatas_all, pseudo_label, 
                                       cell_type_prefix='Pvalb', top_n=10)

In [None]:
# Fix the device mismatch issue
def fix_model_device(model, target_device='mps'):
    """
    Ensure all model components are on the same device
    """
    print(f"Moving model to {target_device}...")
    
    # Move the main model
    model.model = model.model.to(target_device)
    
    # Explicitly move all subcomponents
    model.model.encoders = model.model.encoders.to(target_device)
    model.model.decoders = model.model.decoders.to(target_device)
    model.model.fusers = model.model.fusers.to(target_device)
    model.model.latent_projector = model.model.latent_projector.to(target_device)
    model.model.projectors = model.model.projectors.to(target_device)
    model.model.clusters = model.model.clusters.to(target_device)
    
    # Update device_in_use
    model.model.device_in_use = target_device
    
    # Verify all parameters are on the correct device
    for name, param in model.model.named_parameters():
        if param.device.type != target_device:
            print(f"Warning: {name} is on {param.device}, moving to {target_device}")
            param.data = param.data.to(target_device)
    
    print("Model device fix complete!")
    return model

# Apply the fix
model = fix_model_device(model, target_device='mps')

# Now try inference again
all_adata_fused = model.infer(_)
all_adata_fused.obs['label'] = list(_[0].obs['label'])
all_adata_fused.obs['label_less'] = [ct.split('-')[0] for ct in all_adata_fused.obs['label'].values]

In [None]:
len(all_adata_fused)

In [216]:
from sklearn.utils.multiclass import unique_labels

all_pseudo_label = np.array(all_adata_fused.obs['predicted_label'].values)
cmat = confusion_matrix(all_adata_fused.obs['label'], all_pseudo_label)
ri, ci = linear_sum_assignment(-cmat)
ordered_all = cmat[np.ix_(ri, ci)]
major_sub_names = {}
pred_labels_re_order = copy.deepcopy(all_pseudo_label)
for re_oder,(lb_correct,lb) in enumerate(zip(unique_labels(all_adata_fused.obs['label'], all_pseudo_label)[ri],
                               unique_labels(all_adata_fused.obs['label'], all_pseudo_label)[ci])):
 idx = all_pseudo_label==lb
 if any(idx):
   nm = '-'.join(lb_correct.split('-')[:-1])
   if nm in major_sub_names.keys():
     major_sub_names[nm]+=1
   else:
     major_sub_names[nm]=1
   
   pred_labels_re_order[idx] = f'{nm}-{major_sub_names[nm]}-Uni'

all_adata_fused.obs['predicted_label'] = pred_labels_re_order

In [None]:
test_sample = [torch.tensor(ad_x.X, device=model.device) for ad_x in _]
shap_values = e.shap_values(test_sample, check_additivity=False)
print(f"Done! Shape: {len(shap_values)} modalities")
for i, sv in enumerate(shap_values):
    print(f"Modality {i}: {sv.shape}")

In [None]:
# For all cells
analyze_pvalb_features(shap_values, _, all_pseudo_label, 
                                       cell_type_prefix='Pvalb', top_n=10)

In [441]:
import shapiq
import torch
import numpy as np
from src.modules import submodel_clus

# 1. Extract the clustering submodel
clustering_submodel = submodel_clus(model.model).to(model.device)

# 2. Create a wrapper for the submodel
def submodel_wrapper(X):
    """
    Wrapper to make submodel_clus compatible with SHAPIQ
    X: numpy array of shape (n_samples, total_features)
    Returns: numpy array of predictions
    """
    batch_size = X.shape[0]
    
    # Split X back into modalities based on your original dimensions
    # You'll need to know the feature dimensions for each modality
    modality_dims = [1252, 68, 514]  # Replace with your actual dimensions from config
    
    modalities = []
    start_idx = 0
    for dim in modality_dims:
        end_idx = start_idx + dim
        modality_data = X[:, start_idx:end_idx]
        modalities.append(torch.tensor(modality_data, dtype=torch.float32, device=model.device))
        start_idx = end_idx
    
    # Run through the clustering submodel
    clustering_submodel.eval()
    with torch.no_grad():
        # submodel_clus.forward() expects individual modality tensors
        predictions = clustering_submodel(*modalities)
        
        # Convert to numpy
        if hasattr(predictions, 'cpu'):
            predictions = predictions.cpu().numpy()
        
        return predictions

# 3. Prepare your data (same as before)
def prepare_data_for_shapiq(adatas):
    """Convert AnnData objects to flattened numpy array"""
    all_features = []
    for adata in adatas:
        if hasattr(adata.X, 'toarray'):
            features = adata.X.toarray()
        else:
            features = adata.X
        all_features.append(features)
    
    X = np.concatenate(all_features, axis=1)
    return X

# # 4. Set up the explainer with the submodel
# X = prepare_data_for_shapiq(_)

# explainer = shapiq.TabularExplainer(
#     model=submodel_wrapper,
#     data=X,
#     index="SV",  # or "k-SII"
#     max_order=1
# )

# # 5. Compute explanations
# explanations = explainer.explain(X[0], budget=256)

In [None]:
explanations

In [None]:
import shapiq
import torch
import numpy as np

# 1. Create a simple wrapper for just ephys data
class EphysSubmodelWrapper:
    def __init__(self, full_model, device):
        # Extract just the ephys encoder and relevant parts
        self.device = device
        self.encoder = full_model.encoders[1]  # Ephys is index 1
        self.fuser = full_model.fusers[full_model.best_head]  # Use best head
        self.projector = full_model.projectors[full_model.best_head]
        self.cluster = full_model.clusters[full_model.best_head]
        self.prob_layer = full_model.prob_layer
        
    def __call__(self, X):
        """
        X: numpy array of shape (n_samples, n_ephys_features)
        Returns: numpy array of predictions
        """
        if isinstance(X, np.ndarray):
            X = torch.tensor(X, dtype=torch.float32, device=self.device)
        
        with torch.no_grad():
            # Encode ephys data
            latent = self.encoder(X)
            
            # For single modality, we need to simulate the fusing step
            # The fuser expects multiple modalities, so we'll just use the ephys latent
            # or bypass fusing for single modality
            
            # Project to clustering space
            hidden = self.projector(latent)
            
            # Get cluster outputs
            cluster_output = self.cluster(hidden)
            
            # Convert to probabilities
            probs = self.prob_layer(cluster_output)
            
            # Return class predictions or probabilities
            predictions = torch.argmax(probs, dim=1)  # Class predictions
            
            return predictions.cpu().numpy()

# 2. Prepare just the ephys data
def prepare_ephys_data(adatas):
    """Extract just the ephys data (modality index 1)"""
    ephys_adata = adatas[1]  # Ephys is typically index 1
    
    if hasattr(ephys_adata.X, 'toarray'):
        X = ephys_adata.X.toarray()
    else:
        X = ephys_adata.X
    
    return X

# 3. Set up the explainer
# Prepare your ephys data
X_ephys = prepare_ephys_data(_)

# Create the wrapper
ephys_wrapper = EphysSubmodelWrapper(model.model, model.device)

# Create explainer
explainer = shapiq.TabularExplainer(
    model=ephys_wrapper,
    data=X_ephys,
    index="SV",  # Start with Shapley Interaction Index
    max_order=1   # Start with pairwise interactions
)

# # 4. Compute explanations for a subset
# print(f"Ephys data shape: {X_ephys.shape}")
# explanations = explainer.explain(X_ephys[0], budget=256)  # Explain first 10 samples

# explanations

In [335]:
cell_type_mask = np.array([label.startswith('Pvalb') for label in all_pseudo_label])
cell_type_indices = np.where(cell_type_mask)[0]

In [None]:
cell_type_indices

In [324]:
explanation = explainer.explain(X_ephys[0], budget=256)  # Explain first 10 samples


In [None]:
explanation.plot_force(feature_names=adata_ephys_raw.var_names, abbreviate=False)

In [337]:
explanations = []
explainer = shapiq.TabularExplainer(
    model=ephys_wrapper,
    data=X_ephys,
    index="SV",  # Start with Shapley Interaction Index
    max_order=1   # Start with pairwise interactions
)
for instance_id in cell_type_indices:
    x_explain = X_ephys[instance_id]
    si = explainer.explain(x=x_explain, budget=256)
    explanations.append(si)

In [342]:
def get_explanations(explainer: shapiq.Explainer, indices, X, budget=256):
    explanations = []
    for instance_id in indices:
        x_explain = X[instance_id]
        si = explainer.explain(x=x_explain, budget=budget)
        explanations.append(si)
    return explanations

In [351]:
pvalb_explanations = get_explanations(explainer, cell_type_indices, X_ephys, 5000)

In [None]:
shapiq.plot.bar_plot(pvalb_explanations, feature_names=adata_ephys_raw.var_names, show=True, abbreviate=False)

In [None]:
shapiq.plot.bar_plot(explanations, feature_names=adata_ephys_raw.var_names, show=True, abbreviate=False)

In [None]:
X_ephys.shape

### Try to reproduce with KernalSHAP estimation ###

In [353]:
from shapiq.approximator import KernelSHAP
approximator = KernelSHAP(n=68) # 68 Ephys features

In [354]:
kernal_shap_explainer = shapiq.TabularExplainer(
    model=ephys_wrapper,
    data=X_ephys,
    approximator=approximator,
    index="SV",
    max_order=1
)

In [359]:
pvalb__kernal_shap_explanations = get_explanations(kernal_shap_explainer, cell_type_indices, X_ephys, 1000)

In [None]:
shapiq.plot.bar_plot(pvalb__kernal_shap_explanations, feature_names=adata_ephys_raw.var_names, show=True, abbreviate=False)

In [None]:
from shapiq.approximator import SHAPIQ

In [389]:
approximator = SHAPIQ(n=68, max_order=2, index='k-SII') # 68 Ephys features
kernal_shap_explainer_2_order = shapiq.TabularExplainer(
    model=ephys_wrapper,
    data=X_ephys,
    approximator=approximator,
    index="k-SII",
    max_order=2,
    verbose=True
)

In [None]:
# With approximator, budget = 500, took 1 minute 19 seconds
# With no approximator, same budget, took 1 minute 24 seconds, the number of coalitions calculated is a bit smaller, 6.12 vs 6.34. Results are the same
si = kernal_shap_explainer_2_order.explain(x=X_ephys[5], budget=1000)

In [None]:
si

In [None]:
true_labels = all_adata_fused.obs['label'].values
print(f"Cell type is {all_pseudo_label[5]} for predicted label")
print(f"Cell type is {true_labels[5]} for true label")
shapiq.plot.bar_plot([si], feature_names=adata_ephys_raw.var_names, show=True, abbreviate=False)

In [None]:
# si.plot_network(feature_names=adata_ephys_raw.var_names, n_interactions=10)
shapiq.plot.si_graph_plot(si, draw_threshold=3.15, show=True, plot_explanation=True)

### SHAP IQ With Approximator

In [365]:
n=X.shape[1]

In [454]:

from shapiq.approximator import KernelSHAPIQ
shapiq_approximator_2_order = KernelSHAPIQ(n=n, max_order=2, index="k-SII")

explainer = shapiq.TabularExplainer(
    model=submodel_wrapper,
    data=X,
    approximator=shapiq_approximator_2_order,
    index="k-SII",  # or "k-SII"
    max_order=2,
    verbose=True
)


### Run the calculation on all data

In [462]:
import numpy as np

def map_shap_to_features_fixed(shap_values, adatas, sample_idx=0, class_idx=None):
    """
    Map SHAP values back to original feature names
    
    Args:
        shap_values: SHAP values - list of arrays [modality] with shape (samples, features, classes)
        adatas: List of AnnData objects for each modality
        sample_idx: Which sample to analyze
        class_idx: Which class to analyze (if None, uses mean across all classes)
    """
    modality_names = ['RNA', 'Ephys', 'Morph']
    results = {}

    for mod_idx, (shap_mod, adata, mod_name) in enumerate(zip(shap_values, adatas, modality_names)):
        
        # shap_mod has shape (n_samples, n_features, n_classes)
        if class_idx is not None:
            # Use specific class
            sample_shap = shap_mod[sample_idx, :, class_idx]  # Shape: (n_features,)
        else:
            # Average across all classes (mean absolute SHAP values)
            sample_shap = np.mean(np.abs(shap_mod[sample_idx, :, :]), axis=1)  # Shape: (n_features,)

        # Get feature names
        if hasattr(adata, 'var_names'):
            feature_names = list(adata.var_names)
        elif hasattr(adata, 'var') and 'gene_symbols' in adata.var.columns:
            feature_names = list(adata.var['gene_symbols'])
        else:
            feature_names = [f"{mod_name}_Feature_{i}" for i in range(len(sample_shap))]

        # Create feature importance list
        feature_importance = []
        for i, (shap_val, feature_name) in enumerate(zip(sample_shap, feature_names)):
            if hasattr(shap_val, 'item'):
                shap_scalar = shap_val.item()
            else:
                shap_scalar = float(shap_val)

            feature_importance.append({
                'feature_name': feature_name,
                'shap_value': shap_scalar,
                'abs_shap_value': abs(shap_scalar),
                'feature_index': i
            })

        # Sort by absolute SHAP value
        feature_importance.sort(key=lambda x: x['abs_shap_value'], reverse=True)
        results[mod_name] = feature_importance

    return results


def get_top_features_and_data(shap_values, adatas_all, pseudo_label, 
                              top_rna=100, top_ephys=68, top_morph=100, 
                              cell_type_prefix='Pvalb'):
    """
    Get top features for each modality and return feature matrix + feature names
    
    Args:
        shap_values: SHAP values - list of arrays [modality] with shape (samples, features, classes)
        adatas_all: List of AnnData objects for each modality
        pseudo_label: Cell type labels
        top_rna: Number of top RNA features to select
        top_ephys: Number of top Ephys features to select
        top_morph: Number of top Morph features to select
        cell_type_prefix: Cell type prefix to analyze
    
    Returns:
        feature_matrix: numpy array of shape (n_cells, total_features)
        feature_names: list of feature names matching the column order
        top_features_info: dictionary with detailed feature information
        feature_indices: dictionary mapping modality to list of original indices
    """
    
    # Get indices of neurons with the specified prefix
    cell_type_mask = np.array([label.startswith(cell_type_prefix) for label in pseudo_label])
    cell_type_indices = np.where(cell_type_mask)[0]
    
    if len(cell_type_indices) == 0:
        print(f"No {cell_type_prefix} neurons found!")
        return None, None, None
    
    print(f"Found {len(cell_type_indices)} {cell_type_prefix} neurons")
    
    # Show the subtypes found
    subtypes = set([label for label in pseudo_label if label.startswith(cell_type_prefix)])
    print(f"Subtypes: {subtypes}")
    print("=" * 60)
    
    # Analyze each neuron and aggregate results
    all_features = {mod: [] for mod in ['RNA', 'Ephys', 'Morph']}
    
    for neuron_idx in cell_type_indices:
        # Get SHAP mapping for this neuron using the fixed function
        mapped_shap = map_shap_to_features_fixed(shap_values, adatas_all, 
                                               sample_idx=neuron_idx, class_idx=None)
        
        # Collect features from each modality
        for modality in mapped_shap:
            all_features[modality].extend(mapped_shap[modality])
    
    # Aggregate features by taking mean across neurons
    aggregated_features = {}
    top_features_per_modality = {}
    
    for modality in all_features:
        if not all_features[modality]:
            continue
            
        feature_groups = {}
        for feature_data in all_features[modality]:
            feature_name = feature_data['feature_name']
            if feature_name not in feature_groups:
                feature_groups[feature_name] = []
            feature_groups[feature_name].append(feature_data['abs_shap_value'])
        
        aggregated = []
        for feature_name, shap_values_list in feature_groups.items():
            mean_shap = np.mean(shap_values_list)
            aggregated.append({
                'feature_name': feature_name,
                'mean_abs_shap': mean_shap,
                'n_neurons': len(shap_values_list)
            })
        
        aggregated.sort(key=lambda x: x['mean_abs_shap'], reverse=True)
        aggregated_features[modality] = aggregated
    
    # Select top features for each modality
    top_counts = {'RNA': top_rna, 'Ephys': top_ephys, 'Morph': top_morph}
    
    for modality in ['RNA', 'Ephys', 'Morph']:
        if modality in aggregated_features:
            top_n = top_counts[modality]
            top_features_per_modality[modality] = aggregated_features[modality][:top_n]
            
            print(f"\nTop {top_n} features for {modality} ({cell_type_prefix}-specific):")
            print("-" * 50)
            
            for i, feature in enumerate(top_features_per_modality[modality]):
                print(f"{i+1:2d}. {feature['feature_name']:25s} | "
                      f"Mean |SHAP|: {feature['mean_abs_shap']:8.4f}")
    
    # Now extract the actual feature data for all cells
    total_cells = len(pseudo_label)
    total_features = sum(top_counts.values())
    
    # Initialize feature matrix
    feature_matrix = np.zeros((total_cells, total_features))
    feature_names = []
    feature_indices = {}  # Store original indices for each modality
    
    # Get the original feature indices for selected top features
    feature_col_idx = 0
    
    for mod_idx, modality in enumerate(['RNA', 'Ephys', 'Morph']):
        if modality not in top_features_per_modality:
            continue
            
        adata = adatas_all[mod_idx]
        feature_indices[modality] = []
        
        # Get feature names from adata
        if hasattr(adata, 'var_names'):
            all_feature_names = list(adata.var_names)
        elif hasattr(adata, 'var') and 'gene_symbols' in adata.var.columns:
            all_feature_names = list(adata.var['gene_symbols'])
        else:
            all_feature_names = [f"{modality}_Feature_{i}" for i in range(adata.shape[1])]
        
        # For each top feature, find its index and extract data
        for feature_info in top_features_per_modality[modality]:
            feature_name = feature_info['feature_name']
            
            # Find the original feature index
            try:
                original_idx = all_feature_names.index(feature_name)
                feature_indices[modality].append(original_idx)
            except ValueError:
                # If feature name not found, skip
                print(f"Warning: Feature {feature_name} not found in {modality} data")
                continue
            
            # Extract feature values for all cells
            if hasattr(adata, 'X'):
                if hasattr(adata.X, 'toarray'):
                    feature_values = adata.X.toarray()[:, original_idx]
                else:
                    feature_values = adata.X[:, original_idx]
            else:
                print(f"Warning: No data matrix found for {modality}")
                feature_values = np.zeros(total_cells)
            
            # Add to feature matrix
            feature_matrix[:, feature_col_idx] = feature_values
            feature_names.append(f"{modality}_{feature_name}")
            feature_col_idx += 1
    
    print(f"\nCreated feature matrix: {feature_matrix.shape}")
    print(f"Total features: {len(feature_names)}")
    print(f"Expected features: {total_features}")
    
    return feature_matrix, feature_names, top_features_per_modality, feature_indices


def analyze_pvalb_features_matrix(shap_values, adatas_all, pseudo_label, 
                                 cell_type_prefix='Pvalb', 
                                 top_rna=100, top_ephys=68, top_morph=100):
    """
    Wrapper function that returns the feature matrix and names for the specified cell type
    
    Returns:
        feature_matrix: numpy array of shape (n_cells, 268)
        feature_names: list of 268 feature names
        feature_indices: dictionary mapping modality to original indices
    """
    
    feature_matrix, feature_names, top_features_info, feature_indices = get_top_features_and_data(
        shap_values, adatas_all, pseudo_label,
        top_rna=top_rna, top_ephys=top_ephys, top_morph=top_morph,
        cell_type_prefix=cell_type_prefix
    )
    
    return feature_matrix, feature_names, feature_indices


def create_submodel_wrapper_with_feature_mapping(clustering_submodel, feature_indices, adatas_all,
                                                original_modality_dims=[1252, 68, 514],
                                                selected_modality_dims=[100, 68, 100]):
    """
    Create a wrapper that maps selected features back to original feature space using feature means as baseline
    
    Args:
        clustering_submodel: Your original clustering model
        feature_indices: Dictionary with original feature indices for each modality
        adatas_all: List of AnnData objects to compute feature means from
        original_modality_dims: Original dimensions [RNA, Ephys, Morph]
        selected_modality_dims: Selected dimensions [RNA, Ephys, Morph]
    
    Returns:
        Wrapper function that takes selected features and returns predictions
    """
    
    # Pre-compute feature means for each modality
    modality_means = {}
    modality_names = ['RNA', 'Ephys', 'Morph']
    
    print("Computing feature means for baseline...")
    for mod_idx, (modality, adata) in enumerate(zip(modality_names, adatas_all)):
        if hasattr(adata, 'X'):
            if hasattr(adata.X, 'toarray'):
                data_matrix = adata.X.toarray()
            else:
                data_matrix = adata.X
            modality_means[modality] = np.mean(data_matrix, axis=0)
        else:
            modality_means[modality] = np.zeros(original_modality_dims[mod_idx])
        print(f"{modality} means computed: shape {modality_means[modality].shape}")
    
    def submodel_wrapper(X):
        """
        Wrapper to make submodel_clus compatible with SHAPIQ
        X: numpy array of shape (n_samples, total_selected_features) - only top features
        Returns: numpy array of predictions
        """
        import torch
        
        batch_size = X.shape[0]
        full_modalities = []
        
        # Split the selected features back into modalities
        selected_start_idx = 0
        for mod_idx, (modality, orig_dim, sel_dim) in enumerate(zip(modality_names, 
                                                                    original_modality_dims, 
                                                                    selected_modality_dims)):
            
            # Start with feature means as baseline (broadcast to batch size)
            full_modality_data = np.tile(modality_means[modality], (batch_size, 1))
            
            # Extract selected features for this modality
            selected_end_idx = selected_start_idx + sel_dim
            selected_features = X[:, selected_start_idx:selected_end_idx]
            
            # Map selected features back to original positions
            if modality in feature_indices:
                original_indices = feature_indices[modality]
                for i, orig_idx in enumerate(original_indices):
                    if i < selected_features.shape[1]:  # Safety check
                        full_modality_data[:, orig_idx] = selected_features[:, i]
            
            # Convert to torch tensor
            full_modalities.append(torch.tensor(full_modality_data, dtype=torch.float32, 
                                              device=torch.device('mps')))
            
            selected_start_idx = selected_end_idx
        
        # Run through the clustering submodel
        clustering_submodel.eval()
        with torch.no_grad():
            predictions = clustering_submodel(*full_modalities)
            
            # Convert to numpy
            if hasattr(predictions, 'cpu'):
                predictions = predictions.cpu().numpy()
            
            return predictions
    
    return submodel_wrapper


def get_computationally_feasible_features(shap_values, adatas_all, pseudo_label,
                                        total_budget_features=20,
                                        cell_type_prefix='Pvalb'):
    """
    Get a computationally feasible subset of top features while maintaining model performance
    
    Args:
        total_budget_features: Total number of features to select (should be ≤ 25 for 2^n feasibility)
        
    Returns:
        feature_matrix: numpy array of shape (n_cells, total_budget_features)
        feature_names: list of feature names
        feature_indices: mapping back to original indices
        modality_allocation: how features were distributed across modalities
    """
    
    # Get the full top features analysis first
    full_feature_matrix, full_feature_names, top_features_info, full_feature_indices = get_top_features_and_data(
        shap_values, adatas_all, pseudo_label,
        top_rna=100, top_ephys=68, top_morph=100,  # Get the full analysis
        cell_type_prefix=cell_type_prefix
    )
    
    # Now intelligently allocate the budget across modalities
    # Proportional to original importance, but ensure each modality gets at least some features
    
    # Count how many features we actually have per modality
    modality_counts = {}
    for modality in ['RNA', 'Ephys', 'Morph']:
        if modality in top_features_info:
            modality_counts[modality] = len(top_features_info[modality])
        else:
            modality_counts[modality] = 0
    
    print(f"Available features per modality: {modality_counts}")
    
    # Allocate budget proportionally, but ensure minimum representation
    min_per_modality = max(1, total_budget_features // 10)  # At least 10% each, minimum 1
    remaining_budget = total_budget_features - (3 * min_per_modality)
    
    # Distribute remaining budget proportionally
    total_available = sum(modality_counts.values())
    allocation = {}
    
    for modality in ['RNA', 'Ephys', 'Morph']:
        if total_available > 0:
            proportion = modality_counts[modality] / total_available
            extra = int(remaining_budget * proportion)
            allocation[modality] = min_per_modality + extra
        else:
            allocation[modality] = min_per_modality
    
    # Adjust if we went over budget
    total_allocated = sum(allocation.values())
    if total_allocated > total_budget_features:
        # Reduce proportionally
        scale_factor = total_budget_features / total_allocated
        for modality in allocation:
            allocation[modality] = max(1, int(allocation[modality] * scale_factor))
    
    print(f"Feature allocation: {allocation} (total: {sum(allocation.values())})")
    
    # Select the top N features from each modality
    selected_features = {}
    selected_feature_indices = {}
    
    for modality in ['RNA', 'Ephys', 'Morph']:
        if modality in top_features_info and allocation[modality] > 0:
            n_select = min(allocation[modality], len(top_features_info[modality]))
            selected_features[modality] = top_features_info[modality][:n_select]
            selected_feature_indices[modality] = full_feature_indices[modality][:n_select]
    
    # Build the reduced feature matrix
    total_cells = len(pseudo_label)
    actual_total_features = sum(len(selected_features.get(mod, [])) for mod in ['RNA', 'Ephys', 'Morph'])
    
    reduced_feature_matrix = np.zeros((total_cells, actual_total_features))
    reduced_feature_names = []
    reduced_feature_indices = {}
    
    col_idx = 0
    for mod_idx, modality in enumerate(['RNA', 'Ephys', 'Morph']):
        if modality not in selected_features:
            reduced_feature_indices[modality] = []
            continue
            
        adata = adatas_all[mod_idx]
        reduced_feature_indices[modality] = []
        
        for feature_info in selected_features[modality]:
            feature_name = feature_info['feature_name']
            
            # Find the original feature index
            if hasattr(adata, 'var_names'):
                all_feature_names = list(adata.var_names)
            elif hasattr(adata, 'var') and 'gene_symbols' in adata.var.columns:
                all_feature_names = list(adata.var['gene_symbols'])
            else:
                all_feature_names = [f"{modality}_Feature_{i}" for i in range(adata.shape[1])]
            
            try:
                original_idx = all_feature_names.index(feature_name)
                reduced_feature_indices[modality].append(original_idx)
                
                # Extract feature values
                if hasattr(adata.X, 'toarray'):
                    feature_values = adata.X.toarray()[:, original_idx]
                else:
                    feature_values = adata.X[:, original_idx]
                
                reduced_feature_matrix[:, col_idx] = feature_values
                reduced_feature_names.append(f"{modality}_{feature_name}")
                col_idx += 1
                
            except ValueError:
                print(f"Warning: Feature {feature_name} not found in {modality}")
    
    # Update the modality dimensions for the reduced set
    reduced_modality_dims = [len(reduced_feature_indices.get(mod, [])) for mod in ['RNA', 'Ephys', 'Morph']]
    
    print(f"\nReduced feature matrix: {reduced_feature_matrix.shape}")
    print(f"Modality dimensions: {reduced_modality_dims}")
    print(f"2^{actual_total_features} = {2**actual_total_features:,} combinations (feasible: {2**actual_total_features < 1e8})")
    
    return reduced_feature_matrix, reduced_feature_names, reduced_feature_indices, reduced_modality_dims


In [None]:
# Step 1: Get computationally feasible feature subset
reduced_feature_matrix, reduced_feature_names, reduced_feature_indices, reduced_modality_dims = get_computationally_feasible_features(
    shap_values, _, all_pseudo_label,
    total_budget_features=20,  # Adjust based on your computational budget
    cell_type_prefix='Pvalb'
)

In [None]:
# Step 2: Create wrapper that maintains full model performance
submodel_wrapper = create_submodel_wrapper_with_feature_mapping(
    clustering_submodel=clustering_submodel,
    feature_indices=reduced_feature_indices,
    adatas_all=_,
    original_modality_dims=[1252, 68, 514],     # Your original dimensions
    selected_modality_dims=reduced_modality_dims  # The reduced dimensions
)

In [473]:
n = reduced_feature_matrix.shape[1]

In [474]:
shapiq_approximator_2_order = KernelSHAPIQ(n=n, max_order=2, index="k-SII")
# Step 3: Run SHAPIQ (now computationally feasible!)
explainer = shapiq.TabularExplainer(
    model=submodel_wrapper,
    data=reduced_feature_matrix,  # Shape: (n_cells, ~20)
    approximator=shapiq_approximator_2_order,
    index="k-SII",
    max_order=2,
    verbose=True
)

In [None]:
# Test prediction variance
preds = submodel_wrapper(reduced_feature_matrix[:100])
print(f"Prediction std: {np.std(preds)}")
print(f"Prediction range: [{np.min(preds):.6f}, {np.max(preds):.6f}]")

In [None]:
reduced_feature_names


In [None]:
a = explainer.explain(x=reduced_feature_matrix[5], budget = 10000)

In [None]:
explanations = explainer.explain_X(reduced_feature_matrix, budget=10000, n_jobs=8, random_state=0)

In [None]:
print(a.get_top_k_interactions(10))

In [None]:
shapiq.plot.bar_plot([a], feature_names=reduced_feature_names, show=True, abbreviate=False)

In [None]:
a.plot_si_graph(feature_names=reduced_feature_names, show=True)

In [None]:
fig, axes = shapiq.plot.stacked_bar_plot(
    interaction_values=a,
    feature_names=reduced_feature_names,
)
plt.show()

In [None]:
shapiq.plot.si_graph_plot(interaction_values=a, feature_names=reduced_feature_names, show=True, compactness=10, n_interactions=10)

In [None]:
a.plot_network(feature_names=reduced_feature_names, show=True)