diff --git a/CHANGELOG.md b/CHANGELOG.md index c809c52b0..588b7c3e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 0.16.3 + +### Fix + +- Dataset : Update plot to new decorator paradigm + + ## 0.16.2 ### Fix diff --git a/code_pylint.py b/code_pylint.py index a1fcaa7d6..ec305f3a2 100644 --- a/code_pylint.py +++ b/code_pylint.py @@ -29,16 +29,16 @@ "protected-access": 48, # Highly dependant on our "private" conventions. Keeps getting raised "arguments-differ": 1, "too-many-locals": 5, # Reduce by dropping vectored objects - "too-many-branches": 9, # Huge refactor needed. Will be reduced by schema refactor + "too-many-branches": 8, # Huge refactor needed. Will be reduced by schema refactor "unused-argument": 3, # Some abstract functions have unused arguments (plot_data). Hence cannot decrease "cyclic-import": 2, # Still work to do on Specific based DessiaObject "too-many-arguments": 18, # Huge refactor needed "too-few-public-methods": 3, # Abstract classes (Errors, Checks,...) - "too-many-return-statements": 8, # Huge refactor needed. Will be reduced by schema refactor + "too-many-return-statements": 7, # Huge refactor needed. Will be reduced by schema refactor "import-outside-toplevel": 5, # TODO : will reduced in a future work (when tests are ready) "too-many-instance-attributes": 6, # Huge refactor needed (workflow, etc...) "broad-exception-caught": 9, # Necessary in order not to raise non critical errors. - "too-many-public-methods": 2, # Try to lower by splitting DessiaObject and Workflow + "too-many-public-methods": 3, # Try to lower by splitting DessiaObject and Workflow } ERRORS_WITHOUT_TIME_DECREASE = ['protected-access', 'arguments-differ', 'too-many-locals', 'too-many-branches', diff --git a/dessia_common/datatools/dataset.py b/dessia_common/datatools/dataset.py index 03d811e23..b6f8dfdeb 100644 --- a/dessia_common/datatools/dataset.py +++ b/dessia_common/datatools/dataset.py @@ -1,22 +1,24 @@ """ Library for building Dataset. """ -from typing import List, Dict, Any -from copy import copy import itertools +from copy import copy +from typing import Any, Dict, List -from scipy.spatial.distance import pdist, squareform import numpy as npy +from scipy.spatial.distance import pdist, squareform from sklearn import preprocessing try: from plot_data.core import Scatter, Histogram, MultiplePlots, Tooltip, ParallelPlot, PointFamily, EdgeStyle, Axis, \ PointStyle, Sample - from plot_data.colors import BLUE, GREY + from plot_data.colors import BLUE, GREY, Color except ImportError: pass -from dessia_common.core import DessiaObject, DessiaFilter, FiltersList -from dessia_common.exports import MarkdownWriter from dessia_common import templates -from dessia_common.datatools.metrics import mean, std, variance, covariance_matrix +from dessia_common.core import DessiaFilter, DessiaObject, FiltersList +from dessia_common.datatools.metrics import (covariance_matrix, mean, std, + variance) +from dessia_common.decorators import plot_data_view +from dessia_common.exports import MarkdownWriter class Dataset(DessiaObject): @@ -622,25 +624,32 @@ def _scale_data(data_matrix: List[List[float]]): scaled_matrix = preprocessing.StandardScaler().fit_transform(data_matrix) return [list(map(float, row.tolist())) for row in scaled_matrix] - def plot_data(self, reference_path: str = "#", **kwargs): - """ Plot a standard scatter matrix of all attributes in common_attributes and a dimensionality plot. """ + @plot_data_view(selector='DataSet scatter matrix') + def plot_scatter_matrix(self, reference_path: str = "#", **kwargs): + """ Plot a scatter matrix for attributes in common_attributes. """ + data_list = self._to_samples(reference_path=reference_path) + if len(self.common_attributes) > 1: + scatter_matrix = self._build_multiplot(data_list, self._tooltip_attributes()) + return scatter_matrix + raise ValueError("Scatter matrix can only be plotted with more than one common attribute.") + + @plot_data_view(selector='DataSet parallel plot') + def plot_parallel_plot(self, reference_path: str = "#", **kwargs): + """ Plot a parallel plot for attributes in common_attributes. """ data_list = self._to_samples(reference_path=reference_path) if len(self.common_attributes) > 1: - # Plot a correlation matrix : To develop - # correlation_matrix = [] - # Dimensionality plot - dimensionality_plot = self._plot_dimensionality() - # Scatter Matrix - scatter_matrix = self._build_multiplot(data_list, self._tooltip_attributes(), - axis=dimensionality_plot.axis, - point_style=dimensionality_plot.point_style) - # Parallel plot parallel_plot = self._parallel_plot(data_list) - return [parallel_plot, scatter_matrix] # , dimensionality_plot] + return parallel_plot + if len(self.common_attributes) == 1: + return self.plot_histogram(reference_path=reference_path) + raise ValueError("No common attributes found for plotting a parallel plot.") + def plot_histogram(self, reference_path: str = "#", **kwargs): + """ Plot a histogram when there's only one common attribute. """ + data_list = self._to_samples(reference_path=reference_path) plot_mono_attr = self._histogram_unic_value(0, name_attr=self.common_attributes[0]) plot_mono_attr.elements = data_list - return [plot_mono_attr] + return plot_mono_attr def _build_multiplot(self, data_list: List[Dict[str, float]], tooltip: List[str], **kwargs: Dict[str, Any]): subplots = [] @@ -677,7 +686,7 @@ def _to_samples(self, reference_path: str = '#'): for row, dessia_object in enumerate(self.dessia_objects)] def _point_families(self): - return [PointFamily(GREY, list(range(len(self))))] + return [PointFamily(Color(182/255, 225/255, 251/255), list(range(len(self))))] def _parallel_plot(self, data_list: List[Dict[str, float]]): return ParallelPlot(elements=data_list, axes=self._parallel_plot_attr(), disposition='vertical') diff --git a/scripts/dataset.py b/scripts/dataset.py index a5ad789b0..dc662940b 100644 --- a/scripts/dataset.py +++ b/scripts/dataset.py @@ -132,10 +132,10 @@ def to_vector(self): assert(empty_list == Dataset()) try: - empty_list.plot_data() + empty_list.plot_scatter_matrix() raise ValueError("plot_data should not work on empty Datasets") except Exception as e: - assert(e.__class__.__name__ == "IndexError") + assert(e.__class__.__name__ == "ValueError") try: empty_list.singular_values()