In [1]:
import os
import rpy2
import logging
import warnings
import anndata2ri
import pandas as pd
import scanpy as sc
import anndata as ad
import numpy as np
import rpy2.robjects as robjects
from rpy2.robjects import pandas2ri
from matplotlib.pyplot import rcParams

In [2]:
# # Ignore R warning messages
#Note: this can be commented out to get more verbose R output
rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR)

# # Automatically convert rpy2 outputs to pandas dataframes
# pandas2ri.activate()
# anndata2ri.activate()
# %load_ext rpy2.ipython

warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Automatically convert rpy2 outputs to pandas dataframes
pandas2ri.activate()
anndata2ri.activate()
%load_ext rpy2.ipython

# rcParams['figure.dpi'] = get_sys_dpi(1512, 982, 14.125)
# rcParams['figure.figsize']=(4,4) #rescale figures

sc.settings.verbosity = 3
# sc.set_figure_params(dpi=200, dpi_save=300)
sc.logging.print_versions()



-----
anndata     0.8.0
scanpy      1.9.3
-----
PIL                         9.5.0
anndata2ri                  1.1
appnope                     0.1.3
asttokens                   NA
backcall                    0.2.0
cffi                        1.15.1
comm                        0.1.3
cycler                      0.10.0
cython_runtime              NA
dateutil                    2.8.2
debugpy                     1.6.7
decorator                   5.1.1
executing                   1.2.0
google                      NA
h5py                        3.9.0
igraph                      0.10.4
ipykernel                   6.23.2
ipywidgets                  8.0.6
jedi                        0.18.2
jinja2                      3.1.2
joblib                      1.2.0
kiwisolver                  1.4.4
leidenalg                   0.9.1
llvmlite                    0.39.1
louvain                     0.8.0
markupsafe                  2.1.3
matplotlib                  3.7.1
mpl_toolkits                NA
natsort 

# **Load input data in `.h5ad` format and map to cell class annotation provided in [MapMyCells](https://portal.brain-map.org/atlases-and-data/bkp/mapmycells)**

In [5]:
def convert_columns_to_string(sce, obs_cols=None, var_cols=None):
    """
    Convert specified columns in the .obs and .var DataFrames of an anndata object to string type.

    Parameters:
    sce (anndata.AnnData): The single-cell AnnData object to modify.
    obs_cols (list of str): Columns in sce.obs to convert to strings.
    var_cols (list of str): Columns in sce.var to convert to strings.
    """
    if obs_cols is not None:
        for col in obs_cols:
            sce.obs[col] = sce.obs[col].astype(str)
    
    if var_cols is not None:
        for col in var_cols:
            sce.var[col] = sce.var[col].astype(str)

def save_anndata(sce, file_path):
    """
    Save an AnnData object to a file.

    Parameters:
    sce (anndata.AnnData): The single-cell AnnData object to save.
    file_path (str): The path to save the file to.
    """
    sce.write_h5ad(file_path)


# Function to filter and update AnnData object based on cell barcode annotations
def filter_and_update_anndata(ad, annot):

    # Filter valid cell barcodes
    #valid_barcodes = ad.obs_names.intersection()
    ad_filtered = ad[annot['cell_barcode'].to_list()].copy()
    
    # Merge annotations
    ad_filtered.obs = ad_filtered.obs.merge(annot, left_on=ad_filtered.obs_names, right_on='cell_barcode', how='right')
    
    # Update obs_names with cell barcodes
    ad_filtered.obs_names = ad_filtered.obs['cell_barcode']
    
    return ad_filtered




def map_cell_types(query, celltypes):
    """
    Merge cell type annotations from SEA-AD for a given query data.

    Parameters:
    - target (str): The query data.
    - celltypes (list): List of cell types.

    Returns:
    - None
    """

    dat_dir = f'../data/raw/{target}/anndata/'
    files = os.listdir(dat_dir + 'cell_type_mapping/')

    for cell_type in celltypes:

        try:

            filess = [f for f in files if cell_type in f]

            if len(filess)==1:
                file = filess[0]
                mapping = pd.read_csv(dat_dir + 'cell_type_mapping/' + file, skiprows=3)
                mapping.class_label = mapping.class_label.apply(lambda x: '_'.join([cell_type[:3].upper()] + x.split('_')[1:]))
                mapping.subclass_label = mapping.subclass_label.apply(lambda x: '_'.join([cell_type[:3].upper()] + x.split('_')[1:]))
                mapping.supertype_label = mapping.supertype_label.apply(lambda x: '_'.join([cell_type[:3].upper()] + x.split('_')[1:]))
            else:
                mapping = pd.DataFrame()
                for file in filess:
                    df = pd.read_csv(dat_dir + 'cell_type_mapping/' + file, skiprows=3)
                    df.class_label = df.class_label.apply(lambda x: '_'.join([cell_type[:3].upper()] + x.split('_')[1:]))
                    df.subclass_label = df.subclass_label.apply(lambda x: '_'.join([cell_type[:3].upper()] + x.split('_')[1:]))
                    df.supertype_label = df.supertype_label.apply(lambda x: '_'.join([cell_type[:3].upper()] + x.split('_')[1:]))
                    mapping = pd.concat([mapping, df], axis=0)

            adata = sc.read_h5ad(dat_dir + f'/{cell_type}_raw_anndata.h5ad')
            adata.obs = adata.obs.merge(mapping, left_on=adata.obs_names, right_on='cell_id', how='left')
            adata.write_h5ad(dat_dir + f'/{cell_type}_raw_anndata.h5ad', compression='gzip')
        except FileNotFoundError:
            continue
        

In [6]:
celltypes = ['excitatory', 'inhibitory', 'microglia', 'astrocyte', 'oligodendrocyte', 'endothelial', 'opc']
target = 'mathys_pfc'

map_cell_types(target, celltypes)