In [None]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import sys
import bioframe as bf
import anndata

sys.path.append('/code/decima/src/decima/')
from variant import process_variants

## Paths

In [None]:
matrix_file='/gstore/data/resbioai/grelu/decima/20240823/data.h5ad'
h5_file='/gstore/data/resbioai/grelu/decima/20240823/data.h5'

gwas_dir='/gstore/data/resbioai/grelu/decima/20240823/gwas_44traits'
decima_preds_file = os.path.join(gwas_dir, 'positive_variants/decima_preds_agg.h5ad')
borzoi_gene_preds_file = os.path.join(gwas_dir, 'positive_variants/gene_preds_agg.h5ad')
borzoi_tss_preds_file = os.path.join(gwas_dir, 'positive_variants/tss_preds_agg.h5ad')

out_dir = os.path.join(gwas_dir, 'negative_variants')
neg_file = os.path.join(out_dir, 'negative_variants.csv')

## Load gene intervals

In [None]:
ad = anndata.read_h5ad(matrix_file)

## Load predictions on positive variants

In [None]:
decima_pos = anndata.read_h5ad(decima_preds_file).obs
borzoi_gene_pos = anndata.read_h5ad(borzoi_gene_preds_file).obs
borzoi_tss_pos = anndata.read_h5ad(borzoi_tss_preds_file).obs

decima_pos.shape, borzoi_gene_pos.shape, borzoi_tss_pos.shape

## Combine all positive variant-gene pairs

In [None]:
positives = pd.concat([
    decima_pos[['variant', 'rsid', 'pos', 'gene', 'gene_start_', 'gene_end_', 'strand']],
    borzoi_gene_pos[['variant', 'rsid', 'pos', 'gene', 'gene_start_', 'gene_end_', 'strand']],
    borzoi_tss_pos[['variant', 'rsid', 'pos', 'gene', 'gene_start_', 'gene_end_', 'strand']],
]).drop_duplicates()

## Load negative variants

In [None]:
neg = pd.read_csv(neg_file)
len(neg)

## Filter negatives that overlap with the gene intervals

In [None]:
print(len(neg))
neg = bf.overlap(neg, ad.var.reset_index(names='gene'), how='inner', 
    cols1=['chrom', 'pos', 'pos']).rename(columns={'gene_':'gene'})
print(len(neg))

## Match negatives and positives by distance

In [None]:
positives['tss'] = positives.apply(lambda row:row.gene_start_ if row.strand=='+' else row.gene_end_, axis=1)

In [None]:
positives['abs_tss_dist'] = np.abs(positives.pos-positives.tss)

In [None]:
def match_pos_variant(row, neg, n, min_dist=10, max_abs_tss_dist=150000):
    
    variant = row.variant
    gene = row.gene
    pos = row.pos
    tss = row.tss
    tss_dist = row.abs_tss_dist

    # Get matched negatives
    curr_neg = neg[neg.gene==gene].reset_index(drop=True).copy()
    curr_neg['pos_variant'] = variant

    # > min_dist to the pos variant
    curr_neg['dist_to_pos'] = abs(curr_neg.pos-pos)
    curr_neg = curr_neg[curr_neg.dist_to_pos > min_dist]

    # Calculate distance to the TSS
    curr_neg['tss_dist'] = curr_neg.pos-tss
    curr_neg['abs_tss_dist'] = np.abs(curr_neg['tss_dist'])
    curr_neg = curr_neg[curr_neg.abs_tss_dist > min_dist]
    curr_neg = curr_neg[curr_neg.abs_tss_dist < max_abs_tss_dist]

    # Match for distance to the TSS
    match_dist = np.abs(curr_neg['abs_tss_dist'] - tss_dist)
    
    # Sort
    sel = match_dist.sort_values().head(n).index.tolist()
    
    # Select
    curr_neg = curr_neg.loc[sel, :]
    return curr_neg

In [None]:
matched_neg = []
for row in tqdm(positives.itertuples()):
    matched_neg.append(match_pos_variant(row, neg, n=10))

matched_neg = pd.concat(matched_neg)

In [None]:
len(positives), len(positives.variant.unique()), len(matched_neg), len(matched_neg.variant.unique())

In [None]:
matched_neg[['pos_variant', 'gene']].value_counts().min(), matched_neg[['pos_variant', 'gene']].value_counts().max()

In [None]:
np.abs(positives.abs_tss_dist).describe()

In [None]:
np.abs(matched_neg.abs_tss_dist).describe()

## Process variants for prediction

In [None]:
matched_neg = process_variants(matched_neg, ad)

In [None]:
matched_neg.head()

## Get unique variants

In [None]:
variant_gene_pairs = matched_neg[['chrom', 'pos', 'rsid', 'ref', 'alt', 'vep', 'maf', 'variant', 
        'gene', 'chrom_', 'gene_type_', 'mean_counts_', 'n_tracks_', 'gene_start_', 'gene_end_', 
       'gene_mask_start_', 'gene_mask_end_', 'dataset_', 'gene_id_', 'pearson_', 'size_factor_pearson_',
       'ensembl_canonical_tss_', 'tss_dist', 'start', 'end', 'strand', 'gene_mask_start', 
        'rel_pos', 'ref_tx', 'alt_tx']].drop_duplicates()
len(variant_gene_pairs)

## Save

In [None]:
matched_neg.to_csv(os.path.join(out_dir, 'negatives_matched.csv'), index=None)

In [None]:
var_out_file = os.path.join(out_dir, 'negative_variants_processed.csv')
variant_gene_pairs.to_csv(var_out_file, index=False)

## Run

In [None]:
devices = [0,1,3,4]
ckpts = ['/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/kugrjb50/checkpoints/epoch=3-step=2920.ckpt',
        '/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/0as9e8of/checkpoints/epoch=7-step=5840.ckpt',
        '/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/i68hdsdk/checkpoints/epoch=2-step=2190.ckpt',
        '/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/i9zsp4nm/checkpoints/epoch=8-step=6570.ckpt'
        ]
for d, c in zip(devices, ckpts):
    out_file = os.path.join(out_dir, f'decima_preds_{d}.npy')
    cmd = f"CUDA_VISIBLE_DEVICES={d} python /code/decima/scripts/vep.py \
-device {d} -ckpts {c} -h5_file {h5_file} -variant_df_file {var_out_file} \
-out_file {out_file}"
    print(cmd)

In [None]:
files = [os.path.join(out_dir, f'decima_preds_{d}.npy') for d in devices]
preds = np.stack([np.load(f) for f in files]).mean(0)
preds.shape

In [None]:
out_file = os.path.join(out_dir, 'decima_preds.npy')
np.save(out_file, preds)

In [None]:
device = 0
out_file = os.path.join(out_dir, 'decima_preds.npy')
ckpts = ['/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/kugrjb50/checkpoints/epoch=3-step=2920.ckpt',
        '/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/0as9e8of/checkpoints/epoch=7-step=5840.ckpt',
        '/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/i68hdsdk/checkpoints/epoch=2-step=2190.ckpt',
        '/gstore/data/resbioai/grelu/decima/20240823/lightning_logs/i9zsp4nm/checkpoints/epoch=8-step=6570.ckpt'
        ]

cmd = f"CUDA_VISIBLE_DEVICES={device} python /code/decima/scripts/vep.py \
-device {device} -ckpts {' '.join(ckpts)} \
-h5_file {h5_file} -variant_df_file {var_out_file} \
-out_file {out_file}"
print(cmd)

## Run Borzoi

In [None]:
device = 1
cmd = f"CUDA_VISIBLE_DEVICES={device} python /code/decima/scripts/vep_borzoi.py \
-device {device} -h5_file {h5_file} -variant_df_file {var_out_file} \
-out_dir {out_dir}"
print(cmd)