In [None]:
import os
import inspect
import seaborn
import matplotlib
import matplotlib.pyplot as plt
import torch
import scanpy as sc
from tqdm import tqdm
import sys
import pickle
import PyComplexHeatmap as pch
import scvi
import IPython
import pandas as pd
import scipy
import numpy as np
import itertools
import xarray as xr
sc.settings.figdir=os.path.expanduser('/allen/programs/celltypes/workgroups/rnaseqanalysis/HMBA/Aim1_Atlases/BasalGanglia_paper_package/analysis/dev_map')
sc._settings.settings._vector_friendly=True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    print("GPU is available")
    print("Number of GPUs:", torch.cuda.device_count())
    print("GPU Name:", torch.cuda.get_device_name(0))
else:
    print("GPU is not available")

sys.path.append('/home/matthew.schmitz/utils/mts-utils/')
from genomics import sc_analysis


import antipode

gs1=matplotlib.colors.ListedColormap(sc.pl.palettes.godsnot_102,name='godsnot_102')
try:
    matplotlib.colormaps.register(name='godsnot_102',cmap=gs1)
except:
    pass


In [None]:
table_dir = '/allen/programs/celltypes/workgroups/rnaseqanalysis/HMBA/Aim1_Atlases/BasalGanglia_paper_package/anno_tables/RNA'

In [None]:
batch_key='load_name'
species_key='organism'
donor_key='donor_id'
layer_key='UMIs'
leaf_key='Group'
MDE_KEY = "X_umap_species_integrated"
sex_key = "self_reported_sex"

In [None]:
adata = sc.read('/allen/programs/celltypes/workgroups/rnaseqanalysis/HMBA/Aim1_Atlases/BasalGanglia_paper_package/data/xspecies/Consensus_HMBA_basalganglia_AIT_pre-print.h5ad')

In [None]:
groups = pd.read_csv('/allen/programs/celltypes/workgroups/rnaseqanalysis/HMBA/Aim1_Atlases/BasalGanglia_paper_package/anno_tables/consensus_annotation_groups.tsv',header=None)[0]

In [None]:
adata.obs['Group'] = adata.obs['Group'].cat.reorder_categories(groups)

In [None]:
adata.obs[sex_key] = adata.obs[sex_key].astype(str)
adata.obs.loc[adata.obs[species_key]=='Marmoset',sex_key] = adata.obs.loc[adata.obs[species_key]=='Marmoset','donor_id']
adata.obs.loc[adata.obs[species_key]=='Marmoset',sex_key]
adata.obs['donor_id'].replace({'cjCroissant':'male','cjNutmeg':'male','cjJellybean':'female','cjRambo':'female'})
adata.obs[sex_key] = adata.obs[sex_key].replace({'cjCroissant':'male','cjNutmeg':'male','cjJellybean':'female','cjRambo':'female'}).astype('string')

In [None]:
adata.obs[species_key] = adata.obs[species_key].astype('category')
species_colors = {'Human':'#377eb8','Macaque':'#4daf4a','Marmoset':'#FF5F5D','Mouse':'#ffa300'}
adata.uns[f'{species_key}_colors'] = [species_colors[x] for x in adata.obs[species_key].cat.categories]

In [None]:
adata.obs['sex'] = adata.obs[sex_key].astype(str).str.lower().astype('category')
sex_key = 'sex'

In [None]:
colors = pd.read_csv('/allen/programs/celltypes/workgroups/rnaseqanalysis/HMBA/Aim1_Atlases/BasalGanglia_paper_package/anno_tables/Group_colors.txt',header=None,sep='\t')
adata.uns['Group_color_dict'] = dict(zip(colors[0],colors[1]))
adata.uns['Group_colors'] = [adata.uns['Group_color_dict'].get(x,"#777777") for x in adata.obs['Group'].cat.categories]

In [None]:
sc.pl.embedding(adata,color='Group',basis='X_umap')

In [None]:
sc.pl.embedding(adata[adata.obs['Group'].str.contains('MSN|NUDAP',regex=True),:],color='Group',basis='X_umap',save='_MSNs_only.pdf')

In [None]:
sc.pl.embedding(adata[adata.obs['Group'].str.contains('MSN|NUDAP',regex=True),:],color=species_key,basis='X_umap',save='_species_MSNs_only.pdf')

In [None]:
if not os.path.exists(os.path.join(table_dir, "species_group_means.netcdf")):
    # Get real means and xero proportions
    log_real_means,real_mean_levels=antipode.model_functions.get_real_leaf_means(adata,species_key,leaf_key,layer=layer_key)
    real_means=pd.DataFrame(log_real_means.mean(0),columns=adata.var.index,index=real_mean_levels[leaf_key])
    real_means=real_means.loc[real_mean_levels[leaf_key],:]
    aggr_zeros=antipode.model_functions.group_aggr_anndata(adata,[species_key,leaf_key],layer=layer_key,agg_func=antipode.model_functions.prop_zeros,normalize=True)
    
    coords = {species_key:list(adata.obs[species_key].cat.categories),'Group':list(adata.obs['Group'].cat.categories),'var':list(adata.var.index.astype(str))}
    lrm = xr.DataArray(log_real_means,coords=coords)
    az = xr.DataArray(aggr_zeros[0],coords=coords)
    ads = xr.Dataset({'scalars':lrm,'proportions':az})
    ads.to_netcdf(os.path.join(table_dir, "species_group_means.netcdf"), engine="scipy")
else:
    ads = xr.open_dataset(os.path.join(table_dir, "species_group_means.netcdf"), engine="scipy")

In [None]:
# hann mapping totally failed
# import v1utils
# mapping_df = v1utils.import_hann('/allen/programs/celltypes/workgroups/rnaseqanalysis/HMBA/Aim1_Atlases/BasalGanglia_paper_package/data/human/Human_HMBA_basalganglia_AIT_pre-print_WB_MAPPING/hann_results.json')
# # mapping_df.index = list(adata.obs['cell_id'])
# mapping_df = mapping_df.loc[~mapping_df.index.duplicated(),:]
# # ["class_label","subclass_label","cluster_label"]
# for x in mapping_df.columns:
#     adata.obs[x]=mapping_df[x]
#     adata.obs[x] = [str(i) for i in adata.obs[x]]
#     adata.obs[x] = adata.obs[x].astype('category')

In [None]:
dads = xr.open_dataset('/home/matthew.schmitz/Matthew/models/1.9.1.8.5_Dev_final_600clusters/species_class_means.netcdf', engine="scipy")

In [None]:
adult_means = np.exp(ads['scalars'].mean(species_key)-ads['scalars'].mean(species_key).max(leaf_key)).to_dataframe().unstack().droplevel(0, axis=1) #scaled linear-space pseudobulk
dev_means = np.exp(dads['scalars'].mean('species')-dads['scalars'].mean('species').max('Initial_Class_markers_level_2')).to_dataframe().unstack().droplevel(0, axis=1) #scaled linear-space pseudobulk

In [None]:
tf_genes=pd.read_csv('/home/matthew.schmitz/Matthew/utils/zizhens_tf_code.txt',sep='\t')
include_genes=list(tf_genes['gene'].str.upper())

In [None]:
# gl_path='/allen/programs/celltypes/workgroups/rnaseqanalysis/EvoGen/Team/Matthew/data/taxtest/gene_lists/'
# tf_genes=pd.read_csv(os.path.join(gl_path,'TFs_lambert_pmid29425488_1.01.txt'),sep='\t')
# include_genes=list(tf_genes['hgnc_symbol'])

In [None]:
'' in include_genes

In [None]:
include_genes = include_genes + ['FOXP1','PBX3','TSHZ1','CASZ1','SOX1','SKOR1','SKOR2','DRD1','DRD2'] #genes missing from zizhen's list

In [None]:
include_genes = list(set(include_genes) & set(ads.coords['var'].data) & set(dads.coords['var'].data))

In [None]:
# include_genes = list(set(dev_means.loc[:,include_genes].std(0).sort_values().index[-500:]) & set(adult_means.loc[:,include_genes].std(0).sort_values().index[-500:]))
# print(len(include_genes))

In [None]:
print(len(include_genes))

In [None]:
dev_means = dev_means.loc[~dev_means.index.str.contains('^Cb|^NPC|^Schwann|^L-Q|Progen_Astro|Progen_Tany|Progen_FP|Progen_Hem')]

In [None]:
corrs=sc_analysis.corr2_coeff(dev_means.loc[:,include_genes],adult_means.loc[:,include_genes])
corrs=np.nan_to_num(corrs,0.)
corrs=pd.DataFrame(corrs,index=dev_means.index,columns=adult_means.index)

In [None]:
seaborn.clustermap(adult_means.loc[:,adult_means.columns.isin(include_genes)],col_cluster=True,xticklabels=True,yticklabels=True)
plt.show()
seaborn.clustermap(dev_means.loc[:,dev_means.columns.isin(include_genes)],col_cluster=True)
plt.show()

In [None]:
g = seaborn.clustermap(corrs,cmap='coolwarm',figsize=(30,30),xticklabels=True,yticklabels=True,row_cluster=True,col_cluster=True)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize = 6)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), fontsize = 6)
corrs.to_csv(os.path.join(sc.settings.figdir,'dev_corrs.csv'))
plt.show()

In [None]:
neuron_corrs = corrs.loc[~corrs.index.str.contains('Astro|Oligo|OPC|Mesench|Immune|CSF1|P2RY12|Ependy|Endo|Angio|SMC|Choroid|Peri|Hypendy|VLMC|ABC'),corrs.columns.str.contains('GABA|Glut|Dopa|Gly|MSN|Core|Shell')]
g = seaborn.clustermap(neuron_corrs,cmap='coolwarm',figsize=(30,30),xticklabels=True,yticklabels=True,row_cluster=True,col_cluster=True)
g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xmajorticklabels(), fontsize = 6)
g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_ymajorticklabels(), fontsize = 6)
corrs.to_csv(os.path.join(sc.settings.figdir,'dev_neuron_corrs.csv'))
plt.show()

In [None]:
adata.obs['Initial_Class_markers_level_2'] = adata.obs['Group'].replace(corrs.idxmax(0).to_dict())
adata.obs['Initial_Class_markers_level_2'] = adata.obs['Initial_Class_markers_level_2'].str.replace('_[0-9]+$','',regex=True)

In [None]:
sc.pl.embedding(adata,color='Initial_Class_markers_level_2',basis='X_umap',legend_loc='on data',legend_fontsize=5,save='initial_class.pdf')

In [None]:
pd.set_option('display.max_rows', 100)
adata.obs.loc[(adata.obs['Neighborhood']!='Nonneuron')][['Group','Initial_Class_markers_level_2']].drop_duplicates().reset_index().drop('index',axis=1).to_csv(os.path.join(sc.settings.figdir,'mappings.csv'))

In [None]:
sc.pl.embedding(adata,color=leaf_key,basis='X_umap',legend_loc='on data',legend_fontsize=4)