In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from hydra import compose, initialize

from sklearn.feature_selection import mutual_info_classif
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from src.data import load_data

# 1. Setup

In [None]:
with initialize(config_path="../config"):
    cfg = compose(config_name="config")

In [None]:
root_path = Path("..")

### Helper functions

In [None]:
def plot_grid(
    df,
    cols,
    plot_func=None,
    n_cols=3,
    figsize=(12, 4),
    orient="auto",
    order="auto",
    x=None,
    y=None,
    hue=None,
    plot_type=None,  # nuevo: "kde", "heatmap_crosstab", "custom"
    **plot_kwargs,
):
    """
    Render multiple subplots in a grid layout for a list of columns.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.
    cols : list
        Columns to plot.
    plot_func : callable or None
        Seaborn/matplotlib function. Ignorado si plot_type es especial.
    n_cols : int
        Number of subplot columns.
    figsize : tuple
        (width, height) per row.
    orient : str
        "auto", "h", "v" (solo relevante para countplot/boxplot).
    order : str or list
        "auto" (orden por frecuencia) o lista explícita.
    x, y, hue : str or None
        Column mappings.
    plot_type : str or None
        - None → comportamiento estándar
        - "kde" → distribuciones condicionadas
        - "heatmap_crosstab" → interacciones categóricas
    plot_kwargs :
        Extra args passed to plotting function.
    """
    cols = list(cols)
    if len(cols) == 0:
        return

    n_rows = int(np.ceil(len(cols) / n_cols))
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(figsize[0], figsize[1] * n_rows))
    axes = axes.flatten()

    for i, col in enumerate(cols):
        ax = axes[i]

        # --- Caso especial: heatmap de interacciones categóricas ---
        if plot_type == "heatmap_crosstab" and hue is not None:
            ct = pd.crosstab(df[col], df[hue], normalize="index")
            sns.heatmap(ct, annot=True, fmt=".2f", cmap="Blues", ax=ax, **plot_kwargs)
            ax.set_title(f"{col} vs {hue}")
            continue

        # --- Caso especial: distribuciones condicionadas (KDE/hist) ---
        if plot_type == "kde" and hue is not None:
            sns.kdeplot(
                data=df,
                x=col,
                hue=hue,
                common_norm=False,
                fill=True,
                ax=ax,
                **plot_kwargs,
            )
            ax.set_title(f"Distribution of {col} by {hue}")
            continue

        # --- Caso especial: target fijo en un eje ---
        if x is not None and y is None:
            plot_func(data=df, x=x, y=col, ax=ax, hue=hue, **plot_kwargs)
            ax.set_title(f"{col} by {x}")
            continue
        elif y is not None and x is None:
            plot_func(data=df, x=col, y=y, ax=ax, hue=hue, **plot_kwargs)
            ax.set_title(f"{y} by {col}")
            continue

        # --- Countplot con orientación dinámica ---
        if plot_func is not None and plot_func.__name__ == "countplot":
            n_unique = df[col].nunique(dropna=False)
            use_h = (orient == "h") or (orient == "auto" and n_unique > 10)

            if order == "auto":
                ord_vals = df[col].value_counts(dropna=False).index
            else:
                ord_vals = order

            if use_h:
                plot_func(data=df, y=col, order=ord_vals, ax=ax, hue=hue, **plot_kwargs)
                ax.set_xlabel("count")
                ax.set_ylabel(col)
            else:
                plot_func(data=df, x=col, order=ord_vals, ax=ax, hue=hue, **plot_kwargs)
                ax.set_xlabel(col)
                ax.set_ylabel("count")
                if ord_vals is not None and len(ord_vals) > 8:
                    ax.tick_params(axis="x", rotation=45)
            continue

        # --- Caso general ---
        if plot_func is not None:
            plot_func(data=df, x=col, ax=ax, hue=hue, **plot_kwargs)
            ax.set_title(f"{col}")

    # Limpiar ejes sobrantes
    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.show()

# 2. Load data

In [None]:
df, _ = load_data(root_path / cfg["data"]["path"])

# 3. Data Overview & Quality Checks

### Basic info

In [None]:
print("Shape of dataset:", df.shape)

In [None]:
df.head().transpose()

In [None]:
df.info()

### Descriptive statistics (numeric + categorical)

In [None]:
display(df.describe(include=[np.number]).T)
display(df.describe(include=["object", "category"]).T)

### Duplicates

In [None]:
print("Duplicated rows:", df.duplicated().sum())

### Cardinality & low variance

In [None]:
unique_counts = df.nunique().sort_values()
print("High cardinality features:", unique_counts[unique_counts > 50])
print("Constant features:", unique_counts[unique_counts == 1])

# 4. Missing Values Analysis

In [None]:
missing = df.isnull().sum()
missing_pct = (missing / len(df)) * 100
missing_df = (
    pd.DataFrame({"Missing Count": missing, "Missing %": missing_pct})
    .query("`Missing Count` > 0")
    .sort_values(by="Missing %", ascending=False)
)

if not missing_df.empty:
    display(missing_df)
    plt.figure(figsize=(10, 4))
    sns.barplot(x=missing_df.index, y=missing_df["Missing %"])
    plt.title("Missing Values (%)")
    plt.xticks(rotation=45)

# 5. Target Variable Analysis

In [None]:
target = cfg["data"]["target"]

### Distribution of target

In [None]:
plt.figure(figsize=(5, 4))
sns.countplot(x=target, data=df)
plt.title("Target Variable Distribution")
plt.show()

### Class balance in %

In [None]:
print(df[target].value_counts(normalize=True).mul(100).round(2))

# 6. Univariate Feature Analysis

In [None]:
numeric_cols = df.select_dtypes(include=np.number).columns
cat_cols = df.select_dtypes(exclude=np.number).columns.drop(target, errors="ignore")

### Histograms for numeric variables

In [None]:
plot_grid(df, numeric_cols, sns.histplot, bins=30, kde=True)

### Count plots for categorical variables

In [None]:
plot_grid(df, cat_cols, sns.countplot)

# 7. Bivariate Analysis

### Correlation matrix

In [None]:
numeric_plus_target = list(numeric_cols)

if pd.api.types.is_numeric_dtype(df[target]):
    numeric_plus_target.append(target)
elif df[target].nunique() <= 10:  # categórico con pocas clases
    df["_target_num_"] = pd.factorize(df[target])[0]
    numeric_plus_target.append("_target_num_")
    print(
        f"Target '{target}' convertido a numérico y agregado al heatmap como '_target_num_'"
    )
else:
    print(f"Target '{target}' no agregado al heatmap (demasiados valores únicos).")

corr = df[numeric_plus_target].corr()

plt.figure(figsize=(12, 8))
sns.heatmap(corr, cmap="coolwarm", annot=True, fmt=".2f", vmax=1, vmin=-1)
plt.title("Correlation Matrix (Numeric Features + Target)")
plt.show()

# Limpieza si se creó columna auxiliar
if "_target_num_" in df.columns:
    df.drop(columns="_target_num_", inplace=True)

### Correlations with target (if numeric)

In [None]:
if pd.api.types.is_numeric_dtype(df[target]):
    target_series = df[target]
elif df[target].nunique() <= 10:  # categórico con pocas clases
    target_series = pd.factorize(df[target])[0]  # codificación ordinal simple
    print(f"Target '{target}' convertido a numérico para correlación.")
else:
    target_series = None
    print(
        f"Target '{target}' tiene demasiados valores únicos, no se calcularán correlaciones."
    )

if target_series is not None:
    corr_target = df[numeric_cols].corrwith(pd.Series(target_series, index=df.index))
    corr_target = corr_target.sort_values(ascending=False)

    plt.figure(figsize=(10, 4))
    corr_target.plot(kind="bar")
    plt.title("Correlation with Target")
    plt.show()

### Boxplots: numeric vs target (if categorical target)

In [None]:
if df[target].nunique() <= 10:
    plot_grid(df, numeric_cols, sns.boxplot, x=target)

### Conditional distributions by target (KDE plots)

In [None]:
if df[target].nunique() <= 10:
    plot_grid(df, numeric_cols, plot_type="kde", hue=target)

### Interactions categorical vs target

In [None]:
plot_grid(df, cat_cols, plot_type="heatmap_crosstab", hue=target)

# 8. Outlier Detection

In [None]:
plot_grid(df, numeric_cols, sns.boxplot)

# 9. Multivariate Insights

### Pairplot of numeric features + target

In [None]:
sns.pairplot(df[numeric_cols.tolist() + [target]], hue=target)
plt.show()

# 10. Feature Importance (Mutual Information)

In [None]:
X = df[numeric_cols].dropna()
y = df[target].loc[X.index]
if not y.empty and y.nunique() > 1:
    mi = mutual_info_classif(X, y, discrete_features=False)
    mi_df = pd.DataFrame({"Feature": X.columns, "MI": mi}).sort_values(
        "MI", ascending=False
    )
    sns.barplot(data=mi_df, x="MI", y="Feature")
    plt.title("Mutual Information with Target")
    plt.show()

# 11. Dimensionality Reduction (PCA)

In [None]:
X_scaled = StandardScaler().fit_transform(df[numeric_cols].dropna())
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

plt.figure(figsize=(6, 5))
sns.scatterplot(x=X_pca[:, 0], y=X_pca[:, 1], hue=df[target], palette="Set2")
plt.title("PCA (2D projection)")
plt.show()