# Hybrid Correlation + PMM + MICE Imputer

A sophisticated missing data imputation library that combines three powerful techniques:
- **Correlation Analysis**: Identifies optimal predictor sets based on feature correlations
- **PMM (Predictive Mean Matching)**: Preserves data distribution through semi-parametric imputation
- **MICE (Multivariate Imputation by Chained Equations)**: Iteratively refines imputations for better accuracy

## How It Works

1. **Correlation Analysis Phase**: Computes pairwise correlations and identifies strongly correlated features
2. **Predictive Mean Matching Phase**: Fits prediction models and selects from donor pools to preserve distribution
3. **MICE Iteration Phase**: Iteratively imputes each variable using updated values from other variables

---

## Installation

First, let's install the required dependencies:

In [None]:
# Install required packages
!pip install numpy pandas scikit-learn scipy matplotlib seaborn

## Import Libraries

In [None]:
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple, Optional, Union
import warnings
from sklearn.linear_model import LinearRegression, BayesianRidge
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings('ignore')

## 1. Correlation Analyzer Module

This module analyzes correlations between features to determine optimal imputation strategies.

In [None]:
class CorrelationAnalyzer:
    """
    Analyzes correlations between features to determine optimal imputation strategies.
    Uses multiple correlation methods (Pearson, Spearman, Kendall) with weighted scoring
    to identify the most predictive features for each variable with missing values.
    """

    def __init__(self, correlation_threshold: float = 0.3, use_mixed_correlations: bool = True):
        """
        Initialize the correlation analyzer.

        Parameters:
        -----------
        correlation_threshold : float, default=0.3
            Minimum absolute correlation coefficient to consider a feature
            as a potential predictor
        use_mixed_correlations : bool, default=True
            If True, combines Pearson, Spearman, and Kendall correlations
            for more robust correlation estimation
        """
        self.correlation_threshold = correlation_threshold
        self.use_mixed_correlations = use_mixed_correlations
        self.correlation_matrix = None
        self.pearson_matrix = None
        self.spearman_matrix = None
        self.kendall_matrix = None
        self.combined_matrix = None
        self.predictor_sets = {}
        self.correlation_weights = {}

    def fit(self, data: pd.DataFrame) -> 'CorrelationAnalyzer':
        """
        Compute correlation matrix and identify predictor sets.
        Uses multiple correlation methods for robust estimation.

        Parameters:
        -----------
        data : pd.DataFrame
            The dataset to analyze

        Returns:
        --------
        self : CorrelationAnalyzer
            Fitted analyzer
        """
        if self.use_mixed_correlations:
            # Compute multiple correlation matrices
            self.pearson_matrix = data.corr(method='pearson')
            self.spearman_matrix = data.corr(method='spearman')

            # Kendall can be slow for large datasets, so we compute it selectively
            try:
                self.kendall_matrix = data.corr(method='kendall')
            except:
                self.kendall_matrix = None

            # Combine correlations with weighted average
            # Pearson: 50%, Spearman: 30%, Kendall: 20%
            if self.kendall_matrix is not None:
                self.combined_matrix = (
                    0.5 * self.pearson_matrix.abs() +
                    0.3 * self.spearman_matrix.abs() +
                    0.2 * self.kendall_matrix.abs()
                )
            else:
                # Without Kendall: Pearson 60%, Spearman 40%
                self.combined_matrix = (
                    0.6 * self.pearson_matrix.abs() +
                    0.4 * self.spearman_matrix.abs()
                )

            # Use Pearson for the main correlation matrix (for signed correlations)
            self.correlation_matrix = self.pearson_matrix
        else:
            # Simple Pearson correlation
            self.correlation_matrix = data.corr(method='pearson')
            self.combined_matrix = self.correlation_matrix.abs()

        # For each column, identify highly correlated predictors
        for col in data.columns:
            if self.use_mixed_correlations:
                # Use combined matrix for selection
                correlations_abs = self.combined_matrix[col].drop(col)
                correlations_signed = self.correlation_matrix[col].drop(col)
            else:
                correlations_abs = self.correlation_matrix[col].abs().drop(col)
                correlations_signed = self.correlation_matrix[col].drop(col)

            # Select features with correlation above threshold
            strong_mask = correlations_abs >= self.correlation_threshold
            strong_correlates = correlations_abs[strong_mask].sort_values(ascending=False)

            # Store predictor list
            self.predictor_sets[col] = list(strong_correlates.index)

            # Store correlation weights for each predictor (normalized)
            if len(strong_correlates) > 0:
                # Use signed correlations for weights
                weights = {}
                for pred in strong_correlates.index:
                    # Weight is the combined correlation strength
                    weights[pred] = strong_correlates[pred]

                # Normalize weights to sum to 1
                total_weight = sum(weights.values())
                if total_weight > 0:
                    weights = {k: v / total_weight for k, v in weights.items()}

                self.correlation_weights[col] = weights
            else:
                self.correlation_weights[col] = {}

        return self

    def get_predictors(self, target_column: str, max_predictors: int = None) -> List[str]:
        """
        Get the list of best predictor columns for a target column.

        Parameters:
        -----------
        target_column : str
            The column to find predictors for
        max_predictors : int, optional
            Maximum number of predictors to return

        Returns:
        --------
        predictors : List[str]
            List of predictor column names
        """
        if target_column not in self.predictor_sets:
            return []

        predictors = self.predictor_sets[target_column]

        if max_predictors is not None:
            predictors = predictors[:max_predictors]

        return predictors

    def get_correlation_strength(self, col1: str, col2: str) -> float:
        """
        Get the correlation coefficient between two columns.

        Parameters:
        -----------
        col1, col2 : str
            Column names

        Returns:
        --------
        correlation : float
            Pearson correlation coefficient
        """
        if self.correlation_matrix is None:
            raise ValueError("Analyzer not fitted. Call fit() first.")

        return self.correlation_matrix.loc[col1, col2]

    def get_predictor_weights(self, target_column: str) -> Dict[str, float]:
        """
        Get normalized correlation weights for predictors of a target column.

        Parameters:
        -----------
        target_column : str
            The column to get predictor weights for

        Returns:
        --------
        weights : Dict[str, float]
            Dictionary mapping predictor names to their normalized weights
        """
        if target_column not in self.correlation_weights:
            return {}
        return self.correlation_weights[target_column]

    def get_imputation_order(self, columns_with_missing: List[str]) -> List[str]:
        """
        Determine optimal order for imputing columns based on correlations.
        Uses weighted scoring based on both quantity and quality of predictors.

        Parameters:
        -----------
        columns_with_missing : List[str]
            Columns that have missing values

        Returns:
        --------
        ordered_columns : List[str]
            Columns ordered by imputation priority
        """
        # Score each column by quality and quantity of available predictors
        scores = {}
        for col in columns_with_missing:
            predictors = self.get_predictors(col)
            weights = self.get_predictor_weights(col)

            if predictors:
                # Score based on both count and strength of correlations
                # Higher score = better predictors = impute later (use as predictor for others first)
                avg_weight = sum(weights.values()) / len(weights) if weights else 0
                score = len(predictors) * (1 + avg_weight)  # Weighted by average correlation strength
            else:
                score = 0

            scores[col] = score

        # Sort by score (ascending) - impute columns with weaker predictors first
        # This allows later imputations to benefit from strongly predicted columns
        ordered = sorted(scores.items(), key=lambda x: x[1])

        return [col for col, _ in ordered]

    def visualize_correlations(self, figsize: Tuple[int, int] = (12, 10)):
        """
        Create a heatmap visualization of the correlation matrix.

        Parameters:
        -----------
        figsize : Tuple[int, int]
            Figure size for the plot
        """
        plt.figure(figsize=figsize)
        sns.heatmap(
            self.correlation_matrix,
            annot=True,
            cmap='coolwarm',
            center=0,
            vmin=-1,
            vmax=1,
            fmt='.2f'
        )
        plt.title('Feature Correlation Matrix')
        plt.tight_layout()
        plt.show()

## 2. Predictive Mean Matching (PMM) Imputer Module

Implements PMM algorithm for semi-parametric imputation that preserves data distribution.

In [None]:
class PMMImputer:
    """
    Predictive Mean Matching (PMM) imputer.

    PMM is a semi-parametric imputation method that:
    1. Fits a prediction model on observed data
    2. Predicts values for missing data
    3. Finds observed values with similar predictions
    4. Randomly selects from these "donor" values

    This preserves the distribution of the original data better than
    simple regression imputation.
    """

    def __init__(
        self,
        n_neighbors: int = 5,
        model_type: str = 'linear',
        random_state: Optional[int] = None
    ):
        """
        Initialize the PMM imputer.

        Parameters:
        -----------
        n_neighbors : int, default=5
            Number of nearest neighbors to consider for donor pool
        model_type : str, default='linear'
            Type of prediction model: 'linear', 'bayesian', or 'rf' (random forest)
        random_state : int, optional
            Random state for reproducibility
        """
        self.n_neighbors = n_neighbors
        self.model_type = model_type
        self.random_state = random_state
        self.model = None
        self.scaler = StandardScaler()

        # Initialize the prediction model
        if model_type == 'linear':
            self.model = LinearRegression()
        elif model_type == 'bayesian':
            self.model = BayesianRidge()
        elif model_type == 'rf':
            self.model = RandomForestRegressor(
                n_estimators=100,
                random_state=random_state
            )
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

        self.rng = np.random.RandomState(random_state)

    def fit_transform(
        self,
        data: pd.DataFrame,
        target_column: str,
        predictor_columns: List[str],
        predictor_weights: Optional[Dict[str, float]] = None
    ) -> np.ndarray:
        """
        Impute missing values in the target column using enhanced PMM.

        Parameters:
        -----------
        data : pd.DataFrame
            The dataset
        target_column : str
            Column to impute
        predictor_columns : List[str]
            Columns to use as predictors
        predictor_weights : Dict[str, float], optional
            Weights for each predictor (correlation strengths)

        Returns:
        --------
        imputed_values : np.ndarray
            The imputed column (complete, no missing values)
        """
        # Separate observed and missing data
        target = data[target_column]
        predictors = data[predictor_columns]

        # Handle case where predictors might have missing values
        # For now, we'll only use complete cases in predictors
        complete_mask = ~predictors.isnull().any(axis=1)
        observed_mask = ~target.isnull() & complete_mask
        missing_mask = target.isnull() & complete_mask

        if missing_mask.sum() == 0:
            # No missing values to impute
            return target.values

        if observed_mask.sum() < self.n_neighbors:
            # Not enough observed data for PMM, fall back to mean imputation
            mean_value = target[observed_mask].mean()
            result = target.copy()
            result[missing_mask] = mean_value
            return result.values

        # Get observed and missing predictor matrices
        X_observed = predictors[observed_mask]
        X_missing = predictors[missing_mask]
        y_observed = target[observed_mask].values

        # Create weight vector for predictors
        if predictor_weights:
            weight_vector = np.array([predictor_weights.get(col, 1.0) for col in predictor_columns])
            # Normalize to reasonable scale
            weight_vector = weight_vector / weight_vector.sum() * len(weight_vector)
        else:
            weight_vector = np.ones(len(predictor_columns))

        # Scale the predictors with weights
        X_observed_scaled = self.scaler.fit_transform(X_observed)
        X_missing_scaled = self.scaler.transform(X_missing)

        # Apply weights to scaled predictors
        X_observed_weighted = X_observed_scaled * weight_vector
        X_missing_weighted = X_missing_scaled * weight_vector

        # Fit the prediction model
        self.model.fit(X_observed_weighted, y_observed)

        # Predict for both observed and missing
        # Add small stochastic noise to predictions for uncertainty
        y_observed_pred = self.model.predict(X_observed_weighted)
        y_missing_pred = self.model.predict(X_missing_weighted)

        # Add stochastic component based on residual standard deviation
        residuals = y_observed - y_observed_pred
        residual_std = np.std(residuals)

        # Add noise to missing predictions (stochastic regression component)
        if residual_std > 0:
            noise = self.rng.normal(0, residual_std * 0.5, len(y_missing_pred))
            y_missing_pred_noisy = y_missing_pred + noise
        else:
            y_missing_pred_noisy = y_missing_pred

        # For each missing value, find donors and select one
        imputed_values = np.zeros(missing_mask.sum())

        for i, (pred_value, x_missing) in enumerate(zip(y_missing_pred_noisy, X_missing_weighted)):
            # Compute combined distance: prediction space + predictor space
            # Distance in prediction space (primary)
            pred_distances = np.abs(y_observed_pred - pred_value)

            # Distance in predictor space (secondary, for tie-breaking)
            predictor_distances = np.sqrt(np.sum((X_observed_weighted - x_missing) ** 2, axis=1))

            # Normalize both distances
            pred_dist_norm = pred_distances / (np.std(pred_distances) + 1e-10)
            predictor_dist_norm = predictor_distances / (np.std(predictor_distances) + 1e-10)

            # Combined distance: 70% prediction, 30% predictor space
            combined_distances = 0.7 * pred_dist_norm + 0.3 * predictor_dist_norm

            # Find the k nearest neighbors
            n_donors = min(self.n_neighbors, len(combined_distances))
            nearest_indices = np.argpartition(combined_distances, n_donors - 1)[:n_donors]

            # Weighted random selection from donors
            # Donors closer in distance have higher probability
            donor_distances = combined_distances[nearest_indices]
            # Convert distances to weights (inverse distance)
            donor_weights = 1.0 / (donor_distances + 0.01)  # Add small constant to avoid division by zero
            donor_weights = donor_weights / donor_weights.sum()

            # Select donor with weighted probability
            donor_idx = self.rng.choice(nearest_indices, p=donor_weights)
            imputed_values[i] = y_observed[donor_idx]

        # Create result array
        result = target.copy()
        result[missing_mask] = imputed_values

        return result.values

    def impute_column(
        self,
        data: pd.DataFrame,
        target_column: str,
        predictor_columns: Optional[List[str]] = None,
        predictor_weights: Optional[Dict[str, float]] = None
    ) -> pd.Series:
        """
        Convenience method to impute a single column and return as Series.

        Parameters:
        -----------
        data : pd.DataFrame
            The dataset
        target_column : str
            Column to impute
        predictor_columns : List[str], optional
            Columns to use as predictors. If None, uses all other columns.
        predictor_weights : Dict[str, float], optional
            Weights for each predictor (correlation strengths)

        Returns:
        --------
        imputed_column : pd.Series
            The imputed column
        """
        if predictor_columns is None:
            predictor_columns = [col for col in data.columns if col != target_column]

        imputed_values = self.fit_transform(data, target_column, predictor_columns, predictor_weights)
        return pd.Series(imputed_values, index=data.index, name=target_column)


class AdaptivePMMImputer(PMMImputer):
    """
    Adaptive PMM imputer that adjusts n_neighbors based on data availability.
    """

    def __init__(
        self,
        n_neighbors: int = 5,
        min_neighbors: int = 3,
        model_type: str = 'linear',
        random_state: Optional[int] = None
    ):
        """
        Initialize adaptive PMM imputer.

        Parameters:
        -----------
        n_neighbors : int, default=5
            Target number of nearest neighbors
        min_neighbors : int, default=3
            Minimum number of neighbors to use
        model_type : str, default='linear'
            Type of prediction model
        random_state : int, optional
            Random state for reproducibility
        """
        super().__init__(n_neighbors, model_type, random_state)
        self.min_neighbors = min_neighbors

    def fit_transform(
        self,
        data: pd.DataFrame,
        target_column: str,
        predictor_columns: List[str],
        predictor_weights: Optional[Dict[str, float]] = None
    ) -> np.ndarray:
        """
        Impute with adaptive neighbor selection.
        """
        target = data[target_column]
        predictors = data[predictor_columns]

        complete_mask = ~predictors.isnull().any(axis=1)
        observed_mask = ~target.isnull() & complete_mask
        n_observed = observed_mask.sum()

        # Adapt n_neighbors based on available data
        if n_observed < self.min_neighbors:
            # Fall back to mean imputation
            mean_value = target[observed_mask].mean()
            result = target.copy()
            result[target.isnull()] = mean_value
            return result.values

        # Adjust n_neighbors to available data
        original_n = self.n_neighbors
        self.n_neighbors = min(self.n_neighbors, max(self.min_neighbors, n_observed // 3))

        # Call parent fit_transform
        result = super().fit_transform(data, target_column, predictor_columns, predictor_weights)

        # Restore original n_neighbors
        self.n_neighbors = original_n

        return result

## 3. Hybrid MICE Imputer Module

Combines Correlation Analysis + PMM + MICE for advanced missing data imputation.

In [None]:
class HybridMICEImputer:
    """
    Hybrid imputation model combining:
    - Correlation Analysis: To identify optimal predictor sets
    - PMM (Predictive Mean Matching): For distribution-preserving imputation
    - MICE (Multivariate Imputation by Chained Equations): For iterative refinement

    This hybrid approach:
    1. Uses correlation analysis to determine which features best predict each missing variable
    2. Employs PMM to impute values while preserving data distribution
    3. Iterates using MICE framework to refine imputations based on newly imputed values
    """

    def __init__(
        self,
        n_iterations: int = 15,
        n_neighbors: int = 10,
        correlation_threshold: float = 0.25,
        max_predictors: int = 15,
        pmm_model_type: str = 'bayesian',
        convergence_threshold: float = 0.001,
        random_state: Optional[int] = None,
        verbose: bool = False,
        exclude_columns: Optional[List[str]] = None,
        use_mixed_correlations: bool = True
    ):
        """
        Initialize the Hybrid MICE imputer with enhanced correlation and PMM.

        Parameters:
        -----------
        n_iterations : int, default=15
            Maximum number of MICE iterations (increased for better convergence)
        n_neighbors : int, default=10
            Number of neighbors for PMM (increased for better donor pool)
        correlation_threshold : float, default=0.25
            Minimum correlation to consider for predictor selection (lowered for more predictors)
        max_predictors : int, default=15
            Maximum number of predictors to use per variable (increased for better modeling)
        pmm_model_type : str, default='bayesian'
            Model type for PMM: 'linear', 'bayesian', or 'rf' (bayesian for uncertainty)
        convergence_threshold : float, default=0.001
            Threshold for convergence detection
        random_state : int, optional
            Random state for reproducibility
        verbose : bool, default=False
            Whether to print progress information
        exclude_columns : List[str], optional
            Columns to exclude from imputation (e.g., ID columns like 'ptid')
        use_mixed_correlations : bool, default=True
            Use mixed correlation methods (Pearson, Spearman, Kendall) for robust estimation
        """
        self.n_iterations = n_iterations
        self.n_neighbors = n_neighbors
        self.correlation_threshold = correlation_threshold
        self.max_predictors = max_predictors
        self.pmm_model_type = pmm_model_type
        self.convergence_threshold = convergence_threshold
        self.random_state = random_state
        self.verbose = verbose
        self.exclude_columns = exclude_columns or []
        self.use_mixed_correlations = use_mixed_correlations

        # Components
        self.correlation_analyzer = CorrelationAnalyzer(
            correlation_threshold=correlation_threshold,
            use_mixed_correlations=use_mixed_correlations
        )
        self.imputer = AdaptivePMMImputer(
            n_neighbors=n_neighbors,
            model_type=pmm_model_type,
            random_state=random_state
        )

        # State
        self.missing_indicators = None
        self.columns_with_missing = []
        self.numeric_columns_with_missing = []
        self.imputation_order = []
        self.convergence_history = []

    def _identify_missing_data(self, data: pd.DataFrame) -> None:
        """Identify columns with missing data and create missing indicators."""
        self.missing_indicators = data.isnull()
        self.columns_with_missing = [
            col for col in data.columns if self.missing_indicators[col].any()
        ]

        # Filter to only numeric columns for PMM imputation
        # Exclude non-numeric columns and any user-specified exclusions
        self.numeric_columns_with_missing = [
            col for col in self.columns_with_missing
            if col not in self.exclude_columns
            and data[col].dtype in [np.float64, np.float32, np.int64, np.int32, np.float16, np.int16, np.int8]
        ]

        if self.verbose:
            print(f"Columns with missing data: {len(self.columns_with_missing)}")
            print(f"Numeric columns to impute with PMM: {len(self.numeric_columns_with_missing)}")
            if len(self.columns_with_missing) > len(self.numeric_columns_with_missing):
                excluded = set(self.columns_with_missing) - set(self.numeric_columns_with_missing)
                print(f"Excluded columns (non-numeric or specified): {list(excluded)}")
            print()
            for col in self.columns_with_missing:
                n_missing = self.missing_indicators[col].sum()
                pct_missing = 100 * n_missing / len(data)
                col_type = "numeric (PMM)" if col in self.numeric_columns_with_missing else "excluded"
                print(f"  {col}: {n_missing} ({pct_missing:.2f}%) - {col_type}")

    def _initialize_imputation(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Initialize missing values with simple mean imputation.
        This provides a starting point for the MICE iterations.
        """
        imputed = data.copy()

        for col in self.columns_with_missing:
            if imputed[col].dtype in [np.float64, np.float32, np.int64, np.int32]:
                # Numerical: use mean
                mean_val = imputed[col].mean()
                imputed.loc[:, col] = imputed[col].fillna(mean_val)
            else:
                # Categorical: use mode
                mode_val = imputed[col].mode()[0] if len(imputed[col].mode()) > 0 else imputed[col].iloc[0]
                imputed.loc[:, col] = imputed[col].fillna(mode_val)

        return imputed

    def _compute_convergence_metric(
        self,
        data_current: pd.DataFrame,
        data_previous: pd.DataFrame
    ) -> float:
        """
        Compute convergence metric based on change in imputed values.
        Only considers numeric columns that are being imputed.
        Uses robust standardization and tracks both mean and max changes.
        """
        if data_previous is None:
            return float('inf')

        changes = []

        # Only compute convergence for numeric columns being imputed
        for col in self.numeric_columns_with_missing:
            missing_mask = self.missing_indicators[col]
            if missing_mask.any():
                current_vals = data_current.loc[missing_mask, col]
                previous_vals = data_previous.loc[missing_mask, col]

                # Compute absolute differences
                abs_diff = np.abs(current_vals - previous_vals)

                # Normalize by robust standard deviation (using MAD - median absolute deviation)
                col_values = data_current[col].values
                median = np.median(col_values)
                mad = np.median(np.abs(col_values - median))

                if mad > 0:
                    # Use MAD for robust scaling
                    normalized_change = abs_diff / (1.4826 * mad)  # 1.4826 * MAD approximates std
                else:
                    # Fallback to standard deviation
                    std = data_current[col].std()
                    if std > 0:
                        normalized_change = abs_diff / std
                    else:
                        # If no variation, use absolute change
                        normalized_change = abs_diff

                # Use mean change for this column
                changes.append(np.mean(normalized_change))

        if len(changes) == 0:
            return 0

        # Return average change across all columns
        # Also consider max change to ensure all columns have converged
        mean_change = np.mean(changes)
        max_change = np.max(changes)

        # Weighted combination: prioritize mean but consider max
        return 0.7 * mean_change + 0.3 * max_change

    def fit_transform(
        self,
        data: pd.DataFrame,
        columns_to_impute: Optional[List[str]] = None
    ) -> pd.DataFrame:
        """
        Impute missing values using the hybrid Correlation + PMM + MICE approach.

        Parameters:
        -----------
        data : pd.DataFrame
            Dataset with missing values
        columns_to_impute : List[str], optional
            Specific columns to impute. If None, imputes all columns with missing values.

        Returns:
        --------
        imputed_data : pd.DataFrame
            Dataset with imputed values
        """
        if data.isnull().sum().sum() == 0:
            if self.verbose:
                print("No missing data found. Returning original data.")
            return data.copy()

        # Identify missing data
        self._identify_missing_data(data)

        # Filter to requested columns
        if columns_to_impute is not None:
            self.columns_with_missing = [
                col for col in self.columns_with_missing if col in columns_to_impute
            ]
            self.numeric_columns_with_missing = [
                col for col in self.numeric_columns_with_missing if col in columns_to_impute
            ]

        if not self.numeric_columns_with_missing:
            if self.verbose:
                print("No numeric columns to impute with PMM. Returning data with simple imputation for non-numeric columns.")
            # Still do simple imputation for non-numeric columns
            return self._initialize_imputation(data)

        # Initialize with simple imputation
        imputed = self._initialize_imputation(data)

        # Fit correlation analyzer on complete cases first, but only on numeric columns
        numeric_cols = [col for col in data.columns
                       if data[col].dtype in [np.float64, np.float32, np.int64, np.int32, np.float16, np.int16, np.int8]]
        complete_data = data[numeric_cols].dropna()
        if len(complete_data) > 0:
            self.correlation_analyzer.fit(complete_data)
        else:
            # If no complete cases, use initialized numeric data
            self.correlation_analyzer.fit(imputed[numeric_cols])

        # Determine imputation order based on correlations (only for numeric columns)
        self.imputation_order = self.correlation_analyzer.get_imputation_order(
            self.numeric_columns_with_missing
        )

        if self.verbose:
            print(f"Imputation order: {self.imputation_order}")
            print(f"Starting MICE iterations (max {self.n_iterations})...")

        # MICE iterations
        previous_imputed = None
        self.convergence_history = []

        for iteration in range(self.n_iterations):
            if self.verbose:
                print(f"\nIteration {iteration + 1}/{self.n_iterations}")

            # Iterate through each column with missing values
            for col in self.imputation_order:
                # Get correlation-based predictors
                predictors = self.correlation_analyzer.get_predictors(
                    col,
                    max_predictors=self.max_predictors
                )

                # If no predictors found, use all other columns
                if not predictors:
                    predictors = [c for c in imputed.columns if c != col]

                # Filter out predictors that are all NaN or have too many NaNs
                valid_predictors = []
                for pred in predictors:
                    if imputed[pred].isnull().sum() < len(imputed) * 0.5:
                        valid_predictors.append(pred)

                if not valid_predictors:
                    if self.verbose:
                        print(f"  {col}: No valid predictors, skipping")
                    continue

                # Apply PMM imputation
                try:
                    # Create temporary dataset with only needed columns
                    temp_data = imputed[[col] + valid_predictors].copy()

                    # CRITICAL FIX: Restore original missing values in the target column
                    # so that PMM can actually see them and perform proper imputation
                    missing_mask = self.missing_indicators[col]
                    temp_data.loc[missing_mask, col] = np.nan

                    # Get correlation weights for predictors
                    predictor_weights = self.correlation_analyzer.get_predictor_weights(col)
                    # Filter weights to only valid predictors
                    filtered_weights = {k: v for k, v in predictor_weights.items() if k in valid_predictors}

                    # Impute the column using PMM with correlation weights
                    imputed_col = self.imputer.fit_transform(
                        temp_data,
                        col,
                        valid_predictors,
                        predictor_weights=filtered_weights
                    )

                    # Update only the originally missing values
                    imputed.loc[missing_mask, col] = imputed_col[missing_mask]

                    if self.verbose:
                        n_weighted = len(filtered_weights)
                        print(f"  {col}: Imputed with {len(valid_predictors)} predictors ({n_weighted} weighted)")

                except Exception as e:
                    if self.verbose:
                        print(f"  {col}: Error during imputation - {str(e)}")
                    continue

            # Check convergence
            convergence = self._compute_convergence_metric(imputed, previous_imputed)
            self.convergence_history.append(convergence)

            if self.verbose:
                print(f"Convergence metric: {convergence:.6f}")

            if convergence < self.convergence_threshold:
                if self.verbose:
                    print(f"Converged at iteration {iteration + 1}")
                break

            previous_imputed = imputed.copy()

        if self.verbose:
            print("\nImputation complete!")
            print(f"Total iterations: {len(self.convergence_history)}")

        return imputed

    def fit(self, data: pd.DataFrame) -> 'HybridMICEImputer':
        """
        Fit the imputer (mainly for correlation analysis).

        Parameters:
        -----------
        data : pd.DataFrame
            Training data

        Returns:
        --------
        self : HybridMICEImputer
        """
        # Only fit on numeric columns
        numeric_cols = [col for col in data.columns
                       if data[col].dtype in [np.float64, np.float32, np.int64, np.int32, np.float16, np.int16, np.int8]]
        complete_data = data[numeric_cols].dropna()
        if len(complete_data) > 0:
            self.correlation_analyzer.fit(complete_data)
        else:
            # If no complete cases, use all numeric data
            self.correlation_analyzer.fit(data[numeric_cols])

        return self

    def transform(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Transform data using fitted imputer.

        Parameters:
        -----------
        data : pd.DataFrame
            Data to impute

        Returns:
        --------
        imputed_data : pd.DataFrame
        """
        return self.fit_transform(data)

    def get_diagnostics(self) -> Dict[str, Union[List, pd.DataFrame]]:
        """
        Get diagnostic information about the imputation process.

        Returns:
        --------
        diagnostics : dict
            Dictionary containing diagnostic information
        """
        excluded_columns = list(set(self.columns_with_missing) - set(self.numeric_columns_with_missing))

        diagnostics = {
            'convergence_history': self.convergence_history,
            'imputation_order': self.imputation_order,
            'columns_with_missing': self.columns_with_missing,
            'numeric_columns_imputed': self.numeric_columns_with_missing,
            'excluded_columns': excluded_columns,
            'correlation_matrix': self.correlation_analyzer.correlation_matrix,
            'predictor_sets': self.correlation_analyzer.predictor_sets
        }

        return diagnostics

    def plot_convergence(self):
        """
        Plot the convergence history.
        """
        if not self.convergence_history:
            print("No convergence history available. Run fit_transform first.")
            return

        plt.figure(figsize=(10, 6))
        plt.plot(range(1, len(self.convergence_history) + 1), self.convergence_history, marker='o')
        plt.axhline(y=self.convergence_threshold, color='r', linestyle='--', label='Convergence threshold')
        plt.xlabel('Iteration')
        plt.ylabel('Convergence Metric')
        plt.title('MICE Convergence History')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()

    @staticmethod
    def load_data(
        file_path: str,
        sheet_name: Optional[Union[str, int]] = 0,
        **kwargs
    ) -> pd.DataFrame:
        """
        Load data from CSV or Excel (.xlsx) file.

        Parameters:
        -----------
        file_path : str
            Path to the file (supports .csv, .xlsx, .xls)
        sheet_name : str or int, default=0
            Sheet name or index for Excel files (ignored for CSV)
        **kwargs : dict
            Additional arguments passed to pd.read_csv() or pd.read_excel()

        Returns:
        --------
        data : pd.DataFrame
            Loaded data

        Examples:
        ---------
        >>> # Load CSV file
        >>> data = HybridMICEImputer.load_data('data.csv')
        >>>
        >>> # Load Excel file (first sheet)
        >>> data = HybridMICEImputer.load_data('data.xlsx')
        >>>
        >>> # Load specific sheet from Excel
        >>> data = HybridMICEImputer.load_data('data.xlsx', sheet_name='Sheet2')
        >>> data = HybridMICEImputer.load_data('data.xlsx', sheet_name=1)
        """
        file_path = str(file_path)

        if file_path.endswith('.csv'):
            return pd.read_csv(file_path, **kwargs)
        elif file_path.endswith(('.xlsx', '.xls')):
            return pd.read_excel(file_path, sheet_name=sheet_name, **kwargs)
        else:
            # Try to infer format
            try:
                return pd.read_excel(file_path, sheet_name=sheet_name, **kwargs)
            except:
                return pd.read_csv(file_path, **kwargs)

    @staticmethod
    def save_data(
        data: pd.DataFrame,
        file_path: str,
        sheet_name: str = 'Sheet1',
        index: bool = False,
        **kwargs
    ) -> None:
        """
        Save data to CSV or Excel (.xlsx) file.

        Parameters:
        -----------
        data : pd.DataFrame
            Data to save
        file_path : str
            Path to save the file (supports .csv, .xlsx)
        sheet_name : str, default='Sheet1'
            Sheet name for Excel files (ignored for CSV)
        index : bool, default=False
            Whether to include the index in the saved file
        **kwargs : dict
            Additional arguments passed to pd.to_csv() or pd.to_excel()

        Examples:
        ---------
        >>> # Save to CSV
        >>> HybridMICEImputer.save_data(imputed_data, 'output.csv')
        >>>
        >>> # Save to Excel
        >>> HybridMICEImputer.save_data(imputed_data, 'output.xlsx')
        >>>
        >>> # Save to Excel with custom sheet name
        >>> HybridMICEImputer.save_data(imputed_data, 'output.xlsx', sheet_name='Imputed Data')
        """
        file_path = str(file_path)

        if file_path.endswith('.csv'):
            data.to_csv(file_path, index=index, **kwargs)
        elif file_path.endswith('.xlsx'):
            data.to_excel(file_path, sheet_name=sheet_name, index=index, **kwargs)
        else:
            # Default to CSV if extension not recognized
            data.to_csv(file_path, index=index, **kwargs)

## 4. Helper Functions

Function to create sample data for demonstration.

In [None]:
def create_sample_data_with_missing(n_samples=1000, missing_rate=0.2, random_state=42):
    """
    Create a sample dataset with missing values for demonstration.

    Parameters:
    -----------
    n_samples : int
        Number of samples to generate
    missing_rate : float
        Proportion of values to set as missing
    random_state : int
        Random seed

    Returns:
    --------
    data : pd.DataFrame
        Dataset with missing values
    data_complete : pd.DataFrame
        Original complete dataset (for comparison)
    """
    np.random.seed(random_state)

    # Generate correlated features
    # Feature 1: base random variable
    x1 = np.random.randn(n_samples)

    # Feature 2: strongly correlated with x1
    x2 = 0.8 * x1 + 0.2 * np.random.randn(n_samples)

    # Feature 3: moderately correlated with x1 and x2
    x3 = 0.5 * x1 + 0.3 * x2 + 0.4 * np.random.randn(n_samples)

    # Feature 4: weakly correlated
    x4 = 0.2 * x1 + 0.8 * np.random.randn(n_samples)

    # Feature 5: independent
    x5 = np.random.randn(n_samples)

    # Create DataFrame
    data_complete = pd.DataFrame({
        'feature1': x1,
        'feature2': x2,
        'feature3': x3,
        'feature4': x4,
        'feature5': x5
    })

    # Add a target variable
    data_complete['target'] = (
        2 * x1 + 1.5 * x2 - 0.5 * x3 + np.random.randn(n_samples) * 0.5
    )

    # Create missing values
    data_with_missing = data_complete.copy()

    for col in data_with_missing.columns:
        # Randomly select indices to set as missing
        n_missing = int(n_samples * missing_rate)
        missing_indices = np.random.choice(n_samples, n_missing, replace=False)
        data_with_missing.loc[missing_indices, col] = np.nan

    return data_with_missing, data_complete

---

# Examples and Usage

Let's explore different use cases of the Hybrid MICE Imputer.

## Example 1: Basic Usage

Demonstrate basic imputation with default settings.

In [None]:
# Create sample data
data_missing, data_complete = create_sample_data_with_missing(
    n_samples=500,
    missing_rate=0.15
)

print("="*70)
print("Example 1: Basic Usage")
print("="*70)

print("\nOriginal data with missing values:")
print(data_missing.head(10))
print("\nMissing value statistics:")
print(data_missing.isnull().sum())

# Initialize and fit the hybrid imputer
imputer = HybridMICEImputer(
    n_iterations=10,
    n_neighbors=5,
    correlation_threshold=0.3,
    verbose=True,
    random_state=42
)

# Impute missing values
data_imputed = imputer.fit_transform(data_missing)

print("\n\nImputed data:")
print(data_imputed.head(10))

# Calculate imputation accuracy (RMSE)
print("\n\nImputation Quality Assessment:")
for col in data_missing.columns:
    missing_mask = data_missing[col].isnull()
    if missing_mask.any():
        true_values = data_complete.loc[missing_mask, col]
        imputed_values = data_imputed.loc[missing_mask, col]
        rmse = np.sqrt(np.mean((true_values - imputed_values) ** 2))
        print(f"{col}: RMSE = {rmse:.4f}")

## Example 2: Comparing Different Model Types

Compare performance of different PMM model types.

In [None]:
print("\n" + "="*70)
print("Example 2: Comparing Model Types")
print("="*70)

# Create data with more missing values
data_missing, data_complete = create_sample_data_with_missing(
    n_samples=300,
    missing_rate=0.30
)

print(f"\nDataset size: {len(data_missing)} samples")
print(f"Missing rate: ~30%")

# Try different model configurations
configs = [
    {'name': 'Linear PMM', 'pmm_model_type': 'linear'},
    {'name': 'Bayesian PMM', 'pmm_model_type': 'bayesian'},
    {'name': 'Random Forest PMM', 'pmm_model_type': 'rf'}
]

results = {}

for config in configs:
    print(f"\n\nTesting: {config['name']}")
    print("-"*40)

    imputer = HybridMICEImputer(
        n_iterations=15,
        n_neighbors=7,
        pmm_model_type=config['pmm_model_type'],
        correlation_threshold=0.25,
        verbose=False,
        random_state=42
    )

    data_imputed = imputer.fit_transform(data_missing)

    # Calculate overall RMSE
    total_rmse = 0
    n_cols = 0

    for col in data_missing.columns:
        missing_mask = data_missing[col].isnull()
        if missing_mask.any():
            true_values = data_complete.loc[missing_mask, col]
            imputed_values = data_imputed.loc[missing_mask, col]
            rmse = np.sqrt(np.mean((true_values - imputed_values) ** 2))
            total_rmse += rmse
            n_cols += 1

    avg_rmse = total_rmse / n_cols if n_cols > 0 else 0
    results[config['name']] = avg_rmse

    print(f"Average RMSE: {avg_rmse:.4f}")
    print(f"Iterations to convergence: {len(imputer.convergence_history)}")

print("\n\nComparison Summary:")
print("-"*40)
for name, rmse in sorted(results.items(), key=lambda x: x[1]):
    print(f"{name}: {rmse:.4f}")

## Example 3: Diagnostics and Visualization

Explore diagnostic features and visualizations.

In [None]:
print("\n" + "="*70)
print("Example 3: Diagnostics and Visualization")
print("="*70)

# Create sample data
data_missing, _ = create_sample_data_with_missing(
    n_samples=400,
    missing_rate=0.20
)

# Initialize imputer
imputer = HybridMICEImputer(
    n_iterations=20,
    verbose=False,
    random_state=42
)

# Impute
data_imputed = imputer.fit_transform(data_missing)

# Get diagnostics
diagnostics = imputer.get_diagnostics()

print("\nImputation Order (based on correlations):")
for i, col in enumerate(diagnostics['imputation_order'], 1):
    print(f"{i}. {col}")

print("\n\nPredictor Sets (correlation-based):")
for col, predictors in diagnostics['predictor_sets'].items():
    if col in diagnostics['columns_with_missing']:
        print(f"\n{col}:")
        print(f"  Top predictors: {predictors[:3]}")

print("\n\nCorrelation Matrix:")
print(diagnostics['correlation_matrix'])

print("\n\nConvergence History:")
for i, conv in enumerate(diagnostics['convergence_history'], 1):
    print(f"Iteration {i}: {conv:.6f}")

### Visualize Convergence

In [None]:
# Plot convergence
imputer.plot_convergence()

### Visualize Correlation Matrix

In [None]:
# Visualize correlation matrix
imputer.correlation_analyzer.visualize_correlations()

## Example 4: Partial Imputation (Specific Columns)

Demonstrate imputing only specific columns.

In [None]:
print("\n" + "="*70)
print("Example 4: Partial Imputation (Specific Columns)")
print("="*70)

# Create sample data
data_missing, data_complete = create_sample_data_with_missing(
    n_samples=500,
    missing_rate=0.15
)

print("\nImputing only 'feature1' and 'target' columns...")

# Initialize imputer
imputer = HybridMICEImputer(
    n_iterations=10,
    verbose=True,
    random_state=42
)

# Impute only specific columns
data_imputed = imputer.fit_transform(
    data_missing,
    columns_to_impute=['feature1', 'target']
)

# Check which columns were imputed
print("\n\nMissing values after imputation:")
print(data_imputed.isnull().sum())

## Example 5: Using Your Own Data

Template for using the imputer with your own dataset.

In [None]:
print("\n" + "="*70)
print("Example 5: Using Your Own Data")
print("="*70)

# Uncomment and modify this section to use your own data:

# # Load your data
# your_data = pd.read_csv('your_file.csv')
#
# # Or upload a file in Google Colab:
# from google.colab import files
# uploaded = files.upload()
# your_data = pd.read_csv(list(uploaded.keys())[0])
#
# # Initialize imputer with custom parameters
# imputer = HybridMICEImputer(
#     n_iterations=15,
#     n_neighbors=5,
#     correlation_threshold=0.3,
#     pmm_model_type='linear',  # Options: 'linear', 'bayesian', 'rf'
#     verbose=True,
#     random_state=42
# )
#
# # Impute missing values
# your_data_imputed = imputer.fit_transform(your_data)
#
# # View results
# print("\nOriginal data:")
# print(your_data.head())
# print("\nMissing values:", your_data.isnull().sum().sum())
#
# print("\nImputed data:")
# print(your_data_imputed.head())
# print("\nMissing values:", your_data_imputed.isnull().sum().sum())
#
# # Plot convergence
# imputer.plot_convergence()
#
# # Download the imputed data
# your_data_imputed.to_csv('imputed_data.csv', index=False)
# files.download('imputed_data.csv')

print("\nUncomment and modify the code above to use your own data!")

## Example 6: Handling Non-Numeric and ID Columns

Demonstrates automatic exclusion of non-numeric columns and manual exclusion of ID columns.

In [None]:
print("\n" + "="*70)
print("Example 6: Handling Non-Numeric and ID Columns")
print("="*70)

# Create sample data with mixed types (like a real dataset)
np.random.seed(42)
n_samples = 200

# Create numeric features
feature1 = np.random.randn(n_samples)
feature2 = 0.7 * feature1 + 0.3 * np.random.randn(n_samples)
age = np.random.randint(20, 80, n_samples).astype(float)
score = 50 + 10 * feature1 + np.random.randn(n_samples) * 5

# Create non-numeric columns
ptid = [f"PT{i:04d}" for i in range(n_samples)]
diagnosis = np.random.choice(['Control', 'AD', 'MCI'], n_samples)
gender = np.random.choice(['M', 'F'], n_samples)

# Create DataFrame with mixed types
data_mixed = pd.DataFrame({
    'ptid': ptid,
    'diagnosis': diagnosis,
    'gender': gender,
    'age': age,
    'feature1': feature1,
    'feature2': feature2,
    'score': score
})

# Introduce missing values in various columns
missing_rate = 0.20
for col in ['age', 'feature1', 'feature2', 'score', 'diagnosis']:
    n_missing = int(n_samples * missing_rate)
    missing_indices = np.random.choice(n_samples, n_missing, replace=False)
    data_mixed.loc[missing_indices, col] = np.nan

print("\nOriginal data with mixed types:")
print(data_mixed.head(10))
print("\nData types:")
print(data_mixed.dtypes)
print("\nMissing values:")
print(data_mixed.isnull().sum())

# Initialize imputer - it will automatically exclude non-numeric columns
# and we can manually exclude 'ptid' as well
imputer = HybridMICEImputer(
    n_iterations=10,
    n_neighbors=5,
    correlation_threshold=0.3,
    exclude_columns=['ptid'],  # Manually exclude patient ID
    verbose=True,
    random_state=42
)

# Impute missing values
data_imputed = imputer.fit_transform(data_mixed)

print("\n\nImputed data:")
print(data_imputed.head(10))
print("\nMissing values after imputation:")
print(data_imputed.isnull().sum())

# Get diagnostics to see what was excluded
diagnostics = imputer.get_diagnostics()
print("\n\nDiagnostic Information:")
print(f"Total columns with missing data: {len(diagnostics['columns_with_missing'])}")
print(f"Numeric columns imputed with PMM: {diagnostics['numeric_columns_imputed']}")
print(f"Excluded columns: {diagnostics['excluded_columns']}")
print("\nNote: Non-numeric columns (diagnosis, gender) and manually excluded columns (ptid)")
print("      retain their original missing values or get simple mode imputation.")

---

## Summary

This notebook demonstrates the **Hybrid Correlation + PMM + MICE Imputer**, which combines:

1. **Correlation Analysis** - Smart predictor selection based on feature correlations
2. **Predictive Mean Matching** - Distribution-preserving imputation
3. **MICE Framework** - Iterative refinement for improved accuracy

### Key Features:
- Multiple model types (Linear, Bayesian, Random Forest)
- Automatic predictor selection
- Convergence monitoring
- Comprehensive diagnostics
- Partial imputation support

### When to Use:
- Datasets with complex missing data patterns
- When preserving data distribution is important
- When features have strong correlations
- For both research and production use cases

---

**Repository**: [https://github.com/CodeSakshamY/correlation-PMM-MICE](https://github.com/CodeSakshamY/correlation-PMM-MICE)

**License**: MIT

---