In [None]:
# default_exp plots_refactored

In [None]:
#hide
from nbdev.showdoc import *

# plots_refactored

> This module contains all functions to plot the data and annotate the results of the statistical analyses.

In [None]:
#export
from typing import Tuple, Dict, List, Optional, Union
from abc import ABC, abstractmethod

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

from dcl_stats_n_plots.database import Database

In [None]:
#export
class PlotHandler(ABC):
    
    @property
    @abstractmethod
    def plot_options_displayed_in_gui(self) -> List[str]:
        pass
    
    
    @abstractmethod
    def add_handler_specific_plots(self) -> Tuple[plt.Figure, plt.Axes]:
        fig = self.fig
        ax = self.ax
        # do whatever
        return fig, ax
    
    
    @abstractmethod
    def add_handler_specific_stats_annotations(self) -> Tuple[plt.Figure, plt.Axes]:
        fig = self.fig
        ax = self.ax
        # do whatever
        return fig, ax


    def plot(self, database: Database) -> Database:
        self.database = database
        self.configs = database.configs
        self.data = database.data.copy()
        self.stats_results = database.stats_results.copy()
        self.fig, self.ax = self.initialize_plot()
        self.fig, self.ax = self.add_handler_specific_plots()
        self.fig, self.ax = self.add_handler_specific_stats_annotations()
        self.fig, self.ax = self.finish_plot()
        database.created_plot = self
        return database
        
    
    def initialize_plot(self) -> Tuple[plt.Figure, plt.Axes]:
        fig = plt.figure(figsize=(self.configs.fig_width/2.54 , self.configs.fig_height/2.54), facecolor='white')
        ax = fig.add_subplot()
        for axis in ['top', 'right']:
            ax.spines[axis].set_visible(False)
        for axis in ['bottom','left']:
            ax.spines[axis].set_linewidth(self.configs.axes_linewidth)
            ax.spines[axis].set_color(self.configs.axes_color)
        ax.tick_params(labelsize=self.configs.axes_tick_size, colors=self.configs.axes_color)
        return fig, ax
    
    
    def finish_plot(self) -> Tuple[plt.Figure, plt.Axes]:
        fig, ax = self.fig, self.ax
        ax.set_ylabel(self.configs.yaxis_label_text, fontsize=self.configs.yaxis_label_fontsize, color=self.configs.yaxis_label_color)
        ax.set_xlabel(self.configs.xaxis_label_text, fontsize=self.configs.xaxis_label_fontsize, color=self.configs.xaxis_label_color)
        if self.configs.yaxis_scaling_mode == 'manual': #1 for GUI, manual for API
            ax.set_ylim(self.configs.yaxis_lower_lim_value, self.configs.yaxis_upper_lim_value)
        return fig, ax

In [None]:
#export
class OneSamplePlots(PlotHandler):
    
    @property
    def plot_options_displayed_in_gui(self) -> List[str]:
        return ['stripplot', 'boxplot', 'boxplot with stripplot overlay', 'violinplot', 'violinplot with stripplot overlay']
    

    def add_handler_specific_plots(self) -> Tuple[plt.Figure, plt.Axes]:
        fig, ax = self.fig, self.ax
        data_column_name = self.database.stats_results['df_infos']['data_column_name']
        group_column_name = self.database.stats_results['df_infos']['group_column_name']
        fixed_value = self.database.stats_results['df_infos']['fixed_value']
        if self.configs.plot_type == 'stripplot':
            sns.stripplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                          palette = self.configs.color_palette, size = self.configs.marker_size, ax=ax)
            ax.hlines(y = fixed_value, xmin = -0.5, xmax = 0.5, color = 'gray', linestyle = 'dashed')
        elif self.configs.plot_type == 'boxplot':
            sns.boxplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                        palette = self.configs.color_palette, ax=ax)
            ax.hlines(y = fixed_value, xmin = -0.5, xmax = 0.5, color = 'gray', linestyle = 'dashed')
        elif self.configs.plot_type == 'boxplot with stripplot overlay':
            sns.boxplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                        palette = self.configs.color_palette, ax=ax, showfliers=False)
            sns.stripplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                          color = 'k', size = self.configs.marker_size, ax=ax)
            ax.hlines(y = fixed_value, xmin = -0.5, xmax = 0.5, color = 'gray', linestyle = 'dashed')
        elif self.configs.plot_type == 'violinplot':
            sns.violinplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                           palette = self.configs.color_palette, cut=0, ax=ax)
            ax.hlines(y = fixed_value, xmin = -0.5, xmax = 0.5, color = 'gray', linestyle = 'dashed')
        elif self.configs.plot_type == 'violinplot with stripplot overlay':
            sns.violinplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                           palette = self.configs.color_palette, cut=0, ax=ax)
            sns.stripplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                          color = 'k', size = self.configs.marker_size, ax=ax)
            ax.hlines(y = fixed_value, xmin = -0.5, xmax = 0.5, color = 'gray', linestyle = 'dashed')
        return fig, ax
    
    
    def add_handler_specific_stats_annotations(self) -> Tuple[plt.Figure, plt.Axes]:
        fig = self.fig
        ax = self.ax
        df = self.data
        if len(self.configs.l_stats_to_annotate) > 0:
            max_total = self.data[self.database.stats_results['df_infos']['data_column_name']].max()
            y_shift_annotation_line = max_total * self.configs.distance_brackets_to_data
            y_shift_annotation_text = y_shift_annotation_line*0.5*self.configs.distance_stars_to_brackets
            y = max_total + y_shift_annotation_line
            ax.text(0, y+y_shift_annotation_text, self.database.stats_results['summary_stats']['stars_str'],
                    ha='center', va='bottom', color='k', fontsize=self.configs.fontsize_stars, fontweight=self.configs.fontweight_stars)
        return fig, ax

In [None]:
#export
class MultipleIndependentSamplesPlots(PlotHandler):
    
    @property
    def plot_options_displayed_in_gui(self) -> List[str]:
        return ['stripplot', 'boxplot', 'boxplot with stripplot overlay', 'violinplot', 'violinplot with stripplot overlay']
    
                                      
    def add_handler_specific_plots(self) -> Tuple[plt.Figure, plt.Axes]:
        fig, ax = self.fig, self.ax
        data_column_name = self.database.stats_results['df_infos']['data_column_name']
        group_column_name = self.database.stats_results['df_infos']['group_column_name']
        if self.configs.plot_type == 'stripplot':
            sns.stripplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                          palette = self.configs.color_palette, size = self.configs.marker_size, ax=ax)
        elif self.configs.plot_type == 'boxplot':
            sns.boxplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                        palette = self.configs.color_palette, ax=ax)
        elif self.configs.plot_type == 'boxplot with stripplot overlay':
            sns.boxplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                        palette = self.configs.color_palette, ax=ax, showfliers=False)
            sns.stripplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                          color = 'k', size = self.configs.marker_size, ax=ax)
        elif self.configs.plot_type == 'violinplot':
            sns.violinplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                           palette = self.configs.color_palette, cut=0, ax=ax)
        elif self.configs.plot_type == 'violinplot with stripplot overlay':
            sns.violinplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                           palette = self.configs.color_palette, cut=0, ax=ax)
            sns.stripplot(data = self.data, x = group_column_name, y = data_column_name, order = self.configs.l_xlabel_order,
                          color = 'k', size = self.configs.marker_size, ax=ax)
        return fig, ax

    
    def add_handler_specific_stats_annotations(self) -> Tuple[plt.Figure, plt.Axes]:
        fig = self.fig
        ax = self.ax
        df = self.data
        if len(self.configs.l_stats_to_annotate) > 0:
            max_total = self.data[self.database.stats_results['df_infos']['data_column_name']].max()
            y_shift_annotation_line = max_total * self.configs.distance_brackets_to_data
            brackets_height = y_shift_annotation_line*0.5*self.configs.annotation_brackets_factor
            y_shift_annotation_text = brackets_height + y_shift_annotation_line*0.5*self.configs.distance_stars_to_brackets
            y = max_total + y_shift_annotation_line
            if self.database.stats_results['summary_stats']['p_value'] <= 0.05:
                df_temp = self.database.stats_results['pairwise_comparisons'].copy()
                for group1, group2 in self.configs.l_stats_to_annotate:
                    x1 = self.configs.l_xlabel_order.index(group1)
                    x2 = self.configs.l_xlabel_order.index(group2)
                    stars = self.get_stars_str(df_temp, group1, group2)
                    ax.plot([x1, x1, x2, x2], [y, y+brackets_height, y+brackets_height, y], c='k', lw=self.configs.linewidth_annotations)
                    ax.text((x1+x2)*.5, y+y_shift_annotation_text, stars, ha='center', va='bottom', color='k',
                             fontsize=self.configs.fontsize_stars, fontweight=self.configs.fontweight_stars)
                    # With set_distance_stars_to_brackets being limited to 5, stars will always be closer than next annotation line
                    y = y+3*y_shift_annotation_line
        return fig, ax
    

    def get_stars_str(self, df_tmp: pd.DataFrame, group1: str, group2: str) -> str:
        if df_tmp.loc[(df_tmp['A'] == group1) & (df_tmp['B'] == group2)].shape[0] > 0:
            if 'p-corr' in df_tmp.loc[(df_tmp['A'] == group1) & (df_tmp['B'] == group2)].columns:
                pval = df_tmp.loc[(df_tmp['A'] == group1) & (df_tmp['B'] == group2), 'p-corr'].iloc[0]
            else:
                pval = df_tmp.loc[(df_tmp['A'] == group1) & (df_tmp['B'] == group2), 'p-unc'].iloc[0]

        elif df_tmp.loc[(df_tmp['B'] == group1) & (df_tmp['A'] == group2)].shape[0] > 0:
            if 'p-corr' in df_tmp.loc[(df_tmp['B'] == group1) & (df_tmp['A'] == group2)].columns:
                pval = df_tmp.loc[(df_tmp['B'] == group1) & (df_tmp['A'] == group2), 'p-corr'].iloc[0]
            else:
                pval = df_tmp.loc[(df_tmp['B'] == group1) & (df_tmp['A'] == group2), 'p-unc'].iloc[0]
        else:
            print('There was an error with annotating the stats!')
        if pval <= 0.001:
            stars = '***'
        elif pval <= 0.01:
            stars = '**'
        elif pval <= 0.05:
            stars = '*'
        else:
            stars = 'n.s.'
        return stars                   

SyntaxError: invalid syntax (4049804021.py, line 1)