Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 94 additions & 27 deletions pyautoplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pandas as pd
import numpy as np
import os
import warnings
# from statsmodels.tsa.stattools import acf # Removed as _detect_seasonality is removed

class AutoPlot:
"""
Expand Down Expand Up @@ -47,6 +49,11 @@ def _detect_column_types(self):
elif pd.api.types.is_numeric_dtype(self.data[column]):
self.numeric.append(column)

# Check for unclassified columns
for column in self.data.columns:
if column not in self.numeric and column not in self.categorical and column not in self.time_series:
warnings.warn(f"Warning: Column '{column}' was not classified and will be ignored.")

def _apply_theme(self, theme):
"""Apply theme settings for light, dark, or custom themes."""
if theme == "dark":
Expand Down Expand Up @@ -101,8 +108,8 @@ def _generate_analysis(self):
n = len(data_column) # Get the number of valid data points
mean = data_column.mean()
stddev = data_column.std()
skewness = self._calculate_skewness(self, data=data_column, mean=mean, stddev=stddev, n=n)
kurtosis = self._calculate_kurtosis(self, data=data_column, mean=mean, stddev=stddev, n=n)
skewness = self._calculate_skewness(data=data_column, mean=mean, stddev=stddev, n=n)
kurtosis = self._calculate_kurtosis(data=data_column, mean=mean, stddev=stddev, n=n)

stats = {
"Type": "Numeric",
Expand All @@ -119,7 +126,7 @@ def _generate_analysis(self):
"50th Percentile (Median)": data_column.quantile(0.50),
"75th Percentile": data_column.quantile(0.75),
"Outliers": self._detect_outliers(data_column),
"Missing Values": data_column.isnull().sum(),
"Missing Values": self.data[column].isnull().sum(),
}
analysis[column] = stats

Expand All @@ -136,43 +143,57 @@ def _generate_analysis(self):
"Balance Ratio": self._calculate_balance_ratio(self.data[column]),
"Frequency Distribution": value_counts.to_dict(),
"Mode Variability": value_counts[value_counts == value_counts.max()].index.tolist(),
"Missing Values": data_column.isnull().sum(),
"Missing Values": self.data[column].isnull().sum(),
}
analysis[column] = stats

# Time Series Analysis
for column in self.time_series:
data_column = self.data[column].dropna()

# seasonality_result logic removed
autocorrelation_result = [np.nan] * 10 # Default for autocorrelation

if not data_column.empty:
# seasonality_result = self._detect_seasonality(data_column) # Removed
autocorrelation_result = self._calculate_autocorrelation(data_column)

stats = {
"Type": "Time Series",
"Count": len(data_column),
"Min": data_column.min(),
"Max": data_column.max(),
"Mean": data_column.mean(),
"Median": data_column.median(),
"Missing Values": data_column.isnull().sum(),
"Seasonality": self._detect_seasonality(data_column),
"Autocorrelation": self._calculate_autocorrelation(data_column),
"Count": len(data_column), # This will be 0 if empty
"Min": data_column.min() if not data_column.empty else np.nan,
"Max": data_column.max() if not data_column.empty else np.nan,
"Mean": data_column.mean() if not data_column.empty else np.nan,
"Median": data_column.median() if not data_column.empty else np.nan,
"Missing Values": self.data[column].isnull().sum(),
# "Seasonality": seasonality_result, # Removed
"Autocorrelation": autocorrelation_result,
}
analysis[column] = stats

return analysis

@staticmethod
def _calculate_skewness(self, data, mean, stddev, n):
def _calculate_skewness(data, mean, stddev, n):
"""Calculate skewness of a dataset."""
if stddev == 0 or n == 0:
return np.nan
skewness = np.sum(((data - mean) / stddev) ** 3) / n
return skewness

@staticmethod
def _calculate_kurtosis(self, data, mean, stddev, n):
def _calculate_kurtosis(data, mean, stddev, n):
"""Calculate kurtosis of a dataset."""
if stddev == 0 or n == 0:
return np.nan
kurtosis = (np.sum(((data - mean) / stddev) ** 4) / n) - 3
return kurtosis

@staticmethod
def _calculate_autocorrelation(self, data):
def _calculate_autocorrelation(data):
"""Calculate autocorrelation at various lags."""
if len(data) < 2 or data.var() == 0:
return [np.nan] * 10 # Return list of NaNs for 10 lags
mean = data.mean()
autocorrelations = []
for lag in range(1, 11): # Calculate autocorrelation for lags 1 through 10
Expand All @@ -181,6 +202,7 @@ def _calculate_autocorrelation(self, data):
autocorrelations.append(correlation)
return autocorrelations

# _detect_seasonality method removed

@staticmethod
def _detect_outliers(series, threshold=3):
Expand Down Expand Up @@ -221,15 +243,15 @@ def _plot_detailed_analysis(self, analysis, output_file=None):

if output_file:
plt.savefig(output_file, dpi=300)
plt.show()
# plt.show() will be called from auto_plot

def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs):
"""
Automatically analyze the dataset and generate visualizations.

This method produces:
1. Detailed analysis summary as a text-based plot.
2. Numeric visualizations: Histograms, Boxplots, and Pairwise Scatter Matrix.
2. Numeric visualizations: Histograms, Boxplots, and Pairwise Scatter Matrix. Note: Pairwise Scatter Matrix can be resource-intensive for datasets with many numeric columns. Consider using the `excludes=['pairwise_scatter']` option for large datasets.
3. Categorical visualizations: Enhanced Bar Plots and Pie Charts.
4. Time-series visualizations: Line and Stacked Area Plots.

Expand All @@ -256,7 +278,9 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs):
# Section 1: Detailed Analysis Summary
if "detailed_analysis" not in excludes:
analysis = self._generate_analysis()
self._plot_detailed_analysis(analysis, output_file=f"{base_filename}_analysis{file_extension}")
self._plot_detailed_analysis(analysis, output_file=f"{base_filename}_analysis{file_extension}" if output_file else None)
plt.show()
plt.close('all')

# Section 2: Numeric Distributions and Boxplots
if self.numeric and "numeric" not in excludes:
Expand All @@ -278,6 +302,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs):
if output_file:
# Save Numeric section plots
plt.savefig(f"{base_filename}_numeric{file_extension}", dpi=kwargs.get("dpi", 300))
plt.show()
plt.close('all')

# Section 3: Enhanced Bar Plots for Categorical Variables
if self.categorical and "categorical" not in excludes:
Expand Down Expand Up @@ -308,6 +334,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs):
if output_file:
# Save Categorical section plots
plt.savefig(f"{base_filename}_categorical{file_extension}", dpi=kwargs.get("dpi", 300))
plt.show()
plt.close('all')

# Section 4: Pairwise Scatter Plots for Numeric Variables
if len(self.numeric) > 1 and "pairwise_scatter" not in excludes:
Expand All @@ -329,6 +357,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs):
if output_file:
# Save Pairwise section plots
plt.savefig(f"{base_filename}_pairwise{file_extension}", dpi=kwargs.get("dpi", 300))
plt.show()
plt.close('all')

# Section 5: Pie Charts for Categorical Variables
if self.categorical and "pie_charts" not in excludes:
Expand All @@ -342,6 +372,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs):
if output_file:
# Save Pie section plots
plt.savefig(f"{base_filename}_pie_{cat_column}{file_extension}", dpi=kwargs.get("dpi", 300))
plt.show()
plt.close('all')

# Section 6: Line Plots for Time Series Data
if self.time_series and "line_plots" not in excludes: # Assuming self.time_series contains time series columns
Expand All @@ -354,6 +386,8 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs):
if output_file:
# Save Line plot section
plt.savefig(f"{base_filename}_line_{time_col}{file_extension}", dpi=kwargs.get("dpi", 300))
plt.show()
plt.close('all')

# Section 7: Stacked Area Plots for Time Series Data
if self.time_series and "stacked_area" not in excludes: # If self.time_series contains multiple time series columns for stacking
Expand All @@ -365,9 +399,10 @@ def auto_plot(self, output_file=None, theme="light", excludes=None, **kwargs):
if output_file:
# Save Stacked Area section plot
plt.savefig(f"{base_filename}_stacked_area{file_extension}", dpi=kwargs.get("dpi", 300))

# Show all generated plots
plt.show()
plt.close('all')

# No final plt.show() here anymore

def plot(self, plot_type, x=None, y=None, **kwargs):
"""
Expand All @@ -387,14 +422,46 @@ def plot(self, plot_type, x=None, y=None, **kwargs):
autoplot.plot(plot_type="scatter", x="age", y="income")
autoplot.plot(plot_type="distribution", x="salary", bins=20)
"""
if plot_type == "scatter" and x and y:
if plot_type == "scatter":
if x is None:
raise ValueError("Argument 'x' is required for scatter plot.")
if y is None:
raise ValueError("Argument 'y' is required for scatter plot.")
if x not in self.data.columns:
raise ValueError(f"Column '{x}' not found in dataset.")
if y not in self.data.columns:
raise ValueError(f"Column '{y}' not found in dataset.")
self.data.plot.scatter(x=x, y=y, **kwargs)
elif plot_type == "distribution" and x:
self.data[x].plot(kind="hist", **kwargs)
elif plot_type == "boxplot" and x:
self.data.boxplot(column=x, **kwargs)
elif plot_type == "bar" and x:
self.data[x].value_counts().plot(kind="bar", **kwargs)
elif plot_type in ["distribution", "boxplot", "bar"]:
if x is None:
raise ValueError("Argument 'x' is required for this plot type.")
if x not in self.data.columns:
raise ValueError(f"Column '{x}' not found in dataset.")

if plot_type == "distribution":
self.data[x].plot(kind="hist", **kwargs)
elif plot_type == "boxplot":
self.data.boxplot(column=x, **kwargs)
elif plot_type == "bar":
self.data[x].value_counts().plot(kind="bar", **kwargs)
else:
# Handling other plot types or raising an error for unsupported ones
# For now, let's assume if it's not scatter, distribution, boxplot or bar, it's an issue or needs specific handling
# This part can be expanded based on future plot types.
# If x is generally required for other plots, this check can be more generic.
if x is None: # A more generic check if x is usually required
raise ValueError("Argument 'x' is required for this plot type.")
# If x is provided, but plot_type is unknown
if x is not None and x not in self.data.columns:
raise ValueError(f"Column '{x}' not found in dataset.")
# Defaulting to a simple plot or error if plot_type is not recognized
# For now, just showing plot, which might lead to errors if not configured correctly.
# Consider raising ValueError for unknown plot_type for robustness.
# For example: raise ValueError(f"Plot type '{plot_type}' is not supported.")
# However, current structure seems to rely on specific plot_type checks above.
else:
raise ValueError(f"Plot type '{plot_type}' is not supported or invalid arguments provided.")

plt.show()

def customize(self, **kwargs):
Expand Down
89 changes: 83 additions & 6 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,87 @@
import os
import shutil
from pyautoplot import AutoPlot
import pandas as pd
import matplotlib.pyplot as plt

# Initialize with a CSV file
plotter = AutoPlot("energy_consumption_dataset.csv")
OUTPUT_DIR = "test_output"
DATASET_PATH = "energy_consumption_dataset.csv"

# Automatically analyze and plot
plotter.auto_plot(output_file='test', theme="dark", color='orange', excludes=['detailed_analysis'])
if os.path.exists(OUTPUT_DIR):
shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Manually plot data
plotter.plot(plot_type="scatter", x="Month", y="Hour")
print(f"PyAutoPlot Test. Output: {os.path.abspath(OUTPUT_DIR)}")

if not os.path.exists(DATASET_PATH):
print(f"Dummy data created for {DATASET_PATH}")
dummy_data = {
'Timestamp': pd.to_datetime([f'2023-01-01 {h:02}:00' for h in range(24)] * 3), # 3 days
'Temperature': [10+i*0.1 for i in range(24*3)],
'EnergyConsumption': [100+i for i in range(24*3)],
'DayOfWeek': (['Mon']*24 + ['Tue']*24 + ['Wed']*24),
'Category': (['A','B','C']*24)
}
pd.DataFrame(dummy_data).to_csv(DATASET_PATH, index=False)

try:
plotter = AutoPlot(DATASET_PATH)
print("AutoPlot initialized.")
except Exception as e:
print(f"Init Error: {e}"); exit()

print("Testing auto_plot()...")
try:
plotter.auto_plot(output_file=os.path.join(OUTPUT_DIR, "ap_default.png"))
except Exception as e: print(f"Error default auto_plot: {e}")

try:
plotter.auto_plot(output_file=os.path.join(OUTPUT_DIR, "ap_custom.png"), theme="dark", excludes=['pairwise_scatter'])
except Exception as e: print(f"Error custom auto_plot: {e}")

print("Testing manual plot()...")
n1 = plotter.numeric[0] if plotter.numeric else None
n2 = plotter.numeric[1] if len(plotter.numeric) > 1 else n1
c1 = plotter.categorical[0] if plotter.categorical else None

if n1 and n2:
try:
plotter.plot(plot_type="scatter", x=n1, y=n2, title="Scatter")
plt.savefig(os.path.join(OUTPUT_DIR, "m_scatter.png")); plt.show(); plt.close('all')
except Exception as e: print(f"Error scatter: {e}")
if n1:
try:
plotter.plot(plot_type="distribution", x=n1, title="Dist", bins=10)
plt.savefig(os.path.join(OUTPUT_DIR, "m_dist.png")); plt.show(); plt.close('all')
except Exception as e: print(f"Error distribution: {e}")
try:
plotter.plot(plot_type="boxplot", x=n1, title="Box")
plt.savefig(os.path.join(OUTPUT_DIR, "m_box.png")); plt.show(); plt.close('all')
except Exception as e: print(f"Error boxplot: {e}")
if c1:
try:
plotter.plot(plot_type="bar", x=c1, title="Bar")
plt.savefig(os.path.join(OUTPUT_DIR, "m_bar.png")); plt.show(); plt.close('all')
except Exception as e: print(f"Error bar: {e}")

print("Testing customize()...")
if n1:
try:
plotter.customize(**{"font.size": 8, "figure.facecolor": "lightyellow"})
plotter.plot(plot_type="distribution", x=n1, title="Custom Dist")
plt.savefig(os.path.join(OUTPUT_DIR, "m_custom_dist.png")); plt.show(); plt.close('all')
import matplotlib
plotter.customize(**matplotlib.rcParamsDefault) # Revert
except Exception as e: print(f"Error customize: {e}")

print("Testing small dataset...")
try:
small_data = {'A': [1,2,3], 'B': ['x','y','z'], 'C': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03'])}
small_df_path = os.path.join(OUTPUT_DIR, "small_ds.csv")
pd.DataFrame(small_data).to_csv(small_df_path, index=False)
small_plotter = AutoPlot(csv_path=small_df_path)
small_plotter.auto_plot(output_file=os.path.join(OUTPUT_DIR, "ap_small_ds.png"))
except Exception as e: print(f"Error small dataset: {e}")

print("Test Script Finished.")
```