In [1]:
import random
import torch.utils.data
import torch.nn.parallel
from sklearn.model_selection import StratifiedKFold
from scipy.optimize import linear_sum_assignment
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
from model import actinn, clustering
from model.utilities import *
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
import time
import warnings
warnings.filterwarnings("ignore")

# Data Loading
Load single-cell data and cell type labels from 10x Genomics format, and align them based on cell names.

In [2]:
# data loading
adata1 = sc.read_10x_mtx(
    'data/Kidney/droplet/Kidney-10X_P4_5',  # the directory with the `.mtx` file
    var_names='gene_symbols',  # use gene symbols for the variable names (variables-axis index)
    cache=True)
adata1.obs_names = "10X_P4_5_" + adata1.obs_names
adata2 = sc.read_10x_mtx(
    'data/Kidney/droplet/Kidney-10X_P4_6',  # the directory with the `.mtx` file
    var_names='gene_symbols',                # use gene symbols for the variable names (variables-axis index)
    cache=True)
adata2.obs_names = "10X_P4_6_" + adata2.obs_names
adata3 = sc.read_10x_mtx(
    'data/Kidney/droplet/Kidney-10X_P7_5',  # the directory with the `.mtx` file
    var_names='gene_symbols',                # use gene symbols for the variable names (variables-axis index)
    cache=True)
adata3.obs_names = "10X_P7_5_" + adata3.obs_names
adata = adata1.concatenate(adata2, adata3)
split_names = adata.obs_names.str.split('-', expand=True)
adata.obs_names = split_names.get_level_values(0)
# label loading
gt_df = pd.read_csv("data/Kidney/droplet/annotations_droplet.csv")
gt_df = gt_df[gt_df['tissue'] == 'Kidney']
gt_df.index = gt_df["cell"]
# data and label mapping
common_cells = adata.obs_names.intersection(gt_df.index)
adata = adata[common_cells, :]
gt_aligned = gt_df.loc[common_cells, :]
adata.obs["CellType"] = gt_aligned["cell_ontology_class"]
print(adata)

AnnData object with n_obs × n_vars = 2781 × 23433
    obs: 'batch', 'CellType'
    var: 'gene_ids'


# Data preprocessing
Apply standard QC, normalization, log transformation, and high-variability gene selection processes to prepare inputs for scSemiPLC.

In [3]:
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
sc.pp.normalize_total(adata, target_sum=1e4) ##标准化
sc.pp.log1p(adata)
adata.raw = adata
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
adata = adata[:, adata.var.highly_variable]
sc.pp.scale(adata, max_value=10)
print(adata)

AnnData object with n_obs × n_vars = 2781 × 1988
    obs: 'batch', 'CellType', 'n_genes'
    var: 'gene_ids', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
    uns: 'log1p', 'hvg'


# Training Function Definition
Define the 'train' function, which integrates contrastive learning pre-training, supervised training, semi-supervised training, and annotation result evaluation.

In [4]:
def train(X_label, Y_label, X_unlabel, Y_unlabel, cluster, model, estimator, opt_est,
          Pre_epochs=100, Supervised_epochs=200, SemiSupervised_epochs=150):

    # Dataloader Definition
    dataset = CustomDataset(data=X_unlabel, transform_args=transformation_list)
    pretrain_loader = DataLoader(dataset, batch_size=512, shuffle=True, sampler=None,
                                 batch_sampler=None, collate_fn=None, pin_memory=True)
    # -------------------------------------- 1. Pretraing Stage -----------------------------------
    # Initialize the annotation model using unlabeled data via contrastive loss.
    opt_model = torch.optim.Adam(params=model_cla.parameters(), lr=5e-4, betas=(0.9, 0.999),
                                     eps=1e-08, weight_decay=0.005, amsgrad=False)
    for _ in tqdm(range(Pre_epochs), desc="Pre-Training"):
        model.train()
        for batch_idx, (_, _, inputs_u_w, inputs_u_s) in enumerate(pretrain_loader):
            inputs_u_w = inputs_u_w.to(device)
            inputs_u_s = inputs_u_s.to(device)
            N = inputs_u_w.shape[0]
            opt_model.zero_grad()
            _, logits_u_w = model(inputs_u_w)
            _, logits_u_s = model(inputs_u_s)
            Lc = ContrastiveLoss(logits_u_w, logits_u_s, 0.1)
            Lc.backward()
            opt_model.step()

    # Establish a mapping relationship between the clusters obtained using the Hungarian algorithm and the cell types.
    _, logits = model(torch.tensor(X_label).to(device))
    pred_labels = logits.argmax(1)
    gt_lables = torch.tensor(Y_label).to(device)
    num_classes = max(pred_labels.max().item(), gt_lables.max().item()) + 1
    matrix = torch.zeros((num_classes, num_classes), dtype=torch.long, device=pred_labels.device)
    for i in range(num_classes):
        for j in range(num_classes):
            matrix[i, j] = torch.sum((pred_labels == i) & (gt_lables == j))
    matrix_np = matrix.cpu().numpy()
    row_ind, col_ind = linear_sum_assignment(-matrix_np)
    gt2pred_mapping = {col_ind[i]: row_ind[i] for i in range(len(row_ind))}
    pred2gt_mapping = {row_ind[i]: col_ind[i] for i in range(len(row_ind))}
    aligned_gt_labels = gt_lables.clone()
    for old_label, new_label in gt2pred_mapping.items():
        aligned_gt_labels[gt_lables == old_label] = new_label # Application label mapping

    # ------------------------------------ 2. Supervised Training Stage ---------------------------------
    # Adjust the annotation model using aligned ground truth labels and data .
    label_data = [[torch.tensor(feat), label, aligned_label] for feat, label, aligned_label in zip(X_label, Y_label, aligned_gt_labels.cpu())]
    label_loader = DataLoader(label_data, batch_size=32, shuffle=True, sampler=None, batch_sampler=None, collate_fn=None, pin_memory=True)
    opt_model = torch.optim.Adam(params=model_cla.parameters(), lr=1e-4, betas=(0.9, 0.999),
                                     eps=1e-08, weight_decay=0.005, amsgrad=False)
    for epoch in range(Supervised_epochs):
        model.train()
        for batch_idx, (inputs, _, aligned_targets) in enumerate(label_loader):
            opt_model.zero_grad()
            _, logits_x = model(inputs.to(device))
            Ls = F.cross_entropy(logits_x, aligned_targets.to(device), reduction='mean')
            Ls.backward()
            opt_model.step()

    # ---------------------------------- 3. Semi-Supervised Training Stage -------------------------------
    # Integrate pseudo labels updated by clustering and consistency regularization to alternately optimize the annotation model and confidence estimator.
    update_interval = 10
    inputs_x = torch.tensor(X_label).to(device)
    opt_model = torch.optim.Adam(params=model_cla.parameters(), lr=5e-5, betas=(0.9, 0.999),
                                     eps=1e-08, weight_decay=0.005, amsgrad=False)
    for epoch in tqdm(range(SemiSupervised_epochs), desc="SemiSupervised-Training"):
        if epoch % update_interval ==0:
            with torch.no_grad():
                # update_labels
                model.eval()
                init_target_centers = get_centers(net=model, data=inputs_x, labels=aligned_gt_labels,
                                                  num_classes=NUM_CLASS)
                cluster.set_init_centers(init_target_centers)
                cluster.feature_clustering(model, data=torch.tensor(X_unlabel))
                targets_sel = cluster.samples['p_label']

        model.train()
        estimator.train()
        # Estimator optimization with model fixed
        for batch_idx, (idx, input_u, inputs_u_w, inputs_u_s) in enumerate(pretrain_loader):
            input_u = input_u.to(device)
            inputs_u_w = inputs_u_w.to(device)
            inputs_u_s = inputs_u_s.to(device)

            opt_est.zero_grad()
            with torch.no_grad():
                fea_sel, logits_sel = model(input_u)
                _, logits_sel_s = model(inputs_u_s)
                _, logits_sel_w = model(inputs_u_w)
            input_s = torch.cat([fea_sel, logits_sel], dim=1)
            input_s.requires_grad_(True)
            sim_s = estimator(input_s)
            # Consistency loss
            Lcw = (F.cross_entropy(logits_sel_w, targets_sel[idx], reduction='none') * sim_s).mean()
            Lcs = (F.cross_entropy(logits_sel_s, targets_sel[idx], reduction='none') * sim_s).mean()
            Lu_est = Lcs + Lcw
            Lu_est.backward()
            opt_est.step()

        # Model optimization with estimator fixed
        for (idx, input_u, inputs_u_w, inputs_u_s) in pretrain_loader:
            opt_model.zero_grad()
            _, logits_x = model(inputs_x)
            Ls = F.cross_entropy(logits_x, aligned_gt_labels, reduction='mean')  # 分类任务训练模型
            input_u = input_u.to(device)
            inputs_u_w = inputs_u_w.to(device)
            inputs_u_s = inputs_u_s.to(device)
            # opt_model.zero_grad()
            fea_sel, logits_sel = model(input_u)
            _, logits_sel_s = model(inputs_u_s)
            _, logits_sel_w = model(inputs_u_w)
            with torch.no_grad():
                input_s = torch.cat([fea_sel.detach(), logits_sel.detach()], dim=1)
                sim_s = estimator(input_s).detach()
            Lcw = (F.cross_entropy(logits_sel_w, targets_sel[idx], reduction='none') * sim_s).mean()
            Lcs = (F.cross_entropy(logits_sel_s, targets_sel[idx], reduction='none') * sim_s).mean()
            Lu_model = Ls + 0.5 * (Lcs + Lcw)
            Lu_model.backward()
            opt_model.step()

    # ---------------------------------- 4. Annotation Result Evaluation -------------------------------
    model.eval()
    y_true = np.array([])
    y_pred = np.array([])
    eval_fea = np.array([])
    with torch.no_grad():
        fea_x, output_x = model(torch.tensor(X_label).to(device))
        _, predict_x = torch.max(output_x.squeeze(), 1)
        y_true = np.append(y_true, Y_label)
        y_pred = np.append(y_pred, predict_x.detach().cpu().numpy())
        eval_fea = np.append(eval_fea, fea_x.detach().cpu().numpy())

        fea_u, output_u = model(torch.tensor(X_unlabel).to(device))
        _, predict_u = torch.max(output_u.squeeze(), 1)
        y_true = np.append(y_true, Y_unlabel)
        y_pred = np.append(y_pred, predict_u.detach().cpu().numpy())
        eval_fea = np.append(eval_fea, fea_u.detach().cpu().numpy())

    aligned_pred_labels = y_pred.copy()
    # Mapping pred labels to ground truth's original numbers
    for pred_label_idx, target_gt_idx in pred2gt_mapping.items():
        aligned_pred_labels[y_pred == pred_label_idx] = target_gt_idx

    accuracy = accuracy_score(y_true, aligned_pred_labels)
    precision = precision_score(y_true, aligned_pred_labels, average="macro")
    recall = recall_score(y_true, aligned_pred_labels, average="macro")
    f1 = f1_score(y_true, aligned_pred_labels, average="macro")

    eval_fea = eval_fea.reshape(-1, 50)

    return y_true, y_pred, eval_fea, 100.*accuracy, 100. * precision, 100. * recall, 100. * f1

# Main Program Execution
Set the random seed, device, and data augmentation parameters, then train and evaluate the model through a stratified 10-fold cross-validation.

In [5]:
# Record the results of the metrics
valid_f1_sum, valid_acc_sum = 0, 0
valid_pre_sum, valid_rec_sum = 0, 0
out_pred = pd.DataFrame()
out_true = pd.DataFrame()
out_batch = pd.DataFrame()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
cluster = clustering.Clustering(0.005, 128 * 9, device=device)

seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)

# Data Preparation
dataset_name = "Kidney"
type_to_label_dict = {'kidney capillary endothelial cell': 0, 'kidney cell': 1,
                      'kidney collecting duct epithelial cell': 2,
                      'kidney loop of Henle ascending limb epithelial cell': 3,
                      'kidney proximal straight tubule epithelial cell': 4, 'leukocyte': 5, 'macrophage': 6,
                      'mesangial cell': 7}

# Data Augmentation Settings
transformation_list = [{  # weak
    'mask_percentage': 0.5, 'apply_mask_prob': 0.8,
    'noise_percentage': 0.5, 'sigma': 0.5, 'apply_noise_prob': 0.0
}, {  # strong
    'mask_percentage': 0.5, 'apply_mask_prob': 0.0,
    'noise_percentage': 0.5, 'sigma': 0.5, 'apply_noise_prob': 0.8
}]

X = np.array(adata.X).astype(np.float32)
Y = convert_type2label(adata.obs["CellType"], type_to_label_dict)
feature_size = X.shape[1]
NUM_CLASS = len(type_to_label_dict)

skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
index = 0
label_num = list(range(NUM_CLASS))
for unlabel_index, label_index in skf.split(X, Y):
    X_label, X_unlabel = X[label_index], X[unlabel_index]
    Y_label, Y_unlabel = Y[label_index], Y[unlabel_index]

    # Model definition
    model_cla = actinn.ACTINN(output_dim=NUM_CLASS, input_size=feature_size).to(device)
    model_cla.apply(init_weights)
    # Confidence estimator
    model_est = actinn.Con_estimator(NUM_CLASS=NUM_CLASS).to(device)
    model_est.apply(init_weights)
    optimizer_est = torch.optim.Adam(params=model_est.parameters(), lr=5e-4, betas=(0.9, 0.999),
                                     eps=1e-08, weight_decay=0.005, amsgrad=False)
    # Model Training
    label_true, pred, fea, val_acc, val_pre, val_rec, val_f1 = \
        train(X_label, Y_label, X_unlabel, Y_unlabel, cluster=cluster,
              model=model_cla, estimator=model_est, opt_est=optimizer_est)
    end_time = time.time()
    print('valid F1:{:.3f}%, valid_acc:{:.3f}%'.format(val_f1, val_acc))
    print('valid pre:{:.3f}%, valid_rec:{:.3f}%\n'.format(val_pre, val_rec))
    # Results recording
    valid_f1_sum += val_f1
    valid_acc_sum += val_acc
    valid_pre_sum += val_pre
    valid_rec_sum += val_rec

    index += 1

print('average accuracy:{:.3f}%, average F1:{:.3f}%'.format(valid_acc_sum / 10, valid_f1_sum / 10))
print('average precision:{:.3f}%, average recall:{:.3f}%\n'.format(valid_pre_sum / 10, valid_rec_sum / 10))

Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00,  2.99it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:35<00:00,  1.58it/s]


valid F1:96.785%, valid_acc:99.173%
valid pre:97.918%, valid_rec:96.115%



Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.13it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:35<00:00,  1.57it/s]


valid F1:96.778%, valid_acc:98.849%
valid pre:96.121%, valid_rec:97.856%



Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.13it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:35<00:00,  1.57it/s]


valid F1:94.884%, valid_acc:98.670%
valid pre:95.463%, valid_rec:95.261%



Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.13it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:35<00:00,  1.58it/s]


valid F1:97.754%, valid_acc:99.389%
valid pre:98.010%, valid_rec:97.573%



Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:32<00:00,  3.12it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:35<00:00,  1.57it/s]


valid F1:96.592%, valid_acc:99.065%
valid pre:97.248%, valid_rec:96.200%



Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.13it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:35<00:00,  1.58it/s]


valid F1:96.455%, valid_acc:99.029%
valid pre:98.532%, valid_rec:95.070%



Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.14it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:35<00:00,  1.58it/s]


valid F1:95.593%, valid_acc:98.346%
valid pre:94.415%, valid_rec:97.589%



Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:32<00:00,  3.11it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:34<00:00,  1.58it/s]


valid F1:96.957%, valid_acc:98.993%
valid pre:97.910%, valid_rec:96.251%



Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.13it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:35<00:00,  1.58it/s]


valid F1:98.404%, valid_acc:99.425%
valid pre:99.089%, valid_rec:97.760%



Pre-Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.13it/s]
SemiSupervised-Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 150/150 [01:35<00:00,  1.58it/s]

valid F1:97.821%, valid_acc:99.245%
valid pre:97.411%, valid_rec:98.335%

average accuracy:99.018%, average F1:96.802%
average precision:97.212%, average recall:96.801%




