In [None]:
import numpy as np
import shap
import lightgbm as lgb
import multiprocessing as mp
from sklearn.tree import DecisionTreeClassifier
import fasttreeshap

def train_model(X, y):
    """Train a LightGBM model for multi-class classification."""
    model = DecisionTreeClassifier()
    model.fit(X, y)
    return model

def compute_shap_values(model, X_subsample):
    """Compute SHAP values for a subset of data."""
    shap_explainer = fasttreeshap.TreeExplainer(model, algorithm = 'auto', n_jobs = -1)
    shap_values = shap_explainer.shap_values(X_subsample, check_additivity=True)
    #shap_values = explainer.shap_values(X_subsample)  # Returns list of arrays (one per class)
    return np.array(shap_values)  # Shape: (num_classes, num_samples, num_features)

def permute_and_compute_shap(args):
    """Parallel function: computes SHAP values for permuted labels on a subset of data."""
    model, X, y, sample_idx = args
    y_permuted = np.random.permutation(y)  # Shuffle labels
    X_subsample = X[sample_idx]  # Select subset for faster computation
    return compute_shap_values(model, X_subsample)

def permutation_test_shap(X, y, num_permutations=200, num_cores=mp.cpu_count(), sample_fraction=0.2, early_stopping_threshold=0.001):
    """Highly optimized permutation test for SHAP values in a multi-class setting."""
    
    num_samples = X.shape[0]
    num_features = X.shape[1]
    num_classes = len(np.unique(y))

    # Step 1: Train once on real labels
    print("Training model...")
    model = train_model(X, y)

    # Step 2: Use subsampling for SHAP computation
    subsample_size = int(num_samples * sample_fraction)
    sample_idx = np.random.choice(num_samples, subsample_size, replace=False)  # Random sample
    observed_shap = compute_shap_values(model, X[sample_idx])  # Compute on subset

    # Step 3: Parallelized permutation SHAP computation
    print(f"Running permutation test with {num_permutations} permutations on {num_cores} cores...")
    pool = mp.Pool(processes=num_cores)
    permuted_shap_sum = np.zeros_like(observed_shap)  # Online aggregation
    significant_count = np.zeros_like(observed_shap)  # Track p-value convergence

    for i in range(num_permutations):
        permuted_shap = pool.apply_async(permute_and_compute_shap, [(model, X, y, sample_idx)]).get()
        
        # Update mean permuted SHAP on-the-fly
        permuted_shap_sum += (permuted_shap - permuted_shap_sum) / (i + 1)

        # Early stopping: if difference is below threshold, break early
        significant_count += (permuted_shap >= observed_shap)
        p_values = significant_count / (i + 1)

        if np.all(p_values < early_stopping_threshold) or np.all(p_values > (1 - early_stopping_threshold)):
            print(f"Early stopping at permutation {i+1}")
            break

    pool.close()
    pool.join()

    return p_values, observed_shap

# Example for Large Multi-Class Dataset
np.random.seed(42)
num_samples = 100
num_features = 10
num_classes = 10  # High number of unique labels

X = np.random.rand(num_samples, num_features)  # Simulated features
y = np.random.randint(0, num_classes, num_samples)  # Multi-class labels

# Run even more optimized permutation test
p_values, observed_shap = permutation_test_shap(X, y, num_permutations=100, num_cores=50, sample_fraction=0.2, early_stopping_threshold=0.001)

# Example: Print results for the first few features of the first class
for feature_idx in range(5):
    print(f"Feature {feature_idx}, Class 0: SHAP = {observed_shap[0, :, feature_idx]}, p-value = {p_values[0, :, feature_idx]}")


Training model...
Running permutation test with 100 permutations on 50 cores...


Process ForkPoolWorker-85:
Process ForkPoolWorker-74:
Process ForkPoolWorker-76:
Process ForkPoolWorker-83:
Process ForkPoolWorker-79:
Process ForkPoolWorker-84:
Process ForkPoolWorker-68:
Process ForkPoolWorker-71:
Process ForkPoolWorker-78:
Process ForkPoolWorker-61:
Process ForkPoolWorker-81:
Process ForkPoolWorker-77:
Process ForkPoolWorker-63:
Process ForkPoolWorker-73:
Process ForkPoolWorker-80:
Process ForkPoolWorker-86:
Process ForkPoolWorker-72:
Process ForkPoolWorker-87:
Process ForkPoolWorker-70:
Process ForkPoolWorker-65:
Process ForkPoolWorker-69:
Process ForkPoolWorker-75:
Process ForkPoolWorker-62:
Process ForkPoolWorker-82:
Process ForkPoolWorker-67:
Process ForkPoolWorker-58:
Process ForkPoolWorker-66:
Process ForkPoolWorker-60:
Process ForkPoolWorker-64:
Process ForkPoolWorker-59:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Process ForkPoo

KeyboardInterrupt: 

  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()


  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kw

In [None]:
clf.fit(X_transformed, y_genes)
shap_explainer = fasttreeshap.TreeExplainer(clf, algorithm = 'auto', n_jobs = -1)
shap_values = shap_explainer(X_transformed.toarray()[cell_type_indexes], check_additivity=True).values
average = np.mean(abs(shap_values), axis=0)