In [1]:
import scanpy as sc
data_dir = '../../dataset/Xenium_breast_cancer_sample1_replicate1.h5ad'
adata = sc.read_h5ad(data_dir)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)



In [2]:
%load_ext autoreload
%autoreload 2
import torch
import lightning.pytorch as pl
from self_supervision.models.lightning_modules.cellnet_autoencoder import MLPBarlowTwins
from self_supervision.estimator.cellnet import EstimatorAutoEncoder

# 设置你的 .ckpt 文件路径
ckpt_path = "../../sc_pretrained/Pretrained Models/BarlowTwins.ckpt"

# 模型参数
units_encoder = [512, 512, 256, 256, 64]

# 初始化 EstimatorAutoEncoder 实例
estim = EstimatorAutoEncoder(data_path=None)  # 如果没有实际数据路径，可以设置为None

# 加载预训练模型
estim.model = MLPBarlowTwins(
        gene_dim=19331,  # 根据你的数据调整
        batch_size=8192,  # 根据你的需要调整
        units_encoder=units_encoder,
        CHECKPOINT_PATH=ckpt_path
    )


estim.trainer = pl.Trainer(accelerator="gpu", devices=1 if torch.cuda.is_available() else None)
estim.model

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


MLPBarlowTwins(
  (train_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=train_
  )
  (val_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=val_
  )
  (test_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=test_
  )
  (inner_model): MLP(
    (0): Linear(in_features=19331, out_features=512, bias=True)
    (1): SELU()
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): SELU()
    (5): Dropout(p=0.0, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SELU()
    (8): Dropout(p=0.0, inplace=False)
    (9): Linear(in_features=256, out_features=25

In [3]:
# 加载 checkpoint
checkpoint = torch.load(ckpt_path)
estim.model.inner_model.load_state_dict({k.replace('backbone.', ''): v for k, v in checkpoint.items() if 'backbone' in k})

# 设定模型为评估模式
estim.model.eval()

MLPBarlowTwins(
  (train_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=train_
  )
  (val_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=val_
  )
  (test_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=test_
  )
  (inner_model): MLP(
    (0): Linear(in_features=19331, out_features=512, bias=True)
    (1): SELU()
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): SELU()
    (5): Dropout(p=0.0, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SELU()
    (8): Dropout(p=0.0, inplace=False)
    (9): Linear(in_features=256, out_features=25

In [4]:
import pandas as pd
var_df = pd.read_parquet('../../sc_pretrained/var.parquet')
var_df

Unnamed: 0,feature_id,feature_name
0,ENSG00000186092,OR4F5
1,ENSG00000284733,OR4F29
2,ENSG00000284662,OR4F16
3,ENSG00000187634,SAMD11
4,ENSG00000188976,NOC2L
...,...,...
19326,ENSG00000288702,UGT1A3
19327,ENSG00000288705,UGT1A5
19328,ENSG00000182484,WASH6P
19329,ENSG00000288622,PDCD6-AHRR


In [5]:
all_genes = var_df['feature_name'].tolist()
all_genes

['OR4F5',
 'OR4F29',
 'OR4F16',
 'SAMD11',
 'NOC2L',
 'KLHL17',
 'PLEKHN1',
 'PERM1',
 'HES4',
 'ISG15',
 'AGRN',
 'RNF223',
 'C1orf159',
 'TTLL10',
 'TNFRSF18',
 'TNFRSF4',
 'SDF4',
 'B3GALT6',
 'C1QTNF12',
 'UBE2J2',
 'SCNN1D',
 'ACAP3',
 'PUSL1',
 'INTS11',
 'CPTP',
 'TAS1R3',
 'DVL1',
 'MXRA8',
 'AURKAIP1',
 'CCNL2',
 'MRPL20',
 'ANKRD65',
 'TMEM88B',
 'VWA1',
 'ATAD3C',
 'ATAD3B',
 'ATAD3A',
 'TMEM240',
 'SSU72',
 'FNDC10',
 'MIB2',
 'MMP23B',
 'CDK11B',
 'SLC35E2B',
 'CDK11A',
 'NADK',
 'GNB1',
 'CALML6',
 'TMEM52',
 'CFAP74',
 'GABRD',
 'PRKCZ',
 'FAAP20',
 'SKI',
 'MORN1',
 'RER1',
 'PEX10',
 'PLCH2',
 'PANK4',
 'HES5',
 'TNFRSF14',
 'PRXL2B',
 'MMEL1',
 'TTC34',
 'ACTRT2',
 'PRDM16',
 'ARHGEF16',
 'MEGF6',
 'TPRG1L',
 'WRAP73',
 'TP73',
 'CCDC27',
 'SMIM1',
 'LRRC47',
 'CEP104',
 'DFFB',
 'C1orf174',
 'AJAP1',
 'NPHP4',
 'KCNAB2',
 'CHD5',
 'RPL22',
 'RNF207',
 'ICMT',
 'HES3',
 'GPR153',
 'ACOT7',
 'HES2',
 'ESPN',
 'TNFRSF25',
 'PLEKHG5',
 'NOL9',
 'TAS1R1',
 'ZBTB48',
 'KLH

In [6]:
adata.var['gene_name']=adata.var.index
adata.var['gene_name']

ABCC11    ABCC11
ACTA2      ACTA2
ACTG2      ACTG2
ADAM9      ADAM9
ADGRE5    ADGRE5
           ...  
VWF          VWF
WARS        WARS
ZEB1        ZEB1
ZEB2        ZEB2
ZNF562    ZNF562
Name: gene_name, Length: 313, dtype: object

In [7]:
import numpy as np
# 初始化一个新的数据矩阵，形状为 (adata.X.shape[0], len(all_genes))，填充为零
new_data = np.zeros((adata.X.shape[0], len(all_genes)), dtype=np.float32)


In [8]:
existing_genes = adata.var['gene_name']
existing_genes

ABCC11    ABCC11
ACTA2      ACTA2
ACTG2      ACTG2
ADAM9      ADAM9
ADGRE5    ADGRE5
           ...  
VWF          VWF
WARS        WARS
ZEB1        ZEB1
ZEB2        ZEB2
ZNF562    ZNF562
Name: gene_name, Length: 313, dtype: object

In [9]:
# 将所有基因名称转换为小写
all_genes_lower = [gene.lower() for gene in all_genes]
adata_genes_lower = [gene.lower() for gene in existing_genes]

# 将两个列表转换为集合
all_genes_set = set(all_genes_lower)
adata_genes_set = set(adata_genes_lower)

# 计算交集
matching_genes = all_genes_set.intersection(adata_genes_set)
matching_count = len(matching_genes)
# 计算不匹配的基因
non_matching_genes = adata_genes_set - matching_genes
non_matching_count = len(non_matching_genes)


# 输出结果
print(f"匹配的基因数量: {matching_count}")
print(f"匹配的基因列表: {matching_genes}")
non_matching_genes


匹配的基因数量: 305
匹配的基因列表: {'sec11c', 'c6orf132', 'cd247', 'tnfrsf17', 'lgalsl', 'oprpn', 'cd3e', 'derl3', 'aqp3', 'gzmb', 'tyrobp', 'cth', 'cenpf', 'ctsg', 'clec14a', 'prf1', 'ankrd30a', 'cdh1', 'ramp2', 'dnaaf1', 'ms4a1', 'cd274', 'il7r', 'scgb2a1', 's100a4', 'ptgds', 'ceacam8', 'mmrn2', 'dst', 'postn', 'dmkn', 'cd163', 'trappc3', 'ccr7', 'pdgfrb', 'trib1', 'adipoq', 'lpl', 'krt15', 'tcf15', 'clec9a', 'klrc1', 'cttn', 'agr3', 'pparg', 'igf1', 'tifa', 'znf562', 'oxtr', 'pgr', 'hoxd8', 'cxcr4', 'pcolce', 'sh3yl1', 's100a8', 'igsf6', 'ociad2', 'sec24a', 'zeb1', 'pecam1', 'klrb1', 'glipr1', 'foxa1', 'epcam', 'hdc', 'lif', 'c1qa', 'ncam1', 'klf5', 'medag', 'esm1', 'ankrd29', 'ly86', 'pigr', 'ccl20', 'usp53', 'map3k8', 'vopp1', 'tceal7', 'clic6', 'bank1', 'hoxd9', 'csf3', 'ccnd1', 'lypd3', 'ldhb', 'ndufa4l2', 'dusp2', 'ptrhd1', 'peli1', 'c15orf48', 'itgax', 'apobec3b', 'il2ra', 'srpk1', 'myh11', 'tubb2b', 'lum', 'erbb2', 'fbln1', 'tigit', 'gpr183', 'mdm2', 'angpt2', 'sfrp1', 'akr1c1', 'basp1', 

{'fam49a', 'kars', 'lars', 'nars', 'polr2j3', 'qars', 'trac', 'wars'}

In [10]:
gene_to_index = {gene: idx for idx, gene in enumerate(all_genes_lower)}
gene_to_index

{'or4f5': 0,
 'or4f29': 1,
 'or4f16': 2,
 'samd11': 3,
 'noc2l': 4,
 'klhl17': 5,
 'plekhn1': 6,
 'perm1': 7,
 'hes4': 8,
 'isg15': 9,
 'agrn': 10,
 'rnf223': 11,
 'c1orf159': 12,
 'ttll10': 13,
 'tnfrsf18': 14,
 'tnfrsf4': 15,
 'sdf4': 16,
 'b3galt6': 17,
 'c1qtnf12': 18,
 'ube2j2': 19,
 'scnn1d': 20,
 'acap3': 21,
 'pusl1': 22,
 'ints11': 23,
 'cptp': 24,
 'tas1r3': 25,
 'dvl1': 26,
 'mxra8': 27,
 'aurkaip1': 28,
 'ccnl2': 29,
 'mrpl20': 30,
 'ankrd65': 31,
 'tmem88b': 32,
 'vwa1': 33,
 'atad3c': 34,
 'atad3b': 35,
 'atad3a': 36,
 'tmem240': 37,
 'ssu72': 38,
 'fndc10': 39,
 'mib2': 40,
 'mmp23b': 41,
 'cdk11b': 42,
 'slc35e2b': 43,
 'cdk11a': 44,
 'nadk': 45,
 'gnb1': 46,
 'calml6': 47,
 'tmem52': 48,
 'cfap74': 49,
 'gabrd': 50,
 'prkcz': 51,
 'faap20': 52,
 'ski': 53,
 'morn1': 54,
 'rer1': 55,
 'pex10': 56,
 'plch2': 57,
 'pank4': 58,
 'hes5': 59,
 'tnfrsf14': 60,
 'prxl2b': 61,
 'mmel1': 62,
 'ttc34': 63,
 'actrt2': 64,
 'prdm16': 65,
 'arhgef16': 66,
 'megf6': 67,
 'tprg1l': 68

In [11]:
only_in_all_genes = all_genes_set - adata_genes_set

only_in_adata_genes = adata_genes_set - all_genes_set

# 输出结果
print(f"仅在 all_genes 中存在的基因数量: {len(only_in_all_genes)}")
print(f"仅在 all_genes 中存在的基因: {only_in_all_genes}")

print(f"仅在 adata_genes 中存在的基因数量: {len(only_in_adata_genes)}")
print(f"仅在 adata_genes 中存在的基因: {only_in_adata_genes}")


仅在 all_genes 中存在的基因数量: 19026
仅在 all_genes 中存在的基因: {'fxr1', 'haus3', 'atf5', 'fbxo4', 'pcdh19', 'camsap3', 'atp6v1e2', 'itga1', 'iqcf5', 'defb131b', 'cyb5b', 'usp24', 'or2i1p', 'mccd1', 'bms1', 'clec18a', 'mindy1', 'or2t2', 'rnf17', 'pnliprp3', 'pigv', 'p4ha1', 'wdpcp', 'timd4', 'kdm5d', 'nxph2', 'kiaa0232', 'htra1', 'rsph3', 'slc45a2', 'tmem232', 'slc6a13', 'pde3a', 'arhgef25', 'dhps', 'muc5b', 'tmem101', 'riok1', 'dnajc5g', 'gpr135', 'ppp1r27', 'c17orf100', 'pik3cb', 'map6', 'afp', 'pilrb', 'homer1', 'pcdhga11', 'scn7a', 's100a11', 'adm5', 'agpat5', 'nudt5', 'zmat5', 'cyb5rl', 'krtap9-1', 'fgfbp1', 'parp3', 'unc45b', 'nktr', 'zmym6', 'mapre3', 'dlg5', 'or51a4', 'adgrg6', 'cldn10', 'pkhd1', 'best3', 'gfm2', 'pspc1', 'capns1', 'tspan32', 'mgst1', 'mon1a', 'zscan31', 'ppp1r3e', 'cdk6', 'pcdha8', 'kiaa2026', 'skap1', 'mtmr4', 'hcn3', 'gtf2f2', 'znf48', 'camk2b', 'klhl5', 'rab40b', 'morc1', 'asphd1', 'papss2', 'metap1d', 'h4c8', 'cdk9', 'gls2', 'cthrc1', 'otud1', 'gnat3', 'ska1', 'rfwd3', 

In [12]:
# 将现有的数据填充到新的矩阵中
for i, gene in enumerate(adata_genes_lower):
    if gene in gene_to_index:
        new_data[:, gene_to_index[gene]] = adata.X[:, i].toarray().squeeze()
    else:
        print(f"Gene {gene} not found in all_genes list.")


Gene fam49a not found in all_genes list.


Gene kars not found in all_genes list.


Gene lars not found in all_genes list.


Gene nars not found in all_genes list.


Gene polr2j3 not found in all_genes list.


Gene qars not found in all_genes list.


Gene trac not found in all_genes list.


Gene wars not found in all_genes list.


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


# 使用 GPU 进行评估（如果可用）
estim.trainer = pl.Trainer(accelerator="gpu", devices=1 if torch.cuda.is_available() else None)


from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
# 5. 数据集划分（70% 训练，15% 验证，15% 测试）
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(adata.obs['cell_type'])  # 预先编码标签


random_seed = 42

X_train_val, X_test, y_train_val, y_test = train_test_split(
    new_data, labels_encoded, test_size=0.15, random_state=random_seed)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.1765, random_state=random_seed)  # 0.1765 是为了让验证集占 15%


X_train_tensor = torch.tensor(X_train).float().to(device)
X_test_tensor = torch.tensor(X_test).float().to(device)
estim.model.to(device)

with torch.no_grad():
    train_embeddings = estim.model.inner_model(X_train_tensor).cpu().numpy()
    test_embeddings = estim.model.inner_model(X_test_tensor).cpu().numpy()

GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


cuda


In [14]:
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np


random_states = [42]

# 循环测试不同的 random_state
for state in random_states:
    print(f"Testing with random_state={state}")
    

    # 初始化和训练KNN分类器
    knn = KNeighborsClassifier(n_neighbors=5)
    knn.fit(train_embeddings, y_train)
    # 模型预测
    
    predictions = knn.predict(test_embeddings)

    # 计算准确率和 F1 分数
    accuracy = accuracy_score(y_test, predictions)
    print(f"KNN Accuracy on Test Data: {accuracy}")

    f1 = f1_score(y_test, predictions, average='weighted')
    print(f"Weighted F1 Score: {f1}")
    
    macro_f1 = f1_score(y_test, predictions, average='macro')
    print(f'Macro F1 Score: {macro_f1}')

    # 计算随机猜测的准确率
    class_probabilities = np.bincount(y_test) / len(y_test)
    random_accuracy = np.sum(class_probabilities ** 2)
    print(f"Random Guess Accuracy: {random_accuracy}")


    # 生成分类报告
    report = classification_report(y_test, predictions, target_names=label_encoder.classes_)
    print(report)




Testing with random_state=42


KNN Accuracy on Test Data: 0.7988238566376604
Weighted F1 Score: 0.7932973758105425
Macro F1 Score: 0.6924289482302289
Random Guess Accuracy: 0.1320247177116093
                         precision    recall  f1-score   support

                B_Cells       0.87      0.91      0.89       772
           CD4+_T_Cells       0.69      0.76      0.73      1286
           CD8+_T_Cells       0.71      0.70      0.70      1026
                 DCIS_1       0.63      0.75      0.69      1937
                 DCIS_2       0.68      0.46      0.55      1746
            Endothelial       0.90      0.94      0.92      1348
              IRF7+_DCs       0.74      0.61      0.67        74
         Invasive_Tumor       0.82      0.88      0.85      5230
             LAMP3+_DCs       0.87      0.69      0.77        49
          Macrophages_1       0.82      0.84      0.83      1692
          Macrophages_2       0.62      0.68      0.65       223
             Mast_Cells       0.80      0.57      0.67    

In [15]:
with torch.no_grad():
    new_data_tensor = torch.tensor(new_data).float().to(device)
    SSL_embeddings = estim.model.inner_model(new_data_tensor).detach().cpu().numpy()
new_adata = sc.read_h5ad(data_dir)
new_adata.obsm[f'SSL_BT_ZS_{random_seed}'] = SSL_embeddings
new_adata.uns[f'BT_ZS_y_test_{random_seed}'] = y_test
new_adata.uns[f'BT_ZS_predictions_{random_seed}'] = predictions
new_adata.uns[f'BT_ZS_target_names_{random_seed}'] = label_encoder.classes_
new_adata.write_h5ad(data_dir)

In [16]:

import pandas as pd
import os
import re

# 当前 Notebook 文件名
notebook_name = "Xenium_breast_cancer_sample1_replicate1_barlow_twins_zero_shot_42.ipynb"

# 初始化需要打印的值
init_train_loss = train_losses[0] if 'train_losses' in globals() else None
init_val_loss = val_losses[0] if 'val_losses' in globals() else None
converged_epoch = len(train_losses) - patience if 'train_losses' in globals() else None
converged_val_loss = best_val_loss if 'best_val_loss' in globals() else None

# 打印所有所需的指标
print("Metrics Summary:")
if 'train_losses' in globals():
    print(f"init_train_loss\tinit_val_loss\tconverged_epoch\tconverged_val_loss\tmacro_f1\tweighted_f1\tmicor_f1")
    print(f"{init_train_loss:.3f}\t{init_val_loss:.3f}\t{converged_epoch}\t{converged_val_loss:.3f}\t{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")
else:
    print(f"macro_f1\tweighted_f1\tmicor_f1")
    print(f"{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")

# 保存结果到 CSV 文件
output_data = {
    'dataset_split_random_seed': [int(random_seed)],
    'dataset': ['xenium_breast_cancer_sample1_replicate1'],
    'method': [re.search(r'replicate1_(.*?)_\d+', notebook_name).group(1)],
    'init_train_loss': [init_train_loss if init_train_loss is not None else ''],
    'init_val_loss': [init_val_loss if init_val_loss is not None else ''],
    'converged_epoch': [converged_epoch if converged_epoch is not None else ''],
    'converged_val_loss': [converged_val_loss if converged_val_loss is not None else ''],
    'macro_f1': [macro_f1],
    'weighted_f1': [f1],
    'micor_f1': [accuracy]
}
output_df = pd.DataFrame(output_data)

# 保存到当前目录下名为 results 的文件夹中
if not os.path.exists('results'):
    os.makedirs('results')

csv_filename = f"results/{os.path.splitext(notebook_name)[0]}_results.csv"
output_df.to_csv(csv_filename, index=False)


Metrics Summary:
macro_f1	weighted_f1	micor_f1
0.692	0.793	0.799
