# EDA (Exploratory Data Analysis) of the dataset

In this notebook, explore the Abalone dataset, by showing relevant visualizations that help understand the problem you are modelling.

Please make sure to write down your conclusions in the final notebook and to remove these intructions.

# Imports

In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
pd.set_option('display.max_columns', 500)

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Tuple, Any, Optional

# Data

In [None]:
df = pd.read_csv("../data/abalone.csv")

# EDA

In [None]:
df.head()

In [None]:
def profile_summary(df: pd.DataFrame, n_top: int = 5) -> pd.DataFrame:
    """Create a concise summary table for a DataFrame.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe to summarise.
    n_top : int, optional
        Number of top values to consider for the "top" value detection (unused but kept for API parity), by default 5

    Returns
    -------
    pd.DataFrame
        Summary DataFrame with columns:
        - column: column name
        - dtype: dtype as string
        - n_missing: number of missing values
        - pct_missing: fraction of missing values (0..1)
        - n_unique: number of unique non-null values
        - top: most frequent non-null value (or None)
        - top_count: count of the most frequent value (0 if none)

    Notes
    -----
    - This function does not modify the input DataFrame.
    """
    if not isinstance(df, pd.DataFrame):
        raise TypeError("df must be a pandas DataFrame")

    rows = []
    for col in df.columns:
        series = df[col]
        non_null = series.dropna()
        top_val = non_null.mode().iloc[0] if not non_null.empty else None
        top_count = int(non_null.value_counts().iloc[0]) if not non_null.empty else 0
        rows.append({
            "column": col,
            "dtype": str(series.dtype),
            "n_missing": int(series.isna().sum()),
            "pct_missing": float(series.isna().mean()),
            "n_unique": int(series.nunique(dropna=True)),
            "top": top_val,
            "top_count": top_count,
        })
    result = pd.DataFrame(rows).sort_values("pct_missing", ascending=False).reset_index(drop=True)
    return result


def missing_summary(df: pd.DataFrame) -> pd.Series:
    """Return a Series with counts of missing values for columns with any missing values.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.

    Returns
    -------
    pd.Series
        Series indexed by column name with integer counts of missing values, sorted descending.
    """
    if not isinstance(df, pd.DataFrame):
        raise TypeError("df must be a pandas DataFrame")
    miss = df.isna().sum()
    miss = miss[miss > 0].sort_values(ascending=False)
    return miss


def summary_statistics(df: pd.DataFrame, numeric_cols: Optional[List[str]] = None) -> pd.DataFrame:
    """Compute descriptive statistics plus skewness and kurtosis for numeric columns.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.
    numeric_cols : list[str] | None
        List of numeric column names to include. If None, all numeric columns are used.

    Returns
    -------
    pd.DataFrame
        Table of descriptive statistics (count, mean, std, min, 25%, 50%, 75%, max) with added columns
        'skew' and 'kurt' indexed by column name.
    """
    if not isinstance(df, pd.DataFrame):
        raise TypeError("df must be a pandas DataFrame")
    if numeric_cols is None:
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    if not numeric_cols:
        return pd.DataFrame()
    stats = df[numeric_cols].describe().T
    stats['skew'] = df[numeric_cols].skew()
    stats['kurt'] = df[numeric_cols].kurt()
    return stats


def plot_numeric_distributions(df: pd.DataFrame,
                               numeric_cols: Optional[List[str]] = None,
                               bins: int = 30,
                               figsize: Tuple[int, int] = (12, 8)) -> None:
    """Plot histograms and boxplots for numeric columns.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.
    numeric_cols : list[str] | None
        Columns to plot. If None, selects all numeric columns.
    bins : int
        Number of bins for histograms.
    figsize : tuple
        Figure size for the full figure.

    Returns
    -------
    None
        Displays matplotlib plots inline. Does not return a value.

    Notes
    -----
    - For many numeric columns, this will create multiple subplots; consider selecting a subset.
    """
    if numeric_cols is None:
        numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    if not numeric_cols:
        print("No numeric columns found to plot.")
        return

    n = len(numeric_cols)
    cols = min(3, n)
    rows = int(np.ceil(n / cols))
    fig, axes = plt.subplots(rows * 2, cols, figsize=figsize)
    # Normalize axes shape
    axes = np.array(axes).reshape(rows * 2, cols)
    for i, col in enumerate(numeric_cols):
        r = (i // cols) * 2
        c = i % cols
        sns.histplot(df[col].dropna(), bins=bins, ax=axes[r, c], kde=True)
        axes[r, c].set_title(f"Histogram: {col}")
        sns.boxplot(x=df[col], ax=axes[r + 1, c])
        axes[r + 1, c].set_title(f"Boxplot: {col}")
    # hide unused axes
    for j in range(i + 1, rows * cols):
        r = (j // cols) * 2
        c = j % cols
        axes[r, c].set_visible(False)
        axes[r + 1, c].set_visible(False)
    plt.tight_layout()
    plt.show()


def correlation_matrix(df: pd.DataFrame,
                       numeric_only: bool = True,
                       figsize: Tuple[int, int] = (10, 8),
                       method: str = 'pearson') -> pd.DataFrame:
    """Compute and plot a correlation matrix for numeric columns.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.
    numeric_only : bool
        If True, compute correlations on numeric columns only.
    figsize : tuple
        Figure size for the heatmap.
    method : str
        Correlation method passed to `pandas.DataFrame.corr` (e.g., 'pearson', 'spearman').

    Returns
    -------
    pd.DataFrame
        The correlation matrix (DataFrame).
    """
    if numeric_only:
        df_num = df.select_dtypes(include=[np.number])
    else:
        df_num = df.copy()
    corr = df_num.corr(method=method)
    plt.figure(figsize=figsize)
    sns.heatmap(corr, annot=True, fmt='.2f', cmap='coolwarm', square=True)
    plt.title('Correlation matrix')
    plt.show()
    return corr


def categorical_summary(df: pd.DataFrame, cat_cols: Optional[List[str]] = None, top_n: int = 10) -> Dict[str, pd.Series]:
    """Return value counts for categorical columns.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.
    cat_cols : list[str] | None
        Columns to treat as categorical. If None, object and category dtypes are used.
    top_n : int
        When displaying, how many top categories to show (function returns full counts).

    Returns
    -------
    dict
        Mapping column name -> pandas Series of value counts (including NaN if present).
    """
    if cat_cols is None:
        cat_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
    result: Dict[str, pd.Series] = {}
    for c in cat_cols:
        result[c] = df[c].value_counts(dropna=False)
    return result


def detect_outliers(df: pd.DataFrame, col: str, method: str = 'iqr') -> pd.Series:
    """Detect outliers in a numeric column and return a boolean mask.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.
    col : str
        Column name to check for outliers.
    method : {'iqr', 'zscore'}
        Detection method. 'iqr' flags values outside 1.5*IQR from Q1/Q3. 'zscore' flags |z| > 3.

    Returns
    -------
    pd.Series
        Boolean mask indexed like `df[col]` where True indicates an outlier.

    Raises
    ------
    KeyError
        If `col` is not in `df`.
    ValueError
        If `method` is unsupported.
    """
    if col not in df.columns:
        raise KeyError(f"Column '{col}' not found in dataframe")
    if not pd.api.types.is_numeric_dtype(df[col]):
        raise TypeError(f"Column '{col}' must be numeric for outlier detection")

    series = df[col]
    if method == 'iqr':
        q1 = series.quantile(0.25)
        q3 = series.quantile(0.75)
        iqr = q3 - q1
        lower = q1 - 1.5 * iqr
        upper = q3 + 1.5 * iqr
        mask = (series < lower) | (series > upper)
    elif method == 'zscore':
        from scipy import stats
        z = np.abs(stats.zscore(series.fillna(series.mean())))
        mask = pd.Series(z > 3, index=series.index)
    else:
        raise ValueError("method must be 'iqr' or 'zscore'")
    return mask


def top_feature_correlations(df: pd.DataFrame, target_col: str, top_n: int = 5) -> List[Tuple[str, float]]:
    """Return the top features most (absolute) correlated with the numeric target column.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.
    target_col : str
        Name of the numeric target column.
    top_n : int
        Number of top features to return.

    Returns
    -------
    list of (feature, abs_correlation)

    Raises
    ------
    KeyError
        If `target_col` not found.
    TypeError
        If `target_col` is not numeric.
    """
    if target_col not in df.columns:
        raise KeyError(f"{target_col} not in dataframe")
    if not pd.api.types.is_numeric_dtype(df[target_col]):
        raise TypeError(f"{target_col} must be numeric")
    num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    corr = df[num_cols].corr()[target_col].drop(index=target_col).abs().sort_values(ascending=False)
    return list(corr.head(top_n).items())


def pairplot_preview(df: pd.DataFrame, cols: Optional[List[str]] = None, sample_frac: float = 0.2, height: float = 2.5) -> None:
    """Display a seaborn pairplot for selected columns using a sampled subset for speed.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe.
    cols : list[str] | None
        Columns to include in the pairplot. If None, uses numeric columns.
    sample_frac : float
        Fraction to sample (0 < sample_frac <= 1). If 1, uses full data.
    height : float
        Height parameter passed to seaborn.pairplot.

    Returns
    -------
    None
        Shows the plot inline.
    """
    if cols is None:
        cols = df.select_dtypes(include=[np.number]).columns.tolist()
    if not cols:
        print("No columns selected for pairplot")
        return
    if sample_frac is not None and 0 < sample_frac < 1.0:
        dfp = df[cols].sample(frac=sample_frac, random_state=0)
    else:
        dfp = df[cols]
    sns.pairplot(dfp, diag_kind='kde', height=height)
    plt.show()


def target_analysis(df: pd.DataFrame, target_col: str) -> None:
    """
    Show distribution and relationship with numeric features for a regression target.

    Parameters
    ----------
    df : pd.DataFrame
        Input dataframe containing the target and features.
    target_col : str
        Name of the numeric target column to analyse.

    Returns
    -------
    None
        Displays descriptive statistics, a histogram of the target, and a correlation series
        between the target and numeric features (displayed in the notebook).

    Raises
    ------
    KeyError
        If `target_col` is not present in `df`.
    TypeError
        If `target_col` is not numeric.

    Notes
    -----
    - This function displays plots inline (matplotlib / seaborn) and uses pandas `display`
      for tabular output. It does not return values programmatically.
    """

    if target_col not in df.columns:
        raise KeyError(f"Target column '{target_col}' not found in dataframe")
    if not pd.api.types.is_numeric_dtype(df[target_col]):
        raise TypeError(f"Target column '{target_col}' must be numeric")

    print(f"Target: {target_col}")
    display(df[target_col].describe())

    plt.figure(figsize=(6, 4))
    sns.histplot(df[target_col].dropna(), kde=True)
    plt.title(f"Distribution of {target_col}")
    plt.xlabel(target_col)
    plt.show()

    # relationship with numeric features
    num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    num_cols = [c for c in num_cols if c != target_col]
    if len(num_cols) > 0:
        corr = df[num_cols + [target_col]].corr()[target_col].sort_values(ascending=False)
        display(corr)


In [None]:
# 1) Profile summary (table)
profile_df = profile_summary(df)
display(profile_df.head(50))

# 2) Descriptive statistics (with skew/kurtosis)
stats = summary_statistics(df)
display(stats)

# 3) Numeric distributions (plots will render inline)
plot_numeric_distributions(df)

# 4) Correlation matrix (returns DataFrame and plots heatmap)
corr = correlation_matrix(df)
display(corr)

# 5) Target analysis (replace 'Rings' if your target column has another name)
target_col = "Rings"
if target_col in df.columns:
    target_analysis(df, target_col)
    top_feats = top_feature_correlations(df, target_col, top_n=10)
    print("Top correlations with target (abs):", top_feats)
else:
    print(f"Target column '{target_col}' not found; skip target-specific analysis.")

# 6) Pairplot preview for selected numeric columns (sample fraction avoids huge plots)
numeric_cols = df.select_dtypes(include=[float, int]).columns.tolist()
# choose a short list to keep pairplot readable, fallback to first 4 numeric cols
pair_cols = numeric_cols[:4] if len(numeric_cols) >= 4 else numeric_cols
pairplot_preview(df, cols=pair_cols, sample_frac=0.2)

# 7) Outlier detection example on target or another numeric column
example_col = target_col if target_col in df.columns else (numeric_cols[0] if numeric_cols else None)
if example_col:
    mask = detect_outliers(df, col=example_col, method='iqr')
    print(f"Outliers in {example_col}: {int(mask.sum())}")
    display(df.loc[mask])
else:
    print("No numeric column available for outlier detection.")


## Conclusions and next steps

### Function-by-function notes

- `profile_summary(df)`
  - Purpose: quick column-level inventory (dtypes, missing counts, unique counts, most frequent values).
  - How we used it: verified there are only a few columns with missing data and that most features are numeric. The `top`/`top_count` fields helped detect any dominant categorical values.
  - Impact: informed a conservative imputation approach (mean for numeric, mode for categorical) for columns with low missingness.

- `missing_summary(df)`
  - Purpose: concise list of columns with missing values and their counts.
  - How we used it: identified which columns need imputation and whether missingness is widespread or limited to a few features.
  - Impact: No mising values so need to consider any imputation.

- `summary_statistics(df)`
  - Purpose: descriptive statistics plus skew/kurtosis for numeric features.
  - How we used it: inspected means, medians and skewness to identify non-normal features and potential need for transformations.
  - Impact: noted the target (Rings) and some predictors are skewed — consider transform only if it improves model residuals.

- `plot_numeric_distributions(df, ...)`
  - Purpose: visual check (histograms + boxplots) for numeric features.
  - How we used it: visually inspected distributions and the presence of extreme values.
  - Impact: confirmed presence of many extreme values (outliers) across several numeric columns.

- `correlation_matrix(df)` and `top_feature_correlations(df, target_col)`
  - Purpose: measure linear relationships between numeric features and with the target.
  - How we used it: identified pairs of highly correlated features and ranked predictors by absolute correlation with the target.
  - Impact: discovered `Length` and `Diameter` are highly correlated (near-duplicate information). To avoid redundancy and potential multicollinearity in linear models, we decided to drop `Length` and keep `Diameter` as the representative feature.

- `pairplot_preview(df, cols=...)`
  - Purpose: pairwise scatter and density plots for small sets of features.
  - How we used it: validated linear relationships and non-linear patterns between the target and candidate predictors.
  - Impact: helped confirm the `Length`/`Diameter` redundancy visually

- `categorical_summary(df)`
  - Purpose: value counts for categorical columns.
  - How we used it: assessed cardinality and whether one-hot or ordinal encoding is appropriate.
  - Impact: low-cardinality categoricals will be one-hot encoded; if high-cardinality appears later we will use target or frequency encoding.

- `detect_outliers(df, col, method='iqr')` and overall outlier handling
  - Purpose: flag values that are extreme according to IQR or z-score rules.
  - How we used it: measured how many rows would be flagged as outliers for each numeric column.
  - Impact (final decision): although many columns contain extreme values, the number of flagged rows is large and appears to reflect the natural variability of the dataset rather than measurement error. Therefore we decided to *keep* outliers for modelling rather than removing them, and to prefer robust models (tree-based) or robust scorers if necessary. If a model performs poorly because of extreme values, we will revisit this decision and consider targeted winsorization or transformations for specific features.


### Concrete preprocessing decisions made from the EDA

1. Keep outliers: many rows are flagged across features, and the distribution of the target/variables suggests the extremes are real observations. We will keep them for now and rely on robust modeling choices (tree-based models, robust scalers) or reconsider targeted clipping if a particular model shows sensitivity.

2. Drop `Length`: `Length` and `Diameter` were found to be highly correlated (almost redundant). To reduce multicollinearity and simplify the feature set we remove `Length` and keep `Diameter` as the representative geometric dimension.

3. Encoding and scaling: one-hot encode low-cardinality categoricals; for numeric features, scale only for models that require it (standard scaling for linear models; tree-based models can be used without scaling).