In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

class DataAnalyzer:
    def __init__(self, df, target_col, categorical_cols, numerical_cols):
        """
        Initialize the DataAnalyzer with a dataframe and column specifications
        
        Parameters:
        df (pandas.DataFrame): Input dataframe
        target_col (str): Name of the target column
        categorical_cols (list): List of categorical column names
        numerical_cols (list): List of numerical column names
        """
        self.df = df.copy()
        self.target_col = target_col
        self.categorical_cols = categorical_cols
        self.numerical_cols = numerical_cols
        
    def get_basic_stats(self):
        """Display basic statistics about the dataset"""
        print("Dataset Shape:", self.df.shape)
        print("\nMissing Values:\n", self.df.isnull().sum())
        print("\nTarget Distribution:\n", self.df[self.target_col].value_counts(normalize=True))
        
    def analyze_categorical_vars(self):
        """Analyze categorical variables"""
        for col in self.categorical_cols:
            print(f"\n=== {col} Analysis ===")
            
            # Overall distribution
            print("\nOverall Distribution:")
            print(self.df[col].value_counts(normalize=True))
            
            # Distribution by target
            print("\nDistribution by Target:")
            print(pd.crosstab(self.df[col], self.df[self.target_col], normalize='columns'))
            
            # Chi-square test
            contingency_table = pd.crosstab(self.df[col], self.df[self.target_col])
            chi2, p_value = stats.chi2_contingency(contingency_table)[:2]
            print(f"\nChi-square test p-value: {p_value:.4f}")
            
            # Plotting
            plt.figure(figsize=(10, 5))
            
            # Overall distribution
            plt.subplot(1, 2, 1)
            sns.countplot(data=self.df, x=col)
            plt.title(f'Overall {col} Distribution')
            plt.xticks(rotation=45)
            
            # Distribution by target
            plt.subplot(1, 2, 2)
            sns.countplot(data=self.df, x=col, hue=self.target_col)
            plt.title(f'{col} Distribution by Target')
            plt.xticks(rotation=45)
            
            plt.tight_layout()
            plt.show()
            
    def analyze_numerical_vars(self):
        """Analyze numerical variables"""
        for col in self.numerical_cols:
            print(f"\n=== {col} Analysis ===")
            
            # Overall statistics
            print("\nOverall Statistics:")
            print(self.df[col].describe())
            
            # Statistics by target
            print("\nStatistics by Target:")
            print(self.df.groupby(self.target_col)[col].describe())
            
            # Correlation with target
            correlation = self.df[col].corr(self.df[self.target_col])
            print(f"\nCorrelation with target: {correlation:.4f}")
            
            # T-test between classes
            t_stat, p_value = stats.ttest_ind(
                self.df[self.df[self.target_col] == 0][col],
                self.df[self.df[self.target_col] == 1][col]
            )
            print(f"T-test p-value: {p_value:.4f}")
            
            # Plotting
            plt.figure(figsize=(15, 5))
            
            # Overall distribution
            plt.subplot(1, 3, 1)
            sns.histplot(data=self.df, x=col)
            plt.title(f'Overall {col} Distribution')
            
            # Distribution by target
            plt.subplot(1, 3, 2)
            sns.boxplot(data=self.df, x=self.target_col, y=col)
            plt.title(f'{col} Distribution by Target')
            
            # Violin plot
            plt.subplot(1, 3, 3)
            sns.violinplot(data=self.df, x=self.target_col, y=col)
            plt.title(f'{col} Violin Plot by Target')
            
            plt.tight_layout()
            plt.show()
            
    def analyze_all(self):
        """Run all analyses"""
        self.get_basic_stats()
        self.analyze_categorical_vars()
        self.analyze_numerical_vars()
        
    def compare_datasets(self, other_df, label1="Original", label2="Modified"):
        """Compare two datasets"""
        print(f"\n=== Comparing {label1} vs {label2} ===")
        
        # Compare shapes
        print(f"\n{label1} shape:", self.df.shape)
        print(f"{label2} shape:", other_df.shape)
        
        # Compare target distributions
        print(f"\n{label1} target distribution:")
        print(self.df[self.target_col].value_counts(normalize=True))
        print(f"\n{label2} target distribution:")
        print(other_df[self.target_col].value_counts(normalize=True))
        
        # Compare categorical variables
        for col in self.categorical_cols:
            print(f"\n{col} distribution comparison:")
            print("\nOriginal:")
            print(self.df[col].value_counts(normalize=True))
            print("\nModified:")
            print(other_df[col].value_counts(normalize=True))
            
        # Compare numerical variables
        for col in self.numerical_cols:
            print(f"\n{col} statistics comparison:")
            print("\nOriginal:")
            print(self.df[col].describe())
            print("\nModified:")
            print(other_df[col].describe())

In [None]:
import os
task = 'dep'
folder_path = '/opt/notebooks/sex_balanced'
all_dataframes_dict = {}
for file in os.listdir(folder_path):
    if file.endswith('.csv'):
        #check if in file is the string "dep"
        if task in file:
            file_path = os.path.join(folder_path, file)
            try:
                df = pd.read_csv(file_path)
                name = file.split('.')[0].replace("sex_balanced_", "")
                all_dataframes_dict[name] = df
            except Exception as e:
                print(f'Failed to load {file_path} due to {e}')

In [None]:
df.columns

In [None]:
categorial_cols = ["assessment_centre", "sex"]
numerical_cols = ["bmi", "deprivation_index", "age_at_assessment", "RDS"]
for key, df in all_dataframes_dict.items():
    print(f"=== {key} Analysis ===")
    analyzer = DataAnalyzer(
        df=df,
        target_col=key,
        categorical_cols=categorial_cols,
        numerical_cols=numerical_cols
    )
    analyzer.analyze_all()
