In [1]:
import scanpy as sc
import torch
import lightning.pytorch as pl
from torch import nn
from torch.optim import AdamW
from self_supervision.models.lightning_modules.cellnet_autoencoder import MLPAutoEncoder
from self_supervision.estimator.cellnet import EstimatorAutoEncoder
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np
import pandas as pd

# 1. 加载数据集
data_dir = '../../dataset/uniport_imputed_Xenium_breast_cancer_sample1_replicate1.h5ad'
adata = sc.read_h5ad(data_dir)



# 2. 设置 checkpoint 路径
ckpt_path = "../../sc_pretrained/Pretrained Models/GPMask.ckpt"

# 3. 模型参数
units_encoder = [512, 512, 256, 256, 64]
units_decoder = [256, 256, 512, 512]

# 初始化 EstimatorAutoEncoder 实例
estim = EstimatorAutoEncoder(data_path=None)  # 没有实际数据路径，可以设置为None

# 4. 加载预训练模型
estim.model = MLPAutoEncoder.load_from_checkpoint(
    ckpt_path,
    gene_dim=19331,  # 根据你的数据调整
    batch_size=128,  # 根据你的需要调整
    units_encoder=units_encoder, 
    units_decoder=units_decoder,
    masking_strategy="random",  # 假设模型使用了随机掩码
    masking_rate=0.5,  # 根据需要调整
)

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


In [2]:
# 添加分类层 (FC)
n_classes = len(adata.obs['cell_type'].unique())
estim.model.fc = nn.Linear(units_encoder[-1], n_classes)
n_classes

20

In [3]:
var_df = pd.read_parquet('../../sc_pretrained/var.parquet')
all_genes = var_df['feature_name'].tolist()

new_data = np.zeros((adata.X.shape[0], len(all_genes)), dtype=np.float32)

adata.var['gene_name']=adata.var.index
existing_genes = adata.var['gene_name']

In [4]:
# 将所有基因名称转换为小写
all_genes_lower = [gene.lower() for gene in all_genes]
adata_genes_lower = [gene.lower() for gene in existing_genes]

# 将两个列表转换为集合
all_genes_set = set(all_genes_lower)
adata_genes_set = set(adata_genes_lower)

# 计算交集
matching_genes = all_genes_set.intersection(adata_genes_set)
matching_count = len(matching_genes)
# 计算不匹配的基因
non_matching_genes = adata_genes_set - matching_genes
non_matching_count = len(non_matching_genes)


# 输出结果
print(f"匹配的基因数量: {matching_count}")
print(f"匹配的基因列表: {matching_genes}")
non_matching_genes

匹配的基因数量: 17768
匹配的基因列表: {'col20a1', 'slc8b1', 'fbxw10', 'ildr2', 'edf1', 'morc3', 'malrd1', 'tmtc1', 'amh', 'abhd5', 'plag1', 'pccb', 'pprc1', 'theg', 'paxbp1', 'adamts9', 'isl2', 'usp8', 'stau1', 'or2t2', 'rad9b', 'plk5', 'mib2', 'slc4a5', 'arhgap31', 'dnaja1', 'fam186b', 'cyb561d1', 'angptl1', 'trmt11', 'adm', 'plaa', 'slc27a5', 'clecl1', 'znrf1', 'macrod2', 'zswim1', 'mbd1', 'rtl8a', 'nefm', 'ddx53', 'zranb2', 'ints1', 'hepacam', 'gpsm2', 'mc1r', 'ten1', 'soga1', 'clcn3', 'krtap21-1', 'med21', 'il13', 'actrt3', 'lrrc34', 'pcsk2', 'fkbp14', 'ngfr', 'ncl', 'magi1', 'klhl24', 'c1orf194', 'mrgprx3', 'arhgap10', 'map1a', 'six5', 'mt3', 'calhm4', 'zscan26', 'uaca', 'nbdy', 'prdm5', 'cep192', 'tinag', 'brd8', 'mtnr1b', 'rhbdd1', 'map3k13', 'loxl1', 'znf341', 'ncor2', 'cox10', 'fst', 'commd4', 'c8orf88', 'nme8', 'meltf', 'prlh', 'etf1', 'map2k3', 'zfp69b', 'trmt10a', 'pgbd4', 'timp1', 'c1orf100', 'sema3g', 'syne1', 'malsu1', 'tex29', 'znf587', 'slc7a3', 'kank4', 'cep170', 'tmem150b', 'lamc2

{'aars',
 'ac004593.3',
 'ac005551.1',
 'ac007244.1',
 'ac007906.2',
 'ac008397.1',
 'ac008687.4',
 'ac010255.3',
 'ac010325.1',
 'ac011005.1',
 'ac011195.2',
 'ac013470.2',
 'ac015871.1',
 'ac021072.1',
 'ac021097.2',
 'ac025283.2',
 'ac025287.4',
 'ac067752.1',
 'ac072022.2',
 'ac073111.4',
 'ac087498.1',
 'ac090360.1',
 'ac091057.6',
 'ac092835.1',
 'ac093323.1',
 'ac099489.1',
 'ac104389.5',
 'ac106774.4',
 'ac113348.1',
 'ac115220.1',
 'ac118549.1',
 'ac119396.1',
 'ac132217.2',
 'ac134684.8',
 'ac135068.1',
 'ac138647.1',
 'ac187653.1',
 'ac233723.1',
 'ac236972.4',
 'acpp',
 'adprhl2',
 'adss',
 'adssl1',
 'al032819.3',
 'al109810.2',
 'al121578.2',
 'al135905.2',
 'al160269.1',
 'al162231.1',
 'al162596.1',
 'al353572.3',
 'al354761.1',
 'al391650.1',
 'al445238.1',
 'al451007.3',
 'al590560.2',
 'al603764.2',
 'al772284.2',
 'al845331.2',
 'alg1l',
 'ap000552.4',
 'ap002495.1',
 'arih2os',
 'armc4',
 'arse',
 'atp5md',
 'atp5mpl',
 'atp6ap1l',
 'b3gnt10',
 'bhmg1',
 'bx255925.

In [5]:
gene_to_index = {gene: idx for idx, gene in enumerate(all_genes_lower)}
dense_adata_X = adata.X
for i, gene in enumerate(adata_genes_lower):
    if gene in gene_to_index:
        new_data[:, gene_to_index[gene]] = dense_adata_X[:, i]
    else:
        print(f'Gene {gene} not found in all_genes list')

Gene al391650.1 not found in all_genes list


Gene yars not found in all_genes list
Gene adprhl2 not found in all_genes list


Gene tctex1d4 not found in all_genes list


Gene tctex1d1 not found in all_genes list
Gene wdr78 not found in all_genes list
Gene hhla3 not found in all_genes list
Gene ac118549.1 not found in all_genes list
Gene wdr63 not found in all_genes list


Gene kiaa1324 not found in all_genes list
Gene sars not found in all_genes list


Gene hist2h2be not found in all_genes list


Gene al162596.1 not found in all_genes list
Gene lor not found in all_genes list


Gene c1orf61 not found in all_genes list


Gene al590560.2 not found in all_genes list


Gene rgs5 not found in all_genes list
Gene dusp27 not found in all_genes list


Gene eprs not found in all_genes list
Gene marc2 not found in all_genes list
Gene marc1 not found in all_genes list
Gene h3f3a not found in all_genes list


Gene hist3h3 not found in all_genes list
Gene hist3h2a not found in all_genes list
Gene hist3h2bb not found in all_genes list
Gene al109810.2 not found in all_genes list
Gene tbce not found in all_genes list


Gene adss not found in all_genes list
Gene al451007.3 not found in all_genes list
Gene gcsaml-as1 not found in all_genes list


Gene mycnos not found in all_genes list
Gene fam49a not found in all_genes list


Gene c2orf91 not found in all_genes list


Gene igkc not found in all_genes list
Gene al845331.2 not found in all_genes list
Gene ac092835.1 not found in all_genes list
Gene kiaa1211l not found in all_genes list


Gene dars not found in all_genes list
Gene march7 not found in all_genes list


Gene pde11a not found in all_genes list
Gene dirc1 not found in all_genes list


Gene march4 not found in all_genes list


Gene ccdc140 not found in all_genes list
Gene c2orf83 not found in all_genes list


Gene arih2os not found in all_genes list
Gene qars not found in all_genes list
Gene ccdc36 not found in all_genes list
Gene cyb561d2 not found in all_genes list


Gene c3orf67 not found in all_genes list


Gene maats1 not found in all_genes list
Gene alg1l not found in all_genes list


Gene kiaa1257 not found in all_genes list
Gene h1fx not found in all_genes list
Gene h1foo not found in all_genes list
Gene acpp not found in all_genes list


Gene slc66a1l not found in all_genes list
Gene terc not found in all_genes list


Gene ccdc39 not found in all_genes list


Gene ac072022.2 not found in all_genes list
Gene tctex1d2 not found in all_genes list


Gene ac093323.1 not found in all_genes list


Gene kiaa1211 not found in all_genes list


Gene h2afz not found in all_genes list


Gene tmem155 not found in all_genes list


Gene march1 not found in all_genes list
Gene fam218a not found in all_genes list


Gene march6 not found in all_genes list
Gene march11 not found in all_genes list
Gene ac106774.4 not found in all_genes list
Gene h3.y not found in all_genes list
Gene tars not found in all_genes list


Gene c5orf67 not found in all_genes list


Gene atp6ap1l not found in all_genes list
Gene c5orf30 not found in all_genes list


Gene ac010255.3 not found in all_genes list
Gene march3 not found in all_genes list


Gene h2afy not found in all_genes list
Gene tmem173 not found in all_genes list


Gene hars not found in all_genes list


Gene lars not found in all_genes list


Gene rars not found in all_genes list


Gene ac113348.1 not found in all_genes list
Gene c5orf60 not found in all_genes list
Gene c6orf201 not found in all_genes list


Gene hist1h2aa not found in all_genes list
Gene hist1h2ba not found in all_genes list
Gene hist1h1a not found in all_genes list
Gene hist1h4b not found in all_genes list
Gene hist1h2bb not found in all_genes list
Gene hist1h1c not found in all_genes list
Gene hist1h4c not found in all_genes list
Gene hist1h2ac not found in all_genes list
Gene hist1h1e not found in all_genes list
Gene hist1h4e not found in all_genes list
Gene hist1h2bg not found in all_genes list
Gene hist1h2ae not found in all_genes list
Gene hist1h3e not found in all_genes list
Gene hist1h1d not found in all_genes list
Gene hist1h4g not found in all_genes list
Gene hist1h2bh not found in all_genes list
Gene hist1h3g not found in all_genes list
Gene hist1h2ag not found in all_genes list
Gene hist1h4i not found in all_genes list
Gene hist1h2ai not found in all_genes list
Gene hist1h3h not found in all_genes list
Gene hist1h4j not found in all_genes list
Gene hist1h2bn not found in all_genes list
Gene hist1h2ak not found

Gene znrd1 not found in all_genes list


Gene vars not found in all_genes list
Gene snhg32 not found in all_genes list


Gene c6orf223 not found in all_genes list
Gene defb133 not found in all_genes list


Gene ick not found in all_genes list
Gene al135905.2 not found in all_genes list


Gene fgfr1op not found in all_genes list
Gene tcte3 not found in all_genes list
Gene ac187653.1 not found in all_genes list


Gene ac013470.2 not found in all_genes list
Gene twistnb not found in all_genes list


Gene ac004593.3 not found in all_genes list
Gene gars not found in all_genes list
Gene trgc2 not found in all_genes list
Gene trgjp2 not found in all_genes list
Gene trgc1 not found in all_genes list
Gene trgjp1 not found in all_genes list


Gene ac115220.1 not found in all_genes list


Gene kiaa1324l not found in all_genes list


Gene castor3 not found in all_genes list


Gene c7orf77 not found in all_genes list
Gene ac011005.1 not found in all_genes list


Gene trbc1 not found in all_genes list
Gene trbc2 not found in all_genes list
Gene sspo not found in all_genes list
Gene ac073111.4 not found in all_genes list


Gene ac021097.2 not found in all_genes list
Gene wdr60 not found in all_genes list
Gene ac134684.8 not found in all_genes list


Gene pinx1 not found in all_genes list


Gene impad1 not found in all_genes list


Gene wdyhv1 not found in all_genes list
Gene fam49b not found in all_genes list
Gene ac138647.1 not found in all_genes list


Gene tsta3 not found in all_genes list
Gene dock8-as1 not found in all_genes list


Gene c9orf92 not found in all_genes list


Gene al162231.1 not found in all_genes list


Gene fam122a not found in all_genes list
Gene al353572.3 not found in all_genes list


Gene iars not found in all_genes list
Gene c9orf129 not found in all_genes list
Gene al160269.1 not found in all_genes list


Gene tmem246 not found in all_genes list
Gene palm2-akap2 not found in all_genes list
Gene znf883 not found in all_genes list


Gene dec1 not found in all_genes list
Gene b3gnt10 not found in all_genes list


Gene wdr34 not found in all_genes list


Gene al354761.1 not found in all_genes list
Gene bx255925.3 not found in all_genes list


Gene mir1915hg not found in all_genes list
Gene armc4 not found in all_genes list


Gene c10orf142 not found in all_genes list
Gene march8 not found in all_genes list


Gene ac067752.1 not found in all_genes list
Gene kif1bp not found in all_genes list
Gene h2afy2 not found in all_genes list


Gene c10orf55 not found in all_genes list
Gene dupd1 not found in all_genes list


Gene march5 not found in all_genes list


Gene atp5md not found in all_genes list


Gene al603764.2 not found in all_genes list


Gene pano1 not found in all_genes list
Gene ac132217.2 not found in all_genes list
Gene cars not found in all_genes list
Gene c11orf40 not found in all_genes list


Gene ac104389.5 not found in all_genes list


Gene st5 not found in all_genes list
Gene mrvi1 not found in all_genes list


Gene c11orf74 not found in all_genes list


Gene or5r1 not found in all_genes list


Gene ap002495.1 not found in all_genes list


Gene card16 not found in all_genes list
Gene card17 not found in all_genes list
Gene c11orf88 not found in all_genes list


Gene ccdc84 not found in all_genes list
Gene h2afx not found in all_genes list


Gene hist4h4 not found in all_genes list
Gene h2afj not found in all_genes list
Gene lrmp not found in all_genes list
Gene casc1 not found in all_genes list


Gene h3f3c not found in all_genes list
Gene h1fnt not found in all_genes list


Gene c12orf81 not found in all_genes list
Gene grasp not found in all_genes list
Gene ac021072.1 not found in all_genes list
Gene c12orf10 not found in all_genes list


Gene mars not found in all_genes list
Gene slc26a10 not found in all_genes list
Gene march9 not found in all_genes list


Gene cllu1os not found in all_genes list
Gene c12orf74 not found in all_genes list


Gene c12orf49 not found in all_genes list
Gene wdr66 not found in all_genes list


Gene spata13 not found in all_genes list


Gene spert not found in all_genes list
Gene al445238.1 not found in all_genes list


Gene trdc not found in all_genes list
Gene trac not found in all_genes list


Gene sfta3 not found in all_genes list


Gene elmsan1 not found in all_genes list


Gene c14orf177 not found in all_genes list
Gene wars not found in all_genes list
Gene atp5mpl not found in all_genes list
Gene adssl1 not found in all_genes list


Gene igha2 not found in all_genes list
Gene ighe not found in all_genes list
Gene igha1 not found in all_genes list
Gene ighg1 not found in all_genes list
Gene ighg3 not found in all_genes list
Gene ighd not found in all_genes list
Gene ighm not found in all_genes list
Gene fam30a not found in all_genes list
Gene ac135068.1 not found in all_genes list
Gene golga8m not found in all_genes list
Gene ac091057.6 not found in all_genes list
Gene c15orf41 not found in all_genes list


Gene linc02694 not found in all_genes list


Gene casc4 not found in all_genes list


Gene ct62 not found in all_genes list


Gene ac015871.1 not found in all_genes list


Gene spata8 not found in all_genes list
Gene fam169b not found in all_genes list
Gene tarsl2 not found in all_genes list
Gene tmem8a not found in all_genes list


Gene al032819.3 not found in all_genes list


Gene ac025283.2 not found in all_genes list


Gene ac099489.1 not found in all_genes list
Gene fopnl not found in all_genes list


Gene kiaa0556 not found in all_genes list


Gene c16orf58 not found in all_genes list
Gene ac007906.2 not found in all_genes list


Gene fam192a not found in all_genes list


Gene lrrc29 not found in all_genes list


Gene aars not found in all_genes list
Gene kars not found in all_genes list
Gene ac025287.4 not found in all_genes list


Gene fam92b not found in all_genes list
Gene cenpbd1 not found in all_genes list


Gene ac087498.1 not found in all_genes list
Gene ac233723.1 not found in all_genes list


Gene trim16l not found in all_genes list
Gene linc02693 not found in all_genes list


Gene slfn12l not found in all_genes list


Gene tmem99 not found in all_genes list
Gene ttc25 not found in all_genes list


Gene g6pc not found in all_genes list
Gene c17orf53 not found in all_genes list


Gene ac011195.2 not found in all_genes list
Gene march10 not found in all_genes list


Gene h3f3b not found in all_genes list


Gene eloa3 not found in all_genes list
Gene nars not found in all_genes list


Gene ac090360.1 not found in all_genes list


Gene ac005551.1 not found in all_genes list


Gene ac119396.1 not found in all_genes list
Gene march2 not found in all_genes list


Gene ccdc151 not found in all_genes list


Gene c19orf57 not found in all_genes list


Gene ac008397.1 not found in all_genes list


Gene kiaa0355 not found in all_genes list


Gene cntd2 not found in all_genes list


Gene cd3eap not found in all_genes list
Gene bhmg1 not found in all_genes list
Gene ppp5d1 not found in all_genes list


Gene ccdc114 not found in all_genes list
Gene ac008687.4 not found in all_genes list
Gene ccdc155 not found in all_genes list


Gene ac010325.1 not found in all_genes list
Gene c19orf48 not found in all_genes list
Gene siglec5 not found in all_genes list


Gene gdf5os not found in all_genes list


Gene tmem189 not found in all_genes list


Gene fp565260.1 not found in all_genes list


Gene ap000552.4 not found in all_genes list
Gene iglc1 not found in all_genes list
Gene iglc7 not found in all_genes list
Gene lrp5l not found in all_genes list


Gene elfn2 not found in all_genes list
Gene h1f0 not found in all_genes list
Gene z82206.1 not found in all_genes list


Gene arse not found in all_genes list


Gene cxorf21 not found in all_genes list
Gene hypm not found in all_genes list
Gene al121578.2 not found in all_genes list


Gene bx276092.9 not found in all_genes list


Gene nxf5 not found in all_genes list
Gene glra4 not found in all_genes list
Gene tmsb15b not found in all_genes list
Gene h2bfwt not found in all_genes list
Gene h2bfm not found in all_genes list
Gene pih1d3 not found in all_genes list


Gene al772284.2 not found in all_genes list
Gene cxorf56 not found in all_genes list
Gene fam122b not found in all_genes list
Gene fam122c not found in all_genes list


Gene cxorf40a not found in all_genes list
Gene ac236972.4 not found in all_genes list


Gene prky not found in all_genes list
Gene ac007244.1 not found in all_genes list


In [6]:
# 使用 GPU 进行评估（如果可用）
estim.trainer = pl.Trainer(accelerator="gpu", devices=1 if torch.cuda.is_available() else None)

# 5. 数据集划分（70% 训练，15% 验证，15% 测试）
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(adata.obs['cell_type'])  # 预先编码标签


random_seed = 42

X_train_val, X_test, y_train_val, y_test = train_test_split(
    new_data, labels_encoded, test_size=0.15, random_state=random_seed)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.1765, random_state=random_seed)  # 0.1765 是为了让验证集占 15%

GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


In [7]:
# 将训练数据转换为张量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_train_tensor = torch.tensor(X_train).float().to(device)
y_train_tensor = torch.tensor(y_train).long().to(device)
X_val_tensor = torch.tensor(X_val).float().to(device)
y_val_tensor = torch.tensor(y_val).long().to(device)
estim.model.to(device)

MLPAutoEncoder(
  (train_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=train_
  )
  (val_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=val_
  )
  (test_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=test_
  )
  (encoder): MLP(
    (0): Linear(in_features=19331, out_features=512, bias=True)
    (1): SELU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): SELU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SELU()
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=256, out_features=256, b

In [8]:
# 6. 微调模型，仅微调 encoder 的最后两层，其他层参数冻结
for param in estim.model.encoder.parameters():
    param.requires_grad = False  # 冻结所有层

# 解冻最后两层
for param in list(estim.model.encoder.parameters())[-5:]:
    param.requires_grad = True

In [9]:
# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(filter(lambda p: p.requires_grad, estim.model.parameters()), lr=9e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

In [10]:
from torch.utils.data import DataLoader, TensorDataset

# 设置 batch size
batch_size = 128  # 根据实际需求调整 batch size

# 使用 TensorDataset 将训练数据和标签打包
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

# 使用 DataLoader 来创建批次
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# 7. 微调模型
def train_epoch(model, optimizer, loss_fn, train_loader, val_loader):
    model.train()
    total_train_loss = 0
    
    # 训练集批次训练
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model.encoder(X_batch)
        logits = model.fc(outputs)
        
        # 计算损失
        loss = loss_fn(logits, y_batch)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    # 验证集
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for X_val_batch, y_val_batch in val_loader:
            val_outputs = model.encoder(X_val_batch)
            val_logits = model.fc(val_outputs)
            val_loss = loss_fn(val_logits, y_val_batch)
            total_val_loss += val_loss.item()

    # 返回平均损失
    return total_train_loss / len(train_loader), total_val_loss / len(val_loader)

In [11]:
import copy  # 用于保存模型的最佳状态

# Early Stopping 参数
patience = 20  # 如果验证损失在 10 个 epoch 中没有改善，停止训练
min_delta = 1e-4  # 最小改善幅度
patience_counter = 0
best_val_loss = float('inf')  # 初始设置为正无穷大
best_model_weights = copy.deepcopy(estim.model.state_dict())  # 保存最佳模型权重
train_losses = []
val_losses = []

# 训练 500 个 epoch
for epoch in range(500):
    train_loss, val_loss = train_epoch(estim.model, optimizer, loss_fn, train_loader, val_loader)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}')
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Early Stopping 检查
    if val_loss < best_val_loss - min_delta:
        best_val_loss = val_loss  # 更新最佳验证损失
        patience_counter = 0  # 重置 patience 计数器
        best_model_weights = copy.deepcopy(estim.model.state_dict())  # 保存当前最佳模型
        print(f"Validation loss improved to {best_val_loss}, resetting patience.")
    else:
        patience_counter += 1
        print(f"No improvement in validation loss. Patience counter: {patience_counter}/{patience}")
    
    # 如果 patience_counter 超过设置的 patience，停止训练
    if patience_counter >= patience:
        print(f"Early stopping triggered. Stopping training at epoch {epoch+1}.")
        break

# 恢复到训练中性能最好的模型权重
estim.model.load_state_dict(best_model_weights)
print("Loaded best model weights based on validation loss.")

Epoch 1, Train Loss: 1.892190592629569, Validation Loss: 1.4981382571733914
Validation loss improved to 1.4981382571733914, resetting patience.


Epoch 2, Train Loss: 1.7417605875612616, Validation Loss: 1.4271620738200652
Validation loss improved to 1.4271620738200652, resetting patience.


Epoch 3, Train Loss: 1.7127423198668512, Validation Loss: 1.4538633701128838
No improvement in validation loss. Patience counter: 1/20


Epoch 4, Train Loss: 1.6947079427949674, Validation Loss: 1.4112614961770864
Validation loss improved to 1.4112614961770864, resetting patience.


Epoch 5, Train Loss: 1.6846390813261598, Validation Loss: 1.3680381878828392
Validation loss improved to 1.3680381878828392, resetting patience.


Epoch 6, Train Loss: 1.6712794580302395, Validation Loss: 1.3962430752240695
No improvement in validation loss. Patience counter: 1/20


Epoch 7, Train Loss: 1.6639571576328067, Validation Loss: 1.3794479743028298
No improvement in validation loss. Patience counter: 2/20


Epoch 8, Train Loss: 1.6545293400575827, Validation Loss: 1.3613105987891172
Validation loss improved to 1.3613105987891172, resetting patience.


Epoch 9, Train Loss: 1.6530160192604904, Validation Loss: 1.3524856371757312
Validation loss improved to 1.3524856371757312, resetting patience.


Epoch 10, Train Loss: 1.639305068896367, Validation Loss: 1.3398430426915486
Validation loss improved to 1.3398430426915486, resetting patience.


Epoch 11, Train Loss: 1.641281808601631, Validation Loss: 1.4153000611525315
No improvement in validation loss. Patience counter: 1/20


Epoch 12, Train Loss: 1.630989197060302, Validation Loss: 1.3317451116366263
Validation loss improved to 1.3317451116366263, resetting patience.


Epoch 13, Train Loss: 1.6261838966673547, Validation Loss: 1.3551920805221949
No improvement in validation loss. Patience counter: 1/20


Epoch 14, Train Loss: 1.6252197187025468, Validation Loss: 1.3998368471096723
No improvement in validation loss. Patience counter: 2/20


Epoch 15, Train Loss: 1.6223542723026905, Validation Loss: 1.3408564604245699
No improvement in validation loss. Patience counter: 3/20


Epoch 16, Train Loss: 1.6190678722255833, Validation Loss: 1.296898857752482
Validation loss improved to 1.296898857752482, resetting patience.


Epoch 17, Train Loss: 1.6128789808723953, Validation Loss: 1.3365348461346749
No improvement in validation loss. Patience counter: 1/20


Epoch 18, Train Loss: 1.6146468559464255, Validation Loss: 1.3103879818549522
No improvement in validation loss. Patience counter: 2/20


Epoch 19, Train Loss: 1.6071602086444476, Validation Loss: 1.30440182808118
No improvement in validation loss. Patience counter: 3/20


Epoch 20, Train Loss: 1.6070226550102233, Validation Loss: 1.2977955096807234
No improvement in validation loss. Patience counter: 4/20


Epoch 21, Train Loss: 1.6041028025386097, Validation Loss: 1.3374273232924632
No improvement in validation loss. Patience counter: 5/20


Epoch 22, Train Loss: 1.6015415906906127, Validation Loss: 1.2998760162255703
No improvement in validation loss. Patience counter: 6/20


Epoch 23, Train Loss: 1.6037300733419566, Validation Loss: 1.3139782135303204
No improvement in validation loss. Patience counter: 7/20


Epoch 24, Train Loss: 1.602486045543964, Validation Loss: 1.298084474221254
No improvement in validation loss. Patience counter: 8/20


Epoch 25, Train Loss: 1.597577103808686, Validation Loss: 1.2871277748010097
Validation loss improved to 1.2871277748010097, resetting patience.


Epoch 26, Train Loss: 1.5981917077368433, Validation Loss: 1.3313607442073332
No improvement in validation loss. Patience counter: 1/20


Epoch 27, Train Loss: 1.5975774471576398, Validation Loss: 1.2713143079708784
Validation loss improved to 1.2713143079708784, resetting patience.


Epoch 28, Train Loss: 1.595478254360157, Validation Loss: 1.2823990717912332
No improvement in validation loss. Patience counter: 1/20


Epoch 29, Train Loss: 1.5895923740261204, Validation Loss: 1.348939372942998
No improvement in validation loss. Patience counter: 2/20


Epoch 30, Train Loss: 1.5892824262053102, Validation Loss: 1.2667406510084103
Validation loss improved to 1.2667406510084103, resetting patience.


Epoch 31, Train Loss: 1.5923275930540903, Validation Loss: 1.3249787440666785
No improvement in validation loss. Patience counter: 1/20


Epoch 32, Train Loss: 1.5949806455727462, Validation Loss: 1.2917368699342777
No improvement in validation loss. Patience counter: 2/20


Epoch 33, Train Loss: 1.5853276371955871, Validation Loss: 1.2608514816333085
Validation loss improved to 1.2608514816333085, resetting patience.


Epoch 34, Train Loss: 1.5886982439638495, Validation Loss: 1.286879233824901
No improvement in validation loss. Patience counter: 1/20


Epoch 35, Train Loss: 1.5913800221223098, Validation Loss: 1.2680377587293967
No improvement in validation loss. Patience counter: 2/20


Epoch 36, Train Loss: 1.585889547604781, Validation Loss: 1.2966479105827136
No improvement in validation loss. Patience counter: 3/20


Epoch 37, Train Loss: 1.582185109762045, Validation Loss: 1.273205098739037
No improvement in validation loss. Patience counter: 4/20


Epoch 38, Train Loss: 1.5861847910252247, Validation Loss: 1.2647455343833336
No improvement in validation loss. Patience counter: 5/20


Epoch 39, Train Loss: 1.5825955135481697, Validation Loss: 1.2688266552411593
No improvement in validation loss. Patience counter: 6/20


Epoch 40, Train Loss: 1.5887635242807996, Validation Loss: 1.2754566155947171
No improvement in validation loss. Patience counter: 7/20


Epoch 41, Train Loss: 1.5841120725149638, Validation Loss: 1.2827105057545198
No improvement in validation loss. Patience counter: 8/20


Epoch 42, Train Loss: 1.5803781437349844, Validation Loss: 1.2779253152700571
No improvement in validation loss. Patience counter: 9/20


Epoch 43, Train Loss: 1.5825410852065454, Validation Loss: 1.2786066544361603
No improvement in validation loss. Patience counter: 10/20


Epoch 44, Train Loss: 1.5833091467291445, Validation Loss: 1.293207816588573
No improvement in validation loss. Patience counter: 11/20


Epoch 45, Train Loss: 1.5767056863386553, Validation Loss: 1.248843208948771
Validation loss improved to 1.248843208948771, resetting patience.


Epoch 46, Train Loss: 1.5803340427168122, Validation Loss: 1.3204137282493786
No improvement in validation loss. Patience counter: 1/20


Epoch 47, Train Loss: 1.5792794877356224, Validation Loss: 1.3004733684735421
No improvement in validation loss. Patience counter: 2/20


Epoch 48, Train Loss: 1.5787285593839793, Validation Loss: 1.2856864085564246
No improvement in validation loss. Patience counter: 3/20


Epoch 49, Train Loss: 1.5801452694358407, Validation Loss: 1.3011344891328078
No improvement in validation loss. Patience counter: 4/20


Epoch 50, Train Loss: 1.5801833452759209, Validation Loss: 1.2520975461372963
No improvement in validation loss. Patience counter: 5/20


Epoch 51, Train Loss: 1.581165923522069, Validation Loss: 1.3031530392475617
No improvement in validation loss. Patience counter: 6/20


Epoch 52, Train Loss: 1.5808077159818712, Validation Loss: 1.2697694216019069
No improvement in validation loss. Patience counter: 7/20


Epoch 53, Train Loss: 1.579150381586054, Validation Loss: 1.2918670226366091
No improvement in validation loss. Patience counter: 8/20


Epoch 54, Train Loss: 1.5819461176683614, Validation Loss: 1.2902140812996106
No improvement in validation loss. Patience counter: 9/20


Epoch 55, Train Loss: 1.575628843674293, Validation Loss: 1.2948355032847478
No improvement in validation loss. Patience counter: 10/20


Epoch 56, Train Loss: 1.5784383055927989, Validation Loss: 1.271274321507185
No improvement in validation loss. Patience counter: 11/20


Epoch 57, Train Loss: 1.577838177733369, Validation Loss: 1.307960295677185
No improvement in validation loss. Patience counter: 12/20


Epoch 58, Train Loss: 1.5725228803498403, Validation Loss: 1.2818125675886105
No improvement in validation loss. Patience counter: 13/20


Epoch 59, Train Loss: 1.5767528396386368, Validation Loss: 1.2506222908313458
No improvement in validation loss. Patience counter: 14/20


Epoch 60, Train Loss: 1.5728412890172268, Validation Loss: 1.2431713556632018
Validation loss improved to 1.2431713556632018, resetting patience.


Epoch 61, Train Loss: 1.5775636840652634, Validation Loss: 1.2655392224972064
No improvement in validation loss. Patience counter: 1/20


Epoch 62, Train Loss: 1.5746688606974844, Validation Loss: 1.2647027089045597
No improvement in validation loss. Patience counter: 2/20


Epoch 63, Train Loss: 1.5773389600135468, Validation Loss: 1.3305607123252674
No improvement in validation loss. Patience counter: 3/20


Epoch 64, Train Loss: 1.5759022775587144, Validation Loss: 1.2659139669858492
No improvement in validation loss. Patience counter: 4/20


Epoch 65, Train Loss: 1.5718247200106525, Validation Loss: 1.290163941261096
No improvement in validation loss. Patience counter: 5/20


Epoch 66, Train Loss: 1.5755277052030459, Validation Loss: 1.2622979482014973
No improvement in validation loss. Patience counter: 6/20


Epoch 67, Train Loss: 1.5782692880420894, Validation Loss: 1.2642918550051176
No improvement in validation loss. Patience counter: 7/20


Epoch 68, Train Loss: 1.57633174264824, Validation Loss: 1.2888376969557542
No improvement in validation loss. Patience counter: 8/20


Epoch 69, Train Loss: 1.5731343387247443, Validation Loss: 1.2559104161384778
No improvement in validation loss. Patience counter: 9/20


Epoch 70, Train Loss: 1.5758811671655257, Validation Loss: 1.2794842139268532
No improvement in validation loss. Patience counter: 10/20


Epoch 71, Train Loss: 1.5768200781319168, Validation Loss: 1.277298106902685
No improvement in validation loss. Patience counter: 11/20


Epoch 72, Train Loss: 1.5721304548965707, Validation Loss: 1.252407530026558
No improvement in validation loss. Patience counter: 12/20


Epoch 73, Train Loss: 1.5711018796805496, Validation Loss: 1.2398320650443053
Validation loss improved to 1.2398320650443053, resetting patience.


Epoch 74, Train Loss: 1.5746946227419507, Validation Loss: 1.260490485949394
No improvement in validation loss. Patience counter: 1/20


Epoch 75, Train Loss: 1.5740330449827424, Validation Loss: 1.3063658145757822
No improvement in validation loss. Patience counter: 2/20


Epoch 76, Train Loss: 1.5755265580428826, Validation Loss: 1.2808178235323002
No improvement in validation loss. Patience counter: 3/20


Epoch 77, Train Loss: 1.576483260013245, Validation Loss: 1.303632698303614
No improvement in validation loss. Patience counter: 4/20


Epoch 78, Train Loss: 1.5739756467578174, Validation Loss: 1.2692323128382366
No improvement in validation loss. Patience counter: 5/20


Epoch 79, Train Loss: 1.5731806456387698, Validation Loss: 1.2963940657102144
No improvement in validation loss. Patience counter: 6/20


Epoch 80, Train Loss: 1.570631651170961, Validation Loss: 1.247649072072445
No improvement in validation loss. Patience counter: 7/20


Epoch 81, Train Loss: 1.5716140406472343, Validation Loss: 1.2863646366657355
No improvement in validation loss. Patience counter: 8/20


Epoch 82, Train Loss: 1.5740868547460536, Validation Loss: 1.2837514565541195
No improvement in validation loss. Patience counter: 9/20


Epoch 83, Train Loss: 1.577333141552223, Validation Loss: 1.2710661775026566
No improvement in validation loss. Patience counter: 10/20


Epoch 84, Train Loss: 1.5686981744818635, Validation Loss: 1.2735868777984227
No improvement in validation loss. Patience counter: 11/20


Epoch 85, Train Loss: 1.5732525723321098, Validation Loss: 1.2692386370438795
No improvement in validation loss. Patience counter: 12/20


Epoch 86, Train Loss: 1.5752425233086387, Validation Loss: 1.2895029648756369
No improvement in validation loss. Patience counter: 13/20


Epoch 87, Train Loss: 1.5718904416639725, Validation Loss: 1.2603942467616154
No improvement in validation loss. Patience counter: 14/20


Epoch 88, Train Loss: 1.5692151131210745, Validation Loss: 1.2363288610409469
Validation loss improved to 1.2363288610409469, resetting patience.


Epoch 89, Train Loss: 1.570662634058313, Validation Loss: 1.2979048765622652
No improvement in validation loss. Patience counter: 1/20


Epoch 90, Train Loss: 1.572082101905739, Validation Loss: 1.2636821624560235
No improvement in validation loss. Patience counter: 2/20


Epoch 91, Train Loss: 1.5738044808199116, Validation Loss: 1.2647307108610104
No improvement in validation loss. Patience counter: 3/20


Epoch 92, Train Loss: 1.5705757987368238, Validation Loss: 1.2554017837230975
No improvement in validation loss. Patience counter: 4/20


Epoch 93, Train Loss: 1.5723201197582286, Validation Loss: 1.286663782902253
No improvement in validation loss. Patience counter: 5/20


Epoch 94, Train Loss: 1.570893567604023, Validation Loss: 1.2961475133895874
No improvement in validation loss. Patience counter: 6/20


Epoch 95, Train Loss: 1.5727092704930148, Validation Loss: 1.285610471627651
No improvement in validation loss. Patience counter: 7/20


Epoch 96, Train Loss: 1.5719295007841927, Validation Loss: 1.2587693559817779
No improvement in validation loss. Patience counter: 8/20


Epoch 97, Train Loss: 1.5726159859489608, Validation Loss: 1.2982487684641129
No improvement in validation loss. Patience counter: 9/20


Epoch 98, Train Loss: 1.5726066713804727, Validation Loss: 1.2989840550300402
No improvement in validation loss. Patience counter: 10/20


Epoch 99, Train Loss: 1.5711589185746162, Validation Loss: 1.2713413152939235
No improvement in validation loss. Patience counter: 11/20


Epoch 100, Train Loss: 1.5691641912355527, Validation Loss: 1.262720274925232
No improvement in validation loss. Patience counter: 12/20


Epoch 101, Train Loss: 1.5674439765594819, Validation Loss: 1.2551354811741755
No improvement in validation loss. Patience counter: 13/20


Epoch 102, Train Loss: 1.5754019563014692, Validation Loss: 1.2473497182894975
No improvement in validation loss. Patience counter: 14/20


Epoch 103, Train Loss: 1.570409950843224, Validation Loss: 1.2617111640098768
No improvement in validation loss. Patience counter: 15/20


Epoch 104, Train Loss: 1.567026573747069, Validation Loss: 1.251012479953277
No improvement in validation loss. Patience counter: 16/20


Epoch 105, Train Loss: 1.5713671043678954, Validation Loss: 1.2565010676017174
No improvement in validation loss. Patience counter: 17/20


Epoch 106, Train Loss: 1.5711480867731702, Validation Loss: 1.2590299380131258
No improvement in validation loss. Patience counter: 18/20


Epoch 107, Train Loss: 1.5706374134336198, Validation Loss: 1.2928651632406774
No improvement in validation loss. Patience counter: 19/20


Epoch 108, Train Loss: 1.5744117066100403, Validation Loss: 1.270858065287272
No improvement in validation loss. Patience counter: 20/20
Early stopping triggered. Stopping training at epoch 108.
Loaded best model weights based on validation loss.


In [12]:
# 8. 使用 KNN 替代测试阶段的 FC 分类层
# 使用 encoder 提取训练集和测试集的 embedding
estim.model.eval()
with torch.no_grad():
    train_embeddings = estim.model.encoder(torch.tensor(X_train).float().to(device)).cpu().numpy()
    test_embeddings = estim.model.encoder(torch.tensor(X_test).float().to(device)).cpu().numpy()

In [13]:
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report

    

    # 初始化和训练KNN分类器
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(train_embeddings, y_train)
    
    # 模型预测
predictions = knn.predict(test_embeddings)

    # 计算准确率和 F1 分数
accuracy = accuracy_score(y_test, predictions)
print(f"KNN Accuracy on Test Data: {accuracy}")
f1 = f1_score(y_test, predictions, average='weighted')
print(f"Weighted F1 Score: {f1}")
    
macro_f1 = f1_score(y_test, predictions, average='macro')
print(f'Macro F1 Score: {macro_f1}')

    # 计算随机猜测的准确率
class_probabilities = np.bincount(y_test) / len(y_test)
random_accuracy = np.sum(class_probabilities ** 2)
print(f"Random Guess Accuracy: {random_accuracy}")

    # 生成分类报告
report = classification_report(y_test, predictions, target_names=label_encoder.classes_)
print(report)

KNN Accuracy on Test Data: 0.6339264246213032
Weighted F1 Score: 0.6230539061032725
Macro F1 Score: 0.4601248647892448
Random Guess Accuracy: 0.13330938159201633
                         precision    recall  f1-score   support

                B_Cells       0.73      0.76      0.75       770
           CD4+_T_Cells       0.54      0.70      0.61      1264
           CD8+_T_Cells       0.55      0.56      0.56      1069
                 DCIS_1       0.39      0.41      0.40      1927
                 DCIS_2       0.38      0.29      0.33      1739
            Endothelial       0.82      0.86      0.84      1347
              IRF7+_DCs       0.69      0.51      0.58        73
         Invasive_Tumor       0.60      0.70      0.64      5153
             LAMP3+_DCs       0.60      0.07      0.12        45
          Macrophages_1       0.72      0.77      0.74      1646
          Macrophages_2       0.47      0.29      0.36       256
             Mast_Cells       1.00      0.07      0.13   

In [14]:
with torch.no_grad():
    new_data_tensor = torch.tensor(new_data).float().to(device)
    SSL_embeddings = estim.model.encoder(new_data_tensor).detach().cpu().numpy()
new_adata = sc.read_h5ad(data_dir)
new_adata.obsm[f'SSL_GP_FT_{random_seed}'] = SSL_embeddings
new_adata.uns[f'GP_FT_y_test_{random_seed}'] = y_test
new_adata.uns[f'GP_FT_predictions_{random_seed}'] = predictions
new_adata.uns[f'GP_FT_target_names_{random_seed}'] = label_encoder.classes_
new_adata.uns[f'GP_FT_train_loss_{random_seed}'] = train_losses
new_adata.uns[f'GP_FT_val_loss_{random_seed}'] = val_losses
new_adata.write_h5ad(data_dir)

In [15]:

import pandas as pd
import os
import re

# 当前 Notebook 文件名
notebook_name = "uniport_imputed_Xenium_breast_cancer_sample1_replicate1_GP_mask_fine_tune_42.ipynb"

# 初始化需要打印的值
init_train_loss = train_losses[0] if 'train_losses' in globals() else None
init_val_loss = val_losses[0] if 'val_losses' in globals() else None
converged_epoch = len(train_losses) - patience if 'train_losses' in globals() else None
converged_val_loss = best_val_loss if 'best_val_loss' in globals() else None

# 打印所有所需的指标
print("Metrics Summary:")
if 'train_losses' in globals():
    print(f"init_train_loss\tinit_val_loss\tconverged_epoch\tconverged_val_loss\tmacro_f1\tweighted_f1\tmicro_f1")
    print(f"{init_train_loss:.3f}\t{init_val_loss:.3f}\t{converged_epoch}\t{converged_val_loss:.3f}\t{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")
else:
    print(f"macro_f1\tweighted_f1\tmicro_f1")
    print(f"{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")

# 保存结果到 CSV 文件
output_data = {
    'dataset_split_random_seed': [int(random_seed)],
    'dataset': ['uniport_imputed_xenium_breast_cancer_sample1_replicate1'],
    'method': [re.search(r'replicate1_(.*?)_\d+', notebook_name).group(1)],
    'init_train_loss': [init_train_loss if init_train_loss is not None else ''],
    'init_val_loss': [init_val_loss if init_val_loss is not None else ''],
    'converged_epoch': [converged_epoch if converged_epoch is not None else ''],
    'converged_val_loss': [converged_val_loss if converged_val_loss is not None else ''],
    'macro_f1': [macro_f1],
    'weighted_f1': [f1],
    'micro_f1': [accuracy]
}
output_df = pd.DataFrame(output_data)

# 保存到当前目录下名为 results 的文件夹中
if not os.path.exists('results'):
    os.makedirs('results')

csv_filename = f"results/{os.path.splitext(notebook_name)[0]}_results.csv"
output_df.to_csv(csv_filename, index=False)


Metrics Summary:
init_train_loss	init_val_loss	converged_epoch	converged_val_loss	macro_f1	weighted_f1	micro_f1
1.892	1.498	88	1.236	0.460	0.623	0.634
