In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import os.path

In [19]:
########### GENERAL ###########

def get_xscorer_nbgenes(gwasA, gwasB, crossDir):
    # Get xscorer nb of coherent and anti-coherent genes
    if os.path.exists(f'{crossDir}{gwasA}__{gwasB}__coherence__xscorer_results.csv'):
        cohe_path = f'{crossDir}{gwasA}__{gwasB}__coherence__xscorer_results.csv'
        anti_path = f'{crossDir}{gwasA}__{gwasB}__anti-coherence__xscorer_results.csv'
    else:
        cohe_path = f'{crossDir}{gwasB}__{gwasA}__coherence__xscorer_results.csv'
        anti_path = f'{crossDir}{gwasB}__{gwasA}__anti-coherence__xscorer_results.csv'
    cohe = pd.read_csv(cohe_path)
    anti = pd.read_csv(anti_path)
    Bonf_thre = 0.05 / cohe.shape[0]
    cohe_nb = len(cohe.loc[cohe['pval']<(Bonf_thre)])
    anti_nb = len(anti.loc[anti['pval']<(Bonf_thre)])
    return cohe_nb, anti_nb

def get_xscorer_signif_genes(gwasA, gwasB, crossDir):
    # Get dfs with significant genes for a trait pair (1 df coherent and 1 df anti-coherent)
    if os.path.exists(f'{crossDir}{gwasA}__{gwasB}__coherence__xscorer_results.csv'):
        cohe_path = f'{crossDir}{gwasA}__{gwasB}__coherence__xscorer_results.csv'
        anti_path = f'{crossDir}{gwasA}__{gwasB}__anti-coherence__xscorer_results.csv'
    else:
        cohe_path = f'{crossDir}{gwasB}__{gwasA}__coherence__xscorer_results.csv'
        anti_path = f'{crossDir}{gwasB}__{gwasA}__anti-coherence__xscorer_results.csv'
    cohe = pd.read_csv(cohe_path)
    anti = pd.read_csv(anti_path)
    Bonf_thre = 0.05 / cohe.shape[0]
    cohe_df = cohe.loc[cohe['pval']<(Bonf_thre)]
    anti_df = anti.loc[anti['pval']<(Bonf_thre)]
    return cohe_df, anti_df

def get_xscorer_gene_list(listA, listB, crossDir):
    # Get a df with significant genes, in how many pairs they appeared, and a list of those pairs
    cross_genes = pd.DataFrame({'gene_name': pd.Series(dtype='str'),
                                'pairs_nb': pd.Series(dtype='int'),
                                'pairs_list': pd.Series(dtype='str')})
    gwasA_done = [] # list of done gwasA to avoid duplicate pairs when listA=listB
    for gwasA in listA:
        gwasA_done.append(gwasA)
        for gwasB in listB:
            if gwasA != gwasB:
                if gwasB not in gwasA_done:
                    cohe_df, anti_df = get_xscorer_signif_genes(gwasA, gwasB, crossDir)
                    if len(cohe_df)>0:
                        for i in range(len(cohe_df)):
                            if cohe_df['gene_name'].iloc[i] not in cross_genes['gene_name'].values:
                                cross_genes = pd.concat([cross_genes, pd.DataFrame.from_records([{'gene_name': cohe_df['gene_name'].iloc[i], 'pairs_nb': 1, 'pairs_list': gwasA+'-'+gwasB}])], ignore_index=True)
                            else:
                                cross_genes.loc[cross_genes['gene_name']==cohe_df['gene_name'].iloc[i], 'pairs_nb'] += 1
                                cross_genes.loc[cross_genes['gene_name']==cohe_df['gene_name'].iloc[i], 'pairs_list'] = cross_genes.loc[cross_genes['gene_name']==cohe_df['gene_name'].iloc[i], 'pairs_list'].values[0]+'; '+gwasA+'-'+gwasB
                    if len(anti_df)>0:
                        for j in range(len(anti_df)):
                            if anti_df['gene_name'].iloc[j] not in cross_genes['gene_name'].values:
                                cross_genes = pd.concat([cross_genes, pd.DataFrame.from_records([{'gene_name': anti_df['gene_name'].iloc[j], 'pairs_nb': 1, 'pairs_list': gwasA+'-'+gwasB}])], ignore_index=True)
                            else:
                                cross_genes.loc[cross_genes['gene_name']==anti_df['gene_name'].iloc[j], 'pairs_nb'] += 1
                                cross_genes.loc[cross_genes['gene_name']==anti_df['gene_name'].iloc[j], 'pairs_list'] = cross_genes.loc[cross_genes['gene_name']==anti_df['gene_name'].iloc[j], 'pairs_list'].values[0]+'; '+gwasA+'-'+gwasB
    return cross_genes


##################################################################
########### HEATMAPS ###########
##################################################################

def xscorer_nbgenes_heatmaps(listA, listB, labelsA, labelsB, crossDir, show, export, outname=None):
    # Heatmaps with number of coherent and anti-coherent genes
    # Coloured by number of genes (coherent red, anti-coherent blue)
    cohe_table = pd.DataFrame(index=listB, columns=listA)
    anti_table = pd.DataFrame(index=listB, columns=listA)
    for gwasA in listA:
        for gwasB in listB:
            cohe_nb, anti_nb = get_xscorer_nbgenes(gwasA, gwasB, crossDir)
            cohe_table.loc[gwasB, gwasA] = cohe_nb
            anti_table.loc[gwasB, gwasA] = anti_nb

    cohe_table = cohe_table.apply(pd.to_numeric)
    anti_table = anti_table.apply(pd.to_numeric)

    # Heatmaps
    plt.rcParams['figure.constrained_layout.use'] = False
    if len(listB)<=10:
        fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(len(listA)*0.5, (len(listB)*0.4)*2), sharex=True)
    else:
        fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(len(listA)*0.5, (len(listB)*0.4)*2))
    sns.heatmap(cohe_table, annot=True, fmt='g', cmap='Reds', ax=ax[0], cbar_kws={'pad':0.02, 'label':'N coherent genes'})
    sns.heatmap(anti_table, annot=True, fmt='g', cmap='Blues', ax=ax[1], cbar_kws={'pad':0.02, 'label':'N anti-coherent genes'})
    ax[0].set_xticklabels(labelsA, rotation = 45, ha='right', rotation_mode='anchor')
    ax[0].set_yticklabels(labelsB)
    ax[1].set_xticklabels(labelsA, rotation = 45, ha='right', rotation_mode='anchor')
    ax[1].set_yticklabels(labelsB)
    plt.tight_layout()
    if show:
        plt.show()
    else:
        plt.close()
    if export:
        fig.savefig(crossDir+'figures/'+outname+'.jpg', dpi=300, format='jpg', bbox_inches='tight', pad_inches=0.1)


##################################################################
########### COMBINED PLOT ###########
##################################################################

def xscorer_retina_cohe_anticohe_heatmap_scatter(listA, listB, labelsA, crossDir, top_nb, cmap, show, export, outname=None):
    # Left side: heatmap with retina-retina coherent (top right) and anti-coherent (bottom left) nb of genes 
    # and nothing on the diagonal
    # Right side: table-like scatterplot with (at most) *top_nb* most significant genes against traits
    results = pd.DataFrame(index=listA, columns=listB)
    cmap_results = pd.DataFrame(index=listA, columns=listB) # used for colouring purpose
    gwasA_done = [] # list of done gwasA to avoid duplicate pairs when listA=listB
    for gwasA in listA:
        gwasA_done.append(gwasA)
        for gwasB in listB:
            if gwasA == gwasB: 
                results.loc[gwasA, gwasB] = 0 # empty diagonal
                cmap_results.loc[gwasA, gwasB] = 0 # keep at zero for the colouring
            if gwasB not in gwasA_done: # xscorer result
                cohe_nb, anti_nb = get_xscorer_nbgenes(gwasA, gwasB, crossDir)
                results.loc[gwasA, gwasB] = cohe_nb # top right triangle
                cmap_results.loc[gwasA, gwasB] = cohe_nb
                results.loc[gwasB, gwasA] = anti_nb # bottom left triangle
                cmap_results.loc[gwasB, gwasA] = -anti_nb
    # results = results.apply(pd.to_numeric)
    results.replace(0, '', inplace=True) # blank where zero (for annotation in heatmap)
    cmap_results = cmap_results.apply(pd.to_numeric)

    cross_genes = get_xscorer_gene_list(listA, listB, crossDir)
    if len(cross_genes)>=top_nb:
        top_genes = cross_genes.sort_values(by='pairs_nb', ascending=False, ignore_index=True)[0:top_nb]
    else:
        top_genes = cross_genes.sort_values(by='pairs_nb', ascending=False, ignore_index=True)
    gene_info = [] # initialize new dataframe we will use for the scatterplot
    for gene in top_genes['gene_name']: # most repeated gene at the bottom of y-axis (use next line for the opposite)
    # for gene in top_genes.loc[::-1, 'gene_name']: # ::-1 to invert the gene order for the scatter y-axis
        pairs_list = top_genes.loc[top_genes['gene_name']==gene, 'pairs_list'].values[0]
        for trait in listA:
            # count = pairs_list.count(trait) # alternative nethod to get the count
            count = 0 # nb of pairs this trait appears in
            logp_sum = 0 # sum of -log10(p) values across pairs
            for pair in pairs_list.split('; '): # loop through all pairs for this gene
                if trait in pair: # pairs containing this trait
                    cohe_df, anti_df = get_xscorer_signif_genes(pair.split('-')[0], pair.split('-')[1], crossDir)
                    if gene in cohe_df['gene_name'].values:
                        count += 1
                        logp_sum += -np.log10(cohe_df.loc[cohe_df['gene_name']==gene, 'pval'].values[0])
                    elif gene in anti_df['gene_name'].values:
                        count += 1
                        logp_sum += -np.log10(anti_df.loc[anti_df['gene_name']==gene, 'pval'].values[0])
            if count>0:
                mean_logp = logp_sum / count
            else:
                mean_logp = 0
            gene_info.append({'gene_name': gene, 'trait_name': trait, 'count': count, 'mean_logp': mean_logp})
    gene_info = pd.DataFrame(gene_info)

    # Figure (fig1 = heatmap, fig2 = scatterplot)
    plt.rcParams['figure.constrained_layout.use'] = True
    fig, ax = plt.subplots(1, 2, gridspec_kw={'width_ratios': [4, 3]}, figsize=(13,7))
    diag = np.eye(*results.shape, dtype=bool) # diagonal will be blank
    fig1 = sns.heatmap(cmap_results, annot=results, mask=diag, fmt='', center=0, cmap='seismic', cbar=False, ax=ax[0])
    fig1.set_xticklabels(labelsA, rotation = 45, ha='right', rotation_mode='anchor')
    fig1.set_yticklabels(labelsA)
    red_patch = mpatches.Patch(color='red', label='Coherent N genes')
    blue_patch = mpatches.Patch(color='blue', label='Anti-coherent N genes')
    ax[0].legend(bbox_to_anchor=(0, 1.2), loc='upper left', handles=[red_patch, blue_patch], fontsize=12)
    fig2 = ax[1].scatter(gene_info['trait_name'], gene_info['gene_name'], c=gene_info['mean_logp'], s=3*gene_info['count'], cmap=cmap)
    cmap_label = 'Pleiotropy (N trait pairs)'
    plt.xticks(listA, labelsA, rotation=45, ha='right', rotation_mode='anchor')
    plt.grid(color='gray', linestyle='-', linewidth=0.1)
    cbax = ax[1].inset_axes([0, 1.05, 0.6, 0.04], transform=ax[1].transAxes) # add axis for cbar
    fig.colorbar(ax[1].collections[0], cax=cbax, shrink=0.6, orientation='horizontal').set_label(label=cmap_label, position=(0.5,-0.5), size=12)
    cbax.xaxis.set_label_position('top')
    cbax.xaxis.tick_top()
    kw = dict(prop='sizes', num=4, color=fig2.cmap(0.7))
    ax[1].legend(*fig2.legend_elements(**kw), loc='upper right', title='Mean -log10(p)', bbox_to_anchor=(1, 1.25), title_fontsize=12)

    if show:
        plt.show()
    else:
        plt.close()
    if export:
        fig.savefig(crossDir+'figures/'+outname+'.jpg', dpi=300, format='jpg', bbox_inches='tight', pad_inches=0.1)


In [3]:
# Retinal traits order according to phenotypic correlation clustering
traits = 'mean_angle_taa,mean_angle_tva,tau1_vein,tau1_artery,ratio_AV_DF,eq_CRAE,ratio_CRAE_CRVE,D_A_std,D_V_std,eq_CRVE,ratio_VD,VD_orig_artery,bifurcations,VD_orig_vein,medianDiameter_artery,medianDiameter_vein,ratio_AV_medianDiameter'.split(',')
with open ('/NVME/decrypted/scratch/multitrait/UK_BIOBANK_PREPRINT/participant_phenotype/PC__2022_11_23_covar_fix__labels_order.csv', 'r') as f:
    for r in f: #there's only one row
        order = [int(i) for i in r.split(',')]

traits = [traits[i] for i in order]

In [1]:
# FIGURES
diseases = ['4079_irnt', '4080_irnt', '102_irnt', '21021_irnt', '30760_irnt', '30780_irnt', '30870_irnt', '30750_irnt', '1558', '21001_irnt']
labelsA = ['A temporal angle','V temporal angle','V tortuosity','A tortuosity','ratio tortuosity','A central retinal eq','ratio central retinal eq','A std diameter','V std diameter','V central retinal eq','ratio vascular density','A vascular density','bifurcations','V vascular density','A median diameter','V median diameter','ratio median diameter']
labelsB = ['DBP', 'SBP', 'PR', 'PWASI', 'HDL', 'LDL', 'Triglycerides', 'HbA1c', 'Alcohol', 'BMI']
crossDir_ret_ret = '/NVME/scratch/olga/output/PascalX/xscorer/retina-retina/2022-11-25/'
crossDir_ret_dis = '/NVME/scratch/olga/output/PascalX/xscorer/retina-disease/2023-02-13/'
geneDir = '/NVME/decrypted/scratch/multitrait/UK_BIOBANK_PREPRINT/gwas/2022_11_23_covar_fix/'
gene_sfx = '__gene_scores.p'
top_nb = 30
cmap = 'Greys'
show=True
export=True
outname_ret_ret = 'xscorer_retina_cohe_anticohe_scatter'
outname_ret_dis = 'xscorer_retina_disease_cohe_anticohe_heatmaps'
# Retina-retina heatmap + scatterplot (figure 3c-d)
xscorer_retina_cohe_anticohe_heatmap_scatter(traits, traits, labelsA, crossDir_ret_ret, top_nb, cmap, show=False, export=True, outname=outname_ret_ret)
# Retina-disease heatmaps (coherence, anti-coherence) (figure 6b-c)
xscorer_nbgenes_heatmaps(traits, diseases, labelsA, labelsB, crossDir_ret_dis, show=False, export=True, outname=outname_ret_dis)

