Skip to content

Commit

Permalink
black compare_distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
joshuawe committed Nov 10, 2023
1 parent 21f5a02 commit 0363cc3
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 37 deletions.
2 changes: 1 addition & 1 deletion plotsandgraphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import binary_classifier
from . import compare_distributions
from . import compare_distributions
78 changes: 46 additions & 32 deletions plotsandgraphs/compare_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,22 @@
from typing import List, Tuple, Optional


def plot_raincloud(df: pd.DataFrame,
x_col: str,
y_col: str,
colors: Optional[List[str]] = None,
order: Optional[List[str]] = None,
title: Optional[str] = None,
x_label: Optional[str] = None,
x_range: Optional[Tuple[float, float]] = None,
show_violin = True,
show_scatter = True,
show_boxplot = True):

def plot_raincloud(
df: pd.DataFrame,
x_col: str,
y_col: str,
colors: Optional[List[str]] = None,
order: Optional[List[str]] = None,
title: Optional[str] = None,
x_label: Optional[str] = None,
x_range: Optional[Tuple[float, float]] = None,
show_violin=True,
show_scatter=True,
show_boxplot=True,
):
"""
Generate a raincloud plot using Pandas DataFrame.
Parameters:
- df (pd.DataFrame): The data frame containing the data.
- x_col (str): The column name for the x-axis data.
Expand All @@ -36,7 +37,7 @@ def plot_raincloud(df: pd.DataFrame,
Returns:
- matplotlib.figure.Figure: The generated plot figure.
"""

fig, ax = plt.subplots(figsize=(16, 8))
offset = 0.2 # Offset value to move plots

Expand All @@ -45,34 +46,47 @@ def plot_raincloud(df: pd.DataFrame,

# if colors are none, use distinct colors for each group
if colors is None:
cmap = plt.get_cmap('tab10')
cmap = plt.get_cmap("tab10")
colors = [mpl.colors.to_hex(cmap(i)) for i in np.linspace(0, 1, len(order))]
else:
assert len(colors) == len(order), 'colors and order must be the same length'
assert len(colors) == len(order), "colors and order must be the same length"
colors = colors

# Boxplot
if show_boxplot:
bp = ax.boxplot([df[df[y_col] == grp][x_col].values for grp in order],
patch_artist=True, vert=False, positions=np.arange(1 + offset, len(order) + 1 + offset), widths=0.2)
bp = ax.boxplot(
[df[df[y_col] == grp][x_col].values for grp in order],
patch_artist=True,
vert=False,
positions=np.arange(1 + offset, len(order) + 1 + offset),
widths=0.2,
)

# Customize boxplot colors
for patch, color in zip(bp['boxes'], colors):
for patch, color in zip(bp["boxes"], colors):
patch.set_facecolor(color)
patch.set_alpha(0.8)

# Set median line color to black
for median in bp['medians']:
median.set_color('black')
for median in bp["medians"]:
median.set_color("black")

# Violinplot
if show_violin:
vp = ax.violinplot([df[df[y_col] == grp][x_col].values for grp in order],
positions=np.arange(1 + offset, len(order) + 1 + offset), showmeans=False, showextrema=False, showmedians=False, vert=False)
vp = ax.violinplot(
[df[df[y_col] == grp][x_col].values for grp in order],
positions=np.arange(1 + offset, len(order) + 1 + offset),
showmeans=False,
showextrema=False,
showmedians=False,
vert=False,
)

# Customize violinplot colors
for idx, b in enumerate(vp['bodies']):
b.get_paths()[0].vertices[:, 1] = np.clip(b.get_paths()[0].vertices[:, 1], idx + 1 + offset, idx + 2 + offset)
for idx, b in enumerate(vp["bodies"]):
b.get_paths()[0].vertices[:, 1] = np.clip(
b.get_paths()[0].vertices[:, 1], idx + 1 + offset, idx + 2 + offset
)
b.set_color(colors[idx])

# Scatterplot with jitter
Expand All @@ -82,7 +96,7 @@ def plot_raincloud(df: pd.DataFrame,
y = np.full(len(features), idx + 1 - offset)
jitter_amount = 0.12
y += np.random.uniform(low=-jitter_amount, high=jitter_amount, size=len(y))
plt.scatter(features, y, s=10, c=colors[idx], alpha=0.3, facecolors='none')
plt.scatter(features, y, s=10, c=colors[idx], alpha=0.3, facecolors="none")

# Labels
plt.yticks(np.arange(1, len(order) + 1), order)
Expand All @@ -91,13 +105,13 @@ def plot_raincloud(df: pd.DataFrame,
x_label = x_col
plt.xlabel(x_label)
if title:
plt.title(title + '\n')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.title(title + "\n")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.xaxis.grid(True)

if x_range:
plt.xlim(x_range)

return fig
4 changes: 0 additions & 4 deletions setup_module.sh

This file was deleted.

0 comments on commit 0363cc3

Please sign in to comment.