diff --git a/pyautoplot/__init__.py b/pyautoplot/__init__.py index 23c83ef..a33bea6 100644 --- a/pyautoplot/__init__.py +++ b/pyautoplot/__init__.py @@ -2,6 +2,8 @@ import pandas as pd import numpy as np import os +import warnings +# from statsmodels.tsa.stattools import acf # Removed as _detect_seasonality is removed class AutoPlot: """ @@ -47,6 +49,11 @@ def _detect_column_types(self): elif pd.api.types.is_numeric_dtype(self.data[column]): self.numeric.append(column) + # Check for unclassified columns + for column in self.data.columns: + if column not in self.numeric and column not in self.categorical and column not in self.time_series: + warnings.warn(f"Warning: Column '{column}' was not classified and will be ignored.") + def _apply_theme(self, theme): """Apply theme settings for light, dark, or custom themes.""" if theme == "dark": @@ -101,8 +108,8 @@ def _generate_analysis(self): n = len(data_column) # Get the number of valid data points mean = data_column.mean() stddev = data_column.std() - skewness = self._calculate_skewness(self, data=data_column, mean=mean, stddev=stddev, n=n) - kurtosis = self._calculate_kurtosis(self, data=data_column, mean=mean, stddev=stddev, n=n) + skewness = self._calculate_skewness(data=data_column, mean=mean, stddev=stddev, n=n) + kurtosis = self._calculate_kurtosis(data=data_column, mean=mean, stddev=stddev, n=n) stats = { "Type": "Numeric", @@ -119,7 +126,7 @@ def _generate_analysis(self): "50th Percentile (Median)": data_column.quantile(0.50), "75th Percentile": data_column.quantile(0.75), "Outliers": self._detect_outliers(data_column), - "Missing Values": data_column.isnull().sum(), + "Missing Values": self.data[column].isnull().sum(), } analysis[column] = stats @@ -136,43 +143,57 @@ def _generate_analysis(self): "Balance Ratio": self._calculate_balance_ratio(self.data[column]), "Frequency Distribution": value_counts.to_dict(), "Mode Variability": value_counts[value_counts == value_counts.max()].index.tolist(), - "Missing Values": data_column.isnull().sum(), + "Missing Values": self.data[column].isnull().sum(), } analysis[column] = stats # Time Series Analysis for column in self.time_series: data_column = self.data[column].dropna() + + # seasonality_result logic removed + autocorrelation_result = [np.nan] * 10 # Default for autocorrelation + + if not data_column.empty: + # seasonality_result = self._detect_seasonality(data_column) # Removed + autocorrelation_result = self._calculate_autocorrelation(data_column) + stats = { "Type": "Time Series", - "Count": len(data_column), - "Min": data_column.min(), - "Max": data_column.max(), - "Mean": data_column.mean(), - "Median": data_column.median(), - "Missing Values": data_column.isnull().sum(), - "Seasonality": self._detect_seasonality(data_column), - "Autocorrelation": self._calculate_autocorrelation(data_column), + "Count": len(data_column), # This will be 0 if empty + "Min": data_column.min() if not data_column.empty else np.nan, + "Max": data_column.max() if not data_column.empty else np.nan, + "Mean": data_column.mean() if not data_column.empty else np.nan, + "Median": data_column.median() if not data_column.empty else np.nan, + "Missing Values": self.data[column].isnull().sum(), + # "Seasonality": seasonality_result, # Removed + "Autocorrelation": autocorrelation_result, } analysis[column] = stats return analysis @staticmethod - def _calculate_skewness(self, data, mean, stddev, n): + def _calculate_skewness(data, mean, stddev, n): """Calculate skewness of a dataset.""" + if stddev == 0 or n == 0: + return np.nan skewness = np.sum(((data - mean) / stddev) ** 3) / n return skewness @staticmethod - def _calculate_kurtosis(self, data, mean, stddev, n): + def _calculate_kurtosis(data, mean, stddev, n): """Calculate kurtosis of a dataset.""" + if stddev == 0 or n == 0: + return np.nan kurtosis = (np.sum(((data - mean) / stddev) ** 4) / n) - 3 return kurtosis @staticmethod - def _calculate_autocorrelation(self, data): + def _calculate_autocorrelation(data): """Calculate autocorrelation at various lags.""" + if len(data) < 2 or data.var() == 0: + return [np.nan] * 10 # Return list of NaNs for 10 lags mean = data.mean() autocorrelations = [] for lag in range(1, 11): # Calculate autocorrelation for lags 1 through 10 @@ -181,6 +202,7 @@ def _calculate_autocorrelation(self, data): autocorrelations.append(correlation) return autocorrelations + # _detect_seasonality method removed @staticmethod def _detect_outliers(series, threshold=3): @@ -221,7 +243,7 @@ def _plot_detailed_analysis(self, analysis, output_file=None): if output_file: plt.savefig(output_file, dpi=300) - plt.show() + # plt.show() will be called from auto_plot def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs): """ @@ -229,7 +251,7 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs): This method produces: 1. Detailed analysis summary as a text-based plot. - 2. Numeric visualizations: Histograms, Boxplots, and Pairwise Scatter Matrix. + 2. Numeric visualizations: Histograms, Boxplots, and Pairwise Scatter Matrix. Note: Pairwise Scatter Matrix can be resource-intensive for datasets with many numeric columns. Consider using the `excludes=['pairwise_scatter']` option for large datasets. 3. Categorical visualizations: Enhanced Bar Plots and Pie Charts. 4. Time-series visualizations: Line and Stacked Area Plots. @@ -256,7 +278,9 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs): # Section 1: Detailed Analysis Summary if "detailed_analysis" not in excludes: analysis = self._generate_analysis() - self._plot_detailed_analysis(analysis, output_file=f"{base_filename}_analysis{file_extension}") + self._plot_detailed_analysis(analysis, output_file=f"{base_filename}_analysis{file_extension}" if output_file else None) + plt.show() + plt.close('all') # Section 2: Numeric Distributions and Boxplots if self.numeric and "numeric" not in excludes: @@ -278,6 +302,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs): if output_file: # Save Numeric section plots plt.savefig(f"{base_filename}_numeric{file_extension}", dpi=kwargs.get("dpi", 300)) + plt.show() + plt.close('all') # Section 3: Enhanced Bar Plots for Categorical Variables if self.categorical and "categorical" not in excludes: @@ -308,6 +334,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs): if output_file: # Save Categorical section plots plt.savefig(f"{base_filename}_categorical{file_extension}", dpi=kwargs.get("dpi", 300)) + plt.show() + plt.close('all') # Section 4: Pairwise Scatter Plots for Numeric Variables if len(self.numeric) > 1 and "pairwise_scatter" not in excludes: @@ -329,6 +357,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs): if output_file: # Save Pairwise section plots plt.savefig(f"{base_filename}_pairwise{file_extension}", dpi=kwargs.get("dpi", 300)) + plt.show() + plt.close('all') # Section 5: Pie Charts for Categorical Variables if self.categorical and "pie_charts" not in excludes: @@ -342,6 +372,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs): if output_file: # Save Pie section plots plt.savefig(f"{base_filename}_pie_{cat_column}{file_extension}", dpi=kwargs.get("dpi", 300)) + plt.show() + plt.close('all') # Section 6: Line Plots for Time Series Data if self.time_series and "line_plots" not in excludes: # Assuming self.time_series contains time series columns @@ -354,6 +386,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs): if output_file: # Save Line plot section plt.savefig(f"{base_filename}_line_{time_col}{file_extension}", dpi=kwargs.get("dpi", 300)) + plt.show() + plt.close('all') # Section 7: Stacked Area Plots for Time Series Data if self.time_series and "stacked_area" not in excludes: # If self.time_series contains multiple time series columns for stacking @@ -365,9 +399,10 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs): if output_file: # Save Stacked Area section plot plt.savefig(f"{base_filename}_stacked_area{file_extension}", dpi=kwargs.get("dpi", 300)) - - # Show all generated plots plt.show() + plt.close('all') + + # No final plt.show() here anymore def plot(self, plot_type, x=None, y=None, **kwargs): """ @@ -387,14 +422,46 @@ def plot(self, plot_type, x=None, y=None, **kwargs): autoplot.plot(plot_type="scatter", x="age", y="income") autoplot.plot(plot_type="distribution", x="salary", bins=20) """ - if plot_type == "scatter" and x and y: + if plot_type == "scatter": + if x is None: + raise ValueError("Argument 'x' is required for scatter plot.") + if y is None: + raise ValueError("Argument 'y' is required for scatter plot.") + if x not in self.data.columns: + raise ValueError(f"Column '{x}' not found in dataset.") + if y not in self.data.columns: + raise ValueError(f"Column '{y}' not found in dataset.") self.data.plot.scatter(x=x, y=y, **kwargs) - elif plot_type == "distribution" and x: - self.data[x].plot(kind="hist", **kwargs) - elif plot_type == "boxplot" and x: - self.data.boxplot(column=x, **kwargs) - elif plot_type == "bar" and x: - self.data[x].value_counts().plot(kind="bar", **kwargs) + elif plot_type in ["distribution", "boxplot", "bar"]: + if x is None: + raise ValueError("Argument 'x' is required for this plot type.") + if x not in self.data.columns: + raise ValueError(f"Column '{x}' not found in dataset.") + + if plot_type == "distribution": + self.data[x].plot(kind="hist", **kwargs) + elif plot_type == "boxplot": + self.data.boxplot(column=x, **kwargs) + elif plot_type == "bar": + self.data[x].value_counts().plot(kind="bar", **kwargs) + else: + # Handling other plot types or raising an error for unsupported ones + # For now, let's assume if it's not scatter, distribution, boxplot or bar, it's an issue or needs specific handling + # This part can be expanded based on future plot types. + # If x is generally required for other plots, this check can be more generic. + if x is None: # A more generic check if x is usually required + raise ValueError("Argument 'x' is required for this plot type.") + # If x is provided, but plot_type is unknown + if x is not None and x not in self.data.columns: + raise ValueError(f"Column '{x}' not found in dataset.") + # Defaulting to a simple plot or error if plot_type is not recognized + # For now, just showing plot, which might lead to errors if not configured correctly. + # Consider raising ValueError for unknown plot_type for robustness. + # For example: raise ValueError(f"Plot type '{plot_type}' is not supported.") + # However, current structure seems to rely on specific plot_type checks above. + else: + raise ValueError(f"Plot type '{plot_type}' is not supported or invalid arguments provided.") + plt.show() def customize(self, **kwargs): diff --git a/test.py b/test.py index 3a612d6..e79f189 100644 --- a/test.py +++ b/test.py @@ -1,10 +1,87 @@ +import os +import shutil from pyautoplot import AutoPlot +import pandas as pd +import matplotlib.pyplot as plt -# Initialize with a CSV file -plotter = AutoPlot("energy_consumption_dataset.csv") +OUTPUT_DIR = "test_output" +DATASET_PATH = "energy_consumption_dataset.csv" -# Automatically analyze and plot -plotter.auto_plot(output_file='test', theme="dark", color='orange', excludes=['detailed_analysis']) +if os.path.exists(OUTPUT_DIR): + shutil.rmtree(OUTPUT_DIR) +os.makedirs(OUTPUT_DIR, exist_ok=True) -# Manually plot data -plotter.plot(plot_type="scatter", x="Month", y="Hour") \ No newline at end of file +print(f"PyAutoPlot Test. Output: {os.path.abspath(OUTPUT_DIR)}") + +if not os.path.exists(DATASET_PATH): + print(f"Dummy data created for {DATASET_PATH}") + dummy_data = { + 'Timestamp': pd.to_datetime([f'2023-01-01 {h:02}:00' for h in range(24)] * 3), # 3 days + 'Temperature': [10+i*0.1 for i in range(24*3)], + 'EnergyConsumption': [100+i for i in range(24*3)], + 'DayOfWeek': (['Mon']*24 + ['Tue']*24 + ['Wed']*24), + 'Category': (['A','B','C']*24) + } + pd.DataFrame(dummy_data).to_csv(DATASET_PATH, index=False) + +try: + plotter = AutoPlot(DATASET_PATH) + print("AutoPlot initialized.") +except Exception as e: + print(f"Init Error: {e}"); exit() + +print("Testing auto_plot()...") +try: + plotter.auto_plot(output_file=os.path.join(OUTPUT_DIR, "ap_default.png")) +except Exception as e: print(f"Error default auto_plot: {e}") + +try: + plotter.auto_plot(output_file=os.path.join(OUTPUT_DIR, "ap_custom.png"), theme="dark", excludes=['pairwise_scatter']) +except Exception as e: print(f"Error custom auto_plot: {e}") + +print("Testing manual plot()...") +n1 = plotter.numeric[0] if plotter.numeric else None +n2 = plotter.numeric[1] if len(plotter.numeric) > 1 else n1 +c1 = plotter.categorical[0] if plotter.categorical else None + +if n1 and n2: + try: + plotter.plot(plot_type="scatter", x=n1, y=n2, title="Scatter") + plt.savefig(os.path.join(OUTPUT_DIR, "m_scatter.png")); plt.show(); plt.close('all') + except Exception as e: print(f"Error scatter: {e}") +if n1: + try: + plotter.plot(plot_type="distribution", x=n1, title="Dist", bins=10) + plt.savefig(os.path.join(OUTPUT_DIR, "m_dist.png")); plt.show(); plt.close('all') + except Exception as e: print(f"Error distribution: {e}") + try: + plotter.plot(plot_type="boxplot", x=n1, title="Box") + plt.savefig(os.path.join(OUTPUT_DIR, "m_box.png")); plt.show(); plt.close('all') + except Exception as e: print(f"Error boxplot: {e}") +if c1: + try: + plotter.plot(plot_type="bar", x=c1, title="Bar") + plt.savefig(os.path.join(OUTPUT_DIR, "m_bar.png")); plt.show(); plt.close('all') + except Exception as e: print(f"Error bar: {e}") + +print("Testing customize()...") +if n1: + try: + plotter.customize(**{"font.size": 8, "figure.facecolor": "lightyellow"}) + plotter.plot(plot_type="distribution", x=n1, title="Custom Dist") + plt.savefig(os.path.join(OUTPUT_DIR, "m_custom_dist.png")); plt.show(); plt.close('all') + import matplotlib + plotter.customize(**matplotlib.rcParamsDefault) # Revert + except Exception as e: print(f"Error customize: {e}") + +print("Testing small dataset...") +try: + small_data = {'A': [1,2,3], 'B': ['x','y','z'], 'C': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])} + small_df_path = os.path.join(OUTPUT_DIR, "small_ds.csv") + pd.DataFrame(small_data).to_csv(small_df_path, index=False) + small_plotter = AutoPlot(csv_path=small_df_path) + small_plotter.auto_plot(output_file=os.path.join(OUTPUT_DIR, "ap_small_ds.png")) +except Exception as e: print(f"Error small dataset: {e}") + +print("Test Script Finished.") +```