In [1]:
import scanpy as sc
data_dir = '../../dataset/Marshall2022High_human_sampled.h5ad'
adata = sc.read_h5ad(data_dir)
adata.X = adata.raw.X

In [2]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

In [3]:

adata.var['gene_id'] = adata.var.index
adata.var.index = adata.var['gene_name']


In [4]:
%load_ext autoreload
%autoreload 2
import scanpy as sc
import torch
import lightning.pytorch as pl
from self_supervision.models.lightning_modules.cellnet_autoencoder import MLPBarlowTwins
from self_supervision.estimator.cellnet import EstimatorAutoEncoder

# 设置你的 .ckpt 文件路径
ckpt_path = "../../sc_pretrained/Pretrained Models/BarlowTwins.ckpt"

# 模型参数
units_encoder = [512, 512, 256, 256, 64]

# 初始化 EstimatorAutoEncoder 实例
estim = EstimatorAutoEncoder(data_path=None)  # 如果没有实际数据路径，可以设置为None

# 加载预训练模型
estim.model = MLPBarlowTwins(
        gene_dim=19331,  # 根据你的数据调整
        batch_size=2048,  # 根据你的需要调整
        units_encoder=units_encoder,
        CHECKPOINT_PATH=ckpt_path
    )


estim.trainer = pl.Trainer(accelerator="gpu", devices=1 if torch.cuda.is_available() else None)

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


In [5]:
# 加载 checkpoint
checkpoint = torch.load(ckpt_path)
estim.model.inner_model.load_state_dict({k.replace('backbone.', ''): v for k, v in checkpoint.items() if 'backbone' in k})

# 设定模型为评估模式
estim.model.eval()

MLPBarlowTwins(
  (train_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=train_
  )
  (val_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=val_
  )
  (test_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=test_
  )
  (inner_model): MLP(
    (0): Linear(in_features=19331, out_features=512, bias=True)
    (1): SELU()
    (2): Dropout(p=0.0, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): SELU()
    (5): Dropout(p=0.0, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SELU()
    (8): Dropout(p=0.0, inplace=False)
    (9): Linear(in_features=256, out_features=25

In [6]:
import pandas as pd
var_df = pd.read_parquet('../../sc_pretrained/var.parquet')
var_df

Unnamed: 0,feature_id,feature_name
0,ENSG00000186092,OR4F5
1,ENSG00000284733,OR4F29
2,ENSG00000284662,OR4F16
3,ENSG00000187634,SAMD11
4,ENSG00000188976,NOC2L
...,...,...
19326,ENSG00000288702,UGT1A3
19327,ENSG00000288705,UGT1A5
19328,ENSG00000182484,WASH6P
19329,ENSG00000288622,PDCD6-AHRR


In [7]:
all_genes = var_df['feature_name'].tolist()
all_genes

['OR4F5',
 'OR4F29',
 'OR4F16',
 'SAMD11',
 'NOC2L',
 'KLHL17',
 'PLEKHN1',
 'PERM1',
 'HES4',
 'ISG15',
 'AGRN',
 'RNF223',
 'C1orf159',
 'TTLL10',
 'TNFRSF18',
 'TNFRSF4',
 'SDF4',
 'B3GALT6',
 'C1QTNF12',
 'UBE2J2',
 'SCNN1D',
 'ACAP3',
 'PUSL1',
 'INTS11',
 'CPTP',
 'TAS1R3',
 'DVL1',
 'MXRA8',
 'AURKAIP1',
 'CCNL2',
 'MRPL20',
 'ANKRD65',
 'TMEM88B',
 'VWA1',
 'ATAD3C',
 'ATAD3B',
 'ATAD3A',
 'TMEM240',
 'SSU72',
 'FNDC10',
 'MIB2',
 'MMP23B',
 'CDK11B',
 'SLC35E2B',
 'CDK11A',
 'NADK',
 'GNB1',
 'CALML6',
 'TMEM52',
 'CFAP74',
 'GABRD',
 'PRKCZ',
 'FAAP20',
 'SKI',
 'MORN1',
 'RER1',
 'PEX10',
 'PLCH2',
 'PANK4',
 'HES5',
 'TNFRSF14',
 'PRXL2B',
 'MMEL1',
 'TTC34',
 'ACTRT2',
 'PRDM16',
 'ARHGEF16',
 'MEGF6',
 'TPRG1L',
 'WRAP73',
 'TP73',
 'CCDC27',
 'SMIM1',
 'LRRC47',
 'CEP104',
 'DFFB',
 'C1orf174',
 'AJAP1',
 'NPHP4',
 'KCNAB2',
 'CHD5',
 'RPL22',
 'RNF207',
 'ICMT',
 'HES3',
 'GPR153',
 'ACOT7',
 'HES2',
 'ESPN',
 'TNFRSF25',
 'PLEKHG5',
 'NOL9',
 'TAS1R1',
 'ZBTB48',
 'KLH

In [8]:
adata.var['gene_name']=adata.var.index
adata.var['gene_name']

gene_name
A1BG                                              A1BG
A1BG-AS1                                      A1BG-AS1
A1CF                                              A1CF
A2M                                                A2M
A2M-AS1                                        A2M-AS1
                                       ...            
ZZZ3                                              ZZZ3
MISP3                                            MISP3
ENSG00000234352.9                    ENSG00000234352.9
ARHGAP27P1-BPTFP1-KPNA2P3    ARHGAP27P1-BPTFP1-KPNA2P3
SBNO1-AS1                                    SBNO1-AS1
Name: gene_name, Length: 20299, dtype: category
Categories (25391, object): ['7SK_ENSG00000202198', 'A1BG', 'A1BG-AS1', 'A1CF', ..., 'ZYG11B', 'ZYX', 'ZZEF1', 'ZZZ3']

In [9]:
existing_genes = adata.var['gene_name']
existing_genes

gene_name
A1BG                                              A1BG
A1BG-AS1                                      A1BG-AS1
A1CF                                              A1CF
A2M                                                A2M
A2M-AS1                                        A2M-AS1
                                       ...            
ZZZ3                                              ZZZ3
MISP3                                            MISP3
ENSG00000234352.9                    ENSG00000234352.9
ARHGAP27P1-BPTFP1-KPNA2P3    ARHGAP27P1-BPTFP1-KPNA2P3
SBNO1-AS1                                    SBNO1-AS1
Name: gene_name, Length: 20299, dtype: category
Categories (25391, object): ['7SK_ENSG00000202198', 'A1BG', 'A1BG-AS1', 'A1CF', ..., 'ZYG11B', 'ZYX', 'ZZEF1', 'ZZZ3']

In [10]:
# 将所有基因名称转换为小写
all_genes_lower = [gene.lower() for gene in all_genes]
adata_genes_lower = [gene.lower() for gene in existing_genes]

# 将两个列表转换为集合
all_genes_set = set(all_genes_lower)
adata_genes_set = set(adata_genes_lower)

# 计算交集
matching_genes = all_genes_set.intersection(adata_genes_set)
matching_count = len(matching_genes)
# 计算不匹配的基因
non_matching_genes = adata_genes_set - matching_genes
non_matching_count = len(non_matching_genes)


# 输出结果
print(f"匹配的基因数量: {matching_count}")
print(f"匹配的基因列表: {matching_genes}")
non_matching_genes


匹配的基因数量: 15360
匹配的基因列表: {'mpp3', 'llgl1', 'dipk1b', 'grk6', 'irak4', 'spef1', 'il23a', 'sh2d4b', 'plg', 'stox1', 'capn5', 'slc38a11', 'det1', 'hnrnph2', 'drich1', 'pitrm1', 'gucy2c', 'rps15', 'nudt21', 'telo2', 'aco2', 'c1qa', 'strc', 'gm2a', 'grin3a', 'acbd3', 'dock3', 'dchs2', 'hus1', 'hpd', 'stmn1', 'stra6', 'pcnx3', 'cep152', 'plac8', 'cradd', 'kmo', 'smco3', 'frmpd2', 'mapk10', 'cast', 'bak1', 'josd1', 'ndufa5', 'otub1', 'gpr173', 'ifrd1', 'aurkaip1', 'prr11', 'tmem230', 'spry1', 'trim4', 'snapc2', 'usp47', 'hspg2', 'rnmt', 'slc28a1', 'ankrd13b', 'sh3bgrl', 'sfrp4', 'traf6', 'dnajb6', 'nup160', 'luc7l2', 'cabp1', 'acap2', 'ptprd', 'gab3', 'btn2a2', 'atp6v0e1', 'raf1', 'znf326', 'fam83f', 'uba1', 'akap10', 'clk4', 'crppa', 'gmip', 'maf1', 'ankrd36', 'cntnap2', 'fech', 'ttc16', 'ppp1r3f', 'znf600', 'mrpl10', 'susd1', 'akt2', 'wbp2', 'zc4h2', 'anos1', 'gdf6', 'fth1', 'nfix', 'stk26', 'mtmr11', 'cryz', 'hyal4', 'fosl1', 'myo15a', 'ctnnbl1', 'coq10a', 'ptgs2', 'proser3', 'znf480', 'mga

{'ensg00000227681.5',
 'tesc-as1',
 'nepro-as1',
 'syt15-as1',
 'dhrsx_ensg00000169084',
 'gata3-as1',
 'tns2-as1',
 'ensg00000257286.1',
 'ensg00000228150.1',
 'pvt1',
 'ensg00000272221.1',
 'hdac4-as1',
 'ensg00000261386.2',
 'tprg1-as1',
 'ensg00000227388.3',
 'linc02018',
 'ensg00000270956.1',
 'gata2-as1',
 'ensg00000224738.2',
 'nup107-dt',
 'ensg00000263257.3',
 'ensg00000224356.5',
 'ensg00000271978.2',
 'ensg00000269621.1',
 'rnu6-1327p',
 'fhad1-as1',
 'ccr5as',
 'ensg00000241345.1',
 'linc02398',
 'ensg00000268083.5',
 'ensg00000224985.1',
 'ensg00000248101.2',
 'ensg00000266397.1',
 'dennd6a-as1',
 'ensg00000243004.7',
 'ensg00000248491.6',
 'ensg00000259891.1',
 'linc-pint',
 'ensg00000238280.2',
 'catip-as1',
 'xpc-as1',
 'rsl1d1-dt',
 'ensg00000243389.1',
 'ensg00000265393.1',
 'linc02778',
 'adam20p1_ensg00000259158',
 'ensg00000260661.1',
 'ensg00000270871.1',
 'atxn7l3-as1',
 'ensg00000273456.2',
 'namptp1',
 'dlgap1-as3',
 'cbr3-as1',
 'ensg00000260710.3',
 'slc16a4-

In [11]:
gene_to_index = {gene: idx for idx, gene in enumerate(all_genes_lower)}
gene_to_index

{'or4f5': 0,
 'or4f29': 1,
 'or4f16': 2,
 'samd11': 3,
 'noc2l': 4,
 'klhl17': 5,
 'plekhn1': 6,
 'perm1': 7,
 'hes4': 8,
 'isg15': 9,
 'agrn': 10,
 'rnf223': 11,
 'c1orf159': 12,
 'ttll10': 13,
 'tnfrsf18': 14,
 'tnfrsf4': 15,
 'sdf4': 16,
 'b3galt6': 17,
 'c1qtnf12': 18,
 'ube2j2': 19,
 'scnn1d': 20,
 'acap3': 21,
 'pusl1': 22,
 'ints11': 23,
 'cptp': 24,
 'tas1r3': 25,
 'dvl1': 26,
 'mxra8': 27,
 'aurkaip1': 28,
 'ccnl2': 29,
 'mrpl20': 30,
 'ankrd65': 31,
 'tmem88b': 32,
 'vwa1': 33,
 'atad3c': 34,
 'atad3b': 35,
 'atad3a': 36,
 'tmem240': 37,
 'ssu72': 38,
 'fndc10': 39,
 'mib2': 40,
 'mmp23b': 41,
 'cdk11b': 42,
 'slc35e2b': 43,
 'cdk11a': 44,
 'nadk': 45,
 'gnb1': 46,
 'calml6': 47,
 'tmem52': 48,
 'cfap74': 49,
 'gabrd': 50,
 'prkcz': 51,
 'faap20': 52,
 'ski': 53,
 'morn1': 54,
 'rer1': 55,
 'pex10': 56,
 'plch2': 57,
 'pank4': 58,
 'hes5': 59,
 'tnfrsf14': 60,
 'prxl2b': 61,
 'mmel1': 62,
 'ttc34': 63,
 'actrt2': 64,
 'prdm16': 65,
 'arhgef16': 66,
 'megf6': 67,
 'tprg1l': 68

In [12]:
only_in_all_genes = all_genes_set - adata_genes_set

only_in_adata_genes = adata_genes_set - all_genes_set

# 输出结果
print(f"仅在 all_genes 中存在的基因数量: {len(only_in_all_genes)}")
print(f"仅在 all_genes 中存在的基因: {only_in_all_genes}")

print(f"仅在 adata_genes 中存在的基因数量: {len(only_in_adata_genes)}")
print(f"仅在 adata_genes 中存在的基因: {only_in_adata_genes}")


仅在 all_genes 中存在的基因数量: 3971
仅在 all_genes 中存在的基因: {'foxd4l6', 'tmem275', 'rfx6', 'ceacam8', 'cxorf65', 'plscr5', 'or51e2', 'zbed2', 'h2ac16', 'tas2r8', 'tmem239', 'rpel1', 'tas2r20', 'or4k13', 'mafa', 'shox2', 'golga8n', 'or5h14', 'prag1', 'dmrtc2', 'nxpe2', 'ct83', 'spdye5', 'cdy1b', 'poted', 'bpifb2', 'tiaf1', 'nyap1', 'tchh', 'nt5c1a', 'ct55', 'c17orf50', 'cgb5', 'linc02210-crhr1', 'or5d16', 'trim75p', 'dipk1c', 'h3c10', 'or2i1p', 'cbsl', 'znf560', 'myocos', 'rnase13', 'cdh19', 'prr18', 'usp29', 'or4c16', 'tgfbr3l', 'spanxd', 'dnajc25-gng10', 'rnase2', 'ftmt', 'usp17l26', 'taz', 'defb103b', 'or5ar1', 'ppp4r3c', 'gage12j', 'cnmd', 'psma1', 'spata31a1', 'tex13a', 'foxh1', 'cby2', 'akr1b10', 'or4a15', 'kbtbd4', 'or56a1', 'klrc4', 'cd8b2', 'phox2b', 'cylc1', 'krt26', 'ace', 'fam90a18p', 'h2aw', 'btnl3', 'h4c2', 'csf2ra', 'cfhr5', 'catsperz', 'h1-6', 'ctage6', 'gpr179', 'or2t1', 'klhl34', 'higd2b', 'defb136', 'foxl2nb', 'accsl', 'or5d13', 'or5ap2', 'igfals', 'magec1', 'tectb', 'c22orf15',

In [13]:
import numpy as np
from scipy.sparse import csr_matrix

# Initialize a mapping from gene names in adata to their column indices
adata_gene_to_index = {gene: idx for idx, gene in enumerate(adata_genes_lower)}

# Create an array to map adata.X column indices to new_data column indices
adata_to_new_data_indices = -1 * np.ones(adata.X.shape[1], dtype=int)
for idx, gene in enumerate(adata_genes_lower):
    if gene in gene_to_index:
        adata_to_new_data_indices[idx] = gene_to_index[gene]



# Extract data from adata.X without converting it to a dense array
data = adata.X.data
indices = adata.X.indices
indptr = adata.X.indptr

# Map the column indices to the new indices in new_data
mapped_indices = adata_to_new_data_indices[indices]

# Filter out entries where the mapping is invalid (-1)
valid_entries = mapped_indices != -1
new_data_values = data[valid_entries]
new_data_indices = mapped_indices[valid_entries]

# Build the new indptr array for the new_data matrix
new_indptr = np.zeros(adata.X.shape[0] + 1, dtype=int)


for i in range(adata.X.shape[0]):
    row_start = indptr[i]
    row_end = indptr[i + 1]
    valid_count = np.sum(valid_entries[row_start:row_end])
    new_indptr[i + 1] = new_indptr[i] + valid_count


# Construct the new_data sparse matrix
new_data = csr_matrix(
    (new_data_values, new_data_indices, new_indptr),
    shape=(adata.X.shape[0], len(all_genes)),
    dtype=np.float32
)
new_data = new_data.toarray()

In [14]:
from sklearn.metrics import accuracy_score, f1_score
from sklearn.preprocessing import LabelEncoder
import numpy as np
from sklearn.model_selection import train_test_split


label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(adata.obs['cell_type'])  # 预先编码标签


random_seed = 42
X_train_val, X_test, y_train_val, y_test = train_test_split(
    new_data, labels_encoded, test_size=0.15, random_state=random_seed)


X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.1765, random_state=random_seed)  # 0.1765 是为了让验证集占 15%

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
estim.model.to(device)
estim.model.eval()
with torch.no_grad():
    X_train_tensor = torch.tensor(X_train).float().to(device)
    X_test_tensor = torch.tensor(X_test).float().to(device)
    train_embeddings = estim.model.inner_model(X_train_tensor).detach().cpu().numpy()
    test_embeddings = estim.model.inner_model(X_test_tensor).detach().cpu().numpy()


cuda


In [15]:
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report

    

    # 初始化和训练KNN分类器
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(train_embeddings, y_train)
    
    # 模型预测
predictions = knn.predict(test_embeddings)

    # 计算准确率和 F1 分数
accuracy = accuracy_score(y_test, predictions)
print(f"KNN Accuracy on Test Data: {accuracy}")
f1 = f1_score(y_test, predictions, average='weighted')
print(f"Weighted F1 Score: {f1}")
    
macro_f1 = f1_score(y_test, predictions, average='macro')
print(f'Macro F1 Score: {macro_f1}')

    # 计算随机猜测的准确率
class_probabilities = np.bincount(y_test) / len(y_test)
random_accuracy = np.sum(class_probabilities ** 2)
print(f"Random Guess Accuracy: {random_accuracy}")

    # 生成分类报告
report = classification_report(y_test, predictions, target_names=label_encoder.classes_)
print(report)

KNN Accuracy on Test Data: 0.595791270783848
Weighted F1 Score: 0.5636010112940153
Macro F1 Score: 0.27089193307730425
Random Guess Accuracy: 0.2927884776016484
                                                           precision    recall  f1-score   support

                          blood vessel smooth muscle cell       0.26      0.14      0.19       852
                                         endothelial cell       0.64      0.83      0.72     13511
                 kidney collecting duct intercalated cell       0.23      0.16      0.19      1032
                    kidney collecting duct principal cell       0.51      0.53      0.52      2040
          kidney distal convoluted tubule epithelial cell       0.37      0.22      0.28       618
                                     kidney granular cell       0.00      0.00      0.00       105
                           kidney interstitial fibroblast       0.30      0.06      0.10       781
kidney loop of Henle thick ascending limb epit

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [16]:
with torch.no_grad():
    new_data_tensor = torch.tensor(new_data).float().to(device)
    SSL_embeddings = estim.model.inner_model(new_data_tensor).detach().cpu().numpy()
new_adata = sc.read_h5ad(data_dir)
new_adata.obsm[f'SSL_BT_ZS_{random_seed}'] = SSL_embeddings
new_adata.uns[f'BT_ZS_y_test_{random_seed}'] = y_test
new_adata.uns[f'BT_ZS_predictions_{random_seed}'] = predictions
new_adata.uns[f'BT_ZS_target_names_{random_seed}'] = label_encoder.classes_
new_adata.write_h5ad(data_dir)

In [17]:

import pandas as pd
import os
import re

# 当前 Notebook 文件名
notebook_name = "slide_seq_human_kidney_barlow_twins_zero_shot_42.ipynb"

# 初始化需要打印的值
init_train_loss = train_losses[0] if 'train_losses' in globals() else None
init_val_loss = val_losses[0] if 'val_losses' in globals() else None
converged_epoch = len(train_losses) - patience if 'train_losses' in globals() else None
converged_val_loss = best_val_loss if 'best_val_loss' in globals() else None

# 打印所有所需的指标
print("Metrics Summary:")
if 'train_losses' in globals():
    print(f"init_train_loss\tinit_val_loss\tconverged_epoch\tconverged_val_loss\tmacro_f1\tweighted_f1\tmicor_f1")
    print(f"{init_train_loss:.3f}\t{init_val_loss:.3f}\t{converged_epoch}\t{converged_val_loss:.3f}\t{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")
else:
    print(f"macro_f1\tweighted_f1\tmicor_f1")
    print(f"{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")

# 保存结果到 CSV 文件
output_data = {
    'dataset_split_random_seed': [int(random_seed)],
    'dataset': ['slide_seq_human_kidney'],
    'method': [re.search(r'kidney_(.*?)_\d+', notebook_name).group(1)],
    'init_train_loss': [init_train_loss if init_train_loss is not None else ''],
    'init_val_loss': [init_val_loss if init_val_loss is not None else ''],
    'converged_epoch': [converged_epoch if converged_epoch is not None else ''],
    'converged_val_loss': [converged_val_loss if converged_val_loss is not None else ''],
    'macro_f1': [macro_f1],
    'weighted_f1': [f1],
    'micor_f1': [accuracy]
}
output_df = pd.DataFrame(output_data)

# 保存到当前目录下名为 results 的文件夹中
if not os.path.exists('results'):
    os.makedirs('results')

csv_filename = f"results/{os.path.splitext(notebook_name)[0]}_results.csv"
output_df.to_csv(csv_filename, index=False)


Metrics Summary:
macro_f1	weighted_f1	micor_f1
0.271	0.564	0.596
