In [None]:
import pandas as pd
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt

class DataAnalysis:
    def __init__(self, csv_file):
        self.df = pd.read_csv(csv_file)
        self.categories = [col.split('=')[0] for col in self.df.columns if '=' in col]
        self.categories = list(set(self.categories))

    def plot_categories(self):
        for category in self.categories:
            category_columns = [col for col in self.df.columns if col.startswith(category + '=')]
            if category_columns:
                category_data = self.df[category_columns].idxmax(axis=1).apply(lambda x: x.split('=')[1])
                category_counts = category_data.value_counts().sort_values(ascending=False)
                
                # Separate data based on 'class' column
                normal_data = category_data[self.df['class'] == 0]
                anomaly_data = category_data[self.df['class'] == 1]
                
                # Plot normal data
                plt.figure(figsize=(10, 6))
                sns.barplot(x=normal_data.value_counts().index, y=normal_data.value_counts().values, color='blue', label='Normal')
                
                # Plot anomaly data
                sns.barplot(x=anomaly_data.value_counts().index, y=anomaly_data.value_counts().values, color='red', label='Anomalies')
                
                plt.title(f'Category: {category} (Normal vs Anomalies)')
                plt.xticks(rotation=45)
                plt.legend()
                plt.show()
            else:
                print(f"Warning: {category} not found in dataframe columns.")

    def plot_eda(self):
        # Aggregate one-hot encoded columns and print lots
        one_hot_columns = [col for col in self.df.columns if '=' in col]
        aggregated_data = self.df[one_hot_columns].idxmax(axis=1).apply(lambda x: x.split('=')[1])
        print(aggregated_data.value_counts())
        # Plot correlation heatmap
        plt.figure(figsize=(12, 8))
        correlation_matrix = self.df.corr()
        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt='.2f')
        plt.title('Correlation Heatmap')
        plt.show()

        # Plot pairplot for numerical features
        # numerical_features = self.df.select_dtypes(include=['float64', 'int64']).columns
        # sns.pairplot(self.df[numerical_features])
        # plt.suptitle('Pairplot of Numerical Features', y=1.02)
        # plt.show()

        # Plot boxplot for age by job
        plt.figure(figsize=(12, 8))
        sns.boxplot(data=self.df, x='job', y='age')
        plt.title('Boxplot of Age by Job')
        plt.xlabel('Job')
        plt.ylabel('Age')
        plt.xticks(rotation=45)
        plt.show()

        # Plot violin plot for duration by education
        plt.figure(figsize=(12, 8))
        sns.violinplot(data=self.df, x='education', y='duration')
        plt.title('Violin Plot of Duration by Education')
        plt.xlabel('Education Level')
        plt.ylabel('Duration')
        plt.xticks(rotation=45)
        plt.show()
        # Plot scatter plot for age vs duration
        plt.figure(figsize=(10, 6))
        sns.scatterplot(data=self.df, x='age', y='duration', hue='class', palette='coolwarm')
        plt.title('Scatter Plot of Age vs Duration')
        plt.xlabel('Age')
        plt.ylabel('Duration')
        plt.legend(title='Class')
        plt.show()
        # Plot confusion matrix

        # Assuming 'class' is the target variable and we have predictions
        # For demonstration, let's create random predictions
        np.random.seed(42)
        predictions = np.random.randint(0, 2, size=len(self.df))

        conf_matrix = confusion_matrix(self.df['class'], predictions)
        plt.figure(figsize=(8, 6))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('Actual')
        plt.show()

        # Assuming 'analysis' is an instance of DataAnalysis class

if __name__ == '__main__':
    analysis = DataAnalysis('./bank-additional-full_normalised.csv')
    analysis.plot_categories()