diff --git a/dabest/_dabest_object.py b/dabest/_dabest_object.py index c29a2d8a..70145656 100644 --- a/dabest/_dabest_object.py +++ b/dabest/_dabest_object.py @@ -7,6 +7,7 @@ # %% ../nbs/API/dabest_object.ipynb 5 # Import standard data science libraries +import warnings from numpy import array, repeat, random, issubdtype, number import numpy as np import pandas as pd @@ -62,7 +63,6 @@ def __init__( # Check if there is NaN under any of the paired settings if self.__is_paired and self.__output_data.isnull().values.any(): - import warnings warn1 = f"NaN values detected under paired setting and removed," warn2 = f" please check your data." warnings.warn(warn1 + warn2) @@ -500,10 +500,10 @@ def _check_errors(self, x, y, idx, experiment, experiment_label, x1_level): if x is None: error_msg = "If `delta2` is True. `x` parameter cannot be None. String or list expected" raise ValueError(error_msg) - + if self.__proportional: - err0 = "`proportional` and `delta2` cannot be True at the same time." - raise ValueError(err0) + mes1 = "Only mean_diff is supported for proportional data when `delta2` is True" + warnings.warn(message=mes1, category=UserWarning) # idx should not be specified if idx: @@ -581,8 +581,6 @@ def _get_plot_data(self, x, y, all_plot_groups): """ # Check if there is NaN under any of the paired settings if self.__is_paired is not None and self.__output_data.isnull().values.any(): - print("Nan") - import warnings warn1 = f"NaN values detected under paired setting and removed," warn2 = f" please check your data." warnings.warn(warn1 + warn2) @@ -634,7 +632,6 @@ def _get_plot_data(self, x, y, all_plot_groups): # Check if there is NaN under any of the paired settings if self.__is_paired is not None and self.__output_data.isnull().values.any(): - import warnings warn1 = f"NaN values detected under paired setting and removed," warn2 = f" please check your data." warnings.warn(warn1 + warn2) diff --git a/dabest/_effsize_objects.py b/dabest/_effsize_objects.py index ab638692..cd4aa53f 100644 --- a/dabest/_effsize_objects.py +++ b/dabest/_effsize_objects.py @@ -257,7 +257,8 @@ def _check_errors(self, control, test): raise ValueError(err1) if self.__proportional and self.__effect_size not in ["mean_diff", "cohens_h"]: - err1 = "`proportional` is True; therefore effect size other than mean_diff and cohens_h is not defined." + err1 = "`proportional` is True; therefore effect size other than mean_diff and cohens_h is not defined." + \ + "If you are calculating deltas' g, it's the same as delta-delta when `proportional` is True" raise ValueError(err1) if self.__proportional and ( @@ -884,6 +885,7 @@ def __pre_calc(self): self.__is_paired, self.__resamples, self.__random_seed, + self.__proportional, ) for j, current_tuple in enumerate(idx): diff --git a/dabest/_stats_tools/confint_2group_diff.py b/dabest/_stats_tools/confint_2group_diff.py index afdb44b2..59b53894 100644 --- a/dabest/_stats_tools/confint_2group_diff.py +++ b/dabest/_stats_tools/confint_2group_diff.py @@ -159,24 +159,26 @@ def compute_bootstrapped_diff( return out -@njit(cache=True) # parallelization must be turned off for random number generation -def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_paired): + +@njit(cache=True) +def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_paired, proportional=False): + """ + Compute bootstrapped differences for delta-delta, handling both regular and proportional data + """ np.random.seed(rng_seed) - out_delta_g = np.empty(resamples) deltadelta = np.empty(resamples) + out_delta_g = np.empty(resamples) n1, n2, n3, n4 = len(x1), len(x2), len(x3), len(x4) - if is_paired: - if n1 != n2 or n3 != n4: - raise ValueError("Each control group must have the same length as its corresponding test group in paired analysis.") - + if is_paired and (n1 != n2 or n3 != n4): + raise ValueError("Each control group must have the same length as its corresponding test group in paired analysis.") # Bootstrapping for i in range(resamples): # Paired or unpaired resampling if is_paired: - indices_1 = np.random.choice(len(x1),len(x1)) - indices_2 = np.random.choice(len(x3),len(x3)) + indices_1 = np.random.choice(len(x1), len(x1)) + indices_2 = np.random.choice(len(x3), len(x3)) x1_sample, x2_sample = x1[indices_1], x2[indices_1] x3_sample, x4_sample = x3[indices_2], x4[indices_2] else: @@ -187,13 +189,14 @@ def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_pai x1_sample, x2_sample = x1[indices_1], x2[indices_2] x3_sample, x4_sample = x3[indices_3], x4[indices_4] - # Calculating deltas + # Calculate deltas delta_1 = np.mean(x2_sample) - np.mean(x1_sample) delta_2 = np.mean(x4_sample) - np.mean(x3_sample) delta_delta = delta_2 - delta_1 - + deltadelta[i] = delta_delta - out_delta_g[i] = delta_delta / pooled_sd + + out_delta_g[i] = delta_delta if proportional else delta_delta/pooled_sd return out_delta_g, deltadelta @@ -204,39 +207,42 @@ def compute_delta2_bootstrapped_diff( x3: np.ndarray, # Control group 2 x4: np.ndarray, # Test group 2 is_paired: str = None, - resamples: int = 5000, # The number of bootstrap resamples to be taken for the calculation of the confidence interval limits. - random_seed: int = 12345, # `random_seed` is used to seed the random number generator during bootstrap resampling. This ensures that the confidence intervals reported are replicable. -) -> ( - tuple -): # bootstraped result and empirical result of deltas' g, and the bootstraped result of delta-delta + resamples: int = 5000, + random_seed: int = 12345, + proportional: bool = False +) -> tuple: """ - Bootstraps the effect size deltas' g. - + Bootstraps the effect size deltas' g or proportional delta-delta """ - x1, x2, x3, x4 = map(np.asarray, [x1, x2, x3, x4]) - - # Calculating pooled sample standard deviation - stds = [np.std(x) for x in [x1, x2, x3, x4]] - ns = [len(x) for x in [x1, x2, x3, x4]] - - sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds)) - sd_denominator = sum(n - 1 for n in ns) - - # Avoid division by zero - if sd_denominator == 0: - raise ValueError("Insufficient data to compute pooled standard deviation.") - - pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator) - - # Ensure pooled_sample_sd is not NaN or zero (to avoid division by zero later) - if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0: - raise ValueError("Pooled sample standard deviation is NaN or zero.") - - out_delta_g, deltadelta = delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sample_sd, random_seed, is_paired) - - # Empirical delta_g calculation - delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd + + if proportional: + # For proportional data, pass 1.0 as dummy pooled_sd (won't be used) + out_delta_g, deltadelta = delta2_bootstrap_loop( + x1, x2, x3, x4, resamples, 1.0, random_seed, is_paired, proportional=True + ) + # For proportional data, delta_g is the empirical delta-delta + delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) + else: + # Calculate pooled sample standard deviation for non-proportional data + stds = [np.std(x) for x in [x1, x2, x3, x4]] + ns = [len(x) for x in [x1, x2, x3, x4]] + + sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds)) + sd_denominator = sum(n - 1 for n in ns) + + if sd_denominator == 0: + raise ValueError("Insufficient data to compute pooled standard deviation.") + + pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator) + + if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0: + raise ValueError("Pooled sample standard deviation is NaN or zero.") + + out_delta_g, deltadelta = delta2_bootstrap_loop( + x1, x2, x3, x4, resamples, pooled_sample_sd, random_seed, is_paired, proportional=False + ) + delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd return out_delta_g, delta_g, deltadelta diff --git a/dabest/misc_tools.py b/dabest/misc_tools.py index 558f7674..0ede7c42 100644 --- a/dabest/misc_tools.py +++ b/dabest/misc_tools.py @@ -590,12 +590,12 @@ def get_color_palette( if color_by_subgroups: plot_palette_raw = dict() plot_palette_contrast = dict() - # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots - plot_palette_bar = None + plot_palette_bar = dict() for i in range(len(idx)): for names_i in idx[i]: plot_palette_raw[names_i] = swarm_colors[i] plot_palette_contrast[names_i] = contrast_colors[i] + plot_palette_bar[names_i] = bar_color[i] else: plot_palette_raw = dict(zip(categories, swarm_colors)) plot_palette_contrast = dict(zip(categories, contrast_colors)) @@ -612,11 +612,12 @@ def get_color_palette( if color_by_subgroups: plot_palette_raw = dict() plot_palette_contrast = dict() - plot_palette_bar = None # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots + plot_palette_bar = dict() for i in range(len(idx)): for names_i in idx[i]: plot_palette_raw[names_i] = swarm_colors[i] plot_palette_contrast[names_i] = contrast_colors[i] + plot_palette_bar[names_i] = bar_color[i] else: plot_palette_raw = dict(zip(names, swarm_colors)) plot_palette_contrast = dict(zip(names, contrast_colors)) @@ -1018,6 +1019,7 @@ def lookup_value(text): ticks_with_counts.append(f"{t}\n(N={value})") fontsize_rawxlabel = plot_kwargs.get("fontsize_rawxlabel") + set_major_loc_method(plt.FixedLocator(get_ticks())) set_label(ticks_with_counts, fontsize=fontsize_rawxlabel) # Ensure ticks are at the correct locations diff --git a/dabest/plot_tools.py b/dabest/plot_tools.py index 68d83a30..2448241b 100644 --- a/dabest/plot_tools.py +++ b/dabest/plot_tools.py @@ -731,14 +731,17 @@ def sankeydiag( right_idx = [] # Design for Sankey Flow Diagram sankey_idx = ( - [ - (control, test) - for i in idx - for control, test in zip(i[:], (i[1:] + (i[0],))) - ] - if flow - else temp_idx - ) + [ + (control, test) + for i in idx + for control, test in zip( + i[:], + (tuple(i[1:]) + (i[0],)) if isinstance(i, tuple) else (list(i[1:]) + [i[0]]) + ) + ] + if flow + else temp_idx +) for i in sankey_idx: left_idx.append(i[0]) right_idx.append(i[1]) @@ -2065,6 +2068,7 @@ def barplotter( plot_data: pd.DataFrame, bar_color: str, plot_palette_bar: dict, + color_col: str, plot_kwargs: dict, barplot_kwargs: dict, horizontal: bool @@ -2088,6 +2092,8 @@ def barplotter( Color of the bar. plot_palette_bar : dict Dictionary of colors used in the bar plot. + color_col : str + Column name of the color column. plot_kwargs : dict Keyword arguments for the plot. barplot_kwargs : dict @@ -2102,7 +2108,26 @@ def barplotter( else: x_var, y_var, orient = all_plot_groups, np.ones(len(all_plot_groups)), "v" - bar1_df = pd.DataFrame({xvar: x_var, "proportion": y_var}) + # Create bar1_df with basic columns + bar1_df = pd.DataFrame({ + xvar: x_var, + "proportion": y_var + }) + + # Handle colors + if color_col: + # Get first color value for each group + color_mapping = plot_data.groupby(xvar, observed=False)[color_col].first() + bar1_df[color_col] = [color_mapping.get(group) for group in all_plot_groups] + + # Map colors, defaulting to bar_color if no match + edge_colors = [ + plot_palette_bar.get(hue_val, bar_color) + for hue_val in bar1_df[color_col] + ] + else: + edge_colors = bar_color + bar1 = sns.barplot( data=bar1_df, @@ -2112,7 +2137,7 @@ def barplotter( order=all_plot_groups, linewidth=2, facecolor=(1, 1, 1, 0), - edgecolor=bar_color, + edgecolor=edge_colors, zorder=1, orient=orient, ) @@ -2123,6 +2148,8 @@ def barplotter( ax=rawdata_axes, order=all_plot_groups, palette=plot_palette_bar, + hue=color_col, + dodge=False, zorder=1, orient=orient, **barplot_kwargs diff --git a/dabest/plotter.py b/dabest/plotter.py index 57e41615..aa825b2d 100644 --- a/dabest/plotter.py +++ b/dabest/plotter.py @@ -277,6 +277,7 @@ def effectsize_df_plotter(effectsize_df: object, **plot_kwargs) -> matplotlib.fi plot_data = plot_data, bar_color = bar_color, plot_palette_bar = plot_palette_bar, + color_col = color_col, plot_kwargs = plot_kwargs, barplot_kwargs = barplot_kwargs, horizontal = horizontal, diff --git a/nbs/API/confint_2group_diff.ipynb b/nbs/API/confint_2group_diff.ipynb index c080d452..c5dc22da 100644 --- a/nbs/API/confint_2group_diff.ipynb +++ b/nbs/API/confint_2group_diff.ipynb @@ -213,24 +213,26 @@ "\n", " return out\n", "\n", - "@njit(cache=True) # parallelization must be turned off for random number generation\n", - "def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_paired):\n", + "\n", + "@njit(cache=True)\n", + "def delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sd, rng_seed, is_paired, proportional=False):\n", + " \"\"\"\n", + " Compute bootstrapped differences for delta-delta, handling both regular and proportional data\n", + " \"\"\"\n", " np.random.seed(rng_seed)\n", - " out_delta_g = np.empty(resamples)\n", " deltadelta = np.empty(resamples)\n", + " out_delta_g = np.empty(resamples)\n", " \n", " n1, n2, n3, n4 = len(x1), len(x2), len(x3), len(x4)\n", - " if is_paired:\n", - " if n1 != n2 or n3 != n4:\n", - " raise ValueError(\"Each control group must have the same length as its corresponding test group in paired analysis.\")\n", - " \n", + " if is_paired and (n1 != n2 or n3 != n4):\n", + " raise ValueError(\"Each control group must have the same length as its corresponding test group in paired analysis.\")\n", "\n", " # Bootstrapping\n", " for i in range(resamples):\n", " # Paired or unpaired resampling\n", " if is_paired:\n", - " indices_1 = np.random.choice(len(x1),len(x1))\n", - " indices_2 = np.random.choice(len(x3),len(x3))\n", + " indices_1 = np.random.choice(len(x1), len(x1))\n", + " indices_2 = np.random.choice(len(x3), len(x3))\n", " x1_sample, x2_sample = x1[indices_1], x2[indices_1]\n", " x3_sample, x4_sample = x3[indices_2], x4[indices_2]\n", " else:\n", @@ -241,13 +243,14 @@ " x1_sample, x2_sample = x1[indices_1], x2[indices_2]\n", " x3_sample, x4_sample = x3[indices_3], x4[indices_4]\n", "\n", - " # Calculating deltas\n", + " # Calculate deltas\n", " delta_1 = np.mean(x2_sample) - np.mean(x1_sample)\n", " delta_2 = np.mean(x4_sample) - np.mean(x3_sample)\n", " delta_delta = delta_2 - delta_1\n", - "\n", + " \n", " deltadelta[i] = delta_delta\n", - " out_delta_g[i] = delta_delta / pooled_sd\n", + "\n", + " out_delta_g[i] = delta_delta if proportional else delta_delta/pooled_sd\n", "\n", " return out_delta_g, deltadelta\n", "\n", @@ -258,39 +261,42 @@ " x3: np.ndarray, # Control group 2\n", " x4: np.ndarray, # Test group 2\n", " is_paired: str = None,\n", - " resamples: int = 5000, # The number of bootstrap resamples to be taken for the calculation of the confidence interval limits.\n", - " random_seed: int = 12345, # `random_seed` is used to seed the random number generator during bootstrap resampling. This ensures that the confidence intervals reported are replicable.\n", - ") -> (\n", - " tuple\n", - "): # bootstraped result and empirical result of deltas' g, and the bootstraped result of delta-delta\n", + " resamples: int = 5000,\n", + " random_seed: int = 12345,\n", + " proportional: bool = False\n", + ") -> tuple:\n", " \"\"\"\n", - " Bootstraps the effect size deltas' g.\n", - "\n", + " Bootstraps the effect size deltas' g or proportional delta-delta\n", " \"\"\"\n", - "\n", " x1, x2, x3, x4 = map(np.asarray, [x1, x2, x3, x4])\n", - "\n", - " # Calculating pooled sample standard deviation\n", - " stds = [np.std(x) for x in [x1, x2, x3, x4]]\n", - " ns = [len(x) for x in [x1, x2, x3, x4]]\n", - "\n", - " sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds))\n", - " sd_denominator = sum(n - 1 for n in ns)\n", - "\n", - " # Avoid division by zero\n", - " if sd_denominator == 0:\n", - " raise ValueError(\"Insufficient data to compute pooled standard deviation.\")\n", - "\n", - " pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator)\n", - "\n", - " # Ensure pooled_sample_sd is not NaN or zero (to avoid division by zero later)\n", - " if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0:\n", - " raise ValueError(\"Pooled sample standard deviation is NaN or zero.\")\n", - "\n", - " out_delta_g, deltadelta = delta2_bootstrap_loop(x1, x2, x3, x4, resamples, pooled_sample_sd, random_seed, is_paired)\n", - "\n", - " # Empirical delta_g calculation\n", - " delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd\n", + " \n", + " if proportional:\n", + " # For proportional data, pass 1.0 as dummy pooled_sd (won't be used)\n", + " out_delta_g, deltadelta = delta2_bootstrap_loop(\n", + " x1, x2, x3, x4, resamples, 1.0, random_seed, is_paired, proportional=True\n", + " )\n", + " # For proportional data, delta_g is the empirical delta-delta\n", + " delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1)))\n", + " else:\n", + " # Calculate pooled sample standard deviation for non-proportional data\n", + " stds = [np.std(x) for x in [x1, x2, x3, x4]]\n", + " ns = [len(x) for x in [x1, x2, x3, x4]]\n", + " \n", + " sd_numerator = sum((n - 1) * s**2 for n, s in zip(ns, stds))\n", + " sd_denominator = sum(n - 1 for n in ns)\n", + " \n", + " if sd_denominator == 0:\n", + " raise ValueError(\"Insufficient data to compute pooled standard deviation.\")\n", + " \n", + " pooled_sample_sd = np.sqrt(sd_numerator / sd_denominator)\n", + " \n", + " if np.isnan(pooled_sample_sd) or pooled_sample_sd == 0:\n", + " raise ValueError(\"Pooled sample standard deviation is NaN or zero.\")\n", + " \n", + " out_delta_g, deltadelta = delta2_bootstrap_loop(\n", + " x1, x2, x3, x4, resamples, pooled_sample_sd, random_seed, is_paired, proportional=False\n", + " )\n", + " delta_g = ((np.mean(x4) - np.mean(x3)) - (np.mean(x2) - np.mean(x1))) / pooled_sample_sd\n", "\n", " return out_delta_g, delta_g, deltadelta\n", "\n", diff --git a/nbs/API/dabest_object.ipynb b/nbs/API/dabest_object.ipynb index dee0aef6..011c056f 100644 --- a/nbs/API/dabest_object.ipynb +++ b/nbs/API/dabest_object.ipynb @@ -64,6 +64,7 @@ "source": [ "#| export\n", "# Import standard data science libraries\n", + "import warnings\n", "from numpy import array, repeat, random, issubdtype, number\n", "import numpy as np\n", "import pandas as pd\n", @@ -138,7 +139,6 @@ "\n", " # Check if there is NaN under any of the paired settings\n", " if self.__is_paired and self.__output_data.isnull().values.any():\n", - " import warnings\n", " warn1 = f\"NaN values detected under paired setting and removed,\"\n", " warn2 = f\" please check your data.\"\n", " warnings.warn(warn1 + warn2)\n", @@ -576,10 +576,10 @@ " if x is None:\n", " error_msg = \"If `delta2` is True. `x` parameter cannot be None. String or list expected\"\n", " raise ValueError(error_msg)\n", - " \n", + " \n", " if self.__proportional:\n", - " err0 = \"`proportional` and `delta2` cannot be True at the same time.\"\n", - " raise ValueError(err0)\n", + " mes1 = \"Only mean_diff is supported for proportional data when `delta2` is True\"\n", + " warnings.warn(message=mes1, category=UserWarning)\n", "\n", " # idx should not be specified\n", " if idx:\n", @@ -657,8 +657,6 @@ " \"\"\"\n", " # Check if there is NaN under any of the paired settings\n", " if self.__is_paired is not None and self.__output_data.isnull().values.any():\n", - " print(\"Nan\")\n", - " import warnings\n", " warn1 = f\"NaN values detected under paired setting and removed,\"\n", " warn2 = f\" please check your data.\"\n", " warnings.warn(warn1 + warn2)\n", @@ -710,7 +708,6 @@ "\n", " # Check if there is NaN under any of the paired settings\n", " if self.__is_paired is not None and self.__output_data.isnull().values.any():\n", - " import warnings\n", " warn1 = f\"NaN values detected under paired setting and removed,\"\n", " warn2 = f\" please check your data.\"\n", " warnings.warn(warn1 + warn2)\n", diff --git a/nbs/API/effsize_objects.ipynb b/nbs/API/effsize_objects.ipynb index eec0d119..43bb3a00 100644 --- a/nbs/API/effsize_objects.ipynb +++ b/nbs/API/effsize_objects.ipynb @@ -317,7 +317,8 @@ " raise ValueError(err1)\n", "\n", " if self.__proportional and self.__effect_size not in [\"mean_diff\", \"cohens_h\"]:\n", - " err1 = \"`proportional` is True; therefore effect size other than mean_diff and cohens_h is not defined.\"\n", + " err1 = \"`proportional` is True; therefore effect size other than mean_diff and cohens_h is not defined.\" + \\\n", + " \"If you are calculating deltas' g, it's the same as delta-delta when `proportional` is True\"\n", " raise ValueError(err1)\n", "\n", " if self.__proportional and (\n", @@ -1043,6 +1044,7 @@ " self.__is_paired,\n", " self.__resamples,\n", " self.__random_seed,\n", + " self.__proportional,\n", " )\n", "\n", " for j, current_tuple in enumerate(idx):\n", diff --git a/nbs/API/misc_tools.ipynb b/nbs/API/misc_tools.ipynb index 24a1e54b..a1290bff 100644 --- a/nbs/API/misc_tools.ipynb +++ b/nbs/API/misc_tools.ipynb @@ -643,12 +643,12 @@ " if color_by_subgroups:\n", " plot_palette_raw = dict()\n", " plot_palette_contrast = dict()\n", - " # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots\n", - " plot_palette_bar = None\n", + " plot_palette_bar = dict()\n", " for i in range(len(idx)):\n", " for names_i in idx[i]:\n", " plot_palette_raw[names_i] = swarm_colors[i]\n", " plot_palette_contrast[names_i] = contrast_colors[i]\n", + " plot_palette_bar[names_i] = bar_color[i]\n", " else:\n", " plot_palette_raw = dict(zip(categories, swarm_colors))\n", " plot_palette_contrast = dict(zip(categories, contrast_colors))\n", @@ -665,11 +665,12 @@ " if color_by_subgroups:\n", " plot_palette_raw = dict()\n", " plot_palette_contrast = dict()\n", - " plot_palette_bar = None # plot_palette_bar set to None because currently there is no empty_circle toggle for proportion plots\n", + " plot_palette_bar = dict()\n", " for i in range(len(idx)):\n", " for names_i in idx[i]:\n", " plot_palette_raw[names_i] = swarm_colors[i]\n", " plot_palette_contrast[names_i] = contrast_colors[i]\n", + " plot_palette_bar[names_i] = bar_color[i]\n", " else:\n", " plot_palette_raw = dict(zip(names, swarm_colors))\n", " plot_palette_contrast = dict(zip(names, contrast_colors))\n", @@ -1071,6 +1072,7 @@ " ticks_with_counts.append(f\"{t}\\n(N={value})\")\n", "\n", " fontsize_rawxlabel = plot_kwargs.get(\"fontsize_rawxlabel\")\n", + " set_major_loc_method(plt.FixedLocator(get_ticks()))\n", " set_label(ticks_with_counts, fontsize=fontsize_rawxlabel)\n", "\n", " # Ensure ticks are at the correct locations\n", diff --git a/nbs/API/plot_tools.ipynb b/nbs/API/plot_tools.ipynb index 61c3ac67..3f72e633 100644 --- a/nbs/API/plot_tools.ipynb +++ b/nbs/API/plot_tools.ipynb @@ -781,14 +781,17 @@ " right_idx = []\n", " # Design for Sankey Flow Diagram\n", " sankey_idx = (\n", - " [\n", - " (control, test)\n", - " for i in idx\n", - " for control, test in zip(i[:], (i[1:] + (i[0],)))\n", - " ]\n", - " if flow\n", - " else temp_idx\n", - " )\n", + " [\n", + " (control, test)\n", + " for i in idx\n", + " for control, test in zip(\n", + " i[:],\n", + " (tuple(i[1:]) + (i[0],)) if isinstance(i, tuple) else (list(i[1:]) + [i[0]])\n", + " )\n", + " ]\n", + " if flow\n", + " else temp_idx\n", + ")\n", " for i in sankey_idx:\n", " left_idx.append(i[0])\n", " right_idx.append(i[1])\n", @@ -2115,6 +2118,7 @@ " plot_data: pd.DataFrame, \n", " bar_color: str, \n", " plot_palette_bar: dict, \n", + " color_col: str,\n", " plot_kwargs: dict, \n", " barplot_kwargs: dict, \n", " horizontal: bool\n", @@ -2138,6 +2142,8 @@ " Color of the bar.\n", " plot_palette_bar : dict\n", " Dictionary of colors used in the bar plot.\n", + " color_col : str\n", + " Column name of the color column.\n", " plot_kwargs : dict\n", " Keyword arguments for the plot.\n", " barplot_kwargs : dict\n", @@ -2152,7 +2158,26 @@ " else:\n", " x_var, y_var, orient = all_plot_groups, np.ones(len(all_plot_groups)), \"v\"\n", "\n", - " bar1_df = pd.DataFrame({xvar: x_var, \"proportion\": y_var})\n", + " # Create bar1_df with basic columns\n", + " bar1_df = pd.DataFrame({\n", + " xvar: x_var, \n", + " \"proportion\": y_var\n", + " })\n", + "\n", + " # Handle colors\n", + " if color_col:\n", + " # Get first color value for each group\n", + " color_mapping = plot_data.groupby(xvar, observed=False)[color_col].first()\n", + " bar1_df[color_col] = [color_mapping.get(group) for group in all_plot_groups]\n", + " \n", + " # Map colors, defaulting to bar_color if no match\n", + " edge_colors = [\n", + " plot_palette_bar.get(hue_val, bar_color) \n", + " for hue_val in bar1_df[color_col]\n", + " ]\n", + " else:\n", + " edge_colors = bar_color\n", + "\n", "\n", " bar1 = sns.barplot(\n", " data=bar1_df,\n", @@ -2162,7 +2187,7 @@ " order=all_plot_groups,\n", " linewidth=2,\n", " facecolor=(1, 1, 1, 0),\n", - " edgecolor=bar_color,\n", + " edgecolor=edge_colors,\n", " zorder=1,\n", " orient=orient,\n", " )\n", @@ -2173,6 +2198,8 @@ " ax=rawdata_axes,\n", " order=all_plot_groups,\n", " palette=plot_palette_bar,\n", + " hue=color_col,\n", + " dodge=False,\n", " zorder=1,\n", " orient=orient,\n", " **barplot_kwargs\n", diff --git a/nbs/API/plotter.ipynb b/nbs/API/plotter.ipynb index 060e8f64..8d661f96 100644 --- a/nbs/API/plotter.ipynb +++ b/nbs/API/plotter.ipynb @@ -334,6 +334,7 @@ " plot_data = plot_data, \n", " bar_color = bar_color, \n", " plot_palette_bar = plot_palette_bar, \n", + " color_col = color_col,\n", " plot_kwargs = plot_kwargs, \n", " barplot_kwargs = barplot_kwargs,\n", " horizontal = horizontal,\n", diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_74_unpaired_prop_delta2.png b/nbs/tests/mpl_image_tests/baseline_images/test_74_unpaired_prop_delta2.png new file mode 100644 index 00000000..a036733c Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_74_unpaired_prop_delta2.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_75_unpaired_specified_prop_delta2.png b/nbs/tests/mpl_image_tests/baseline_images/test_75_unpaired_specified_prop_delta2.png new file mode 100644 index 00000000..3d450b61 Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_75_unpaired_specified_prop_delta2.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_76_paired_prop_delta2.png b/nbs/tests/mpl_image_tests/baseline_images/test_76_paired_prop_delta2.png new file mode 100644 index 00000000..986abce2 Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_76_paired_prop_delta2.png differ diff --git a/nbs/tests/mpl_image_tests/baseline_images/test_77_paired_specified_prop_delta2.png b/nbs/tests/mpl_image_tests/baseline_images/test_77_paired_specified_prop_delta2.png new file mode 100644 index 00000000..cfc9cc9b Binary files /dev/null and b/nbs/tests/mpl_image_tests/baseline_images/test_77_paired_specified_prop_delta2.png differ diff --git a/nbs/tests/mpl_image_tests/test_07_delta-delta_plots.py b/nbs/tests/mpl_image_tests/test_07_delta-delta_plots.py index 9dd63e1e..abe1af3c 100644 --- a/nbs/tests/mpl_image_tests/test_07_delta-delta_plots.py +++ b/nbs/tests/mpl_image_tests/test_07_delta-delta_plots.py @@ -26,6 +26,9 @@ def create_demo_dataset_delta(seed=9999, N=20): y = norm.rvs(loc=3, scale=0.4, size=N*4) y[N:2*N] = y[N:2*N]+1 y[2*N:3*N] = y[2*N:3*N]-0.5 + ind = np.random.binomial(1, 0.5, size=N*4) + ind[N:2*N] = np.random.binomial(1, 0.2, size=N) + ind[2*N:3*N] = np.random.binomial(1, 0.7, size=N) # Add drug column t1 = np.repeat('Placebo', N*2).tolist() @@ -54,10 +57,11 @@ def create_demo_dataset_delta(seed=9999, N=20): # Combine all columns into a DataFrame. df = pd.DataFrame({'ID' : id_col, - 'Rep' : rep, + 'Rep' : rep, 'Genotype' : genotype, - 'Treatment': treatment, - 'Y' : y + 'Treatment' : treatment, + 'Y' : y, + 'Cat' :ind }) return df @@ -81,6 +85,34 @@ def create_demo_dataset_delta(seed=9999, N=20): experiment = "Genotype", paired="sequential", id_col="ID") +unpaired_prop = load(data = df, proportional=True, + # id_col="index", paired='baseline', + x = ["Genotype", "Genotype"], + y = "Cat", delta2=True, + experiment="Treatment",) + +unpaired_specified_prop = load(data = df, proportional=True, + # id_col="index", paired='baseline', + x = ["Genotype", "Genotype"], + y = "Cat", delta2=True, + experiment="Treatment", + experiment_label = ["Drug", "Placebo"], + x1_level = ["M", "W"]) + +paired_prop = load(data = df, proportional=True, + id_col="ID", paired='baseline', + x = ["Genotype", "Genotype"], + y = "Cat", delta2=True, + experiment="Treatment",) + +paired_specified_prop = load(data = df, proportional=True, + id_col="ID", paired='baseline', + x = ["Genotype", "Genotype"], + y = "Cat", delta2=True, + experiment="Treatment", + experiment_label = ["Drug", "Placebo"], + x1_level = ["M", "W"]) + @pytest.mark.mpl_image_compare(tolerance=8) def test_47_cummings_unpaired_delta_delta_meandiff(): @@ -164,4 +196,20 @@ def test_72_sequential_delta_g(): @pytest.mark.mpl_image_compare(tolerance=8) def test_73_baseline_delta_g(): - return baseline.mean_diff.plot(); \ No newline at end of file + return baseline.mean_diff.plot(); + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_74_unpaired_prop_delta2(): + return unpaired_prop.mean_diff.plot() + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_75_unpaired_specified_prop_delta2(): + return unpaired_specified_prop.mean_diff.plot() + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_76_paired_prop_delta2(): + return paired_prop.mean_diff.plot() + +@pytest.mark.mpl_image_compare(tolerance=8) +def test_77_paired_specified_prop_delta2(): + return paired_specified_prop.mean_diff.plot() \ No newline at end of file diff --git a/nbs/tests/test_load_errors.py b/nbs/tests/test_load_errors.py index eb598796..fa9f1aa8 100644 --- a/nbs/tests/test_load_errors.py +++ b/nbs/tests/test_load_errors.py @@ -35,18 +35,6 @@ def test_wrong_params_combinations(): assert error_msg in str(excinfo.value) - error_msg = "`proportional` and `delta2` cannot be True at the same time." - with pytest.raises(ValueError) as excinfo: - my_data = load( - dummy_df, - x=["Control 1", "Control 1"], - y="Test 1", - delta2=True, - proportional=True - ) - - assert error_msg in str(excinfo.value) - error_msg = "`idx` should not be specified when `delta2` is True.".format(N) with pytest.raises(ValueError) as excinfo: my_data = load(