In [1]:
import scanpy as sc
data_dir = '../../dataset/Xenium_breast_cancer_sample1_replicate1.h5ad'
adata = sc.read_h5ad(data_dir)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)



In [2]:
import torch
import lightning.pytorch as pl
from self_supervision.models.lightning_modules.cellnet_autoencoder import MLPAutoEncoder
from self_supervision.estimator.cellnet import EstimatorAutoEncoder

# 设置你的 .ckpt 文件路径
ckpt_path = "../../sc_pretrained/Pretrained Models/GPMask.ckpt"

# 模型参数
units_encoder = [512, 512, 256, 256, 64]
units_decoder = [256, 256, 512, 512]

# 初始化 EstimatorAutoEncoder 实例
estim = EstimatorAutoEncoder(data_path=None)  # 如果没有实际数据路径，可以设置为None

# 加载预训练模型
estim.model = MLPAutoEncoder.load_from_checkpoint(
    ckpt_path,
    gene_dim=19331,  # 根据你的数据调整
    batch_size=128,  # 根据你的需要调整
    units_encoder=units_encoder, 
    units_decoder=units_decoder,
    masking_strategy="random",  # 假设模型使用了随机掩码
    masking_rate=0.5,  # 根据需要调整
)

# 使用 GPU 进行评估（如果可用）
estim.trainer = pl.Trainer(accelerator="gpu", devices=1 if torch.cuda.is_available() else None)
estim.model

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


MLPAutoEncoder(
  (train_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=train_
  )
  (val_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=val_
  )
  (test_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=test_
  )
  (encoder): MLP(
    (0): Linear(in_features=19331, out_features=512, bias=True)
    (1): SELU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): SELU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SELU()
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=256, out_features=256, b

In [3]:
import pandas as pd
var_df = pd.read_parquet('../../sc_pretrained/var.parquet')
var_df

Unnamed: 0,feature_id,feature_name
0,ENSG00000186092,OR4F5
1,ENSG00000284733,OR4F29
2,ENSG00000284662,OR4F16
3,ENSG00000187634,SAMD11
4,ENSG00000188976,NOC2L
...,...,...
19326,ENSG00000288702,UGT1A3
19327,ENSG00000288705,UGT1A5
19328,ENSG00000182484,WASH6P
19329,ENSG00000288622,PDCD6-AHRR


In [4]:
all_genes = var_df['feature_name'].tolist()
all_genes

['OR4F5',
 'OR4F29',
 'OR4F16',
 'SAMD11',
 'NOC2L',
 'KLHL17',
 'PLEKHN1',
 'PERM1',
 'HES4',
 'ISG15',
 'AGRN',
 'RNF223',
 'C1orf159',
 'TTLL10',
 'TNFRSF18',
 'TNFRSF4',
 'SDF4',
 'B3GALT6',
 'C1QTNF12',
 'UBE2J2',
 'SCNN1D',
 'ACAP3',
 'PUSL1',
 'INTS11',
 'CPTP',
 'TAS1R3',
 'DVL1',
 'MXRA8',
 'AURKAIP1',
 'CCNL2',
 'MRPL20',
 'ANKRD65',
 'TMEM88B',
 'VWA1',
 'ATAD3C',
 'ATAD3B',
 'ATAD3A',
 'TMEM240',
 'SSU72',
 'FNDC10',
 'MIB2',
 'MMP23B',
 'CDK11B',
 'SLC35E2B',
 'CDK11A',
 'NADK',
 'GNB1',
 'CALML6',
 'TMEM52',
 'CFAP74',
 'GABRD',
 'PRKCZ',
 'FAAP20',
 'SKI',
 'MORN1',
 'RER1',
 'PEX10',
 'PLCH2',
 'PANK4',
 'HES5',
 'TNFRSF14',
 'PRXL2B',
 'MMEL1',
 'TTC34',
 'ACTRT2',
 'PRDM16',
 'ARHGEF16',
 'MEGF6',
 'TPRG1L',
 'WRAP73',
 'TP73',
 'CCDC27',
 'SMIM1',
 'LRRC47',
 'CEP104',
 'DFFB',
 'C1orf174',
 'AJAP1',
 'NPHP4',
 'KCNAB2',
 'CHD5',
 'RPL22',
 'RNF207',
 'ICMT',
 'HES3',
 'GPR153',
 'ACOT7',
 'HES2',
 'ESPN',
 'TNFRSF25',
 'PLEKHG5',
 'NOL9',
 'TAS1R1',
 'ZBTB48',
 'KLH

In [5]:
adata.var['gene_name']=adata.var.index
adata.var['gene_name']

ABCC11    ABCC11
ACTA2      ACTA2
ACTG2      ACTG2
ADAM9      ADAM9
ADGRE5    ADGRE5
           ...  
VWF          VWF
WARS        WARS
ZEB1        ZEB1
ZEB2        ZEB2
ZNF562    ZNF562
Name: gene_name, Length: 313, dtype: object

In [6]:
import numpy as np
# 初始化一个新的数据矩阵，形状为 (adata.X.shape[0], len(all_genes))，填充为零
new_data = np.zeros((adata.X.shape[0], len(all_genes)), dtype=np.float32)


In [7]:
existing_genes = adata.var['gene_name']
existing_genes

ABCC11    ABCC11
ACTA2      ACTA2
ACTG2      ACTG2
ADAM9      ADAM9
ADGRE5    ADGRE5
           ...  
VWF          VWF
WARS        WARS
ZEB1        ZEB1
ZEB2        ZEB2
ZNF562    ZNF562
Name: gene_name, Length: 313, dtype: object

In [8]:
# 将所有基因名称转换为小写
all_genes_lower = [gene.lower() for gene in all_genes]
adata_genes_lower = [gene.lower() for gene in existing_genes]

# 将两个列表转换为集合
all_genes_set = set(all_genes_lower)
adata_genes_set = set(adata_genes_lower)

# 计算交集
matching_genes = all_genes_set.intersection(adata_genes_set)
matching_count = len(matching_genes)
# 计算不匹配的基因
non_matching_genes = adata_genes_set - matching_genes
non_matching_count = len(non_matching_genes)


# 输出结果
print(f"匹配的基因数量: {matching_count}")
print(f"匹配的基因列表: {matching_genes}")
non_matching_genes


匹配的基因数量: 305
匹配的基因列表: {'fgl2', 'cd83', 'dst', 'basp1', 'slamf7', 'hmga1', 'tnfrsf17', 'igf1', 'derl3', 'svil', 'itgax', 'lif', 'lag3', 'ociad2', 'ceacam6', 'angpt2', 'lyz', 'fcgr3a', 'adipoq', 'klrf1', 's100a14', 'tubb2b', 'fstl3', 'mki67', 'bace2', 'enah', 'cd80', 'ltb', 'tifa', 'ctsg', 'slc5a6', 'hook2', 'trib1', 'cytip', 'mylk', 'dsp', 'dsc2', 'tmem147', 'cd27', 'cd79b', 'stc1', 'tpsab1', 'cxcr4', 'adam9', 'pecam1', 'lep', 'myo5b', 'thap2', 'cd93', 'trappc3', 'ccdc80', 'mlph', 'cxcl12', 'sh3yl1', 'scd', 'dusp2', 'akr1c3', 'agr3', 'cd69', 'cd68', 'cd86', 'ucp1', 'dnttip1', 'pigr', 'edn1', 'smap2', 'esm1', 'gjb2', 'tcf4', 'crispld2', 'dnaaf1', 'cd4', 'tigit', 'cd19', 'sfrp4', 'tcf15', 's100a4', 'ccr7', 'adgre5', 'ctla4', 'apobec3a', 'foxp3', 'ptn', 'ccl8', 'medag', 'cth', 'cxcl16', 'krt8', 'csf3', 'pgr', 'cavin2', 'fcer1a', 'scgb2a1', 'sstr2', 'nkg7', 'ly86', 'pdgfra', 'ankrd29', 'flnb', 'snai1', 'runx1', 'peli1', 'avpr1a', 'cldn5', 'prf1', 'dmkn', 'cd8b', 'mmp1', 'traf4', 'rab30', 's

{'fam49a', 'kars', 'lars', 'nars', 'polr2j3', 'qars', 'trac', 'wars'}

In [9]:
gene_to_index = {gene: idx for idx, gene in enumerate(all_genes_lower)}
gene_to_index

{'or4f5': 0,
 'or4f29': 1,
 'or4f16': 2,
 'samd11': 3,
 'noc2l': 4,
 'klhl17': 5,
 'plekhn1': 6,
 'perm1': 7,
 'hes4': 8,
 'isg15': 9,
 'agrn': 10,
 'rnf223': 11,
 'c1orf159': 12,
 'ttll10': 13,
 'tnfrsf18': 14,
 'tnfrsf4': 15,
 'sdf4': 16,
 'b3galt6': 17,
 'c1qtnf12': 18,
 'ube2j2': 19,
 'scnn1d': 20,
 'acap3': 21,
 'pusl1': 22,
 'ints11': 23,
 'cptp': 24,
 'tas1r3': 25,
 'dvl1': 26,
 'mxra8': 27,
 'aurkaip1': 28,
 'ccnl2': 29,
 'mrpl20': 30,
 'ankrd65': 31,
 'tmem88b': 32,
 'vwa1': 33,
 'atad3c': 34,
 'atad3b': 35,
 'atad3a': 36,
 'tmem240': 37,
 'ssu72': 38,
 'fndc10': 39,
 'mib2': 40,
 'mmp23b': 41,
 'cdk11b': 42,
 'slc35e2b': 43,
 'cdk11a': 44,
 'nadk': 45,
 'gnb1': 46,
 'calml6': 47,
 'tmem52': 48,
 'cfap74': 49,
 'gabrd': 50,
 'prkcz': 51,
 'faap20': 52,
 'ski': 53,
 'morn1': 54,
 'rer1': 55,
 'pex10': 56,
 'plch2': 57,
 'pank4': 58,
 'hes5': 59,
 'tnfrsf14': 60,
 'prxl2b': 61,
 'mmel1': 62,
 'ttc34': 63,
 'actrt2': 64,
 'prdm16': 65,
 'arhgef16': 66,
 'megf6': 67,
 'tprg1l': 68

In [10]:
only_in_all_genes = all_genes_set - adata_genes_set

only_in_adata_genes = adata_genes_set - all_genes_set

# 输出结果
print(f"仅在 all_genes 中存在的基因数量: {len(only_in_all_genes)}")
print(f"仅在 all_genes 中存在的基因: {only_in_all_genes}")

print(f"仅在 adata_genes 中存在的基因数量: {len(only_in_adata_genes)}")
print(f"仅在 adata_genes 中存在的基因: {only_in_adata_genes}")


仅在 all_genes 中存在的基因数量: 19026
仅在 all_genes 中存在的基因: {'ice2', 'atp9a', 'nfkbiz', 'pry', 'adat2', 'mov10', 'adss2', 'eva1c', 'rtl10', 'znf556', 'tceanc', 'tectb', 'twf2', 'cabp1', 'rtl3', 'ly6g5c', 'pola2', 'c19orf25', 'khdrbs2', 'arl4c', 'rexo1', 'or4c16', 'zfp3', 'uvrag', 'uts2b', 'hspa4', 'sap130', 'ssna1', 'nnat', 'pvalef', 'mafg', 'c12orf40', 'ccdc81', 'itgb5', 'sco2', 'vrk1', 'rpl24', 'smim17', 'mrpl43', 'upf1', 'znf223', 'b3galt9', 'plekhn1', 'thumpd1', 'znf691', 'minpp1', 'smarcc2', 'rbm41', 'h2bs1', 'hsf4', 'pdss1', 'rtn4r', 'lhx9', 'etnppl', 'phf21a', 'rhoxf1', 'eif4e1b', 'sik1b', 'dus2', 'cmc1', 'ccdc168', 'chaf1a', 'slc49a4', 'oscp1', 'slc36a3', 'rnf6', 'ttbk1', 'amy1a', 'dmac2l', 'mcrip2', 'cox20', 'rnaseh1', 'cct8', 'atp13a1', 'hrh1', 'stau1', 'mmp26', 'luzp2', 'kcnj6', 'ndufa13', 'nrxn2', 'usp17l1', 'c3orf20', 'mta3', 'tmem131l', 'abl2', 'sh3bgrl3', 'yars1', 'nudt1', 'prac1', 'pcyt1a', 'mfap3l', 'znf773', 'ypel3', 'snrnp200', 'aip', 'abcd3', 'tnfrsf11b', 'mef2a', 'bco2', 'pp

In [11]:
dense_adata_X = adata.X.toarray()
for i, gene in enumerate(adata_genes_lower):
    if gene in gene_to_index:
        new_data[:, gene_to_index[gene]] = dense_adata_X[:, i]
    else:
        print(f'Gene {gene} not found in all_genes list')


Gene fam49a not found in all_genes list


Gene kars not found in all_genes list
Gene lars not found in all_genes list


Gene nars not found in all_genes list
Gene polr2j3 not found in all_genes list


Gene qars not found in all_genes list


Gene trac not found in all_genes list
Gene wars not found in all_genes list


In [12]:
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder
import numpy as np
from sklearn.model_selection import train_test_split


label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(adata.obs['cell_type'])  # 预先编码标签


random_seed = 42
X_train_val, X_test, y_train_val, y_test = train_test_split(
    new_data, labels_encoded, test_size=0.15, random_state=random_seed)


X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.1765, random_state=random_seed)  # 0.1765 是为了让验证集占 15%

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

estim.model.eval()
with torch.no_grad():
    X_train_tensor = torch.tensor(X_train).float().to(device)
    X_test_tensor = torch.tensor(X_test).float().to(device)
    train_embeddings = estim.model.encoder(X_train_tensor).detach().cpu().numpy()
    test_embeddings = estim.model.encoder(X_test_tensor).detach().cpu().numpy()


cuda


In [13]:
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report

    

    # 初始化和训练KNN分类器
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(train_embeddings, y_train)
    
    # 模型预测
predictions = knn.predict(test_embeddings)

    # 计算准确率和 F1 分数
accuracy = accuracy_score(y_test, predictions)
print(f"KNN Accuracy on Test Data: {accuracy}")
f1 = f1_score(y_test, predictions, average='weighted')
print(f"Weighted F1 Score: {f1}")
    
macro_f1 = f1_score(y_test, predictions, average='macro')
print(f'Macro F1 Score: {macro_f1}')

    # 计算随机猜测的准确率
class_probabilities = np.bincount(y_test) / len(y_test)
random_accuracy = np.sum(class_probabilities ** 2)
print(f"Random Guess Accuracy: {random_accuracy}")

    # 生成分类报告
report = classification_report(y_test, predictions, target_names=label_encoder.classes_)
print(report)

KNN Accuracy on Test Data: 0.7182818770612309
Weighted F1 Score: 0.7059551806931811
Macro F1 Score: 0.5505823062338087
Random Guess Accuracy: 0.1320247177116093
                         precision    recall  f1-score   support

                B_Cells       0.74      0.85      0.79       772
           CD4+_T_Cells       0.59      0.65      0.62      1286
           CD8+_T_Cells       0.59      0.60      0.60      1026
                 DCIS_1       0.63      0.75      0.69      1937
                 DCIS_2       0.66      0.49      0.57      1746
            Endothelial       0.81      0.84      0.82      1348
              IRF7+_DCs       0.77      0.73      0.75        74
         Invasive_Tumor       0.78      0.89      0.83      5230
             LAMP3+_DCs       0.52      0.22      0.31        49
          Macrophages_1       0.61      0.57      0.59      1692
          Macrophages_2       0.37      0.18      0.24       223
             Mast_Cells       0.68      0.68      0.68    

In [14]:
with torch.no_grad():
    new_data_tensor = torch.tensor(new_data).float().to(device)
    SSL_embeddings = estim.model.encoder(new_data_tensor).detach().cpu().numpy()
new_adata = sc.read_h5ad(data_dir)
new_adata.obsm[f'SSL_GP_ZS_{random_seed}'] = SSL_embeddings
new_adata.uns[f'GP_ZS_y_test_{random_seed}'] = y_test
new_adata.uns[f'GP_ZS_predictions_{random_seed}'] = predictions
new_adata.uns[f'GP_ZS_target_names_{random_seed}'] = label_encoder.classes_
new_adata.write_h5ad(data_dir)

In [15]:

import pandas as pd
import os
import re

# 当前 Notebook 文件名
notebook_name = "Xenium_breast_cancer_sample1_replicate1_GP_mask_zero_shot_42.ipynb"

# 初始化需要打印的值
init_train_loss = train_losses[0] if 'train_losses' in globals() else None
init_val_loss = val_losses[0] if 'val_losses' in globals() else None
converged_epoch = len(train_losses) - patience if 'train_losses' in globals() else None
converged_val_loss = best_val_loss if 'best_val_loss' in globals() else None

# 打印所有所需的指标
print("Metrics Summary:")
if 'train_losses' in globals():
    print(f"init_train_loss\tinit_val_loss\tconverged_epoch\tconverged_val_loss\tmacro_f1\tweighted_f1\tmicor_f1")
    print(f"{init_train_loss:.3f}\t{init_val_loss:.3f}\t{converged_epoch}\t{converged_val_loss:.3f}\t{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")
else:
    print(f"macro_f1\tweighted_f1\tmicor_f1")
    print(f"{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")

# 保存结果到 CSV 文件
output_data = {
    'dataset_split_random_seed': [int(random_seed)],
    'dataset': ['xenium_breast_cancer_sample1_replicate1'],
    'method': [re.search(r'replicate1_(.*?)_\d+', notebook_name).group(1)],
    'init_train_loss': [init_train_loss if init_train_loss is not None else ''],
    'init_val_loss': [init_val_loss if init_val_loss is not None else ''],
    'converged_epoch': [converged_epoch if converged_epoch is not None else ''],
    'converged_val_loss': [converged_val_loss if converged_val_loss is not None else ''],
    'macro_f1': [macro_f1],
    'weighted_f1': [f1],
    'micor_f1': [accuracy]
}
output_df = pd.DataFrame(output_data)

# 保存到当前目录下名为 results 的文件夹中
if not os.path.exists('results'):
    os.makedirs('results')

csv_filename = f"results/{os.path.splitext(notebook_name)[0]}_results.csv"
output_df.to_csv(csv_filename, index=False)


Metrics Summary:
macro_f1	weighted_f1	micor_f1
0.551	0.706	0.718
