# Visualize example variants

In [None]:
import numpy as np
import pandas as pd
import os, sys
import anndata
from plotnine import *
import wandb 
import torch

sys.path.append('/code/decima/src/decima/')
from visualize import plot_gene_scatter
from read_hdf5 import extract_gene_data
from lightning import LightningModel

from grelu.transforms.prediction_transforms import Aggregate
from grelu.visualize import plot_attributions
from grelu.interpret.motifs import scan_sequences, compare_motifs
from grelu.sequence.format import strings_to_one_hot
from captum.attr import InputXGradient
import scanpy as sc

from matplotlib import pyplot as plt
import seaborn as sns
%matplotlib inline

from grelu.io.motifs import read_meme_file
from tangermeme.plot import plot_pwm
hocomoco = read_meme_file('../H12CORE_meme_format.meme')

## Paths

In [None]:
save_dir='/gstore/data/resbioai/grelu/decima/20240823'
matrix_file = os.path.join(save_dir, 'data.h5ad')
h5_file = os.path.join(save_dir, 'data.h5')

out_dir = os.path.join(save_dir, 'gwas_44traits')
pred_file = os.path.join(out_dir, 'gwas_variant_predictions_matched.h5ad')

pos_dir = os.path.join(out_dir, 'positive_variants')
neg_dir = os.path.join(out_dir, 'negative_variants')

trait_file = os.path.join(pos_dir, 'positive_variants_and_traits.csv')
matched_neg_file = os.path.join(neg_dir, 'negatives_matched.csv')

## Load Decima predictions

In [None]:
ad = anndata.read_h5ad(matrix_file)

In [None]:
ad.var['tss'] = ad.var.apply(lambda row: row.gene_start if row.strand=='+' else row.gene_end, axis=1)
ad.var.loc[ad.var.ensembl_canonical_tss.isna(), 'ensembl_canonical_tss'] = ad.var.loc[ad.var.ensembl_canonical_tss.isna(), 'tss']

In [None]:
ad.obs['celltype_coarse'] = [row.celltype_coarse if row.dataset=='skin_atlas' else row.cell_type for row in ad.obs.itertuples()]

## Load variant predictions

In [None]:
gwas_ = anndata.read_h5ad(pred_file)

In [None]:
gwas_.var = gwas_.var.merge(ad.obs[['cell_type', 'celltype_coarse']].drop_duplicates(), left_index=True, right_on='cell_type').set_index('cell_type')

In [None]:
gwas_.X = np.abs(gwas_.X)

In [None]:
gwas_.obs.index = gwas_.obs_names = range(len(gwas_.obs))

In [None]:
gwas = sc.get.aggregate(gwas_, by='celltype_coarse', func='mean', axis=1)
gwas.X = gwas.layers['mean']

In [None]:
gwas_.X = gwas_.layers['gene_exp']

In [None]:
exp = sc.get.aggregate(gwas_, by='celltype_coarse', func='mean', axis=1)

In [None]:
gwas.layers['gene_exp'] = exp.layers['mean']

## Annotate Decima predictions

In [None]:
gwas.obs['nearest_gene'] = gwas.obs.apply(lambda row: np.abs(row.pos - ad.var.loc[ad.var.chrom==row.chrom, 'tss']).sort_values().index[0], axis=1)
gwas.obs['nearest_canonical_gene'] = gwas.obs.apply(lambda row: np.abs(row.pos - ad.var.loc[ad.var.chrom==row.chrom, 'ensembl_canonical_tss']).sort_values().index[0], axis=1)

In [None]:
gwas.obs['tss_dist'] = gwas.obs.apply(lambda row: row.pos - ad.var.loc[row.gene, 'tss'], axis=1)
gwas.obs['canonical_tss_dist'] = gwas.obs.apply(lambda row: row.pos - ad.var.loc[row.gene, 'ensembl_canonical_tss'], axis=1)

In [None]:
gwas.obs['same'] = gwas.obs.gene.astype(str)==gwas.obs.nearest_gene.astype(str)
gwas.obs['same_canonical'] = gwas.obs.gene.astype(str)==gwas.obs.nearest_canonical_gene.astype(str)

In [None]:
gwas.obs['gene_pearson'] = gwas.obs.gene.apply(lambda x: ad.var.loc[x, 'pearson']).astype(float)
gwas.obs['gene_sf_pearson'] = gwas.obs.gene.apply(lambda x: ad.var.loc[x, 'size_factor_pearson']).astype(float)

## Annotate positive variants

In [None]:
match = pd.read_csv(matched_neg_file)
match = match.groupby(['pos_variant', 'gene']).variant.apply(list).reset_index()

traits = pd.read_csv(trait_file)

In [None]:
traits = traits[['variant', 'trait_name', 'vep']].merge(match[['pos_variant', 'gene', 'variant']], left_on='variant', right_on='pos_variant').drop(columns='variant_x')
traits = traits.merge(gwas.obs[['variant', 'rsid', 'gene', 'pval', 'score', 'has_eqtl', 'is_best', 'same', 'same_canonical', 'gene_pearson', 'gene_sf_pearson', 'tss_dist', 'canonical_tss_dist']].rename(
    columns={'variant':'pos_variant'})).copy()
traits.head(2)

In [None]:
traits['neg_score'] = traits.apply(lambda row:
    gwas.obs.loc[(gwas.obs.variant.isin(row.variant_y)) & (gwas.obs.gene==row.gene), 'score'].mean(), axis=1)
traits['delta'] = traits.apply(lambda row: (row.score - row.neg_score), axis=1)
traits.head(2)

## Load models

In [None]:
wandb.login(host="https://genentech.wandb.io")
ckpts=[
'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/kugrjb50/checkpoints/epoch=3-step=2920.ckpt',
'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/i68hdsdk/checkpoints/epoch=2-step=2190.ckpt',
'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/0as9e8of/checkpoints/epoch=7-step=5840.ckpt',
'/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/i9zsp4nm/checkpoints/epoch=8-step=6570.ckpt',
]
models = [LightningModel.load_from_checkpoint(ckpt).eval() for ckpt in ckpts]

## Add per-cell type deltas

In [None]:
traits['deltas'] = traits.apply(
    lambda row: (gwas[gwas.obs.variant==row.pos_variant].X - gwas[(gwas.obs.variant.isin(row.variant_y)) & (gwas.obs.gene==row.gene)].X).mean(0),
    axis=1)
traits['delta_z'] = traits.deltas.apply(lambda x: (x - x.mean())/x.std())
traits['top5'] = traits.delta_z.apply(lambda x: gwas.var_names[np.argsort(x)[-5:]].tolist())
traits['ct_z2'] = traits.delta_z.apply(lambda x: gwas.var_names[x>2].tolist())

## Select top positive variants

In [None]:
top = traits[(traits.pval < .01)].sort_values('delta', ascending=False)

## Examples

In [None]:
def get_attrs(gene, ref_allele, alt_allele, rel_pos, top_cts, h5_file=h5_file, models=models):
    seq, mask = extract_gene_data(h5_file, gene, merge=False)
    tss_pos = np.where(mask[0] == 1)[0][0] - 2
    device = torch.device(0)
    ref_seq = seq.clone()
    alt_seq = seq.clone()

    ref_seq[:, rel_pos] = strings_to_one_hot(ref_allele).squeeze()
    alt_seq[:, rel_pos] = strings_to_one_hot(alt_allele).squeeze()

    tasks = ad.obs_names[ad.obs.cell_type.isin(top_cts)].tolist()
    ref_inputs = torch.vstack([ref_seq, mask]).to(device)
    alt_inputs = torch.vstack([alt_seq, mask]).to(device)

    attr_ref = []
    attr_alt = []
    
    for model in models:
        model.add_transform(Aggregate(tasks=tasks, task_aggfunc="mean", model=model))
        attributer = InputXGradient(model.to(device))
        with torch.no_grad():
            attr_ref.append(attributer.attribute(ref_inputs).cpu().numpy())
            attr_alt.append(attributer.attribute(alt_inputs).cpu().numpy())
    
    attr_ref = np.stack(attr_ref).mean(0)[:-1]
    attr_alt = np.stack(attr_alt).mean(0)[:-1]
    return attr_ref, attr_alt

## Individual examples

## rs138682554

https://www.genetics.opentargets.org/Variant/15_90884462_G_A/associations
eQTL: Artery tibial, Blood, Pancreas, Skin, Fibroblast, Macrophage, Monocyte

In [None]:
rs='rs138682554'
gene = top[top.rsid==rs].gene.tolist()[0]
traits = top[top.rsid==rs].trait_name.tolist()[0]
top_cts = top[top.rsid==rs].ct_z2.tolist()[0]
rel_pos = gwas.obs[gwas.obs.rsid==rs].rel_pos.tolist()[0]
ref_allele = gwas.obs.loc[gwas.obs.rsid==rs, 'ref_tx'].tolist()[0]
alt_allele = gwas.obs.loc[gwas.obs.rsid==rs, 'alt_tx'].tolist()[0]

df = pd.DataFrame({
    'ct': gwas.var_names,
    'vep': top[top.rsid==rs].deltas.tolist()[0],
    'exp':np.array(gwas[gwas.obs.rsid==rs].layers['gene_exp']).squeeze()
})
df['z>2'] = df.ct.isin(top_cts)

print(gene)
print(traits)
print(top_cts)
print(gwas.obs.loc[gwas.obs.rsid==rs, 'tss_dist'].tolist()[0])
display(gwas.obs.loc[gwas.obs.rsid==rs])

In [None]:
df['Cell Type'] = 'Other'
df.loc[df.ct.isin(['DC', 'Langerhans cell','conventional dendritic cell']), 'Cell Type'] = 'DC'
df.loc[df.ct.isin(['Macrophages', 'alveolar macrophage','macrophage']), 'Cell Type'] = 'Macrophages'
df.loc[df.ct.isin(['classical monocyte', 'intermediate monocyte','non-classical monocyte']), 'Cell Type'] = 'Monocytes'

In [None]:
(
 ggplot(df, aes(x='exp', y='vep', color='Cell Type')) + geom_point() + theme_classic() 
    + theme(figure_size=(3.2,2.2)) 
    + ylab('Predicted absolute logFC\n(Background subtracted)') + xlab('Measured gene\n    expression')
    + scale_color_manual(values=['red', 'blue', 'orange', 'gray'])
)

In [None]:
attr_ref, attr_alt = get_attrs(gene,ref_allele,alt_allele, rel_pos, top_cts)

In [None]:
plot_attributions(attr_ref[:, rel_pos-10:rel_pos+10], figsize=(6,2), ylim=(-.55, 1.1),
                  highlight_positions=[10], alpha=.25)

In [None]:
plot_attributions(attr_alt[:, rel_pos-10:rel_pos+10], figsize=(6,2), ylim=(-.55, 1.1), highlight_positions=[10],
                 alpha=.25)

In [None]:
plot_pwm(hocomoco['SPI1.H12CORE.1.S.B'])

## rs8105903

https://www.genetics.opentargets.org/Variant/19_46784893_C_A/associations

eQTL: adipose, blood, thyyroid, tibial nerve, monocyte, fibroblast

In [None]:
rs='rs8105903'
gene = top[top.rsid==rs].gene.tolist()[0]
traits = top[top.rsid==rs].trait_name.tolist()[0]
top_cts = top[top.rsid==rs].ct_z2.tolist()[0]
rel_pos = gwas.obs[gwas.obs.rsid==rs].rel_pos.tolist()[0]
ref_allele = gwas.obs.loc[gwas.obs.rsid==rs, 'ref_tx'].tolist()[0]
alt_allele = gwas.obs.loc[gwas.obs.rsid==rs, 'alt_tx'].tolist()[0]

df = pd.DataFrame({
    'ct': gwas.var_names,
    'vep': top[top.rsid==rs].deltas.tolist()[0],
    'exp':np.array(gwas[gwas.obs.rsid==rs].layers['gene_exp']).squeeze()
})
df['z>2'] = df.ct.isin(top_cts)

print(gene)
print(traits)
print(top_cts)
print(gwas.obs.loc[gwas.obs.rsid==rs, 'tss_dist'].tolist()[0], ref_allele, alt_allele)

In [None]:
df=df.drop(index=40)

In [None]:
df['Cell Type'] = 'Other'
df.loc[df.ct.isin(['Fibroblasts']), 'Cell Type'] = 'Fibroblasts'
df.loc[df.ct.isin(['Macrophages', 'alveolar macrophage','macrophage']), 'Cell Type'] = 'Macrophages'
df.loc[df.ct.isin(['classical monocyte', 'intermediate monocyte','non-classical monocyte']), 'Cell Type'] = 'Monocytes'

In [None]:
(
 ggplot(df, aes(x='exp', y='vep', color='Cell Type')) + geom_point() + theme_classic() 
    + theme(figure_size=(3.2,2.2)) 
    + ylab('Predicted absolute logFC\n(Background subtracted)') 
    + xlab('Measured gene\n   expression')
    + scale_color_manual(['Red', 'Blue', 'Orange', 'Gray'])
)

In [None]:
attr_ref, attr_alt = get_attrs(gene,ref_allele,alt_allele, rel_pos, ['Fibroblasts', 'fibroblast'])

In [None]:
plot_attributions(attr_ref[:, rel_pos-10:rel_pos+10], figsize=(5.5,2), ylim=(-.17, .07),
                 highlight_positions=[10], alpha=.25)

In [None]:
plot_attributions(attr_alt[:, rel_pos-10:rel_pos+10], figsize=(5.5,2), ylim=(-.17, .07),
                 highlight_positions=[10], alpha=.25)

In [None]:
plot_pwm(hocomoco['ZEB2.H12CORE.0.P.B'])

In [None]:
cmp = compare_motifs(ref_seq='TGTCCAGGGTATT', alt_seq='TGTCCAGGTATT', motifs='hocomoco_v12')
cmp

## rs79755767

In [None]:
rs='rs79755767'
gene = top[top.rsid==rs].gene.tolist()[0]
traits = top[top.rsid==rs].trait_name.tolist()[0]
top_cts = top[top.rsid==rs].ct_z2.tolist()[0]
rel_pos = gwas.obs[gwas.obs.rsid==rs].rel_pos.tolist()[0]
ref_allele = gwas.obs.loc[gwas.obs.rsid==rs, 'ref_tx'].tolist()[0]
alt_allele = gwas.obs.loc[gwas.obs.rsid==rs, 'alt_tx'].tolist()[0]

df = pd.DataFrame({
    'ct': gwas.var_names,
    'vep': top[top.rsid==rs].deltas.tolist()[0],
    'exp':np.array(gwas[gwas.obs.rsid==rs].layers['gene_exp']).squeeze()
})
df['z>2'] = df.ct.isin(top_cts)

print(gene)
print(traits)
print(top_cts)
display(gwas.obs.loc[gwas.obs.rsid==rs])


In [None]:
df['Cell Type'] = 'Other'
df.loc[df.ct.isin(['erythroid lineage cell']), 'Cell Type'] = 'Erythroid   \nlineage cell'
df.loc[df.ct.isin(['hematopoietic stem cell']), 'Cell Type'] = 'Hematopoietic\nstem cell        '
df.loc[df.ct.isin(['megakaryocyte']), 'Cell Type'] = 'Megakaryocyte         '

(
 ggplot(df, aes(x='exp', y='vep', color='Cell Type')) + geom_point() + theme_classic() 
    + theme(figure_size=(3.5,2.2)) 
    + ylab('Predicted absolute logFC\n(background subtracted)') 
    + xlab('Measured gene\n  expression')
    + scale_color_manual(['red', 'blue', 'orange', 'gray'])
)

In [None]:
attr_ref, attr_alt = get_attrs(gene,ref_allele,alt_allele, rel_pos, ['hematopoietic stem cell', 'erythroid lineage cell'])
plot_attributions(attr_ref[:, rel_pos-10:rel_pos+10], figsize=(5.5,2),
                 alpha=.25, highlight_positions=[10], ylim=(-.15, .29))
plot_attributions(attr_alt[:, rel_pos-10:rel_pos+10], figsize=(5.5, 2),
                 alpha=.25, highlight_positions=[10], ylim=(-.15, .29))

In [None]:
cmp = compare_motifs(ref_seq='CGGGGTAACCGCCCGGCT', alt_seq='CGGGGTAACTGCCCGGCT', motifs='hocomoco_v12')
cmp

In [None]:
plot_pwm(hocomoco['RUNX2.H12CORE.1.S.B'])