purpose: compare the manhattan plots for nicsa traits to determine what the optimal trait to use is

In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import statsmodels.stats.multitest
import statsmodels.api as sm 
import pylab as py 
import statsmodels.api as sm

In [2]:
os.chdir('/tscc/projects/ps-palmer/brittany/rare_common_alcohol/rare_common_alcohol_comparison/notebooks/')

In [3]:
from rca_functions import manhattan
from rca_functions import porcupine

In [4]:
os.chdir('/tscc/projects/ps-palmer/brittany/SUD_cross_species/')

In [6]:
def porcupine(pval, test, pos, chr, label,
              cut_SKAT=5e-8,
              cut_SKATO=5e-8,
              cut_burden=5e-8,
              chrs_plot=None, chrs_names=None,
              cut=2,
              colors=['k', '0.5'],
              title='Title',
              xlabel='chromosome',
              ylabel='-log10(p-value)',
              top=0,
              lines=[10, 15],
              lines_colors=['g', 'r'],
              lines_styles=['-', '--'],
              lines_widths=[1, 1],
              zoom=None,
              scaling='-log10',
              plot_grid_lines=True,
              **kwargs):
    """
    Generates a specialized Porcupine plot for different types of rare-variant SNP to gene tests (SKAT, SKATO, Burden), highlighting significant findings in genomic data using color-coding for test types.

    Parameters:
    - pval (array-like): Array of p-values for genomic variants.
    - test (array-like): Array specifying the type of genetic test performed for each variant p-value.
    - pos (array-like): Array of positions of variants on their respective chromosomes.
    - chr (array-like): Array of chromosome numbers for each variant.
    - label (str): Label for the dataset, used for annotations.
    - cut_SKAT, cut_SKATO, cut_burden (float): Cutoff p-values for SKAT, SKATO, and Burden tests, respectively.
    - chrs_plot (list, optional): Specific chromosomes to include in the plot.
    - chrs_names (list, optional): Custom names for the chromosomes to be plotted.
    - cut (float): Cutoff for -log10(p-value) for displaying points on the plot.
    - colors (list): Colors to use for plotting points, cycling through for different chromosomes.
    - title (str): Title of the plot.
    - xlabel, ylabel (str): Labels for the x-axis and y-axis.
    - top (float): Upper limit for the y-axis; if 0, it is calculated from the data.
    - lines (list): y-values where horizontal lines should be drawn.
    - lines_colors, lines_styles, lines_widths (list): Properties for the horizontal lines.
    - zoom (tuple): Tuple (chromosome, center position, range) for focusing on a specific region.
    - scaling (str): P-value scaling method; supports '-log10' for negative log transformation or 'none'.
    - plot_grid_lines (bool): Whether to include grid lines on the plot.
    - **kwargs: Additional keyword arguments for matplotlib plot functions.

    Returns:
    matplotlib.pyplot: Configured plot object ready for display or saving.
    """
    
    # Initialize plot settings and clear any existing figures
    shift = np.array([0.0])
    plt.clf()

    # Determine which chromosomes to plot, sorting naturally if needed
    if chrs_plot is None:
        chrs_list = np.unique(chr)
        chrs_list = sorted_nicely(chrs_list) if isinstance(chrs_list[0], str) else chrs_list.sort()
    else:
        chrs_list = chrs_plot

    # Generate chromosome labels if not provided
    if chrs_names is None:
        chrs_names = [str(chrs_list[i]) for i in range(len(chrs_list))]

    plot_positions = len(chrs_list) == 1

    # Convert cutoffs to the appropriate scale if necessary
    if scaling == '-log10':
        cut_burden = -np.log10(cut_burden)
        cut_SKATO = -np.log10(cut_SKATO)
        cut_SKAT = -np.log10(cut_SKAT)
        
    # Plot data for each chromosome
    for ii, i in enumerate(chrs_list):     
        plt.subplot(1,1,1)
        filt = np.where(chr == i)[0]
        x = shift[-1] + pos[filt]
        y = -np.log10(pval[filt]) if scaling == '-log10' else pval[filt]
        test_filter = test[filt]
        
        # Plot data points above a general cut-off, with specific colors for each test type. Requires the color_dict to function, or must redefine the color dictionary
        plt.plot(x[y > cut], y[y > cut], '.', color=colors[ii % len(colors)], **kwargs)
        plt.plot(x[(y > cut) & (test_filter == 'Burden') & (y > cut_burden)], y[(y > cut) & (test_filter == 'Burden') & (y > cut_burden)], '.', color=color_dict['Burden'], **kwargs)
        plt.plot(x[(y > cut) & (test_filter == 'SKATO') & (y > cut_SKATO)], y[(y > cut) & (test_filter == 'SKATO') & (y > cut_SKATO)], '.', color=color_dict['SKAT-O'], **kwargs)
        plt.plot(x[(y > cut) & (test_filter == 'SKAT') & (y > cut_SKAT)], y[(y > cut) & (test_filter == 'SKAT') & (y > cut_SKAT)], '.', color=color_dict['SKAT'], **kwargs)

        # Calculate the maximum shift for the next set of points
        shift_f = np.max(x)
        shift_m = 0  # Placeholder for potential future use
        shift = np.append(shift, np.max([shift_f, shift_m]))

        # Set grid lines and limits
        if plot_grid_lines:
            plt.plot([shift[-1], shift[-1]], [0, 1000], '-', lw=0.5, color='lightgray', **kwargs)
        plt.xlim([0, shift[-1]])

    # Determine the upper limit for the y-axis
    if top == 0:
        top = np.ceil(np.max(-np.log10(pval))) if scaling == '-log10' else np.ceil(np.max(pval))

    # Configure fig labels and horizontal lines
    shift_label = shift[-1]
    shift = (shift[1:] + shift[:-1]) / 2
    for i, line_height in enumerate(lines):
        plt.axhline(y=line_height, color=lines_colors[i], linestyle=lines_styles[i], linewidth=lines_widths[i])
    
    plt.ylim([cut, top])
    plt.title(title)
    if not plot_positions:
        plt.xticks(shift, chrs_names)
    plt.text(shift_label * 0.95, top * 0.95, label, verticalalignment='top', horizontalalignment='right')
    plt.ylabel(ylabel)
    plt.xlabel(xlabel)

    # Apply zoom settings if specified
    if zoom is not None:
        plt.xlim([zoom_shift - zoom[2], zoom_shift + zoom[2]])

    return plt

<function rca_functions.porcupine(pval, test, pos, chr, label, cut_SKAT=5e-08, cut_SKATO=5e-08, cut_burden=5e-08, chrs_plot=None, chrs_names=None, cut=2, colors=['k', '0.5'], title='Title', xlabel='chromosome', ylabel='-log10(p-value)', top=0, lines=[10, 15], lines_colors=['g', 'r'], lines_styles=['-', '--'], lines_widths=[1, 1], zoom=None, scaling='-log10', plot_grid_lines=True, **kwargs)>

In [None]:
    def porcupineplotv2(self, qtltable = '', traitlist: list = [], display_figure = False, skip_manhattan = False, maxtraits = 60):
        printwithlog('starting porcupine plot v2')
        hv.opts.defaults(hv.opts.Points(width=1200, height=600), hv.opts.RGB(width=1200, height=600) )
        if type(qtltable) == str:
            if not len(qtltable): qtltable = pd.read_csv(f'{self.path}results/qtls/finalqtl.csv').reset_index().query('QTL == True')
        if not len(traitlist): traitlist = list(map(lambda x:x.replace('regressedlr_', ''),self.traits))        
        cmap = sns.color_palette("tab20", len(traitlist))
        d = {t: cmap[v] for v,t in enumerate(sorted(traitlist))}
        d_inv = {cmap[v]:t for v,t in enumerate(sorted(traitlist))}
        tnum = {t:num for num,t in enumerate(sorted(traitlist))}    
        qtltable['color'] =  qtltable.trait.apply(lambda x: d[x]) 
        qtltable['traitnum'] =  qtltable.trait.apply(lambda x: f'{tnum[x]}') 
        if len(traitlist) > maxtraits: 
            traitlist_new = list(qtltable.trait.unique())
            if maxtraits - len(traitlist_new) > 0:
                traitlist_new += list(np.random.choice(list(set(traitlist) - set(traitlist_new)), maxtraits - len(traitlist_new), replace = False))
        else: traitlist_new = traitlist
        fdf = []
        h2file = pd.read_csv(f'{self.path}results/heritability/heritability.tsv', sep = '\t', index_col = 0).rename(lambda x: x.replace('regressedlr_', ''))
        for num, t in tqdm(list(enumerate(traitlist))):
            if not skip_manhattan or t in traitlist_new:
                df_gwas = []
                for opt in [f'regressedlr_{t.replace("regressedlr_", "")}.loco.mlma', 
                            f'regressedlr_{t.replace("regressedlr_", "")}.mlma']+ \
                           [f'regressedlr_{t.replace("regressedlr_", "")}_chrgwas{chromp2}.mlma' for chromp2 in self.chrList()]:
                    if glob(f'{self.path}results/gwas/{opt}'):
                        df_gwas += [pd.read_csv(f'{self.path}results/gwas/{opt}', sep = '\t')]
                    else:  pass
                if len(df_gwas) == 0 :  printwithlog(f'could not open mlma files for {t}')
                df_gwas = pd.concat(df_gwas)
                append_position = df_gwas.groupby('Chr').bp.agg('max').sort_index().cumsum().shift(1,fill_value=0)
                qtltable['x'] = qtltable.apply(lambda x: x.bp +  append_position[x.Chr], axis = 1)
                df_gwas['-log10p'] = -np.log10(df_gwas.p)
                df_gwas.drop(['A1', 'A2', 'Freq', 'b', 'se', 'p'], axis = 1, inplace = True)
                def mapcolor(c): 
                    if int(str(c).replace('X',str(self.n_autosome+1)).replace('Y', str(self.n_autosome+2)).replace('MT', str(self.n_autosome+4)))%2 == 0: return 'black'
                    return 'gray'
                df_gwas = df_gwas.groupby('Chr') \
                                 .apply(lambda df: df.assign(color = mapcolor(df.Chr[0]), x = df.bp + append_position[df.Chr[0]])) \
                                 .reset_index(drop = True)
                df_gwas.loc[df_gwas['-log10p']> self.threshold, 'color' ] = str(d[t])[1:-1]
                df_gwas.loc[df_gwas['-log10p']> self.threshold, 'color' ] = df_gwas.loc[df_gwas['-log10p']> self.threshold, 'color' ].str.split(',').map(lambda x: tuple(map(float, x)))
                if not skip_manhattan:
                    yrange = (-.05,max(6, df_gwas['-log10p'].max()+.5))
                    xrange = tuple(df_gwas.x.agg(['min', 'max'])+ np.array([-1e7,+1e7]))
                    fig = []
                    for idx, dfs in df_gwas[df_gwas.color.isin(['gray', 'black'])].groupby('color'):
                        temp = datashade(hv.Points(dfs, kdims = ['x','-log10p']), pixel_ratio= 2, aggregator=ds.count(), width = 1200,height = 600, y_range= yrange,
                                 min_alpha=.7, cmap = [idx], dynamic = False )
                        temp = dynspread(temp, max_px=4,threshold= 1 )
                        fig += [temp]
                    fig = fig[0]*fig[1]
                    fig = fig*hv.HLine((self.threshold05)).opts(color='blue')*hv.HLine(self.threshold).opts(color='red')
                    fig = fig*hv.Points(df_gwas[df_gwas['-log10p']> self.threshold].drop('color', axis = 1), 
                                        kdims = ['x','-log10p']).opts(color = 'red', size = 5)
                    figh2 = round(h2file.loc[t.replace("regressedlr_", ""),'V(G)/Vp'],3)
                    fig = fig.opts(xticks=[((dfs.x.agg(['min', 'max'])).sum()//2 , self.replacenumstoXYMT(names)) for names,dfs in  df_gwas.groupby('Chr')],
                                                   xlim =xrange, ylim=yrange, width = 1200,height = 600,  xlabel='Chromosome',
                                   title = f'{t.replace("regressedlr_", "")} n={self.df["regressedlr_"+ t.replace("regressedlr_", "")].count()} h2={figh2}') 
                    hv.save(fig, f'{self.path}images/manhattan/{t.replace("regressedlr_", "")}.png')
                if t in traitlist_new: fdf += [df_gwas]
        fdf = pd.concat(fdf).reset_index(drop = True).sort_values('x')
        fig = []
        yrange = (-.05,max(6, fdf['-log10p'].max()+.5))
        xrange = tuple(fdf.x.agg(['min', 'max'])+ np.array([-1e7,+1e7]))
        for idx, dfs in fdf[fdf.color.isin(['gray', 'black'])].groupby('color'):
            temp = datashade(hv.Points(dfs, kdims = ['x','-log10p']), pixel_ratio= 2, aggregator=ds.count(), width = 1200,height = 600, y_range= yrange,
                     min_alpha=.7, cmap = [idx], dynamic = False )
            temp = dynspread(temp, max_px=4,threshold= 1 )
            fig += [temp]
        fig = fig[0]*fig[1]
        
        fig = fig*hv.HLine((self.threshold05)).opts(color='blue')
        fig = fig*hv.HLine(self.threshold).opts(color='red')
        
        for idx, dfs in fdf[~fdf.color.isin(['gray', 'black'])].groupby('color'):
            fig = fig*hv.Points(dfs.drop('color', axis = 1), kdims = ['x','-log10p']).opts(color = idx, size = 5)
        
        for t, dfs in qtltable.groupby('trait'):
            fig = fig*hv.Points(dfs.assign(**{'-log10p': qtltable.p}), kdims = ['x','-log10p'],vdims=[ 'trait','SNP' ,'A1','A2','Freq' ,'b','traitnum'], label = f'({tnum[t]}) {t}' ) \
                                          .opts(size = 17, color = d[t], marker='inverted_triangle', line_color = 'black', tools=['hover']) #
        fig = fig*hv.Labels(qtltable.rename({'p':'-log10p'}, axis = 1)[['x', '-log10p', 'traitnum']], 
                            ['x','-log10p'],vdims=['traitnum']).opts(text_font_size='5pt', text_color='black')
        fig.opts(xticks=[((dfs.x.agg(['min', 'max'])).sum()//2 , self.replacenumstoXYMT(names)) for names, dfs in fdf.groupby('Chr')],
                                   xlim =xrange, ylim=yrange, xlabel='Chromosome', shared_axes=False,
                               width=1200, height=600, title = f'porcupineplot',legend_position='right',show_legend=True)
        hv.save(fig, f'{self.path}images/porcupineplot.png')
        if display_figure: 
            display(fig)
            return
        return fig

In [None]:
    def prune_genotypes(self):
        printwithlog('starting genotype prunner...')
        snps,_,gens = pandas_plink.read_plink(self.genotypes_subset)
        gens = da.nan_to_num(gens, -1).astype(np.int8)
        def prune_dups(array):
            dict = defaultdict(list, {})
            for num, i in enumerate(array): dict[i.tobytes()] += [num]
            return dict
        printwithlog('starting genotype dups finder...')    
        pruned = prune_dups(gens.compute())
        first_snps = [snps.loc[v[0], 'snp'] for k,v in pruned.items()]
        printwithlog(f'saving resulst to:\n1){self.path}pvalthresh/genomaping.parquet.gz\n2){self.path}pvalthresh/prunned_dup_snps.in\n3){self.path}genotypes/prunedgenotypes')  
        prunedset = pd.DataFrame([[k,'|'.join(map(str, v))] for k,v in pruned.items()], columns = ['genotypes', 'snps'])
        prunedset.to_parquet(f'{self.path}pvalthresh/genomaping.parquet.gz', compression = 'gzip')
        pd.DataFrame(first_snps).to_csv(f'{self.path}pvalthresh/prunned_dup_snps.in', index = False, header = None)
        plink(bfile=self.genotypes_subset,  thread_num = self.threadnum, extract = f'{self.path}pvalthresh/prunned_dup_snps.in',
              make_bed = '', out = f'{self.path}genotypes/prunedgenotypes')
        printwithlog(f'prunned data has {format(prunedset.shape[0], ",")} out of the original {format(gens.shape[0], ",")}')
        return f'saved prunned genotypes to {self.path}genotypes/prunedgenotypes'
        
    def estimate_pval_threshold(self, replicates = 1000, sample_size = 'all', exact_prunner = True ,prunning_window = 5000000, prunning_step = 1000, remove_after = False):
        printwithlog('starting P-value threshold calculation...')
        if sample_size == 'all': sample_size = len(self.df)
        if sample_size < 1: round(len(self.df)*sample_size)
        cline = Client( processes = False) if not client._get_global_client() else client._get_global_client()
        os.makedirs(f'{self.path}pvalthresh', exist_ok = True)
        os.makedirs(f'{self.path}pvalthresh/gwas/', exist_ok = True)
        os.makedirs(f'{self.path}pvalthresh/randomtrait/', exist_ok = True)
        
        if exact_prunner: self.prune_genotypes()
        if not exact_prunner:
            prunning_params = f'{prunning_window} {prunning_step} 0.999'
            printwithlog(f'prunning gentoypes using {prunning_params}')
            if not os.path.exists(f'{self.path}pvalthresh/pruned_data.prune.in'):
                plink(bfile=self.genotypes_subset, indep_pairwise = prunning_params, out = f'{self.path}pvalthresh/pruned_data', thread_num = self.threadnum)
                plink(bfile=self.genotypes_subset,  thread_num = self.threadnum, extract = f'{self.path}pvalthresh/pruned_data.prune.in', 
                      make_bed = '', out = f'{self.path}genotypes/prunedgenotypes')
            npruned_snps = [pd.read_csv(f'{self.path}pvalthresh/pruned_data.prune.{i}', header = None).shape[0] for i in ['in', 'out']]
            display(f'prunned data has {npruned_snps[0]} out of the original {npruned_snps[1]}' )
        
        def get_maxp_1sample(ranid, skip_already_present = True, remove_after = True ):
            os.makedirs(f'{self.path}pvalthresh/gwas/{ranid}', exist_ok = True)
            r = np.random.RandomState(ranid)
            valuelis = r.normal(size = self.df.shape[0])
            valuelis *= r.choice([1, np.nan],size = self.df.shape[0] , 
                                 p = [sample_size/self.df.shape[0], 1-sample_size/self.df.shape[0]])
            self.df[['rfid', 'rfid']].assign(trait = valuelis).fillna('NA').astype(str).to_csv(f'{self.path}pvalthresh/randomtrait/{ranid}.txt',  index = False, sep = ' ', header = None)
            maxp = 0
            for c in self.chrList():
                chrom = self.replaceXYMTtonums(c)
                filename = f'{self.path}pvalthresh/gwas/{ranid}/chrgwas{c}' 
                if os.path.exists(f'{filename}.mlma') and skip_already_present: pass
                    #printwithlog(f'''skipping gwas for trait: {ranid} and chr {c}''')
                else:
                    subgrmflag = f'--mlma-subtract-grm {self.path}grm/{c}chrGRM' if c not in ['x','y'] else ''
                    bash(f'{self.gcta} --thread-num 1 --pheno {self.path}pvalthresh/randomtrait/{ranid}.txt --bfile {self.path}genotypes/prunedgenotypes \
                                               --grm {self.path}grm/AllchrGRM --autosome-num {self.n_autosome} \
                                               --chr {chrom} {subgrmflag} --mlma \
                                               --out {filename}', 
                                 print_call = False)#f'GWAS_{chrom}_{ranid}',
                if os.path.exists(f'{filename}.mlma'): chrmaxp = np.log10(pd.read_csv(f'{filename}.mlma', sep = '\t')['p'].min())
                else: chrmaxp = 0
                if chrmaxp < maxp: maxp = chrmaxp
            if remove_after:
                bash(f'rm -r {self.path}pvalthresh/gwas/{ranid}')
            return maxp
        
        looppd = pd.DataFrame(range(replicates), columns = ['reps'])
        loop   = dd.from_pandas(looppd, npartitions=min(replicates, 200))
        # %time get_maxp_1sample(34)
        def _gwas_pval(df, rmv_aft):
            ret = df['reps'].map(lambda x: get_maxp_1sample(x,  skip_already_present = True, remove_after = rmv_aft)) # skip_already_present = True,remove_after= rmv_aft
            if len(df) != len(ret): print(ret)
            return ret
        _ = loop.map_partitions(lambda x: _gwas_pval(x, rmv_aft = remove_after), meta = pd.Series())
        printwithlog(f'running gwas for {replicates} replicates')
        future = cline.compute(_)
        progress(future,notebook = False,  interval="300s") #, group_by="spans"
        wait(future)
        out = looppd.assign(maxp = future.result())
        for tf in [True, False]:
            maxrange = 2000 if tf else len(out)
            lis = pd.concat([out.sample(n = x, replace = tf)['maxp'].describe(percentiles=[.1, .05, 0.01, 1e-3, 1e-4]).abs().to_frame().rename({'min': x}, axis = 1) \
                   for x in np.linspace(1, maxrange, 200).round().astype(int)], axis = 1)
            lis = lis.T.reset_index(drop = True).rename({'count': 'samplesize'}, axis = 1)
            lis = lis.drop(['mean', 'min', 'max'], axis = 1).fillna(0)
            melted = lis.melt(id_vars=['samplesize'], value_vars=lis.columns[1:], value_name='pval')
            lis.to_csv(f'{self.path}pvalthresh/maxpvaltable{"with" if tf else "without"}replacement.csv', index = False)
            fig = sns.lmplot(x="samplesize", y="pval",
                 hue="variable",  data=melted.query('variable != "std"'),logx= True,height = 5, aspect=2 )
            fig.savefig(f'{self.path}pvalthresh/threshfig{"with" if tf else "without"}replacement.png')
        oo = out['maxp'].describe(percentiles=[.1, .05, 0.01, 1e-3, 1e-4]).to_frame().set_axis(['thresholds'], axis =1).abs()
        oo.to_csv(f'{self.path}pvalthresh/PVALTHRESHOLD.csv')
        display(oo)
        printwithlog(f"new_thresholds = 5% : {oo.loc['5%','thresholds']} , 10% : {oo.loc['10%','thresholds']}")
        self.threshold = oo.loc['10%','thresholds']
        self.threshold05 = oo.loc['5%','thresholds']
        return oo   
        
        