# 7_export_thalamus_data_for_cirrocumulus

In [7]:
%load_ext autoreload
%autoreload 2

In [8]:
import numpy as np
import pandas as pd
import anndata as ad

import scanpy as sc

import matplotlib.pyplot as plt
import seaborn as sns
import colorcet as cc

from thalamus_merfish_analysis import abc_load as abc
from thalamus_merfish_analysis import ccf_plots as cplots
# from thalamus_merfish_analysis import abc_to_cirro as a2c

get_ipython().run_line_magic('matplotlib', 'inline')

## Load abc atlas data

In [9]:
# load custom adata with required metadata for cirrocumulus compatibility
adata_th_zi = abc.load_adata_thalamus(flip_y=True, # invert y coords so sections appear in correct coronal orientation
                                      with_colors=False, # don't need colors per cell - can get them from dict
                                      ) 
# filter by class
adata_th_zi_neurons = abc.filter_by_class_thalamus(adata_th_zi,
                                                   filter_nonneuronal=True,
                                                   filter_midbrain=False,
                                                   filter_other_nonTH=True
                                                  )
# filter by thalamus coordinates
adata_th_zi_neurons = abc.filter_by_thalamus_coords(adata_th_zi_neurons, 
                                                    buffer=0
                                                    )

In [10]:
adata_th_zi_neurons

In [12]:
# which columns in obs to use for plotting
realigned = False
if realigned:
    ccf_label = 'parcellation_substructure_realigned'
    coords = 'section'
else:
    ccf_label = 'parcellation_substructure'
    coords = 'reconstructed'
    
x_col = 'x_'+coords
y_col = 'y_'+coords
section_col = 'z_'+coords

sections_all = sorted(adata_th_zi_neurons.obs[section_col].unique())

In [49]:
def add_montage_coords(adata, 
                       section_col='z_section',
                       x_col='x_section',
                       y_col='y_section',
                       new_coord_suffix='_cirro',
                       n_cols=4):
    ''' Creates 2D montage spatial coordinates from 2D+section coordinates and 
    adds them to adata.obs.
    
    Montage spatial coordinates are used for simultaneous display of multiple 
    sections on a 2D screen, e.g. in cirrocumulus.

    Parameters
    ----------
    adata : AnnData object
        AnnData object with spatial data.
    section_col : str
        Column in adata.obs with section numbers.
    x_col, y_col : str
        Columns in adata.obs with x,y coordinates.
    new_coord_suffix : str
        Suffix to add to new x,y columns
    n_cols : int, default=4
        Number of columns to display sections in

    Returns
    -------
    adata : AnnData object
        with x & y montage spatial coordinates added to adata.obs
    '''

    # copy over original xy coords                 
    adata.obs['x'+new_coord_suffix] = adata.obs[x_col].copy()
    adata.obs['y'+new_coord_suffix] = adata.obs[y_col].copy()

    # Dynamically set x_shift, y_shift based on current x_col, y_col units
    width_max = abs(adata.obs[x_col].max() - adata.obs[x_col].min())
    height_max = abs(adata.obs[y_col].max() - adata.obs[y_col].min())
    # set x_shift and y_shift as a ratio of max width/height
    x_shift = width_max*1.2
    y_shift = -(height_max*1.5) # neg to start each new row below previous row

    sections = sorted(adata.obs[section_col].unique())
    count = 0
    # start with anterior-most section in top left and end with posterior-most 
    # section in bottom right
    for i, sec in enumerate(reversed(sections)):
        # increment x_shift each column; reset after completing a row
        curr_x_shift = x_shift * (count % n_cols)
        # increment y_shift after completing each row; 1st row has no y_shift
        curr_y_shift = y_shift * (i // n_cols)
        
        # apply x,y shifts to current section
        mask = adata.obs[section_col]==sec
        adata.obs.loc[mask, 'x'+new_coord_suffix] += curr_x_shift
        adata.obs.loc[mask, 'y'+new_coord_suffix] += curr_y_shift
        
        count+=1
    
    return adata

In [50]:
def add_coords_to_obsm(adata):
    ''' Copy cirro and CCF spatial coordinates into adata.obsm, where 
    cirrocumulus expects to find them. 
    
    3D CCF coords should be in .obs already from loading the ABC Atlas data.
    2D cirro coords should have been added to .obs with add_cirro_coords().
    '''

    if {'x_cirro','y_cirro'}.issubset(adata.obs.columns):
        adata.obsm['cirro_spatial'] = adata.obs[['x_cirro','y_cirro']].to_numpy()
    else:
        UserWarning("No cirrocumulus spatial coordinates, ['x_cirro','y_cirro'], found in adata.obs. Run add_cirro_coords() first.")


    if {'x_ccf','y_ccf','z_ccf'}.issubset(adata.obs.columns):
        adata.obsm['ccf_spatial_3d'] = adata.obs[['x_ccf','y_ccf','z_ccf']].to_numpy()
    else:
        UserWarning("No CCF spatial coordinates, ['x_ccf','y_ccf','z_ccf'], found in adata.obs.")

    return adata

In [51]:
adata_th_zi_neurons = add_cirro_coords(adata_th_zi_neurons, 
                                       section_col=section_col, 
                                       x_col=x_col, 
                                       y_col=y_col, 
                                       n_cols=4)

adata_th_zi_neurons = add_coords_to_obsm(adata_th_zi_neurons)

In [52]:
adata_th_zi_neurons

In [53]:
plt.scatter(adata_th_zi_neurons.obs['x_cirro'], adata_th_zi_neurons.obs['y_cirro'], s=0.001)
plt.axis('equal')

## Make color lists for .uns

.uns\["cluster_colors"\] should be a list of colors as hex strings (e.g. #D1C9BA)  in the order of the .obs.cluster.cat.categories

and same for .uns\["subclass_colors"\] matching the categories of .obs.subclass

In [71]:
def add_colors_to_uns(adata):
    ''' Add ABC color palette dict to adata.uns for each taxonomy level in 
    adata.obs.

    Cirrocumulus expects to find a dict mapping an adata.obs column's categories   
    to hex string colors stored in adata.uns. The dict keys MUST be in the order
    returned by adata.obs.my_column.cat.categories.
    e.g. if you want cirro to use custom colors for adata.obs['cluster'], then
    the color dict containining {category: color} should be stored in 
    adata.uns['cluster_colors'].

    Parameters
    ----------
    adata : AnnData object
        AnnData object with taxonomy levels in adata.obs.

    Returns
    -------
    adata : AnnData object
        with color palettes added to adata.uns for each taxonomy level in adata.obs.
    '''

    taxonomy_levels = ['class', 'subclass', 'supertype', 'cluster']
    assert set(taxonomy_levels).issubset(adata.obs.columns), f"adata.obs.columns is missing at least one of: {taxonomy_levels}"

    for level in taxonomy_levels:
        # get the full ABC color palette for this taxonomy level
        abc_color_dict = abc.get_taxonomy_palette(level)

        # get the categories that exist in this dataset
        # MUST be kept in the order returned by .cat.categories
        curr_cats = adata.obs[level].cat.categories

        # make new color dict for only the categories that exist in this dataset
        cat_color_dict = dict((cat, abc_color_dict[cat]) for cat in curr_cats if cat in abc_color_dict)

        # add this color dict to adata.uns
        adata.uns[level+'_colors'] = cat_color_dict
    
    return adata

In [72]:
adata_th_zi_neurons = add_colors_to_uns(adata_th_zi_neurons)
adata_th_zi_neurons

In [74]:
adata_th_zi_neurons.uns['class_colors']

## Generate UMAP, tSNE, etc.

In [76]:
# PCA pre-processing
sc.pp.pca(adata_th_zi_neurons)
sc.pl.pca(adata_th_zi_neurons)

In [77]:
# more pre-processing
sc.pp.neighbors(adata_th_zi_neurons)

In [79]:
# UMAP
sc.tl.umap(adata_th_zi_neurons)
sc.pl.umap(adata_th_zi_neurons)

In [83]:
adata_th_zi_neurons

In [84]:
# tSNE - takes much longer than UMAP to run
sc.tl.tsne(adata_th_zi_neurons)
sc.pl.tsne(adata_th_zi_neurons)

In [91]:
adata_th_zi_neurons

In [85]:
adata_th_zi_neurons.obsm.pop('X_pca')
adata_th_zi_neurons.uns.pop('pca')
adata_th_zi_neurons.uns.pop('neighbors')
adata_th_zi_neurons.obsp.pop('connectivities')
adata_th_zi_neurons.obsp.pop('distances')
adata_th_zi_neurons.varm.pop('PCs')

In [92]:
adata_th_zi_neurons

## Load & save SpaGCN domains results

In [104]:
# load in SpaGCN domain results
# temporarily a static file in '../code/resources' until I get a reproducible run setup for the spagcn capsule
spagcn_df = pd.read_parquet('/code/resources/spagcn_predicated_domains.parquet')
spagcn_df

In [119]:
for col in spagcn_df.columns:
    spagcn_df[col] = pd.Categorical(spagcn_df[col], 
                                    categories=sorted(spagcn_df[col].unique()), 
                                    ordered=False)


In [122]:
spagcn_df['SpaGCN_domains']

In [None]:
adata_th_zi_neurons
adata_th_zi_neurons.obs = adata_th_zi_neurons.obs.join(spagcn_domains_df['SpaGCN_domains'], on='cell_label')

In [107]:
spagcn_df.rename(columns={'res1pt4':'SpaGCN_domains'},inplace=True)
spagcn_df['SpaGCN_domains']

In [110]:
adata_th_zi_neurons.obs = adata_th_zi_neurons.obs.join(spagcn_df['SpaGCN_domains'], on='cell_label')
adata_th_zi_neurons

In [None]:

adata_log2.obs['SpaGCN_domains'] = adata_log2.obs['SpaGCN_domains'].cat.add_categories('no data').fillna('no data')

adata_log2.obs['SpaGCN_domains']

In [115]:
adata_th_zi_neurons.obs['SpaGCN_domains'] = adata_th_zi_neurons.obs['SpaGCN_domains'].cat.add_categories('no data').fillna('no data')
adata_th_zi_neurons.obs['SpaGCN_domains']

In [116]:
spg_domain_cats = adata_th_zi_neurons.obs['SpaGCN_domains'].cat.categories
print(f'{spg_domain_cats=}')
spg_palette_sns = sns.color_palette(cc.glasbey, n_colors=len(spg_domain_cats))

# set the 'no data' category color to white so it doesn't show up in cirro
spg_palette_sns[-1] = (1.0, 1.0, 1.0)

# need RGB dict for sns plotting to check colors
palette_dict_sns = dict(zip(spg_domain_cats, spg_palette_sns))

# need hex strings for cirro
spg_palette_cirro = list(spg_palette_sns.as_hex())
print(spg_palette_cirro)

In [117]:
fig = plt.figure(figsize=(20,15))
ax = fig.gca()
sns.scatterplot(adata_th_zi_neurons.obs, ax=ax, x='x_cirro', y='y_cirro', 
                hue='SpaGCN_domains', s=10, palette=palette_dict_sns, 
                linewidth=0, legend=False)
plt.axis('equal')

In [118]:
adata_th_zi_neurons.obs['SpaGCN_domains']

In [29]:
adata_log2.uns['SpaGCN_domains_color'] = spg_palette_cirro

# Load & save NSF results

In [30]:
adata_nsf = ad.read_zarr("/root/capsule/data/nsf_2000_adata/nsf_2000_adata.zarr")

In [31]:
nsf_cols = ['nsf_tot', 'nsf0', 'nsf1', 'nsf2', 'nsf3', 'nsf4', 
            'nsf5', 'nsf6', 'nsf7', 'nsf8', 'nsf9', 'nsf10', 
            'nsf11', 'nsf12', 'nsf13', 'nsf14', 'nsf15', 'nsf16', 
            'nsf17', 'nsf18', 'nsf19', 'nsf20', 'nsf21', 'nsf22', 
            'nsf23', 'nsf24', 'nsf25', 'nsf26', 'nsf27', 'nsf28', 
            'nsf29']
nsf_df = adata_nsf.obs[nsf_cols].copy()
nsf_df

In [32]:
adata_log2.obs = adata_log2.obs.join(nsf_df, on='cell_label')
adata_log2.obs.head(3)

In [33]:
adata_log2.obs[nsf_cols] = adata_log2.obs[nsf_cols].fillna(0)

In [34]:
adata_log2.obs.head(5)

## Clean up obs

In [35]:
cols_to_remove = ['parcellation_division', 'parcellation_index', 
                  # 'parcellation_structure','parcellation_substructure', 
                  'x_ccf', 'y_ccf', 'z_ccf', 'x_cirro', 'y_cirro']

adata_log2.obs.drop(columns=cols_to_remove, inplace=True)

## Save as h5ad

In [36]:
adata_log2

In [39]:
adata_log2.write('/results/wmb_abc_atlas_v20230830_th_nsf_spagcn_for_cirro_log2CPM.h5ad', compression="gzip")

In [40]:
adata.write('/results/wmb_abc_atlas_v20230830_th_nsf_spagcn_for_cirro_raw.h5ad', compression="gzip")

In [41]:
adata_raw = adata_log2.copy()

In [42]:
adata_raw.X = adata.X.copy()

In [43]:
adata_raw.write('/results/wmb_abc_atlas_v20230830_th_nsf_spagcn_for_cirro_raw.h5ad', compression="gzip")

In [46]:
adata_MKlog2X = abc.load_adata_thalamus(version=version, 
                             transform='log2', # will manually norm+log2 later
                             subset_to_TH_ZI=True, 
                             with_metadata=False, 
                             flip_y=True,
                             round_z=True,
                             with_colors=False)

adata_MKlog2X_th_zi_neurons = abc.filter_by_class_thalamus(adata_MKlog2X, 
                                                           filter_nonneuronal=True,
                                                           filter_midbrain=True)

# filter to thalamus boundaries (add a buffer here if wanted)
filter_buffer = 0  # 5
realigned=False
obs_filtered_MKlog2X = abc.filter_by_thalamus_coords(adata_MKlog2X_th_zi_neurons.obs.copy(), 
                                             realigned=realigned, 
                                             buffer=filter_buffer)
adata_MKlog2X_th_zi_neurons = adata_MKlog2X_th_zi_neurons[obs_filtered_MKlog2X.index].copy()

gene_list = [gene for gene in adata_MKlog2X_th_zi_neurons.var_names if 'Blank' not in gene]
adata_MKlog2X_th_zi_neurons = adata_MKlog2X_th_zi_neurons[:,gene_list]

adata_MKlog2X_th_zi_neurons

In [47]:
adata_log2CPV = adata_log2.copy()
adata_log2CPV.X = adata_MKlog2X_th_zi_neurons.X.copy()

In [49]:
adata_log2CPV.obs.rename(columns={'SpaGCN_domains':'spagcn'},inplace=True)

In [51]:
adata_log2CPV.uns['spagcn_colors'] = adata_log2CPV.uns['SpaGCN_domains_color']

In [52]:
adata_log2CPV.write('/results/wmb_abc_atlas_v20230830_th_nsf_spagcn_for_cirro_log2CPV.h5ad', compression="gzip")