diff --git a/requirements.txt b/requirements.txt index 62768ef..444a30c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,8 @@ numpy~=2.1.3 pandas~=2.2.3 -matplotlib~=3.9.2 \ No newline at end of file +matplotlib~=3.9.2 +diptest +scipy +scikit-learn +pymannkendall +cydets \ No newline at end of file diff --git a/src/external_explainers/metainsight_explainer/__init__.py b/src/external_explainers/metainsight_explainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/external_explainers/metainsight_explainer/cache.py b/src/external_explainers/metainsight_explainer/cache.py new file mode 100644 index 0000000..87bf162 --- /dev/null +++ b/src/external_explainers/metainsight_explainer/cache.py @@ -0,0 +1,118 @@ +from singleton_decorator import singleton +from collections import OrderedDict + +PATTERN_CACHE_MAX_SIZE = 40000 +DATASCOPE_CACHE_MAX_SIZE = 40000 +PATTERN_EVAL_CACHE_MAX_SIZE = 40000 +GROUPBY_CACHE_MAX_SIZE = 5000 + +@singleton +class Cache: + """ + A singleton class to hold various caches used in the MetaInsight explainer. + This helps in avoiding redundant computations and speeds up the evaluation process. + We use a singleton pattern to make the cache: + 1. Global across the application. + 2. Persistent throughout the lifetime of the application. + This cache is a simple LRU (Least Recently Used) cache implementation, removing the least recently used items when the cache exceeds its maximum size. + The caches in this class are: + - pattern_cache: Stores the data pattern objects evaluated for different data scopes and patterns. + - datascope_cache: Stores the scores for different data scopes. + - groupby_cache: Stores the results of groupby operations. + - pattern_eval_cache: Stores the results of pattern evaluations on series. + """ + + def __init__(self): + self._pattern_cache = OrderedDict() + self._datascope_cache = OrderedDict() + self._groupby_cache = OrderedDict() + self._pattern_eval_cache = OrderedDict() + self.pattern_cache_max_size = PATTERN_CACHE_MAX_SIZE + self.datascope_cache_max_size = DATASCOPE_CACHE_MAX_SIZE + self.groupby_cache_max_size = GROUPBY_CACHE_MAX_SIZE + self.pattern_eval_cache_max_size = PATTERN_EVAL_CACHE_MAX_SIZE + + + def _add_to_cache(self, cache, key, value, max_size) -> None: + """ + Adds a key-value pair to the specified cache. + If the cache exceeds its maximum size, it removes the least recently used item. + """ + if key in cache: + # Update the value and mark as recently used + cache.move_to_end(key) + cache[key] = value + if len(cache) > max_size: + # Pop the first item (least recently used) + cache.popitem(last=False) + + + def _get_from_cache(self, cache, key) -> any: + """ + Retrieves a value from the specified cache by key. + If the key exists, it marks the key as recently used. + """ + if key in cache: + # Move the accessed item to the end to mark it as recently used + cache.move_to_end(key) + return cache[key] + return None + + + def add_to_pattern_cache(self, key, value) -> None: + """ + Adds a key-value pair to the pattern cache. + If the cache exceeds its maximum size, it removes the least recently used item. + """ + self._add_to_cache(self._pattern_cache, key, value, PATTERN_CACHE_MAX_SIZE) + + + def add_to_datascope_cache(self, key, value) -> None: + """ + Adds a key-value pair to the datascope cache. + If the cache exceeds its maximum size, it removes the least recently used item. + """ + self._add_to_cache(self._datascope_cache, key, value, DATASCOPE_CACHE_MAX_SIZE) + + def add_to_groupby_cache(self, key, value): + """ + Adds a key-value pair to the groupby cache. + If the cache exceeds its maximum size, it removes the least recently used item. + """ + self._add_to_cache(self._groupby_cache, key, value, GROUPBY_CACHE_MAX_SIZE) + + def add_to_pattern_eval_cache(self, key, value) -> None: + """ + Adds a key-value pair to the pattern evaluation cache. + If the cache exceeds its maximum size, it removes the least recently used item. + """ + self._add_to_cache(self._pattern_eval_cache, key, value, PATTERN_EVAL_CACHE_MAX_SIZE) + + + def get_from_pattern_cache(self, key): + """ + Retrieves a value from the pattern cache by key. + If the key exists, it marks the key as recently used. + """ + return self._get_from_cache(self._pattern_cache, key) + + def get_from_datascope_cache(self, key): + """ + Retrieves a value from the datascope cache by key. + If the key exists, it marks the key as recently used. + """ + return self._get_from_cache(self._datascope_cache, key) + + def get_from_groupby_cache(self, key): + """ + Retrieves a value from the groupby cache by key. + If the key exists, it marks the key as recently used. + """ + return self._get_from_cache(self._groupby_cache, key) + + def get_from_pattern_eval_cache(self, key): + """ + Retrieves a value from the pattern evaluation cache by key. + If the key exists, it marks the key as recently used. + """ + return self._get_from_cache(self._pattern_eval_cache, key) \ No newline at end of file diff --git a/src/external_explainers/metainsight_explainer/data_pattern.py b/src/external_explainers/metainsight_explainer/data_pattern.py new file mode 100644 index 0000000..f10d909 --- /dev/null +++ b/src/external_explainers/metainsight_explainer/data_pattern.py @@ -0,0 +1,226 @@ +import typing + +import pandas as pd +from typing import Dict, List, Tuple + +from external_explainers.metainsight_explainer.data_scope import DataScope, HomogenousDataScope +from external_explainers.metainsight_explainer.pattern_evaluations import PatternEvaluator, PatternType +from external_explainers.metainsight_explainer.patterns import PatternInterface +from external_explainers.metainsight_explainer.cache import Cache + + +class BasicDataPattern: + """ + A data pattern, as defined in the MetaInsight paper. + Contains 3 elements: data scope, type (interpretation type) and highlight. + """ + cache = Cache() + + def __init__(self, data_scope: DataScope, pattern_type: PatternType, highlight: PatternInterface | None): + """ + Initialize the BasicDataPattern with the provided data scope, type and highlight. + + :param data_scope: The data scope of the pattern. a DataScope object. + :param pattern_type: str, e.g., 'Unimodality', 'Trend', 'Other Pattern', 'No Pattern' + :param highlight: depends on type, e.g., ('April', 'Valley') for Unimodality + """ + self.data_scope = data_scope + self.pattern_type = pattern_type + self.highlight = highlight + self.hash = None + + def __eq__(self, other): + if not isinstance(other, BasicDataPattern): + return False + return self.pattern_type == other.pattern_type and \ + self.highlight == other.highlight and \ + self.data_scope == other.data_scope + + def sim(self, other) -> bool: + """ + Computes the similarity between two BasicDataPattern objects. + They are similar if they have the same pattern type and highlight, as well as neither having + a pattern type of NONE or OTHER. + + :param other: The other BasicDataPattern object to compare with. + :return: True if similar, False otherwise. + """ + if not isinstance(other, BasicDataPattern): + return False + # There is no REAL need to check that both don't have NONE or OTHER pattern types, since if one + # has it but the other doesn't, the equality will be false anyway. If they both have it, then + # the equality conditions will be true but the inequality conditions will be false. + return self.pattern_type == other.pattern_type and self.highlight == other.highlight and \ + self.pattern_type != PatternType.NONE and self.pattern_type != PatternType.OTHER + + def __hash__(self): + if self.hash is not None: + return self.hash + self.hash = hash((hash(self.data_scope), self.pattern_type, self.highlight)) + return self.hash + + def __repr__(self): + return f"BasicDataPattern(ds={self.data_scope}, type='{self.pattern_type}', highlight={self.highlight})" + + @staticmethod + def evaluate_pattern(data_scope: DataScope, df: pd.DataFrame, pattern_type: PatternType) -> List['BasicDataPattern']: + """ + Evaluates a specific pattern type for the data distribution of a data scope. + :param data_scope: The data scope to evaluate. + :param df: The DataFrame containing the data. + :param pattern_type: The type of the pattern to evaluate. + """ + # Apply subspace filters + filtered_df = data_scope.apply_subspace() + + # Group by breakdown dimension and aggregate measure + if any([dim for dim in data_scope.breakdown if dim not in filtered_df.columns]): + # Cannot group by breakdown if it's not in the filtered data + return [BasicDataPattern(data_scope, PatternType.NONE, None)] + + measure_col, agg_func = data_scope.measure + if measure_col not in filtered_df.columns: + # Cannot aggregate if measure column is not in the data + return [BasicDataPattern(data_scope, PatternType.NONE, None)] + + try: + # Perform the aggregation + if agg_func != "std": + aggregated_series = filtered_df.groupby(data_scope.breakdown)[measure_col].agg(agg_func) + else: + # For standard deviation, we need to use the std function directly + aggregated_series = filtered_df.groupby(data_scope.breakdown)[measure_col].std(ddof=1) + except Exception as e: + print(f"Error during aggregation for {data_scope}: {e}") + return [BasicDataPattern(data_scope, PatternType.NONE, None)] + + # Ensure series is sortable if breakdown is temporal + if all([True for dim in data_scope.breakdown if df[dim].dtype.kind in 'iuMmfc']): + # If the breakdown is temporal or at-least can be sorted, sort the series + aggregated_series = aggregated_series.sort_index() + + # Evaluate the specific pattern type + returned_patterns = [] + pattern_evaluator = PatternEvaluator() + is_valid, highlight = pattern_evaluator(aggregated_series, pattern_type) + if is_valid: + # A returned highlight can contain multiple highlights, for example, if a peak and a valley are found + # in the same series. + for hl in highlight: + returned_patterns.append(BasicDataPattern(data_scope, pattern_type, hl)) + else: + # Check for other pattern types + for other_type in PatternType: + if other_type == PatternType.OTHER or other_type == PatternType.NONE: + continue + if other_type != pattern_type: + other_is_valid, highlight = pattern_evaluator(aggregated_series, other_type) + if other_is_valid: + for hl in highlight: + returned_patterns.append(BasicDataPattern(data_scope, PatternType.OTHER, hl)) + + if len(returned_patterns) == 0: + # If no pattern is found, return a 'No Pattern' type + return [BasicDataPattern(data_scope, PatternType.NONE, None)] + + return returned_patterns + + def create_hdp(self, pattern_type: PatternType, + hds: List[DataScope] = None, group_by_dims: List[List[str]] = None, + measures: List[Tuple[str,str]] = None, n_bins: int = 10, + extend_by_measure: bool = False, extend_by_breakdown: bool = False) -> 'HomogenousDataPattern': + """ + Generates a Homogenous Data Pattern (HDP) either from a given HDS or from the current DataScope. + + :param pattern_type: The type of the pattern (e.g., 'Unimodality', 'Trend', etc.), provided as a PatternType enum. + :param hds: A list of DataScopes to create the HDP from. If None, it will be created from the current DataScope. + :param group_by_dims: The temporal dimensions to extend the breakdown with. Expected as a list of lists of strings. + :param measures: The measures to extend the measure with. Expected to be a dict {measure_column: aggregate_function}. Only needed if hds is None. + :param n_bins: The number of bins to use for numeric columns. Defaults to 10. + :param extend_by_measure: Whether to extend the hds by measure. Defaults to False. + :param extend_by_breakdown: Whether to extend the hds by breakdown. Defaults to False. + :return: The HomogenousDataPattern object containing the evaluated patterns. + """ + if hds is None or len(hds) == 0: + hds = self.data_scope.create_hds(dims=group_by_dims, measures=measures, + n_bins=n_bins, extend_by_measure=extend_by_measure, + extend_by_breakdown=extend_by_breakdown) + # All the data scopes in the HDS should have the same source_df, and it should be + # the same as the source_df of the current DataScope (otherwise, this pattern should not be + # the one producing the HDP with this HDS). + source_df = self.data_scope.source_df + + # Create the HDP + hdp = [] + for ds in hds: + if ds != self.data_scope: + # Check pattern cache first + cache_key = (ds.__hash__(), pattern_type) + cache_result = self.cache.get_from_pattern_cache(cache_key) + if cache_result is not None: + dp = cache_result + else: + # Evaluate the pattern if not in cache, and add to cache + dp = self.evaluate_pattern(ds, source_df, pattern_type) + self.cache.add_to_pattern_cache(cache_key, dp) + + # Some evaluation functions can return multiple patterns, so it is simpler to just + # convert it to a list and then treat it as an iterable. + if not isinstance(dp, typing.Iterable): + dp = [dp] + + # # Only add patterns that are not 'No Pattern' to the HDP for MetaInsight evaluation + # for d in dp: + # if d is not None and d.pattern_type != PatternType.NONE: + # hdp.append(d) + + # Add all patterns, including 'No Pattern', since it is important to know that we had a 'No Pattern'. + for d in dp: + if dp is not None: + hdp.append(d) + + if self.pattern_type != PatternType.NONE: + # Add the current pattern to the HDP + hdp.append(self) + hdp = HomogenousDataPattern(hdp) + + return hdp + + +class HomogenousDataPattern(HomogenousDataScope): + """ + A homogenous data pattern. + A list of data patterns induced by the same pattern type on a homogenous data scope. + """ + + def __init__(self, data_patterns: List[BasicDataPattern]): + """ + Initialize the HomogenousDataPattern with the provided data patterns. + + :param data_patterns: A list of BasicDataPattern objects. + """ + if not data_patterns: + raise ValueError("data_patterns cannot be empty.") + super(HomogenousDataPattern, self).__init__([dp.data_scope for dp in data_patterns]) + self.data_patterns = data_patterns + + def __iter__(self): + """ + Allows iteration over the data patterns. + """ + return iter(self.data_patterns) + + def __len__(self): + """ + Returns the number of data patterns. + """ + return len(self.data_patterns) + + def __repr__(self): + return f"HomogenousDataPattern(#Patterns={len(self.data_patterns)})" + + def __getitem__(self, item): + """ + Allows indexing into the data patterns. + """ + return self.data_patterns[item] diff --git a/src/external_explainers/metainsight_explainer/data_scope.py b/src/external_explainers/metainsight_explainer/data_scope.py new file mode 100644 index 0000000..7d5d265 --- /dev/null +++ b/src/external_explainers/metainsight_explainer/data_scope.py @@ -0,0 +1,325 @@ +import pandas as pd +from typing import Dict, List, Tuple +from scipy.special import kl_div +import re +from external_explainers.metainsight_explainer.cache import Cache + +cache = Cache() + +class DataScope: + """ + A data scope, as defined in the MetaInsight paper. + Contains 3 elements: subspace, breakdown and measure. + Example: for the query SELECT Month, SUM(Sales) FROM DATASET WHERE City==“Los Angeles” GROUP BY Month + The subspace is {City: Los Angeles, Month: *}, the breakdown is {Month} and the measure is {SUM(Sales)}. + """ + + + def __init__(self, source_df: pd.DataFrame, subspace: Dict[str, str], + breakdown: str | List[str], + measure: tuple): + """ + Initialize the DataScope with the provided subspace, breakdown and measure. + + :param source_df: The DataFrame containing the data. + :param subspace: dict of filters, e.g., {'City': 'Los Angeles', 'Month': '*'} + :param breakdown: The dimension(s) to group by. Can be a string or a list of strings. + :param measure: tuple, (measure_column_name, aggregate_function_name) + """ + # We want to allow for multi-value groupbys, so we work with lists of strings + if isinstance(breakdown, str): + breakdown = [breakdown] + self.source_df = source_df + self.subspace = subspace + self.breakdown = breakdown + self.measure = measure + self.breakdown_frozen = frozenset(self.breakdown) + self.hash = None + + def __hash__(self): + if self.hash is not None: + return self.hash + # Need a hashable representation of subspace for hashing + subspace_tuple = tuple(sorted(self.subspace.items())) if isinstance(self.subspace, dict) else tuple( + self.subspace) + self.hash = hash((subspace_tuple, frozenset(self.breakdown), self.measure)) + return self.hash + + def __repr__(self): + return f"DataScope(subspace={self.subspace}, breakdown='{self.breakdown}', measure={self.measure})" + + def __eq__(self, other): + if not isinstance(other, DataScope): + return False + return (self.subspace == other.subspace and + self.breakdown == other.breakdown and + self.measure == other.measure) + + def apply_subspace(self) -> pd.DataFrame: + """ + Applies the subspace filters to the source DataFrame and returns the filtered DataFrame. + """ + filtered_df = self.source_df.copy() + for dim, value in self.subspace.items(): + if value != '*': + pattern = rf"^.+<= {dim} <= .+$" + pattern_matched = re.match(pattern, str(value)) + if pattern_matched: + # If the value is a range, split it and filter accordingly + split = re.split(r"<=|>=|<|>", value) + lower_bound, dim, upper_bound = float(split[0].strip()), split[1].strip(), float(split[2].strip()) + filtered_df = filtered_df[(filtered_df[dim] >= lower_bound) & (filtered_df[dim] <= upper_bound)] + else: + filtered_df = filtered_df[filtered_df[dim] == value] + return filtered_df + + def _subspace_extend(self, n_bins: int = 10) -> List['DataScope']: + """ + Extends the subspace of the DataScope into its sibling group by the dimension dim_to_extend. + Subspaces with the same sibling group only differ from each other in 1 non-empty filter. + + :param n_bins: The number of bins to use for numeric columns. Defaults to 10. + + :return: A list of new DataScope objects with the extended subspace. + """ + new_ds = [] + if isinstance(self.subspace, dict): + for dim_to_extend in self.subspace.keys(): + unique_values = self.source_df[dim_to_extend].dropna().unique() + # If there are too many unique values, we bin them if it's a numeric column, or only choose the + # top 10 most frequent values if it's a categorical column + if len(unique_values) > n_bins: + if self.source_df[dim_to_extend].dtype.kind in 'biufcmM': + # Bin the numeric column + bins = pd.cut(self.source_df[dim_to_extend], bins=n_bins, retbins=True)[1] + unique_values = [f"{bins[i]} <= {dim_to_extend} <= {bins[i + 1]}" for i in range(len(bins) - 1)] + # else: + # # Choose the top 10 most frequent values + # top_values = self.source_df[dim_to_extend].value_counts().nlargest(10).index.tolist() + # unique_values = [v for v in unique_values if v in top_values] + for value in unique_values: + # Ensure it's a sibling + if self.subspace.get(dim_to_extend) != value: + # Add the new DataScope with the extended subspace + new_subspace = self.subspace.copy() + new_subspace[dim_to_extend] = value + new_ds.append(DataScope(self.source_df, new_subspace, self.breakdown, self.measure)) + return new_ds + + def _measure_extend(self, measures: List[Tuple[str, str]]) -> List['DataScope']: + """ + Extends the measure of the DataScope while keeping the same breakdown and subspace. + + :param measures: The measures to extend. + :return: A list of new DataScope objects with the extended measure. + """ + new_ds = [] + for measure_col, agg_func in measures: + if (measure_col, agg_func) != self.measure: + new_ds.append(DataScope(self.source_df, self.subspace, self.breakdown, (measure_col, agg_func))) + return new_ds + + def _breakdown_extend(self, dims: List[List[str]]) -> List['DataScope']: + """ + Extends the breakdown of the DataScope while keeping the same subspace and measure. + + :param dims: The dimensions to extend the breakdown with. + :return: A list of new DataScope objects with the extended breakdown. + """ + new_ds = [] + + for breakdown_dim in dims: + if breakdown_dim != self.breakdown: + new_ds.append(DataScope(self.source_df, self.subspace, breakdown_dim, self.measure)) + return new_ds + + def create_hds(self, dims: List[List[str]] = None, + measures: List[Tuple[str,str]] = None, n_bins: int = 10, + extend_by_measure: bool = False, + extend_by_breakdown: bool = False, + ) -> 'HomogenousDataScope': + """ + Generates a Homogeneous Data Scope (HDS) from a base data scope, using subspace, measure and breakdown + extensions as defined in the MetaInsight paper. + + :param dims: The temporal dimensions to extend the breakdown with. Expected as a list of strings. + :param measures: The measures to extend the measure with. Expected to be a dict {measure_column: aggregate_function}. + :param n_bins: The number of bins to use for numeric columns. Defaults to 10. + :param extend_by_measure: Whether to use measure extension or not. Defaults to False. Setting this to true + can lead to metainsights with mixed aggregation functions, which may often be undesirable. + :param extend_by_breakdown: Whether to use breakdown extension or not. Defaults to False. Setting this to True + can lead to metainsights with several disjoint indexes, which may often be undesirable. + + :return: A HDS in the form of a list of DataScope objects. + """ + hds = [self] + if dims is None: + dims = [] + if measures is None: + measures = {} + + # Subspace Extending + hds.extend(self._subspace_extend(n_bins=n_bins)) + + # Measure Extending. + # We may not want to do it though, if we want our HDS to only contain the original measure. + if extend_by_measure: + hds.extend(self._measure_extend(measures)) + + # Breakdown Extending + if extend_by_breakdown: + hds.extend(self._breakdown_extend(dims)) + + return HomogenousDataScope(hds) + + def compute_impact(self) -> float: + """ + Computes the impact of the data scope based on the provided impact measure. + We define impact as the proportion of rows between the data scope and the total date scope, multiplied + by their KL divergence. + """ + if len(self.subspace) == 0: + # No subspace, no impact + return 0 + # Use the provided impact measure or default to the data scope's measure + impact_col, agg_func = self.measure + if impact_col not in self.source_df.columns: + raise ValueError(f"Impact column '{impact_col}' not found in source DataFrame.") + + # Perform subspace filtering + filtered_df = self.apply_subspace() + # Group by breakdown dimension and aggregate measure + if any([True for dim in self.breakdown if dim not in filtered_df.columns]): + # Cannot group by breakdown if it's not in the filtered data + return 0 + if impact_col not in filtered_df.columns: + # Cannot aggregate if measure column is not in the data + return 0 + try: + numeric_columns = filtered_df.select_dtypes(include=['number']).columns.tolist() + # Perform the aggregation + if agg_func != "std": + aggregated_series = filtered_df.groupby(impact_col)[numeric_columns].agg(agg_func) + else: + # If the aggregation is std, we need to manually provide ddof + aggregated_series = filtered_df.groupby(impact_col)[numeric_columns].std(ddof=1) + cache_result = cache.get_from_groupby_cache((impact_col, agg_func)) + if cache_result is not None: + # If the aggregation is in the cache, use it + aggregated_source = cache_result + else: + if agg_func != "std": + aggregated_source = self.source_df.groupby(impact_col)[numeric_columns].agg(agg_func) + else: + # If the aggregation is std, we need to manually provide ddof + aggregated_source = self.source_df.groupby(impact_col)[numeric_columns].std(ddof=1) + # Cache the result of the groupby operation + cache.add_to_groupby_cache((impact_col, agg_func), aggregated_source) + except Exception as e: + # raise e + print(f"Error during aggregation for {self}: {e}") + return 0 + + kl_divergence = kl_div(aggregated_series, aggregated_source).mean() + # If it is still a series, then the first mean was on a dataframe and not a series, and thus we need + # to take the mean to get a float. + if isinstance(kl_divergence, pd.Series): + kl_divergence = kl_divergence.mean() + row_proportion = len(filtered_df.index.to_list()) / len(self.source_df.index.to_list()) + impact = row_proportion * kl_divergence + return impact + + def create_query_string(self, df_name: str = None) -> str: + """ + Create a query string for the data scope. + :param df_name: The name of the DataFrame to use in the query string. + :return: + """ + if df_name is None: + df_name = self.source_df.name if self.source_df.name else "df" + subspace_where_string = [] + for dim, value in self.subspace.items(): + # If the value is a range, we can just add it as is + pattern = rf"^.+<= {dim} <= .+$" + pattern_matched = re.match(pattern, str(value)) + if pattern_matched: + subspace_where_string.append(value) + else: + # Otherwise, we need to add it as an equality string + subspace_where_string.append(f"{dim} == '{value}'") + subspace_where_string = 'WHERE ' + ' AND '.join(subspace_where_string) + measures_select_string = f'SELECT {self.measure[1].upper()}({self.measure[0]})' + breakdown_groupby_string = f"GROUP BY {self.breakdown}" + query_string = f"{measures_select_string} FROM {df_name} {subspace_where_string} {breakdown_groupby_string}" + return query_string + + + + + + +class HomogenousDataScope: + """ + A homogenous data scope. + A list of data scopes that are all from the same source_df, and are all created using + one of the 3 extension methods of the DataScope class. + """ + + def __init__(self, data_scopes: List[DataScope]): + """ + Initialize the HomogenousDataScope with the provided data scopes. + + :param data_scopes: A list of DataScope objects. + """ + self.data_scopes = data_scopes + self.source_df = data_scopes[0].source_df if data_scopes else None + self.impact = 0 + + def __iter__(self): + """ + Allows iteration over the data scopes. + """ + return iter(self.data_scopes) + + def __len__(self): + """ + Returns the number of data scopes. + """ + return len(self.data_scopes) + + def __getitem__(self, item): + """ + Allows indexing into the data scopes. + """ + return self.data_scopes[item] + + def __repr__(self): + return f"HomogenousDataScope(#DataScopes={len(self.data_scopes)})" + + def __lt__(self, other): + """ + Less than comparison for sorting. + :param other: Another HomogenousDataScope object. + :return: True if this object is less than the other, False otherwise. + """ + # We use the negative impact, since we want to use a max-heap but only have min-heap available + return - self.impact < - other.impact + + def compute_impact(self) -> float: + """ + Computes the impact of the HDS. This is the sum of the impacts of all data scopes in the HDS. + :return: The total impact of the HDS. + """ + impact = 0 + for ds in self.data_scopes: + # Use the cached impact if available to avoid recomputation, since computing the impact + # is the single most expensive operation in the entire pipeline + cache_result = cache.get_from_datascope_cache(ds.__hash__()) + if cache_result is not None: + ds_impact = cache_result + else: + ds_impact = ds.compute_impact() + cache.add_to_datascope_cache(ds.__hash__(), ds_impact) + impact += ds_impact + self.impact = impact + return impact diff --git a/src/external_explainers/metainsight_explainer/meta_insight.py b/src/external_explainers/metainsight_explainer/meta_insight.py new file mode 100644 index 0000000..77aca1a --- /dev/null +++ b/src/external_explainers/metainsight_explainer/meta_insight.py @@ -0,0 +1,729 @@ +from collections import defaultdict +from typing import List, Dict + +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +import textwrap + +import math + +from external_explainers.metainsight_explainer.data_pattern import HomogenousDataPattern +from external_explainers.metainsight_explainer.data_pattern import BasicDataPattern +from external_explainers.metainsight_explainer.pattern_evaluations import PatternType + +COMMONNESS_THRESHOLD = 0.5 +BALANCE_PARAMETER = 1 +ACTIONABILITY_REGULARIZER_PARAM = 0.1 +EXCEPTION_CATEGORY_COUNT = 3 + + +class MetaInsight: + """ + Represents a MetaInsight (HDP, commonness_set, exceptions). + """ + + def __init__(self, hdp: HomogenousDataPattern, + commonness_set: List[List[BasicDataPattern]], + exceptions: Dict[str, List[BasicDataPattern]], score=0, + commonness_threshold: float = COMMONNESS_THRESHOLD, + balance_parameter: float = BALANCE_PARAMETER, + actionability_regularizer_param: float = ACTIONABILITY_REGULARIZER_PARAM, + source_name: str = None, + ): + """ + :param hdp: list of BasicDataPattern objects + :param commonness_set: A dictionary mapping commonness patterns to lists of BasicDataPattern objects + :param exceptions: A dictionary mapping exception categories to lists of BasicDataPattern objects + """ + self.hdp = hdp + self.commonness_set = commonness_set + self.exceptions = exceptions + self.score = score + self.commonness_threshold = commonness_threshold + self.balance_parameter = balance_parameter + self.actionability_regularizer_param = actionability_regularizer_param + self.source_name = source_name if source_name else "df" + self.hash = None + + def __repr__(self): + return f"MetaInsight(score={self.score:.4f}, #HDP={len(self.hdp)}, #Commonness={len(self.commonness_set)}, #Exceptions={len(self.exceptions)})" + + def __hash__(self): + if self.hash is not None: + return self.hash + self.hash = 0 + for commonness in self.commonness_set: + for pattern in commonness: + self.hash += pattern.__hash__() + return self.hash + + + def __eq__(self, other): + """ + Compares two MetaInsight objects for equality. + Two MetaInsight objects are considered equal if they have the same commonness sets. + :param other: + :return: + """ + if not isinstance(other, MetaInsight): + return False + # If the commonness sets are not the same size, they are not equal + if len(self.commonness_set) != len(other.commonness_set): + return False + all_equal = True + for self_commonness in self.commonness_set: + for other_commonness in other.commonness_set: + # Check if the commonness sets are equal + if len(self_commonness) != len(other_commonness): + all_equal = False + break + for pattern in self_commonness: + if pattern not in other_commonness: + all_equal = False + break + + return all_equal + + + def __str__(self): + """ + :return: A string representation of the MetaInsight, describing all of the commonnesses in it. + """ + ret_str = "" + for commonness in self.commonness_set: + ret_str += self._create_commonness_set_title(commonness) + return ret_str + + + def _write_exceptions_list_string(self, category: PatternType, patterns: List[BasicDataPattern], category_name: str) -> str: + """ + Helper function to create a string representation of a list of exception patterns. + :param category: The category of the exceptions. + :param patterns: The list of BasicDataPattern objects in this category. + :param category_name: The name of the category. + :return: A string representation of the exceptions in this category. + """ + if not patterns: + return "" + if category_name.lower() not in ["no pattern", "no-pattern", "none", "highlight-change", "highlight change"]: + # If the category is "No Pattern" or "Highlight Change", we don't need to write anything + exceptions = [pattern for pattern in patterns if pattern.pattern_type not in [PatternType.NONE, PatternType.OTHER]] + else: + exceptions = [pattern for pattern in patterns if pattern.pattern_type == category] + subspaces = [pattern.data_scope.subspace for pattern in exceptions] + subspace_dict = defaultdict(list) + for subspace in subspaces: + for key, val in subspace.items(): + subspace_dict[key].append(val) + out_str = f"Exceptions in category '{category_name}' ({len(exceptions)}): [" + for key, val in subspace_dict.items(): + out_str += f"{key} = {val}, " + out_str = out_str[:-2] + "]\n" + return out_str + + def get_exceptions_string(self): + """ + A string representation of the list of exception categories. + :return: + """ + exceptions_string = "" + for category, patterns in self.exceptions.items(): + if not patterns: + continue + # No-Pattern category: create an array of + if category.lower() == "no-pattern" or category.lower() == "none": + exceptions_string += self._write_exceptions_list_string(PatternType.NONE, patterns, "No Pattern") + if category.lower() == "highlight-change" or category.lower() == "highlight change": + # Doesn't matter which PatternType we use here, so long as it is not None or PatternType.OTHER. + exceptions_string += self._write_exceptions_list_string(PatternType.UNIMODALITY, patterns, "Same pattern, different highlight") + elif category.lower() == "type-change" or category.lower() == "type change": + exceptions_string += self._write_exceptions_list_string(PatternType.OTHER, patterns, "Pattern type change") + if not exceptions_string: + exceptions_string = "All values belong to a commonness set, no exceptions found." + return exceptions_string + + + def to_str_full(self): + """ + :return: A full string representation of the MetaInsight, including commonness sets and exceptions. + """ + ret_str = self.__str__() + if len(self.exceptions) > 0: + ret_str += f"Exceptions to this pattern were found:\n" + ret_str += self.get_exceptions_string() + return ret_str + + + @staticmethod + def categorize_exceptions(commonness_set, exceptions): + """ + Categorizes exceptions based on differences from commonness highlights/types. + Simplified categorization: Highlight-Change, Type-Change, No-Pattern (though No-Pattern + should ideally not be in the exceptions list generated by generate_hdp). + Returns a dictionary mapping category names to lists of exception patterns. + """ + categorized = defaultdict(list) + commonness_highlights = set() + for commonness in commonness_set: + if commonness: # Ensure commonness is not empty + commonness_highlights.add(str(commonness[0].highlight)) # Assume all in commonness have same highlight + + for exc_dp in exceptions: + if exc_dp.pattern_type == PatternType.OTHER: + categorized['Type-Change'].append(exc_dp) + elif exc_dp.pattern_type == PatternType.NONE: + # This case should ideally not happen if generate_hdp filters 'No Pattern' + categorized['No-Pattern'].append(exc_dp) + elif str(exc_dp.highlight) not in commonness_highlights: + categorized['Highlight-Change'].append(exc_dp) + + # Keeping this commented out, since I couldn't figure out what to do with something in this catch-all category. + # For now it will be ignored, but it could maybe be useful. + # else: + # # Exception has a valid pattern type and highlight, but didn't meet commonness threshold + # # This could be another category or grouped with Highlight-Change + # categorized['Other-Exception'].append(exc_dp) # Add a catch-all category + + return categorized + + @staticmethod + def create_meta_insight(hdp: HomogenousDataPattern, commonness_threshold=COMMONNESS_THRESHOLD) -> 'MetaInsight': + """ + Evaluates the HDP and creates a MetaInsight object. + :param hdp: A HomogenousDataPattern object. + :param commonness_threshold: The threshold for commonness. + :return: A MetaInsight object if possible, None otherwise. + """ + if len(hdp) == 0: + return None + + # Group patterns by similarity + similarity_groups = defaultdict(list) + for dp in hdp: + found_group = False + for key in similarity_groups: + # Check similarity with the first element of an existing group + if dp.sim(similarity_groups[key][0]): + similarity_groups[key].append(dp) + found_group = True + break + if not found_group: + # Create a new group with this pattern as the first element (key) + similarity_groups[dp].append(dp) + + # Identify commonness(es) based on the threshold + commonness_set = [] + exceptions = [] + total_patterns_in_hdp = len(hdp) + + # Need to iterate through the original HDP to ensure all patterns are considered + # and assigned to either commonness or exceptions exactly once. + processed_patterns = set() + for dp in hdp: + if dp in processed_patterns: + continue + + is_commonness = False + for key, group in similarity_groups.items(): + if dp in group: + # An equivalence class is a commonness if it contains more than COMMONNESS_THRESHOLD of the HDP + if len(group) / total_patterns_in_hdp > commonness_threshold: + commonness_set.append(group) + for pattern in group: + processed_patterns.add(pattern) + is_commonness = True + break # Found the group for this pattern + + if not is_commonness: + # If the pattern wasn't part of a commonness, add it to exceptions + exceptions.append(dp) + processed_patterns.add(dp) + + # A valid MetaInsight requires at least one commonness + if not commonness_set: + return None + + # Categorize exceptions (optional for basic MetaInsight object, but needed for scoring) + categorized_exceptions = MetaInsight.categorize_exceptions(commonness_set, exceptions) + + return MetaInsight(hdp, commonness_set, categorized_exceptions, commonness_threshold=commonness_threshold) + + def calculate_conciseness(self) -> float: + """ + Calculates the conciseness score of a MetaInsight. + Based on the entropy of category proportions. + """ + n = len(self.hdp) + if n == 0: + return 0 + + # Calculate entropy + S = 0 + commonness_proportions = [] + for patterns in self.commonness_set: + if len(patterns) > 0: + proportion = len(patterns) / n + S += proportion * math.log2(proportion) + commonness_proportions.append(proportion) + + exception_proportions = [] + for category, patterns in self.exceptions.items(): + if len(patterns) > 0: + proportion = len(patterns) / n + S += self.balance_parameter * (proportion * math.log2(proportion)) + exception_proportions.append(proportion) + + # Convert to positive entropy + S = -S + + # Compute S* (the upper bound of S) + threshold = ((1 - self.commonness_threshold) * math.e) / ( + math.pow(self.commonness_threshold, 1 / self.balance_parameter)) + if EXCEPTION_CATEGORY_COUNT > threshold: + S_star = -math.log2(self.commonness_threshold) + (self.balance_parameter * EXCEPTION_CATEGORY_COUNT + * math.pow(self.commonness_threshold, + 1 / self.balance_parameter) + * math.log2(math.e)) + else: + S_star = - self.commonness_threshold * math.log(self.commonness_threshold) - ( + self.balance_parameter * (1 - self.commonness_threshold) * math.log2( + (1 - self.commonness_threshold) / EXCEPTION_CATEGORY_COUNT) + ) + + indicator_value = 1 if len(exception_proportions) == 0 else 0 + conciseness = 1 - ((S + self.actionability_regularizer_param * indicator_value) / S_star) + + # Ensure conciseness is within a reasonable range, e.g., [0, 1] + return conciseness + + def compute_score(self) -> float: + """ + Computes the score of the MetaInsight. + The score is the multiple of the conciseness of the MetaInsight and the impact score of the HDS + making up the HDP. + :param impact_measure: The impact measure to be used for the HDS. + :return: The score of the MetaInsight. + """ + conciseness = self.calculate_conciseness() + # If the impact has already been computed, use it + hds_score = self.hdp.impact if self.hdp.impact != 0 else self.hdp.compute_impact() + self.score = conciseness * hds_score + return self.score + + def compute_pairwise_overlap_ratio(self, other: 'MetaInsight') -> float: + """ + Computes the pairwise overlap ratio between two MetaInsights, as the ratio between the + size of the intersection and the size of the union of their HDPs. + :param other: Another MetaInsight object to compare with. + :return: The overlap ratio between the two MetaInsights. + """ + if not isinstance(other, MetaInsight): + raise ValueError("The other object must be an instance of MetaInsight.") + hds_1 = set(self.hdp.data_scopes) + hds_2 = set(other.hdp.data_scopes) + + overlap = len(hds_1.intersection(hds_2)) + total = len(hds_1.union(hds_2)) + # Avoid division by 0 + if total == 0: + return 0.0 + return overlap / total + + def compute_pairwise_overlap_score(self, other: 'MetaInsight') -> float: + """ + Computes the pairwise overlap score between two MetaInsights. + This is computed as min(I_1.score, I_2.scor) * overlap_ratio(I_1, I_2) + :param other: Another MetaInsight object to compare with. + :return: The pairwise overlap score between the two MetaInsights. + """ + if not isinstance(other, MetaInsight): + raise ValueError("The other object must be an instance of MetaInsight.") + overlap_ratio = self.compute_pairwise_overlap_ratio(other) + return min(self.score, other.score) * overlap_ratio + + + def _create_commonness_set_title(self, commonness_set: List[BasicDataPattern]) -> str: + """ + Create a title for the commonness set based on the patterns it contains. + :param commonness_set: A list of BasicDataPattern objects. + :return: A string representing the title for the commonness set. + """ + if not commonness_set: + return "No Patterns" + title = "" + # Check the type of the first pattern in the set. All patterns in the set should be of the same type. + pattern_type = commonness_set[0].pattern_type + if pattern_type == PatternType.UNIMODALITY: + title += "Common unimodality detected - " + umimodality = commonness_set[0].highlight + type = umimodality.type + index = umimodality.highlight_index + title += f"common {type} at index {index} " + elif pattern_type == PatternType.TREND: + trend = commonness_set[0].highlight + trend_type = trend.type + title += f"Common {trend_type} trend detected " + elif pattern_type == PatternType.OUTLIER: + title += "Common outliers detected " + outliers = [pattern.highlight for pattern in commonness_set] + common_outlier_indexes = {} + # Create a counter for the outlier indexes + for outlier in outliers: + if outlier.outlier_indexes is not None: + for idx in outlier.outlier_indexes: + if idx in common_outlier_indexes: + common_outlier_indexes[idx] += 1 + else: + common_outlier_indexes[idx] = 1 + # Sort the outlier indexes by their count + common_outlier_indexes = sorted(common_outlier_indexes.items(), key=lambda x: x[1], reverse=True) + # Take the top 5 most common outlier indexes + num_outliers = len(common_outlier_indexes) + common_outlier_indexes = list(dict(common_outlier_indexes).keys()) + # If there are more than 5, truncate the list and add "..." + if num_outliers > 5: + common_outlier_indexes.append("...") + title += f"at indexes {' / '.join(map(str, common_outlier_indexes))}: " + elif pattern_type == PatternType.CYCLE: + title += "Common cycles detected " + # Find the common subspace of the patterns in the set + # First, get the data scope of all of the patterns in the set + data_scopes = [pattern.data_scope for pattern in commonness_set] + subspaces = [datascope.subspace for datascope in data_scopes] + # Now, find the common subspace they share. + shared_subspace = set(subspaces[0].keys()) + for subspace in subspaces[1:]: + shared_subspace.intersection_update(subspace.keys()) + title += f"for over {self.commonness_threshold * 100}% of values of {', '.join(shared_subspace)}, " + breakdowns = set([str(datascope.breakdown) for datascope in data_scopes]) + measures = set([datascope.measure for datascope in data_scopes]) + measures_str = [] + for measure in measures: + if isinstance(measure, tuple): + measures_str.append(f"{{{measure[0]}: {measure[1]}}}") + else: + measures_str.append(measure) + title += f"when grouping by {' or '.join(breakdowns)} and aggregating by {' or '.join(measures_str)}" + title = textwrap.wrap(title, 70) + title = "\n".join(title) + return title + + def visualize_commonesses_individually(self, fig=None, subplot_spec=None, figsize=(15, 10)) -> None: + """ + Visualize only the commonness sets of the metainsight, with each set in its own column. + Within each column, patterns are arranged in a grid with at most 3 patterns per column. + This was the initial visualization method, but it was too cluttered and not very useful, so it was renamed and + replaced with the more compact and informative visualize method. + + :param fig: Optional figure to plot on (or create a new one if None) + :param subplot_spec: Optional subplot specification to plot within + :param figsize: Figure size if creating a new figure + :return: The figure with visualization + """ + # Create figure if not provided + if fig is None: + fig = plt.figure(figsize=figsize) + + # Only proceed if there are commonness sets + if not self.commonness_set: + return fig + + # Create the main grid with one column per commonness set + num_commonness_sets = len(self.commonness_set) + + if subplot_spec is not None: + # Use the provided subplot area + outer_grid = gridspec.GridSpecFromSubplotSpec(1, num_commonness_sets, + subplot_spec=subplot_spec, + wspace=0.6, hspace=0.4) + else: + # Use the entire figure + outer_grid = gridspec.GridSpec(1, num_commonness_sets, figure=fig, wspace=0.6, hspace=0.4) + + # For each commonness set + for i, patterns in enumerate(self.commonness_set): + # Calculate how many sub-columns needed for this set + num_patterns = len(patterns) + num_cols = math.ceil(num_patterns / 3) # At most 3 patterns per column + max_patterns_per_col = min(3, math.ceil(num_patterns / num_cols)) + + # Create a sub-grid for this commonness set's title and patterns + set_grid = gridspec.GridSpecFromSubplotSpec( + max_patterns_per_col + 1, # Title row + pattern rows + num_cols, + subplot_spec=outer_grid[i], + height_ratios=[0.2] + [1] * max_patterns_per_col, # Title row smaller + hspace=1.5, # Increased spacing between rows + wspace=0.5, # Increased spacing between columns + ) + + # Add the set title spanning all columns in the first row + title_ax = fig.add_subplot(set_grid[0, :]) + set_title = self._create_commonness_set_title(patterns) + title_ax.text(0.5, 0.5, set_title, + ha='center', va='center', + fontsize=12, fontweight='bold') + title_ax.axis('off') # Hide axis for the title + + # Plot each pattern + j = 0 + for pattern in patterns: + # Visualize the pattern + if hasattr(pattern, 'highlight') and pattern.highlight is not None: + # Calculate which column and row this pattern should be in + col = j // max_patterns_per_col + row = (j % max_patterns_per_col) + 1 # +1 to skip title row + # Create subplot for this pattern + ax = fig.add_subplot(set_grid[row, col]) + + pattern.highlight.visualize(ax) + + # Rotate x-axis tick labels + plt.setp(ax.get_xticklabels(), rotation=45, ha='right', fontsize=8) + + # Instead of setting title, add text box for query below the plot + query_text = pattern.data_scope.create_query_string(df_name=self.source_name) + query_text = textwrap.fill(query_text, width=40) + + # Add text box with query string instead of title + props = dict(boxstyle='round', facecolor='wheat', alpha=0.3) + ax.text(0.5, 1.5, query_text, transform=ax.transAxes, fontsize=9, + ha='center', va='top', bbox=props) + + j += 1 + + return fig + + + def _create_labels(self, patterns: List[BasicDataPattern]) -> List[str]: + """ + Create labels for the patterns in a commonness set. + :param patterns: A list of BasicDataPattern objects. + :return: A list of strings representing the labels for the patterns. + """ + labels = [] + for pattern in patterns: + subspace_str = "" + for key, val in pattern.data_scope.subspace.items(): + if isinstance(val, str): + split = val.split("<=") + if len(split) > 1: + subspace_str += f"{val}" + else: + subspace_str += f"{key} = {val}, " + else: + subspace_str += f"{key} = {val}, " + + labels.append(f"{subspace_str}") + return labels + + def visualize(self, fig=None, subplot_spec=None, figsize=(15, 10), additional_text: str = None) -> None: + """ + Visualize the metainsight, showing commonness sets on the left and exceptions on the right. + + :param fig: Matplotlib figure to plot on. If None, a new figure is created. + :param subplot_spec: GridSpec to plot on. If None, a new GridSpec is created. + :param figsize: Size of the figure if a new one is created. + :param additional_text: Optional additional text to display in the bottom-middle of the figure. + """ + # Create a new figure if not provided + # n_cols = 2 if self.exceptions and len(self.exceptions) > 0 else 1 + # Above line makes it so the plot of the commonness sets takes up the entire figure if there are no exceptions. + # However, this can potentially make for some confusion, so I elected to always use 2 columns. + n_cols = 2 + if fig is None: + fig = plt.figure(figsize=figsize) + if subplot_spec is None: + outer_grid = gridspec.GridSpec(1, n_cols, width_ratios=[1] * n_cols, figure=fig, wspace=0.2) + else: + outer_grid = gridspec.GridSpecFromSubplotSpec(1, n_cols, width_ratios=[1] * n_cols, + subplot_spec=subplot_spec, wspace=0.2) + + # Wrap the existing 1x2 layout in a 2-row local GridSpec + if additional_text: + wrapper_gs = gridspec.GridSpecFromSubplotSpec( + 2, 1, subplot_spec=subplot_spec, height_ratios=[10, 1], hspace=0.8 + ) + else: + wrapper_gs = gridspec.GridSpecFromSubplotSpec( + 1, 1, subplot_spec=subplot_spec + ) + top_gs = gridspec.GridSpecFromSubplotSpec( + 1, 2, subplot_spec=wrapper_gs[0], wspace=0.2 + ) + + # Set up the left side for commonness sets + left_grid = gridspec.GridSpecFromSubplotSpec(1, len(self.commonness_set), + subplot_spec=top_gs[0, 0], wspace=0.3) + + # Plot each commonness set in its own column + for i, commonness_set in enumerate(self.commonness_set): + if not commonness_set: # Skip empty sets + continue + + # Create a subplot for this commonness set + ax = fig.add_subplot(left_grid[0, i]) + + # Add light orange background to commonness sets + # ax.set_facecolor((1.0, 0.9, 0.8, 0.2)) # Light orange with alpha + + # Get the highlights for visualization + highlights = [pattern.highlight for pattern in commonness_set] + + # Create labels based on subspace + labels = self._create_labels(commonness_set) + + # Create title for this commonness set + title = self._create_commonness_set_title(commonness_set) + # Wrap title to prevent overflowing + title = textwrap.fill(title, width=40) + + # Call the appropriate visualize_many function based on pattern type + if highlights: + if hasattr(highlights[0], "visualize_many"): + highlights[0].visualize_many(plt_ax=ax, patterns=highlights, labels=labels, title=title) + else: + ax.set_title(title) + + # Handle exceptions area if there are any + if self.exceptions and n_cols > 1: + none_patterns_exist = self.exceptions.get("No-Pattern", None) is not None + # Set up the right side for exceptions with one row per exception type + # If there are no exceptions, we create a grid with equal height ratios for each exception type. + # Else, we create a grid where the last row is smaller if there are None exceptions. + if not none_patterns_exist: + right_grid = gridspec.GridSpecFromSubplotSpec(len(self.exceptions), 1, + subplot_spec=top_gs[0, 1], + hspace=1.2) # Add more vertical space + else: + # If there are None exceptions, place them at the bottom with very little space, since it just text + height_ratios = [10] * (len(self.exceptions) - 1) + [1] if len(self.exceptions) > 1 else [1] + right_grid = gridspec.GridSpecFromSubplotSpec(len(self.exceptions), 1, + subplot_spec=outer_grid[0, 1], + height_ratios=height_ratios, + hspace=1.4) # Add more vertical space + # Get the None patterns and "summarize" them in a dictionary + exception_patterns = self.exceptions.get("No-Pattern", []) + non_exceptions = [pattern for pattern in exception_patterns if pattern.pattern_type == PatternType.NONE] + non_exceptions_subspaces = [pattern.data_scope.subspace for pattern in non_exceptions] + non_exceptions_dict = defaultdict(list) + for subspace in non_exceptions_subspaces: + for key, val in subspace.items(): + non_exceptions_dict[key].append(val) + # Create a title for the None patterns + title = f"No patterns detected ({len(non_exceptions)})" + title = textwrap.fill(title, width=40) + # Create text saying all the values for which no patterns were detected + no_patterns_text = "" + for key, val in non_exceptions_dict.items(): + no_patterns_text += f"{key} = {val}\n" + no_patterns_text = textwrap.fill(no_patterns_text, width=60) + # Create a subplot for the None patterns + ax = fig.add_subplot(right_grid[len(self.exceptions) - 1, 0]) + # Add title and text + if len(self.exceptions) == 1: + title_y = None + text_y = 0.9 + else: + title_y = -0.3 + text_y = -1.1 + text_x = 0.5 + ax.set_title(title, y=title_y, fontsize=18, fontweight='bold') + ax.text(text_x, text_y, no_patterns_text, + ha='center', va='center', + fontsize=18) + ax.axis('off') # Hide axis for the title + + # Process each exception category + i = 0 + for category, exception_patterns in self.exceptions.items(): + if not exception_patterns: # Skip empty categories + continue + + # For "None" category, already handled it above + if category.lower() == "none" or category.lower() == "no-pattern": + continue + + + # For "highlight change" category, visualize all in one plot + if category.lower() == "highlight-change" or category.lower() == "highlight change": + ax = fig.add_subplot(right_grid[i, 0]) + # ax.set_facecolor((0.8, 0.9, 1.0, 0.2)) # Light blue with alpha + + # Get the highlights for visualization + highlights = [pattern.highlight for pattern in exception_patterns] + + # Create labels based on subspace and measure + labels = self._create_labels(exception_patterns) + + title = f"Same pattern, different highlights ({len(exception_patterns)})" + + if highlights and hasattr(highlights[0], "visualize_many"): + highlights[0].visualize_many(plt_ax=ax, patterns=highlights, labels=labels, title=title) + + # For "type change" or other categories, create a nested grid + elif category.lower() == "type-change" or category.lower() == "type change": + # Make sure there are highlights to visualize + highlights = [pattern.highlight for pattern in exception_patterns] + if all(highlight is None for highlight in highlights): + continue + + # Create a nested grid for this row with more space + type_grid = gridspec.GridSpecFromSubplotSpec(2, 1, + subplot_spec=right_grid[i, 0], + height_ratios=[1, 15], hspace=0.6, wspace=0.3) + + # Add title for the category in the first row + title_ax = fig.add_subplot(type_grid[0, 0]) + title_ax.axis('off') + title_ax.set_facecolor((0.8, 0.9, 1.0, 0.2)) + title_ax.text(0.5, 0, + s=f"Different patterns types detected ({len(exception_patterns)})", + horizontalalignment='center', + verticalalignment='center', + fontsize=16, + fontweight='bold' + ) + + # Create subplots for each pattern in the second row + num_patterns = len(exception_patterns) + # At most 2 patterns per row + n_cols = 2 if num_patterns >= 2 else 1 + n_rows = math.ceil(num_patterns / n_cols) + pattern_grid = gridspec.GridSpecFromSubplotSpec(n_rows, n_cols, + subplot_spec=type_grid[1, 0], + wspace=0.4, hspace=0.6) # More horizontal space + + + for j, pattern in enumerate(exception_patterns): + col_index = j % n_cols + row_index = j // n_cols + ax = fig.add_subplot(pattern_grid[row_index, col_index]) + # ax.set_facecolor((0.8, 0.9, 1.0, 0.2)) # Light blue with alpha + + # Format labels for title + subspace_str = ", ".join([f"{key}={val}" for key, val in pattern.data_scope.subspace.items()]) + + title = f"{pattern.highlight.__name__} when {subspace_str}" + title = "\n".join(textwrap.wrap(title, 30)) # Wrap title to prevent overflow + + # Visualize the individual pattern with internal legend + if pattern.highlight: + pattern.highlight.visualize(ax, title=title) + + i += 1 + + # If there is additional text, add it to the bottom middle of the grid + if additional_text: + text_ax = fig.add_subplot(wrapper_gs[1]) + text_ax.axis('off') + text_ax.text( + 0.5, 0.5, additional_text, + ha='center', va='center', fontsize=18 + ) + + # Allow more space for the figure elements + plt.subplots_adjust(bottom=0.15, top=0.9) # Adjust bottom and top margins + + return fig + diff --git a/src/external_explainers/metainsight_explainer/metainsight_mining.py b/src/external_explainers/metainsight_explainer/metainsight_mining.py new file mode 100644 index 0000000..4971410 --- /dev/null +++ b/src/external_explainers/metainsight_explainer/metainsight_mining.py @@ -0,0 +1,298 @@ +import itertools +from typing import List, Tuple +import numpy as np +from queue import PriorityQueue + +import pandas as pd +from matplotlib import pyplot as plt, gridspec + +from external_explainers.metainsight_explainer.data_pattern import BasicDataPattern +from external_explainers.metainsight_explainer.meta_insight import (MetaInsight, + ACTIONABILITY_REGULARIZER_PARAM, + BALANCE_PARAMETER, + COMMONNESS_THRESHOLD) +from external_explainers.metainsight_explainer.data_scope import DataScope +from external_explainers.metainsight_explainer.pattern_evaluations import PatternType +from external_explainers.metainsight_explainer.cache import Cache + +MIN_IMPACT = 0.01 + + +class MetaInsightMiner: + """ + This class is responsible for the actual process of mining MetaInsights. + The full process is described in the paper " MetaInsight: Automatic Discovery of Structured Knowledge for + Exploratory Data Analysis" by Ma et al. (2021). + """ + + def __init__(self, k=5, min_score=MIN_IMPACT, min_commonness=COMMONNESS_THRESHOLD, balance_factor=BALANCE_PARAMETER, + actionability_regularizer=ACTIONABILITY_REGULARIZER_PARAM + ): + """ + Initialize the MetaInsightMiner with the provided parameters. + + :param min_score: The minimum score for a MetaInsight to be considered. + :param min_commonness: The minimum commonness for a MetaInsight to be considered. + :param balance_factor: The balance factor for the MetaInsight. + :param actionability_regularizer: The actionability regularizer for the MetaInsight. + """ + self.k = k + self.min_score = min_score + self.min_commonness = min_commonness + self.balance_factor = balance_factor + self.actionability_regularizer = actionability_regularizer + + def _compute_variety_factor(self, metainsight: MetaInsight, included_pattern_types_count: dict) -> float: + """ + Compute the variety factor for a given MetaInsight based on the pattern types + already present in the selected set. + + :param metainsight: The MetaInsight object to compute the variety factor for. + :param included_pattern_types_count: Dictionary tracking count of selected pattern types. + :return: The variety factor between 0 and 1. + """ + # Get pattern types in this metainsight + candidate_pattern_types = [commonness[0].pattern_type for commonness in metainsight.commonness_set] + + if not candidate_pattern_types: + return 0.0 + + # Calculate how many of this metainsight's pattern types are already included + pattern_repetition = [included_pattern_types_count.get(pt, 0) for pt in candidate_pattern_types] + if any(pt == 0 for pt in pattern_repetition): + return 1 + pattern_repetition = sum(pattern_repetition) + + # Normalize by the number of pattern types in this metainsight + avg_repetition = pattern_repetition / len(candidate_pattern_types) + + # Exponential decay: variety_factor decreases as pattern repetition increases + # The 0.5 constant controls how quickly the penalty grows + variety_factor = np.exp(-0.5 * avg_repetition) + + return variety_factor + + + def rank_metainsights(self, metainsight_candidates: List[MetaInsight]): + """ + Rank the MetaInsights based on their scores. + + :param metainsight_candidates: A list of MetaInsights to rank. + :return: A list of the top k MetaInsights. + """ + + selected_metainsights = [] + # Sort candidates by score initially (descending) + candidate_set = sorted(list(set(metainsight_candidates)), key=lambda mi: mi.score, reverse=True) + + included_pattern_types_count = { + pattern_type: 0 + for pattern_type in PatternType if pattern_type != PatternType.NONE and pattern_type != PatternType.OTHER + } + + # Greedy selection of MetaInsights. + # We compute the total use of the currently selected MetaInsights, then how much a candidate would add to that. + # We take the candidate that adds the most to the total use, repeating until we have k MetaInsights or no candidates left. + while len(selected_metainsights) < self.k and candidate_set: + best_candidate = None + max_gain = -np.inf + + total_use_approx = sum(mi.score for mi in selected_metainsights) - \ + sum(mi1.compute_pairwise_overlap_score(mi2) for mi1, mi2 in + itertools.combinations(selected_metainsights, 2)) + + for candidate in candidate_set: + total_use_with_candidate = total_use_approx + (candidate.score - sum( + mi.compute_pairwise_overlap_score(candidate) for mi in selected_metainsights)) + + gain = total_use_with_candidate - total_use_approx + # Added penalty for repeating the same pattern types + variety_factor = self._compute_variety_factor(candidate, included_pattern_types_count) + gain *= variety_factor + + if gain > max_gain: + max_gain = gain + best_candidate = candidate + + if best_candidate: + selected_metainsights.append(best_candidate) + candidate_set.remove(best_candidate) + # Store a counter for the pattern types of the selected candidates + candidate_pattern_types = [commonness[0].pattern_type for commonness in best_candidate.commonness_set] + for pattern_type in candidate_pattern_types: + if pattern_type in included_pattern_types_count: + included_pattern_types_count[pattern_type] += 1 + else: + # No candidate provides a positive gain, or candidate_set is empty + break + + return selected_metainsights + + def mine_metainsights(self, source_df: pd.DataFrame, + filter_dimensions: List[str], + measures: List[Tuple[str,str]], n_bins: int = 10, + extend_by_measure: bool = False, + extend_by_breakdown: bool = False, + breakdown_dimensions: List[List[str]] = None, + ) -> List[MetaInsight]: + """ + The main function to mine MetaInsights. + Mines metainsights from the given data frame based on the provided dimensions, measures, and impact measure. + :param source_df: The source DataFrame to mine MetaInsights from. + :param breakdown_dimensions: The dimensions to consider for breakdown (groupby). + :param filter_dimensions: The dimensions to consider for applying filters on. + :param measures: The measures (aggregations) to consider for mining. + :param n_bins: The number of bins to use for numeric columns. + :param extend_by_measure: Whether to extend the data scope by measure. Settings this to true can cause strange results, + because we will consider multiple aggregation functions on the same filter dimension. + :param extend_by_breakdown: Whether to extend the data scope by breakdown. Settings this to true can cause strange results, + because we will consider multiple different groupby dimensions on the same filter dimension, which can lead to + having a metainsight on 2 disjoint sets of indexes. + :return: + """ + cache = Cache() + hdp_queue = PriorityQueue() + + if breakdown_dimensions is None: + breakdown_dimensions = filter_dimensions + + # Generate data scopes with one dimension as breakdown, all '*' subspace + base_data_scopes = [] + for breakdown_dim in breakdown_dimensions: + for measure_col, agg_func in measures: + base_data_scopes.append( + DataScope(source_df, {}, breakdown_dim, (measure_col, agg_func))) + + # Generate data scopes with one filter in subspace and one breakdown + for filter_dim in filter_dimensions: + unique_values = source_df[filter_dim].dropna().unique() + # If there are too many unique values, we bin them if it's a numeric column, or only choose the + # top 10 most frequent values if it's a categorical column + if len(unique_values) > n_bins: + if source_df[filter_dim].dtype in ['int64', 'float64']: + # Bin the numeric column + bins = pd.cut(source_df[filter_dim], bins=n_bins, retbins=True)[1] + unique_values = [f"{bins[i]} <= {filter_dim} <= {bins[i + 1]}" for i in range(len(bins) - 1)] + else: + # Choose the top 10 most frequent values + top_values = source_df[filter_dim].value_counts().nlargest(10).index.tolist() + unique_values = [v for v in unique_values if v in top_values] + for value in unique_values: + for breakdown_dim in breakdown_dimensions: + # Prevents the same breakdown dimension from being used as filter. This is because it + # is generally not very useful to groupby the same dimension as the filter dimension. + if breakdown_dim != filter_dim: + for measure_col, agg_func in measures: + base_data_scopes.append( + DataScope(source_df, {filter_dim: value}, breakdown_dim, (measure_col, agg_func))) + + # The source dataframe with a groupby on various dimensions and measures can be precomputed, + # instead of computed each time we need it. + numeric_columns = source_df.select_dtypes(include=[np.number]).columns.tolist() + for col, agg_func in measures: + groupby_key = (col, agg_func) + cache_result = cache.get_from_groupby_cache(groupby_key) + if cache_result is not None: + # Handle 'std' aggregation specially + if agg_func == 'std': + cache.add_to_groupby_cache(groupby_key, source_df.groupby(col)[numeric_columns].std(ddof=1)) + else: + cache.add_to_groupby_cache(groupby_key, source_df.groupby(col)[numeric_columns].agg(agg_func)) + + + for base_ds in base_data_scopes: + # Evaluate basic patterns for the base data scope for selected types + for pattern_type in PatternType: + if pattern_type == PatternType.OTHER or pattern_type == PatternType.NONE: + continue + base_dps = BasicDataPattern.evaluate_pattern(base_ds, source_df, pattern_type) + + for base_dp in base_dps: + if base_dp.pattern_type not in [PatternType.NONE, PatternType.OTHER]: + # If a valid basic pattern is found, extend the data scope to generate HDS + hdp = base_dp.create_hdp(group_by_dims=breakdown_dimensions, measures=measures, + pattern_type=pattern_type, + extend_by_measure=extend_by_measure, extend_by_breakdown=extend_by_breakdown) + + # Pruning 1 - if the HDP is unlikely to form a commonness, discard it + if len(hdp) < len(hdp.data_scopes) * self.min_commonness: + continue + + # Pruning 2: Discard HDS with extremely low impact + hds_impact = hdp.compute_impact() + if hds_impact < MIN_IMPACT: + continue + + # Add HDS to a queue for evaluation + hdp_queue.put((hdp, pattern_type)) + + metainsight_candidates = {} + while not hdp_queue.empty(): + hdp, pattern_type = hdp_queue.get() + + # Evaluate HDP to find MetaInsight + metainsight = MetaInsight.create_meta_insight(hdp, commonness_threshold=self.min_commonness) + + if metainsight: + # Calculate and assign the score + metainsight.compute_score() + if metainsight in metainsight_candidates: + other_metainsight = metainsight_candidates[metainsight] + if metainsight.score > other_metainsight.score: + # If the new metainsight is better, replace the old one + metainsight_candidates[metainsight] = metainsight + else: + metainsight_candidates[metainsight] = metainsight + + return self.rank_metainsights(list(metainsight_candidates)) + + +if __name__ == "__main__": + # Create a sample Pandas DataFrame (similar to the paper's example) + df = pd.read_csv("C:\\Users\\Yuval\\PycharmProjects\\pd-explain\\Examples\\Datasets\\adult.csv") + df = df.sample(5000, random_state=42) # Sample 5000 rows for testing + print(df.columns) + + # Define dimensions, measures + dimensions = ['education', 'occupation', 'marital-status'] + breakdown_dimensions = [['age'], + ['education-num'], + ['occupation'], + ['marital-status'], + ] + measures = [ + ('capital-gain', 'mean'), + ('capital-loss', 'mean'), + ('hours-per-week', 'mean'), + ('income', 'count'), + ('education-num', 'mean'), + ] + + # Run the mining process + import time + start_time = time.time() + miner = MetaInsightMiner(k=4, min_score=0.01, min_commonness=0.5) + top_metainsights = miner.mine_metainsights( + source_df=df, + filter_dimensions=dimensions, + measures=measures, + breakdown_dimensions=breakdown_dimensions, + ) + end_time = time.time() + print(f"Time taken: {end_time - start_time:.2f} seconds") + + nrows = 4 + ncols = 1 + + fig_len = 20 * ncols + fig_height = 15 * nrows + + fig = plt.figure(figsize=(fig_len, fig_height)) + main_grid = gridspec.GridSpec(nrows, ncols, figure=fig, wspace=0.2, hspace=0.3) + + for i, mi in enumerate(top_metainsights[:4]): + row, col = i, 0 + mi.visualize(fig=fig, subplot_spec=main_grid[row, col]) + + # plt.tight_layout() + plt.show() diff --git a/src/external_explainers/metainsight_explainer/pattern_evaluations.py b/src/external_explainers/metainsight_explainer/pattern_evaluations.py new file mode 100644 index 0000000..89091b3 --- /dev/null +++ b/src/external_explainers/metainsight_explainer/pattern_evaluations.py @@ -0,0 +1,238 @@ +import typing +from enum import Enum +from typing import List + +import pandas as pd +import numpy as np +from diptest import diptest +from scipy.stats import zscore +from external_explainers.metainsight_explainer.patterns import UnimodalityPattern, TrendPattern, OutlierPattern, \ + CyclePattern, PatternInterface +import pymannkendall as mk +from cydets.algorithm import detect_cycles +from singleton_decorator import singleton +from external_explainers.metainsight_explainer.cache import Cache + + +class PatternType(Enum): + """ + An enumeration of the types of patterns. + """ + NONE = 0 + OTHER = 1 + UNIMODALITY = 2 + TREND = 3 + OUTLIER = 4 + CYCLE = 5 + + +@singleton +class PatternEvaluator: + """ + A class to evaluate different patterns in a series. + """ + + def __init__(self): + self.cache = Cache() + self.OUTLIER_ZSCORE_THRESHOLD = 2.0 # Z-score threshold for outlier detection + self.TREND_SLOPE_THRESHOLD = 0.01 # Minimum absolute slope for trend detection + + + + def _is_time_series(self, series: pd.Series) -> bool: + """ + Checks if the series is a time series. + We consider a series to be a time series if its index is either a datetime index or an increasing integer index. + The second case is not always accurate, since an ordered series of numbers may not be a time series, but + we also can not discard the possibility that it is a time series. + :param series: The series to check. + :return: True if the series is a time series, False otherwise. + """ + if isinstance(series.index, pd.DatetimeIndex): + return True + elif np.issubdtype(series.index.dtype, np.number): + # Check if the index is strictly increasing + return np.all(np.diff(series.index) > 0) + else: + return False + + + def unimodality(self, series: pd.Series) -> (bool, List[UnimodalityPattern] | None): + """ + Evaluates if the series is unimodal using Hartigan's Dip test. + If it is, finds the peak or valley. + :param series: The series to evaluate. + :return: Tuple (is_unimodal, UnimodalityPattern or None if not unimodal) + """ + if isinstance(series, pd.Series): + series = series.sort_index() + else: + return False, None + vals = series.values + if len(vals) < 4: + return False, None + # Perform Hartigan's Dip test + dip_statistic, p_value = diptest(vals) + is_unimodal = p_value > 0.05 + if not is_unimodal: + return False, None + # If there is unimodality, find the valley / peak + max_value = series.max() + min_value = series.min() + # Check to make sure either the max or min happens only once, and is not at the start or end of the series + peaks = series[series == max_value] + valleys = series[series == min_value] + if len(peaks) > 1 and len(valleys) > 1: + return False, None + max_value_index = peaks.index[0] if len(peaks) == 1 else None + min_value_index = valleys.index[0] if len(valleys) == 1 else None + # If both are at the edges, this is more likely a trend than a unimodal pattern + if (max_value_index is not None and (max_value_index == series.index[0] or max_value_index == series.index[-1])) and \ + (min_value_index is not None and (min_value_index == series.index[0] or min_value_index == series.index[-1])): + return False, None + to_return = [] + # If both a peak and a valley exists, we can return both. If none exists, we return None. + if max_value_index: + to_return.append(UnimodalityPattern(series, 'Peak', max_value_index, value_name=series.name)) + elif min_value_index: + to_return.append(UnimodalityPattern(series, 'Valley', min_value_index, value_name=series.name)) + if len(to_return) == 0: + return False, None + return True, frozenset(to_return) + + + + def trend(self, series: pd.Series) -> (bool, TrendPattern | None): + """ + Evaluates if a time series exhibits a significant trend (upward or downward). + Uses the Mann-Kendall test to check for monotonic trends. + + :param series: The series to evaluate. + :return: Tuple (trend_detected, a Trend pattern object or None. None if no trend is detected) + """ + if len(series) < 2: + return False, None + + # Check if the series is a time series + if not self._is_time_series(series): + return False, None + + # Use the Mann Kendall test to check for trend. + mk_result = mk.original_test(series) + p_val = mk_result.p + # Reject or accept the null hypothesis + if p_val > 0.05 or mk_result.trend == 'no trend': + return False, None + else: + return True, TrendPattern(series, type=mk_result.trend, + slope=mk_result.slope, intercept=mk_result.intercept, value_name=series.name) + + + + def outlier(self, series: pd.Series) -> (bool, OutlierPattern): + """ + Evaluates if a series contains significant outliers. + Uses the Z-score method. + Returns (True, highlight) if outliers are detected, (False, None) otherwise. + Highlight is a list of indices of the outlier points. + """ + if len(series) < 2: + return False, (None, None) + + # Calculate Z-scores + z_scores = np.abs(zscore(series.dropna())) + + # Find indices where Z-score exceeds the threshold + outlier_indices = np.where(z_scores > self.OUTLIER_ZSCORE_THRESHOLD)[0] + if len(outlier_indices) == 0: + return False, None + outlier_values = series.iloc[outlier_indices] + outlier_indexes = series.index[outlier_indices] + return True, OutlierPattern(series, outlier_indexes=outlier_indexes, + outlier_values=outlier_values, value_name=series.name + ) + + + def cycle(self, series: pd.Series) -> (bool, CyclePattern): + """ + Evaluates if a series exhibits cyclical patterns. + Uses the Cydets library to detect cycles. + :param series: The series to evaluate. + :return: Tuple (is_cyclical, CyclePattern or None) + """ + if len(series) < 2: + return False, None + + # Ensure the series has enough variability to detect cycles + if series.std() < 1e-10 or (series.max() - series.min()) < 1e-8: + return False, None + + # Quick pre-filtering using autocorrelation (much faster than full detection) + # Suppress the specific divide-by-zero warnings during autocorrelation calculation + with np.errstate(divide='ignore', invalid='ignore'): + # Quick pre-filtering using autocorrelation + if len(series) >= 20: + # Handle possible NaN results from autocorrelation + try: + autocorr = pd.Series(series.values).autocorr(lag=len(series) // 4) + if pd.isna(autocorr) or abs(autocorr) < 0.3: # Check for NaN and low correlation + return False, None + except (ValueError, ZeroDivisionError): + return False, None + + # Check if the series is a time series + if not self._is_time_series(series): + return False, None + + # Detect cycles using Cydets + try: + cycle_info = detect_cycles(series) + if cycle_info is not None and len(cycle_info) > 0: + return True, CyclePattern(series, cycle_info, value_name=series.name) + return False, None + # For some godforsaken reason, Cydets throws a ValueError when it fails to detect cycles, instead of + # returning None like it should. And so, we have this incredibly silly try/except block. + except ValueError: + return False, None + + + + def __call__(self, series: pd.Series, pattern_type: PatternType) -> (bool, frozenset[PatternInterface] | None): + """ + Calls the appropriate pattern evaluation method based on the pattern type. + :param series: The series to evaluate. + :param pattern_type: The type of the pattern to evaluate. + :return: (is_valid, highlight) + """ + series_hash = hash(tuple(series.values)) + cache_key = (series_hash, pattern_type) + + cache_result = self.cache.get_from_pattern_eval_cache(cache_key) + if cache_result is not None: + # If the result is already cached, return it + return cache_result + + series = series[~series.isna()] # Remove NaN values + series = series.sort_index() # Sort the series by index + + if pattern_type == PatternType.UNIMODALITY: + result = self.unimodality(series) + elif pattern_type == PatternType.TREND: + result = self.trend(series) + elif pattern_type == PatternType.OUTLIER: + result = self.outlier(series) + elif pattern_type == PatternType.CYCLE: + result = self.cycle(series) + else: + raise ValueError(f"Unsupported pattern type: {pattern_type}") + is_valid = result[0] if isinstance(result, tuple) else False + patterns = result[1] if isinstance(result, tuple) else None + # If the returned patterns are not a frozenset, convert them to one + if not isinstance(patterns, frozenset): + if not isinstance(patterns, typing.Iterable): + patterns = frozenset([patterns]) + else: + patterns = frozenset(patterns) + # Add the result to the cache + self.cache.add_to_pattern_eval_cache(cache_key, (is_valid, patterns)) + return is_valid, patterns diff --git a/src/external_explainers/metainsight_explainer/patterns.py b/src/external_explainers/metainsight_explainer/patterns.py new file mode 100644 index 0000000..681c5a2 --- /dev/null +++ b/src/external_explainers/metainsight_explainer/patterns.py @@ -0,0 +1,771 @@ +from abc import ABC, abstractmethod +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +from typing import Literal, List + + +class PatternInterface(ABC): + """ + Abstract base class for defining patterns. + """ + + @abstractmethod + def visualize(self, plt_ax, title: str = None) -> None: + """ + Visualize the pattern. + """ + # Note for all the implementations below: all of them just use the visualize_many method internally, + # because that one handles all the complex cases already and can also visualize just one pattern. + raise NotImplementedError("Subclasses must implement this method.") + + @abstractmethod + def __eq__(self, other) -> bool: + """ + Check if two patterns are equal + :param other: Another pattern of the same type + :return: + """ + raise NotImplementedError("Subclasses must implement this method.") + + @abstractmethod + def __repr__(self) -> str: + """ + String representation of the pattern. + """ + raise NotImplementedError("Subclasses must implement this method.") + + + @abstractmethod + def __str__(self) -> str: + """ + String representation of the pattern. + """ + raise NotImplementedError("Subclasses must implement this method.") + + + @abstractmethod + def __hash__(self) -> int: + """ + Hash representation of the pattern. + """ + raise NotImplementedError("Subclasses must implement this method.") + + @staticmethod + def prepare_patterns_for_visualization(patterns): + """ + Prepare patterns for visualization by creating a consistent numeric position mapping. + Returns a mapping of original indices to numeric positions for plotting. + + :param patterns: List of pattern objects with source_series attribute + :return: Dictionary mapping original indices to positions and sorted unique indices + """ + # Collect all unique indices from all patterns + all_indices = set() + for pattern in patterns: + all_indices.update(pattern.source_series.index) + + # Sort indices in their natural order - this works for dates, numbers, etc. + sorted_indices = sorted(list(all_indices)) + + # Create mapping from original index to position (0, 1, 2, ...) + index_to_position = {idx: pos for pos, idx in enumerate(sorted_indices)} + + return index_to_position, sorted_indices + + + @staticmethod + def handle_sorted_indices(plt_ax, sorted_indices): + """ + Handle setting x-ticks and labels for the plot based on sorted indices. + :param plt_ax: The matplotlib axes to set ticks on + :param sorted_indices: The sorted indices to use for x-ticks + """ + # For large datasets, show fewer tick labels + step = max(1, len(sorted_indices) // 10) + positions = list(range(0, len(sorted_indices), step)) + tick_labels = [str(sorted_indices[pos]) for pos in positions] + + plt_ax.set_xticks(positions) + plt_ax.set_xticklabels(tick_labels, rotation=45, ha='right', fontsize=16) + + + @staticmethod + @abstractmethod + def visualize_many(plt_ax, patterns: List['PatternInterface'], labels:List[str], title: str = None) -> None: + """ + Visualize many patterns of the same type on the same plot. + :param plt_ax: The matplotlib axes to plot on + :param patterns: The patterns to plot + :param labels: The labels to display in the legend. + :param title: The title of the plot + """ + raise NotImplementedError("Subclasses must implement this method.") + + __name__ = "PatternInterface" + + +class UnimodalityPattern(PatternInterface): + + __name__ = "Unimodality pattern" + + @staticmethod + def visualize_many(plt_ax, patterns: List['UnimodalityPattern'], labels: List[str], title: str = None) -> None: + """ + Visualize multiple unimodality patterns on a single plot. + + :param plt_ax: Matplotlib axes to plot on + :param patterns: List of UnimodalityPattern objects + :param labels: List of labels for each pattern (e.g. data scope descriptions) + """ + # Define a color cycle for lines + colors = plt.cm.tab10.colors + + # Prepare patterns with consistent numeric positions + index_to_position, sorted_indices = PatternInterface.prepare_patterns_for_visualization(patterns) + + # Plot each pattern + for i, (pattern, label) in enumerate(zip(patterns, labels)): + color = colors[i % len(colors)] + + # Map series to numeric positions for plotting + x_positions = [index_to_position[idx] for idx in pattern.source_series.index] + values = pattern.source_series.values + + # Plot the series with a unique color + plt_ax.plot(x_positions, values, color=color, alpha=0.7, label=label) + + # Highlight the peak or valley with a marker + if pattern.type.lower() == 'peak' and pattern.highlight_index in pattern.source_series.index: + highlight_pos = index_to_position[pattern.highlight_index] + plt_ax.plot(highlight_pos, pattern.source_series.loc[pattern.highlight_index], + 'o', color=color, markersize=8, markeredgecolor='black') + elif pattern.type.lower() == 'valley' and pattern.highlight_index in pattern.source_series.index: + highlight_pos = index_to_position[pattern.highlight_index] + plt_ax.plot(highlight_pos, pattern.source_series.loc[pattern.highlight_index], + 'v', color=color, markersize=8, markeredgecolor='black') + + # Set x-ticks to show original index values + if sorted_indices: + PatternInterface.handle_sorted_indices(plt_ax, sorted_indices) + + # Set labels and title + plt_ax.set_xlabel(patterns[0].index_name if patterns else 'Index') + plt_ax.set_ylabel(patterns[0].value_name if patterns else 'Value') + plt_ax.set_title( + f"Multiple {patterns[0].type if patterns else 'Unimodality'} Patterns" if title is None else title) + + # Add legend + plt_ax.legend() + + # Rotate x-axis tick labels + plt.setp(plt_ax.get_xticklabels(), rotation=45, ha='right', fontsize=16) + + # Ensure bottom margin for x-labels + plt_ax.figure.subplots_adjust(bottom=0.15) + + def __init__(self, source_series: pd.Series, type: Literal['Peak', 'Valley'], highlight_index, value_name: str=None): + """ + Initialize the UnimodalityPattern with the provided parameters. + + :param source_series: The source series to evaluate. + :param type: The type of the pattern. Either 'Peak' or 'Valley' is expected. + :param highlight_index: The index of the peak or valley. + :param value_name: The name of the value to display. + """ + self.source_series = source_series + self.type = type + self.highlight_index = highlight_index + self.index_name = source_series.index.name if source_series.index.name else 'Index' + self.value_name = value_name if value_name else 'Value' + self.hash = None + + def visualize(self, plt_ax, title: str = None) -> None: + """ + Visualize the unimodality pattern. + :return: + """ + self.visualize_many(plt_ax, [self], [self.value_name], title=None) + if title is not None: + plt_ax.set_title(title) + else: + plt_ax.set_title(f"{self.type} at {self.highlight_index} in {self.value_name}") + + + def __eq__(self, other) -> bool: + """ + Check if two UnimodalityPattern objects are equal. + :param other: Another UnimodalityPattern object. + :return: True if they are equal, False otherwise. They are considered equal if they have the same type, + the same highlight index. + """ + if not isinstance(other, UnimodalityPattern): + return False + if not type(self.highlight_index) == type(other.highlight_index): + return False + return (self.type == other.type and + self.highlight_index == other.highlight_index) + + + def __repr__(self) -> str: + """ + String representation of the UnimodalityPattern. + :return: A string representation of the UnimodalityPattern. + """ + return f"UnimodalityPattern(type={self.type}, highlight_index={self.highlight_index})" + + def __str__(self) -> str: + """ + String representation of the UnimodalityPattern. + :return: A string representation of the UnimodalityPattern. + """ + return f"UnimodalityPattern(type={self.type}, highlight_index={self.highlight_index})" + + def __hash__(self) -> int: + """ + Hash representation of the UnimodalityPattern. + :return: A hash representation of the UnimodalityPattern. + """ + if self.hash is not None: + return self.hash + self.hash = hash(f"UnimodalityPattern(type={self.type}, highlight_index={self.highlight_index})") + return self.hash + + + +class TrendPattern(PatternInterface): + + __name__ = "Trend pattern" + + @staticmethod + def visualize_many(plt_ax, patterns: List['TrendPattern'], labels: List[str], title: str = None, + show_data: bool = True, alpha_data: float = 0.5) -> None: + """ + Visualize multiple trend patterns on a single plot. + + :param plt_ax: Matplotlib axes to plot on + :param patterns: List of TrendPattern objects + :param labels: List of labels for each pattern + :param title: Optional custom title for the plot + :param show_data: Whether to show the raw data points (can be set to False if too cluttered) + :param alpha_data: Opacity of the raw data (lower value reduces visual clutter) + """ + # Define a color cycle for lines + colors = plt.cm.tab10.colors + + # Define line styles for additional differentiation. + # Taken from the matplotlib docs. + line_styles = [ + ('loosely dotted', (0, (1, 10))), + ('dotted', (0, (1, 5))), + ('densely dotted', (0, (1, 1))), + ('long dash with offset', (5, (10, 3))), + ('loosely dashed', (0, (5, 10))), + ('dashed', (0, (5, 5))), + ('densely dashed', (0, (5, 1))), + ('loosely dashdotted', (0, (3, 10, 1, 10))), + ('dashdotted', (0, (3, 5, 1, 5))), + ('densely dashdotted', (0, (3, 1, 1, 1))), + ('dashdotdotted', (0, (3, 5, 1, 5, 1, 5))), + ('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))), + ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))] + + # Prepare patterns with consistent numeric positions + index_to_position, sorted_indices = PatternInterface.prepare_patterns_for_visualization(patterns) + + for i, (pattern, label) in enumerate(zip(patterns, labels)): + color = colors[i % len(colors)] + line_style = line_styles[i % len(line_styles)][1] + + # Map series to numeric positions for plotting + x_positions = [index_to_position[idx] for idx in pattern.source_series.index] + values = pattern.source_series.values + + # Plot the trend line using numeric positions + trend_label = f"{label}" + x_range = np.arange(len(sorted_indices)) + plt_ax.plot(x_range, pattern.slope * x_range + pattern.intercept, + linestyle=line_style, color=color, linewidth=2, label=trend_label + " (trend line)") + + # Set x-ticks to show original index values + if sorted_indices: + PatternInterface.handle_sorted_indices(plt_ax, sorted_indices) + + # Compute the mean value across the data as a whole, and visualize that line, if show_data is True + if show_data: + # Collect all data points from all patterns + mean_dict = { + idx: [] for idx in index_to_position.keys() + } + for idx in index_to_position: + for pattern in patterns: + if idx in pattern.source_series.index: + mean_dict[idx].append(pattern.source_series.loc[idx]) + # Compute the overall mean series + overall_mean_series = pd.Series( + {idx: np.mean(values) for idx, values in mean_dict.items()}, + name='Overall Mean Data', + index=index_to_position + ) + mean_x_positions = [index_to_position.get(idx) for idx in overall_mean_series.index if + idx in index_to_position] + mean_values = [overall_mean_series.loc[idx] for idx in overall_mean_series.index if + idx in index_to_position] + plt_ax.plot(mean_x_positions, mean_values, color='gray', alpha=alpha_data, linewidth=5, + label='Mean Over All Data') + + # Set labels and title + if patterns: + plt_ax.set_xlabel(patterns[0].source_series.index.name if patterns[0].source_series.index.name else 'Index') + plt_ax.set_ylabel(patterns[0].value_name if patterns[0].value_name else 'Value') + + default_title = f"Multiple Trend Patterns" + plt_ax.set_title(title if title is not None else default_title) + + # Rotate x-axis tick labels + plt.setp(plt_ax.get_xticklabels(), rotation=45, ha='right', fontsize=16) + + # Add legend + plt_ax.legend() + + # Ensure bottom margin for x-labels + plt_ax.figure.subplots_adjust(bottom=0.15) + + def __init__(self, source_series: pd.Series, type: Literal['Increasing', 'Decreasing'], + slope: float, intercept: float = 0, value_name: str = None): + """ + Initialize the Trend pattern with the provided parameters. + + :param source_series: The source series to evaluate. + :param type: The type of the pattern. + :param slope: The slope of the trend. + """ + self.source_series = source_series + self.type = type + self.slope = slope + self.intercept = intercept + self.value_name = value_name if value_name else 'Value' + self.hash = None + + def visualize(self, plt_ax, title: str = None) -> None: + """ + Visualize the trend pattern. + :param plt_ax: + :return: + """ + self.visualize_many(plt_ax, [self], [self.value_name], title=None) + if title is not None: + plt_ax.set_title(title) + else: + plt_ax.set_title(f"{self.type} trend in {self.value_name} with slope {self.slope:.2f} and intercept {self.intercept:.2f}") + + def __eq__(self, other) -> bool: + """ + Check if two TrendPattern objects are equal. + :param other: Another TrendPattern object. + :return: True if they are equal, False otherwise. They are considered equal if they have the same type + (increasing / decreasing) (we trust that comparisons will be done on the same series). + """ + if not isinstance(other, TrendPattern): + return False + # We do not compare the slope and intercept - we only care about the type of trend + return self.type == other.type + + + def __repr__(self) -> str: + """ + String representation of the TrendPattern. + :return: A string representation of the TrendPattern. + """ + return f"TrendPattern(type={self.type})" + + def __str__(self) -> str: + """ + String representation of the TrendPattern. + :return: A string representation of the TrendPattern. + """ + return f"TrendPattern(type={self.type})" + + def __hash__(self) -> int: + """ + Hash representation of the TrendPattern. + :return: A hash representation of the TrendPattern. + """ + if self.hash is not None: + return self.hash + self.hash = hash(f"TrendPattern(type={self.type})") + return self.hash + + +class OutlierPattern(PatternInterface): + + __name__ = "Outlier pattern" + + @staticmethod + def visualize_many(plt_ax, patterns: List['OutlierPattern'], labels: List[str], title: str = None, + show_regular: bool = True, alpha_regular: float = 0.5, alpha_outliers: float = 0.9) -> None: + """ + Visualize multiple outlier patterns on a single plot. + """ + colors = plt.cm.tab10.colors + regular_marker = 'o' + outlier_marker = 'X' + + # Prepare patterns with consistent numeric positions + index_to_position, sorted_indices = PatternInterface.prepare_patterns_for_visualization(patterns) + + # Plot each pattern + for i, (pattern, label) in enumerate(zip(patterns, labels)): + color = colors[i % len(colors)] + + # Plot regular data points + if show_regular: + # Get positions and values for plotting + positions = [index_to_position[idx] for idx in pattern.source_series.index] + values = pattern.source_series.values + + plt_ax.scatter( + positions, + values, + color=color, + alpha=alpha_regular, + marker=regular_marker, + s=30, + label=label + ) + else: + plt_ax.scatter([], [], color=color, marker=regular_marker, s=30, label=label) + + # Plot outliers + if pattern.outlier_indexes is not None and len(pattern.outlier_indexes) > 0: + # Map outliers to positions + outlier_positions = [] + outlier_values = [] + + for idx in pattern.outlier_indexes: + if idx in pattern.source_series.index: + outlier_positions.append(index_to_position[idx]) + outlier_values.append(pattern.source_series.loc[idx]) + + plt_ax.scatter( + outlier_positions, + outlier_values, + color=color, + alpha=alpha_outliers, + marker=outlier_marker, + s=100, + edgecolors='black', + linewidth=1.5 + ) + + # Set x-ticks to show original index values + if sorted_indices: + PatternInterface.handle_sorted_indices(plt_ax, sorted_indices) + + # Setup the rest of the plot + from matplotlib.lines import Line2D + custom_lines = [Line2D([0], [0], marker=outlier_marker, color='black', + markerfacecolor='black', markersize=10, linestyle='')] + custom_labels = ['Outliers (marked with X)'] + + # Set labels and title + if patterns: + plt_ax.set_xlabel(patterns[0].source_series.index.name if patterns[0].source_series.index.name else 'Index') + plt_ax.set_ylabel(patterns[0].value_name if patterns[0].value_name else 'Value') + + plt_ax.set_title(title if title is not None else "Multiple Outlier Patterns") + + # Setup legend + handles, labels_current = plt_ax.get_legend_handles_labels() + all_handles = handles + custom_lines + all_labels = labels_current + custom_labels + plt_ax.legend(all_handles, all_labels) + + plt.setp(plt_ax.get_xticklabels(), rotation=45, ha='right', fontsize=16) + + # Ensure bottom margin for x-labels + plt_ax.figure.subplots_adjust(bottom=0.15) + + def __init__(self, source_series: pd.Series, outlier_indexes: pd.Index, outlier_values: pd.Series, + value_name: str = None): + """ + Initialize the Outlier pattern with the provided parameters. + + :param source_series: The source series to evaluate. + :param outlier_indexes: The indexes of the outliers. + :param outlier_values: The values of the outliers. + """ + self.source_series = source_series + self.outlier_indexes = outlier_indexes + self.outlier_values = outlier_values + self.value_name = value_name if value_name else 'Value' + self.hash = None + + def visualize(self, plt_ax, title: str = None) -> None: + """ + Visualize the outlier pattern. + :param plt_ax: + :return: + """ + self.visualize_many(plt_ax, [self], [self.value_name], title=None) + if title is not None: + plt_ax.set_title(title) + else: + plt_ax.set_title(f"Outliers in {self.value_name} at {self.outlier_indexes.tolist()}") + + + def __eq__(self, other): + """ + Check if two OutlierPattern objects are equal. + :param other: Another OutlierPattern object. + :return: True if they are equal, False otherwise. They are considered equal if the index set of one is a subset + of the other or vice versa. + """ + if not isinstance(other, OutlierPattern): + return False + # If one index is a multi-index and the other is not, for example, they cannot be equal + if not type(self.outlier_indexes) == type(other.outlier_indexes): + return False + return self.outlier_indexes.isin(other.outlier_indexes).all() or \ + other.outlier_indexes.isin(self.outlier_indexes).all() + + def __repr__(self) -> str: + """ + String representation of the OutlierPattern. + :return: A string representation of the OutlierPattern. + """ + return f"OutlierPattern(outlier_indexes={self.outlier_indexes})" + + def __str__(self) -> str: + """ + String representation of the OutlierPattern. + :return: A string representation of the OutlierPattern. + """ + return f"OutlierPattern(outlier_indexes={self.outlier_indexes})" + + def __hash__(self) -> int: + """ + Hash representation of the OutlierPattern. + :return: A hash representation of the OutlierPattern. + """ + if self.hash is not None: + return self.hash + self.hash = hash(f"OutlierPattern(outlier_indexes={self.outlier_indexes})") + return self.hash + + +class CyclePattern(PatternInterface): + + __name__ = "Cycle pattern" + + @staticmethod + def visualize_many(plt_ax, patterns: List['CyclePattern'], labels: List[str], title: str = None, + alpha_cycles: float = 0.3, line_alpha: float = 0.8) -> None: + """ + Visualize multiple cycle patterns on a single plot with common cycles highlighted. + + :param plt_ax: Matplotlib axes to plot on + :param patterns: List of CyclePattern objects + :param labels: List of labels for each pattern + :param title: Optional custom title for the plot + :param alpha_cycles: Opacity for the highlighted cycle regions + :param line_alpha: Opacity for the time series lines + """ + import numpy as np + + # Define a color cycle for lines + colors = plt.cm.tab10.colors + + # Prepare patterns with consistent numeric positions + index_to_position, sorted_indices = PatternInterface.prepare_patterns_for_visualization(patterns) + + # Color for common cycles + common_cycle_color = 'darkviolet' + + # Plot each dataset and collect legend handles + legend_handles = [] + legend_labels = [] + + # First, identify time ranges covered by cycles for each pattern + all_cycle_data = [] + + for pattern in patterns: + if hasattr(pattern, 'cycles') and not pattern.cycles.empty: + for _, cycle in pattern.cycles.iterrows(): + # Map to numeric positions + t_start_pos = index_to_position.get(cycle['t_start'], None) + t_end_pos = index_to_position.get(cycle['t_end'], None) + if t_start_pos is not None and t_end_pos is not None: + all_cycle_data.append((t_start_pos, t_end_pos)) + + # Find common cycle periods (using numeric positions) + common_periods = [] + if len(patterns) > 1 and all_cycle_data: + # Get all unique numeric positions from starts and ends + all_positions = sorted(list(set([pos for start, end in all_cycle_data for pos in [start, end]]))) + + # Create additional points between positions if needed + if len(all_positions) > 1: + position_points = np.linspace(min(all_positions), max(all_positions), 100) + else: + position_points = all_positions + + # For each position point, check if it falls within a cycle for each pattern + overlap_counts = np.zeros(len(position_points)) + + for pattern in patterns: + if hasattr(pattern, 'cycles') and not pattern.cycles.empty: + pattern_mask = np.zeros(len(position_points), dtype=bool) + for _, cycle in pattern.cycles.iterrows(): + t_start_pos = index_to_position.get(cycle['t_start'], None) + t_end_pos = index_to_position.get(cycle['t_end'], None) + if t_start_pos is not None and t_end_pos is not None: + pattern_mask = pattern_mask | ( + (position_points >= t_start_pos) & (position_points <= t_end_pos)) + overlap_counts += pattern_mask + + # Find regions where all patterns have a cycle + common_mask = overlap_counts == len(patterns) + + # Find contiguous regions of common cycles + if np.any(common_mask): + changes = np.diff(np.concatenate(([0], common_mask.astype(int), [0]))) + start_indices = np.where(changes == 1)[0] + end_indices = np.where(changes == -1)[0] - 1 + + for start_idx, end_idx in zip(start_indices, end_indices): + common_periods.append((position_points[start_idx], position_points[end_idx])) + + # Plot each pattern + for i, (pattern, label) in enumerate(zip(patterns, labels)): + color = colors[i % len(colors)] + + # Map series to numeric positions for plotting + x_positions = [index_to_position[idx] for idx in pattern.source_series.index] + values = pattern.source_series.values + + # Plot the time series + line, = plt_ax.plot(x_positions, values, color=color, alpha=line_alpha, linewidth=2, label=label) + legend_handles.append(line) + legend_labels.append(label) + + # Highlight each cycle with a semi-transparent fill + if hasattr(pattern, 'cycles') and not pattern.cycles.empty: + # Add individual cycle legend element + cycle_patch = plt.Rectangle((0, 0), 1, 1, color=color, alpha=alpha_cycles) + + for _, cycle in pattern.cycles.iterrows(): + t_start_pos = index_to_position.get(cycle['t_start'], None) + t_end_pos = index_to_position.get(cycle['t_end'], None) + + if t_start_pos is None or t_end_pos is None: + continue + + # Check if this cycle overlaps with common cycles + is_common = any( + start <= t_start_pos <= end and start <= t_end_pos <= end + for start, end in common_periods + ) + + # Highlight the cycle only if it is not in the common cycles + if not is_common: + # Highlight the cycle region + plt_ax.axvspan(t_start_pos, t_end_pos, color=color, alpha=alpha_cycles) + + # Highlight common cycles + if common_periods: + for start, end in common_periods: + plt_ax.axvspan(start, end, color=common_cycle_color, alpha=alpha_cycles * 1.5, zorder=-1) + + # Add legend item for common cycles + common_patch = plt.Rectangle((0, 0), 1, 1, color=common_cycle_color, alpha=alpha_cycles * 1.5) + legend_handles.append(common_patch) + legend_labels.append('Common cycles (all patterns)') + + # Set x-ticks to show original index values + if sorted_indices: + PatternInterface.handle_sorted_indices(plt_ax, sorted_indices) + + # Set labels and title + if patterns: + plt_ax.set_xlabel(patterns[0].source_series.index.name if patterns[0].source_series.index.name else 'Index') + plt_ax.set_ylabel(patterns[0].value_name if patterns[0].value_name else 'Value') + + default_title = "Multiple Cycle Patterns" + plt_ax.set_title(title if title is not None else default_title) + + # Add legend + plt_ax.legend(legend_handles, legend_labels) + + plt.setp(plt_ax.get_xticklabels(), rotation=45, ha='right', fontsize=16) + + # Ensure bottom margin for x-labels + plt_ax.figure.subplots_adjust(bottom=0.15) + + def __init__(self, source_series: pd.Series, cycles: pd.DataFrame, value_name: str = None): + """ + Initialize the Cycle pattern with the provided parameters. + + :param source_series: The source series to evaluate. + :param cycles: The cycles detected in the series. + """ + self.source_series = source_series + # Cycles is a dataframe with the columns: t_start, t_end, t_minimum, doc, duration + self.cycles = cycles + self.hash = None + self._cycle_tuples = frozenset((row['t_start'], row['t_end']) for _, row in cycles.iterrows()) + self.value_name = value_name if value_name else 'Value' + + def visualize(self, plt_ax, title: str = None): + """ + Visualize the cycle pattern. + :param plt_ax: + :return: + """ + self.visualize_many(plt_ax, [self], [self.value_name], title=None, alpha_cycles=0.5, line_alpha=0.8) + if title is not None: + plt_ax.set_title(title) + else: + plt_ax.set_title(f"Cycles in {self.value_name} at {self._cycle_tuples}") + + def __eq__(self, other): + """ + Check if two CyclePattern objects are equal. + :param other: + :return: True if they are equal, False otherwise. They are considered equal if the cycles of one are a + subset of the other or vice versa. + """ + if not isinstance(other, CyclePattern): + return False + + # Use precomputed cycle tuples instead of computing them each time + return self._cycle_tuples.issubset(other._cycle_tuples) or other._cycle_tuples.issubset(self._cycle_tuples) + + def __repr__(self) -> str: + """ + String representation of the CyclePattern. + :return: A string representation of the CyclePattern. + """ + return f"CyclePattern(cycles={self.cycles})" + + def __str__(self) -> str: + """ + String representation of the CyclePattern. + :return: A string representation of the CyclePattern. + """ + return f"CyclePattern(cycles={self.cycles})" + + def __hash__(self) -> int: + """ + Hash representation of the CyclePattern. + :return: A hash representation of the CyclePattern. + """ + if self.hash is not None: + return self.hash + # Create a hashable representation of the key cycle properties + if len(self.cycles) == 0: + return hash("empty_cycle") + # Use a tuple of tuples for cycle start/end times + cycle_tuples = tuple((row['t_start'], row['t_end']) for _, row in self.cycles.iterrows()) + self.hash = hash(cycle_tuples) + return self.hash \ No newline at end of file diff --git a/src/external_explainers/outlier_explainer/outlier_explainer.py b/src/external_explainers/outlier_explainer/outlier_explainer.py index 0a2dfc1..3e7cf91 100644 --- a/src/external_explainers/outlier_explainer/outlier_explainer.py +++ b/src/external_explainers/outlier_explainer/outlier_explainer.py @@ -66,7 +66,11 @@ def calc_influence_pred(self, df_before: DataFrame, df_after: DataFrame, target: try: # Compute target influence - the ratio between the change in the output and the number of # tuples that satisfy the predicate, multiplied by the direction factor. - target_inf = ((df_before[target] - df_after[target]) * dir) / (df_before[target] + df_after[target]) + denominator = df_before[target] + df_after[target] + # We may have a try catch here, but division by zero is still causing a runtime warning. + if denominator == 0: + return -1 + target_inf = ((df_before[target] - df_after[target]) * dir) / denominator except: return -1 @@ -253,8 +257,31 @@ def compute_predicates_per_attribute(self, attr: str, df_in: DataFrame, g_att: s return predicates + + def pred_to_human_readable(self, non_formatted_pred): + explanation = f'This outlier is not as significant when excluding rows with:\n' + for_wizard = '' + for a, bins in non_formatted_pred.items(): + for b in bins: + if type(b[0]) is tuple: + pred = f"{b[0][0]} < {a} < {b[0][1]}" + inter_exp = r'$\bf{{{}}}$'.format(utils.to_valid_latex(pred)) + else: + pred = f"{a}={b[0]}" + inter_exp = r'$\bf{{{}}}$'.format(utils.to_valid_latex(pred)) + if b[1] is not None: + if b[1] <= 5: + inter_exp = inter_exp + '-' + r'$\bf{low}$' + elif b[1] >= 25: + inter_exp = inter_exp + '-' + r'$\bf{high}$' + inter_exp += '\n' + for_wizard += inter_exp + explanation += inter_exp + + return explanation, for_wizard + def draw_bar_plot(self, df_agg: DataFrame | Series, final_df: DataFrame, g_att: str, g_agg: str, final_pred_by_attr: dict, - target: str, agg_title: str) -> None: + target: str, agg_title: str, added_text: dict = None) -> None: """ Draw a bar plot to visualize the influence of predicates on the target attribute. @@ -269,10 +296,11 @@ def draw_bar_plot(self, df_agg: DataFrame | Series, final_df: DataFrame, g_att: :param final_pred_by_attr: Dictionary containing the final predicates grouped by attribute. :param target: The target attribute for which the influence is being visualized. :param agg_title: Title for the aggregation method used in the plot. + :param added_text: Additional text to add to the plot. Optional. Expected: dict with 'text' and 'position' keys. :return: None. Displays the bar plot. """ - fig, ax = plt.subplots(layout='constrained', figsize=(5, 5)) + fig, ax = plt.subplots(figsize=(5, 5)) x1 = list(df_agg.index) ind1 = np.arange(len(x1)) y1 = df_agg.values @@ -281,24 +309,7 @@ def draw_bar_plot(self, df_agg: DataFrame | Series, final_df: DataFrame, g_att: ind2 = np.arange(len(x2)) y2 = final_df.values - explanation = f'This outlier is not as significant when excluding rows with:\n' - for_wizard = '' - for a, bins in final_pred_by_attr.items(): - for b in bins: - if type(b[0]) is tuple: - pred = f"{b[0][0]} < {a} < {b[0][1]}" - inter_exp = r'$\bf{{{}}}$'.format(utils.to_valid_latex(pred)) - else: - pred = f"{a}={b[0]}" - inter_exp = r'$\bf{{{}}}$'.format(utils.to_valid_latex(pred)) - if b[1] is not None: - if b[1] <= 5: - inter_exp = inter_exp + '-' + r'$\bf{low}$' - elif b[1] >= 25: - inter_exp = inter_exp + '-' + r'$\bf{high}$' - inter_exp += '\n' - for_wizard += inter_exp - explanation += inter_exp + explanation, for_wizard = self.pred_to_human_readable(final_pred_by_attr) bar1 = ax.bar(ind1 - 0.2, y1, 0.4, alpha=1., label='All') bar2 = ax.bar(ind2 + 0.2, y2, 0.4, alpha=1., label=f'without\n{for_wizard}') @@ -314,11 +325,52 @@ def draw_bar_plot(self, df_agg: DataFrame | Series, final_df: DataFrame, g_att: bar2[x2.index(target)].set_linewidth(2) ax.get_xticklabels()[x1.index(target)].set_color('tab:green') + plt.tight_layout() + + if added_text is not None: + # Draw the plot first to establish the bounding boxes. + plt.draw() + text = added_text['text'] + position = added_text['position'] + renderer = ax.figure.canvas.get_renderer() + max_label_height = 0 + + for label in ax.get_xticklabels() + [ax.xaxis.get_label()]: + bbox = label.get_window_extent(renderer=renderer) + if bbox.height > max_label_height: + max_label_height = bbox.height + + if position == "bottom": + offset_in_points = -(max_label_height + 10) + + ax.annotate( + text, + xy=(0.5, 0), # anchor at the bottom of the axes + xycoords='axes fraction', + xytext=(0, offset_in_points), + textcoords='offset points', + ha='center', va='top', + fontsize=16 + ) + elif position == "top": + offset_in_points = max_label_height + 10 + + ax.annotate( + text, + xy=(0.5, 1), # anchor at the top of the axes + xycoords='axes fraction', + xytext=(0, offset_in_points), + textcoords='offset points', + ha='center', va='bottom', + fontsize=16 + ) + plt.show() def explain(self, df_agg: DataFrame, df_in: DataFrame, g_att: str, g_agg: str, agg_method: str, target: str, - dir: int, control=None, hold_out: List = [], k: int = 1) -> str | None: + dir: int, control=None, hold_out: List = None, k: int = 1, draw_plot: bool = True) \ + -> str | None | Tuple: """ Explain the outlier in the given DataFrame. @@ -340,6 +392,9 @@ def explain(self, df_agg: DataFrame, df_in: DataFrame, g_att: str, g_agg: str, a :return: None. Will generate a plot with the explanation for the outlier. """ + if hold_out is None: + hold_out = [] + # Get the attributes from the input DataFrame and remove the hold-out attributes. attrs = df_in.columns attrs = [a for a in attrs if a not in hold_out + [g_att, g_agg]] @@ -390,6 +445,9 @@ def explain(self, df_agg: DataFrame, df_in: DataFrame, g_att: str, g_agg: str, a final_pred_by_attr[a] = [] final_pred_by_attr[a].append((i, rank)) - # Create a plot to display the explanation for the outlier. - self.draw_bar_plot(df_agg, final_df, g_att, g_agg, final_pred_by_attr, target, agg_title) - return None + # Create a plot to display the explanation for the outlier, or return everything needed to draw the plot later. + if draw_plot: + self.draw_bar_plot(df_agg, final_df, g_att, g_agg, final_pred_by_attr, target, agg_title) + return None + else: + return df_agg, final_df, g_att, g_agg, final_pred_by_attr, target, agg_title