Classification with logistic regession on the balanced dataset.
 - This is a baseline to assess performance of more complex models. 
 - Here here is no mitigation of batch effects; next iteration can consider a binning strategy.

Overall, with 0.1% of the data and a couple of minutes of time to learn model parameters we achieve an accuracy of ~0.84.

In [1]:
import numpy as np
import anndata as ad
import scanpy as sc
from sklearn.model_selection import train_test_split
from schelp.utils.config import load_config
from rich import print as rprint
import pandas as pd


paths = load_config(dataset_key="init")

def load_dataset_balance(paths, seed=0):

    data_dir = str(paths["data"])
    adata_ = sc.read_h5ad(str(paths["data"]) + '/Human-Brain/balanced_SEAAD_MTG_RNAseq_Singleome_final-nuclei.2024-06-18.h5ad')
    adata_.obs["celltype"] = adata_.obs["Supertype"]
    adata_.obs["batch"] = adata_.obs["Donor ID"]

    num_types = adata_.obs["celltype"].unique().size
    id2type = dict(enumerate(adata_.obs["celltype"].cat.categories))
    celltypes = adata_.obs["celltype"].unique()
    celltype_id_labels = adata_.obs["celltype"].astype("category").cat.codes.values
    adata_.obs["celltype_id"] = celltype_id_labels
    adata_.obs["batch_id"] = adata_.obs["batch"].cat.codes.values
    adata_.var["gene_name"] = adata_.var.index.tolist()
    
    # get high variance genes
    sc.pp.highly_variable_genes(adata_, n_top_genes=4000, flavor="seurat_v3", batch_key="batch")
    adata_ = adata_[:, adata_.var["highly_variable"]]
    adata_ = sc.pp.subsample(adata_, fraction=0.1, copy=True, random_state=seed)
    n_cells = adata_.shape[0]

    # split the data into train and test
    train_ind, test_ind = train_test_split(range(n_cells), test_size=0.2, random_state=seed)
    adata = adata_[train_ind].to_memory()
    adata_test = adata_[test_ind].to_memory()
    
    del adata_
    print(f"There are {adata.shape[0]} cells in the reference set")
    print(f"There are {adata_test.shape[0]} cells in the query set")

    return adata, adata_test, celltypes, id2type

In [2]:
# get the data splits
adata, adata_test, celltypes, id2type = load_dataset_balance(paths=paths, seed=0)



There are 11120 cells in the reference set
There are 2780 cells in the query set


In [3]:
# pipeline for logistic regression classifier
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.pipeline import Pipeline

# split the data into train and test
X_train = adata.X.A
y_train = adata.obs["celltype_id"]
X_test = adata_test.X.A
y_test = adata_test.obs["celltype_id"]

# create a pipeline
pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", LogisticRegressionCV(Cs=10, cv=5, max_iter=1000, n_jobs=-1))
])

# fit the pipeline
pipe.fit(X_train, y_train)

# predict the test set
y_pred = pipe.predict(X_test)

# print the accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"Overall test accuracy: {accuracy:0.3f}")

Overall test accuracy: 0.840


In [4]:
#rprint(classification_report(y_test, y_pred, target_names=celltypes))
report = classification_report(y_test, y_pred, target_names=celltypes, output_dict=True)
df = pd.DataFrame(report).transpose()

# round entries to 3 decimal places
df = df.round(3)

# display full dataframe
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.width", None)
pd.set_option("display.max_colwidth", None)

# reorder rows to match the order of celltypes
df = df.reindex(celltypes.categories.tolist())

# color any values less than 0.5
def color_negative_red(val):
    color = '#ff9999' if val < 0.5 else 'black'
    return 'color: %s' % color

display(df.round(3).style.applymap(color_negative_red))

  display(df.round(3).style.applymap(color_negative_red))


Unnamed: 0,precision,recall,f1-score,support
Astro_1,0.615,0.533,0.571,15.0
Astro_2,0.833,0.789,0.811,19.0
Astro_3,0.778,0.824,0.8,17.0
Astro_4,1.0,1.0,1.0,24.0
Astro_5,0.962,0.926,0.943,27.0
Astro_6-SEAAD,0.893,1.0,0.943,25.0
Chandelier_1,0.611,0.688,0.647,16.0
Chandelier_2,0.778,0.824,0.8,17.0
Endo_1,0.737,0.824,0.778,17.0
Endo_2,0.952,0.952,0.952,21.0
