In [1]:
import scanpy as sc
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
import lightning.pytorch as pl
from sklearn.preprocessing import LabelEncoder
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report
import numpy as np
import pandas as pd
import os
import re
import copy

# Paths for datasets
train_data_dir = '../../dataset/PBMC_train_set_mapped.h5ad'
val_data_dir = '../../dataset/PBMC_val_set_mapped.h5ad'
test_data_dir = '../../dataset/PBMC_test_set_mapped.h5ad'

# Load the datasets (no change in loading)
adata_train = sc.read_h5ad(train_data_dir)
adata_val = sc.read_h5ad(val_data_dir)
adata_test = sc.read_h5ad(test_data_dir)


In [2]:
sc.pp.normalize_total(adata_train, target_sum=1e4)
sc.pp.log1p(adata_train)

sc.pp.normalize_total(adata_val, target_sum=1e4)
sc.pp.log1p(adata_val)

sc.pp.normalize_total(adata_test, target_sum=1e4)
sc.pp.log1p(adata_test)


In [3]:
# Access Anndata.X and transform to tensors
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



# Step 1: Take the union of all unique labels across the three datasets
all_labels = np.concatenate([
    adata_train.obs['cell_type'].values, 
    adata_val.obs['cell_type'].values, 
    adata_test.obs['cell_type'].values
])

# Step 2: Fit LabelEncoder on the combined labels
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

# Directly use the 'cell_type' column, assuming it is already encoded as int64
X_train = torch.tensor(adata_train.X.toarray()).float().to(device)
y_train = torch.tensor(label_encoder.transform(adata_train.obs['cell_type'])).long().to(device)

X_val = torch.tensor(adata_val.X).float().to(device)
y_val = torch.tensor(label_encoder.transform(adata_val.obs['cell_type'])).long().to(device)

X_test = torch.tensor(adata_test.X).float().to(device)
y_test = torch.tensor(label_encoder.transform(adata_test.obs['cell_type'])).long().to(device)

# The rest of the code remains the same


# Create TensorDataset and DataLoader for train, val, test
batch_size = 256
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
# Model Initialization
ckpt_path = "../../sc_pretrained/Pretrained Models/BarlowTwins.ckpt"
units_encoder = [512, 512, 256, 256, 64]

from self_supervision.models.lightning_modules.cellnet_autoencoder import MLPBarlowTwins
from self_supervision.estimator.cellnet import EstimatorAutoEncoder

estim = EstimatorAutoEncoder(data_path=None)
estim.model = MLPBarlowTwins(
    gene_dim=X_train.shape[1],  # Number of genes
    batch_size=batch_size,
    units_encoder=units_encoder,
    CHECKPOINT_PATH=ckpt_path
)

# Load pre-trained checkpoint
checkpoint = torch.load(ckpt_path)
estim.model.inner_model.load_state_dict({k.replace('backbone.', ''): v for k, v in checkpoint.items() if 'backbone' in k})

# Add classification layer
n_classes = len(label_encoder.classes_)
estim.model.fc = nn.Linear(units_encoder[-1], n_classes)

# Fine-tuning: Enable gradient updates for the inner model
for param in list(estim.model.inner_model.parameters()):
    param.requires_grad = True

estim.model.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(filter(lambda p: p.requires_grad, estim.model.parameters()), lr=9e-4, weight_decay=0.05)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

# Trainer with Lightning
estim.trainer = pl.Trainer(accelerator="gpu", devices=1 if torch.cuda.is_available() else None)

  warn(f"Tensorflow dtype mappings did not load successfully due to an error: {exc.msg}")
  warn(f"Triton dtype mappings did not load successfully due to an error: {exc.msg}")


GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores


HPU available: False, using: 0 HPUs


In [5]:


# Training and Validation Pipeline
def train_epoch(model, optimizer, loss_fn, train_loader, val_loader):
    model.train()
    total_train_loss = 0
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        outputs = model.inner_model(X_batch)
        logits = model.fc(outputs)
        loss = loss_fn(logits, y_batch)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

    model.eval()
    total_val_loss = 0
    with torch.no_grad():
        for X_val_batch, y_val_batch in val_loader:
            val_outputs = model.inner_model(X_val_batch)
            val_logits = model.fc(val_outputs)
            val_loss = loss_fn(val_logits, y_val_batch)
            total_val_loss += val_loss.item()

    return total_train_loss / len(train_loader), total_val_loss / len(val_loader)

# Early Stopping and Model Saving
patience = 20
min_delta = 1e-4
patience_counter = 0
best_val_loss = float('inf')
best_model_weights = copy.deepcopy(estim.model.state_dict())
train_losses = []
val_losses = []

for epoch in range(500):
    train_loss, val_loss = train_epoch(estim.model, optimizer, loss_fn, train_loader, val_loader)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Validation Loss: {val_loss}')
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    if val_loss < best_val_loss - min_delta:
        best_val_loss = val_loss
        patience_counter = 0
        best_model_weights = copy.deepcopy(estim.model.state_dict())
        print(f"Validation loss improved to {best_val_loss}, resetting patience.")
    else:
        patience_counter += 1
        print(f"No improvement in validation loss. Patience counter: {patience_counter}/{patience}")

    if patience_counter >= patience:
        print(f"Early stopping triggered. Stopping training at epoch {epoch+1}.")
        break

estim.model.load_state_dict(best_model_weights)
print("Loaded best model weights based on validation loss.")


Epoch 1, Train Loss: 0.4533096645209416, Validation Loss: 3.525856673717499
Validation loss improved to 3.525856673717499, resetting patience.


Epoch 2, Train Loss: 0.14050636112884124, Validation Loss: 4.097159638549343
No improvement in validation loss. Patience counter: 1/20


Epoch 3, Train Loss: 0.05790462550867943, Validation Loss: 4.6973727309342586
No improvement in validation loss. Patience counter: 2/20


Epoch 4, Train Loss: 0.03768684581628961, Validation Loss: 5.223945057753361
No improvement in validation loss. Patience counter: 3/20


Epoch 5, Train Loss: 0.03363424789119536, Validation Loss: 5.733627145940607
No improvement in validation loss. Patience counter: 4/20


Epoch 6, Train Loss: 0.03360751651550555, Validation Loss: 5.461417147607515
No improvement in validation loss. Patience counter: 5/20


Epoch 7, Train Loss: 0.029334898458085408, Validation Loss: 4.402367647850152
No improvement in validation loss. Patience counter: 6/20


Epoch 8, Train Loss: 0.019592523601942115, Validation Loss: 5.158466411359383
No improvement in validation loss. Patience counter: 7/20


Epoch 9, Train Loss: 0.015802096076467428, Validation Loss: 5.483580448410728
No improvement in validation loss. Patience counter: 8/20


Epoch 10, Train Loss: 0.02134154356228482, Validation Loss: 5.360293883265871
No improvement in validation loss. Patience counter: 9/20


Epoch 11, Train Loss: 0.020102212195669936, Validation Loss: 5.321753173163443
No improvement in validation loss. Patience counter: 10/20


Epoch 12, Train Loss: 0.01238552865052102, Validation Loss: 5.418553966464418
No improvement in validation loss. Patience counter: 11/20


Epoch 13, Train Loss: 0.02152327167673483, Validation Loss: 5.701917883121606
No improvement in validation loss. Patience counter: 12/20


Epoch 14, Train Loss: 0.016949822525160103, Validation Loss: 5.379720120718985
No improvement in validation loss. Patience counter: 13/20


Epoch 15, Train Loss: 0.019381837598098396, Validation Loss: 4.504064455176845
No improvement in validation loss. Patience counter: 14/20


Epoch 16, Train Loss: 0.020323378619757262, Validation Loss: 5.272372249400977
No improvement in validation loss. Patience counter: 15/20


Epoch 17, Train Loss: 0.021489090608944435, Validation Loss: 6.082648660197402
No improvement in validation loss. Patience counter: 16/20


Epoch 18, Train Loss: 0.0072614683039369425, Validation Loss: 4.841500827760408
No improvement in validation loss. Patience counter: 17/20


Epoch 19, Train Loss: 0.012333865210160624, Validation Loss: 4.712962874860475
No improvement in validation loss. Patience counter: 18/20


Epoch 20, Train Loss: 0.011867061840126281, Validation Loss: 5.475575981718121
No improvement in validation loss. Patience counter: 19/20


Epoch 21, Train Loss: 0.016169846820160982, Validation Loss: 4.922058528119868
No improvement in validation loss. Patience counter: 20/20
Early stopping triggered. Stopping training at epoch 21.
Loaded best model weights based on validation loss.


In [6]:
  

# Evaluate on Test Set  
estim.model.eval()  
with torch.no_grad():  
    test_embeddings = estim.model.inner_model(X_test).cpu().numpy()  
    val_embeddings = estim.model.inner_model(X_val).detach().cpu().numpy()  
    train_embeddings = estim.model.inner_model(X_train).detach().cpu().numpy()  

# KNN Classification  
knn = KNeighborsClassifier(n_neighbors=5)  
knn.fit(val_embeddings, y_val.cpu().numpy())  
predictions = knn.predict(test_embeddings)  

# First, get the actual unique classes present in both y_test and predictions  
unique_classes = np.unique(np.concatenate([y_test.cpu().numpy(), predictions]))


accuracy = accuracy_score(y_test.cpu().numpy(), predictions)  
f1 = f1_score(y_test.cpu().numpy(), predictions, average='weighted')  
macro_f1 = f1_score(y_test.cpu().numpy(), predictions, average='macro')  

print(f"KNN Accuracy: {accuracy}")  
print(f"Weighted F1 Score: {f1}")  
print(f"Macro F1 Score: {macro_f1}")  

# Get the class names for only the classes present in the data  
present_classes = [label_encoder.classes_[i] for i in unique_classes]  
report = classification_report(y_test.cpu().numpy(), predictions,   
                             labels=unique_classes,  # specify which labels to include  
                             target_names=present_classes)  # their corresponding names  
print(report)  

# Optionally, print which class is missing  
all_classes_set = set(range(len(label_encoder.classes_)))  
present_classes_set = set(unique_classes)  
missing_classes = all_classes_set - present_classes_set  
if missing_classes:  
    print("\nMissing class indices:", missing_classes)  
    print("Missing class names:", [label_encoder.classes_[i] for i in missing_classes])
random_seed = 42


import numpy as np
import os
import json

# Create directory to store embeddings and predictions
output_dir = os.path.join('./prediction_results', f'barlow_twins_zero_shot_seed_{random_seed}')
os.makedirs(output_dir, exist_ok=True)

# Save embeddings
np.save(os.path.join(output_dir, 'train_embeddings.npy'), train_embeddings)
np.save(os.path.join(output_dir, 'val_embeddings.npy'), val_embeddings) 
np.save(os.path.join(output_dir, 'test_embeddings.npy'), test_embeddings)

# Save predictions and ground truth
np.save(os.path.join(output_dir, 'test_predictions.npy'), predictions)
np.save(os.path.join(output_dir, 'test_ground_truth.npy'), y_test.cpu().numpy())
np.save(os.path.join(output_dir, 'train_ground_truth.npy'), y_train.cpu().numpy())
np.save(os.path.join(output_dir, 'val_ground_truth.npy'), y_val.cpu().numpy())

# Save training history if exists
if 'train_losses' in globals() and 'val_losses' in globals():
    np.save(os.path.join(output_dir, 'train_losses.npy'), np.array(train_losses))
    np.save(os.path.join(output_dir, 'val_losses.npy'), np.array(val_losses))

# Save label encoder classes (target names)
label_mapping = {i: label_name for i, label_name in enumerate(label_encoder.classes_)}
with open(os.path.join(output_dir, 'label_mapping.json'), 'w') as f:
    json.dump(label_mapping, f, indent=4)

print(f"Saved embeddings, predictions and label mapping to {output_dir}")


KNN Accuracy: 0.4208878770631759
Weighted F1 Score: 0.3857580637327141
Macro F1 Score: 0.2617841079856328


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                                                      precision    recall  f1-score   support

                                              B cell       0.46      0.40      0.43     11338
                CD1c-positive myeloid dendritic cell       0.00      0.00      0.00        95
                          CD4-positive helper T cell       0.00      0.00      0.00      7784
                     CD4-positive, alpha-beta T cell       0.15      0.21      0.18     11008
              CD4-positive, alpha-beta memory T cell       0.00      0.00      0.00      2866
                     CD8-positive, alpha-beta T cell       0.22      0.22      0.22      8505
           CD8-positive, alpha-beta cytotoxic T cell       0.00      0.00      0.00       543
              CD8-positive, alpha-beta memory T cell       0.00      0.00      0.00      4408
                                        Schwann cell       0.00      0.00      0.00        10
                                              T cell       

In [7]:

import pandas as pd
import os
import re

# 当前 Notebook 文件名
notebook_name = "PBMC_barlow_twins_fine_tune_42.ipynb"

# 初始化需要打印的值
init_train_loss = train_losses[0] if 'train_losses' in globals() else None
init_val_loss = val_losses[0] if 'val_losses' in globals() else None
converged_epoch = len(train_losses) - patience if 'train_losses' in globals() else None
converged_val_loss = best_val_loss if 'best_val_loss' in globals() else None

# 打印所有所需的指标
print("Metrics Summary:")
if 'train_losses' in globals():
    print(f"init_train_loss\tinit_val_loss\tconverged_epoch\tconverged_val_loss\tmacro_f1\tweighted_f1\tmicro_f1")
    print(f"{init_train_loss:.3f}\t{init_val_loss:.3f}\t{converged_epoch}\t{converged_val_loss:.3f}\t{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")
else:
    print(f"macro_f1\tweighted_f1\tmicro_f1")
    print(f"{macro_f1:.3f}\t{f1:.3f}\t{accuracy:.3f}")

# 保存结果到 CSV 文件
output_data = {
    'dataset_split_random_seed': [int(random_seed)],
    'dataset': ['PBMC'],
    'method': [re.search(r'PBMC_(.*?)_\d+', notebook_name).group(1)],
    'init_train_loss': [init_train_loss if init_train_loss is not None else ''],
    'init_val_loss': [init_val_loss if init_val_loss is not None else ''],
    'converged_epoch': [converged_epoch if converged_epoch is not None else ''],
    'converged_val_loss': [converged_val_loss if converged_val_loss is not None else ''],
    'macro_f1': [macro_f1],
    'weighted_f1': [f1],
    'micro_f1': [accuracy]
}
output_df = pd.DataFrame(output_data)

# 保存到当前目录下名为 results 的文件夹中
if not os.path.exists('results'):
    os.makedirs('results')

csv_filename = f"results/{os.path.splitext(notebook_name)[0]}_results.csv"
output_df.to_csv(csv_filename, index=False)


Metrics Summary:
init_train_loss	init_val_loss	converged_epoch	converged_val_loss	macro_f1	weighted_f1	micro_f1
0.453	3.526	1	3.526	0.262	0.386	0.421
