In [1]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from sklearn.impute import SimpleImputer
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import roc_curve, auc, accuracy_score, confusion_matrix, classification_report
import shap
import xgboost as xgb
import anndata as ad
import scipy.sparse
from tqdm import tqdm
import logging
import random
import h5py
import joblib

In [None]:
# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 设置全局随机种子
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

# 可调参数
MAX_CELLS_PER_LABEL = 500
HIGH_VAR_GENES = 1000
MIN_CELLS_PER_LABEL = 50
TARGET_CANCER_TYPES = 9  # 目标癌种数
MIN_CELLS_PER_CANCER = 10  # 每种癌种每标签的最小采样量

# 读取 Ensembl 转换表
ENSEMBL_MAP = pd.read_csv("ensemble-tran.csv")
ENSEMBL_DICT = dict(zip(ENSEMBL_MAP['ensembl'], ENSEMBL_MAP['feature']))


In [None]:
import os
os.environ["SCIPY_ARRAY_API"] = "1"
from sklearn.model_selection import GridSearchCV, cross_val_score, StratifiedKFold, train_test_split
from sklearn.metrics import make_scorer, accuracy_score, roc_auc_score,roc_curve, auc
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import StackingClassifier
from sklearn.feature_selection import SelectKBest, f_classif
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
from xgboost import XGBClassifier
import pandas as pd
import numpy as np
import warnings
import matplotlib.pyplot as plt
import shap
import pickle
from tabpfn import TabPFNClassifier
from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import AutoTabPFNClassifier
from scipy.stats import loguniform,randint, uniform
model_param_grid = {
    "Random_Forest": (RandomForestClassifier(random_state=42), {
        'n_estimators': [50, 100, 200, 300, 500],
        'max_depth': [None, 5, 10, 20, 30, 50],
        'min_samples_split': [2, 5, 10],
        'min_samples_leaf': [1, 2, 4],
        'max_features': ['auto', 'sqrt', 'log2'],
        'bootstrap': [True, False],
        'criterion': ['gini', 'entropy'],
        'class_weight': [None, 'balanced']
    }),
    "Logistic Regression": (LogisticRegression(max_iter=1000), {
        'penalty': ['l1', 'l2', 'elasticnet', 'none'],
        'C': loguniform(1e-4, 1e4), 
        'solver': ['liblinear', 'saga', 'lbfgs', 'newton-cg', 'sag'],
        'max_iter': [100, 200, 500, 1000],
        'l1_ratio': [0.0, 0.25, 0.5, 0.75, 1.0]
    }),
    "LightGBM": (LGBMClassifier(verbose=-1,random_state=42,device='gpu'), {
        'num_leaves': randint(20, 150),                 
        'max_depth': randint(3, 15),                    
        'learning_rate': uniform(0.01, 0.2),            
        'n_estimators': randint(50, 500),               
        'min_child_samples': randint(10, 100),          
        'subsample': uniform(0.5, 0.5),                 
        'colsample_bytree': uniform(0.5, 0.5),          
        'reg_alpha': uniform(0.0, 1.0),                 
        'reg_lambda': uniform(0.0, 1.0),                
        'boosting_type': ['gbdt', 'dart', 'goss'],      
    'objective': ['binary']
    }),
    "CatBoost": (CatBoostClassifier(task_type = 'GPU',verbose=0, random_state=42), {
        'iterations': randint(50, 100),                
        'learning_rate': uniform(0.01, 0.3),             
        'depth': randint(4, 6),                         
        'l2_leaf_reg': uniform(1, 10),                  
        'bagging_temperature': uniform(0, 1),            
        # 'border_count': randint(32, 255),           
        'border_count': randint(32, 64),       
        'random_strength': uniform(1, 20),               
        # 'leaf_estimation_iterations': randint(1, 10),   
        'leaf_estimation_iterations':  [1],   
        'scale_pos_weight': [1, 2, 5],               
        # 'bootstrap_type': ['Bayesian', 'Bernoulli', 'MVS'],  
        'bootstrap_type': ['Bayesian'], 
        # 'grow_policy': ['SymmetricTree', 'Depthwise', 'Lossguide']  
        'grow_policy': ['SymmetricTree'] 
    }),
    "XGBoost": (XGBClassifier(tree_mothod='gpu_hist',use_label_encoder=False, eval_metric='logloss', random_state=42), {
        'max_depth': [3, 5, 7, 10],
        'learning_rate': [0.01, 0.05, 0.1, 0.3],
        'n_estimators': [50, 100, 200, 300],
        'reg_lambda': [0.1, 1.0, 10.0, 100.0],
        'reg_alpha': [0.0, 0.5, 1.0, 5.0],
        'subsample': [0.6, 0.8, 1.0],
        'colsample_bytree': [0.6, 0.8, 1.0]
    }),
    # "Stacking": (StackingClassifier(
    #     estimators=[
    #         ('rf', RandomForestClassifier(n_estimators=50, random_state=42)),
    #         ('lr', LogisticRegression(max_iter=2000)),
    #         ('lgb', LGBMClassifier(verbose=-1))
    #     ],
    #     final_estimator=LogisticRegression(max_iter=2000),
    #     cv=5
    # ),{
    #     'estimators__lgb__learning_rate': [0.01, 0.1, 0.2],
    #     'estimators__lgb__n_estimators': [100, 200],
    # }),
    # "tabpfn":(TabPFNClassifier( device="cuda") ,{})
    # "tabpfn":(TabPFNClassifier( ) ,{})
}
warnings.filterwarnings("ignore")
def check_h5ad_file(filepath):
    try:
        with h5py.File(filepath, 'r') as f:
            logging.info(f"H5AD 文件 {filepath} 可读，包含键：{list(f.keys())}")
            f.close()
        return True
    except Exception as e:
        logging.error(f"H5AD 文件读取失败：{str(e)}")
        return False

def SamplePick(adata, types, cluster, label):
    if types == 'all':
        adata1 = adata[adata.obs[cluster].isin(valid_cell_types)].copy()
    else:
        adata1 = adata[adata.obs[cluster] == types].copy()
    logging.info(f"细胞类型 {types} 的样本量: {adata1.shape[0]}")

    cancer_types = ['HCC']

    lists = {'responder': [], 'non-responder': []}
    for lb in ['responder', 'non-responder']:     
        # adata_label = adata1[adata1.obs['Combined_outcome'] == lb].copy()
        adata_label = adata1[adata1.obs['response'] == lb].copy()
        if adata_label.shape[0] < MIN_CELLS_PER_LABEL:
            logging.warning(f"细胞类型 {types} 标签 {lb} 样本量 {adata_label.shape[0]} 小于 {MIN_CELLS_PER_LABEL}，跳过")
            # return None, None, None
        
        sampled_cells = []
        sampled_cancers = set()
        # 按癌种采样
        for cancer in cancer_types:
            adata_cancer=adata_label

            n_available = adata_cancer.shape[0]

            n_sample = n_available
            if n_sample > 0:
                random.seed(RANDOM_SEED)
                indices = random.sample(range(n_available), n_sample)
                sampled_cells.append(adata_cancer[indices])
                sampled_cancers.add(cancer)
                logging.info(f"细胞类型 {types} 标签 {lb} 癌种 {cancer} 采样 {n_sample} 个细胞")

        # 合并采样结果
        if sampled_cells:
            adata_sampled = ad.concat(sampled_cells)
            lists[lb].append(adata_sampled)
        else:
            logging.warning(f"细胞类型 {types} 标签 {lb} 无可用样本，跳过")
            # return None, None, None
    
    # 合并 Favourable 和 Unfavourable 的样本
    if lists['responder'] and lists['non-responder']:
        adata_combined = ad.concat(lists['responder'] + lists['non-responder'])
        logging.info(f"细胞类型 {types} 采样完成，样本量: {len(adata_combined)}")
        # 统计采样后的癌种分布
        sampled_cancer_dist = adata_combined.obs['Cancer_type_update'].value_counts().to_dict()
        logging.info(f"细胞类型 {types} 采样后的癌种分布: {sampled_cancer_dist}")
        return adata_combined, np.array(adata_combined.obs[label]), sampled_cancer_dist
    else:
        logging.warning(f"细胞类型 {types} 无足够标签类别（需要 responder 和 non-responder")
        return None, None, None

def TumorClassify(subset, label,model_name='XGBoost'):
    cv = 5
    import scipy.sparse as sp
    label = np.array(label)
    train_idx, test_idx = train_test_split(np.arange(len(label)), test_size=0.2, 
                                          random_state=RANDOM_SEED, stratify=label)
    le = LabelEncoder()
    y_train = le.fit_transform(label[train_idx])
    y_test = le.transform(label[test_idx])
    X = subset.X.toarray() if sp.issparse(subset.X) else subset.X
    gene_means = X.mean(axis=0)
    valid_gene_mask = gene_means > 1e-3
    subset_ = subset[:, valid_gene_mask].copy()


    adata_train = subset_[train_idx].copy()
    adata_test = subset_[test_idx].copy()

    sc.pp.highly_variable_genes(adata_train, n_top_genes=HIGH_VAR_GENES, flavor='cell_ranger')
    selected_genes = adata_train.var.highly_variable
    adata_train = adata_train[:, selected_genes]
    adata_test = adata_test[:, selected_genes]

    X_train = adata_train.X.toarray() if scipy.sparse.issparse(adata_train.X) else adata_train.X
    X_test = adata_test.X.toarray() if scipy.sparse.issparse(adata_test.X) else adata_test.X

    imputer = SimpleImputer(strategy='mean')
    X_train = imputer.fit_transform(X_train)
    X_test = imputer.transform(X_test)
    # for model_name in model_param_grid:
    model,param_grid = model_param_grid[model_name]
    if model_name == "tabpfn":
        if X_train.shape[1] > 500:
            print("Reducing dimensions for TabPFN...")
            selector = SelectKBest(score_func=f_classif, k=500)
            X_train = selector.fit_transform(X_train, y_train)
            X_test = selector.transform(X_test)
    base_score = cross_val_score(model, X_train, y_train, cv=cv, scoring='roc_auc')
    print(f"Default AUROC (CV): {base_score.mean():.4f}")
    logging.info(f"Default AUROC (CV): {base_score.mean():.4f}")
    # grid = GridSearchCV(model, param_grid, cv=cv, scoring='roc_auc', n_jobs=-1, verbose=0)
    # grid = RandomizedSearchCV(model, param_grid, cv=cv, scoring='roc_auc', n_jobs=-1, verbose=0)
    grid = RandomizedSearchCV(model, param_grid, cv=cv, scoring='roc_auc', n_jobs=1, verbose=0) #gpu
    grid.fit(X_train, y_train)
    print(f"Best AUROC (CV): {grid.best_score_:.4f}")
    logging.info(f"Best Params: {grid.best_params_}")
    logging.info(f"Best AUROC (CV): {grid.best_score_:.4f}")
    best_model = grid.best_estimator_
    y_train_pred = best_model.predict(X_train)
    y_train_proba = best_model.predict_proba(X_train)[:, 1]
    train_acc = accuracy_score(y_train, y_train_pred)
    train_fpr, train_tpr, _ = roc_curve(y_train, y_train_proba)
    train_auc = auc(train_fpr, train_tpr)
    logging.info(f"训练集准确率: {train_acc:.2f}, 训练集 AUC: {train_auc:.2f}")

    y_pred = best_model.predict(X_test)
    y_proba = best_model.predict_proba(X_test)[:, 1]
    acc = accuracy_score(y_test, y_pred)
    fpr, tpr, thresholds = roc_curve(y_test, y_proba)
    roc_auc = auc(fpr, tpr)

    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    y_pred_optimal = (y_proba >= optimal_threshold).astype(int)
    acc_optimal = accuracy_score(y_test, y_pred_optimal)
    logging.info(f"最优阈值: {optimal_threshold:.3f}, 优化后的准确率: {acc_optimal:.2f}")

    cm = confusion_matrix(y_test, y_pred)
    report = classification_report(y_test, y_pred, target_names=le.classes_)
    logging.info(f"混淆矩阵:\n{cm}")
    logging.info(f"分类报告:\n{report}")
    cell_type = list(set(adata_train.obs['Cell_Type'].values))
    if len(cell_type)>1:
        cell_type = 'all'
    else: 
        cell_type = cell_type[0]
    model_filename = f"{model_name}_model_{cell_type}.pkl"
    joblib.dump(best_model, model_filename)
    logging.info(f"模型已保存至: {model_filename}")

    le_filename = f"{model_name}_label_encoder_{cell_type}.pkl"
    joblib.dump(le, le_filename)
    logging.info(f"LabelEncoder 已保存至: {le_filename}")

    return acc, roc_auc, fpr, tpr, best_model, X_test, adata_test, le,X_train,adata_train,y_train,y_test

def plot_roc(fpr, tpr, auc_val, types,model_name='XGBoost'):
    if not os.path.exists(f'{model_name}_roc_{types}.pdf'):
        plt.figure()
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {auc_val:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'{model_name}_ROC Curve - {types}')
        plt.legend(loc='lower right')
        plt.tight_layout()
        plt.savefig(f'{model_name}_roc_{types}.pdf')
        plt.close()
def shap_analysis(model, X_test, adata_test, types,model_name='XGBoost'):
    file =f'feature_importance_{model_name}_{types}.pickle'
    if not os.path.exists(file) or not os.path.exists(f'{model_name}_shap_{types}.dot.pdf'):
        feature_names = adata_test.var_names.tolist()
        if isinstance(model, (RandomForestClassifier, LGBMClassifier, XGBClassifier, CatBoostClassifier)):
            explainer = shap.TreeExplainer(model, feature_names=feature_names)
        elif isinstance(model, LogisticRegression):
            explainer = shap.LinearExplainer(model, X_test, feature_perturbation="interventional")
        else:
            print(f"[警告] {str(model).split('(')[0]} 模型类型暂不支持 SHAP")

        shap_values = explainer(X_test, check_additivity=False)
        shap_vals = shap_values.values
        if shap_vals.shape[1] == X_test.shape[1] and shap_vals.ndim == 3:  # 特征在中间（错误）
            shap_vals = np.transpose(shap_vals, (0, 2, 1))

            shap_mean = np.abs(shap_vals).mean(axis=(0, 1))
        else:
            shap_mean = np.abs(shap_vals).mean(axis=0)
        
        ensembl_ids = adata_test.var.index.tolist()
        gene_names = [ENSEMBL_DICT.get(eid, eid) for eid in ensembl_ids]
        feature_importance = pd.DataFrame({
            'Feature': gene_names,
            'Mean SHAP Value': shap_mean
        }).sort_values(by='Mean SHAP Value', ascending=False).reset_index(drop=True)
        with open(file,'wb')as f:
            pickle.dump(feature_importance,f)
    else:
        feature_importance = pickle.load(open(file,'rb'))
    if not os.path.exists(f'{model_name}_shap_{types}.pdf'):
        plt.figure()
        shap.summary_plot(shap_values, X_test, feature_names=gene_names, plot_type='bar', show=False)
        plt.title(f'SHAP Summary - {types}')
        plt.tight_layout()
        plt.savefig(f'{model_name}_shap_{types}.pdf')
        plt.close()
    if not os.path.exists(f'{model_name}_shap_{types}.dot.pdf'):
        print(f'{model_name}_shap_{types}.dot.pdf')
        shap.summary_plot(shap_values, X_test, plot_type="dot")
        fig = plt.gcf()
        plt.savefig(f'{model_name}_shap_{types}.dot.pdf')
        plt.close()
    return feature_importance.head(10)

In [None]:
import sys
import os

filepath = 'test_adata_combined.h5ad'

import pickle
import scanpy as sc
count = 0
adata = None
if os.path.exists("test_adata_combined.pickle"):
    adata = pickle.load(open("test_adata_combined.pickle",'rb'))
else:
    adata = sc.read_h5ad(filepath)
    with open("test_adata_combined.pickle", "wb") as f:
        pickle.dump(adata, f)
while not adata:
    try:
        adata = sc.read_h5ad(filepath)
        with open("test_adata_combined.pickle", "wb") as f:
            pickle.dump(adata, f)
    except Exception as e:
        count+=1
        print('try num :',count)
print(adata)

AnnData object with n_obs × n_vars = 286843 × 28913
    obs: 'sample', 'tissue', 'patient', 'n_genes_by_counts', 'total_counts', 'total_counts_rp', 'pct_counts_rp', 'total_counts_mt', 'pct_counts_mt', 'total_counts_hb', 'pct_counts_hb', 'total_counts_hsp', 'pct_counts_hsp', 'doublet_scores', 'predicted_doublets', 'doublet_info', 'n_genes', 'S_score', 'G2M_score', 'phase', 'UMAP_1', 'UMAP_2', 'sub_cluster', 'major_cluster', 'batch', 'cell_type', 'response', 'pre_post'


In [None]:
adata.obs

Unnamed: 0,sample,tissue,patient,n_genes_by_counts,total_counts,total_counts_rp,pct_counts_rp,total_counts_mt,pct_counts_mt,total_counts_hb,...,G2M_score,phase,UMAP_1,UMAP_2,sub_cluster,major_cluster,batch,cell_type,response,pre_post
P53-post-P-CD8-GGCAATTGTGCGCTTG-1,P53-post-P-CD8,P,P53,1674,5891.0,2574.0,43.693768,227.0,3.853336,0.0,...,-0.120244,G1,11.528003,0.799107,CD8_C01_LEF1,,cd8,CD8T,non-responder,post
P58-post-P-CD8-GAATAAGTCCGCTGTT-1,P58-post-P-CD8,P,P58,2432,7200.0,2288.0,31.777779,398.0,5.527778,0.0,...,-0.081107,G1,12.740842,0.170495,CD8_C01_LEF1,,cd8,CD8T,responder,post
P3-af-P-CD8p-GTCACAATCTCAACTT-1,P3-af-P-CD8p,P,P3,1622,4478.0,1455.0,32.492184,201.0,4.488611,0.0,...,-0.045128,S,5.800357,4.500662,CD8_C02_GPR183,,cd8,CD8T,responder,post
P58-post-P-CD8-CACCAGGGTCGGATCC-1,P58-post-P-CD8,P,P58,1966,6532.0,2578.0,39.467239,210.0,3.214942,0.0,...,-0.122368,G1,11.725809,2.026451,CD8_C01_LEF1,,cd8,CD8T,responder,post
P3-af-P-CD8p-GCTGGGTCAGGCGATA-1,P3-af-P-CD8p,P,P3,853,1784.0,458.0,25.672644,83.0,4.652466,0.0,...,-0.203805,G1,0.680999,6.438696,CD8_C03_CX3CR1,,cd8,CD8T,responder,post
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
P18-pre-P-CTCTGGTCAGCGTAAG-1,P18-pre-P,P,P18,1396,2318.0,214.0,9.232097,96.0,4.141501,0.0,...,-0.059760,S,7.865705,-0.785990,CD8_C03_CX3CR1,CD8T,cd45,CD45,responder,pre
P27-pre-T-TGAGGGAAGAAGGCCT-1,P27-pre-T,T,P27,1319,3377.0,1045.0,30.944624,137.0,4.056855,0.0,...,0.030680,S,3.888998,1.845784,CD4_C03_CD44,CD4T,cd45,CD45,responder,pre
P27-pre-T-CGTAGGCTCGCATGGC-1,P27-pre-T,T,P27,1205,2150.0,80.0,3.720930,82.0,3.813953,0.0,...,-0.135053,G1,9.128882,4.708649,NK_C03_GZMK,ILC,cd45,CD45,responder,pre
P11-pre-P-ACACCCTCAAGGCTCC-1,P11-pre-P,P,P11,2415,6466.0,1061.0,16.408909,193.0,2.984844,0.0,...,1.040371,G2M,4.558650,1.222805,NK_C01_FCGR3A,ILC,cd45,CD45,responder,pre


In [6]:
adata.obs['Cell_Type'] = adata.obs['sub_cluster'].apply(lambda x:x.split('_')[0])
adata.obs['Cell_Type']


P53-post-P-CD8-GGCAATTGTGCGCTTG-1    CD8
P58-post-P-CD8-GAATAAGTCCGCTGTT-1    CD8
P3-af-P-CD8p-GTCACAATCTCAACTT-1      CD8
P58-post-P-CD8-CACCAGGGTCGGATCC-1    CD8
P3-af-P-CD8p-GCTGGGTCAGGCGATA-1      CD8
                                    ... 
P18-pre-P-CTCTGGTCAGCGTAAG-1         CD8
P27-pre-T-TGAGGGAAGAAGGCCT-1         CD4
P27-pre-T-CGTAGGCTCGCATGGC-1          NK
P11-pre-P-ACACCCTCAAGGCTCC-1          NK
P5-pre-P-CTTCTCTAGGAGCGTT-1          CD4
Name: Cell_Type, Length: 286843, dtype: object

In [7]:
# 检查 pre_post 列是否存在
if 'pre_post' not in adata.obs.columns:
    logging.error("adata.obs 中缺少 'pre_post' 列，程序退出")
    sys.exit()
logging.info(f"pre_post 列的唯一值: {adata.obs['pre_post'].unique().tolist()}")

# 筛选 pre_post == 'Pre' 的行
adata_ = adata[adata.obs['pre_post'] == 'pre'].copy()
logging.info(f"筛选后 pre_post == 'pre' 的样本量: {adata_.shape[0]}")

2025-09-13 10:42:10,455 - INFO - pre_post 列的唯一值: ['post', 'pre']
2025-09-13 10:42:12,033 - INFO - 筛选后 pre_post == 'pre' 的样本量: 98270


In [8]:
# 检查 pre_post 列是否存在
if 'pre_post' not in adata_.obs.columns:
    logging.error("adata.obs 中缺少 'pre_post' 列，程序退出")
    sys.exit()

# 检查 pre_post 列的唯一值
logging.info(f"pre_post 列的唯一值: {adata_.obs['pre_post'].unique().tolist()}")

# 筛选 pre_post == 'Pre' 的行
adata_ = adata_[adata_.obs['pre_post'] == 'pre'].copy()
logging.info(f"筛选后 pre_post == 'Pre' 的样本量: {adata_.shape[0]}")

# 如果筛选后样本量为 0，打印更多诊断信息并退出
if adata_.shape[0] == 0:
    logging.error("筛选 pre_post == 'pre' 后样本量为 0，请检查数据")
    logging.info(f"原始 adata 的列: {adata_.obs.columns.tolist()}")
    logging.info(f"response 列的唯一值: {adata_.obs['response'].unique().tolist() if 'response' in adata_.obs.columns else '列不存在'}")
    sys.exit()

# 筛选 Combined_outcome
if 'response' not in adata_.obs.columns:
    logging.error("adata_.obs 中缺少 'response' 列，程序退出")
    sys.exit()


2025-09-13 10:42:13,048 - INFO - pre_post 列的唯一值: ['pre']


2025-09-13 10:42:14,936 - INFO - 筛选后 pre_post == 'Pre' 的样本量: 98270


In [None]:
adata_.obs

Unnamed: 0,sample,tissue,patient,n_genes_by_counts,total_counts,total_counts_rp,pct_counts_rp,total_counts_mt,pct_counts_mt,total_counts_hb,...,phase,UMAP_1,UMAP_2,sub_cluster,major_cluster,batch,cell_type,response,pre_post,Cell_Type
P16-pre-P-CD8p2-GGCGACTTCCACGTTC-1,P16-pre-P-CD8p,P,P16,1427,4671.0,1542.0,33.012203,154.0,3.296938,0.0,...,S,-0.212849,2.458946,CD8_C03_CX3CR1,,cd8,CD8T,non-responder,pre,CD8
P16-pre-P-CD8p1-CACCACTCAACAACCT-1,P16-pre-P-CD8p,P,P16,1835,7281.0,3163.0,43.441833,264.0,3.625876,0.0,...,G1,7.944293,-0.144074,CD8_C01_LEF1,,cd8,CD8T,non-responder,pre,CD8
P16-pre-P-CD8n-TATTACCGTCCAACTA-1,P16-pre-P-CD8n,P,P16,959,2529.0,996.0,39.383156,231.0,9.134046,0.0,...,G1,5.682766,3.407964,CD8_C02_GPR183,,cd8,CD8T,non-responder,pre,CD8
P16-pre-T-CD8p-CTAATGGGTTCCATGA-1,P16-pre-T-CD8p,T,P16,1081,2505.0,496.0,19.800400,72.0,2.874251,0.0,...,S,0.651383,-3.216115,CD8_C08_GZMK,,cd8,CD8T,non-responder,pre,CD8
P16-pre-P-CD8p1-TGACTTTAGCGTGAGT-1,P16-pre-P-CD8p,P,P16,894,1815.0,353.0,19.449036,205.0,11.294765,0.0,...,G1,0.857628,-1.018432,CD8_C08_GZMK,,cd8,CD8T,non-responder,pre,CD8
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
P18-pre-P-CTCTGGTCAGCGTAAG-1,P18-pre-P,P,P18,1396,2318.0,214.0,9.232097,96.0,4.141501,0.0,...,S,7.865705,-0.785990,CD8_C03_CX3CR1,CD8T,cd45,CD45,responder,pre,CD8
P27-pre-T-TGAGGGAAGAAGGCCT-1,P27-pre-T,T,P27,1319,3377.0,1045.0,30.944624,137.0,4.056855,0.0,...,S,3.888998,1.845784,CD4_C03_CD44,CD4T,cd45,CD45,responder,pre,CD4
P27-pre-T-CGTAGGCTCGCATGGC-1,P27-pre-T,T,P27,1205,2150.0,80.0,3.720930,82.0,3.813953,0.0,...,G1,9.128882,4.708649,NK_C03_GZMK,ILC,cd45,CD45,responder,pre,NK
P11-pre-P-ACACCCTCAAGGCTCC-1,P11-pre-P,P,P11,2415,6466.0,1061.0,16.408909,193.0,2.984844,0.0,...,G2M,4.558650,1.222805,NK_C01_FCGR3A,ILC,cd45,CD45,responder,pre,NK


In [9]:
# 检查 Combined_outcome 列的唯一值
logging.info(f"response 列的唯一值: {adata_.obs['response'].unique().tolist()}")

adata_ = adata_[adata_.obs['response'].isin(['responder', 'non-responder']), :]
logging.info(f"筛选后 response 包含 response/non-responder 的样本量: {adata_.shape[0]}")

# 如果筛选后样本量为 0，退出
if adata_.shape[0] == 0:
    logging.error("筛选 response 后样本量为 0，请检查数据")
    sys.exit()

2025-09-13 10:42:16,751 - INFO - response 列的唯一值: ['non-responder', 'responder']


2025-09-13 10:42:16,972 - INFO - 筛选后 response 包含 response/non-responder 的样本量: 98270


In [None]:
adata_.obs.shape

(98270, 29)

In [10]:
pd.set_option('display.max_columns', None)
adata_.obs['Cancer_type_update'] = ['HCC']*adata_.obs.shape[0]
adata_.obs

Unnamed: 0,sample,tissue,patient,n_genes_by_counts,total_counts,total_counts_rp,pct_counts_rp,total_counts_mt,pct_counts_mt,total_counts_hb,pct_counts_hb,total_counts_hsp,pct_counts_hsp,doublet_scores,predicted_doublets,doublet_info,n_genes,S_score,G2M_score,phase,UMAP_1,UMAP_2,sub_cluster,major_cluster,batch,cell_type,response,pre_post,Cell_Type,Cancer_type_update
P16-pre-P-CD8p2-GGCGACTTCCACGTTC-1,P16-pre-P-CD8p,P,P16,1427,4671.0,1542.0,33.012203,154.0,3.296938,0.0,0.0,13.0,0.278313,0.087618,False,False,1427,0.037808,-0.000656,S,-0.212849,2.458946,CD8_C03_CX3CR1,,cd8,CD8T,non-responder,pre,CD8,HCC
P16-pre-P-CD8p1-CACCACTCAACAACCT-1,P16-pre-P-CD8p,P,P16,1835,7281.0,3163.0,43.441833,264.0,3.625876,0.0,0.0,15.0,0.206016,0.112999,False,False,1835,-0.002969,-0.124453,G1,7.944293,-0.144074,CD8_C01_LEF1,,cd8,CD8T,non-responder,pre,CD8,HCC
P16-pre-P-CD8n-TATTACCGTCCAACTA-1,P16-pre-P-CD8n,P,P16,959,2529.0,996.0,39.383156,231.0,9.134046,0.0,0.0,8.0,0.316331,0.133690,False,False,959,-0.188267,-0.082084,G1,5.682766,3.407964,CD8_C02_GPR183,,cd8,CD8T,non-responder,pre,CD8,HCC
P16-pre-T-CD8p-CTAATGGGTTCCATGA-1,P16-pre-T-CD8p,T,P16,1081,2505.0,496.0,19.800400,72.0,2.874251,0.0,0.0,69.0,2.754491,0.064125,False,False,1081,0.139061,0.017359,S,0.651383,-3.216115,CD8_C08_GZMK,,cd8,CD8T,non-responder,pre,CD8,HCC
P16-pre-P-CD8p1-TGACTTTAGCGTGAGT-1,P16-pre-P-CD8p,P,P16,894,1815.0,353.0,19.449036,205.0,11.294765,0.0,0.0,9.0,0.495868,0.222841,False,False,894,-0.032617,-0.149058,G1,0.857628,-1.018432,CD8_C08_GZMK,,cd8,CD8T,non-responder,pre,CD8,HCC
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
P18-pre-P-CTCTGGTCAGCGTAAG-1,P18-pre-P,P,P18,1396,2318.0,214.0,9.232097,96.0,4.141501,0.0,0.0,26.0,1.121657,0.137306,False,False,1396,0.041544,-0.059760,S,7.865705,-0.785990,CD8_C03_CX3CR1,CD8T,cd45,CD45,responder,pre,CD8,HCC
P27-pre-T-TGAGGGAAGAAGGCCT-1,P27-pre-T,T,P27,1319,3377.0,1045.0,30.944624,137.0,4.056855,0.0,0.0,131.0,3.879183,0.045350,False,False,1319,0.100014,0.030680,S,3.888998,1.845784,CD4_C03_CD44,CD4T,cd45,CD45,responder,pre,CD4,HCC
P27-pre-T-CGTAGGCTCGCATGGC-1,P27-pre-T,T,P27,1205,2150.0,80.0,3.720930,82.0,3.813953,0.0,0.0,181.0,8.418605,0.066667,False,False,1205,-0.219012,-0.135053,G1,9.128882,4.708649,NK_C03_GZMK,ILC,cd45,CD45,responder,pre,NK,HCC
P11-pre-P-ACACCCTCAAGGCTCC-1,P11-pre-P,P,P11,2415,6466.0,1061.0,16.408909,193.0,2.984844,0.0,0.0,49.0,0.757810,0.120760,False,False,2415,-0.131148,1.040371,G2M,4.558650,1.222805,NK_C01_FCGR3A,ILC,cd45,CD45,responder,pre,NK,HCC


In [None]:
os.environ['CATBOOST_GPU_MEMORY_PART'] = '0.8'

In [None]:
adata_.obs['Cancer_type_update']

logging.info(f"adata.X 数据类型: {adata_.X.dtype}, 最小值: {adata_.X.min()}, 最大值: {adata_.X.max()}")

cluster_col = 'Cell_Type'
if cluster_col not in adata_.obs.columns:
    logging.error("adata.obs 中缺少 'Cell_Type' 列，程序退出")
    sys.exit()

logging.info("Cell_Type 分布：\n" + str(adata_.obs['Cell_Type'].value_counts()))
logging.info("response 分布：\n" + str(adata_.obs.groupby(['Cell_Type', 'response']).size()))

valid_cell_types = []
for types in set(adata_.obs[cluster_col]):
    adata1 = adata_[adata_.obs[cluster_col] == types]
    label_counts = adata1.obs['response'].value_counts()
    if (label_counts.get('responder', 0) >= MIN_CELLS_PER_LABEL and 
        label_counts.get('non-responder', 0) >= MIN_CELLS_PER_LABEL):
        valid_cell_types.append(types)
    else:
        logging.warning(f"跳过细胞类型 {types}：样本量不足或缺少标签类别")

valid_cell_types_ =  ['all'] +valid_cell_types
print(valid_cell_types)

print('valid_cell_types:',valid_cell_types_)
model_reult = {}
for types in tqdm(valid_cell_types_, desc="处理细胞类型"):
    subset = None
    model_reult[types] = {}
    subset, labels, cancer_dist = SamplePick(adata_, types, cluster_col, 'response')
    
    if subset is None:
        continue
    for model_name in model_param_grid:
        print(model_name)
        if model_name != 'XGBoost':
            continue
        file =f'model_reult_{types}_{model_name}.pickle'
        if not os.path.exists(file):
            dflist = []
            metrics_list = []
            acc, auc_val, fpr, tpr, model, X_test, adata_test, le,X_train,adata_train,y_train,y_test = TumorClassify(subset, labels,model_name)
            types_clean = types.replace(' ', '_')
            plot_roc(fpr, tpr, auc_val, types_clean,model_name)
            logging.info(f"{types_clean} — 准确率: {acc:.2f}, AUC: {auc_val:.2f}, 采样后癌种分布: {cancer_dist}")

            metrics_list.append({
                'Cell_Type': types_clean,
                'Accuracy': acc,
                'AUC': auc_val,
                'Cancer_Distribution': str(cancer_dist)  # 将癌种分布存为字符串
            })
            
        else:
            print('loding data:',file)
            model_reult[types][model_name] = pickle.load(open(file,'rb'))
        
        if isinstance(model_param_grid[model_name][0], (XGBClassifier, LGBMClassifier, CatBoostClassifier, RandomForestClassifier)):
            feature_df = shap_analysis(model, X_test, adata_test, types_clean)
            if feature_df:
                feature_df['Cell_Type'] = types_clean
                dflist.append(feature_df)

                if dflist:
                    df_all = pd.concat(dflist)
                    df_all.to_csv(f"{model_name}_SHAP_Feature_Top10_Table.tsv", sep='\t', index=False)
                    logging.info("所有 SHAP 特征已保存至 'SHAP_Feature_Top10_Table.tsv'")

                if metrics_list:
                    metrics_df = pd.DataFrame(metrics_list)
                    metrics_df.to_csv(f"{model_name}_Cell_Type_Metrics.tsv", sep='\t', index=False)
                    logging.info("所有细胞类型的 ROC、准确率及癌种分布已保存至 '.tsv'")
        model_reult[types][model_name] = (acc, auc_val, fpr, tpr, model, X_test, adata_test, le,metrics_list,types_clean,dflist)
        with open(file,'wb')as f:
            pickle.dump(model_reult,f)

2025-09-13 10:46:22,932 - INFO - adata.X 数据类型: float32, 最小值: 0.0, 最大值: 27563.0
2025-09-13 10:46:22,964 - INFO - Cell_Type 分布：
Cell_Type
CD8              44949
CD4              24953
NK               11119
B                 5425
Mono              4820
Macro             1817
gdT               1525
Plasma            1368
cDC2               991
ILC2               269
pDC                268
cDC1               192
Mast               178
NKT-like           157
ILC3               113
Megakaryocyte       99
LAMP3+ DC           27
Name: count, dtype: int64
2025-09-13 10:46:23,004 - INFO - response 分布：
Cell_Type      response     
B              non-responder     1949
               responder         3476
CD4            non-responder    10904
               responder        14049
CD8            non-responder    25828
               responder        19121
ILC2           non-responder       30
               responder          239
ILC3           non-responder       18
               responder      

['CD8', 'cDC2', 'Mono', 'cDC1', 'Macro', 'Plasma', 'gdT', 'B', 'NK', 'CD4', 'pDC']
valid_cell_types: ['all', 'CD8', 'cDC2', 'Mono', 'cDC1', 'Macro', 'Plasma', 'gdT', 'B', 'NK', 'CD4', 'pDC']


处理细胞类型:   0%|          | 0/12 [00:00<?, ?it/s]2025-09-13 10:46:24,786 - INFO - 细胞类型 all 的样本量: 97427
2025-09-13 10:46:25,414 - INFO - 细胞类型 all 标签 responder 癌种 HCC 采样 48669 个细胞
2025-09-13 10:46:26,602 - INFO - 细胞类型 all 标签 non-responder 癌种 HCC 采样 48758 个细胞
2025-09-13 10:46:27,733 - INFO - 细胞类型 all 采样完成，样本量: 97427
2025-09-13 10:46:27,739 - INFO - 细胞类型 all 采样后的癌种分布: {'HCC': 97427}


Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:47:14,268 - INFO - Default AUROC (CV): 0.9449


Default AUROC (CV): 0.9449


2025-09-13 10:47:58,468 - INFO - Best Params: {'n_estimators': 100, 'max_depth': 5}
2025-09-13 10:47:58,472 - INFO - Best AUROC (CV): 0.9390


Best AUROC (CV): 0.9390


2025-09-13 10:48:02,422 - INFO - 训练集准确率: 0.88, 训练集 AUC: 0.96
2025-09-13 10:48:03,140 - INFO - 最优阈值: 0.504, 优化后的准确率: 0.86
2025-09-13 10:48:03,161 - INFO - 混淆矩阵:
[[8549 1203]
 [1500 8234]]
2025-09-13 10:48:03,163 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.85      0.88      0.86      9752
    responder       0.87      0.85      0.86      9734

     accuracy                           0.86     19486
    macro avg       0.86      0.86      0.86     19486
 weighted avg       0.86      0.86      0.86     19486

2025-09-13 10:48:03,185 - INFO - 模型已保存至: XGBoost_model_all.pkl
2025-09-13 10:48:03,212 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_all.pkl
2025-09-13 10:48:05,606 - INFO - all — 准确率: 0.86, AUC: 0.94, 采样后癌种分布: {'HCC': 97427}


XGBoots_shap_all.dot.pdf


处理细胞类型:   8%|▊         | 1/12 [02:20<25:43, 140.35s/it]

Stacking
tabpfn


2025-09-13 10:48:44,314 - INFO - 细胞类型 CD8 的样本量: 44949
2025-09-13 10:48:44,559 - INFO - 细胞类型 CD8 标签 responder 癌种 HCC 采样 19121 个细胞
2025-09-13 10:48:45,115 - INFO - 细胞类型 CD8 标签 non-responder 癌种 HCC 采样 25828 个细胞
2025-09-13 10:48:45,688 - INFO - 细胞类型 CD8 采样完成，样本量: 44949
2025-09-13 10:48:45,697 - INFO - 细胞类型 CD8 采样后的癌种分布: {'HCC': 44949}


Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:49:09,943 - INFO - Default AUROC (CV): 0.9901


Default AUROC (CV): 0.9901


2025-09-13 10:49:40,301 - INFO - Best Params: {'n_estimators': 100, 'max_depth': 5}
2025-09-13 10:49:40,304 - INFO - Best AUROC (CV): 0.9887


Best AUROC (CV): 0.9887


2025-09-13 10:49:41,961 - INFO - 训练集准确率: 0.98, 训练集 AUC: 1.00
2025-09-13 10:49:42,370 - INFO - 最优阈值: 0.399, 优化后的准确率: 0.95
2025-09-13 10:49:42,403 - INFO - 混淆矩阵:
[[4953  213]
 [ 279 3545]]
2025-09-13 10:49:42,405 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.95      0.96      0.95      5166
    responder       0.94      0.93      0.94      3824

     accuracy                           0.95      8990
    macro avg       0.94      0.94      0.94      8990
 weighted avg       0.95      0.95      0.95      8990

2025-09-13 10:49:42,443 - INFO - 模型已保存至: XGBoost_model_CD8.pkl
2025-09-13 10:49:42,450 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_CD8.pkl
2025-09-13 10:49:44,206 - INFO - CD8 — 准确率: 0.95, AUC: 0.99, 采样后癌种分布: {'HCC': 44949}


XGBoots_shap_CD8.dot.pdf


处理细胞类型:  17%|█▋        | 2/12 [03:48<18:17, 109.71s/it]2025-09-13 10:50:12,106 - INFO - 细胞类型 cDC2 的样本量: 991
2025-09-13 10:50:12,128 - INFO - 细胞类型 cDC2 标签 responder 癌种 HCC 采样 431 个细胞
2025-09-13 10:50:12,170 - INFO - 细胞类型 cDC2 标签 non-responder 癌种 HCC 采样 560 个细胞


Stacking
tabpfn


2025-09-13 10:50:12,231 - INFO - 细胞类型 cDC2 采样完成，样本量: 991
2025-09-13 10:50:12,233 - INFO - 细胞类型 cDC2 采样后的癌种分布: {'HCC': 991}


Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:50:15,832 - INFO - Default AUROC (CV): 0.9216


Default AUROC (CV): 0.9216


2025-09-13 10:50:25,638 - INFO - Best Params: {'n_estimators': 50, 'max_depth': 5}
2025-09-13 10:50:25,640 - INFO - Best AUROC (CV): 0.9220
2025-09-13 10:50:25,798 - INFO - 训练集准确率: 1.00, 训练集 AUC: 1.00


Best AUROC (CV): 0.9220


2025-09-13 10:50:25,907 - INFO - 最优阈值: 0.378, 优化后的准确率: 0.85
2025-09-13 10:50:25,923 - INFO - 混淆矩阵:
[[101  11]
 [ 22  65]]
2025-09-13 10:50:25,924 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.82      0.90      0.86       112
    responder       0.86      0.75      0.80        87

     accuracy                           0.83       199
    macro avg       0.84      0.82      0.83       199
 weighted avg       0.84      0.83      0.83       199

2025-09-13 10:50:25,936 - INFO - 模型已保存至: XGBoost_model_cDC2.pkl
2025-09-13 10:50:25,939 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_cDC2.pkl
2025-09-13 10:50:25,984 - INFO - cDC2 — 准确率: 0.83, AUC: 0.91, 采样后癌种分布: {'HCC': 991}


XGBoots_shap_cDC2.dot.pdf


处理细胞类型:  25%|██▌       | 3/12 [04:26<11:32, 76.93s/it] 2025-09-13 10:50:50,018 - INFO - 细胞类型 Mono 的样本量: 4820
2025-09-13 10:50:50,057 - INFO - 细胞类型 Mono 标签 responder 癌种 HCC 采样 2496 个细胞
2025-09-13 10:50:50,137 - INFO - 细胞类型 Mono 标签 non-responder 癌种 HCC 采样 2324 个细胞


Stacking
tabpfn


2025-09-13 10:50:50,216 - INFO - 细胞类型 Mono 采样完成，样本量: 4820
2025-09-13 10:50:50,219 - INFO - 细胞类型 Mono 采样后的癌种分布: {'HCC': 4820}


Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:50:54,397 - INFO - Default AUROC (CV): 0.8989


Default AUROC (CV): 0.8989


2025-09-13 10:51:02,020 - INFO - Best Params: {'n_estimators': 100, 'max_depth': 5}
2025-09-13 10:51:02,024 - INFO - Best AUROC (CV): 0.8937


Best AUROC (CV): 0.8937


2025-09-13 10:51:02,264 - INFO - 训练集准确率: 0.97, 训练集 AUC: 0.99
2025-09-13 10:51:02,330 - INFO - 最优阈值: 0.435, 优化后的准确率: 0.81
2025-09-13 10:51:02,345 - INFO - 混淆矩阵:
[[364 101]
 [ 89 410]]
2025-09-13 10:51:02,346 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.80      0.78      0.79       465
    responder       0.80      0.82      0.81       499

     accuracy                           0.80       964
    macro avg       0.80      0.80      0.80       964
 weighted avg       0.80      0.80      0.80       964

2025-09-13 10:51:02,363 - INFO - 模型已保存至: XGBoost_model_Mono.pkl
2025-09-13 10:51:02,368 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_Mono.pkl
2025-09-13 10:51:02,544 - INFO - Mono — 准确率: 0.80, AUC: 0.89, 采样后癌种分布: {'HCC': 4820}


XGBoots_shap_Mono.dot.pdf


处理细胞类型:  33%|███▎      | 4/12 [05:03<08:08, 61.06s/it]2025-09-13 10:51:26,693 - INFO - 细胞类型 cDC1 的样本量: 192
2025-09-13 10:51:26,708 - INFO - 细胞类型 cDC1 标签 responder 癌种 HCC 采样 89 个细胞
2025-09-13 10:51:26,736 - INFO - 细胞类型 cDC1 标签 non-responder 癌种 HCC 采样 103 个细胞
2025-09-13 10:51:26,776 - INFO - 细胞类型 cDC1 采样完成，样本量: 192
2025-09-13 10:51:26,778 - INFO - 细胞类型 cDC1 采样后的癌种分布: {'HCC': 192}


Stacking
tabpfn
Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:51:28,929 - INFO - Default AUROC (CV): 0.8807


Default AUROC (CV): 0.8807


2025-09-13 10:51:35,272 - INFO - Best Params: {'n_estimators': 50, 'max_depth': 3}
2025-09-13 10:51:35,278 - INFO - Best AUROC (CV): 0.9075
2025-09-13 10:51:35,340 - INFO - 训练集准确率: 1.00, 训练集 AUC: 1.00
2025-09-13 10:51:35,407 - INFO - 最优阈值: 0.808, 优化后的准确率: 0.92
2025-09-13 10:51:35,425 - INFO - 混淆矩阵:
[[19  2]
 [ 3 15]]
2025-09-13 10:51:35,426 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.86      0.90      0.88        21
    responder       0.88      0.83      0.86        18

     accuracy                           0.87        39
    macro avg       0.87      0.87      0.87        39
 weighted avg       0.87      0.87      0.87        39

2025-09-13 10:51:35,450 - INFO - 模型已保存至: XGBoost_model_cDC1.pkl
2025-09-13 10:51:35,458 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_cDC1.pkl
2025-09-13 10:51:35,460 - INFO - cDC1 — 准确率: 0.87, AUC: 0.95, 采样后癌种分布: {'HCC': 192}


Best AUROC (CV): 0.9075
XGBoots_shap_cDC1.dot.pdf


处理细胞类型:  42%|████▏     | 5/12 [05:35<05:54, 50.61s/it]2025-09-13 10:51:58,815 - INFO - 细胞类型 Macro 的样本量: 1817
2025-09-13 10:51:58,846 - INFO - 细胞类型 Macro 标签 responder 癌种 HCC 采样 1244 个细胞
2025-09-13 10:51:58,887 - INFO - 细胞类型 Macro 标签 non-responder 癌种 HCC 采样 573 个细胞
2025-09-13 10:51:58,929 - INFO - 细胞类型 Macro 采样完成，样本量: 1817
2025-09-13 10:51:58,935 - INFO - 细胞类型 Macro 采样后的癌种分布: {'HCC': 1817}


Stacking
tabpfn
Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:52:02,584 - INFO - Default AUROC (CV): 0.9324


Default AUROC (CV): 0.9324


2025-09-13 10:52:12,582 - INFO - Best Params: {'n_estimators': 100, 'max_depth': 3}
2025-09-13 10:52:12,585 - INFO - Best AUROC (CV): 0.9349
2025-09-13 10:52:12,689 - INFO - 训练集准确率: 0.97, 训练集 AUC: 0.99
2025-09-13 10:52:12,775 - INFO - 最优阈值: 0.576, 优化后的准确率: 0.88


Best AUROC (CV): 0.9349


2025-09-13 10:52:12,802 - INFO - 混淆矩阵:
[[ 98  17]
 [ 23 226]]
2025-09-13 10:52:12,804 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.81      0.85      0.83       115
    responder       0.93      0.91      0.92       249

     accuracy                           0.89       364
    macro avg       0.87      0.88      0.87       364
 weighted avg       0.89      0.89      0.89       364

2025-09-13 10:52:12,823 - INFO - 模型已保存至: XGBoost_model_Macro.pkl
2025-09-13 10:52:12,827 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_Macro.pkl
2025-09-13 10:52:12,919 - INFO - Macro — 准确率: 0.89, AUC: 0.96, 采样后癌种分布: {'HCC': 1817}


XGBoots_shap_Macro.dot.pdf


处理细胞类型:  50%|█████     | 6/12 [06:14<04:41, 46.85s/it]2025-09-13 10:52:38,379 - INFO - 细胞类型 Plasma 的样本量: 1368
2025-09-13 10:52:38,420 - INFO - 细胞类型 Plasma 标签 responder 癌种 HCC 采样 1280 个细胞
2025-09-13 10:52:38,471 - INFO - 细胞类型 Plasma 标签 non-responder 癌种 HCC 采样 88 个细胞


Stacking
tabpfn


2025-09-13 10:52:38,510 - INFO - 细胞类型 Plasma 采样完成，样本量: 1368
2025-09-13 10:52:38,511 - INFO - 细胞类型 Plasma 采样后的癌种分布: {'HCC': 1368}


Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:52:41,670 - INFO - Default AUROC (CV): 0.9718


Default AUROC (CV): 0.9718


2025-09-13 10:52:50,205 - INFO - Best Params: {'n_estimators': 50, 'max_depth': 5}
2025-09-13 10:52:50,210 - INFO - Best AUROC (CV): 0.9729
2025-09-13 10:52:50,299 - INFO - 训练集准确率: 1.00, 训练集 AUC: 1.00
2025-09-13 10:52:50,377 - INFO - 最优阈值: 0.996, 优化后的准确率: 0.87
2025-09-13 10:52:50,402 - INFO - 混淆矩阵:
[[  9   9]
 [  7 249]]
2025-09-13 10:52:50,404 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.56      0.50      0.53        18
    responder       0.97      0.97      0.97       256

     accuracy                           0.94       274
    macro avg       0.76      0.74      0.75       274
 weighted avg       0.94      0.94      0.94       274



Best AUROC (CV): 0.9729


2025-09-13 10:52:50,425 - INFO - 模型已保存至: XGBoost_model_Plasma.pkl
2025-09-13 10:52:50,430 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_Plasma.pkl
2025-09-13 10:52:50,509 - INFO - Plasma — 准确率: 0.94, AUC: 0.96, 采样后癌种分布: {'HCC': 1368}


XGBoots_shap_Plasma.dot.pdf


处理细胞类型:  58%|█████▊    | 7/12 [06:54<03:42, 44.51s/it]2025-09-13 10:53:18,026 - INFO - 细胞类型 gdT 的样本量: 1525
2025-09-13 10:53:18,059 - INFO - 细胞类型 gdT 标签 responder 癌种 HCC 采样 1244 个细胞
2025-09-13 10:53:18,109 - INFO - 细胞类型 gdT 标签 non-responder 癌种 HCC 采样 281 个细胞
2025-09-13 10:53:18,152 - INFO - 细胞类型 gdT 采样完成，样本量: 1525
2025-09-13 10:53:18,154 - INFO - 细胞类型 gdT 采样后的癌种分布: {'HCC': 1525}


Stacking
tabpfn
Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:53:26,212 - INFO - Default AUROC (CV): 0.9383


Default AUROC (CV): 0.9383


2025-09-13 10:53:42,327 - INFO - Best Params: {'n_estimators': 100, 'max_depth': 3}
2025-09-13 10:53:42,333 - INFO - Best AUROC (CV): 0.9426
2025-09-13 10:53:42,410 - INFO - 训练集准确率: 1.00, 训练集 AUC: 1.00
2025-09-13 10:53:42,510 - INFO - 最优阈值: 0.842, 优化后的准确率: 0.93


Best AUROC (CV): 0.9426


2025-09-13 10:53:42,535 - INFO - 混淆矩阵:
[[ 43  13]
 [  4 245]]
2025-09-13 10:53:42,536 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.91      0.77      0.83        56
    responder       0.95      0.98      0.97       249

     accuracy                           0.94       305
    macro avg       0.93      0.88      0.90       305
 weighted avg       0.94      0.94      0.94       305

2025-09-13 10:53:42,551 - INFO - 模型已保存至: XGBoost_model_gdT.pkl
2025-09-13 10:53:42,555 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_gdT.pkl
2025-09-13 10:53:42,632 - INFO - gdT — 准确率: 0.94, AUC: 0.98, 采样后癌种分布: {'HCC': 1525}


XGBoots_shap_gdT.dot.pdf


处理细胞类型:  67%|██████▋   | 8/12 [07:45<03:06, 46.55s/it]2025-09-13 10:54:09,027 - INFO - 细胞类型 B 的样本量: 5425
2025-09-13 10:54:09,088 - INFO - 细胞类型 B 标签 responder 癌种 HCC 采样 3476 个细胞


Stacking
tabpfn


2025-09-13 10:54:09,194 - INFO - 细胞类型 B 标签 non-responder 癌种 HCC 采样 1949 个细胞
2025-09-13 10:54:09,300 - INFO - 细胞类型 B 采样完成，样本量: 5425
2025-09-13 10:54:09,303 - INFO - 细胞类型 B 采样后的癌种分布: {'HCC': 5425}


Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:54:19,070 - INFO - Default AUROC (CV): 0.9400


Default AUROC (CV): 0.9400


2025-09-13 10:54:38,371 - INFO - Best Params: {'n_estimators': 100, 'max_depth': 5}
2025-09-13 10:54:38,377 - INFO - Best AUROC (CV): 0.9392


Best AUROC (CV): 0.9392


2025-09-13 10:54:38,761 - INFO - 训练集准确率: 0.97, 训练集 AUC: 1.00
2025-09-13 10:54:38,884 - INFO - 最优阈值: 0.693, 优化后的准确率: 0.85
2025-09-13 10:54:38,905 - INFO - 混淆矩阵:
[[310  80]
 [ 79 616]]
2025-09-13 10:54:38,906 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.80      0.79      0.80       390
    responder       0.89      0.89      0.89       695

     accuracy                           0.85      1085
    macro avg       0.84      0.84      0.84      1085
 weighted avg       0.85      0.85      0.85      1085

2025-09-13 10:54:38,930 - INFO - 模型已保存至: XGBoost_model_B.pkl
2025-09-13 10:54:38,936 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_B.pkl
2025-09-13 10:54:39,120 - INFO - B — 准确率: 0.85, AUC: 0.93, 采样后癌种分布: {'HCC': 5425}


XGBoots_shap_B.dot.pdf


处理细胞类型:  75%|███████▌  | 9/12 [08:46<02:33, 51.15s/it]

Stacking
tabpfn


2025-09-13 10:55:10,506 - INFO - 细胞类型 NK 的样本量: 11119
2025-09-13 10:55:10,650 - INFO - 细胞类型 NK 标签 responder 癌种 HCC 采样 5077 个细胞
2025-09-13 10:55:11,045 - INFO - 细胞类型 NK 标签 non-responder 癌种 HCC 采样 6042 个细胞
2025-09-13 10:55:11,316 - INFO - 细胞类型 NK 采样完成，样本量: 11119
2025-09-13 10:55:11,322 - INFO - 细胞类型 NK 采样后的癌种分布: {'HCC': 11119}


Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:55:23,507 - INFO - Default AUROC (CV): 0.9691


Default AUROC (CV): 0.9691


2025-09-13 10:55:39,066 - INFO - Best Params: {'n_estimators': 100, 'max_depth': 5}
2025-09-13 10:55:39,067 - INFO - Best AUROC (CV): 0.9678
2025-09-13 10:55:39,172 - INFO - 训练集准确率: 0.99, 训练集 AUC: 1.00


Best AUROC (CV): 0.9678


2025-09-13 10:55:39,278 - INFO - 最优阈值: 0.553, 优化后的准确率: 0.92
2025-09-13 10:55:39,291 - INFO - 混淆矩阵:
[[1101  108]
 [  93  922]]
2025-09-13 10:55:39,291 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.92      0.91      0.92      1209
    responder       0.90      0.91      0.90      1015

     accuracy                           0.91      2224
    macro avg       0.91      0.91      0.91      2224
 weighted avg       0.91      0.91      0.91      2224

2025-09-13 10:55:39,318 - INFO - 模型已保存至: XGBoost_model_NK.pkl
2025-09-13 10:55:39,325 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_NK.pkl
2025-09-13 10:55:39,622 - INFO - NK — 准确率: 0.91, AUC: 0.97, 采样后癌种分布: {'HCC': 11119}


XGBoots_shap_NK.dot.pdf


处理细胞类型:  83%|████████▎ | 10/12 [09:44<01:46, 53.27s/it]

Stacking
tabpfn


2025-09-13 10:56:08,557 - INFO - 细胞类型 CD4 的样本量: 24953
2025-09-13 10:56:08,769 - INFO - 细胞类型 CD4 标签 responder 癌种 HCC 采样 14049 个细胞
2025-09-13 10:56:09,117 - INFO - 细胞类型 CD4 标签 non-responder 癌种 HCC 采样 10904 个细胞
2025-09-13 10:56:09,422 - INFO - 细胞类型 CD4 采样完成，样本量: 24953
2025-09-13 10:56:09,426 - INFO - 细胞类型 CD4 采样后的癌种分布: {'HCC': 24953}


Random Forest
Logistic Regression
LightGBM
CatBoost
XGBoost


2025-09-13 10:56:27,694 - INFO - Default AUROC (CV): 0.9185


Default AUROC (CV): 0.9185


2025-09-13 10:56:50,888 - INFO - Best Params: {'n_estimators': 100, 'max_depth': 5}
2025-09-13 10:56:50,890 - INFO - Best AUROC (CV): 0.9184


Best AUROC (CV): 0.9184


2025-09-13 10:56:51,680 - INFO - 训练集准确率: 0.89, 训练集 AUC: 0.96
2025-09-13 10:56:51,942 - INFO - 最优阈值: 0.534, 优化后的准确率: 0.83
2025-09-13 10:56:51,958 - INFO - 混淆矩阵:
[[1779  402]
 [ 433 2377]]
2025-09-13 10:56:51,959 - INFO - 分类报告:
               precision    recall  f1-score   support

non-responder       0.80      0.82      0.81      2181
    responder       0.86      0.85      0.85      2810

     accuracy                           0.83      4991
    macro avg       0.83      0.83      0.83      4991
 weighted avg       0.83      0.83      0.83      4991

2025-09-13 10:56:51,977 - INFO - 模型已保存至: XGBoost_model_CD4.pkl
2025-09-13 10:56:51,981 - INFO - LabelEncoder 已保存至: XGBoost_label_encoder_CD4.pkl
2025-09-13 10:56:52,619 - INFO - CD4 — 准确率: 0.83, AUC: 0.92, 采样后癌种分布: {'HCC': 24953}
处理细胞类型:  83%|████████▎ | 10/12 [10:57<02:11, 65.78s/it]


KeyboardInterrupt: 

In [None]:
valid_cell_types_ = ['all', 'CD4', 'Plasma', 'Macro', 'B', 'CD8', 'pDC', 'Mono', 'NK', 'cDC1', 'cDC2', 'gdT'] 
for types in tqdm(valid_cell_types_, desc="处理细胞类型"):
    for model_name in model_param_grid:
        file =f'model_reult_{types}_{model_name}.pickle'
        if not os.path.exists(file):
            
            model_reult[types][model_name] = (acc, auc_val, fpr, tpr, model, X_test, adata_test, le,metrics_list,types_clean,dflist)
            with open(file,'wb')as f:
                pickle.dump(model_reult,f)
        else:
            model_reult[types][model_name] = pickle.load(open(file,'rb'))

In [None]:

model_reult[types]

In [None]:
types = 'all'

In [None]:
model_reult.keys()

In [17]:
test = pickle.load(open('model_reult_all_XGBoost.pickle','rb'))
test.keys()

dict_keys(['all'])

In [None]:
test['all'].keys()

In [None]:
test['all']['CatBoost']

In [18]:
model_reult['all'] = {}
model_reult['all']['Random_Forest'] = test['all']['Random_Forest']['all']['Random_Forest']
model_reult['all']['Logistic Regression'] = test['all']['Logistic Regression']['all']['Logistic Regression']
model_reult['all']['LightGBM'] = test['all']['LightGBM']['all']['LightGBM']
model_reult['all']['CatBoost'] = test['all']['CatBoost']
model_reult['all']['XGBoost'] = test['all']['XGBoost']

In [None]:
file ='model_reult.pickle'
if not os.path.exists(file):
    with open(file,'wb')as f:
        pickle.dump(model_reult,f)
else:
    model_reult = pickle.load(open(file,'rb'))

In [19]:
file ='model_reult.pickle'
with open(file,'wb')as f:
    pickle.dump(model_reult,f)

In [None]:
model_reult['']

In [None]:
# for types in tqdm(valid_cell_types, desc="处理细胞类型"):
error_dic= {}
for model_name in model_param_grid:
    for types in tqdm(model_reult, desc="处理细胞类型"):
    # model_reult[types] = {}
    # subset, labels, cancer_dist = SamplePick(adata_, types, cluster_col, 'response')
    # if subset is None:
    #     continue
    # for model_name in model_param_grid:
    # for model_name in model_reult[types]:
        print(types,model_name)
        # dflist = []
        # metrics_list = []
        # acc, auc_val, fpr, tpr, model, X_test, adata_test, le = TumorClassify(subset, labels,model_name)
        # types_clean = types.replace(' ', '_')
        # plot_roc(fpr, tpr, auc_val, types_clean,model_name)
        # logging.info(f"{types_clean} — 准确率: {acc:.2f}, AUC: {auc_val:.2f}, 采样后癌种分布: {cancer_dist}")

        # metrics_list.append({
        #     'Cell_Type': types_clean,
        #     'Accuracy': acc,
        #     'AUC': auc_val,
        #     'Cancer_Distribution': str(cancer_dist)  # 将癌种分布存为字符串
        # })
        if os.path.exists(f"{model_name}_{types}_Metrics.tsv"):
            continue
        acc, auc_val, fpr, tpr, model, X_test, adata_test, le,metrics_list,types_clean,dflistt =model_reult[types][model_name]
        plot_roc(fpr, tpr, auc_val, types_clean,model_name)
        print(X_test.shape)
        try:
            if isinstance(model_param_grid[model_name][0], (XGBClassifier, LGBMClassifier, CatBoostClassifier, RandomForestClassifier)):
                feature_df = shap_analysis(model, X_test, adata_test, types_clean,model_name)
                if len(feature_df):
                    feature_df['Cell_Type'] = types_clean
                    dflist.append(feature_df)

                    if dflist:
                        df_all = pd.concat(dflist)
                        df_all.to_csv(f"{model_name}_{types}_SHAP_Feature_Top10_Table.tsv", sep='\t', index=False)
                        logging.info(f"所有 SHAP 特征已保存至 '{model_name}_{types}_SHAP_Feature_Top10_Table.tsv'")

                    if metrics_list:
                        metrics_df = pd.DataFrame(metrics_list)
                        metrics_df.to_csv(f"{model_name}_{types}_Metrics.tsv", sep='\t', index=False)
                        logging.info(f"所有细胞类型的 ROC、准确率及癌种分布已保存至 '{model_name}_{types}_Metrics.tsv'")
        except Exception as e:
            error_dic[f'{model_name}_{types}'] = str(e)

处理细胞类型:   0%|          | 0/12 [00:00<?, ?it/s]

Mono Random_Forest
Plasma Random_Forest
B Random_Forest
gdT Random_Forest
cDC2 Random_Forest
pDC Random_Forest
CD8 Random_Forest
NK Random_Forest
cDC1 Random_Forest
CD4 Random_Forest
Macro Random_Forest
all Random_Forest
(19486, 1000)
Random_Forest_shap_all.dot.pdf


2025-06-21 07:28:49,269 - INFO - 所有 SHAP 特征已保存至 'Random_Forest_all_SHAP_Feature_Top10_Table.tsv'
2025-06-21 07:28:49,274 - INFO - 所有细胞类型的 ROC、准确率及癌种分布已保存至 'Random_Forest_all_Metrics.tsv'
处理细胞类型: 100%|██████████| 12/12 [21:57:05<00:00, 6585.44s/it]
处理细胞类型: 100%|██████████| 12/12 [00:00<00:00, 1141.54it/s]


Mono Logistic Regression
(964, 1000)
Plasma Logistic Regression
(274, 1000)
B Logistic Regression
(1085, 1000)
gdT Logistic Regression
(305, 1000)
cDC2 Logistic Regression
(199, 1034)
pDC Logistic Regression
(54, 1000)
CD8 Logistic Regression
(8990, 1000)
NK Logistic Regression
(2224, 1000)
cDC1 Logistic Regression
(39, 1018)
CD4 Logistic Regression
(4991, 1011)
Macro Logistic Regression
(364, 1002)
all Logistic Regression
(19486, 1000)


处理细胞类型:   0%|          | 0/12 [00:00<?, ?it/s]

Mono LightGBM
Plasma LightGBM
B LightGBM
gdT LightGBM
cDC2 LightGBM
pDC LightGBM
CD8 LightGBM
NK LightGBM
cDC1 LightGBM
CD4 LightGBM
Macro LightGBM
all LightGBM
(19486, 1000)
LightGBM_shap_all.dot.pdf


2025-06-21 07:29:55,536 - INFO - 所有 SHAP 特征已保存至 'LightGBM_all_SHAP_Feature_Top10_Table.tsv'
2025-06-21 07:29:55,540 - INFO - 所有细胞类型的 ROC、准确率及癌种分布已保存至 'LightGBM_all_Metrics.tsv'
处理细胞类型: 100%|██████████| 12/12 [01:06<00:00,  5.52s/it]
处理细胞类型:   0%|          | 0/12 [00:00<?, ?it/s]

Mono CatBoost
Plasma CatBoost
B CatBoost
gdT CatBoost
cDC2 CatBoost
pDC CatBoost
CD8 CatBoost
NK CatBoost
cDC1 CatBoost
CD4 CatBoost
Macro CatBoost
all CatBoost
(19486, 1000)
CatBoost_shap_all.dot.pdf


2025-06-21 07:30:11,153 - INFO - 所有 SHAP 特征已保存至 'CatBoost_all_SHAP_Feature_Top10_Table.tsv'
2025-06-21 07:30:11,158 - INFO - 所有细胞类型的 ROC、准确率及癌种分布已保存至 'CatBoost_all_Metrics.tsv'
处理细胞类型: 100%|██████████| 12/12 [00:15<00:00,  1.30s/it]
处理细胞类型:   0%|          | 0/12 [00:00<?, ?it/s]

Mono XGBoost
Plasma XGBoost
B XGBoost
gdT XGBoost
cDC2 XGBoost
pDC XGBoost
CD8 XGBoost
NK XGBoost
cDC1 XGBoost
CD4 XGBoost
Macro XGBoost
all XGBoost
(19486, 1000)
XGBoost_shap_all.dot.pdf


2025-06-21 07:30:42,215 - INFO - 所有 SHAP 特征已保存至 'XGBoost_all_SHAP_Feature_Top10_Table.tsv'
2025-06-21 07:30:42,220 - INFO - 所有细胞类型的 ROC、准确率及癌种分布已保存至 'XGBoost_all_Metrics.tsv'
处理细胞类型: 100%|██████████| 12/12 [00:31<00:00,  2.59s/it]


In [None]:
if not feature_df is None :
    print('ture')

In [None]:
feature_df = shap_analysis(model, X_test, adata_test, types_clean)

In [None]:
shap_analysis(model, X_test, adata_test, types_clean)

In [None]:
label = np.array(labels)
train_idx, test_idx = train_test_split(np.arange(len(label)), test_size=0.2, 
                                        random_state=RANDOM_SEED, stratify=label)
le = LabelEncoder()
y_train = le.fit_transform(label[train_idx])
y_test = le.transform(label[test_idx])

In [None]:
model

In [None]:
feature_names = adata_test.var_names.tolist()
if hasattr(model, "predict_proba") and "tree_" in str(type(model)).lower():
    explainer = shap.TreeExplainer(model, feature_names=feature_names)
elif "logisticregression" in str(type(model)).lower():
    explainer = shap.LinearExplainer(model, X_test, feature_perturbation="interventional")
else:
    print(f"[警告] {str(model).split('(')[0]} 模型类型暂不支持 SHAP")
    # return None

In [None]:
# def shap_analysis(model, X_test, adata_test, types):
if 1 == 1 :
    feature_names = adata_test.var_names.tolist()
    if isinstance(model, (RandomForestClassifier, LGBMClassifier, XGBClassifier, CatBoostClassifier)):
        explainer = shap.TreeExplainer(model, feature_names=feature_names)
    elif isinstance(model, LogisticRegression):
        explainer = shap.LinearExplainer(model, X_test, feature_perturbation="interventional")
    else:
        print(f"[警告] {str(model).split('(')[0]} 模型类型暂不支持 SHAP")
        # return None
    # explainer = shap.TreeExplainer(model, model_output='probability')
    shap_values = explainer(X_test)
    shap_vals = shap_values.values
    if shap_vals.shape[1] == X_test.shape[1] and shap_vals.ndim == 3:  # 特征在中间（错误）
        shap_vals = np.transpose(shap_vals, (0, 2, 1))
    # if shap_vals.ndim == 3:
        shap_mean = np.abs(shap_vals).mean(axis=(0, 1))
    else:
        shap_mean = np.abs(shap_vals).mean(axis=0)
    
    # print("shap_values.values shape:", shap_values.values.shape)
      # mean over samples & classes → shape (1000,)
    ensembl_ids = adata_test.var.index.tolist()
    gene_names = [ENSEMBL_DICT.get(eid, eid) for eid in ensembl_ids]
    feature_importance = pd.DataFrame({
        'Feature': gene_names,
        'Mean SHAP Value': shap_mean
    }).sort_values(by='Mean SHAP Value', ascending=False).reset_index(drop=True)
    plt.figure()
    shap.summary_plot(shap_values, X_test, feature_names=gene_names, plot_type='bar', show=False)
    plt.title(f'SHAP Summary - {types}')
    plt.tight_layout()
    plt.savefig(f'shap_{types}.pdf')
    plt.close()
    # return feature_importance.head(10)

In [None]:
X_test.shape

In [None]:
acc, auc_val, fpr, tpr, model, X_test, adata_test, le,metrics_list,types_clean,dflist =model_reult['B']['Random_Forest']

In [None]:
np.max(X_test), np.min(X_test)

In [None]:
explainer(X_test, check_additivity=False)

In [None]:
for model_name in model_param_grid:
    for types in tqdm(model_reult, desc="处理细胞类型"):
        print(types,model_name)
        acc, auc_val, fpr, tpr, model, X_test, adata_test, le,metrics_list,types_clean,dflist =model_reult[types][model_name]
        print(X_test.shape)

In [None]:
print(shap_vals.ndim)
print(shap_vals.shape)
np.abs(shap_vals).mean(axis=0).shape

In [None]:
def shap_analysis(model, X_test, adata_test, types,model_name='XGBoots'):
    feature_names = adata_test.var_names.tolist()
    if isinstance(model, (RandomForestClassifier, LGBMClassifier, XGBClassifier, CatBoostClassifier)):
        explainer = shap.TreeExplainer(model, feature_names=feature_names)
    elif isinstance(model, LogisticRegression):
        explainer = shap.LinearExplainer(model, X_test, feature_perturbation="interventional")
    else:
        print(f"[警告] {str(model).split('(')[0]} 模型类型暂不支持 SHAP")
        # return None
    # explainer = shap.TreeExplainer(model, model_output='probability')
    shap_values = explainer(X_test, check_additivity=False)
    shap_vals = shap_values.values
    if shap_vals.shape[1] == X_test.shape[1] and shap_vals.ndim == 3:  # 特征在中间（错误）
        shap_vals = np.transpose(shap_vals, (0, 2, 1))
    # if shap_vals.ndim == 3:
        shap_mean = np.abs(shap_vals).mean(axis=(0, 1))
    else:
        shap_mean = np.abs(shap_vals).mean(axis=0)
    
    # print("shap_values.values shape:", shap_values.values.shape)
      # mean over samples & classes → shape (1000,)
    ensembl_ids = adata_test.var.index.tolist()
    gene_names = [ENSEMBL_DICT.get(eid, eid) for eid in ensembl_ids]
    feature_importance = pd.DataFrame({
        'Feature': gene_names,
        'Mean SHAP Value': shap_mean
    }).sort_values(by='Mean SHAP Value', ascending=False).reset_index(drop=True)
    plt.figure()
    shap.summary_plot(shap_values, X_test, feature_names=gene_names, plot_type='bar', show=False)
    plt.title(f'SHAP Summary - {types}')
    plt.tight_layout()
    plt.savefig(f'{model_name}_shap_{types}.pdf')
    plt.close()
    return feature_importance.head(10)
 

In [None]:
X_train.shape[1] == X_test.shape[1]

In [None]:
sc.pp.highly_variable_genes(adata_train, n_top_genes=HIGH_VAR_GENES, flavor='cell_ranger')

In [None]:
print(subset_.shape)
X = subset_.X.toarray() if sp.issparse(subset_.X) else subset_.X

has_nan = np.isnan(X).any()
has_posinf = np.isposinf(X).any()
has_neginf = np.isneginf(X).any()

print(f"Has NaN: {has_nan}")
print(f"Has +inf: {has_posinf}")
print(f"Has -inf: {has_neginf}")

In [None]:
TumorClassify(subset, labels)

In [None]:
import scipy.sparse as sp
label = np.array(labels)
train_idx, test_idx = train_test_split(np.arange(len(label)), test_size=0.2, 
                                        random_state=RANDOM_SEED, stratify=label)
le = LabelEncoder()
y_train = le.fit_transform(label[train_idx])
y_test = le.transform(label[test_idx])
X = subset.X.toarray() if sp.issparse(subset.X) else subset.X
gene_means = X.mean(axis=0)
valid_gene_mask = gene_means > 0.1
subset_ = subset[:, valid_gene_mask].copy()
print(subset_.shape)
X = subset_.X.toarray() if sp.issparse(subset_.X) else subset_.X

has_nan = np.isnan(X).any()
has_posinf = np.isposinf(X).any()
has_neginf = np.isneginf(X).any()

print(f"Has NaN: {has_nan}")
print(f"Has +inf: {has_posinf}")
print(f"Has -inf: {has_neginf}")
adata_train = subset_[train_idx].copy()
adata_test = subset_[test_idx].copy()

sc.pp.highly_variable_genes(adata_train, n_top_genes=HIGH_VAR_GENES, flavor='cell_ranger')
selected_genes = adata_train.var.highly_variable
adata_train = adata_train[:, selected_genes]
adata_test = adata_test[:, selected_genes]

X_train = adata_train.X.toarray() if scipy.sparse.issparse(adata_train.X) else adata_train.X
X_test = adata_test.X.toarray() if scipy.sparse.issparse(adata_test.X) else adata_test.X

imputer = SimpleImputer(strategy='mean')
X_train = imputer.fit_transform(X_train)
X_test = imputer.transform(X_test)

sc.pp.highly_variable_genes(adata_train, n_top_genes=HIGH_VAR_GENES, flavor='cell_ranger')
selected_genes = adata_train.var.highly_variable
adata_train = adata_train[:, selected_genes]
adata_test = adata_test[:, selected_genes]
sc.pp.highly_variable_genes(adata_train, n_top_genes=HIGH_VAR_GENES, flavor='cell_ranger')

In [None]:
gene_means = X.mean(axis=0)
genes = subset.var_names
genes,X.shape

In [None]:
X = adata_train.X.toarray() if sp.issparse(adata_train.X) else adata_train.X

# 计算每个基因的均值
gene_means = X.mean(axis=0)

# 只保留均值大于一个小阈值的基因（如1e-5）
valid_gene_mask = gene_means < 1e-5

# 筛选掉均值太低的基因
test = adata_train[:, valid_gene_mask].copy()
test.shape,adata_train.shape
# # 然后再计算 HVG，不需要 duplicates 参数
# sc.pp.highly_variable_genes(adata_train, n_top_genes=HIGH_VAR_GENES, flavor='cell_ranger')

In [None]:
X_all = subset_.X.toarray() if sp.issparse(subset_.X) else subset_.X
max_val = np.max(X_all)
print(f"Maximum expression value in entire dataset: {max_val}")

In [None]:
test_df = pd.DataFrame(subset_.X.toarray(),columns=subset_.var_names)
test_df

In [None]:

max(test_df.values[0])

In [None]:
count = 0
sample_list = []
for i in test_df.index:
    data = test_df.loc[i]
    if sum(data) ==0:
        sample_list.append(i)
len(sample_list),len(test_df)

In [None]:
import scipy.sparse as sp
X = subset.X.toarray() if sp.issparse(subset.X) else subset.X

has_nan = np.isnan(X).any()
has_posinf = np.isposinf(X).any()
has_neginf = np.isneginf(X).any()

print(f"Has NaN: {has_nan}")
print(f"Has +inf: {has_posinf}")
print(f"Has -inf: {has_neginf}")

In [None]:
acc, auc_val, fpr, tpr, model, X_test, adata_test, le = TumorClassify(subset, labels)