# 0.1 - Exploratory Data Analysis and Data Cleaning

This notebook serves as the initial step in the project. Its purpose is to:
1.  Load the raw "Adult Income" dataset.
2.  Perform a thorough exploratory data analysis (EDA) to understand its structure, distributions, and relationships.
3.  Identify and handle data quality issues like missing values.
4.  Apply initial cleaning and preprocessing steps.
5.  Save the cleaned, processed dataset to the `data/processed` directory, so all subsequent modeling notebooks can start from a consistent and clean state.


In [None]:
# Library Imports

# Data manipulation
import pandas as pd
import numpy as np
import math
import os

# Dataset loading
from sklearn.datasets import fetch_openml

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# General settings
import warnings
warnings.filterwarnings('ignore')

# Visualization style
sns.set(style='whitegrid')

## 1. Load and Explore the Dataset

In this first step, we will work with the **Adult Income** dataset.

### Specific Goals:
- **Load the dataset** from `OpenML`.
- **Explore and clean the data**, handling missing or inconsistent values.
- **Save the cleaned dataset** to be used by other notebooks.


In [None]:
# Load the dataset
# Note: We are not using the load_data function here as this notebook *creates* the processed data
adult_data = fetch_openml(name='adult', version=2, as_frame=True)

# Extract the data and target
df = adult_data.frame

# Display the first few rows
print('DataFrame dimensions:', df.shape)
df.head()

In [None]:
# General dataset info
df.info()

# Descriptive statistics for numerical variables
df.describe()

In [None]:
# Categorical columns
cat_cols = df.select_dtypes(include='category').columns.tolist()
print("Categorical variables:", cat_cols)

# Numerical columns
num_cols = df.select_dtypes(exclude='category').columns.tolist()
print("Numerical variables:", num_cols)

## 2. Data Cleaning

In this section, we will address the data quality issues identified in the initial exploration.

- **Missing Values:** We will review the columns with `NaN` values and apply a strategy to handle them. Since the percentage of missing data is low, we will choose to remove the corresponding rows.
- **Simplify 'native-country'**: The `native-country` variable is heavily dominated by one category (`United-States`). To make the model more robust, all other nationalities will be grouped into a single category, 'Other'.


In [None]:
# Count missing values per column
print("Columns with NaN values before cleaning:")
print(df.isnull().sum())

# Drop rows with NaN values
df_cleaned = df.dropna()

print("\nDimensions after dropping NaNs:", df_cleaned.shape)

# Simplify the 'native-country' column
df_cleaned['native-country'] = df_cleaned['native-country'].apply(lambda x: x if x == 'United-States' else 'Other')

print("\nUnique values in 'native-country' after simplification:")
print(df_cleaned['native-country'].value_counts())

## 3. Data Visualization

Once the data is clean, we proceed to visualize the distributions of numerical and categorical variables to better understand their characteristics.


In [None]:
# Histograms for numerical variables
df_cleaned[num_cols].hist(figsize=(12, 8), bins=30)
plt.suptitle("Distribution of Numerical Variables")
plt.savefig('../reports/figures/0.1_numerical_distribution.png')
plt.show()

In [None]:
# Countplots for categorical variables

# Total number of categorical variables
cat_cols_cleaned = df_cleaned.select_dtypes(include='category').columns.tolist()
num_vars = len(cat_cols_cleaned)


# Define number of columns and rows for the grid
cols = 3
rows = math.ceil(num_vars / cols)

# Create figure and axes
fig, axes = plt.subplots(rows, cols, figsize=(cols*6, rows*4), constrained_layout=True)

# Flatten axes for easy iteration
axes = axes.flatten()

# Plot each categorical variable in a subplot
for i, col in enumerate(cat_cols_cleaned):
    sns.countplot(data=df_cleaned, y=col, order=df_cleaned[col].value_counts().index, ax=axes[i])
    axes[i].set_title(f"Distribution of '{col}'")

# Remove empty subplots if any
for j in range(i+1, len(axes)):
    fig.delaxes(axes[j])

plt.savefig('../reports/figures/0.1_categorical_distribution.png')
plt.show()

## 4. Save Processed Data

The final step in this notebook is to save the cleaned `DataFrame` to a CSV file. This will allow the modeling notebooks to directly load the processed data, ensuring consistency and avoiding the repetition of cleaning steps.


In [None]:
# Define the output path
output_path = '../data/processed/adult_cleaned.csv'
os.makedirs(os.path.dirname(output_path), exist_ok=True)


# Save the dataframe
df_cleaned.to_csv(output_path, index=False)

print(f"Cleaned dataset saved to: {output_path}")