In [None]:
import numpy as np
import pandas as pd
import os, sys
import anndata
from plotnine import *
import seaborn as sns
import scanpy as sc

## Paths

In [None]:
save_dir='/gstore/data/resbioai/grelu/decima/20240823'
matrix_file=os.path.join(save_dir, 'data.h5ad')
out_dir = os.path.join(save_dir, 'gwas_44traits')
pred_file = os.path.join(out_dir, 'gwas_variant_predictions_matched.h5ad')

pos_file = os.path.join(out_dir, 'positive_variants/positive_variants_and_traits.csv')
matched_neg_file = os.path.join(out_dir, 'negative_variants/negatives_matched.csv')

## Load data

In [None]:
ad=anndata.read_h5ad(matrix_file)
gwas=anndata.read_h5ad(pred_file)

In [None]:
match = pd.read_csv(matched_neg_file)
match = match.groupby(['pos_variant', 'gene']).variant.apply(list).reset_index()

traits = pd.read_csv(pos_file)

In [None]:
traits = traits[['variant', 'trait_name']].merge(match, left_on='variant', right_on='pos_variant').drop(columns='variant_x')
traits.head(2)

## Filter variant-gene combinations for Decima

In [None]:
traits = traits.merge(gwas.obs[['variant', 'gene', 'pval']].rename(columns={'variant':'pos_variant'})).copy()

## Filter p-values

In [None]:
traits= traits[traits.pval < .01]
len(traits)

## Assign traits to categories

In [None]:
category_mapping = {
        'hdl': 'lipids',
        'ldl': 'lipids',
        'cholesterol': 'lipids',
        'triglycerides': 'triglycerides',
    
        'cardiovascular_disease': 'cardiovascular_disease',
        'coronary_artery_disease': 'coronary_artery_disease',
        'hypertension': 'hypertension',

        'fasting_glucose': 'Metabolic Traits',
        'glucose': 'Metabolic Traits',
        'hba1c': 'Metabolic Traits',
        't2d': 'Metabolic Traits',
    
        'whr_adj_bmi': 'BMI-related Traits',
        'body_mass_index': 'BMI-related Traits',    
    
        'mean_corpuscular_hemoglobin': 'Blood-related Traits',
        'red_count': 'Blood-related Traits',
        'platelet_count': 'Blood-related Traits',
        'red_blood_cell_width': 'Blood-related Traits',
    
        'autoimmune_disease': 'Autoimmune and Inflammatory Diseases',
        'crohns_disease': 'Autoimmune and Inflammatory Diseases',
        'inflammatory_bowel_disease': 'Autoimmune and Inflammatory Diseases',
        'multiple_sclerosis': 'Autoimmune and Inflammatory Diseases',
        'lupus': 'Autoimmune and Inflammatory Diseases',
        'asthma': 'Autoimmune and Inflammatory Diseases',
        'eczema': 'Autoimmune and Inflammatory Diseases',

        'hypothyroidism': 'Endocrine and Reproductive Traits',
        'age_of_menarche': 'Endocrine and Reproductive Traits',
        'bone_mineral_density': 'Bone Health',
        'resp_ent': 'Respiratory Conditions',

        'alzheimers_disease': 'Neurological and Psychiatric Disorders',
        'schizophrenia': 'Neurological and Psychiatric Disorders',
        'bipolar_disorder': 'Neurological and Psychiatric Disorders',
        'neuroticism': 'Neurological and Psychiatric Disorders',
    
        'positive_mood_disorder': 'Neurological and Psychiatric Disorders',
        'college_education': 'Cognitive and Educational Traits',
        'intelligence': 'Cognitive and Educational Traits',
        'years_education': 'Cognitive and Educational Traits',
        'college_educatiojn':'Cognitive and Educational Traits',

        'height':'height',
    }

In [None]:
traits['category'] = traits.trait_name.map(category_mapping)

## Get scores for each variant

In [None]:
ad.obs['celltype_coarse'] = ad.obs.apply(lambda row:row.celltype_coarse if row.dataset=='skin_atlas' else row.cell_type, axis=1)

In [None]:
gwas.var = gwas.var.merge(ad.obs[['cell_type', 'celltype_coarse']].drop_duplicates(), left_index=True, right_on='cell_type').set_index('cell_type')

In [None]:
gwas_agg = sc.get.aggregate(gwas, by='celltype_coarse', func='mean', axis=1)

In [None]:
gwas_agg.X = np.abs(gwas_agg.layers['mean'])

In [None]:
traits['scores'] = traits.pos_variant.apply(lambda x: np.array(gwas_agg[gwas_agg.obs.variant==x].X).squeeze())
traits['neg_scores'] = traits.variant_y.apply(lambda x: np.array(gwas_agg[gwas_agg.obs.variant.isin(x)].X.mean(0)))
traits['deltas'] = traits.apply(lambda row: row.scores - row.neg_scores, axis=1)
traits['delta_z'] = traits.deltas.apply(lambda x: (x - x.mean())/x.std())

In [None]:
traits.category.value_counts()

## Plot heatmap

In [None]:
res = pd.DataFrame(np.vstack(traits.delta_z), index=traits.category, columns=gwas_agg.var_names)
res = res.reset_index().groupby('category').apply(lambda x: x.mean(0))
res = res.loc[['Autoimmune and Inflammatory Diseases', 'Blood-related Traits', 'height', 'triglycerides', 
                    'Neurological and Psychiatric Disorders', 'Respiratory Conditions']]

In [None]:
sel_cts = list(np.unique([
    'CD8-positive, alpha-beta T cell', 'mucosal invariant T cell', 'mature NK T cell', 'regulatory T cell', 'natural killer cell', 'CD4-positive, alpha-beta T cell', 'innate lymphoid cell',
    'megakaryocyte', 'erythroid lineage cell', 'common lymphoid progenitor', 'memory B cell', 'mature NK T cell', 
    'Oligodendrocyte', 'Committed oligodendrocyte precursor', 'Hippocampal dentate gyrus', 'MGE interneuron', 
    'megakaryocyte', 'common lymphoid progenitor', 'naive B cell', 'club cell', 'lung secretory cell',
    'VEC', 'blood vessel endothelial cell', 'vascular associated smooth muscle cell', 'capillary endothelial cell', 'Fibroblasts',
    'enterocyte','hepatocyte']))

In [None]:
g=sns.clustermap(res.loc[:, sel_cts].T,
               cmap='RdBu_r', figsize=(5.5, 10), col_cluster=False, center=0)
g.fig.subplots_adjust(right=0.7)
g.ax_cbar.set_position((0.8, .2, .03, .4))

In [None]:
#traits_filt = traits[traits.category.isin(['Blood-related Traits', 'height', 'Autoimmune and Inflammatory Diseases', 'Triglycerides'])]
traits_filt = traits_filt[(traits.unique_trait)]
print(len(traits_filt))
traits_filt[['category', 'trait_name']].value_counts().reset_index().sort_values(['category', 'count'], ascending=False)

In [None]:
res = pd.DataFrame(np.vstack(traits_filt.delta_z), index=traits_filt.category, columns=gwas_agg.var_names)
res = res.reset_index().groupby('category').apply(lambda x: x.mean(0))

In [None]:
for c in res.index:
    print(c)
    print(gwas_agg.var_names[res.loc[c].argsort()[::-1][:9]].tolist())
    print("")

In [None]:
sel_cts =['T cell', 'CD8-positive, alpha-beta T cell', 'NK', 'mucosal invariant T cell', 'mature NK T cell', 'regulatory T cell', 
          'erythroid lineage cell', 'megakaryocyte', 'common lymphoid progenitor', 'hematopoietic stem cell',
          'hepatocyte', 'enterocyte',
          'VEC', 'blood vessel endothelial cell', 'capillary endothelial cell', 
          'vascular associated smooth muscle cell', 'fibroblast']


In [None]:
sns.clustermap(res.loc[:, sel_cts], cmap='viridis', figsize=(10, 5), row_cluster=False)