In [1]:
import scanpy as sc
import torch
import lightning.pytorch as pl
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from self_supervision.models.lightning_modules.cellnet_autoencoder import MLPAutoEncoder
from self_supervision.estimator.cellnet import EstimatorAutoEncoder
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np
import pandas as pd

train_data_dir = '../../dataset/PBMC_train_set_mapped.h5ad'
val_data_dir = '../../dataset/PBMC_val_set_mapped.h5ad'
test_data_dir = '../../dataset/PBMC_test_set_mapped.h5ad'

# Load the datasets (no change in loading)
adata_train = sc.read_h5ad(train_data_dir)
adata_val = sc.read_h5ad(val_data_dir)
adata_test = sc.read_h5ad(test_data_dir)


  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


In [2]:
sc.pp.normalize_total(adata_train, target_sum=1e4)
sc.pp.log1p(adata_train)

sc.pp.normalize_total(adata_val, target_sum=1e4)
sc.pp.log1p(adata_val)

sc.pp.normalize_total(adata_test, target_sum=1e4)
sc.pp.log1p(adata_test)


In [3]:
# 2. 设置 checkpoint 路径
ckpt_path = "../../sc_pretrained/Pretrained Models/GPMask.ckpt"

# 3. 模型参数
units_encoder = [512, 512, 256, 256, 64]
units_decoder = [256, 256, 512, 512]

# 初始化 EstimatorAutoEncoder 实例
estim = EstimatorAutoEncoder(data_path=None)  # 没有实际数据路径，可以设置为None

# 4. 加载预训练模型
estim.model = MLPAutoEncoder.load_from_checkpoint(
    ckpt_path,
    gene_dim=19331,  
    batch_size=128,  
    units_encoder=units_encoder, 
    units_decoder=units_decoder,
    masking_strategy="random",  # 假设模型使用了随机掩码
    masking_rate=0.5,  # 根据需要调整
)

In [4]:
# Access Anndata.X and transform to tensors
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



# Step 1: Take the union of all unique labels across the three datasets
all_labels = np.concatenate([
    adata_train.obs['cell_type'].values, 
    adata_val.obs['cell_type'].values, 
    adata_test.obs['cell_type'].values
])

# Step 2: Fit LabelEncoder on the combined labels
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

# Directly use the 'cell_type' column, assuming it is already encoded as int64
X_train = torch.tensor(adata_train.X.toarray()).float().to(device)
y_train = torch.tensor(label_encoder.transform(adata_train.obs['cell_type'])).long().to(device)

X_val = torch.tensor(adata_val.X).float().to(device)
y_val = torch.tensor(label_encoder.transform(adata_val.obs['cell_type'])).long().to(device)

X_test = torch.tensor(adata_test.X).float().to(device)
y_test = torch.tensor(label_encoder.transform(adata_test.obs['cell_type'])).long().to(device)

# The rest of the code remains the same


# Create TensorDataset and DataLoader for train, val, test
batch_size = 256
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [5]:
# Add classification layer
n_classes = len(label_encoder.classes_)
estim.model.fc = nn.Linear(units_encoder[-1], n_classes)

# Fine-tuning: Enable gradient updates for the inner model
for param in list(estim.model.encoder.parameters()):
    param.requires_grad = True

estim.model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(filter(lambda p: p.requires_grad, estim.model.parameters()), lr=9e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)


In [6]:
def train_epoch(model, optimizer, loss_fn, train_loader, val_loader):
    model.train()
    total_train_loss = 0
    
    # 训练集批次训练
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        
        # 前向传播
        outputs = model.encoder(X_batch)
        logits = model.fc(outputs)
        
        # 计算损失
        loss = loss_fn(logits, y_batch)
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    # 验证集
    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for X_val_batch, y_val_batch in val_loader:
            val_outputs = model.encoder(X_val_batch)
            val_logits = model.fc(val_outputs)
            val_loss = loss_fn(val_logits, y_val_batch)
            total_val_loss += val_loss.item()

    # 返回平均损失
    return total_train_loss / len(train_loader), total_val_loss / len(val_loader)

In [7]:
import copy  # 用于保存模型的最佳状态

# Early Stopping 参数
patience = 20  # 如果验证损失在 10 个 epoch 中没有改善，停止训练
min_delta = 1e-4  # 最小改善幅度
patience_counter = 0
best_val_loss = float('inf')  # 初始设置为正无穷大
best_model_weights = copy.deepcopy(estim.model.state_dict())  # 保存最佳模型权重
train_losses = []
val_losses = []

# 训练 500 个 epoch
for epoch in range(500):
    train_loss, val_loss = train_epoch(estim.model, optimizer, loss_fn, train_loader, val_loader)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}')
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Early Stopping 检查
    if val_loss < best_val_loss - min_delta:
        best_val_loss = val_loss  # 更新最佳验证损失
        patience_counter = 0  # 重置 patience 计数器
        best_model_weights = copy.deepcopy(estim.model.state_dict())  # 保存当前最佳模型
        print(f"Validation loss improved to {best_val_loss}, resetting patience.")
    else:
        patience_counter += 1
        print(f"No improvement in validation loss. Patience counter: {patience_counter}/{patience}")
    
    # 如果 patience_counter 超过设置的 patience，停止训练
    if patience_counter >= patience:
        print(f"Early stopping triggered. Stopping training at epoch {epoch+1}.")
        break

# 恢复到训练中性能最好的模型权重
estim.model.load_state_dict(best_model_weights)
print("Loaded best model weights based on validation loss.")

Epoch 1, Train Loss: 0.5952249283656623, Validation Loss: 3.7269207257213015
Validation loss improved to 3.7269207257213015, resetting patience.


Epoch 2, Train Loss: 0.23102404732948795, Validation Loss: 3.7844744595614346
No improvement in validation loss. Patience counter: 1/20


Epoch 3, Train Loss: 0.15517819488378612, Validation Loss: 4.371565851298246
No improvement in validation loss. Patience counter: 2/20


Epoch 4, Train Loss: 0.10933016774950682, Validation Loss: 5.2204855969457915
No improvement in validation loss. Patience counter: 3/20


Epoch 5, Train Loss: 0.07670493386321814, Validation Loss: 4.942798329122139
No improvement in validation loss. Patience counter: 4/20


Epoch 6, Train Loss: 0.05435169129162202, Validation Loss: 5.530504707134131
No improvement in validation loss. Patience counter: 5/20


Epoch 7, Train Loss: 0.043007080491741435, Validation Loss: 5.674379121173512
No improvement in validation loss. Patience counter: 6/20


Epoch 8, Train Loss: 0.034413929126932245, Validation Loss: 5.529855970180396
No improvement in validation loss. Patience counter: 7/20


Epoch 9, Train Loss: 0.027659065218764458, Validation Loss: 5.867149739554434
No improvement in validation loss. Patience counter: 8/20


Epoch 10, Train Loss: 0.02933732785416482, Validation Loss: 5.452914411371404
No improvement in validation loss. Patience counter: 9/20


Epoch 11, Train Loss: 0.01713581519248419, Validation Loss: 5.759479367371761
No improvement in validation loss. Patience counter: 10/20


Epoch 12, Train Loss: 0.018937552391577222, Validation Loss: 6.2268166722673355
No improvement in validation loss. Patience counter: 11/20


Epoch 13, Train Loss: 0.02007195018643304, Validation Loss: 6.588129534865871
No improvement in validation loss. Patience counter: 12/20


Epoch 14, Train Loss: 0.03293966824918415, Validation Loss: 5.677758397478046
No improvement in validation loss. Patience counter: 13/20


Epoch 15, Train Loss: 0.017730805359846013, Validation Loss: 6.520015366149671
No improvement in validation loss. Patience counter: 14/20


Epoch 16, Train Loss: 0.018439357903014705, Validation Loss: 6.556045535838965
No improvement in validation loss. Patience counter: 15/20


Epoch 17, Train Loss: 0.015887243488663865, Validation Loss: 5.893736853744045
No improvement in validation loss. Patience counter: 16/20


Epoch 18, Train Loss: 0.017090056291102593, Validation Loss: 5.592822631200154
No improvement in validation loss. Patience counter: 17/20


Epoch 19, Train Loss: 0.0277594433282739, Validation Loss: 5.66216702894731
No improvement in validation loss. Patience counter: 18/20


Epoch 20, Train Loss: 0.01320582971037268, Validation Loss: 5.581631526802525
No improvement in validation loss. Patience counter: 19/20


Epoch 21, Train Loss: 0.018528300511026195, Validation Loss: 5.463535178791393
No improvement in validation loss. Patience counter: 20/20
Early stopping triggered. Stopping training at epoch 21.
Loaded best model weights based on validation loss.


In [8]:
from sklearn.neighbors import KNeighborsClassifier

# Evaluate on Test Set  
estim.model.eval()  
with torch.no_grad():  
    test_embeddings = estim.model.encoder(X_test).detach().cpu().numpy()  
    val_embeddings = estim.model.encoder(X_val).detach().cpu().numpy()  
    train_embeddings = estim.model.encoder(X_train).detach().cpu().numpy()  

# KNN Classification  
knn = KNeighborsClassifier(n_neighbors=5)  
knn.fit(val_embeddings, y_val.cpu().numpy())  
predictions = knn.predict(test_embeddings)  

# First, let's get the actual unique classes present in both y_test and predictions  
unique_classes = np.unique(np.concatenate([y_test.cpu().numpy(), predictions]))


accuracy = accuracy_score(y_test.cpu().numpy(), predictions)  
f1 = f1_score(y_test.cpu().numpy(), predictions, average='weighted')  
macro_f1 = f1_score(y_test.cpu().numpy(), predictions, average='macro')  

print(f"KNN Accuracy: {accuracy}")  
print(f"Weighted F1 Score: {f1}")  
print(f"Macro F1 Score: {macro_f1}")  

# Get the class names for only the classes present in the data  
present_classes = [label_encoder.classes_[i] for i in unique_classes]  
report = classification_report(y_test.cpu().numpy(), predictions,   
                             labels=unique_classes,  # specify which labels to include  
                             target_names=present_classes)  # their corresponding names  
print(report)  

# Optionally, print which class is missing  
all_classes_set = set(range(len(label_encoder.classes_)))  
present_classes_set = set(unique_classes)  
missing_classes = all_classes_set - present_classes_set  
if missing_classes:  
    print("\nMissing class indices:", missing_classes)  
    print("Missing class names:", [label_encoder.classes_[i] for i in missing_classes])
    
random_seed = 42

import numpy as np
import os
import json

# Create directory to store embeddings and predictions
output_dir = os.path.join('./prediction_results', f'GP_mask_fine_tune_seed_{random_seed}')
os.makedirs(output_dir, exist_ok=True)

# Save embeddings
np.save(os.path.join(output_dir, 'train_embeddings.npy'), train_embeddings)
np.save(os.path.join(output_dir, 'val_embeddings.npy'), val_embeddings) 
np.save(os.path.join(output_dir, 'test_embeddings.npy'), test_embeddings)

# Save predictions and ground truth
np.save(os.path.join(output_dir, 'test_predictions.npy'), predictions)
np.save(os.path.join(output_dir, 'test_ground_truth.npy'), y_test.cpu().numpy())
np.save(os.path.join(output_dir, 'train_ground_truth.npy'), y_train.cpu().numpy())
np.save(os.path.join(output_dir, 'val_ground_truth.npy'), y_val.cpu().numpy())

# Save training history if exists
if 'train_losses' in globals() and 'val_losses' in globals():
    np.save(os.path.join(output_dir, 'train_losses.npy'), np.array(train_losses))
    np.save(os.path.join(output_dir, 'val_losses.npy'), np.array(val_losses))

# Save label encoder classes (target names)
label_mapping = {i: label_name for i, label_name in enumerate(label_encoder.classes_)}
with open(os.path.join(output_dir, 'label_mapping.json'), 'w') as f:
    json.dump(label_mapping, f, indent=4)

print(f"Saved embeddings, predictions and label mapping to {output_dir}")


KNN Accuracy: 0.41781551044499254
Weighted F1 Score: 0.386143524420991
Macro F1 Score: 0.2729916408541471


                                                      precision    recall  f1-score   support

                                              B cell       0.43      0.55      0.48     11338
                CD1c-positive myeloid dendritic cell       0.00      0.00      0.00        95
                          CD4-positive helper T cell       0.00      0.00      0.00      7784
                     CD4-positive, alpha-beta T cell       0.15      0.20      0.17     11008
              CD4-positive, alpha-beta memory T cell       0.00      0.00      0.00      2866
                     CD8-positive, alpha-beta T cell       0.22      0.14      0.17      8505
           CD8-positive, alpha-beta cytotoxic T cell       0.00      0.00      0.00       543
              CD8-positive, alpha-beta memory T cell       0.00      0.00      0.00      4408
                                        Schwann cell       0.00      0.00      0.00        10
                                              T cell       

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [9]:

import pandas as pd
import os
import re

# 当前 Notebook 文件名
notebook_name = "PBMC_GP_mask_fine_tune_42.ipynb"

# 初始化需要打印的值
init_train_loss = train_losses[0] if 'train_losses' in globals() else None
init_val_loss = val_losses[0] if 'val_losses' in globals() else None
converged_epoch = len(train_losses) - patience if 'train_losses' in globals() else None
converged_val_loss = best_val_loss if 'best_val_loss' in globals() else None

# 打印所有所需的指标
print("Metrics Summary:")
if 'train_losses' in globals():
    print(f"init_train_loss\tinit_val_loss\tconverged_epoch\tconverged_val_loss\tmacro_f1\tweighted_f1\tmicro_f1")
    print(f"{init_train_loss:.3f}\t{init_val_loss:.3f}\t{converged_epoch}\t{converged_val_loss:.3f}\t{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")
else:
    print(f"macro_f1\tweighted_f1\tmicro_f1")
    print(f"{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")

# 保存结果到 CSV 文件
output_data = {
    'dataset_split_random_seed': [int(random_seed)],
    'dataset': ['PBMC'],
    'method': [re.search(r'PBMC_(.*?)_\d+', notebook_name).group(1)],
    'init_train_loss': [init_train_loss if init_train_loss is not None else ''],
    'init_val_loss': [init_val_loss if init_val_loss is not None else ''],
    'converged_epoch': [converged_epoch if converged_epoch is not None else ''],
    'converged_val_loss': [converged_val_loss if converged_val_loss is not None else ''],
    'macro_f1': [macro_f1],
    'weighted_f1': [f1],
    'micro_f1': [accuracy]
}
output_df = pd.DataFrame(output_data)

# 保存到当前目录下名为 results 的文件夹中
if not os.path.exists('results'):
    os.makedirs('results')

csv_filename = f"results/{os.path.splitext(notebook_name)[0]}_results.csv"
output_df.to_csv(csv_filename, index=False)


Metrics Summary:
init_train_loss	init_val_loss	converged_epoch	converged_val_loss	macro_f1	weighted_f1	micro_f1
0.595	3.727	1	3.727	0.273	0.386	0.418
