In [1]:
import scanpy as sc
import torch
import lightning.pytorch as pl
from torch import nn
from torch.optim import AdamW
from self_supervision.models.lightning_modules.cellnet_autoencoder import MLPAutoEncoder
from self_supervision.estimator.cellnet import EstimatorAutoEncoder
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np
import pandas as pd

data_dir = '../../dataset/Human_Great_Apes_filtered.h5ad'
adata = sc.read_h5ad(data_dir)
print(adata.X)
sc.pp.normalize_total(adata,target_sum=1e4)
sc.pp.log1p(adata)


# 2. 设置 checkpoint 路径
ckpt_path = "../../sc_pretrained/Pretrained Models/RandomMask.ckpt"

# 3. 模型参数
units_encoder = [512, 512, 256, 256, 64]
units_decoder = [256, 256, 512, 512]

# 初始化 EstimatorAutoEncoder 实例
estim = EstimatorAutoEncoder(data_path=None)  # 没有实际数据路径，可以设置为None

# 4. 加载预训练模型
estim.model = MLPAutoEncoder.load_from_checkpoint(
    ckpt_path,
    gene_dim=19331,  # 根据你的数据调整
    batch_size=8192,  # 根据你的需要调整
    units_encoder=units_encoder, 
    units_decoder=units_decoder,
    masking_strategy="random",  # 假设模型使用了随机掩码
    masking_rate=0.5,  # 根据需要调整
)

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


<Compressed Sparse Row sparse matrix of dtype 'float32'
	with 753825891 stored elements and shape (156285, 17413)>
  Coords	Values
  (0, 12862)	0.5665589570999146
  (0, 3064)	0.5665589570999146
  (0, 13278)	3.3999741077423096
  (0, 813)	2.833569049835205
  (0, 13496)	4.101833820343018
  (0, 5722)	2.5798444747924805
  (0, 15309)	7.879579067230225
  (0, 2014)	0.5665589570999146
  (0, 7585)	5.2227888107299805
  (0, 6840)	5.497739791870117
  (0, 12721)	3.967062473297119
  (0, 15364)	1.8461450338363647
  (0, 8240)	1.5708974599838257
  (0, 11878)	2.2390189170837402
  (0, 8802)	5.875441551208496
  (0, 4476)	4.394822597503662
  (0, 10574)	0.5665589570999146
  (0, 16860)	6.168292999267578
  (0, 2182)	5.933631896972656
  (0, 8212)	2.520345449447632
  (0, 7982)	5.331713676452637
  (0, 2798)	5.730631351470947
  (0, 6747)	0.9259976744651794
  (0, 2034)	6.668968677520752
  (0, 9759)	3.3205926418304443
  :	:
  (156284, 2496)	3.841771125793457
  (156284, 12674)	3.841771125793457
  (156284, 3173)	4.524

In [2]:
# 添加分类层 (FC)
n_classes = len(adata.obs['cell_type'].unique())
estim.model.fc = nn.Linear(units_encoder[-1], n_classes)
n_classes

18

In [3]:
var_df = pd.read_parquet('../../sc_pretrained/var.parquet')
all_genes = var_df['feature_name'].tolist()

new_data = np.zeros((adata.X.shape[0], len(all_genes)), dtype=np.float32)

adata.var['gene_name']=adata.var.index
existing_genes = adata.var['gene_name']

In [4]:
# 将所有基因名称转换为小写
all_genes_lower = [gene.lower() for gene in all_genes]
adata_genes_lower = [gene.lower() for gene in existing_genes]

# 将两个列表转换为集合
all_genes_set = set(all_genes_lower)
adata_genes_set = set(adata_genes_lower)

# 计算交集
matching_genes = all_genes_set.intersection(adata_genes_set)
matching_count = len(matching_genes)
# 计算不匹配的基因
non_matching_genes = adata_genes_set - matching_genes
non_matching_count = len(non_matching_genes)


# 输出结果
print(f"匹配的基因数量: {matching_count}")
print(f"匹配的基因列表: {matching_genes}")
non_matching_genes

匹配的基因数量: 17413
匹配的基因列表: {'tada1', 'mcm7', 'mageb3', 'myl12b', 'jund', 'tekt3', 'zbtb33', 'znf256', 'mt-co1', 'setdb1', 'tbc1d7', 'xylt2', 'atp13a1', 'nbdy', 'hey1', 'dnajc2', 'nsun4', 'gpr17', 'elavl1', 'xcr1', 'c1orf198', 'sesn3', 'znf563', 'trim72', 'eapp', 'nodal', 'c16orf78', 'lrrtm2', 'trap1', 'pgap1', 'ddhd1', 'lhfpl2', 'ahcyl1', 'or1g1', 'ldlrad2', 'egln1', 'ebf1', 'sh2d2a', 'znf514', 'cpsf4l', 'mybbp1a', 'e2f5', 'timp2', 'fam47e-stbd1', 'rgmb', 'gsta4', 'gna14', 'apobec2', 'kazald1', 'pgap3', 'tsc22d3', 'kcnk2', 'mt1b', 'kdr', 'plaat4', 'atp6v0c', 'grpel1', 'tigd3', 'c5orf24', 'krt75', 'ankrd42', 'atg101', 'ccdc91', 'or4d9', 'cops8', 'fbxw4', 'ugt1a6', 'acsf2', 'rrm1', 'il12rb2', 'znf268', 'vdac3', 'zscan23', 'tmem51', 'myod1', 'arl9', 'mthfs', 'trip4', 'arpc1b', 'znf536', 'pnpla7', 'golga8h', 'rbbp7', 'b4galnt3', 'stac', 'nin', 'dpp8', 'fundc1', 'mpi', 'atg3', 'irgm', 'galnt17', 'plat', 'plekhn1', 'asns', 'ndufs7', 'uncx', 'tut7', 'pde6g', 'fam205c', 'bms1', 'gage12c', 'bnc2',

set()

In [5]:
gene_to_index = {gene: idx for idx, gene in enumerate(all_genes_lower)}
dense_adata_X = adata.X.toarray()
for i, gene in enumerate(adata_genes_lower):
    if gene in gene_to_index:
        new_data[:, gene_to_index[gene]] = dense_adata_X[:, i]
    else:
        print(f'Gene {gene} not found in all_genes list')

In [6]:
# 使用 GPU 进行评估（如果可用）
estim.trainer = pl.Trainer(accelerator="gpu", devices=1 if torch.cuda.is_available() else None)

# 5. 数据集划分（70% 训练，15% 验证，15% 测试）
label_encoder = LabelEncoder()
labels_encoded = label_encoder.fit_transform(adata.obs['cell_type'])  # 预先编码标签


random_seed = 42

X_train_val, X_test, y_train_val, y_test = train_test_split(
    new_data, labels_encoded, test_size=0.20, random_state=random_seed)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.25, random_state=random_seed)  # 0.1765 是为了让验证集占 15%

GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


In [7]:
# 将训练数据转换为张量
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_train_tensor = torch.tensor(X_train).float().to(device)
y_train_tensor = torch.tensor(y_train).long().to(device)
X_val_tensor = torch.tensor(X_val).float().to(device)
y_val_tensor = torch.tensor(y_val).long().to(device)
estim.model.to(device)

MLPAutoEncoder(
  (train_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=train_
  )
  (val_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=val_
  )
  (test_metrics): MetricCollection(
    (explained_var_uniform): ExplainedVariance()
    (explained_var_weighted): ExplainedVariance()
    (mse): MeanSquaredError(),
    prefix=test_
  )
  (encoder): MLP(
    (0): Linear(in_features=19331, out_features=512, bias=True)
    (1): SELU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): SELU()
    (5): Dropout(p=0.1, inplace=False)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SELU()
    (8): Dropout(p=0.1, inplace=False)
    (9): Linear(in_features=256, out_features=256, b

In [8]:
# 6. 微调模型，仅微调 encoder 的最后两层，其他层参数冻结
for param in estim.model.encoder.parameters():
    param.requires_grad = False  # 冻结所有层

# 解冻所有层
for param in list(estim.model.encoder.parameters())[-5:]:
    param.requires_grad = True

In [9]:
# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(filter(lambda p: p.requires_grad, estim.model.parameters()), lr=9e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

In [10]:
from torch.utils.data import DataLoader, TensorDataset

# 设置 batch size
batch_size = 128  # 根据实际需求调整 batch size

# 使用 TensorDataset 将训练数据和标签打包
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)

# 使用 DataLoader 来创建批次
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# 7. 微调模型
def train_epoch(model, optimizer, loss_fn, train_loader, val_loader):
    model.train()
    total_train_loss = 0
    
    # 训练集批次训练
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model.encoder(X_batch)
        logits = model.fc(outputs)
        
        # 计算损失
        loss = loss_fn(logits, y_batch)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    # 验证集
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for X_val_batch, y_val_batch in val_loader:
            val_outputs = model.encoder(X_val_batch)
            val_logits = model.fc(val_outputs)
            val_loss = loss_fn(val_logits, y_val_batch)
            total_val_loss += val_loss.item()

    # 返回平均损失
    return total_train_loss / len(train_loader), total_val_loss / len(val_loader)

In [11]:
import copy  # 用于保存模型的最佳状态

# Early Stopping 参数
patience = 20  # 如果验证损失在 10 个 epoch 中没有改善，停止训练
min_delta = 1e-4  # 最小改善幅度
patience_counter = 0
best_val_loss = float('inf')  # 初始设置为正无穷大
best_model_weights = copy.deepcopy(estim.model.state_dict())  # 保存最佳模型权重
train_losses = []
val_losses = []

# 训练 500 个 epoch
for epoch in range(500):
    train_loss, val_loss = train_epoch(estim.model, optimizer, loss_fn, train_loader, val_loader)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}')
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Early Stopping 检查
    if val_loss < best_val_loss - min_delta:
        best_val_loss = val_loss  # 更新最佳验证损失
        patience_counter = 0  # 重置 patience 计数器
        best_model_weights = copy.deepcopy(estim.model.state_dict())  # 保存当前最佳模型
        print(f"Validation loss improved to {best_val_loss}, resetting patience.")
    else:
        patience_counter += 1
        print(f"No improvement in validation loss. Patience counter: {patience_counter}/{patience}")
    
    # 如果 patience_counter 超过设置的 patience，停止训练
    if patience_counter >= patience:
        print(f"Early stopping triggered. Stopping training at epoch {epoch+1}.")
        break

# 恢复到训练中性能最好的模型权重
estim.model.load_state_dict(best_model_weights)
print("Loaded best model weights based on validation loss.")


Epoch 1, Train Loss: 0.3318599066695473, Validation Loss: 0.015770088669182305
Validation loss improved to 0.015770088669182305, resetting patience.


Epoch 2, Train Loss: 0.04112781888099095, Validation Loss: 0.015801679304811432
No improvement in validation loss. Patience counter: 1/20


Epoch 3, Train Loss: 0.03726000937646796, Validation Loss: 0.02348339681276619
No improvement in validation loss. Patience counter: 2/20


Epoch 4, Train Loss: 0.0320056788791087, Validation Loss: 0.01938892205109777
No improvement in validation loss. Patience counter: 3/20


Epoch 5, Train Loss: 0.02675753051662369, Validation Loss: 0.016527209525847124
No improvement in validation loss. Patience counter: 4/20


Epoch 6, Train Loss: 0.023874334590393596, Validation Loss: 0.015117963222158378
Validation loss improved to 0.015117963222158378, resetting patience.


Epoch 7, Train Loss: 0.021458201179315857, Validation Loss: 0.012870218150441806
Validation loss improved to 0.012870218150441806, resetting patience.


Epoch 8, Train Loss: 0.0203947855203038, Validation Loss: 0.013778970639730093
No improvement in validation loss. Patience counter: 1/20


Epoch 9, Train Loss: 0.021742410089085763, Validation Loss: 0.013502062404616166
No improvement in validation loss. Patience counter: 2/20


Epoch 10, Train Loss: 0.021545734487162755, Validation Loss: 0.015074957891304119
No improvement in validation loss. Patience counter: 3/20


Epoch 11, Train Loss: 0.021269000446908132, Validation Loss: 0.014508267664161988
No improvement in validation loss. Patience counter: 4/20


Epoch 12, Train Loss: 0.020535354064432854, Validation Loss: 0.014183611303689054
No improvement in validation loss. Patience counter: 5/20


Epoch 13, Train Loss: 0.020280855469276052, Validation Loss: 0.014745032087969831
No improvement in validation loss. Patience counter: 6/20


Epoch 14, Train Loss: 0.01931668992741062, Validation Loss: 0.012452845459780775
Validation loss improved to 0.012452845459780775, resetting patience.


Epoch 15, Train Loss: 0.020568833412456192, Validation Loss: 0.014747460446990455
No improvement in validation loss. Patience counter: 1/20


Epoch 16, Train Loss: 0.01928865736636929, Validation Loss: 0.01547870813184109
No improvement in validation loss. Patience counter: 2/20


Epoch 17, Train Loss: 0.018135238486758798, Validation Loss: 0.01297550047329052
No improvement in validation loss. Patience counter: 3/20


Epoch 18, Train Loss: 0.01908708754915207, Validation Loss: 0.013728341765376432
No improvement in validation loss. Patience counter: 4/20


Epoch 19, Train Loss: 0.019393708624061572, Validation Loss: 0.012826283502615799
No improvement in validation loss. Patience counter: 5/20


Epoch 20, Train Loss: 0.018784685358130354, Validation Loss: 0.014252704373516003
No improvement in validation loss. Patience counter: 6/20


Epoch 21, Train Loss: 0.018560808705369906, Validation Loss: 0.012767227424061571
No improvement in validation loss. Patience counter: 7/20


Epoch 22, Train Loss: 0.01863889792115897, Validation Loss: 0.013181658356494822
No improvement in validation loss. Patience counter: 8/20


Epoch 23, Train Loss: 0.01871981339342622, Validation Loss: 0.013869168042511235
No improvement in validation loss. Patience counter: 9/20


Epoch 24, Train Loss: 0.0176405188407019, Validation Loss: 0.012628219401582954
No improvement in validation loss. Patience counter: 10/20


Epoch 25, Train Loss: 0.016706188783416023, Validation Loss: 0.01341250678617923
No improvement in validation loss. Patience counter: 11/20


Epoch 26, Train Loss: 0.018064691511971348, Validation Loss: 0.014438184068614869
No improvement in validation loss. Patience counter: 12/20


Epoch 27, Train Loss: 0.017341191787296515, Validation Loss: 0.013007820534032214
No improvement in validation loss. Patience counter: 13/20


Epoch 28, Train Loss: 0.018533301075457424, Validation Loss: 0.013887983476161025
No improvement in validation loss. Patience counter: 14/20


Epoch 29, Train Loss: 0.016408248587734743, Validation Loss: 0.01239413189668623
No improvement in validation loss. Patience counter: 15/20


Epoch 30, Train Loss: 0.017481073447605033, Validation Loss: 0.014225954974140873
No improvement in validation loss. Patience counter: 16/20


Epoch 31, Train Loss: 0.017364469432473082, Validation Loss: 0.012025963616728478
Validation loss improved to 0.012025963616728478, resetting patience.


Epoch 32, Train Loss: 0.018080126412426383, Validation Loss: 0.01386375992400533
No improvement in validation loss. Patience counter: 1/20


Epoch 33, Train Loss: 0.017723131099498086, Validation Loss: 0.01230654192345819
No improvement in validation loss. Patience counter: 2/20


Epoch 34, Train Loss: 0.017701392220550803, Validation Loss: 0.0130497068707381
No improvement in validation loss. Patience counter: 3/20


Epoch 35, Train Loss: 0.016213174177866396, Validation Loss: 0.011697120537591756
Validation loss improved to 0.011697120537591756, resetting patience.


Epoch 36, Train Loss: 0.01704241730723661, Validation Loss: 0.013943387491202216
No improvement in validation loss. Patience counter: 1/20


Epoch 37, Train Loss: 0.017332344371185807, Validation Loss: 0.013082848631410696
No improvement in validation loss. Patience counter: 2/20


Epoch 38, Train Loss: 0.016476738797404644, Validation Loss: 0.014378325583582372
No improvement in validation loss. Patience counter: 3/20


Epoch 39, Train Loss: 0.01681560137307387, Validation Loss: 0.015032321868858738
No improvement in validation loss. Patience counter: 4/20


Epoch 40, Train Loss: 0.016142718296167397, Validation Loss: 0.011850832314191895
No improvement in validation loss. Patience counter: 5/20


Epoch 41, Train Loss: 0.01544073383235961, Validation Loss: 0.014010150199650599
No improvement in validation loss. Patience counter: 6/20


Epoch 42, Train Loss: 0.016259572797054152, Validation Loss: 0.014318043637245047
No improvement in validation loss. Patience counter: 7/20


Epoch 43, Train Loss: 0.016885919054539105, Validation Loss: 0.013633404033784505
No improvement in validation loss. Patience counter: 8/20


Epoch 44, Train Loss: 0.015705448389324334, Validation Loss: 0.016330410328106264
No improvement in validation loss. Patience counter: 9/20


Epoch 45, Train Loss: 0.016462317124150243, Validation Loss: 0.012157970830505328
No improvement in validation loss. Patience counter: 10/20


Epoch 46, Train Loss: 0.015521606722670575, Validation Loss: 0.012615681478839215
No improvement in validation loss. Patience counter: 11/20


Epoch 47, Train Loss: 0.015905810638393927, Validation Loss: 0.013680423369857765
No improvement in validation loss. Patience counter: 12/20


Epoch 48, Train Loss: 0.01583389079665054, Validation Loss: 0.011888987541243871
No improvement in validation loss. Patience counter: 13/20


Epoch 49, Train Loss: 0.016244933723792102, Validation Loss: 0.0116249574630014
No improvement in validation loss. Patience counter: 14/20


Epoch 50, Train Loss: 0.01607424018795854, Validation Loss: 0.01312536041967512
No improvement in validation loss. Patience counter: 15/20


Epoch 51, Train Loss: 0.015390348855578287, Validation Loss: 0.011402691615616536
Validation loss improved to 0.011402691615616536, resetting patience.


Epoch 52, Train Loss: 0.016665383042830537, Validation Loss: 0.01301050482416402
No improvement in validation loss. Patience counter: 1/20


Epoch 53, Train Loss: 0.015997713039906145, Validation Loss: 0.01460901148224069
No improvement in validation loss. Patience counter: 2/20


Epoch 54, Train Loss: 0.015686591734121746, Validation Loss: 0.012777660592529378
No improvement in validation loss. Patience counter: 3/20


Epoch 55, Train Loss: 0.01586507439716248, Validation Loss: 0.011861798545467128
No improvement in validation loss. Patience counter: 4/20


Epoch 56, Train Loss: 0.016274918649108756, Validation Loss: 0.012296794596061643
No improvement in validation loss. Patience counter: 5/20


Epoch 57, Train Loss: 0.015690517999323295, Validation Loss: 0.012692008639232324
No improvement in validation loss. Patience counter: 6/20


Epoch 58, Train Loss: 0.016388409537191985, Validation Loss: 0.015977916825951458
No improvement in validation loss. Patience counter: 7/20


Epoch 59, Train Loss: 0.015663692229038684, Validation Loss: 0.015058488052722058
No improvement in validation loss. Patience counter: 8/20


Epoch 60, Train Loss: 0.015652814938568547, Validation Loss: 0.011535633323498532
No improvement in validation loss. Patience counter: 9/20


Epoch 61, Train Loss: 0.014031576744400188, Validation Loss: 0.013110619544885576
No improvement in validation loss. Patience counter: 10/20


Epoch 62, Train Loss: 0.015783236106503052, Validation Loss: 0.012194998943737212
No improvement in validation loss. Patience counter: 11/20


Epoch 63, Train Loss: 0.01643100395156612, Validation Loss: 0.0122029536797731
No improvement in validation loss. Patience counter: 12/20


Epoch 64, Train Loss: 0.014283045023447617, Validation Loss: 0.012739647279025176
No improvement in validation loss. Patience counter: 13/20


Epoch 65, Train Loss: 0.015799530840782464, Validation Loss: 0.01212644466556171
No improvement in validation loss. Patience counter: 14/20


Epoch 66, Train Loss: 0.01561722361156026, Validation Loss: 0.013961328300798896
No improvement in validation loss. Patience counter: 15/20


Epoch 67, Train Loss: 0.015712605543807212, Validation Loss: 0.015624012538259825
No improvement in validation loss. Patience counter: 16/20


Epoch 68, Train Loss: 0.014411457934335722, Validation Loss: 0.013507535050815263
No improvement in validation loss. Patience counter: 17/20


Epoch 69, Train Loss: 0.015089859748693759, Validation Loss: 0.014180564419283266
No improvement in validation loss. Patience counter: 18/20


Epoch 70, Train Loss: 0.016441752672599473, Validation Loss: 0.013492411923638132
No improvement in validation loss. Patience counter: 19/20


Epoch 71, Train Loss: 0.015390273288867221, Validation Loss: 0.013904797630976292
No improvement in validation loss. Patience counter: 20/20
Early stopping triggered. Stopping training at epoch 71.
Loaded best model weights based on validation loss.


In [12]:
# 8. 使用 KNN 替代测试阶段的 FC 分类层
# 使用 encoder 提取训练集和测试集的 embedding
estim.model.eval()
with torch.no_grad():
    train_embeddings = estim.model.encoder(torch.tensor(X_train).float().to(device)).cpu().numpy()
    test_embeddings = estim.model.encoder(torch.tensor(X_test).float().to(device)).cpu().numpy()

In [13]:
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report

    

    # 初始化和训练KNN分类器
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(train_embeddings, y_train)
    
    # 模型预测
predictions = knn.predict(test_embeddings)

    # 计算准确率和 F1 分数
accuracy = accuracy_score(y_test, predictions)
print(f"KNN Accuracy on Test Data: {accuracy}")
f1 = f1_score(y_test, predictions, average='weighted')
print(f"Weighted F1 Score: {f1}")
    
macro_f1 = f1_score(y_test, predictions, average='macro')
print(f'Macro F1 Score: {macro_f1}')

    # 计算随机猜测的准确率
class_probabilities = np.bincount(y_test) / len(y_test)
random_accuracy = np.sum(class_probabilities ** 2)
print(f"Random Guess Accuracy: {random_accuracy}")

    # 生成分类报告
report = classification_report(y_test, predictions, target_names=label_encoder.classes_)
print(report)

KNN Accuracy on Test Data: 0.9970886521419202
Weighted F1 Score: 0.9970813306585145
Macro F1 Score: 0.9916985583188406
Random Guess Accuracy: 0.32775393340845166
                                                                precision    recall  f1-score   support

     L2/3-6 intratelencephalic projecting glutamatergic neuron       1.00      1.00      1.00     17236
L5 extratelencephalic projecting glutamatergic cortical neuron       1.00      1.00      1.00        87
                             L6b glutamatergic cortical neuron       1.00      1.00      1.00       702
                            VIP GABAergic cortical interneuron       0.99      1.00      1.00      2022
                              astrocyte of the cerebral cortex       1.00      1.00      1.00       639
                caudal ganglionic eminence derived interneuron       0.97      0.96      0.97       162
                              cerebral cortex endothelial cell       0.96      1.00      0.98        27
     

In [14]:

import pandas as pd
import os
import re

# 当前 Notebook 文件名
notebook_name = "Human_Great_Apes_random_mask_fine_tune_42.ipynb"

# 初始化需要打印的值
init_train_loss = train_losses[0] if 'train_losses' in globals() else None
init_val_loss = val_losses[0] if 'val_losses' in globals() else None
converged_epoch = len(train_losses) - patience if 'train_losses' in globals() else None
converged_val_loss = best_val_loss if 'best_val_loss' in globals() else None

# 打印所有所需的指标
print("Metrics Summary:")
if 'train_losses' in globals():
    print(f"init_train_loss\tinit_val_loss\tconverged_epoch\tconverged_val_loss\tmacro_f1\tweighted_f1\tmicro_f1")
    print(f"{init_train_loss:.3f}\t{init_val_loss:.3f}\t{converged_epoch}\t{converged_val_loss:.3f}\t{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")
else:
    print(f"macro_f1\tweighted_f1\tmicor_f1")
    print(f"{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")

# 保存结果到 CSV 文件
output_data = {
    'dataset_split_random_seed': [int(random_seed)],
    'dataset': ['Human_Great_Apes'],
    'method': [re.search(r'Human_Great_Apes_(.*?)_\d+', notebook_name).group(1)],
    'init_train_loss': [init_train_loss if init_train_loss is not None else ''],
    'init_val_loss': [init_val_loss if init_val_loss is not None else ''],
    'converged_epoch': [converged_epoch if converged_epoch is not None else ''],
    'converged_val_loss': [converged_val_loss if converged_val_loss is not None else ''],
    'macro_f1': [macro_f1],
    'weighted_f1': [f1],
    'micro_f1': [accuracy]
}
output_df = pd.DataFrame(output_data)

# 保存到当前目录下名为 results 的文件夹中
if not os.path.exists('results'):
    os.makedirs('results')

csv_filename = f"results/{os.path.splitext(notebook_name)[0]}_results.csv"
output_df.to_csv(csv_filename, index=False)


Metrics Summary:
init_train_loss	init_val_loss	converged_epoch	converged_val_loss	macro_f1	weighted_f1	micro_f1
0.332	0.016	51	0.011	0.992	0.997	0.997
