In [None]:
'''
Train STATE Model for Brain Cell Atlas
'''

In [4]:
# Imports 
import importlib

import json
import logging
import os
import re

import pandas as pd
import numpy as np
import anndata as ad

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc

from scipy import sparse
from joblib import dump, load

import subprocess
from tqdm import tqdm

import psutil
import pyarrow as pa
import pyarrow.parquet as pq
import random

tqdm.pandas()

In [1]:
BRAIN_PATH = "/large_storage/ctc/ML/brain_cell_atlas"

In [4]:
adatas = []
brain_adata_files = os.listdir(BRAIN_PATH)
for file in tqdm(brain_adata_files):
    adata = sc.read_h5ad(BRAIN_PATH + '/' + file)
    adata = adata[(~adata.obs['cell_type'].isnull()) 
        & (adata.obs['cell_type'] != 'nan') 
        & (adata.obs['cell_type'] != 'other')].copy()
    adatas.append(adata)

100%|█████████████████████████████████████████| 443/443 [04:26<00:00,  1.66it/s]


In [9]:
total_adata.obs['cell_type'].value_counts()

cell_type
neuron                                    1202604
oligodendrocyte                            238107
astrocyte                                   83385
oligodendrocyte precursor cell              63844
central nervous system macrophage           48195
Bergmann glial cell                          6331
choroid plexus epithelial cell               5003
fibroblast                                   4050
endothelial cell                             3021
ependymal cell                               2539
leukocyte                                    1914
pericyte                                     1859
vascular associated smooth muscle cell        349
Name: count, dtype: int64

In [None]:
total_adata = sc.concat(adatas)
total_adata.write(output_adata_dir + f'/TSP_brain.h5ad')

In [6]:
total_adata = sc.read_h5ad('output/adatas_embedded/TSP_brain.h5ad')

In [None]:
# Train the logistic regression classifier
output_model_dir = "output/classifiers"
pipeline = Pipeline([
    ("scaler", StandardScaler()),
    ("logreg", LogisticRegression(max_iter=1000, verbose=1)),
])
embeddings, labels = total_adata.obsm["X_state"], total_adata.obs["cell_type"]
pipeline.fit(embeddings, labels)

# Save pipeline
dump(pipeline, output_model_dir + f'/brain_ref_model_logreg.joblib')
print(f"Successfully saved model to {output_model_dir}")