diff --git a/plots/seaborn/boxplot/box-basic/default.py b/plots/seaborn/boxplot/box-basic/default.py index d2a1cd1f30..f80e9206ab 100644 --- a/plots/seaborn/boxplot/box-basic/default.py +++ b/plots/seaborn/boxplot/box-basic/default.py @@ -1,189 +1,53 @@ """ box-basic: Basic Box Plot -Implementation for: seaborn -Variant: default -Python: 3.10+ +Library: seaborn """ -from typing import TYPE_CHECKING, Optional - import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns -if TYPE_CHECKING: - from matplotlib.figure import Figure - - -def create_plot( - data: pd.DataFrame, - values: str, - groups: str, - title: Optional[str] = None, - xlabel: Optional[str] = None, - ylabel: Optional[str] = None, - palette: Optional[str] = "Set2", - figsize: tuple[float, float] = (16, 9), - showfliers: bool = True, - **kwargs, -) -> Figure: - """ - Create a basic box plot showing statistical distribution of multiple groups using seaborn. - - Args: - data: Input DataFrame with required columns - values: Column name containing numeric values - groups: Column name containing group categories - title: Plot title (optional) - xlabel: Custom x-axis label (optional, defaults to groups column name) - ylabel: Custom y-axis label (optional, defaults to values column name) - palette: Color palette name for boxes (default: 'Set2') - figsize: Figure size as (width, height) in inches (default: (16, 9)) - showfliers: Whether to show outliers (default: True) - **kwargs: Additional parameters passed to seaborn boxplot function - - Returns: - Matplotlib Figure object - - Raises: - ValueError: If data is empty - KeyError: If required columns not found - - Example: - >>> data = pd.DataFrame({ - ... 'Group': ['A', 'A', 'B', 'B', 'C', 'C'], - ... 'Value': [1, 2, 2, 3, 3, 4] - ... }) - >>> fig = create_plot(data, values='Value', groups='Group') - """ - # Input validation - if data.empty: - raise ValueError("Data cannot be empty") - - # Check required columns - for col in [values, groups]: - if col not in data.columns: - available = ", ".join(data.columns) - raise KeyError(f"Column '{col}' not found. Available columns: {available}") - - # Create figure - fig, ax = plt.subplots(figsize=figsize) - - # Create boxplot with seaborn - sns.boxplot( - data=data, - x=groups, - y=values, - hue=groups, - palette=palette, - ax=ax, - showfliers=showfliers, - width=0.7, - linewidth=1.5, - fliersize=6, - legend=False, - **kwargs, - ) - - # Customize the appearance - # Set median line color to be more visible - for patch in ax.artists: - # Get the current face color - r, g, b, a = patch.get_facecolor() - # Set the box face color with some transparency - patch.set_facecolor((r, g, b, 0.7)) - # Set edge color - patch.set_edgecolor("black") - patch.set_linewidth(1.2) - - # Style the median lines - for line in ax.lines: - # Median lines are the ones inside the boxes - if line.get_linestyle() == "-" and line.get_marker() == "None": - line.set_color("red") - line.set_linewidth(2) - - # Labels and title - ax.set_xlabel(xlabel or groups) - ax.set_ylabel(ylabel or values) - - if title: - ax.set_title(title, fontsize=14, fontweight="bold", pad=20) - - # Grid for better readability - ax.grid(True, axis="y", alpha=0.3, linestyle="--") - ax.set_axisbelow(True) - - # Rotate x-axis labels if there are many groups - unique_groups = data[groups].nunique() - if unique_groups > 5: - plt.xticks(rotation=45, ha="right") - - # Add some statistical annotations - # Calculate and display the number of data points per group - group_counts = data.groupby(groups)[values].count() - y_bottom = ax.get_ylim()[0] - for i, (_group_name, count) in enumerate(group_counts.items()): - ax.text(i, y_bottom, f"n={count}", ha="center", va="top", fontsize=9, alpha=0.7) - - # Apply seaborn style for better aesthetics - sns.despine(ax=ax) - - # Layout - plt.tight_layout() - - return fig - - -if __name__ == "__main__": - # Sample data for testing with different distributions per group - np.random.seed(42) # For reproducibility - - # Generate sample data with 4 groups - data_dict = {"Group": [], "Value": []} - - # Group A: Normal distribution, mean=50, std=10 - group_a_data = np.random.normal(50, 10, 40) - # Add some outliers - group_a_data = np.append(group_a_data, [80, 85, 15]) - - # Group B: Normal distribution, mean=60, std=15 - group_b_data = np.random.normal(60, 15, 35) - # Add outliers - group_b_data = np.append(group_b_data, [100, 10]) - - # Group C: Normal distribution, mean=45, std=8 - group_c_data = np.random.normal(45, 8, 45) - - # Group D: Skewed distribution - group_d_data = np.random.gamma(2, 2, 30) + 40 - # Add outliers - group_d_data = np.append(group_d_data, [75, 78, 20]) - - # Combine all data - for group, values in zip( - ["Group A", "Group B", "Group C", "Group D"], - [group_a_data, group_b_data, group_c_data, group_d_data], - strict=False, - ): - data_dict["Group"].extend([group] * len(values)) - data_dict["Value"].extend(values) - - data = pd.DataFrame(data_dict) - - # Create plot - fig = create_plot( - data, - values="Value", - groups="Group", - title="Statistical Distribution Comparison Across Groups", - ylabel="Measurement Value", - xlabel="Categories", - palette="Set2", - ) - - # Save for inspection - plt.savefig("plot.png", dpi=300, bbox_inches="tight") - print("Plot saved to plot.png") +# Data +np.random.seed(42) +data = pd.DataFrame( + { + "group": ["A"] * 50 + ["B"] * 50 + ["C"] * 50 + ["D"] * 50, + "value": np.concatenate( + [ + np.random.normal(50, 10, 50), + np.random.normal(60, 15, 50), + np.random.normal(45, 8, 50), + np.random.normal(70, 20, 50), + ] + ), + } +) + +# Custom color palette using style guide colors +colors = ["#306998", "#FFD43B", "#DC2626", "#059669"] + +# Create plot +fig, ax = plt.subplots(figsize=(16, 9)) +sns.boxplot( + data=data, + x="group", + y="value", + hue="group", + palette=colors, + legend=False, + ax=ax, + linewidth=2, + flierprops={"marker": "o", "markersize": 8, "alpha": 0.7}, +) + +# Labels and styling +ax.set_xlabel("Group", fontsize=20) +ax.set_ylabel("Value", fontsize=20) +ax.set_title("Basic Box Plot", fontsize=20) +ax.tick_params(axis="both", labelsize=16) +ax.grid(True, alpha=0.3, axis="y") + +plt.tight_layout() +plt.savefig("plot.png", dpi=300, bbox_inches="tight")