Prepoccess data

In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.model_selection import train_test_split

# === Step 1: 载入 AnnData 和 barcode ===
adata = sc.read_10x_h5("data/Parent_NGSC3_DI_PBMC_filtered_feature_bc_matrix.h5")
barcodes = adata.obs_names.tolist()

# === Step 2: 载入 latent 向量 ===
# X = np.load("latent_db/latent_128.npy")  # shape: [n_cells, latent_dim]
X = adata.X.toarray()

# === Step 3: 载入 cluster label 并构建映射 dict ===
df_label = pd.read_csv("data/Parent_NGSC3_DI_PBMC_analysis/analysis/clustering/graphclust/clusters.csv")
barcode2cluster = dict(zip(df_label["Barcode"], df_label["Cluster"]))

# === Step 4: 构建标签向量 y，未匹配的设为 -1 ===
y = np.array([barcode2cluster.get(bc, -1) for bc in barcodes])  # -1 表示缺失标签
valid_idx = y != -1

# === Step 5: 仅保留有标签的样本 ===
X = X[valid_idx]
y = y[valid_idx]

# === Step 6: 划分训练、验证、测试集 ===
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.50, stratify=y_temp, random_state=42)

# === Step 7: 打印信息检查 ===
print("X_train:", X_train.shape, "y_train:", y_train.shape)
print("X_val:", X_val.shape, "y_val:", y_val.shape)
print("X_test:", X_test.shape, "y_test:", y_test.shape)

  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


X_train: (7135, 36601) y_train: (7135,)
X_val: (1529, 36601) y_val: (1529,)
X_test: (1530, 36601) y_test: (1530,)


In [8]:
np.unique(y_train)

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

In [9]:
def print_label_stats(y, name):
    unique, counts = np.unique(y, return_counts=True)
    total = len(y)
    print(f"\n=== {name.upper()} SET LABEL DISTRIBUTION ===")
    for label, count in zip(unique, counts):
        percent = 100 * count / total
        print(f"Label {label:2d} → Count: {count:4d}  |  {percent:5.2f}%")
    print(f"Total: {total} samples")

print_label_stats(y, "train")



=== TRAIN SET LABEL DISTRIBUTION ===
Label  1 → Count: 1970  |  19.33%
Label  2 → Count: 1540  |  15.11%
Label  3 → Count:  967  |   9.49%
Label  4 → Count:  878  |   8.61%
Label  5 → Count:  753  |   7.39%
Label  6 → Count:  630  |   6.18%
Label  7 → Count:  606  |   5.94%
Label  8 → Count:  534  |   5.24%
Label  9 → Count:  468  |   4.59%
Label 10 → Count:  384  |   3.77%
Label 11 → Count:  353  |   3.46%
Label 12 → Count:  328  |   3.22%
Label 13 → Count:  278  |   2.73%
Label 14 → Count:  259  |   2.54%
Label 15 → Count:  246  |   2.41%
Total: 10194 samples


In [10]:
print_label_stats(y_train, "train")


=== TRAIN SET LABEL DISTRIBUTION ===
Label  1 → Count: 1379  |  19.33%
Label  2 → Count: 1078  |  15.11%
Label  3 → Count:  677  |   9.49%
Label  4 → Count:  614  |   8.61%
Label  5 → Count:  527  |   7.39%
Label  6 → Count:  441  |   6.18%
Label  7 → Count:  424  |   5.94%
Label  8 → Count:  374  |   5.24%
Label  9 → Count:  327  |   4.58%
Label 10 → Count:  269  |   3.77%
Label 11 → Count:  247  |   3.46%
Label 12 → Count:  230  |   3.22%
Label 13 → Count:  195  |   2.73%
Label 14 → Count:  181  |   2.54%
Label 15 → Count:  172  |   2.41%
Total: 7135 samples


In [11]:
print_label_stats(y_val, "val")


=== VAL SET LABEL DISTRIBUTION ===
Label  1 → Count:  295  |  19.29%
Label  2 → Count:  231  |  15.11%
Label  3 → Count:  145  |   9.48%
Label  4 → Count:  132  |   8.63%
Label  5 → Count:  113  |   7.39%
Label  6 → Count:   94  |   6.15%
Label  7 → Count:   91  |   5.95%
Label  8 → Count:   80  |   5.23%
Label  9 → Count:   70  |   4.58%
Label 10 → Count:   58  |   3.79%
Label 11 → Count:   53  |   3.47%
Label 12 → Count:   49  |   3.20%
Label 13 → Count:   42  |   2.75%
Label 14 → Count:   39  |   2.55%
Label 15 → Count:   37  |   2.42%
Total: 1529 samples


In [12]:
print_label_stats(y_test, "test")


=== TEST SET LABEL DISTRIBUTION ===
Label  1 → Count:  296  |  19.35%
Label  2 → Count:  231  |  15.10%
Label  3 → Count:  145  |   9.48%
Label  4 → Count:  132  |   8.63%
Label  5 → Count:  113  |   7.39%
Label  6 → Count:   95  |   6.21%
Label  7 → Count:   91  |   5.95%
Label  8 → Count:   80  |   5.23%
Label  9 → Count:   71  |   4.64%
Label 10 → Count:   57  |   3.73%
Label 11 → Count:   53  |   3.46%
Label 12 → Count:   49  |   3.20%
Label 13 → Count:   41  |   2.68%
Label 14 → Count:   39  |   2.55%
Label 15 → Count:   37  |   2.42%
Total: 1530 samples


Classfication

In [13]:
import torch
torch.cuda.is_available()

True

In [18]:
import warnings
import random
random.seed(42)

warnings.filterwarnings('ignore')
import importlib, cellbert_runner as cr
importlib.reload(cr)

cfg = {
    "adata_h5": "data/Parent_NGSC3_DI_PBMC_filtered_feature_bc_matrix.h5",
    "cluster_csv": "data/Parent_NGSC3_DI_PBMC_analysis/analysis/clustering/graphclust/clusters.csv",
    # 若要用 SAE latent 向量：
    "use_latent": True,
    "latent_path": "latent_db/latent_8.npy",

    # 模型超参
    "bins": 100,
    "embed_dim": 128,
    "layers": 4,
    "heads": 8,
    "dropout": 0.1,
    "epochs": 15,
    "batch_size": 64,
    "lr": 1e-4,
}

model, test_acc = cr.run_experiment(cfg)
print("Final test acc:", test_acc)


Epoch 1/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 01] train_loss=2.5213 | val_acc=0.1511


Epoch 2/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 02] train_loss=2.5065 | val_acc=0.2041


Epoch 3/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 03] train_loss=2.5035 | val_acc=0.2041


Epoch 4/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 04] train_loss=2.4976 | val_acc=0.2041


Epoch 5/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 05] train_loss=2.4882 | val_acc=0.1956


Epoch 6/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 06] train_loss=2.4929 | val_acc=0.2041


Epoch 7/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 07] train_loss=2.4943 | val_acc=0.2041


Epoch 8/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 08] train_loss=2.4913 | val_acc=0.2041


Epoch 9/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 09] train_loss=2.4884 | val_acc=0.1511


Epoch 10/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 10] train_loss=2.4882 | val_acc=0.2041


Epoch 11/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 11] train_loss=2.4866 | val_acc=0.2041


Epoch 12/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 12] train_loss=2.4868 | val_acc=0.2041


Epoch 13/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 13] train_loss=2.4871 | val_acc=0.2041


Epoch 14/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 14] train_loss=2.4845 | val_acc=0.2041


Epoch 15/15:   0%|          | 0/112 [00:00<?, ?it/s]

[Epoch 15] train_loss=2.4844 | val_acc=0.2041

>>> Test Accuracy: 0.2059

Final test acc: 0.20588235294117646
