In [1]:
"""
Feature Visualizer - Comprehensive Visualization Tools for Chaotic Features.

This module provides advanced visualization capabilities for analyzing chaotic
features extracted by the C-HiLAP system, including statistical distributions,
correlations, dimensionality reduction plots, and specialized chaos visualizations.

Author: C-HiLAP Project
Date: 2025
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
from typing import Dict, List, Tuple, Optional, Union, Any, Callable
from dataclasses import dataclass, field
import pandas as pd
from pathlib import Path
import json

import os
import sys

# 导入路径设置
try:
    from setup_imports import setup_project_imports
    setup_project_imports()
except ImportError:
    # 手动设置路径
    project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
    sys.path.insert(0, project_root)
    
# Advanced plotting libraries
try:
    import plotly.graph_objects as go
    import plotly.express as px
    from plotly.subplots import make_subplots
    import plotly.offline as pyo
    PLOTLY_AVAILABLE = True
except ImportError:
    PLOTLY_AVAILABLE = False
    warnings.warn("Plotly not available. Some interactive visualizations will be disabled.")

# Dimensionality reduction for visualization
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE, UMAP
from sklearn.preprocessing import StandardScaler

# Statistical analysis
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import pdist, squareform
from scipy.cluster.hierarchy import dendrogram, linkage

# Import project modules
try:
    from features.chaotic_features import ChaoticFeatureExtractor
    from core.phase_space_reconstruction import PhaseSpaceReconstructor
    from core.mlsa_extractor import MLSAExtractor
    from core.rqa_extractor import RQAExtractor
except ImportError as e:
    warnings.warn(f"Could not import project modules: {e}")


@dataclass
class VisualizationConfig:
    """Configuration for visualization settings."""
    
    # Figure settings
    figure_size: Tuple[int, int] = (12, 8)
    dpi: int = 300
    style: str = 'whitegrid'  # seaborn style
    color_palette: str = 'husl'  # color palette
    font_size: int = 12
    title_size: int = 14
    
    # Plot settings
    alpha: float = 0.7
    line_width: float = 2.0
    marker_size: float = 50
    
    # Save settings
    save_format: str = 'png'  # 'png', 'pdf', 'svg', 'eps'
    save_dpi: int = 300
    bbox_inches: str = 'tight'
    
    # Interactive settings
    enable_interactive: bool = True
    plotly_theme: str = 'plotly_white'
    
    # Color schemes
    chaos_colors: Dict[str, str] = field(default_factory=lambda: {
        'mlsa': '#FF6B6B',
        'rqa': '#4ECDC4', 
        'traditional': '#45B7D1',
        'fused': '#96CEB4'
    })


class BaseVisualizer:
    """Base class for all visualizers."""
    
    def __init__(self, config: VisualizationConfig = None):
        self.config = config or VisualizationConfig()
        self._setup_style()
        
    def _setup_style(self):
        """Setup matplotlib and seaborn styles."""
        sns.set_style(self.config.style)
        plt.rcParams.update({
            'font.size': self.config.font_size,
            'axes.titlesize': self.config.title_size,
            'axes.labelsize': self.config.font_size,
            'xtick.labelsize': self.config.font_size,
            'ytick.labelsize': self.config.font_size,
            'legend.fontsize': self.config.font_size,
            'figure.dpi': self.config.dpi,
            'savefig.dpi': self.config.save_dpi,
            'savefig.bbox': self.config.bbox_inches
        })
    
    def _save_figure(self, fig, filename: str, output_dir: str = None):
        """Save figure with proper formatting."""
        if output_dir:
            Path(output_dir).mkdir(parents=True, exist_ok=True)
            filepath = Path(output_dir) / f"{filename}.{self.config.save_format}"
        else:
            filepath = f"{filename}.{self.config.save_format}"
        
        fig.savefig(filepath, format=self.config.save_format, 
                   dpi=self.config.save_dpi, bbox_inches=self.config.bbox_inches)
        print(f"Figure saved: {filepath}")


class FeatureDistributionVisualizer(BaseVisualizer):
    """Visualizer for feature distributions and statistics."""
    
    def plot_feature_distributions(self, features: np.ndarray, 
                                 feature_names: List[str] = None,
                                 labels: np.ndarray = None,
                                 max_features: int = 20,
                                 output_dir: str = None) -> plt.Figure:
        """
        Plot distributions of multiple features.
        
        Args:
            features: Feature matrix (n_samples, n_features)
            feature_names: Names of features
            labels: Optional class labels for coloring
            max_features: Maximum number of features to plot
            output_dir: Output directory for saving
            
        Returns:
            Matplotlib figure
        """
        n_samples, n_features = features.shape
        n_plot = min(n_features, max_features)
        
        # Select most variable features if too many
        if n_features > max_features:
            feature_vars = np.var(features, axis=0)
            top_indices = np.argsort(feature_vars)[-max_features:]
            features_to_plot = features[:, top_indices]
            if feature_names:
                names_to_plot = [feature_names[i] for i in top_indices]
            else:
                names_to_plot = [f"Feature_{i}" for i in top_indices]
        else:
            features_to_plot = features
            names_to_plot = feature_names or [f"Feature_{i}" for i in range(n_features)]
        
        # Create subplots
        n_cols = min(4, n_plot)
        n_rows = (n_plot + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3))
        if n_plot == 1:
            axes = [axes]
        elif n_rows == 1:
            axes = axes.flatten() if n_cols > 1 else [axes]
        else:
            axes = axes.flatten()
        
        for i in range(n_plot):
            ax = axes[i]
            feature_data = features_to_plot[:, i]
            
            if labels is not None:
                # Plot by class
                unique_labels = np.unique(labels)
                for j, label in enumerate(unique_labels):
                    mask = labels == label
                    ax.hist(feature_data[mask], alpha=0.6, 
                           label=f'Class {label}', bins=30)
                ax.legend()
            else:
                # Single distribution
                ax.hist(feature_data, bins=30, alpha=self.config.alpha,
                       color=sns.color_palette()[i % 10])
            
            ax.set_title(names_to_plot[i])
            ax.set_xlabel('Feature Value')
            ax.set_ylabel('Frequency')
            ax.grid(True, alpha=0.3)
        
        # Hide empty subplots
        for i in range(n_plot, len(axes)):
            axes[i].set_visible(False)
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, "feature_distributions", output_dir)
        
        return fig
    
    def plot_feature_statistics(self, features: np.ndarray,
                               feature_names: List[str] = None,
                               output_dir: str = None) -> plt.Figure:
        """Plot comprehensive feature statistics."""
        n_features = features.shape[1]
        feature_names = feature_names or [f"Feature_{i}" for i in range(n_features)]
        
        # Calculate statistics
        stats = {
            'mean': np.mean(features, axis=0),
            'std': np.std(features, axis=0),
            'skewness': self._calculate_skewness(features),
            'kurtosis': self._calculate_kurtosis(features),
            'min': np.min(features, axis=0),
            'max': np.max(features, axis=0)
        }
        
        # Create subplots
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        for i, (stat_name, stat_values) in enumerate(stats.items()):
            ax = axes[i]
            
            # Bar plot for statistics
            bars = ax.bar(range(len(stat_values)), stat_values, 
                         alpha=self.config.alpha)
            
            # Color bars by value
            colors = plt.cm.viridis(np.linspace(0, 1, len(stat_values)))
            for bar, color in zip(bars, colors):
                bar.set_color(color)
            
            ax.set_title(f'Feature {stat_name.capitalize()}')
            ax.set_xlabel('Feature Index')
            ax.set_ylabel(stat_name.capitalize())
            ax.grid(True, alpha=0.3)
            
            # Rotate x-labels if too many features
            if len(feature_names) > 20:
                ax.tick_params(axis='x', rotation=45)
            else:
                ax.set_xticks(range(len(feature_names)))
                ax.set_xticklabels(feature_names, rotation=45, ha='right')
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, "feature_statistics", output_dir)
        
        return fig
    
    def _calculate_skewness(self, data: np.ndarray) -> np.ndarray:
        """Calculate skewness for each feature."""
        mean = np.mean(data, axis=0)
        std = np.std(data, axis=0)
        n = data.shape[0]
        
        skewness = np.sum(((data - mean) / std) ** 3, axis=0) / n
        return skewness
    
    def _calculate_kurtosis(self, data: np.ndarray) -> np.ndarray:
        """Calculate kurtosis for each feature."""
        mean = np.mean(data, axis=0)
        std = np.std(data, axis=0)
        n = data.shape[0]
        
        kurtosis = np.sum(((data - mean) / std) ** 4, axis=0) / n - 3
        return kurtosis


class CorrelationVisualizer(BaseVisualizer):
    """Visualizer for feature correlations and relationships."""
    
    def plot_correlation_matrix(self, features: np.ndarray,
                               feature_names: List[str] = None,
                               method: str = 'pearson',
                               output_dir: str = None) -> plt.Figure:
        """
        Plot feature correlation matrix.
        
        Args:
            features: Feature matrix
            feature_names: Feature names
            method: Correlation method ('pearson', 'spearman')
            output_dir: Output directory
            
        Returns:
            Matplotlib figure
        """
        # Calculate correlation matrix
        if method == 'pearson':
            corr_matrix = np.corrcoef(features.T)
        elif method == 'spearman':
            from scipy.stats import spearmanr
            corr_matrix, _ = spearmanr(features, axis=0)
        else:
            raise ValueError(f"Unknown correlation method: {method}")
        
        # Create figure
        fig, ax = plt.subplots(figsize=self.config.figure_size)
        
        # Create heatmap
        im = ax.imshow(corr_matrix, cmap='RdBu_r', aspect='auto', 
                      vmin=-1, vmax=1)
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label(f'{method.capitalize()} Correlation')
        
        # Set ticks and labels
        if feature_names and len(feature_names) <= 50:
            ax.set_xticks(range(len(feature_names)))
            ax.set_yticks(range(len(feature_names)))
            ax.set_xticklabels(feature_names, rotation=45, ha='right')
            ax.set_yticklabels(feature_names)
        else:
            # Too many features, use indices
            n_ticks = min(10, len(corr_matrix))
            tick_indices = np.linspace(0, len(corr_matrix)-1, n_ticks, dtype=int)
            ax.set_xticks(tick_indices)
            ax.set_yticks(tick_indices)
        
        ax.set_title(f'Feature Correlation Matrix ({method.capitalize()})')
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, f"correlation_matrix_{method}", output_dir)
        
        return fig
    
    def plot_correlation_heatmap_seaborn(self, features: np.ndarray,
                                        feature_names: List[str] = None,
                                        method: str = 'pearson',
                                        output_dir: str = None) -> plt.Figure:
        """Plot enhanced correlation heatmap using seaborn."""
        # Create DataFrame for seaborn
        if feature_names:
            df = pd.DataFrame(features, columns=feature_names)
        else:
            df = pd.DataFrame(features)
        
        # Calculate correlation
        corr_matrix = df.corr(method=method)
        
        # Create figure
        fig, ax = plt.subplots(figsize=self.config.figure_size)
        
        # Create heatmap with seaborn
        sns.heatmap(corr_matrix, annot=len(corr_matrix) <= 20,  # Annotate if not too many
                    cmap='RdBu_r', center=0, square=True,
                    linewidths=0.5, cbar_kws={"shrink": .8}, ax=ax)
        
        ax.set_title(f'Feature Correlation Heatmap ({method.capitalize()})')
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, f"correlation_heatmap_{method}", output_dir)
        
        return fig
    
    def plot_pairwise_relationships(self, features: np.ndarray,
                                   feature_names: List[str] = None,
                                   labels: np.ndarray = None,
                                   max_features: int = 10,
                                   output_dir: str = None) -> plt.Figure:
        """Plot pairwise relationships between features."""
        n_features = min(features.shape[1], max_features)
        
        # Select most variable features
        feature_vars = np.var(features, axis=0)
        top_indices = np.argsort(feature_vars)[-n_features:]
        
        selected_features = features[:, top_indices]
        if feature_names:
            selected_names = [feature_names[i] for i in top_indices]
        else:
            selected_names = [f"F{i}" for i in top_indices]
        
        # Create DataFrame
        df = pd.DataFrame(selected_features, columns=selected_names)
        if labels is not None:
            df['Label'] = labels
        
        # Create pair plot
        if labels is not None:
            g = sns.pairplot(df, hue='Label', diag_kind='hist', 
                           plot_kws={'alpha': self.config.alpha})
        else:
            g = sns.pairplot(df, diag_kind='hist',
                           plot_kws={'alpha': self.config.alpha})
        
        g.fig.suptitle('Pairwise Feature Relationships', y=1.02)
        
        if output_dir:
            self._save_figure(g.fig, "pairwise_relationships", output_dir)
        
        return g.fig


class DimensionalityReductionVisualizer(BaseVisualizer):
    """Visualizer for dimensionality reduction and embedding plots."""
    
    def plot_pca_analysis(self, features: np.ndarray,
                         labels: np.ndarray = None,
                         n_components: int = 10,
                         output_dir: str = None) -> Tuple[plt.Figure, Dict]:
        """
        Plot PCA analysis including explained variance and scatter plots.
        
        Args:
            features: Feature matrix
            labels: Optional class labels
            n_components: Number of PCA components
            output_dir: Output directory
            
        Returns:
            Figure and PCA results dictionary
        """
        # Perform PCA
        pca = PCA(n_components=min(n_components, features.shape[1]))
        features_pca = pca.fit_transform(features)
        
        # Create subplots
        fig = plt.figure(figsize=(16, 12))
        
        # 1. Explained variance ratio
        ax1 = plt.subplot(2, 3, 1)
        plt.bar(range(1, len(pca.explained_variance_ratio_) + 1),
                pca.explained_variance_ratio_, alpha=self.config.alpha)
        plt.xlabel('Principal Component')
        plt.ylabel('Explained Variance Ratio')
        plt.title('PCA Explained Variance')
        plt.grid(True, alpha=0.3)
        
        # 2. Cumulative explained variance
        ax2 = plt.subplot(2, 3, 2)
        cumsum = np.cumsum(pca.explained_variance_ratio_)
        plt.plot(range(1, len(cumsum) + 1), cumsum, 'o-',
                linewidth=self.config.line_width)
        plt.axhline(y=0.95, color='r', linestyle='--', 
                   label='95% Variance')
        plt.axhline(y=0.99, color='g', linestyle='--', 
                   label='99% Variance')
        plt.xlabel('Number of Components')
        plt.ylabel('Cumulative Explained Variance')
        plt.title('Cumulative Explained Variance')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        # 3. PC1 vs PC2 scatter plot
        ax3 = plt.subplot(2, 3, 3)
        if labels is not None:
            unique_labels = np.unique(labels)
            colors = plt.cm.Set1(np.linspace(0, 1, len(unique_labels)))
            for i, label in enumerate(unique_labels):
                mask = labels == label
                plt.scatter(features_pca[mask, 0], features_pca[mask, 1],
                           c=[colors[i]], label=f'Class {label}',
                           alpha=self.config.alpha, s=self.config.marker_size)
            plt.legend()
        else:
            plt.scatter(features_pca[:, 0], features_pca[:, 1],
                       alpha=self.config.alpha, s=self.config.marker_size)
        
        plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
        plt.ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)')
        plt.title('PCA Scatter Plot (PC1 vs PC2)')
        plt.grid(True, alpha=0.3)
        
        # 4. PC1 vs PC3 scatter plot (if available)
        if features_pca.shape[1] >= 3:
            ax4 = plt.subplot(2, 3, 4)
            if labels is not None:
                for i, label in enumerate(unique_labels):
                    mask = labels == label
                    plt.scatter(features_pca[mask, 0], features_pca[mask, 2],
                               c=[colors[i]], label=f'Class {label}',
                               alpha=self.config.alpha, s=self.config.marker_size)
                plt.legend()
            else:
                plt.scatter(features_pca[:, 0], features_pca[:, 2],
                           alpha=self.config.alpha, s=self.config.marker_size)
            
            plt.xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)')
            plt.ylabel(f'PC3 ({pca.explained_variance_ratio_[2]:.2%} variance)')
            plt.title('PCA Scatter Plot (PC1 vs PC3)')
            plt.grid(True, alpha=0.3)
        
        # 5. Component loadings heatmap (first few components)
        ax5 = plt.subplot(2, 3, 5)
        n_show = min(5, pca.components_.shape[0])
        loadings = pca.components_[:n_show]
        
        im = plt.imshow(loadings, aspect='auto', cmap='RdBu_r')
        plt.colorbar(im, shrink=0.8)
        plt.xlabel('Feature Index')
        plt.ylabel('Principal Component')
        plt.title('PCA Component Loadings')
        
        # 6. Feature importance based on first PC
        ax6 = plt.subplot(2, 3, 6)
        pc1_loadings = np.abs(pca.components_[0])
        top_indices = np.argsort(pc1_loadings)[-10:]  # Top 10 features
        
        plt.barh(range(len(top_indices)), pc1_loadings[top_indices])
        plt.xlabel('Absolute Loading')
        plt.ylabel('Feature Index')
        plt.title('Top Features Contributing to PC1')
        plt.yticks(range(len(top_indices)), top_indices)
        
        plt.tight_layout()
        
        # Prepare results
        pca_results = {
            'pca_model': pca,
            'transformed_features': features_pca,
            'explained_variance_ratio': pca.explained_variance_ratio_,
            'cumulative_variance': cumsum,
            'components': pca.components_
        }
        
        if output_dir:
            self._save_figure(fig, "pca_analysis", output_dir)
        
        return fig, pca_results
    
    def plot_tsne_embedding(self, features: np.ndarray,
                           labels: np.ndarray = None,
                           perplexity: float = 30.0,
                           n_iter: int = 1000,
                           output_dir: str = None) -> Tuple[plt.Figure, np.ndarray]:
        """Plot t-SNE embedding."""
        print(f"Computing t-SNE embedding with perplexity={perplexity}...")
        
        # Perform t-SNE
        tsne = TSNE(n_components=2, perplexity=perplexity, n_iter=n_iter,
                   random_state=42)
        features_tsne = tsne.fit_transform(features)
        
        # Create plot
        fig, ax = plt.subplots(figsize=self.config.figure_size)
        
        if labels is not None:
            unique_labels = np.unique(labels)
            colors = plt.cm.Set1(np.linspace(0, 1, len(unique_labels)))
            
            for i, label in enumerate(unique_labels):
                mask = labels == label
                ax.scatter(features_tsne[mask, 0], features_tsne[mask, 1],
                          c=[colors[i]], label=f'Class {label}',
                          alpha=self.config.alpha, s=self.config.marker_size)
            ax.legend()
        else:
            ax.scatter(features_tsne[:, 0], features_tsne[:, 1],
                      alpha=self.config.alpha, s=self.config.marker_size,
                      c=plt.cm.viridis(np.linspace(0, 1, len(features_tsne))))
        
        ax.set_xlabel('t-SNE Dimension 1')
        ax.set_ylabel('t-SNE Dimension 2')
        ax.set_title(f't-SNE Embedding (perplexity={perplexity})')
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, f"tsne_embedding_perp{perplexity}", output_dir)
        
        return fig, features_tsne
    
    def plot_umap_embedding(self, features: np.ndarray,
                           labels: np.ndarray = None,
                           n_neighbors: int = 15,
                           min_dist: float = 0.1,
                           output_dir: str = None) -> Tuple[plt.Figure, np.ndarray]:
        """Plot UMAP embedding."""
        try:
            print(f"Computing UMAP embedding with n_neighbors={n_neighbors}, min_dist={min_dist}...")
            
            # Perform UMAP
            umap = UMAP(n_neighbors=n_neighbors, min_dist=min_dist, random_state=42)
            features_umap = umap.fit_transform(features)
            
            # Create plot
            fig, ax = plt.subplots(figsize=self.config.figure_size)
            
            if labels is not None:
                unique_labels = np.unique(labels)
                colors = plt.cm.Set1(np.linspace(0, 1, len(unique_labels)))
                
                for i, label in enumerate(unique_labels):
                    mask = labels == label
                    ax.scatter(features_umap[mask, 0], features_umap[mask, 1],
                              c=[colors[i]], label=f'Class {label}',
                              alpha=self.config.alpha, s=self.config.marker_size)
                ax.legend()
            else:
                ax.scatter(features_umap[:, 0], features_umap[:, 1],
                          alpha=self.config.alpha, s=self.config.marker_size,
                          c=plt.cm.viridis(np.linspace(0, 1, len(features_umap))))
            
            ax.set_xlabel('UMAP Dimension 1')
            ax.set_ylabel('UMAP Dimension 2')
            ax.set_title(f'UMAP Embedding (n_neighbors={n_neighbors}, min_dist={min_dist})')
            ax.grid(True, alpha=0.3)
            
            plt.tight_layout()
            
            if output_dir:
                self._save_figure(fig, f"umap_embedding_nn{n_neighbors}_md{min_dist}", output_dir)
            
            return fig, features_umap
            
        except ImportError:
            warnings.warn("UMAP not available. Please install umap-learn package.")
            return None, None


class ChaosSpecificVisualizer(BaseVisualizer):
    """Visualizer for chaos-specific features and analysis."""
    
    def plot_phase_space(self, embedded_data: np.ndarray,
                        title: str = "Phase Space Reconstruction",
                        output_dir: str = None) -> plt.Figure:
        """Plot phase space reconstruction."""
        if embedded_data.shape[1] < 2:
            warnings.warn("Need at least 2D embedding for phase space plot")
            return None
        
        fig = plt.figure(figsize=self.config.figure_size)
        
        if embedded_data.shape[1] == 2:
            # 2D phase space
            plt.plot(embedded_data[:, 0], embedded_data[:, 1], 
                    alpha=self.config.alpha, linewidth=0.5)
            plt.scatter(embedded_data[0, 0], embedded_data[0, 1], 
                       c='red', s=100, marker='o', label='Start', zorder=5)
            plt.xlabel('x(t)')
            plt.ylabel('x(t+τ)')
            plt.legend()
            plt.grid(True, alpha=0.3)
            
        elif embedded_data.shape[1] >= 3:
            # 3D phase space
            ax = fig.add_subplot(111, projection='3d')
            
            # Plot trajectory
            ax.plot(embedded_data[:, 0], embedded_data[:, 1], embedded_data[:, 2],
                   alpha=self.config.alpha, linewidth=0.5)
            
            # Mark start point
            ax.scatter(embedded_data[0, 0], embedded_data[0, 1], embedded_data[0, 2],
                      c='red', s=100, marker='o', label='Start')
            
            ax.set_xlabel('x(t)')
            ax.set_ylabel('x(t+τ)')
            ax.set_zlabel('x(t+2τ)')
            ax.legend()
        
        plt.title(title)
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, "phase_space", output_dir)
        
        return fig
    
    def plot_recurrence_matrix(self, recurrence_matrix: np.ndarray,
                              title: str = "Recurrence Matrix",
                              output_dir: str = None) -> plt.Figure:
        """Plot recurrence matrix."""
        fig, ax = plt.subplots(figsize=self.config.figure_size)
        
        # Plot recurrence matrix
        im = ax.imshow(recurrence_matrix, cmap='binary', aspect='auto')
        
        ax.set_xlabel('Time Index')
        ax.set_ylabel('Time Index')
        ax.set_title(title)
        
        # Add colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Recurrence')
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, "recurrence_matrix", output_dir)
        
        return fig
    
    def plot_lyapunov_spectrum(self, spectrum: np.ndarray,
                              title: str = "Lyapunov Spectrum",
                              output_dir: str = None) -> plt.Figure:
        """Plot Lyapunov spectrum."""
        fig, ax = plt.subplots(figsize=self.config.figure_size)
        
        # Plot spectrum
        x_indices = range(1, len(spectrum) + 1)
        bars = ax.bar(x_indices, spectrum, alpha=self.config.alpha)
        
        # Color bars based on sign
        for i, (bar, value) in enumerate(zip(bars, spectrum)):
            if value > 0:
                bar.set_color('red')
            elif value < 0:
                bar.set_color('blue')
            else:
                bar.set_color('gray')
        
        # Add horizontal line at zero
        ax.axhline(y=0, color='black', linestyle='-', linewidth=1)
        
        ax.set_xlabel('Lyapunov Exponent Index')
        ax.set_ylabel('Lyapunov Exponent Value')
        ax.set_title(title)
        ax.grid(True, alpha=0.3)
        
        # Add text annotation for chaos detection
        n_positive = np.sum(spectrum > 0)
        if n_positive > 0:
            ax.text(0.7, 0.9, f'Chaotic: {n_positive} positive exponents',
                   transform=ax.transAxes, bbox=dict(boxstyle="round", facecolor='wheat'))
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, "lyapunov_spectrum", output_dir)
        
        return fig
    
    def plot_rqa_measures_comparison(self, rqa_results: Dict[str, Any],
                                   output_dir: str = None) -> plt.Figure:
        """Plot comparison of RQA measures across scales."""
        if not rqa_results:
            warnings.warn("No RQA results provided")
            return None
        
        # Extract RQA measures
        scales = list(rqa_results.keys())
        measures = ['RR', 'DET', 'LAM', 'L_mean', 'V_mean', 'ENTR']
        
        # Prepare data
        data = {measure: [] for measure in measures}
        valid_scales = []
        
        for scale in scales:
            if rqa_results[scale].get('success', False):
                valid_scales.append(scale)
                rqa_measures = rqa_results[scale]['rqa_measures']
                
                for measure in measures:
                    value = rqa_measures.get(measure, np.nan)
                    data[measure].append(value)
        
        if not valid_scales:
            warnings.warn("No valid RQA results found")
            return None
        
        # Create subplots
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        axes = axes.flatten()
        
        for i, measure in enumerate(measures):
            ax = axes[i]
            
            values = data[measure]
            ax.plot(valid_scales, values, 'o-', linewidth=self.config.line_width,
                   markersize=8)
            ax.set_xlabel('Scale')
            ax.set_ylabel(measure)
            ax.set_title(f'RQA {measure} vs Scale')
            ax.grid(True, alpha=0.3)
            
            # Set x-axis to log scale if scales span multiple orders
            if max(valid_scales) / min(valid_scales) > 10:
                ax.set_xscale('log')
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, "rqa_measures_comparison", output_dir)
        
        return fig


class ComparisonVisualizer(BaseVisualizer):
    """Visualizer for comparing different feature types."""
    
    def plot_feature_comparison(self, feature_sets: Dict[str, np.ndarray],
                               labels: np.ndarray = None,
                               method: str = 'pca',
                               output_dir: str = None) -> plt.Figure:
        """
        Compare different feature sets using dimensionality reduction.
        
        Args:
            feature_sets: Dictionary of {name: features} pairs
            labels: Optional class labels
            method: Dimensionality reduction method ('pca', 'tsne')
            output_dir: Output directory
            
        Returns:
            Matplotlib figure
        """
        n_sets = len(feature_sets)
        if n_sets == 0:
            return None
        
        # Create subplots
        n_cols = min(3, n_sets)
        n_rows = (n_sets + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 6, n_rows * 5))
        if n_sets == 1:
            axes = [axes]
        elif n_rows == 1:
            axes = axes.flatten() if n_cols > 1 else [axes]
        else:
            axes = axes.flatten()
        
        for i, (name, features) in enumerate(feature_sets.items()):
            ax = axes[i]
            
            # Apply dimensionality reduction
            if method == 'pca':
                reducer = PCA(n_components=2)
                features_2d = reducer.fit_transform(features)
                var_explained = reducer.explained_variance_ratio_
                xlabel = f'PC1 ({var_explained[0]:.1%})'
                ylabel = f'PC2 ({var_explained[1]:.1%})'
                
            elif method == 'tsne':
                reducer = TSNE(n_components=2, random_state=42)
                features_2d = reducer.fit_transform(features)
                xlabel = 't-SNE Dimension 1'
                ylabel = 't-SNE Dimension 2'
                
            else:
                warnings.warn(f"Unknown method {method}")
                continue
            
            # Plot
            if labels is not None:
                unique_labels = np.unique(labels)
                colors = plt.cm.Set1(np.linspace(0, 1, len(unique_labels)))
                
                for j, label in enumerate(unique_labels):
                    mask = labels == label
                    ax.scatter(features_2d[mask, 0], features_2d[mask, 1],
                              c=[colors[j]], label=f'Class {label}',
                              alpha=self.config.alpha, s=self.config.marker_size)
                
                if i == 0:  # Only show legend for first subplot
                    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
            else:
                color = self.config.chaos_colors.get(name.lower(), 
                                                   plt.cm.Set1(i / n_sets))
                ax.scatter(features_2d[:, 0], features_2d[:, 1],
                          c=color, alpha=self.config.alpha, s=self.config.marker_size)
            
            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)
            ax.set_title(f'{name} Features ({method.upper()})')
            ax.grid(True, alpha=0.3)
        
        # Hide empty subplots
        for i in range(n_sets, len(axes)):
            axes[i].set_visible(False)
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, f"feature_comparison_{method}", output_dir)
        
        return fig
    
    def plot_performance_comparison(self, performance_results: Dict[str, Dict],
                                   metrics: List[str] = None,
                                   output_dir: str = None) -> plt.Figure:
        """Plot performance comparison across different feature types."""
        if not performance_results:
            return None
        
        metrics = metrics or ['accuracy', 'precision', 'recall', 'f1_score']
        method_names = list(performance_results.keys())
        
        # Prepare data
        data = {metric: [] for metric in metrics}
        valid_methods = []
        
        for method in method_names:
            results = performance_results[method]
            if all(metric in results for metric in metrics):
                valid_methods.append(method)
                for metric in metrics:
                    data[metric].append(results[metric])
        
        if not valid_methods:
            warnings.warn("No valid performance results found")
            return None
        
        # Create grouped bar plot
        fig, ax = plt.subplots(figsize=self.config.figure_size)
        
        x = np.arange(len(valid_methods))
        width = 0.2
        n_metrics = len(metrics)
        
        colors = plt.cm.Set1(np.linspace(0, 1, n_metrics))
        
        for i, metric in enumerate(metrics):
            offset = (i - n_metrics/2 + 0.5) * width
            bars = ax.bar(x + offset, data[metric], width, 
                         label=metric.capitalize(), color=colors[i],
                         alpha=self.config.alpha)
            
            # Add value labels on bars
            for bar in bars:
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                       f'{height:.3f}', ha='center', va='bottom', fontsize=10)
        
        ax.set_xlabel('Feature Type')
        ax.set_ylabel('Performance Score')
        ax.set_title('Performance Comparison Across Feature Types')
        ax.set_xticks(x)
        ax.set_xticklabels(valid_methods, rotation=45, ha='right')
        ax.legend()
        ax.grid(True, alpha=0.3, axis='y')
        ax.set_ylim(0, 1.1)
        
        plt.tight_layout()
        
        if output_dir:
            self._save_figure(fig, "performance_comparison", output_dir)
        
        return fig


class InteractiveVisualizer:
    """Interactive visualizer using Plotly."""
    
    def __init__(self, config: VisualizationConfig = None):
        self.config = config or VisualizationConfig()
        
        if not PLOTLY_AVAILABLE:
            warnings.warn("Plotly not available. Interactive features disabled.")
    
    def create_interactive_feature_explorer(self, features: np.ndarray,
                                           feature_names: List[str] = None,
                                           labels: np.ndarray = None,
                                           output_file: str = None) -> go.Figure:
        """Create interactive feature exploration plot."""
        if not PLOTLY_AVAILABLE:
            return None
        
        # Perform PCA for 2D visualization
        pca = PCA(n_components=2)
        features_2d = pca.fit_transform(features)
        
        # Create DataFrame
        df = pd.DataFrame({
            'PC1': features_2d[:, 0],
            'PC2': features_2d[:, 1],
            'Sample_Index': range(len(features_2d))
        })
        
        if labels is not None:
            df['Label'] = labels.astype(str)
        
        # Add original features for hover information
        if feature_names:
            for i, name in enumerate(feature_names[:10]):  # Limit to first 10
                df[name] = features[:, i]
        
        # Create interactive scatter plot
        if labels is not None:
            fig = px.scatter(df, x='PC1', y='PC2', color='Label',
                           hover_data=feature_names[:10] if feature_names else None,
                           title='Interactive Feature Explorer')
        else:
            fig = px.scatter(df, x='PC1', y='PC2',
                           hover_data=feature_names[:10] if feature_names else None,
                           title='Interactive Feature Explorer')
        
        # Update layout
        fig.update_layout(
            template=self.config.plotly_theme,
            xaxis_title=f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)',
            yaxis_title=f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)',
            font=dict(size=self.config.font_size)
        )
        
        if output_file:
            fig.write_html(output_file)
            print(f"Interactive plot saved: {output_file}")
        
        return fig
    
    def create_interactive_correlation_matrix(self, features: np.ndarray,
                                            feature_names: List[str] = None,
                                            output_file: str = None) -> go.Figure:
        """Create interactive correlation matrix."""
        if not PLOTLY_AVAILABLE:
            return None
        
        # Calculate correlation matrix
        corr_matrix = np.corrcoef(features.T)
        
        # Create labels
        if feature_names:
            labels = feature_names
        else:
            labels = [f'Feature_{i}' for i in range(len(corr_matrix))]
        
        # Create interactive heatmap
        fig = go.Figure(data=go.Heatmap(
            z=corr_matrix,
            x=labels,
            y=labels,
            colorscale='RdBu',
            zmid=0,
            text=np.round(corr_matrix, 3),
            texttemplate="%{text}",
            textfont={"size": 8},
            hovertemplate='%{y} vs %{x}<br>Correlation: %{z:.3f}<extra></extra>'
        ))
        
        fig.update_layout(
            title='Interactive Correlation Matrix',
            template=self.config.plotly_theme,
            font=dict(size=self.config.font_size),
            width=800,
            height=800
        )
        
        if output_file:
            fig.write_html(output_file)
            print(f"Interactive correlation matrix saved: {output_file}")
        
        return fig


def create_comprehensive_report(features: np.ndarray,
                               feature_names: List[str] = None,
                               labels: np.ndarray = None,
                               output_dir: str = "visualization_report",
                               config: VisualizationConfig = None):
    """Create a comprehensive visualization report."""
    print(f"Creating comprehensive visualization report in {output_dir}...")
    
    # Create output directory
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    
    # Initialize visualizers
    config = config or VisualizationConfig()
    
    dist_viz = FeatureDistributionVisualizer(config)
    corr_viz = CorrelationVisualizer(config)
    dim_viz = DimensionalityReductionVisualizer(config)
    
    # Generate all visualizations
    figures = {}
    
    # 1. Feature distributions
    print("Creating feature distributions...")
    figures['distributions'] = dist_viz.plot_feature_distributions(
        features, feature_names, labels, output_dir=output_dir
    )
    
    # 2. Feature statistics
    print("Creating feature statistics...")
    figures['statistics'] = dist_viz.plot_feature_statistics(
        features, feature_names, output_dir=output_dir
    )
    
    # 3. Correlation analysis
    print("Creating correlation analysis...")
    figures['correlation'] = corr_viz.plot_correlation_heatmap_seaborn(
        features, feature_names, output_dir=output_dir
    )
    
    # 4. PCA analysis
    print("Creating PCA analysis...")
    figures['pca'], pca_results = dim_viz.plot_pca_analysis(
        features, labels, output_dir=output_dir
    )
    
    # 5. t-SNE embedding
    print("Creating t-SNE embedding...")
    figures['tsne'], _ = dim_viz.plot_tsne_embedding(
        features, labels, output_dir=output_dir
    )
    
    # 6. UMAP embedding (if available)
    print("Creating UMAP embedding...")
    umap_fig, _ = dim_viz.plot_umap_embedding(
        features, labels, output_dir=output_dir
    )
    if umap_fig:
        figures['umap'] = umap_fig
    
    # Save summary
    summary = {
        'n_samples': features.shape[0],
        'n_features': features.shape[1],
        'n_classes': len(np.unique(labels)) if labels is not None else 'Unknown',
        'pca_variance_95': np.sum(pca_results['explained_variance_ratio'].cumsum() <= 0.95) + 1,
        'total_variance_explained': pca_results['explained_variance_ratio'].sum()
    }
    
    with open(Path(output_dir) / "report_summary.json", 'w') as f:
        json.dump(summary, f, indent=2, default=str)
    
    print(f"✓ Comprehensive visualization report created in {output_dir}")
    print(f"Generated {len(figures)} visualizations")
    
    return figures


if __name__ == "__main__":
    # Example usage and testing
    print("Testing Feature Visualizer...")
    
    # Generate synthetic data for testing
    np.random.seed(42)
    
    # Create synthetic features with different characteristics
    n_samples = 200
    n_features = 50
    
    # Generate features with different distributions
    features = np.random.randn(n_samples, n_features)
    
    # Add some structure
    features[:, 0] = features[:, 1] + np.random.normal(0, 0.1, n_samples)  # Correlated features
    features[:, 2] = features[:, 0] ** 2 + np.random.normal(0, 0.2, n_samples)  # Nonlinear relation
    
    # Create labels
    labels = np.random.choice([0, 1, 2], n_samples)
    
    # Feature names
    feature_names = [f"Feature_{i}" for i in range(n_features)]
    
    print(f"Generated synthetic data: {features.shape} features, {len(np.unique(labels))} classes")
    
    # Test different visualizers
    config = VisualizationConfig(figure_size=(10, 8))
    
    # 1. Test distribution visualizer
    print("\nTesting distribution visualizer...")
    dist_viz = FeatureDistributionVisualizer(config)
    fig1 = dist_viz.plot_feature_distributions(features, feature_names, labels)
    plt.show(block=False)
    
    # 2. Test correlation visualizer  
    print("Testing correlation visualizer...")
    corr_viz = CorrelationVisualizer(config)
    fig2 = corr_viz.plot_correlation_heatmap_seaborn(features, feature_names)
    plt.show(block=False)
    
    # 3. Test dimensionality reduction visualizer
    print("Testing dimensionality reduction visualizer...")
    dim_viz = DimensionalityReductionVisualizer(config)
    fig3, pca_results = dim_viz.plot_pca_analysis(features, labels)
    plt.show(block=False)
    
    # 4. Test t-SNE
    print("Testing t-SNE visualization...")
    fig4, tsne_features = dim_viz.plot_tsne_embedding(features, labels, n_iter=500)
    plt.show(block=False)
    
    # 5. Test comprehensive report
    print("\nCreating comprehensive report...")
    output_dir = "test_visualization_report"
    figures = create_comprehensive_report(
        features, feature_names, labels, 
        output_dir=output_dir, config=config
    )
    
    print(f"✓ Feature Visualizer testing completed!")
    print(f"Generated {len(figures)} visualization figures")
    print(f"Report saved in: {output_dir}")

NameError: name '__file__' is not defined