In [103]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import math
import numpy as np

In [104]:
iris_df = pd.read_csv('../data/Iris-Dataset.csv')

In [164]:
class EdaVisualization:
    def __init__(self, df: pd.DataFrame):
        self.df = df.copy()
        self.categorized_df = self.df.copy()
        self.correlation_df = self.df.copy()
    
    def snapshot(self):
        return self.df.head()
    
    def describe(self):
        return self.df.info()
    
    def mapping(self):
        """
        This method is used to map categorical variables inside 'object' dtype columns to numerical variables
        
        Save changes to the categorized df in case we need to access the untouched df in the future (dynamic_pairplot method)
        """
        for col in self.df.select_dtypes(include=['object']).columns:
            self.categorized_df[col] = self.df[col].astype('category').cat.codes
        return self.categorized_df
    

    def visualization(self):
        sns.set(style="whitegrid")
        
        # Calculate number of rows and columns for subplots
        num_features = len(self.categorized_df.columns)
        num_rows = math.ceil(num_features / 2)
        num_cols = 2
        
        fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 5 * num_rows))
        
        # Flatten axes if more than one row
        if num_rows > 1:
            axes = axes.flatten()
        else:
            axes = np.array(axes).flatten()
        
        # Loop through columns and plot on respective subplot
        for i, column in enumerate(self.categorized_df.columns):
            sns.histplot(self.categorized_df[column], kde=True, ax=axes[i]).set(title=f"{column.replace('_', ' ').title()} Distribution")
                    
        # Hide unused subplots
        for j in range(i + 1, len(axes)):
            fig.delaxes(axes[j])
    
        plt.tight_layout()
        plt.show()
        
    def dynamic_pairplot(self):
        sns.set(style="whitegrid")
        
        # Automatically detect non-numeric columns for hue
        target_col = self.df.select_dtypes(exclude='number').columns
                
        # If a categorical column exists, use it as hue; otherwise, no hue
        hue = target_col[0] if len(target_col) > 0 else None
        
        sns.pairplot(self.df, hue=hue, markers=["o", "s", "D"])
        plt.show()
        
    def correlation_matrix(self):
        target_col = self.correlation_df.select_dtypes(exclude='number').columns
        target_col = target_col[0] if len(target_col) > 0 else None
        
        correlation_matrix = self.correlation_df.drop(target_col, axis=1).corr()
        
        plt.figure(figsize=(8, 6))
        
        sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", fmt='.2f', linewidths=0.5)
        plt.title('Correlation Matrix')
        plt.show()

In [166]:
visualization = EdaVisualization(iris_df)

In [None]:
visualization.mapping()

In [None]:
visualization.visualization()

In [None]:
visualization.dynamic_pairplot()

In [None]:
visualization.correlation_matrix()