In [None]:
'''
Training logistic regression on reference data embedded by STATE (tabula sapiens)
'''

In [7]:
# Imports 
import importlib

import json
import logging
import os
import re

import pandas as pd
import numpy as np
import anndata as ad

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
import scanpy as sc

from scipy import sparse
from joblib import dump, load

import subprocess
from tqdm import tqdm

tqdm.pandas()

In [8]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    handlers=[
        logging.StreamHandler()
    ]
)

In [9]:
# TSP Anndatas file path
TSP_FILE_PATH = '/large_storage/ctc/public/scBasecamp/tabula_sapiens/GeneFull_Ex50pAS'

In [10]:
# TSP state embeddings
STATE_TSP_FILE_PATH = '/large_storage/ctc/userspace/rohankshah/tabula_sapiens/tabula_sapiens_with_state.parquet'
state_embeds = pd.read_parquet(STATE_TSP_FILE_PATH, engine='fastparquet')

In [114]:
# Create an identifier
state_embeds['tissue'] = state_embeds['dataset'].progress_apply(lambda dataset: dataset.split('_')[1])
state_embeds['barcode'] = state_embeds['cell'].values
state_embeds['identifier'] = (
    state_embeds['dataset']
    .str.replace(r'_S\d+_', '_', regex=True)
    .str.replace(r'_GeneFull_Ex50pAS$', '', regex=True)
    # .str.replace(r'_[^_]*etc[^_]*_', '_', regex=True, flags=re.IGNORECASE)
) + '_' + state_embeds['barcode']

100%|██████████████████████████████| 833668/833668 [00:00<00:00, 1570497.90it/s]


In [132]:
np.random.seed(42)

def remove_integers_from_list(in_string):
    '''
    Remove any integers from the input string
    '''
    return re.sub(r'\d+', '', in_string)

def get_tissues_list():
    '''
    Get list of tissues in tabula sapiens
    '''
    all_tsp_dirs = os.listdir(TSP_FILE_PATH)
    tissues = []
    for file_name in all_tsp_dirs:
        tissue = remove_integers_from_list(file_name.split('_')[1]).lower()
        if tissue not in tissues:
            tissues.append(tissue)
    return tissues
    
def get_tsp_files(tissue=None):
    '''
    Get list of TSP files for a given tissue
    '''
    all_tsp_dirs = os.listdir(TSP_FILE_PATH)
    if tissue == 'lymphnode':
        tsp_filtered = list(filter(lambda filename: remove_integers_from_list(filename.split('_')[1]).lower() == tissue.lower() or remove_integers_from_list(filename.split('_')[1]).lower() == 'lymphnodes', all_tsp_dirs))
    else:
        tsp_filtered = list(filter(lambda filename: remove_integers_from_list(filename.split('_')[1]).lower() == tissue.lower(), all_tsp_dirs))
    return tsp_filtered

def get_ref_adata(file_name):
    '''
    Get reference tabula sapiens data to train on
    '''
    adata = sc.read_h5ad(f'{TSP_FILE_PATH}/{file_name}')
    return adata

def merge_anndatas(file_list=None):
    '''
    Read in all reference anndatas and put them together
    '''
    adatas = []
    for file_name in file_list:
        adata = get_ref_adata(file_name)
        adatas.append(adata)
    ref_adata = ad.concat(adatas, merge='same')
    ref_adata.obs.rename(columns={'Unnamed: 0': 'identifier'}, inplace=True)
    return ref_adata

def remove_null_cts(adata):
    '''
    Remove the entries with nan, unknown, or null cell types
    '''
    cleaned_adata = adata[(adata.obs.cell_type != 'nan') & (adata.obs.cell_type != 'unknown') & (~adata.obs.cell_type.isnull())].copy()
    return cleaned_adata

def align_embeddings(ref_adata, state_embeds, identifier_col):
    '''
    Align cells from state embeddings dataframe and align them with anndata
    '''
    if 'nan' in ref_adata.obs[identifier_col].values:
        logging.error("Null identifiers exist in the reference data")
        return
    
    ref_adata.obs[identifier_col] = ref_adata.obs[identifier_col].astype(str)
    state_embeds[identifier_col] = state_embeds[identifier_col].astype(str)

    mask = ~ref_adata.obs.duplicated(subset=identifier_col, keep='first')
    ref_adata = ref_adata[mask].copy()
    
    common_ids = set(ref_adata.obs[identifier_col]).intersection(set(state_embeds[identifier_col]))

    if len(ref_adata.obs[identifier_col].unique()) != len(common_ids):
        logging.warning("Cells in reference data not included in common ids")

    ref_adata = ref_adata[ref_adata.obs[identifier_col].isin(common_ids)].copy()
    df_filtered = state_embeds[state_embeds[identifier_col].isin(common_ids)].copy()
    df_filtered = df_filtered.set_index(identifier_col)
    df_filtered = df_filtered.loc[ref_adata.obs[identifier_col]]
    X_embed = df_filtered.iloc[:, 1:2059].to_numpy()
    
    assert X_embed.shape[0] == ref_adata.n_obs, "Mismatch in number of cells and embeddings"
    
    ref_adata.obsm['X_state'] = X_embed

    return ref_adata

In [None]:
# Get the list of tissues to make models for
tissues = get_tissues_list()
output_model_dir = "output/classifiers"
output_adata_dir = "output/adatas_embedded"
ignore_prev = False
to_skip_list = ['endopancreas', 'exopancreas', 'lymphnodes', 'wholeblood']
barcode_aligned = ['eye', 'mammary', 'myometrium', 'endometrium']

for i in tqdm(range(len(tissues))):
    # Get the tissue from tissues list
    tissue = tissues[i]

    if tissue in to_skip_list:
        logging.info(f"{tissue} in the skip list, continuing to the next tissue")
        continue
    
    # Retrieve list of tsp files for a corresponding tissue
    files_list = get_tsp_files(tissue)
    if f'TSP_{tissue}.h5ad' not in os.listdir(output_adata_dir):
        logging.info(f"Processing {tissue} TSP data...")
        ref_adata = merge_anndatas(files_list)
        logging.info("Merged anndatas")
        
        ref_adata_processed = remove_null_cts(ref_adata)
        logging.info("Removed null cell types")

        if tissue in barcode_aligned:
            mask = ~ref_adata_processed.obs.duplicated(subset=['barcode'], keep=False)
            ref_adata_processed = ref_adata_processed[mask].copy()
            embeds = state_embeds[state_embeds['identifier'].str.contains(tissue, case=False)].copy()
            embeds.drop_duplicates(subset=['barcode'], keep=False, inplace=True)
            ref_adata_embedded = align_embeddings(ref_adata_processed, embeds, identifier_col='barcode')
            print(f'{tissue} data shape {ref_adata_embedded.shape}')
        else:
            ref_adata_embedded = align_embeddings(ref_adata_processed, state_embeds, identifier_col='identifier')
        logging.info("Aligned STATE embeddings with reference anndata")
        
        ref_adata_embedded.write(output_adata_dir + f'/TSP_{tissue}.h5ad')
        logging.info(f"Successfully wrote anndata to {output_adata_dir}")
    else:
        logging.info(f"{tissue} already has embeddings")

    if f'{tissue}_ref_model_logreg.joblib' not in os.listdir(output_model_dir) or ignore_prev:
        logging.info("Training logistic regression classifier...")
        
        pipeline = Pipeline([
                            ("scaler", StandardScaler()),
                            ("logreg", LogisticRegression(max_iter=1000)),
                        ])

        ref_adata_embedded = sc.read_h5ad(output_adata_dir + f'/TSP_{tissue}.h5ad')
            
        embeddings, labels = ref_adata_embedded.obsm["X_state"], ref_adata_embedded.obs["cell_type"]
        pipeline.fit(embeddings, labels)
        logging.info("Finished training logistic regression classifier")
        
        # Save pipeline
        dump(pipeline, output_model_dir + f'/{tissue}_ref_model_logreg.joblib')
        logging.info(f"Successfully saved model to {output_model_dir}")
    else:
        logging.info(f"Logistic regression classifier for {tissue} tissue already exists")

    logging.info(f"All steps completed for {tissue}")
