In [None]:
import polars as pl
import pandas as pd
import os, sys
from tqdm import tqdm

from grelu.data.preprocess import filter_blacklist
from grelu.data.utils import get_chromosomes

## Paths

In [None]:
gwas_dir = '/gstore/data/humgenet/projects/statgen/GWAS/Benchmark_GWAS/'
gnomad_file = '/data/gnomAD/gnomad-regulatory-variants.tsv'
trait_file = os.path.join(gwas_dir, 'disease_list.txt')

out_dir='/gstore/data/resbioai/grelu/decima/20240823/gwas_44traits/negative_variants'

## Load gnomad regulatory variants <100 kb from TSS

In [None]:
%%time
snps = pl.read_csv(gnomad_file, has_header=False, separator='\t', columns=[0, 1, 2, 3, 4, 5, 8],
    new_columns=['chrom', 'pos', 'rsid', 'ref', 'alt', 'af', 'vep']).unique()
print(len(snps))
snps.head(3)

## Filter by allele frequency

In [None]:
# MAF > 1%
snps = snps.with_columns(maf = snps['af'].apply(lambda x: 1-x if x > .5 else x))
snps = snps.drop(columns=['af'])
snps = snps.filter(pl.col('maf') > 0.01)
len(snps)

## Filter by chromosome

In [None]:
%%time
snps = snps.filter(pl.col("chrom").is_in(get_chromosomes('autosomesXY')))
len(snps)

## Filter SNPs with clear alleles

In [None]:
%%time
snps = snps.filter(pl.col("ref").is_in(["A", "C", "G", "T"]))
snps = snps.filter(pl.col("alt").is_in(["A", "C", "G", "T"]))
len(snps)

## Load GWAS SNPs

In [None]:
traits = pl.read_csv(trait_file, has_header=False, new_columns=['trait_ID', 'study', 'trait_name'], separator=' ')
traits.head(2)

In [None]:
%%time
gwas = []
for row in tqdm(traits.iter_rows()):
    susie_file = os.path.join(gwas_dir, 'Complete', row[0], f'{row[1]}.susie.gwfinemap.b38.gz')
    df = pl.read_csv(susie_file, separator='\t',columns=[0,1,2, 3, 4, 6,9,10],
             new_columns = ['chrom', 'rsid', 'pos', 'ref', 'alt', 'MAF', 'p', 'PIP'])
    gwas.append(df)

gwas = pl.concat(gwas)
gwas.head()

## Get max PIP for each GWAS SNP

In [None]:
gwas = gwas.group_by(['chrom', 'pos', 'ref', 'alt', 'rsid']).agg([pl.max("PIP")])
gwas.head()

## Filter by GWAS PIP

In [None]:
# Max PIP < 0.01
snps = snps.filter(~pl.col('rsid').is_in(gwas.filter(pl.col('PIP') > 0.01)['rsid']))
len(snps)

## Filter blacklist

In [None]:
%%time
snps = snps.to_pandas()
snps['start'] = snps['pos'].tolist()
snps['end'] = snps['start']+1
snps = snps[['chrom', 'start', 'end'] + [x for x in snps.columns if x not in ['chrom', 'start', 'end']]]
snps = filter_blacklist(snps, 'hg38')
snps = snps.drop(columns=['start', 'end'])
len(snps)

## Add variant ID

In [None]:
snps['variant'] = snps['chrom'] + '_' + snps['pos'].astype(str) + '_' + snps['ref'] + '_' + snps['alt']

## Save

In [None]:
out_file = os.path.join(out_dir, 'negative_variants.csv')
snps.to_csv(out_file, index=None)