In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from my_model.mlp_cls import MLP
from my_model.trans_enc_cls import  TransformerEncoder
from my_model.mydata import mydataSet
from my_model.util import setup_seed, count_labels, compute_mean_std
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, precision_score
import warnings
import pickle
from sklearn.model_selection import KFold, StratifiedKFold

In [2]:
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Train with', device)

#128, 16
hidden_size_1 = 512
hidden_size_2 = 64
batch_size = 64
num_epochs = 25
num_folds = 5
learning_rate = 0.0005
random_seed = 42
dataset_name = 'Baron'

Train with cuda


In [6]:
def data_produce(token_dim):

    adata = sc.read_h5ad(f"./cls_data/preprocessed_{dataset_name}.h5ad")
    sc.pp.filter_genes(adata, min_cells=3)

    X = adata.X

    y = adata.obs['cell_type']
    label_encoder = LabelEncoder()
    y_encoded = label_encoder.fit_transform(y)
    with open('label_encoder_Baron.pkl', 'wb') as file:
        pickle.dump(label_encoder, file)


    
    data = pd.DataFrame(X.toarray() if hasattr(X, 'toarray') else X, columns=adata.var_names)
    data["label"] = y_encoded

    label_counts = data["label"].value_counts()
    print("Label Counts:", label_counts)

    embedding_dim = token_dim
    features = data.iloc[:, :-1]
    num_samples, num_features = features.shape
    remaining_features = num_features % embedding_dim

    if remaining_features != 0:
        padding_size = embedding_dim - remaining_features
        features_padded = pd.concat([features, pd.DataFrame(np.zeros((num_samples, padding_size)))], axis=1)
    else:
        features_padded = features

    # Calculate the number of tokens
    num_tokens = features_padded.shape[1] // embedding_dim

    # Create a 3D array to store data
    grouped_features = np.zeros((num_samples, num_tokens, embedding_dim))

    # Fill the array with token vectors
    for i in range(num_tokens):  

        start_idx = i * embedding_dim
        end_idx = start_idx + embedding_dim
        grouped_features[:, i, :] = features_padded.iloc[:, start_idx:end_idx]

    np.save('input_data/data_x.npy', grouped_features)
    np.save('input_data/data_y.npy', data.iloc[:, -1])

In [7]:
def train(conv_dim):
    
    data = np.load('input_data/data_x.npy')
    _, token_num, embedding_size = data.shape
    label = np.load('input_data/data_y.npy')
    num_classes = int(label.max()) + 1
    print('num_class:', num_classes)
    my_dataset = mydataSet(data, label)
    kf = StratifiedKFold(n_splits=num_folds, shuffle=True)
    kf_second = StratifiedKFold(n_splits=2, shuffle=True)

    test_ACC = []
    test_F1 = []
    test_PRE = []
    for fold, (_, test_indices) in enumerate(kf.split(my_dataset, label)):
        
        tmp_x = my_dataset[test_indices][0]
        tmp_y = my_dataset[test_indices][1]
        new_test_indices, _ = next(kf_second.split(tmp_x,tmp_y), tmp_y)
        test_indices = test_indices[new_test_indices]

        transformer_model = TransformerEncoder(seq_length=token_num, token_dim=embedding_size, conv_emb_dim=conv_dim).double()
        classification_model = MLP(input_dim=token_num, hidden_dim1 = hidden_size_1, hidden_dim2 = hidden_size_2, num_classes=num_classes).double()
        transformer_model.to(device)
        classification_model.to(device)

        test_dataset = mydataSet(my_dataset[test_indices][0], my_dataset[test_indices][1])
        test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True, num_workers=0)
        test_label_counts = count_labels(test_loader)

        print(f"Test labels distribution: {test_label_counts}")

        best_model_wts = torch.load(f'ckpts/{dataset_name}/{dataset_name}_model_fold_{fold + 1}.pt')
        transformer_model.load_state_dict(best_model_wts['transformer'])
        classification_model.load_state_dict(best_model_wts['classification'])
        transformer_model.eval()
        classification_model.eval()
        with torch.no_grad():
            all_test_predictions = []
            all_test_labels = []
            for test_data_batch, test_label_batch in test_loader:
                test_data_batch, test_label_batch = test_data_batch.to(device), test_label_batch.to(device)
                test_transformer_output, _ = transformer_model(test_data_batch)
                test_predictions = classification_model(test_transformer_output)
                all_test_predictions.append(test_predictions.cpu().numpy())
                all_test_labels.append(test_label_batch.cpu().numpy())
            all_test_predictions = np.concatenate(all_test_predictions)
            all_test_labels = np.concatenate(all_test_labels)

            test_pred_classes = np.argmax(all_test_predictions, axis=1)
            test_accuracy = accuracy_score(all_test_labels, test_pred_classes)
            test_f1 = f1_score(all_test_labels, test_pred_classes, average='macro')
            test_precision = precision_score(all_test_labels, test_pred_classes, average= 'macro')
            test_ACC.append(test_accuracy)
            test_F1.append(test_f1)
            test_PRE.append(test_precision)

            print(f"Test Accuracy: {test_accuracy:.4f}, Test F1 Score: {test_f1:.4f}, Test Precision Score: {test_precision:.4f}\n")
    acc_mean, acc_std = compute_mean_std(test_ACC)
    f1_mean, f1_std = compute_mean_std(test_F1)
    pre_mean, pre_std = compute_mean_std(test_PRE)

    print(f"ACC: {acc_mean}±{acc_std}")
    print(f"F1: {f1_mean}±{f1_std}")
    print(f"Pre: {pre_mean}±{pre_std}")

In [8]:
setup_seed(random_seed)
data_produce(token_dim = 64)
train(conv_dim = 128)

Label Counts: label
3     2525
2     2326
5     1077
0      958
4      601
1      284
8      255
6      252
11     173
9       55
10      25
7       18
12      13
13       7
Name: count, dtype: int64
num_class: 14
Test labels distribution: Counter({3: 252, 2: 233, 5: 108, 0: 96, 4: 60, 1: 28, 6: 26, 8: 25, 11: 17, 9: 6, 7: 2, 10: 2, 12: 1, 13: 1})
Test Accuracy: 0.9848, Test F1 Score: 0.9679, Test Precision Score: 0.9623

Test labels distribution: Counter({3: 252, 2: 232, 5: 108, 0: 96, 4: 60, 1: 29, 6: 25, 8: 25, 11: 17, 9: 6, 12: 2, 10: 2, 7: 2, 13: 1})
Test Accuracy: 0.9848, Test F1 Score: 0.9890, Test Precision Score: 0.9892

Test labels distribution: Counter({3: 252, 2: 232, 5: 107, 0: 96, 4: 60, 1: 29, 8: 26, 6: 25, 11: 18, 9: 6, 10: 2, 7: 2, 12: 1, 13: 1})
Test Accuracy: 0.9848, Test F1 Score: 0.8825, Test Precision Score: 0.8866

Test labels distribution: Counter({3: 252, 2: 233, 5: 108, 0: 95, 4: 61, 1: 28, 8: 26, 6: 25, 11: 17, 9: 5, 10: 3, 7: 2, 12: 1, 13: 1})
Test Accuracy: