Skip to content

Commit

Permalink
Merge pull request #683 from Dessia-tech/update/dataset_plot
Browse files Browse the repository at this point in the history
update: dataset plot
  • Loading branch information
GhislainJ committed Mar 26, 2024
2 parents fe0642f + d44c368 commit 32daf8b
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 26 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Expand Up @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## 0.16.3

### Fix

- Dataset : Update plot to new decorator paradigm


## 0.16.2

### Fix
Expand Down
6 changes: 3 additions & 3 deletions code_pylint.py
Expand Up @@ -29,16 +29,16 @@
"protected-access": 48, # Highly dependant on our "private" conventions. Keeps getting raised
"arguments-differ": 1,
"too-many-locals": 5, # Reduce by dropping vectored objects
"too-many-branches": 9, # Huge refactor needed. Will be reduced by schema refactor
"too-many-branches": 8, # Huge refactor needed. Will be reduced by schema refactor
"unused-argument": 3, # Some abstract functions have unused arguments (plot_data). Hence cannot decrease
"cyclic-import": 2, # Still work to do on Specific based DessiaObject
"too-many-arguments": 18, # Huge refactor needed
"too-few-public-methods": 3, # Abstract classes (Errors, Checks,...)
"too-many-return-statements": 8, # Huge refactor needed. Will be reduced by schema refactor
"too-many-return-statements": 7, # Huge refactor needed. Will be reduced by schema refactor
"import-outside-toplevel": 5, # TODO : will reduced in a future work (when tests are ready)
"too-many-instance-attributes": 6, # Huge refactor needed (workflow, etc...)
"broad-exception-caught": 9, # Necessary in order not to raise non critical errors.
"too-many-public-methods": 2, # Try to lower by splitting DessiaObject and Workflow
"too-many-public-methods": 3, # Try to lower by splitting DessiaObject and Workflow
}

ERRORS_WITHOUT_TIME_DECREASE = ['protected-access', 'arguments-differ', 'too-many-locals', 'too-many-branches',
Expand Down
51 changes: 30 additions & 21 deletions dessia_common/datatools/dataset.py
@@ -1,22 +1,24 @@
""" Library for building Dataset. """
from typing import List, Dict, Any
from copy import copy
import itertools
from copy import copy
from typing import Any, Dict, List

from scipy.spatial.distance import pdist, squareform
import numpy as npy
from scipy.spatial.distance import pdist, squareform
from sklearn import preprocessing

try:
from plot_data.core import Scatter, Histogram, MultiplePlots, Tooltip, ParallelPlot, PointFamily, EdgeStyle, Axis, \
PointStyle, Sample
from plot_data.colors import BLUE, GREY
from plot_data.colors import BLUE, GREY, Color
except ImportError:
pass
from dessia_common.core import DessiaObject, DessiaFilter, FiltersList
from dessia_common.exports import MarkdownWriter
from dessia_common import templates
from dessia_common.datatools.metrics import mean, std, variance, covariance_matrix
from dessia_common.core import DessiaFilter, DessiaObject, FiltersList
from dessia_common.datatools.metrics import (covariance_matrix, mean, std,
variance)
from dessia_common.decorators import plot_data_view
from dessia_common.exports import MarkdownWriter


class Dataset(DessiaObject):
Expand Down Expand Up @@ -622,25 +624,32 @@ def _scale_data(data_matrix: List[List[float]]):
scaled_matrix = preprocessing.StandardScaler().fit_transform(data_matrix)
return [list(map(float, row.tolist())) for row in scaled_matrix]

def plot_data(self, reference_path: str = "#", **kwargs):
""" Plot a standard scatter matrix of all attributes in common_attributes and a dimensionality plot. """
@plot_data_view(selector='DataSet scatter matrix')
def plot_scatter_matrix(self, reference_path: str = "#", **kwargs):
""" Plot a scatter matrix for attributes in common_attributes. """
data_list = self._to_samples(reference_path=reference_path)
if len(self.common_attributes) > 1:
scatter_matrix = self._build_multiplot(data_list, self._tooltip_attributes())
return scatter_matrix
raise ValueError("Scatter matrix can only be plotted with more than one common attribute.")

@plot_data_view(selector='DataSet parallel plot')
def plot_parallel_plot(self, reference_path: str = "#", **kwargs):
""" Plot a parallel plot for attributes in common_attributes. """
data_list = self._to_samples(reference_path=reference_path)
if len(self.common_attributes) > 1:
# Plot a correlation matrix : To develop
# correlation_matrix = []
# Dimensionality plot
dimensionality_plot = self._plot_dimensionality()
# Scatter Matrix
scatter_matrix = self._build_multiplot(data_list, self._tooltip_attributes(),
axis=dimensionality_plot.axis,
point_style=dimensionality_plot.point_style)
# Parallel plot
parallel_plot = self._parallel_plot(data_list)
return [parallel_plot, scatter_matrix] # , dimensionality_plot]
return parallel_plot
if len(self.common_attributes) == 1:
return self.plot_histogram(reference_path=reference_path)
raise ValueError("No common attributes found for plotting a parallel plot.")

def plot_histogram(self, reference_path: str = "#", **kwargs):
""" Plot a histogram when there's only one common attribute. """
data_list = self._to_samples(reference_path=reference_path)
plot_mono_attr = self._histogram_unic_value(0, name_attr=self.common_attributes[0])
plot_mono_attr.elements = data_list
return [plot_mono_attr]
return plot_mono_attr

def _build_multiplot(self, data_list: List[Dict[str, float]], tooltip: List[str], **kwargs: Dict[str, Any]):
subplots = []
Expand Down Expand Up @@ -677,7 +686,7 @@ def _to_samples(self, reference_path: str = '#'):
for row, dessia_object in enumerate(self.dessia_objects)]

def _point_families(self):
return [PointFamily(GREY, list(range(len(self))))]
return [PointFamily(Color(182/255, 225/255, 251/255), list(range(len(self))))]

def _parallel_plot(self, data_list: List[Dict[str, float]]):
return ParallelPlot(elements=data_list, axes=self._parallel_plot_attr(), disposition='vertical')
Expand Down
4 changes: 2 additions & 2 deletions scripts/dataset.py
Expand Up @@ -132,10 +132,10 @@ def to_vector(self):
assert(empty_list == Dataset())

try:
empty_list.plot_data()
empty_list.plot_scatter_matrix()
raise ValueError("plot_data should not work on empty Datasets")
except Exception as e:
assert(e.__class__.__name__ == "IndexError")
assert(e.__class__.__name__ == "ValueError")

try:
empty_list.singular_values()
Expand Down

0 comments on commit 32daf8b

Please sign in to comment.