In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text
plt.rcParams['font.family'] = 'Helvetica'

In [None]:
df = pd.read_csv('results/PheWAS/phewas_data_for_plot.csv')
print(df.columns)
print(df['Domain'].unique().tolist())
# sort by Domain and Trait
#df.sort_values(by=['Domain', 'Trait'], inplace=True)

In [None]:
df['neg_log_p_value'] = -np.log10(df['P'] + 1e-300)
# add an id for each phenotype, this will be used for the x axis
df['weight'] = np.where(df['P'] < 5e-8, 2, 1) # set weight to 2 for significant phenotypes, 1 for others, to make significant phenotypes more separated for labelling
df['phenotype_id'] = np.cumsum(df['weight']) - 1
category_centers = df.groupby('Domain')['phenotype_id'].mean()
category_labels = category_centers.index
category_ticks = category_centers.values
print(len(category_ticks))

In [None]:
num_categories = len(df['Domain'].unique())
palette = sns.color_palette("tab10", num_categories)
if num_categories > 10:
    palette = sns.color_palette("hsv", num_categories)
palette = ["#aa5063",
            "#d24344",
            "#da8a6c",
            "#c56428",
            "#cfa640",
            "#8e7a39",
            "#7db844",
            "#53803b",
            "#57b786",
            "#4bafd0",
            "#6975c9",
            "#ac58c4",
            "#c981bc",
            "#d34688"]
category_colors = dict(zip(df['Domain'].unique(), palette))
df['plot_color'] = df['Domain'].map(category_colors)

In [None]:
fig, ax = plt.subplots(figsize=(18, 6))
sns.scatterplot(data=df[df['neg_log_p_value'] <= 50],
                x='phenotype_id',
                y='neg_log_p_value',
                hue='Domain',
                palette=category_colors,
                s=60,
                alpha=0.8,
                edgecolor='black',
                linewidth=0.,  # no border for points
                ax=ax)
ax.axhline(y=27, color='red', linestyle='--', linewidth=1.5)


In [None]:
cut_thres = 27
df_upper = df[df['neg_log_p_value'] > cut_thres]
df_lower = df[df['neg_log_p_value'] <= cut_thres]
# break the plot into two subplots, one for the top 20 and one for the bottom 20
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(18, 8), gridspec_kw={'height_ratios': [5, 5]})
# different y axis for each plot, upper: 20-320, step 50, lower: 0-20, step 5
# Upper plot
sns.scatterplot(data=df_upper,
                x='phenotype_id',
                y='neg_log_p_value',
                hue='Domain',
                palette=category_colors,
                s=60,
                alpha=0.8,
                edgecolor='black',
                linewidth=0., # no border for points
                ax=ax1)
# hide x ticks and labels for the upper plot
ax1.spines['bottom'].set_visible(False)
ax1.tick_params(axis='x', which='both', bottom=False, labelbottom=False)
ax1.set_ylabel(r'$-\log_{10}(P)$', fontsize=18, ha='right', y=0.1)
# ylim for upper plot: 20-320, every 50 a tick
ax1.set_ylim(cut_thres - 1, 320)
ax1.set_yticks(np.arange(cut_thres, 321, 100))
# remove legend for the upper plot
ax1.legend_.remove()
# remove top and right spines for the upper plot
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
# title
ax1.set_title("", fontsize=16, pad=20)

# Lower plot
sns.scatterplot(data=df_lower,
                x='phenotype_id',
                y='neg_log_p_value',
                hue='Domain',
                palette=category_colors,
                s=60,
                alpha=0.8,
                edgecolor='black',
                linewidth=0.,  # no border for points
                ax=ax2)
# set x ticks and labels for the lower plot
# ax2.set_xticks(category_ticks)
# ax2.set_xticklabels(category_labels, rotation=45, ha='right', fontsize=11)
ax2.set_xticklabels('')
ax2.set_xlabel('')
# no y label for the lower plot
ax2.set_ylabel('')
# ylim for lower plot: 0-20, every 5 a tick
ax2.set_ylim(0, cut_thres)
ax2.set_yticks(np.arange(0, cut_thres, 10))
# add a horizontal line at y=5
ax2.axhline(y=5, color='red', linestyle='--', linewidth=1.5)


# same x-axis limits for both plots
ax1.set_xlim(-10, df['phenotype_id'].max() + 10)
ax2.set_xlim(-10, df['phenotype_id'].max() + 10)

# remove top and right spines for the lower plot
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

# remove grid lines for both plots
ax1.grid(False)
ax2.grid(False)

# legend at the center right of the two plots, add bbox to avoid overlap with the plots
ax2.legend(title="", bbox_to_anchor=(1, 1.3), fontsize=15, loc='center right',
           ncol=1, frameon=False, markerscale=1.2, labelspacing=0.5, handletextpad=0.3,
           )

# some adjustments for aesthetics
ax1.spines['left'].set_linewidth(2)
ax2.spines['left'].set_linewidth(2)
ax2.spines['bottom'].set_linewidth(2)
ax1.spines['left'].set_color('black')
ax2.spines['left'].set_color('black')
ax2.spines['bottom'].set_color('black')

ax1.tick_params(axis='y', which='both', left=True, labelleft=True, labelsize=15)
ax2.tick_params(axis='y', which='both', left=True, labelleft=True, labelsize=15)

# --- add break markers ---
kwargs = dict(marker=[(-1, -0.5), (1, 0.5)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
# add a break marker at the bottom of ax1
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)
# add a break marker at the top of ax2
ax2.plot([0], [1], transform=ax2.transAxes, **kwargs)
plt.subplots_adjust(hspace=0.06)


top_annot = pd.read_csv('results/PheWAS/top_to_annot.csv')
anot_x, anot_y, annots, colors = top_annot['phenotype_id'].values, top_annot['neg_log_p_value'].values, top_annot['Trait'].values, top_annot['plot_color'].values
texts_ax1, texts_ax2 = [], []
for i, txt in enumerate(annots):

    if anot_y[i] > cut_thres:
        # text with rectangle around it
        texts_ax1.append(ax1.text(anot_x[i], anot_y[i], txt, fontsize=15, color=colors[i],
                                  ha='center', va='center',
                                  bbox=dict(facecolor='white', edgecolor=colors[i], boxstyle='round,pad=0.2')))
    else:
        texts_ax2.append(ax2.text(anot_x[i], anot_y[i], txt, fontsize=15, color=colors[i], ha='center', va='center',
                         bbox=dict(facecolor='white', edgecolor=colors[i], boxstyle='round,pad=0.2')))
adjust_text(texts=texts_ax1, x=anot_x, y=anot_y, ax=ax1, arrowprops=dict(arrowstyle='-', color='lightgray', lw=1.5),
            force_text=1.15, force_points=1.1, expand_text=(1.4, 1.4), autoalign='y',)
adjust_text(texts=texts_ax2, x=anot_x, y=anot_y, ax=ax2, arrowprops=dict(arrowstyle='-', color='lightgray', lw=1.5),
            force_text=0.5, force_points=1.1, expand_text=(1.5, 1.5), autoalign='x', va='bottom', ha='left')
# save the figure
plt.savefig('results/PheWAS/phewas_manhattan_plot.pdf', dpi=300, bbox_inches='tight')