From f2fe4dc2b4d8be06984abbcf247ca4800f271d3f Mon Sep 17 00:00:00 2001 From: Arthur HAMARD Date: Mon, 17 Nov 2025 11:18:44 +0100 Subject: [PATCH 1/5] =?UTF-8?q?=F0=9F=8E=89=20feat(utils):=20add=20visuali?= =?UTF-8?q?zation=20functions=20for=20dataset=20analysis?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plaid/utils/__init__.py | 4 + src/plaid/utils/viz.py | 622 ++++++++++++++++++++++++++++++++++++ 2 files changed, 626 insertions(+) create mode 100644 src/plaid/utils/viz.py diff --git a/src/plaid/utils/__init__.py b/src/plaid/utils/__init__.py index c6835793..96f32873 100644 --- a/src/plaid/utils/__init__.py +++ b/src/plaid/utils/__init__.py @@ -6,3 +6,7 @@ # file 'LICENSE.txt', which is part of this source code package. # # + +from plaid.utils.viz import kdeplot, pairplot, scatter_plot + +__all__ = ["scatter_plot", "pairplot", "kdeplot"] diff --git a/src/plaid/utils/viz.py b/src/plaid/utils/viz.py new file mode 100644 index 00000000..554e0a38 --- /dev/null +++ b/src/plaid/utils/viz.py @@ -0,0 +1,622 @@ +"""Visualization utilities for analyzing datasets.""" + +# -*- coding: utf-8 -*- +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# +# + +# %% Imports + +import logging +from typing import Optional, Union + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.figure import Figure + +from plaid.containers.dataset import Dataset +from plaid.containers.sample import Sample + +logger = logging.getLogger(__name__) + + +# %% Functions + + +def scatter_plot( + dataset: Union[Dataset, list[Sample]], + feature_names: Optional[list[str]] = None, + sample_ids: Optional[list[int]] = None, + figsize: Optional[tuple[float, float]] = None, + title: Optional[str] = None, + max_features_per_plot: int = 6, + **kwargs, +) -> Union[Figure, list[Figure]]: + """Create scatter plots of feature values vs sample IDs. + + This function visualizes how feature values vary across samples in the dataset. + Each feature is plotted as a scatter plot with sample IDs on the x-axis and + feature values on the y-axis. Useful for detecting trends, outliers, or patterns. + + Args: + dataset (Union[Dataset, list[Sample]]): The dataset or list of samples to visualize. + feature_names (list[str], optional): List of feature names to plot. If None, plots all scalar features. + For field features, use the format "base_name/zone_name/location/field_name". Defaults to None. + sample_ids (list[int], optional): List of sample IDs to include. If None, uses all samples. Defaults to None. + figsize (tuple[float, float], optional): Figure size (width, height) in inches. + If None, automatically calculated based on number of features. Defaults to None. + title (str, optional): Main title for the plot. Defaults to None. + max_features_per_plot (int, optional): Maximum number of features to display in a single figure. + If more features are requested, multiple figures are created. Defaults to 6. + **kwargs: Additional keyword arguments passed to matplotlib's scatter function. + + Returns: + Union[plt.Figure, list[plt.Figure]]: The created figure(s). Returns a single figure if all features + fit in one plot, otherwise returns a list of figures. + + Raises: + TypeError: If dataset is not a Dataset or list[Sample]. + ValueError: If no features are found or feature_names contains invalid names. + + Example: + >>> from plaid import Dataset + >>> from plaid.utils.viz import scatter_plot + >>> dataset = Dataset("path/to/dataset") + >>> # Plot all scalar features + >>> scatter_plot(dataset) + >>> # Plot specific features + >>> scatter_plot(dataset, feature_names=["temperature", "pressure"]) + >>> # Customize appearance + >>> scatter_plot(dataset, feature_names=["velocity"], figsize=(12, 6), alpha=0.6) + """ + # Input validation + if isinstance(dataset, list): + # Convert list of samples to Dataset for easier handling + temp_dataset = Dataset() + temp_dataset.add_samples(dataset) + dataset = temp_dataset + elif not isinstance(dataset, Dataset): + raise TypeError( + f"dataset must be a Dataset or list[Sample], got {type(dataset)}" + ) + + if len(dataset) == 0: + raise ValueError("Dataset is empty") + + # Get sample IDs + if sample_ids is None: + sample_ids = dataset.get_sample_ids() + else: + # Validate sample_ids + available_ids = dataset.get_sample_ids() + invalid_ids = [sid for sid in sample_ids if sid not in available_ids] + if invalid_ids: + raise ValueError(f"Invalid sample IDs: {invalid_ids}") + + # Get feature names if not provided + if feature_names is None: + feature_names = dataset.get_scalar_names(ids=sample_ids) + if not feature_names: + raise ValueError("No scalar features found in dataset") + + # Validate feature names + all_scalar_names = dataset.get_scalar_names(ids=sample_ids) + all_field_names = dataset.get_field_names(ids=sample_ids) + all_feature_names = all_scalar_names + all_field_names + + invalid_features = [f for f in feature_names if f not in all_feature_names] + if invalid_features: + raise ValueError( + f"Invalid feature names: {invalid_features}. Available features: {all_feature_names}" + ) + + # Collect feature data + feature_data = {} + for feature_name in feature_names: + values = [] + valid_ids = [] + + for sid in sample_ids: + sample = dataset[sid] + + # Try to get as scalar first + if feature_name in sample.get_scalar_names(): + value = sample.get_scalar(feature_name) + if value is not None: + # Handle both scalar and array scalars + if isinstance(value, np.ndarray): + # For array scalars, take the mean + values.append(float(np.mean(value))) + else: + values.append(float(value)) + valid_ids.append(sid) + # Try to get as field + elif "/" in feature_name: + # Parse field identifier: base_name/zone_name/location/field_name + parts = feature_name.split("/") + if len(parts) >= 4: + base_name, zone_name, location, field_name = ( + parts[0], + parts[1], + parts[2], + "/".join(parts[3:]), + ) + try: + field = sample.get_field( + field_name, + location=location, + zone_name=zone_name, + base_name=base_name, + ) + if field is not None: + # For fields, take the mean value + values.append(float(np.mean(field))) + valid_ids.append(sid) + except Exception: + # Field not found in this sample, skip + pass + + if values: + feature_data[feature_name] = (np.array(valid_ids), np.array(values)) + else: + logger.warning(f"No valid data found for feature '{feature_name}'") + + if not feature_data: + raise ValueError("No valid feature data found") + + # Create plots + n_features = len(feature_data) + n_plots = (n_features + max_features_per_plot - 1) // max_features_per_plot + figures = [] + + for plot_idx in range(n_plots): + start_idx = plot_idx * max_features_per_plot + end_idx = min((plot_idx + 1) * max_features_per_plot, n_features) + plot_features = list(feature_data.keys())[start_idx:end_idx] + n_subplot_features = len(plot_features) + + # Calculate subplot layout + n_cols = min(2, n_subplot_features) + n_rows = (n_subplot_features + n_cols - 1) // n_cols + + # Set figure size + if figsize is None: + fig_width = 6 * n_cols + fig_height = 4 * n_rows + fig_size = (fig_width, fig_height) + else: + fig_size = figsize + + fig, axes = plt.subplots(n_rows, n_cols, figsize=fig_size, squeeze=False) + + # Set main title + if title: + if n_plots > 1: + fig.suptitle(f"{title} (Part {plot_idx + 1}/{n_plots})", fontsize=14) + else: + fig.suptitle(title, fontsize=14) + + # Plot each feature + for idx, feature_name in enumerate(plot_features): + row = idx // n_cols + col = idx % n_cols + ax = axes[row, col] + + ids, values = feature_data[feature_name] + ax.scatter(ids, values, **kwargs) + ax.set_xlabel("Sample ID", fontsize=10) + ax.set_ylabel("Value", fontsize=10) + ax.set_title(feature_name, fontsize=11) + ax.grid(True, alpha=0.3) + + # Hide unused subplots + for idx in range(n_subplot_features, n_rows * n_cols): + row = idx // n_cols + col = idx % n_cols + axes[row, col].axis("off") + + plt.tight_layout() + figures.append(fig) + + return figures[0] if len(figures) == 1 else figures + + +def pairplot( + dataset: Union[Dataset, list[Sample]], + scalar_names: Optional[list[str]] = None, + sample_ids: Optional[list[int]] = None, + figsize: Optional[tuple[float, float]] = None, + title: Optional[str] = None, + diag_kind: str = "hist", + corner: bool = False, + **kwargs, +) -> Figure: + """Create a pairplot matrix showing relationships between scalar features. + + This function creates a grid of plots where each off-diagonal subplot shows + a scatter plot of two features, and diagonal subplots show the distribution + of individual features. Useful for detecting correlations and understanding + multivariate relationships. + + Args: + dataset (Union[Dataset, list[Sample]]): The dataset or list of samples to visualize. + scalar_names (list[str], optional): List of scalar feature names to include in the pairplot. + If None, uses all scalar features. Defaults to None. + sample_ids (list[int], optional): List of sample IDs to include. If None, uses all samples. Defaults to None. + figsize (tuple[float, float], optional): Figure size (width, height) in inches. + If None, automatically calculated based on number of features. Defaults to None. + title (str, optional): Main title for the plot. Defaults to None. + diag_kind (str, optional): Type of plot for diagonal subplots. Options are: + - "hist": Histogram + - "kde": Kernel Density Estimation + Defaults to "hist". + corner (bool, optional): If True, only shows the lower triangle of the pairplot. + This reduces redundancy since scatter(x, y) and scatter(y, x) show the same information. + Defaults to False. + **kwargs: Additional keyword arguments passed to matplotlib's scatter function. + + Returns: + Figure: The created figure containing the pairplot. + + Raises: + TypeError: If dataset is not a Dataset or list[Sample]. + ValueError: If dataset is empty, no scalar features found, or invalid parameters. + + Example: + >>> from plaid import Dataset + >>> from plaid.utils.viz import pairplot + >>> dataset = Dataset("path/to/dataset") + >>> # Create pairplot for all scalars + >>> pairplot(dataset) + >>> # Create pairplot for specific scalars + >>> pairplot(dataset, scalar_names=["temperature", "pressure", "density"]) + >>> # Create corner pairplot with KDE on diagonal + >>> pairplot(dataset, diag_kind="kde", corner=True) + """ + # Input validation + if isinstance(dataset, list): + # Convert list of samples to Dataset for easier handling + temp_dataset = Dataset() + temp_dataset.add_samples(dataset) + dataset = temp_dataset + elif not isinstance(dataset, Dataset): + raise TypeError( + f"dataset must be a Dataset or list[Sample], got {type(dataset)}" + ) + + if len(dataset) == 0: + raise ValueError("Dataset is empty") + + if diag_kind not in ["hist", "kde"]: + raise ValueError(f"diag_kind must be 'hist' or 'kde', got '{diag_kind}'") + + # Get sample IDs + if sample_ids is None: + sample_ids = dataset.get_sample_ids() + else: + # Validate sample_ids + available_ids = dataset.get_sample_ids() + invalid_ids = [sid for sid in sample_ids if sid not in available_ids] + if invalid_ids: + raise ValueError(f"Invalid sample IDs: {invalid_ids}") + + # Get scalar names if not provided + if scalar_names is None: + scalar_names = dataset.get_scalar_names(ids=sample_ids) + if not scalar_names: + raise ValueError("No scalar features found in dataset") + else: + # Validate scalar names + all_scalar_names = dataset.get_scalar_names(ids=sample_ids) + invalid_scalars = [s for s in scalar_names if s not in all_scalar_names] + if invalid_scalars: + raise ValueError( + f"Invalid scalar names: {invalid_scalars}. Available scalars: {all_scalar_names}" + ) + + # Collect scalar data + scalar_data = {} + for scalar_name in scalar_names: + values = [] + for sid in sample_ids: + sample = dataset[sid] + value = sample.get_scalar(scalar_name) + if value is not None: + # Handle both scalar and array scalars + if isinstance(value, np.ndarray): + # For array scalars, take the mean + values.append(float(np.mean(value))) + else: + values.append(float(value)) + else: + # Use NaN for missing values + values.append(np.nan) + + scalar_data[scalar_name] = np.array(values) + + if not scalar_data: + raise ValueError("No valid scalar data found") + + # Remove samples with any NaN values + data_matrix = np.column_stack([scalar_data[name] for name in scalar_names]) + valid_mask = ~np.any(np.isnan(data_matrix), axis=1) + data_matrix = data_matrix[valid_mask] + + if len(data_matrix) == 0: + raise ValueError("No samples with complete scalar data found") + + n_features = len(scalar_names) + + # Set figure size + if figsize is None: + fig_size = (3 * n_features, 3 * n_features) + else: + fig_size = figsize + + # Create figure and axes + fig, axes = plt.subplots(n_features, n_features, figsize=fig_size) + + # Handle single feature case + if n_features == 1: + axes = np.array([[axes]]) + + # Set main title + if title: + fig.suptitle(title, fontsize=16, y=0.995) + + # Create pairplot + for i in range(n_features): + for j in range(n_features): + ax = axes[i, j] + + if corner and j > i: + # Hide upper triangle if corner=True + ax.axis("off") + continue + + if i == j: + # Diagonal: plot distribution + data = data_matrix[:, i] + + if diag_kind == "hist": + ax.hist(data, bins=20, edgecolor="black", alpha=0.7) + elif diag_kind == "kde": + # Simple KDE using histogram as approximation + from scipy import stats + + try: + kde = stats.gaussian_kde(data) + x_range = np.linspace(data.min(), data.max(), 100) + ax.plot(x_range, kde(x_range), linewidth=2) + ax.fill_between(x_range, kde(x_range), alpha=0.3) + except Exception: + # Fallback to histogram if KDE fails + ax.hist(data, bins=20, edgecolor="black", alpha=0.7) + logger.warning( + f"KDE failed for {scalar_names[i]}, using histogram instead" + ) + + ax.set_ylabel("Frequency" if diag_kind == "hist" else "Density") + else: + # Off-diagonal: scatter plot + x_data = data_matrix[:, j] + y_data = data_matrix[:, i] + ax.scatter(x_data, y_data, alpha=0.5, **kwargs) + + # Set labels + if i == n_features - 1: + ax.set_xlabel(scalar_names[j], fontsize=10) + else: + ax.set_xticklabels([]) + + if j == 0 and not (corner and i == 0): + ax.set_ylabel(scalar_names[i], fontsize=10) + elif i != j: + ax.set_yticklabels([]) + + ax.grid(True, alpha=0.3) + + plt.tight_layout() + return fig + + +def kdeplot( + dataset: Union[Dataset, list[Sample]], + feature_names: Optional[list[str]] = None, + sample_ids: Optional[list[int]] = None, + figsize: Optional[tuple[float, float]] = None, + title: Optional[str] = None, + fill: bool = True, + bw_method: Optional[Union[str, float]] = None, + **kwargs, +) -> Figure: + """Create kernel density estimation plots for feature distributions. + + This function visualizes the probability density of features using KDE, + which provides a smooth continuous estimate of the distribution. Multiple + features can be overlaid on the same plot for comparison. + + Args: + dataset (Union[Dataset, list[Sample]]): The dataset or list of samples to visualize. + feature_names (list[str], optional): List of feature names to plot. If None, plots all scalar features. + For field features, use the format "base_name/zone_name/location/field_name". Defaults to None. + sample_ids (list[int], optional): List of sample IDs to include. If None, uses all samples. Defaults to None. + figsize (tuple[float, float], optional): Figure size (width, height) in inches. + If None, defaults to (10, 6). Defaults to None. + title (str, optional): Main title for the plot. Defaults to None. + fill (bool, optional): If True, fills the area under the KDE curve. Defaults to True. + bw_method (str or float, optional): Bandwidth selection method for KDE. + Can be 'scott', 'silverman', or a scalar. If None, uses 'scott'. Defaults to None. + **kwargs: Additional keyword arguments passed to matplotlib's plot function. + + Returns: + Figure: The created figure containing the KDE plots. + + Raises: + TypeError: If dataset is not a Dataset or list[Sample]. + ValueError: If dataset is empty or no valid features found. + + Example: + >>> from plaid import Dataset + >>> from plaid.utils.viz import kdeplot + >>> dataset = Dataset("path/to/dataset") + >>> # Create KDE plot for all scalars + >>> kdeplot(dataset) + >>> # Compare distributions of specific features + >>> kdeplot(dataset, feature_names=["temperature", "pressure"]) + >>> # Customize appearance + >>> kdeplot(dataset, fill=False, bw_method='silverman', linewidth=2) + """ + # Input validation + if isinstance(dataset, list): + # Convert list of samples to Dataset for easier handling + temp_dataset = Dataset() + temp_dataset.add_samples(dataset) + dataset = temp_dataset + elif not isinstance(dataset, Dataset): + raise TypeError( + f"dataset must be a Dataset or list[Sample], got {type(dataset)}" + ) + + if len(dataset) == 0: + raise ValueError("Dataset is empty") + + # Get sample IDs + if sample_ids is None: + sample_ids = dataset.get_sample_ids() + else: + # Validate sample_ids + available_ids = dataset.get_sample_ids() + invalid_ids = [sid for sid in sample_ids if sid not in available_ids] + if invalid_ids: + raise ValueError(f"Invalid sample IDs: {invalid_ids}") + + # Get feature names if not provided + if feature_names is None: + feature_names = dataset.get_scalar_names(ids=sample_ids) + if not feature_names: + raise ValueError("No scalar features found in dataset") + + # Validate feature names + all_scalar_names = dataset.get_scalar_names(ids=sample_ids) + all_field_names = dataset.get_field_names(ids=sample_ids) + all_feature_names = all_scalar_names + all_field_names + + invalid_features = [f for f in feature_names if f not in all_feature_names] + if invalid_features: + raise ValueError( + f"Invalid feature names: {invalid_features}. Available features: {all_feature_names}" + ) + + # Collect feature data + feature_data = {} + for feature_name in feature_names: + values = [] + + for sid in sample_ids: + sample = dataset[sid] + + # Try to get as scalar first + if feature_name in sample.get_scalar_names(): + value = sample.get_scalar(feature_name) + if value is not None: + # Handle both scalar and array scalars + if isinstance(value, np.ndarray): + # For array scalars, flatten all values + values.extend(value.flatten().tolist()) + else: + values.append(float(value)) + # Try to get as field + elif "/" in feature_name: + # Parse field identifier: base_name/zone_name/location/field_name + parts = feature_name.split("/") + if len(parts) >= 4: + base_name, zone_name, location, field_name = ( + parts[0], + parts[1], + parts[2], + "/".join(parts[3:]), + ) + try: + field = sample.get_field( + field_name, + location=location, + zone_name=zone_name, + base_name=base_name, + ) + if field is not None: + # For fields, flatten all values + values.extend(field.flatten().tolist()) + except Exception: + # Field not found in this sample, skip + pass + + if values: + feature_data[feature_name] = np.array(values) + else: + logger.warning(f"No valid data found for feature '{feature_name}'") + + if not feature_data: + raise ValueError("No valid feature data found") + + # Set figure size + if figsize is None: + figsize = (10, 6) + + # Create figure + fig, ax = plt.subplots(figsize=figsize) + + # Set main title + if title: + fig.suptitle(title, fontsize=14) + + # Import scipy for KDE + # Plot KDE for each feature + from matplotlib import cm + from scipy import stats + + colors = cm.get_cmap("tab10")(np.linspace(0, 1, len(feature_data))) + + for idx, (feature_name, data) in enumerate(feature_data.items()): + # Remove NaN values + data = data[~np.isnan(data)] + + if len(data) < 2: + logger.warning( + f"Skipping '{feature_name}': need at least 2 data points for KDE" + ) + continue + + try: + # Compute KDE + kde = stats.gaussian_kde(data, bw_method=bw_method) + + # Create evaluation points + x_min, x_max = data.min(), data.max() + x_range = x_max - x_min + x_eval = np.linspace(x_min - 0.1 * x_range, x_max + 0.1 * x_range, 200) + + # Evaluate KDE + density = kde(x_eval) + + # Plot + color = colors[idx] + ax.plot(x_eval, density, label=feature_name, color=color, **kwargs) + + if fill: + ax.fill_between(x_eval, density, alpha=0.3, color=color) + + except Exception as e: + logger.warning(f"KDE failed for '{feature_name}': {e}") + continue + + # Configure plot + ax.set_xlabel("Value", fontsize=12) + ax.set_ylabel("Density", fontsize=12) + ax.grid(True, alpha=0.3) + ax.legend(loc="best", fontsize=10) + + plt.tight_layout() + return fig From 4e516b2c3e06ee3f8f8ab1d0feae009bbf427cb9 Mon Sep 17 00:00:00 2001 From: Arthur HAMARD Date: Mon, 17 Nov 2025 11:57:46 +0100 Subject: [PATCH 2/5] =?UTF-8?q?=F0=9F=90=9B=20fix(viz):=20resolve=20duplic?= =?UTF-8?q?ate=20alpha=20parameter=20error=20in=20pairplot=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/utils/viz_example.py | 313 +++++++++++++++++++++++++ src/plaid/utils/viz.py | 24 +- tests/utils/test_viz.py | 423 ++++++++++++++++++++++++++++++++++ 3 files changed, 755 insertions(+), 5 deletions(-) create mode 100644 examples/utils/viz_example.py create mode 100644 tests/utils/test_viz.py diff --git a/examples/utils/viz_example.py b/examples/utils/viz_example.py new file mode 100644 index 00000000..529587a9 --- /dev/null +++ b/examples/utils/viz_example.py @@ -0,0 +1,313 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.17.3 +# kernelspec: +# display_name: Python 3 +# language: python +# name: python3 +# --- + +# %% [markdown] +# # Visualization Examples +# +# This notebook demonstrates the visualization capabilities of PLAID for analyzing datasets. +# It covers three main visualization functions: +# +# 1. **scatter_plot**: Visualize how feature values vary across samples +# 2. **pairplot**: Show pairwise relationships between scalar features +# 3. **kdeplot**: Display probability density distributions of features +# +# Each visualization function helps in understanding different aspects of your dataset: +# - Detecting trends and outliers +# - Understanding correlations between features +# - Comparing distributions +# +# **Each section is documented and explained.** + +# %% +# Import required libraries +import numpy as np +import matplotlib.pyplot as plt + +# %% +# Import necessary PLAID classes and visualization functions +from plaid import Dataset, Sample +from plaid.utils.viz import scatter_plot, pairplot, kdeplot + +# %% [markdown] +# ## Section 1: Creating a Sample Dataset +# +# First, we'll create a sample dataset with multiple samples containing both scalar +# and field features. This dataset will be used to demonstrate all visualization functions. + +# %% +print("#---# Create sample dataset") + +# Number of samples to create +n_samples = 50 + +# Create samples with scalars and fields +samples = [] +for i in range(n_samples): + sample = Sample() + + # Add scalar features with some relationships + # temperature and pressure are correlated + temperature = 20 + 5 * np.random.randn() + 0.5 * i + pressure = 100 + 10 * np.random.randn() + 0.8 * temperature + + # density is independent + density = 1.2 + 0.2 * np.random.randn() + + # velocity has a trend + velocity = 10 + 0.1 * i + 2 * np.random.randn() + + sample.add_scalar("temperature", temperature) + sample.add_scalar("pressure", pressure) + sample.add_scalar("density", density) + sample.add_scalar("velocity", velocity) + + # Add a field feature + sample.init_base(2, 3, "mesh_base") + zone_shape = np.array([100, 0, 0]) + sample.init_zone(zone_shape, zone_name="zone_1") + + # Create a field with spatial variation + field_data = np.sin(np.linspace(0, 2 * np.pi, 100)) * (1 + 0.1 * i) + sample.add_field("displacement", field_data) + + samples.append(sample) + +# Create dataset from samples +dataset = Dataset() +dataset.add_samples(samples) + +print(f"Created dataset with {len(dataset)} samples") +print(f"Scalar features: {dataset.get_scalar_names()}") +print(f"Field features: {dataset.get_field_names()}") + +# %% [markdown] +# ## Section 2: Scatter Plot - Feature vs Sample ID +# +# The scatter_plot function visualizes how feature values change across samples. +# This is useful for detecting trends, outliers, or patterns in your data. + +# %% [markdown] +# ### Example 2.1: Plot all scalar features + +# %% +print("#---# Scatter plot of all scalar features") + +fig = scatter_plot(dataset) +plt.show() + +# %% [markdown] +# ### Example 2.2: Plot specific features with customization + +# %% +print("#---# Scatter plot of specific features") + +fig = scatter_plot( + dataset, + feature_names=["temperature", "velocity"], + figsize=(12, 5), + alpha=0.6, + s=50, # marker size + c="blue", # marker color + title="Temperature and Velocity Trends", +) +plt.show() + +# %% [markdown] +# ### Example 2.3: Plot subset of samples + +# %% +print("#---# Scatter plot for first 20 samples only") + +# Get first 20 sample IDs +sample_ids = dataset.get_sample_ids()[:20] + +fig = scatter_plot( + dataset, + feature_names=["pressure", "density"], + sample_ids=sample_ids, + title="Pressure and Density (First 20 Samples)", +) +plt.show() + +# %% [markdown] +# ## Section 3: Pairplot - Relationships Between Features +# +# The pairplot function creates a matrix showing pairwise relationships between features. +# The diagonal shows distributions, while off-diagonal plots show scatter plots. +# This helps identify correlations and multivariate patterns. + +# %% [markdown] +# ### Example 3.1: Basic pairplot with all scalars + +# %% +print("#---# Pairplot of all scalar features") + +fig = pairplot(dataset) +plt.show() + +# %% [markdown] +# ### Example 3.2: Pairplot with specific features and KDE on diagonal + +# %% +print("#---# Pairplot with KDE on diagonal") + +fig = pairplot( + dataset, + scalar_names=["temperature", "pressure", "velocity"], + diag_kind="kde", + title="Feature Relationships (KDE on diagonal)", +) +plt.show() + +# %% [markdown] +# ### Example 3.3: Corner pairplot (lower triangle only) + +# %% +print("#---# Corner pairplot") + +fig = pairplot( + dataset, + scalar_names=["temperature", "pressure", "density"], + corner=True, + diag_kind="hist", + title="Corner Pairplot", + alpha=0.5, +) +plt.show() + +# %% [markdown] +# ## Section 4: KDE Plot - Distribution Comparison +# +# The kdeplot function shows smooth probability density estimates for features. +# Multiple features can be overlaid on the same plot for easy comparison. + +# %% [markdown] +# ### Example 4.1: KDE plot of all scalar features + +# %% +print("#---# KDE plot of all scalar features") + +fig = kdeplot(dataset) +plt.show() + +# %% [markdown] +# ### Example 4.2: Compare distributions of specific features + +# %% +print("#---# KDE plot comparing specific features") + +fig = kdeplot( + dataset, + feature_names=["temperature", "velocity"], + title="Temperature vs Velocity Distributions", + fill=True, +) +plt.show() + +# %% [markdown] +# ### Example 4.3: KDE plot without fill and custom bandwidth + +# %% +print("#---# KDE plot with custom styling") + +fig = kdeplot( + dataset, + feature_names=["pressure", "density"], + fill=False, + bw_method="silverman", # alternative bandwidth method + linewidth=2.5, + title="Pressure and Density Distributions (No Fill)", +) +plt.show() + +# %% [markdown] +# ### Example 4.4: KDE plot for a single feature + +# %% +print("#---# KDE plot for single feature") + +fig = kdeplot( + dataset, + feature_names=["temperature"], + title="Temperature Distribution", + figsize=(8, 5), +) +plt.show() + +# %% [markdown] +# ## Section 5: Combining Visualizations +# +# Often, it's useful to combine multiple visualization types to get a complete +# picture of your data. + +# %% +print("#---# Create a combined visualization layout") + +# Create a figure with multiple subplots +fig = plt.figure(figsize=(15, 10)) + +# Subplot 1: Scatter plot of temperature +ax1 = plt.subplot(2, 2, 1) +temp_ids = dataset.get_sample_ids() +temp_values = [dataset[sid].get_scalar("temperature") for sid in temp_ids] +ax1.scatter(temp_ids, temp_values, alpha=0.6) +ax1.set_xlabel("Sample ID") +ax1.set_ylabel("Temperature") +ax1.set_title("Temperature Trend") +ax1.grid(True, alpha=0.3) + +# Subplot 2: Scatter plot of pressure vs temperature +ax2 = plt.subplot(2, 2, 2) +pressure_values = [dataset[sid].get_scalar("pressure") for sid in temp_ids] +ax2.scatter(temp_values, pressure_values, alpha=0.6) +ax2.set_xlabel("Temperature") +ax2.set_ylabel("Pressure") +ax2.set_title("Pressure vs Temperature") +ax2.grid(True, alpha=0.3) + +# Subplot 3: Histograms +ax3 = plt.subplot(2, 2, 3) +ax3.hist(temp_values, bins=15, alpha=0.7, label="Temperature") +ax3.set_xlabel("Value") +ax3.set_ylabel("Frequency") +ax3.set_title("Temperature Distribution (Histogram)") +ax3.legend() +ax3.grid(True, alpha=0.3) + +# Subplot 4: KDE comparison +ax4 = plt.subplot(2, 2, 4) +from scipy import stats + +# Temperature KDE +kde_temp = stats.gaussian_kde(temp_values) +x_temp = np.linspace(min(temp_values), max(temp_values), 100) +ax4.plot(x_temp, kde_temp(x_temp), label="Temperature", linewidth=2) +ax4.fill_between(x_temp, kde_temp(x_temp), alpha=0.3) + +# Velocity KDE +velocity_values = [dataset[sid].get_scalar("velocity") for sid in temp_ids] +kde_vel = stats.gaussian_kde(velocity_values) +x_vel = np.linspace(min(velocity_values), max(velocity_values), 100) +ax4.plot(x_vel, kde_vel(x_vel), label="Velocity", linewidth=2) +ax4.fill_between(x_vel, kde_vel(x_vel), alpha=0.3) + +ax4.set_xlabel("Value") +ax4.set_ylabel("Density") +ax4.set_title("KDE Comparison") +ax4.legend() +ax4.grid(True, alpha=0.3) + +plt.tight_layout() +plt.show() \ No newline at end of file diff --git a/src/plaid/utils/viz.py b/src/plaid/utils/viz.py index 554e0a38..ef378597 100644 --- a/src/plaid/utils/viz.py +++ b/src/plaid/utils/viz.py @@ -7,17 +7,19 @@ # # -# %% Imports +from __future__ import annotations +# %% Imports import logging -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union import matplotlib.pyplot as plt import numpy as np from matplotlib.figure import Figure -from plaid.containers.dataset import Dataset -from plaid.containers.sample import Sample +if TYPE_CHECKING: + from plaid.containers.dataset import Dataset + from plaid.containers.sample import Sample logger = logging.getLogger(__name__) @@ -71,6 +73,9 @@ def scatter_plot( >>> # Customize appearance >>> scatter_plot(dataset, feature_names=["velocity"], figsize=(12, 6), alpha=0.6) """ + # Lazy import to avoid circular dependency + from plaid.containers.dataset import Dataset + # Input validation if isinstance(dataset, list): # Convert list of samples to Dataset for easier handling @@ -275,6 +280,9 @@ def pairplot( >>> # Create corner pairplot with KDE on diagonal >>> pairplot(dataset, diag_kind="kde", corner=True) """ + # Lazy import to avoid circular dependency + from plaid.containers.dataset import Dataset + # Input validation if isinstance(dataset, list): # Convert list of samples to Dataset for easier handling @@ -403,7 +411,10 @@ def pairplot( # Off-diagonal: scatter plot x_data = data_matrix[:, j] y_data = data_matrix[:, i] - ax.scatter(x_data, y_data, alpha=0.5, **kwargs) + # Set default alpha if not provided in kwargs + scatter_kwargs = {"alpha": 0.5} + scatter_kwargs.update(kwargs) + ax.scatter(x_data, y_data, **scatter_kwargs) # Set labels if i == n_features - 1: @@ -469,6 +480,9 @@ def kdeplot( >>> # Customize appearance >>> kdeplot(dataset, fill=False, bw_method='silverman', linewidth=2) """ + # Lazy import to avoid circular dependency + from plaid.containers.dataset import Dataset + # Input validation if isinstance(dataset, list): # Convert list of samples to Dataset for easier handling diff --git a/tests/utils/test_viz.py b/tests/utils/test_viz.py new file mode 100644 index 00000000..0f12ddb1 --- /dev/null +++ b/tests/utils/test_viz.py @@ -0,0 +1,423 @@ +# -*- coding: utf-8 -*- +# +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# +# + +# %% Imports + +import matplotlib.pyplot as plt +import numpy as np +import pytest + +from plaid.containers.dataset import Dataset +from plaid.containers.sample import Sample +from plaid.utils.viz import kdeplot, pairplot, scatter_plot + +# %% Fixtures + + +@pytest.fixture() +def sample_dataset(): + """Create a dataset with multiple samples for testing.""" + samples = [] + n_samples = 20 + + for i in range(n_samples): + sample = Sample() + + # Add scalar features + sample.add_scalar("temperature", 20.0 + i * 0.5 + np.random.randn() * 0.1) + sample.add_scalar("pressure", 100.0 + i * 2.0 + np.random.randn() * 0.5) + sample.add_scalar("density", 1.2 + np.random.randn() * 0.05) + + # Add field feature + sample.init_base(2, 3, "base_1") + zone_shape = np.array([50, 0, 0]) + sample.init_zone(zone_shape, zone_name="zone_1") + sample.set_nodes(np.random.randn(50, 3)) + + field_data = np.random.randn(50) * (1 + i * 0.1) + sample.add_field("velocity", field_data) + + samples.append(sample) + + dataset = Dataset() + dataset.add_samples(samples) + return dataset + + +@pytest.fixture() +def empty_dataset(): + """Create an empty dataset for testing error handling.""" + return Dataset() + + +@pytest.fixture() +def single_sample_dataset(): + """Create a dataset with a single sample.""" + sample = Sample() + sample.add_scalar("test_scalar", 42.0) + + dataset = Dataset() + dataset.add_samples([sample]) + return dataset + + +# %% Tests for scatter_plot + + +def test_scatter_plot_basic(sample_dataset): + """Test basic scatter plot functionality.""" + fig = scatter_plot(sample_dataset) + assert fig is not None + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +def test_scatter_plot_with_feature_names(sample_dataset): + """Test scatter plot with specific feature names.""" + fig = scatter_plot(sample_dataset, feature_names=["temperature", "pressure"]) + assert fig is not None + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +def test_scatter_plot_with_sample_ids(sample_dataset): + """Test scatter plot with specific sample IDs.""" + sample_ids = sample_dataset.get_sample_ids()[:10] + fig = scatter_plot(sample_dataset, sample_ids=sample_ids) + assert fig is not None + plt.close(fig) + + +def test_scatter_plot_custom_figsize(sample_dataset): + """Test scatter plot with custom figure size.""" + fig = scatter_plot(sample_dataset, figsize=(12, 8)) + assert fig is not None + assert fig.get_figwidth() == 12 + assert fig.get_figheight() == 8 + plt.close(fig) + + +def test_scatter_plot_with_title(sample_dataset): + """Test scatter plot with title.""" + fig = scatter_plot(sample_dataset, title="Test Plot") + assert fig is not None + plt.close(fig) + + +def test_scatter_plot_with_kwargs(sample_dataset): + """Test scatter plot with additional matplotlib kwargs.""" + fig = scatter_plot(sample_dataset, alpha=0.5, s=100, c="red") + assert fig is not None + plt.close(fig) + + +def test_scatter_plot_list_of_samples(sample_dataset): + """Test scatter plot with list of samples instead of dataset.""" + samples = [sample_dataset[sid] for sid in sample_dataset.get_sample_ids()] + fig = scatter_plot(samples) + assert fig is not None + plt.close(fig) + + +def test_scatter_plot_empty_dataset_raises_error(empty_dataset): + """Test that scatter plot raises error for empty dataset.""" + with pytest.raises(ValueError, match="Dataset is empty"): + scatter_plot(empty_dataset) + + +def test_scatter_plot_invalid_feature_name(sample_dataset): + """Test that scatter plot raises error for invalid feature names.""" + with pytest.raises(ValueError, match="Invalid feature names"): + scatter_plot(sample_dataset, feature_names=["nonexistent_feature"]) + + +def test_scatter_plot_invalid_sample_ids(sample_dataset): + """Test that scatter plot raises error for invalid sample IDs.""" + with pytest.raises(ValueError, match="Invalid sample IDs"): + scatter_plot(sample_dataset, sample_ids=[999, 1000]) + + +def test_scatter_plot_invalid_type(): + """Test that scatter plot raises error for invalid input type.""" + with pytest.raises(TypeError, match="dataset must be a Dataset or list"): + scatter_plot("invalid_input") + + +# %% Tests for pairplot + + +def test_pairplot_basic(sample_dataset): + """Test basic pairplot functionality.""" + fig = pairplot(sample_dataset) + assert fig is not None + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +def test_pairplot_with_scalar_names(sample_dataset): + """Test pairplot with specific scalar names.""" + fig = pairplot(sample_dataset, scalar_names=["temperature", "pressure"]) + assert fig is not None + plt.close(fig) + + +def test_pairplot_with_kde_diagonal(sample_dataset): + """Test pairplot with KDE on diagonal.""" + fig = pairplot(sample_dataset, diag_kind="kde") + assert fig is not None + plt.close(fig) + + +def test_pairplot_with_hist_diagonal(sample_dataset): + """Test pairplot with histogram on diagonal.""" + fig = pairplot(sample_dataset, diag_kind="hist") + assert fig is not None + plt.close(fig) + + +def test_pairplot_corner_mode(sample_dataset): + """Test pairplot in corner mode (lower triangle only).""" + fig = pairplot(sample_dataset, corner=True) + assert fig is not None + plt.close(fig) + + +def test_pairplot_with_sample_ids(sample_dataset): + """Test pairplot with specific sample IDs.""" + sample_ids = sample_dataset.get_sample_ids()[:10] + fig = pairplot(sample_dataset, sample_ids=sample_ids) + assert fig is not None + plt.close(fig) + + +def test_pairplot_custom_figsize(sample_dataset): + """Test pairplot with custom figure size.""" + fig = pairplot(sample_dataset, figsize=(15, 15)) + assert fig is not None + plt.close(fig) + + +def test_pairplot_with_title(sample_dataset): + """Test pairplot with title.""" + fig = pairplot(sample_dataset, title="Test Pairplot") + assert fig is not None + plt.close(fig) + + +def test_pairplot_single_feature(sample_dataset): + """Test pairplot with single feature.""" + fig = pairplot(sample_dataset, scalar_names=["temperature"]) + assert fig is not None + plt.close(fig) + + +def test_pairplot_list_of_samples(sample_dataset): + """Test pairplot with list of samples.""" + samples = [sample_dataset[sid] for sid in sample_dataset.get_sample_ids()] + fig = pairplot(samples) + assert fig is not None + plt.close(fig) + + +def test_pairplot_empty_dataset_raises_error(empty_dataset): + """Test that pairplot raises error for empty dataset.""" + with pytest.raises(ValueError, match="Dataset is empty"): + pairplot(empty_dataset) + + +def test_pairplot_invalid_diag_kind(sample_dataset): + """Test that pairplot raises error for invalid diag_kind.""" + with pytest.raises(ValueError, match="diag_kind must be"): + pairplot(sample_dataset, diag_kind="invalid") + + +def test_pairplot_invalid_scalar_names(sample_dataset): + """Test that pairplot raises error for invalid scalar names.""" + with pytest.raises(ValueError, match="Invalid scalar names"): + pairplot(sample_dataset, scalar_names=["nonexistent_scalar"]) + + +def test_pairplot_invalid_sample_ids(sample_dataset): + """Test that pairplot raises error for invalid sample IDs.""" + with pytest.raises(ValueError, match="Invalid sample IDs"): + pairplot(sample_dataset, sample_ids=[999, 1000]) + + +def test_pairplot_invalid_type(): + """Test that pairplot raises error for invalid input type.""" + with pytest.raises(TypeError, match="dataset must be a Dataset or list"): + pairplot(123) + + +# %% Tests for kdeplot + + +def test_kdeplot_basic(sample_dataset): + """Test basic kdeplot functionality.""" + fig = kdeplot(sample_dataset) + assert fig is not None + assert isinstance(fig, plt.Figure) + plt.close(fig) + + +def test_kdeplot_with_feature_names(sample_dataset): + """Test kdeplot with specific feature names.""" + fig = kdeplot(sample_dataset, feature_names=["temperature", "pressure"]) + assert fig is not None + plt.close(fig) + + +def test_kdeplot_single_feature(sample_dataset): + """Test kdeplot with single feature.""" + fig = kdeplot(sample_dataset, feature_names=["temperature"]) + assert fig is not None + plt.close(fig) + + +def test_kdeplot_with_sample_ids(sample_dataset): + """Test kdeplot with specific sample IDs.""" + sample_ids = sample_dataset.get_sample_ids()[:10] + fig = kdeplot(sample_dataset, sample_ids=sample_ids) + assert fig is not None + plt.close(fig) + + +def test_kdeplot_no_fill(sample_dataset): + """Test kdeplot without fill.""" + fig = kdeplot(sample_dataset, fill=False) + assert fig is not None + plt.close(fig) + + +def test_kdeplot_with_fill(sample_dataset): + """Test kdeplot with fill.""" + fig = kdeplot(sample_dataset, fill=True) + assert fig is not None + plt.close(fig) + + +def test_kdeplot_custom_bandwidth(sample_dataset): + """Test kdeplot with custom bandwidth method.""" + fig = kdeplot(sample_dataset, bw_method="silverman") + assert fig is not None + plt.close(fig) + + +def test_kdeplot_custom_figsize(sample_dataset): + """Test kdeplot with custom figure size.""" + fig = kdeplot(sample_dataset, figsize=(12, 8)) + assert fig is not None + assert fig.get_figwidth() == 12 + assert fig.get_figheight() == 8 + plt.close(fig) + + +def test_kdeplot_with_title(sample_dataset): + """Test kdeplot with title.""" + fig = kdeplot(sample_dataset, title="Test KDE Plot") + assert fig is not None + plt.close(fig) + + +def test_kdeplot_with_kwargs(sample_dataset): + """Test kdeplot with additional matplotlib kwargs.""" + fig = kdeplot(sample_dataset, linewidth=3) + assert fig is not None + plt.close(fig) + + +def test_kdeplot_list_of_samples(sample_dataset): + """Test kdeplot with list of samples.""" + samples = [sample_dataset[sid] for sid in sample_dataset.get_sample_ids()] + fig = kdeplot(samples) + assert fig is not None + plt.close(fig) + + +def test_kdeplot_empty_dataset_raises_error(empty_dataset): + """Test that kdeplot raises error for empty dataset.""" + with pytest.raises(ValueError, match="Dataset is empty"): + kdeplot(empty_dataset) + + +def test_kdeplot_invalid_feature_name(sample_dataset): + """Test that kdeplot raises error for invalid feature names.""" + with pytest.raises(ValueError, match="Invalid feature names"): + kdeplot(sample_dataset, feature_names=["nonexistent_feature"]) + + +def test_kdeplot_invalid_sample_ids(sample_dataset): + """Test that kdeplot raises error for invalid sample IDs.""" + with pytest.raises(ValueError, match="Invalid sample IDs"): + kdeplot(sample_dataset, sample_ids=[999, 1000]) + + +def test_kdeplot_invalid_type(): + """Test that kdeplot raises error for invalid input type.""" + with pytest.raises(TypeError, match="dataset must be a Dataset or list"): + kdeplot({"invalid": "type"}) + + +# %% Integration tests + + +def test_all_functions_with_minimal_data(single_sample_dataset): + """Test that all functions can handle minimal datasets.""" + # scatter_plot should work with single sample + fig1 = scatter_plot(single_sample_dataset) + assert fig1 is not None + plt.close(fig1) + + # pairplot should work with single sample + fig2 = pairplot(single_sample_dataset) + assert fig2 is not None + plt.close(fig2) + + # kdeplot might fail with single sample due to KDE requirements + # We expect it to potentially raise a warning but not crash + try: + fig3 = kdeplot(single_sample_dataset) + if fig3 is not None: + plt.close(fig3) + except Exception: + # KDE may fail with insufficient data, which is acceptable + pass + + +def test_multiple_features_scatter_plot(sample_dataset): + """Test scatter plot with many features (multiple figures).""" + # Create many features to trigger multiple figures + feature_names = ["temperature", "pressure", "density"] + figs = scatter_plot( + sample_dataset, feature_names=feature_names, max_features_per_plot=2 + ) + + # Should return list of figures when features exceed max_features_per_plot + if isinstance(figs, list): + assert len(figs) > 1 + for fig in figs: + plt.close(fig) + else: + plt.close(figs) + + +def test_visualization_functions_consistency(sample_dataset): + """Test that all visualization functions produce valid outputs.""" + # Test scatter_plot + fig1 = scatter_plot(sample_dataset, feature_names=["temperature"]) + assert fig1 is not None + plt.close(fig1) + + # Test pairplot + fig2 = pairplot(sample_dataset, scalar_names=["temperature", "pressure"]) + assert fig2 is not None + plt.close(fig2) + + # Test kdeplot + fig3 = kdeplot(sample_dataset, feature_names=["temperature"]) + assert fig3 is not None + plt.close(fig3) From 15aa98fa4257410bc8887144412ce5ecb658c163 Mon Sep 17 00:00:00 2001 From: Arthur HAMARD Date: Mon, 17 Nov 2025 12:17:55 +0100 Subject: [PATCH 3/5] =?UTF-8?q?=F0=9F=90=9B=20fix(viz):=20remove=20warning?= =?UTF-8?q?=20from=20matplotlib?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plaid/utils/viz.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/plaid/utils/viz.py b/src/plaid/utils/viz.py index ef378597..c1c12155 100644 --- a/src/plaid/utils/viz.py +++ b/src/plaid/utils/viz.py @@ -630,7 +630,10 @@ def kdeplot( ax.set_xlabel("Value", fontsize=12) ax.set_ylabel("Density", fontsize=12) ax.grid(True, alpha=0.3) - ax.legend(loc="best", fontsize=10) + + # Only add legend if there are labeled artists + if ax.get_legend_handles_labels()[0]: + ax.legend(loc="best", fontsize=10) plt.tight_layout() return fig From 4a16d607fc289af2e0bebb3295809e6f46c62388 Mon Sep 17 00:00:00 2001 From: Arthur HAMARD Date: Mon, 17 Nov 2025 12:21:14 +0100 Subject: [PATCH 4/5] =?UTF-8?q?=F0=9F=90=9B=20fix(actions):=20add=20kokko-?= =?UTF-8?q?lib=20to=20conda=20environment?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 3bb8268d..9b88a650 100644 --- a/environment.yml +++ b/environment.yml @@ -22,6 +22,7 @@ dependencies: ##### DEV/TESTS/EXAMPLES ##### #---# mesh/graph libs - muscat-core=2.5 + - kokkos>=4.6.2,<4.7 # Required for muscat-core compatibility #---# base - rich #---# optim From cffa2faa0e9f87d998211017f8748d2011df24ab Mon Sep 17 00:00:00 2001 From: Fabien Casenave Date: Thu, 20 Nov 2025 21:27:11 +0100 Subject: [PATCH 5/5] revert env modifs --- environment.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/environment.yml b/environment.yml index 9b88a650..3bb8268d 100644 --- a/environment.yml +++ b/environment.yml @@ -22,7 +22,6 @@ dependencies: ##### DEV/TESTS/EXAMPLES ##### #---# mesh/graph libs - muscat-core=2.5 - - kokkos>=4.6.2,<4.7 # Required for muscat-core compatibility #---# base - rich #---# optim