diff --git a/clarite/cli/commands/plot.py b/clarite/cli/commands/plot.py index d2c737f..7ba437e 100644 --- a/clarite/cli/commands/plot.py +++ b/clarite/cli/commands/plot.py @@ -121,3 +121,15 @@ def manhattan_fdr(ewas_result, output, categories, cutoff, other, nlabeled, labe num_labeled=nlabeled, label_vars=label, filename=output) # Log click.echo(click.style(f"Done: Saved plot to {output}", fg='green')) + + +@click.argument('ewas_result', type=EWAS_RESULT) +@arg_output +@click.option('--pvalue_name', type=click.Choice(['pvalue', 'pvalue_bonferroni', 'pvalue_FDR']), default='pvalue', + help='Which pvalues to use in the plot') +@click.option('--cutoff', type=click.FLOAT, default=0.05, help="cutoff value for plotting the significance line") +@click.option('--num_rows', type=click.INT, default=20, help="How many EWAS result rows to plot") +def top_results(ewas_result, output, pvalue_name, cutoff, num_rows): + plot.top_results(ewas_result, pvalue_name=pvalue_name, cutoff=cutoff, num_rows=num_rows, filename=output) + # Log + click.echo(click.style(f"Done: Saved plot to {output}", fg='green')) diff --git a/clarite/modules/plot.py b/clarite/modules/plot.py index 8d5c46b..e7fd9d7 100644 --- a/clarite/modules/plot.py +++ b/clarite/modules/plot.py @@ -10,6 +10,8 @@ histogram distributions manhattan + top_results + """ from copy import copy @@ -512,7 +514,7 @@ def manhattan( Examples -------- - >>> clarite.plot_manhattan({'discovery':disc_df, 'replication':repl_df}, categories=data_categories, title="EWAS Results") + >>> clarite.plot.manhattan({'discovery':disc_df, 'replication':repl_df}, categories=data_categories, title="EWAS Results") .. image:: ../../_static/plot/manhattan.png """ @@ -611,7 +613,7 @@ def manhattan_bonferroni( Examples -------- - >>> clarite.plot_manhattan_bonferroni({'discovery':disc_df, 'replication':repl_df}, + >>> clarite.plot.manhattan_bonferroni({'discovery':disc_df, 'replication':repl_df}, categories=data_categories, title="EWAS Results") """ # Ensure corrected values are present @@ -690,7 +692,7 @@ def manhattan_fdr( Examples -------- - >>> clarite.plot_manhattan_fdr({'discovery':disc_df, 'replication':repl_df}, + >>> clarite.plot.manhattan_fdr({'discovery':disc_df, 'replication':repl_df}, categories=data_categories, title="EWAS Results") """ # Ensure corrected values are present @@ -717,3 +719,88 @@ def manhattan_fdr( background_colors=background_colors, filename=filename ) + + +def top_results( + ewas_result: pd.DataFrame, + pvalue_name: str = "pvalue", + cutoff: float = 0.05, + num_rows: int = 20, + filename: Optional[str] = None +): + """ + Create a dotplot for EWAS Results showing pvalues and beta coefficients + + Parameters + ---------- + ewas_result: DataFrame + EWAS Result to plot + pvalue_name: str + 'pvalue', 'pvalue_fdr', or 'pvalue_bonferroni' + cutoff: float (default 0.05) + A vertical line is drawn in the pvalue column to show a significance cutoff + num_rows: int (default 20) + How many rows to show in the plot + filename: Optional str + If provided, a copy of the plot will be saved to the specified file + + Returns + ------- + None + + Examples + -------- + >>> clarite.plot.top_results(ewas_result) + """ + # Ensure corrected pvalues are present + if pvalue_name == 'pvalue_fdr' or pvalue_name == 'pvalue_bonferroni': + if pvalue_name not in list(ewas_result): + raise ValueError(f"Missing corrected pvalues in ewas result. Run clarite.analyze.add_corrected_pvalues") + elif pvalue_name == 'pvalue': + pass + else: + raise ValueError(f"Incorrect value specified for 'pvalue_name': must be one of 'pvalue', 'pvalue_fdr'," + f" or 'pvalue_bonferroni'.") + + # Sort and filter data + df = ewas_result.sort_values(pvalue_name, ascending=True).head(num_rows).reset_index() + df["Variable, Phenotype"] = df[["Variable", "Phenotype"]].apply(lambda r: ', '.join(r), axis=1) + df["Significant"] = df[pvalue_name] <= cutoff + + # Plot + sns.set(style="whitegrid") + g = sns.PairGrid(df, + x_vars=[pvalue_name, 'Beta'], + y_vars=["Variable, Phenotype"], + hue="Significant", + height=10, aspect=0.50, + layout_pad=1.3) + # Draw vertical lines before plotting points + g.axes.flat[0].axvline(x=cutoff, ls='-', color='black') # Significance cutoff + g.axes.flat[1].axvline(x=0, ls='-', color='black') # 0 Beta + + # Plot points + g.map(sns.stripplot, size=10, orient='h', + palette="ch:s=1,r=-0.1,h=1_r", + linewidth=1, + edgecolor='w') + + # Format + for ax in g.axes.flat: + ax.xaxis.grid(False) + ax.yaxis.grid(True) + sns.despine(left=True, bottom=True) + + # Update Axes + # pvalue + g.axes.flat[0].set_xscale('log') + g.axes.flat[0].set_xlim(0.1 * df[pvalue_name].min(), 100) + g.axes.flat[0].set_xlabel(f"{pvalue_name} (cutoff = {cutoff:.3f})") + # Beta + max_beta = df['Beta'].abs().max() + g.axes.flat[1].set_xlim(-1.10 * max_beta, 1.1 * max_beta) # max value +/- 10% + + # Save + if filename is not None: + plt.savefig(filename, bbox_inches="tight") + plt.show()