# Aggregate Borzoi and Decim apredictions for GWAS-matched negative control variants

In [None]:
import anndata
import pandas as pd
import numpy as np
import os
import sys

## Paths

In [None]:
matrix_file='/gstore/data/resbioai/grelu/decima/20240823/data.h5ad'

gwas_dir='/gstore/data/resbioai/grelu/decima/20240823/gwas_44traits'
pos_dir = os.path.join(gwas_dir, 'positive_variants')
neg_dir = os.path.join(gwas_dir, 'negative_variants')

matching_file = os.path.join(neg_dir, 'negatives_matched.csv')
neg_file = os.path.join(neg_dir, 'negative_variants_processed.csv')

decima_preds_file = os.path.join(neg_dir, 'decima_preds.npy')
decima_pos_file = os.path.join(pos_dir, 'decima_preds_agg.h5ad')

## Load data

In [None]:
ad = anndata.read_h5ad(matrix_file)

## Load variant-gene pairs

In [None]:
neg = pd.read_csv(neg_file)

## Load positive-to-negative mapping

In [None]:
match = pd.read_csv(matching_file)

## Load negative predictions

In [None]:
decima_preds = np.load(decima_preds_file)
decima_preds.shape

## Load positive predictions

In [None]:
decima_pos_preds = anndata.read_h5ad(decima_pos_file)

print(decima_pos_preds.shape)

## average over all tracks of the same cell type

In [None]:
idx_map = ad.obs.reset_index().groupby(['cell_type']).agg({'index': tuple}).reset_index().dropna()
idx_map.head()

In [None]:
decima_preds = idx_map['index'].apply(lambda idxs:decima_preds[:, ad.obs_names.isin(idxs)].mean(1))
decima_preds = np.stack(decima_preds).T
decima_preds.shape

In [None]:
var = pd.DataFrame(index=idx_map.cell_type.astype(str))

## Subset the variant-gene pairs that are matched to the respective positive pairs

In [None]:
decima_neg_pairs = match[['variant', 'pos_variant', 'gene']].merge(
    decima_pos_preds.obs[['variant', 'gene']], left_on=['pos_variant', 'gene'], right_on=['variant', 'gene']).drop(
    columns='variant_y').rename(columns={'variant_x':'variant'})

In [None]:
len(decima_neg_pairs)

In [None]:
gene_neg_pairs = match[['variant', 'pos_variant', 'gene']].merge(
    gene_pos_preds.obs[['variant', 'gene']], left_on=['pos_variant', 'gene'], right_on=['variant', 'gene']).drop(
    columns='variant_y').rename(columns={'variant_x':'variant'})

tss_neg_pairs = match[['variant', 'pos_variant', 'gene']].merge(
    tss_pos_preds.obs[['variant', 'gene']], left_on=['pos_variant', 'gene'], right_on=['variant', 'gene']).drop(
    columns='variant_y').rename(columns={'variant_x':'variant'})

len(decima_neg_pairs), len(gene_neg_pairs), len(tss_neg_pairs)

In [None]:
sel_decima = neg.reset_index().merge(
    decima_neg_pairs[['variant', 'gene']])['index'].tolist()

In [None]:
len(sel_decima)

In [None]:
sel_gene = neg.reset_index().merge(
    gene_neg_pairs[['variant', 'gene']])['index'].tolist()

sel_tss = neg.reset_index().merge(
    tss_neg_pairs[['variant', 'gene']])['index'].tolist()

## Make anndata

In [None]:
decima_preds = anndata.AnnData(
    X=decima_preds[sel_decima],
    var=var,
    obs=neg.iloc[sel_decima].copy().reset_index(drop=True))
decima_preds.shape

## Saved

In [None]:
decima_out_file = os.path.join(neg_dir, 'decima_preds_agg.h5ad')
decima_preds.write_h5ad(decima_out_file)