In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import GridSearchCV
from scipy.stats import mannwhitneyu

# ==============================================================================
#  scPred Class Definition (from previous response)
# ==============================================================================
class scPred:
    """
    A Python implementation of the scPred algorithm for single-cell classification.
    """

    def __init__(self, probability_threshold=0.9):
        self.scaler = None
        self.pca = None
        self.classifier = None
        self.informative_pcs = None
        self.class_labels = None
        self.probability_threshold = probability_threshold

    def _cpm_transform(self, data):
        if hasattr(data, "toarray"):
            data = data.toarray()
        data = data.astype(float)
        total_counts = np.sum(data, axis=1, keepdims=True)
        total_counts[total_counts == 0] = 1
        return (data / total_counts) * 1_000_000

    def _log2_transform(self, data):
        return np.log2(data + 1)

    def train(self, X_train, y_train, p_value_threshold=0.05, variance_threshold=0.0001, perform_hpt=False):
        print("Starting training process...")
        if hasattr(X_train, "toarray"):
            X_train = X_train.toarray()

        print("Step 1: Preprocessing data (CPM, Log2)...")
        X_cpm = self._cpm_transform(X_train)
        X_log = self._log2_transform(X_cpm)
        
        print("Step 2: Scaling data and performing PCA...")
        self.scaler = StandardScaler()
        X_scaled = self.scaler.fit_transform(X_log)
        self.pca = PCA()
        pc_scores = self.pca.fit_transform(X_scaled)

        print("Step 3: Selecting informative PCs...")
        explained_variance = self.pca.explained_variance_ratio_
        significant_variance_indices = np.where(explained_variance > variance_threshold)[0]
        pc_scores_filtered = pc_scores[:, significant_variance_indices]
        
        self.class_labels = np.unique(y_train)
        informative_pc_union = set()

        for cell_type in self.class_labels:
            print(f"  - Finding informative PCs for class: {cell_type}")
            in_class_mask = (y_train == cell_type)
            p_values = []
            for i in range(pc_scores_filtered.shape[1]):
                pc_col = pc_scores_filtered[:, i]
                group1 = pc_col[in_class_mask]
                group2 = pc_col[~in_class_mask]
                if len(np.unique(group1)) > 1 and len(np.unique(group2)) > 1:
                    _, p_val = mannwhitneyu(group1, group2, alternative='two-sided')
                    p_values.append(p_val)
                else:
                    p_values.append(1.0)
            
            significant_pcs_for_class = np.where(np.array(p_values) < p_value_threshold)[0]
            informative_pc_union.update(significant_variance_indices[significant_pcs_for_class])

        self.informative_pcs = sorted(list(informative_pc_union))
        if not self.informative_pcs:
            raise ValueError("No informative PCs found. Try relaxing the thresholds.")
        print(f"Found {len(self.informative_pcs)} unique informative PCs.")
        
        X_train_final = pc_scores[:, self.informative_pcs]
        
        print("Step 4: Training the SVM classifier...")
        svm = SVC(probability=True, kernel='rbf')
        if perform_hpt:
            print("  - Performing hyperparameter tuning (GridSearchCV)...")
            param_grid = {'estimator__C': [0.1, 1, 10], 'estimator__gamma': ['scale', 'auto']}
            self.classifier = GridSearchCV(OneVsRestClassifier(svm), param_grid, cv=3)
        else:
            self.classifier = OneVsRestClassifier(SVC(probability=True, kernel='rbf', C=1.0, gamma='scale'))
        self.classifier.fit(X_train_final, y_train)
        print("Training complete.")

    def predict(self, X_test):
        if self.scaler is None or self.pca is None or self.classifier is None:
            raise RuntimeError("Model not trained. Call train() first.")
        print("Starting prediction process...")
        if hasattr(X_test, "toarray"):
            X_test = X_test.toarray()

        print("Step 1: Preprocessing query data...")
        X_cpm = self._cpm_transform(X_test)
        X_log = self._log2_transform(X_cpm)
        
        print("Step 2: Scaling and projecting data...")
        X_scaled = self.scaler.transform(X_log)
        pc_scores_test = self.pca.transform(X_scaled)
        
        print("Step 3: Predicting class probabilities...")
        X_test_final = pc_scores_test[:, self.informative_pcs]
        probabilities = self.classifier.predict_proba(X_test_final)
        
        print("Step 4: Assigning final labels...")
        max_probs = np.max(probabilities, axis=1)
        pred_indices = np.argmax(probabilities, axis=1)
        predicted_labels = self.class_labels[pred_indices]
        final_labels = np.where(max_probs >= self.probability_threshold, predicted_labels, "Unassigned")
        
        results_df = pd.DataFrame({'Predicted_Label': final_labels, 'Max_Probability': max_probs})
        for i, label in enumerate(self.class_labels):
            results_df[f'Prob_{label}'] = probabilities[:, i]
        print("Prediction complete.")
        return results_df

# ==============================================================================
#  Main Execution Block for pbmc3k dataset
# ==============================================================================
if __name__ == '__main__':
    print("--- scPred on Scanpy pbmc3k Dataset ---")
    
    # --- 1. Load and process data to get cell type labels ---
    print("\nStep 1: Loading and preprocessing pbmc3k dataset to generate labels...")
    # Load the dataset
    adata = sc.datasets.pbmc3k()
    
    # Store raw counts before normalization
    adata.raw = adata
    
    # Standard preprocessing to find clusters (our 'cell types')
    sc.pp.filter_cells(adata, min_genes=200)
    sc.pp.filter_genes(adata, min_cells=3)
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
    adata = adata[:, adata.var.highly_variable]
    sc.pp.scale(adata, max_value=10)
    sc.tl.pca(adata, svd_solver='arpack')
    sc.pp.neighbors(adata, n_neighbors=10, n_pcs=40)
    sc.tl.leiden(adata) # The 'leiden' column in adata.obs will be our labels
    
    print("Preprocessing complete. Labels are in 'leiden' column.")
    print("Cell type distribution:")
    print(adata.obs['leiden'].value_counts())

    # --- 2. Prepare data for scPred ---
    # We use the raw counts for scPred, as it does its own normalization
    # The AnnData object `adata.raw` conveniently stores the original data
    X = adata.raw.X
    y = adata.obs['leiden'].values

    print(f"\nData prepared for scPred:")
    print(f"  - Expression matrix shape: {X.shape}")
    print(f"  - Number of labels: {len(y)}\n")

    # --- 3. Split data into training and testing sets ---
    print("Step 2: Splitting data into training (75%) and testing (25%) sets...")
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.25, random_state=42, stratify=y
    )
    print(f"  - Training set size: {X_train.shape[0]} cells")
    print(f"  - Test set size: {X_test.shape[0]} cells\n")
    
    # --- 4. Initialize and train the scPred model ---
    model = scPred(probability_threshold=0.9)
    model.train(X_train, y_train)

    # --- 5. Predict on the test set and evaluate ---
    print("\nStep 3: Predicting labels for the test set and evaluating performance...")
    predictions_df = model.predict(X_test)
    
    print("\n--- Model Performance Evaluation ---")
    true_labels = y_test
    predicted_labels = predictions_df['Predicted_Label'].values
    
    print("\nClassification Report:")
    report = classification_report(true_labels, predicted_labels, zero_division=0)
    print(report)
