In [1]:
import scanpy as sc
import torch
import lightning.pytorch as pl
from torch import nn
from torch.optim import AdamW
from self_supervision.models.lightning_modules.cellnet_autoencoder import MLPAutoEncoder
from self_supervision.estimator.cellnet import EstimatorAutoEncoder
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np
import pandas as pd

# 1. 加载数据集
data_dir = '../../dataset/uniport_imputed_Xenium_breast_cancer_sample1_replicate1.h5ad'
adata = sc.read_h5ad(data_dir)


# 2. 模型参数
units_encoder = [512, 512, 256, 256, 64]
units_decoder = [256, 256, 512, 512]

# 初始化 EstimatorAutoEncoder 实例
estim = EstimatorAutoEncoder(data_path=None)  # 没有实际数据路径，可以设置为None

# 3. 直接初始化从头训练的模型
estim.model = MLPAutoEncoder(
    gene_dim=19331,  # 根据你的数据调整
    batch_size=128,  # 根据你的需要调整
    units_encoder=units_encoder, 
    units_decoder=units_decoder,
    masking_strategy="random",  # 假设模型使用了随机掩码
    masking_rate=0.5,  # 根据需要调整
)

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


In [2]:
# 添加分类层 (FC)
n_classes = len(adata.obs['cell_type'].unique())
estim.model.fc = nn.Linear(units_encoder[-1], n_classes)

# 处理基因数据
var_df = pd.read_parquet('../../sc_pretrained/var.parquet')
all_genes = var_df['feature_name'].tolist()

new_data = np.zeros((adata.X.shape[0], len(all_genes)), dtype=np.float32)
adata.var['gene_name'] = adata.var.index
existing_genes = adata.var['gene_name']

# 将基因名称转换为小写
all_genes_lower = [gene.lower() for gene in all_genes]
adata_genes_lower = [gene.lower() for gene in existing_genes]

# 计算基因匹配
all_genes_set = set(all_genes_lower)
adata_genes_set = set(adata_genes_lower)
matching_genes = all_genes_set.intersection(adata_genes_set)

# 填充 new_data
gene_to_index = {gene: idx for idx, gene in enumerate(all_genes_lower)}
dense_adata_X = adata.X
for i, gene in enumerate(adata_genes_lower):
    if gene in gene_to_index:
        new_data[:, gene_to_index[gene]] = dense_adata_X[:, i]
    else:
        print(f'Gene {gene} not found in all_genes list')

# 使用 GPU 进行评估（如果可用）
estim.trainer = pl.Trainer(accelerator="gpu", devices=1 if torch.cuda.is_available() else None)

# 4. 数据集划分（70% 训练，15% 验证，15% 测试）
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(adata.obs['cell_type'])

random_seed = 42

X_train_val, X_test, y_train_val, y_test = train_test_split(
    new_data, labels_encoded, test_size=0.15, random_state=random_seed)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.1765, random_state=random_seed)  # 0.1765 是为了让验证集占 15%

# 将训练数据转换为张量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_train_tensor = torch.tensor(X_train).float().to(device)
y_train_tensor = torch.tensor(y_train).long().to(device)
X_val_tensor = torch.tensor(X_val).float().to(device)
y_val_tensor = torch.tensor(y_val).long().to(device)
estim.model.to(device)

# 5. 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(estim.model.parameters(), lr=9e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

# 使用 TensorDataset 将训练数据和标签打包
from torch.utils.data import DataLoader, TensorDataset
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)

Gene al391650.1 not found in all_genes list


Gene yars not found in all_genes list
Gene adprhl2 not found in all_genes list


Gene tctex1d4 not found in all_genes list


Gene tctex1d1 not found in all_genes list
Gene wdr78 not found in all_genes list
Gene hhla3 not found in all_genes list
Gene ac118549.1 not found in all_genes list
Gene wdr63 not found in all_genes list


Gene kiaa1324 not found in all_genes list
Gene sars not found in all_genes list


Gene hist2h2be not found in all_genes list


Gene al162596.1 not found in all_genes list
Gene lor not found in all_genes list


Gene c1orf61 not found in all_genes list
Gene al590560.2 not found in all_genes list


Gene rgs5 not found in all_genes list
Gene dusp27 not found in all_genes list


Gene eprs not found in all_genes list
Gene marc2 not found in all_genes list
Gene marc1 not found in all_genes list
Gene h3f3a not found in all_genes list
Gene hist3h3 not found in all_genes list
Gene hist3h2a not found in all_genes list
Gene hist3h2bb not found in all_genes list


Gene al109810.2 not found in all_genes list
Gene tbce not found in all_genes list
Gene adss not found in all_genes list
Gene al451007.3 not found in all_genes list


Gene gcsaml-as1 not found in all_genes list


Gene mycnos not found in all_genes list
Gene fam49a not found in all_genes list


Gene c2orf91 not found in all_genes list


Gene igkc not found in all_genes list
Gene al845331.2 not found in all_genes list
Gene ac092835.1 not found in all_genes list
Gene kiaa1211l not found in all_genes list


Gene dars not found in all_genes list
Gene march7 not found in all_genes list


Gene pde11a not found in all_genes list
Gene dirc1 not found in all_genes list


Gene march4 not found in all_genes list
Gene ccdc140 not found in all_genes list


Gene c2orf83 not found in all_genes list


Gene arih2os not found in all_genes list
Gene qars not found in all_genes list
Gene ccdc36 not found in all_genes list
Gene cyb561d2 not found in all_genes list


Gene c3orf67 not found in all_genes list


Gene maats1 not found in all_genes list
Gene alg1l not found in all_genes list


Gene kiaa1257 not found in all_genes list
Gene h1fx not found in all_genes list
Gene h1foo not found in all_genes list
Gene acpp not found in all_genes list


Gene slc66a1l not found in all_genes list
Gene terc not found in all_genes list


Gene ccdc39 not found in all_genes list
Gene ac072022.2 not found in all_genes list


Gene tctex1d2 not found in all_genes list


Gene ac093323.1 not found in all_genes list


Gene kiaa1211 not found in all_genes list


Gene h2afz not found in all_genes list


Gene tmem155 not found in all_genes list


Gene march1 not found in all_genes list
Gene fam218a not found in all_genes list


Gene march6 not found in all_genes list
Gene march11 not found in all_genes list
Gene ac106774.4 not found in all_genes list
Gene h3.y not found in all_genes list
Gene tars not found in all_genes list


Gene c5orf67 not found in all_genes list


Gene atp6ap1l not found in all_genes list
Gene c5orf30 not found in all_genes list


Gene ac010255.3 not found in all_genes list
Gene march3 not found in all_genes list


Gene h2afy not found in all_genes list
Gene tmem173 not found in all_genes list
Gene hars not found in all_genes list


Gene lars not found in all_genes list


Gene rars not found in all_genes list


Gene ac113348.1 not found in all_genes list
Gene c5orf60 not found in all_genes list
Gene c6orf201 not found in all_genes list


Gene hist1h2aa not found in all_genes list
Gene hist1h2ba not found in all_genes list
Gene hist1h1a not found in all_genes list
Gene hist1h4b not found in all_genes list
Gene hist1h2bb not found in all_genes list
Gene hist1h1c not found in all_genes list
Gene hist1h4c not found in all_genes list
Gene hist1h2ac not found in all_genes list
Gene hist1h1e not found in all_genes list
Gene hist1h4e not found in all_genes list
Gene hist1h2bg not found in all_genes list
Gene hist1h2ae not found in all_genes list
Gene hist1h3e not found in all_genes list
Gene hist1h1d not found in all_genes list
Gene hist1h4g not found in all_genes list
Gene hist1h2bh not found in all_genes list
Gene hist1h3g not found in all_genes list
Gene hist1h2ag not found in all_genes list
Gene hist1h4i not found in all_genes list
Gene hist1h2ai not found in all_genes list
Gene hist1h3h not found in all_genes list
Gene hist1h4j not found in all_genes list
Gene hist1h2bn not found in all_genes list
Gene hist1h2ak not found

Gene vars not found in all_genes list
Gene snhg32 not found in all_genes list


Gene c6orf223 not found in all_genes list
Gene defb133 not found in all_genes list
Gene ick not found in all_genes list


Gene al135905.2 not found in all_genes list


Gene fgfr1op not found in all_genes list
Gene tcte3 not found in all_genes list
Gene ac187653.1 not found in all_genes list


Gene ac013470.2 not found in all_genes list
Gene twistnb not found in all_genes list


Gene ac004593.3 not found in all_genes list
Gene gars not found in all_genes list
Gene trgc2 not found in all_genes list
Gene trgjp2 not found in all_genes list
Gene trgc1 not found in all_genes list
Gene trgjp1 not found in all_genes list


Gene ac115220.1 not found in all_genes list


Gene kiaa1324l not found in all_genes list


Gene castor3 not found in all_genes list


Gene c7orf77 not found in all_genes list
Gene ac011005.1 not found in all_genes list


Gene trbc1 not found in all_genes list
Gene trbc2 not found in all_genes list
Gene sspo not found in all_genes list
Gene ac073111.4 not found in all_genes list


Gene ac021097.2 not found in all_genes list
Gene wdr60 not found in all_genes list
Gene ac134684.8 not found in all_genes list


Gene pinx1 not found in all_genes list


Gene impad1 not found in all_genes list


Gene wdyhv1 not found in all_genes list
Gene fam49b not found in all_genes list
Gene ac138647.1 not found in all_genes list


Gene tsta3 not found in all_genes list
Gene dock8-as1 not found in all_genes list


Gene c9orf92 not found in all_genes list


Gene al162231.1 not found in all_genes list


Gene fam122a not found in all_genes list
Gene al353572.3 not found in all_genes list


Gene iars not found in all_genes list
Gene c9orf129 not found in all_genes list
Gene al160269.1 not found in all_genes list


Gene tmem246 not found in all_genes list
Gene palm2-akap2 not found in all_genes list
Gene znf883 not found in all_genes list


Gene dec1 not found in all_genes list
Gene b3gnt10 not found in all_genes list


Gene wdr34 not found in all_genes list


Gene al354761.1 not found in all_genes list
Gene bx255925.3 not found in all_genes list


Gene mir1915hg not found in all_genes list
Gene armc4 not found in all_genes list


Gene c10orf142 not found in all_genes list
Gene march8 not found in all_genes list
Gene ac067752.1 not found in all_genes list


Gene kif1bp not found in all_genes list
Gene h2afy2 not found in all_genes list
Gene c10orf55 not found in all_genes list
Gene dupd1 not found in all_genes list


Gene march5 not found in all_genes list


Gene atp5md not found in all_genes list


Gene al603764.2 not found in all_genes list


Gene pano1 not found in all_genes list
Gene ac132217.2 not found in all_genes list
Gene cars not found in all_genes list
Gene c11orf40 not found in all_genes list


Gene ac104389.5 not found in all_genes list


Gene st5 not found in all_genes list
Gene mrvi1 not found in all_genes list


Gene c11orf74 not found in all_genes list


Gene or5r1 not found in all_genes list


Gene ap002495.1 not found in all_genes list


Gene card16 not found in all_genes list
Gene card17 not found in all_genes list
Gene c11orf88 not found in all_genes list


Gene ccdc84 not found in all_genes list
Gene h2afx not found in all_genes list


Gene hist4h4 not found in all_genes list
Gene h2afj not found in all_genes list
Gene lrmp not found in all_genes list
Gene casc1 not found in all_genes list


Gene h3f3c not found in all_genes list
Gene h1fnt not found in all_genes list


Gene c12orf81 not found in all_genes list
Gene grasp not found in all_genes list
Gene ac021072.1 not found in all_genes list
Gene c12orf10 not found in all_genes list


Gene mars not found in all_genes list
Gene slc26a10 not found in all_genes list
Gene march9 not found in all_genes list


Gene cllu1os not found in all_genes list
Gene c12orf74 not found in all_genes list


Gene c12orf49 not found in all_genes list
Gene wdr66 not found in all_genes list


Gene spata13 not found in all_genes list


Gene spert not found in all_genes list
Gene al445238.1 not found in all_genes list


Gene trdc not found in all_genes list
Gene trac not found in all_genes list


Gene sfta3 not found in all_genes list


Gene elmsan1 not found in all_genes list


Gene c14orf177 not found in all_genes list
Gene wars not found in all_genes list
Gene atp5mpl not found in all_genes list
Gene adssl1 not found in all_genes list


Gene igha2 not found in all_genes list
Gene ighe not found in all_genes list
Gene igha1 not found in all_genes list
Gene ighg1 not found in all_genes list
Gene ighg3 not found in all_genes list
Gene ighd not found in all_genes list
Gene ighm not found in all_genes list
Gene fam30a not found in all_genes list
Gene ac135068.1 not found in all_genes list
Gene golga8m not found in all_genes list
Gene ac091057.6 not found in all_genes list
Gene c15orf41 not found in all_genes list
Gene linc02694 not found in all_genes list


Gene casc4 not found in all_genes list


Gene ct62 not found in all_genes list


Gene ac015871.1 not found in all_genes list


Gene spata8 not found in all_genes list
Gene fam169b not found in all_genes list
Gene tarsl2 not found in all_genes list
Gene tmem8a not found in all_genes list


Gene al032819.3 not found in all_genes list


Gene ac025283.2 not found in all_genes list


Gene ac099489.1 not found in all_genes list
Gene fopnl not found in all_genes list


Gene kiaa0556 not found in all_genes list


Gene c16orf58 not found in all_genes list
Gene ac007906.2 not found in all_genes list


Gene fam192a not found in all_genes list


Gene lrrc29 not found in all_genes list


Gene aars not found in all_genes list
Gene kars not found in all_genes list
Gene ac025287.4 not found in all_genes list


Gene fam92b not found in all_genes list
Gene cenpbd1 not found in all_genes list


Gene ac087498.1 not found in all_genes list
Gene ac233723.1 not found in all_genes list


Gene trim16l not found in all_genes list
Gene linc02693 not found in all_genes list


Gene slfn12l not found in all_genes list


Gene tmem99 not found in all_genes list
Gene ttc25 not found in all_genes list


Gene g6pc not found in all_genes list
Gene c17orf53 not found in all_genes list


Gene ac011195.2 not found in all_genes list
Gene march10 not found in all_genes list


Gene h3f3b not found in all_genes list


Gene eloa3 not found in all_genes list
Gene nars not found in all_genes list


Gene ac090360.1 not found in all_genes list


Gene ac005551.1 not found in all_genes list


Gene ac119396.1 not found in all_genes list
Gene march2 not found in all_genes list


Gene ccdc151 not found in all_genes list


Gene c19orf57 not found in all_genes list


Gene ac008397.1 not found in all_genes list


Gene kiaa0355 not found in all_genes list


Gene cntd2 not found in all_genes list


Gene cd3eap not found in all_genes list
Gene bhmg1 not found in all_genes list
Gene ppp5d1 not found in all_genes list


Gene ccdc114 not found in all_genes list
Gene ac008687.4 not found in all_genes list
Gene ccdc155 not found in all_genes list


Gene ac010325.1 not found in all_genes list
Gene c19orf48 not found in all_genes list
Gene siglec5 not found in all_genes list


Gene gdf5os not found in all_genes list


Gene tmem189 not found in all_genes list


Gene fp565260.1 not found in all_genes list


Gene ap000552.4 not found in all_genes list
Gene iglc1 not found in all_genes list
Gene iglc7 not found in all_genes list
Gene lrp5l not found in all_genes list


Gene elfn2 not found in all_genes list
Gene h1f0 not found in all_genes list
Gene z82206.1 not found in all_genes list


Gene arse not found in all_genes list


Gene cxorf21 not found in all_genes list
Gene hypm not found in all_genes list
Gene al121578.2 not found in all_genes list


Gene bx276092.9 not found in all_genes list


Gene nxf5 not found in all_genes list
Gene glra4 not found in all_genes list
Gene tmsb15b not found in all_genes list
Gene h2bfwt not found in all_genes list
Gene h2bfm not found in all_genes list
Gene pih1d3 not found in all_genes list


Gene al772284.2 not found in all_genes list
Gene cxorf56 not found in all_genes list
Gene fam122b not found in all_genes list
Gene fam122c not found in all_genes list


Gene cxorf40a not found in all_genes list
Gene ac236972.4 not found in all_genes list


GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


Gene prky not found in all_genes list
Gene ac007244.1 not found in all_genes list


In [3]:
# 6. 定义训练过程
def train_epoch(model, optimizer, loss_fn, train_loader, val_loader):
    model.train()
    total_train_loss = 0
    
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model.encoder(X_batch)
        logits = model.fc(outputs)
        
        loss = loss_fn(logits, y_batch)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    # 验证集
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for X_val_batch, y_val_batch in val_loader:
            val_outputs = model.encoder(X_val_batch)
            val_logits = model.fc(val_outputs)
            val_loss = loss_fn(val_logits, y_val_batch)
            total_val_loss += val_loss.item()

    return total_train_loss / len(train_loader), total_val_loss / len(val_loader)



In [4]:
import copy  # 用于保存模型的最佳状态

# Early Stopping 参数
patience = 20  # 如果验证损失在 10 个 epoch 中没有改善，停止训练
min_delta = 1e-4  # 最小改善幅度
patience_counter = 0
best_val_loss = float('inf')  # 初始设置为正无穷大
best_model_weights = copy.deepcopy(estim.model.state_dict())  # 保存最佳模型权重
train_losses = []
val_losses = []

# 训练 500 个 epoch
for epoch in range(500):
    train_loss, val_loss = train_epoch(estim.model, optimizer, loss_fn, train_loader, val_loader)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}')
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Early Stopping 检查
    if val_loss < best_val_loss - min_delta:
        best_val_loss = val_loss  # 更新最佳验证损失
        patience_counter = 0  # 重置 patience 计数器
        best_model_weights = copy.deepcopy(estim.model.state_dict())  # 保存当前最佳模型
        print(f"Validation loss improved to {best_val_loss}, resetting patience.")
    else:
        patience_counter += 1
        print(f"No improvement in validation loss. Patience counter: {patience_counter}/{patience}")
    
    # 如果 patience_counter 超过设置的 patience，停止训练
    if patience_counter >= patience:
        print(f"Early stopping triggered. Stopping training at epoch {epoch+1}.")
        break

# 恢复到训练中性能最好的模型权重
estim.model.load_state_dict(best_model_weights)
print("Loaded best model weights based on validation loss.")

Epoch 1, Train Loss: 1.200995475774283, Validation Loss: 1.0355298977631788
Validation loss improved to 1.0355298977631788, resetting patience.


Epoch 2, Train Loss: 0.9402124064964252, Validation Loss: 0.8765075674423805
Validation loss improved to 0.8765075674423805, resetting patience.


Epoch 3, Train Loss: 0.8416675545357086, Validation Loss: 0.7947758961946536
Validation loss improved to 0.7947758961946536, resetting patience.


Epoch 4, Train Loss: 0.7850184887320131, Validation Loss: 0.7912634586676573
Validation loss improved to 0.7912634586676573, resetting patience.


Epoch 5, Train Loss: 0.7521631270974547, Validation Loss: 0.7202446509630253
Validation loss improved to 0.7202446509630253, resetting patience.


Epoch 6, Train Loss: 0.7242976691041674, Validation Loss: 0.712485614953897
Validation loss improved to 0.712485614953897, resetting patience.


Epoch 7, Train Loss: 0.6902748158344856, Validation Loss: 0.6289988788274619
Validation loss improved to 0.6289988788274619, resetting patience.


Epoch 8, Train Loss: 0.6755253189867669, Validation Loss: 0.626382526067587
Validation loss improved to 0.626382526067587, resetting patience.


Epoch 9, Train Loss: 0.6624657945318536, Validation Loss: 0.5940469613442054
Validation loss improved to 0.5940469613442054, resetting patience.


Epoch 10, Train Loss: 0.6525622173980042, Validation Loss: 0.6221708161708636
No improvement in validation loss. Patience counter: 1/20


Epoch 11, Train Loss: 0.6246208542323375, Validation Loss: 0.5712071968958928
Validation loss improved to 0.5712071968958928, resetting patience.


Epoch 12, Train Loss: 0.6212340256015023, Validation Loss: 0.5458626247369326
Validation loss improved to 0.5458626247369326, resetting patience.


Epoch 13, Train Loss: 0.6200294301732556, Validation Loss: 0.6019711797053997
No improvement in validation loss. Patience counter: 1/20


Epoch 14, Train Loss: 0.6036850278194134, Validation Loss: 0.6211816757153242
No improvement in validation loss. Patience counter: 2/20


Epoch 15, Train Loss: 0.5953770921780512, Validation Loss: 0.5771061402100783
No improvement in validation loss. Patience counter: 3/20


Epoch 16, Train Loss: 0.5895047693134664, Validation Loss: 0.6011325776576996
No improvement in validation loss. Patience counter: 4/20


Epoch 17, Train Loss: 0.5938911947575244, Validation Loss: 0.602213009657004
No improvement in validation loss. Patience counter: 5/20


Epoch 18, Train Loss: 0.5862748742758572, Validation Loss: 0.5274118281327761
Validation loss improved to 0.5274118281327761, resetting patience.


Epoch 19, Train Loss: 0.581870428611944, Validation Loss: 0.5602459060840118
No improvement in validation loss. Patience counter: 1/20


Epoch 20, Train Loss: 0.5800493725708553, Validation Loss: 0.5456084635013189
No improvement in validation loss. Patience counter: 2/20


Epoch 21, Train Loss: 0.5738716944233402, Validation Loss: 0.6299399278102777
No improvement in validation loss. Patience counter: 3/20


Epoch 22, Train Loss: 0.5640309213936984, Validation Loss: 0.7110881725947062
No improvement in validation loss. Patience counter: 4/20


Epoch 23, Train Loss: 0.5700772547459865, Validation Loss: 0.5611729867947407
No improvement in validation loss. Patience counter: 5/20


Epoch 24, Train Loss: 0.5683083554218104, Validation Loss: 0.6665354160162119
No improvement in validation loss. Patience counter: 6/20


Epoch 25, Train Loss: 0.5496911740892536, Validation Loss: 0.5145564238230388
Validation loss improved to 0.5145564238230388, resetting patience.


Epoch 26, Train Loss: 0.557967670900481, Validation Loss: 0.5860654254754384
No improvement in validation loss. Patience counter: 1/20


Epoch 27, Train Loss: 0.5659666468808939, Validation Loss: 0.5694941543615781
No improvement in validation loss. Patience counter: 2/20


Epoch 28, Train Loss: 0.5563792853237508, Validation Loss: 0.5346918471348592
No improvement in validation loss. Patience counter: 3/20


Epoch 29, Train Loss: 0.5553786510622108, Validation Loss: 0.5571298733735696
No improvement in validation loss. Patience counter: 4/20


Epoch 30, Train Loss: 0.545308039181835, Validation Loss: 0.575553851097058
No improvement in validation loss. Patience counter: 5/20


Epoch 31, Train Loss: 0.5412817710703546, Validation Loss: 0.5445984177100353
No improvement in validation loss. Patience counter: 6/20


Epoch 32, Train Loss: 0.5450410026770371, Validation Loss: 0.5462327394730006
No improvement in validation loss. Patience counter: 7/20


Epoch 33, Train Loss: 0.543111904161972, Validation Loss: 0.5434467977438218
No improvement in validation loss. Patience counter: 8/20


Epoch 34, Train Loss: 0.536257026385475, Validation Loss: 0.5355052717221089
No improvement in validation loss. Patience counter: 9/20


Epoch 35, Train Loss: 0.5343443465429348, Validation Loss: 0.5249460717042287
No improvement in validation loss. Patience counter: 10/20


Epoch 36, Train Loss: 0.5446323803165457, Validation Loss: 0.5162046736631638
No improvement in validation loss. Patience counter: 11/20


Epoch 37, Train Loss: 0.5344909676483699, Validation Loss: 0.5567813659325624
No improvement in validation loss. Patience counter: 12/20


Epoch 38, Train Loss: 0.5362645502601351, Validation Loss: 0.5174213935167361
No improvement in validation loss. Patience counter: 13/20


Epoch 39, Train Loss: 0.5347065085238153, Validation Loss: 0.5573529238884266
No improvement in validation loss. Patience counter: 14/20


Epoch 40, Train Loss: 0.5233768841722509, Validation Loss: 0.5771578130049583
No improvement in validation loss. Patience counter: 15/20


Epoch 41, Train Loss: 0.5299537624304111, Validation Loss: 0.5297957249176808
No improvement in validation loss. Patience counter: 16/20


Epoch 42, Train Loss: 0.523912415596155, Validation Loss: 0.5686193716831697
No improvement in validation loss. Patience counter: 17/20


Epoch 43, Train Loss: 0.5241746771139103, Validation Loss: 0.509927411415638
Validation loss improved to 0.509927411415638, resetting patience.


Epoch 44, Train Loss: 0.5200768241515527, Validation Loss: 0.5846061980113005
No improvement in validation loss. Patience counter: 1/20


Epoch 45, Train Loss: 0.5286348323573123, Validation Loss: 0.5814156653025211
No improvement in validation loss. Patience counter: 2/20


Epoch 46, Train Loss: 0.5209753133765944, Validation Loss: 0.5270983350582612
No improvement in validation loss. Patience counter: 3/20


Epoch 47, Train Loss: 0.5264436589820044, Validation Loss: 0.5227468440165887
No improvement in validation loss. Patience counter: 4/20


Epoch 48, Train Loss: 0.5223156854674056, Validation Loss: 0.5318618516127268
No improvement in validation loss. Patience counter: 5/20


Epoch 49, Train Loss: 0.5161451673278442, Validation Loss: 0.5086104408288613
Validation loss improved to 0.5086104408288613, resetting patience.


Epoch 50, Train Loss: 0.5167131979714383, Validation Loss: 0.5295504178756323
No improvement in validation loss. Patience counter: 1/20


Epoch 51, Train Loss: 0.5125376052908845, Validation Loss: 0.5041561308579567
Validation loss improved to 0.5041561308579567, resetting patience.


Epoch 52, Train Loss: 0.5236652278638148, Validation Loss: 0.5476966588925093
No improvement in validation loss. Patience counter: 1/20


Epoch 53, Train Loss: 0.5198762421588321, Validation Loss: 0.528236433940056
No improvement in validation loss. Patience counter: 2/20


Epoch 54, Train Loss: 0.5086517284532169, Validation Loss: 0.5630388470796438
No improvement in validation loss. Patience counter: 3/20


Epoch 55, Train Loss: 0.5069361075267687, Validation Loss: 0.5236268170368977
No improvement in validation loss. Patience counter: 4/20


Epoch 56, Train Loss: 0.5114556656761484, Validation Loss: 0.5586515863736471
No improvement in validation loss. Patience counter: 5/20


Epoch 57, Train Loss: 0.5182456556584809, Validation Loss: 0.5345027691278702
No improvement in validation loss. Patience counter: 6/20


Epoch 58, Train Loss: 0.5160801383492711, Validation Loss: 0.5281859316887
No improvement in validation loss. Patience counter: 7/20


Epoch 59, Train Loss: 0.5136912525355161, Validation Loss: 0.6088448501550234
No improvement in validation loss. Patience counter: 8/20


Epoch 60, Train Loss: 0.5095795485508311, Validation Loss: 0.5145069935382941
No improvement in validation loss. Patience counter: 9/20


Epoch 61, Train Loss: 0.5105521682199541, Validation Loss: 0.5428140469086475
No improvement in validation loss. Patience counter: 10/20


Epoch 62, Train Loss: 0.5048705567697902, Validation Loss: 0.504240654523556
No improvement in validation loss. Patience counter: 11/20


Epoch 63, Train Loss: 0.5144772296095942, Validation Loss: 0.5253415280427688
No improvement in validation loss. Patience counter: 12/20


Epoch 64, Train Loss: 0.5106177722359752, Validation Loss: 0.5391617564054636
No improvement in validation loss. Patience counter: 13/20


Epoch 65, Train Loss: 0.5089944371482829, Validation Loss: 0.514797378044862
No improvement in validation loss. Patience counter: 14/20


Epoch 66, Train Loss: 0.5144783279725483, Validation Loss: 0.5634847361307878
No improvement in validation loss. Patience counter: 15/20


Epoch 67, Train Loss: 0.513566154732809, Validation Loss: 0.6357977702067449
No improvement in validation loss. Patience counter: 16/20


Epoch 68, Train Loss: 0.5052673748561314, Validation Loss: 0.5558232782742916
No improvement in validation loss. Patience counter: 17/20


Epoch 69, Train Loss: 0.5113974365559253, Validation Loss: 0.5519007805066231
No improvement in validation loss. Patience counter: 18/20


Epoch 70, Train Loss: 0.5035942475874345, Validation Loss: 0.5006249757913442
Validation loss improved to 0.5006249757913442, resetting patience.


Epoch 71, Train Loss: 0.50306822346462, Validation Loss: 0.5509606309426136
No improvement in validation loss. Patience counter: 1/20


Epoch 72, Train Loss: 0.5008709654703245, Validation Loss: 0.5606446813314389
No improvement in validation loss. Patience counter: 2/20


Epoch 73, Train Loss: 0.5088483688595531, Validation Loss: 0.6331399269593068
No improvement in validation loss. Patience counter: 3/20


Epoch 74, Train Loss: 0.5031845673754975, Validation Loss: 0.5482458203266829
No improvement in validation loss. Patience counter: 4/20


Epoch 75, Train Loss: 0.5078253944169034, Validation Loss: 0.6065582868380425
No improvement in validation loss. Patience counter: 5/20


Epoch 76, Train Loss: 0.4979758246586873, Validation Loss: 0.518274818169765
No improvement in validation loss. Patience counter: 6/20


Epoch 77, Train Loss: 0.5045623761939478, Validation Loss: 0.5122874131569496
No improvement in validation loss. Patience counter: 7/20


Epoch 78, Train Loss: 0.49893169982747715, Validation Loss: 0.5412351556313344
No improvement in validation loss. Patience counter: 8/20


Epoch 79, Train Loss: 0.49899564581912953, Validation Loss: 0.5706658059205765
No improvement in validation loss. Patience counter: 9/20


Epoch 80, Train Loss: 0.5007627178679456, Validation Loss: 0.5565910458564758
No improvement in validation loss. Patience counter: 10/20


Epoch 81, Train Loss: 0.5085005972084108, Validation Loss: 0.5540711485422575
No improvement in validation loss. Patience counter: 11/20


Epoch 82, Train Loss: 0.5044734107596534, Validation Loss: 0.5672119806974362
No improvement in validation loss. Patience counter: 12/20


Epoch 83, Train Loss: 0.4924975306450666, Validation Loss: 0.5752753431980426
No improvement in validation loss. Patience counter: 13/20


Epoch 84, Train Loss: 0.4890658096625255, Validation Loss: 0.5602985129906581
No improvement in validation loss. Patience counter: 14/20


Epoch 85, Train Loss: 0.5028014655296619, Validation Loss: 0.5707090455752153
No improvement in validation loss. Patience counter: 15/20


Epoch 86, Train Loss: 0.49927075528181514, Validation Loss: 0.5439852488346589
No improvement in validation loss. Patience counter: 16/20


Epoch 87, Train Loss: 0.4937706049982008, Validation Loss: 0.548560419296607
No improvement in validation loss. Patience counter: 17/20


Epoch 88, Train Loss: 0.4941006390439285, Validation Loss: 0.5657554806807102
No improvement in validation loss. Patience counter: 18/20


Epoch 89, Train Loss: 0.4988612699967164, Validation Loss: 0.5520495208410117
No improvement in validation loss. Patience counter: 19/20


Epoch 90, Train Loss: 0.4948109354619141, Validation Loss: 0.5587547776026603
No improvement in validation loss. Patience counter: 20/20
Early stopping triggered. Stopping training at epoch 90.
Loaded best model weights based on validation loss.


In [5]:
# 8. 使用 KNN 替代测试阶段的 FC 分类层
estim.model.eval()
with torch.no_grad():
    train_embeddings = estim.model.encoder(torch.tensor(X_train).float().to(device)).cpu().numpy()
    test_embeddings = estim.model.encoder(torch.tensor(X_test).float().to(device)).cpu().numpy()

from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(train_embeddings, y_train)

predictions = knn.predict(test_embeddings)

accuracy = accuracy_score(y_test, predictions)
print(f"KNN Accuracy on Test Data: {accuracy}")
f1 = f1_score(y_test, predictions, average='weighted')
print(f"Weighted F1 Score: {f1}")

macro_f1 = f1_score(y_test, predictions, average='macro')
print(f'Macro F1 Score: {macro_f1}')

class_probabilities = np.bincount(y_test) / len(y_test)
random_accuracy = np.sum(class_probabilities ** 2)
print(f"Random Guess Accuracy: {random_accuracy}")

report = classification_report(y_test, predictions, target_names=label_encoder.classes_)
print(report)

KNN Accuracy on Test Data: 0.8216718762523042
Weighted F1 Score: 0.817840238547572
Macro F1 Score: 0.7070231329069646
Random Guess Accuracy: 0.13330938159201633
                         precision    recall  f1-score   support

                B_Cells       0.82      0.86      0.84       770
           CD4+_T_Cells       0.68      0.79      0.73      1264
           CD8+_T_Cells       0.69      0.68      0.69      1069
                 DCIS_1       0.79      0.83      0.81      1927
                 DCIS_2       0.72      0.74      0.73      1739
            Endothelial       0.88      0.93      0.90      1347
              IRF7+_DCs       0.82      0.89      0.86        73
         Invasive_Tumor       0.90      0.90      0.90      5153
             LAMP3+_DCs       0.67      0.73      0.70        45
          Macrophages_1       0.81      0.82      0.81      1646
          Macrophages_2       0.64      0.61      0.62       256
             Mast_Cells       0.68      0.46      0.55    

In [6]:
with torch.no_grad():
    new_data_tensor = torch.tensor(new_data).float().to(device)
    SSL_embeddings = estim.model.encoder(new_data_tensor).detach().cpu().numpy()
new_adata = sc.read_h5ad(data_dir)
new_adata.obsm[f'supervised_{random_seed}'] = SSL_embeddings
new_adata.uns[f'supervised_y_test_{random_seed}'] = y_test
new_adata.uns[f'supervised_predictions_{random_seed}'] = predictions
new_adata.uns[f'supervised_target_names_{random_seed}'] = label_encoder.classes_
new_adata.uns[f'supervised_train_loss_{random_seed}'] = train_losses
new_adata.uns[f'supervised_val_loss_{random_seed}'] = val_losses
new_adata.write_h5ad(data_dir)

In [7]:

import pandas as pd
import os
import re

# 当前 Notebook 文件名
notebook_name = "uniport_imputed_Xenium_breast_cancer_sample1_replicate1_supervised_42.ipynb"

# 初始化需要打印的值
init_train_loss = train_losses[0] if 'train_losses' in globals() else None
init_val_loss = val_losses[0] if 'val_losses' in globals() else None
converged_epoch = len(train_losses) - patience if 'train_losses' in globals() else None
converged_val_loss = best_val_loss if 'best_val_loss' in globals() else None

# 打印所有所需的指标
print("Metrics Summary:")
if 'train_losses' in globals():
    print(f"init_train_loss\tinit_val_loss\tconverged_epoch\tconverged_val_loss\tmacro_f1\tweighted_f1\tmicro_f1")
    print(f"{init_train_loss:.3f}\t{init_val_loss:.3f}\t{converged_epoch}\t{converged_val_loss:.3f}\t{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")
else:
    print(f"macro_f1\tweighted_f1\tmicro_f1")
    print(f"{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")

# 保存结果到 CSV 文件
output_data = {
    'dataset_split_random_seed': [int(random_seed)],
    'dataset': ['uniport_imputed_xenium_breast_cancer_sample1_replicate1'],
    'method': [re.search(r'replicate1_(.*?)_\d+', notebook_name).group(1)],
    'init_train_loss': [init_train_loss if init_train_loss is not None else ''],
    'init_val_loss': [init_val_loss if init_val_loss is not None else ''],
    'converged_epoch': [converged_epoch if converged_epoch is not None else ''],
    'converged_val_loss': [converged_val_loss if converged_val_loss is not None else ''],
    'macro_f1': [macro_f1],
    'weighted_f1': [f1],
    'micro_f1': [accuracy]
}
output_df = pd.DataFrame(output_data)

# 保存到当前目录下名为 results 的文件夹中
if not os.path.exists('results'):
    os.makedirs('results')

csv_filename = f"results/{os.path.splitext(notebook_name)[0]}_results.csv"
output_df.to_csv(csv_filename, index=False)


Metrics Summary:
init_train_loss	init_val_loss	converged_epoch	converged_val_loss	macro_f1	weighted_f1	micro_f1
1.201	1.036	70	0.501	0.707	0.818	0.822
