In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [4]:
class SimpleImputer:
    def __init__(self, strategy="mean", fill_value=None):
        self.strategy = strategy
        self.fill_value = fill_value

    def fit(self, X):
        is_dataframe = isinstance(X, pd.DataFrame)
        X_df = X if is_dataframe else pd.DataFrame(X)
        self.feature_names_ = X_df.columns

        if self.strategy in ["mean", "median"]:
            numeric_cols = X_df.select_dtypes(include=[np.number])
            if numeric_cols.shape[1] != X_df.shape[1]:
                non_numeric = set(X_df.columns) - set(numeric_cols.columns)
                raise ValueError(
                    f"Strategy '{self.strategy}' cannot be applied to non-numeric columns: {non_numeric}"
                )

            if self.strategy == "mean":
                self.statistics_ = numeric_cols.mean(axis=0, skipna=True).values
            else:  # median
                self.statistics_ = numeric_cols.median(axis=0, skipna=True).values

        elif self.strategy == "constant":
            self.statistics_ = np.array([self.fill_value] * X_df.shape[1])

        elif self.strategy == "most_frequent":
            self.statistics_ = X_df.mode(dropna=True).iloc[0].values

        else:
            raise ValueError(f"Unknown strategy: {self.strategy}")

        return self

    def transform(self, X):
        is_dataframe = isinstance(X, pd.DataFrame)
        X_df = X if is_dataframe else pd.DataFrame(X)

        if X_df.shape[1] != len(self.statistics_):
            raise ValueError(
                f"Shape mismatch: fitted on {len(self.statistics_)} features, "
                f"but transform data has {X_df.shape[1]}"
            )

        X_filled = X_df.fillna(pd.Series(self.statistics_, index=self.feature_names_))
        return X_filled if is_dataframe else X_filled.values

    def fit_transform(self, X):
        return self.fit(X).transform(X)


In [6]:
data_df = pd.DataFrame({
    "A": [1, np.nan, 3, np.nan],   # numeric
    "B": [4, 5, np.nan, 7]         # numeric
})

print("Original DataFrame:\n", data_df, "\n")

# Test all strategies
for strategy in ["mean", "median", "most_frequent", "constant"]:
    print(f"--- Testing strategy: {strategy} ---")
    
    fill_value = 0 if strategy == "constant" else None
    imputer = SimpleImputer(strategy=strategy, fill_value=fill_value)
    
    transformed_df = imputer.fit_transform(data_df)
    print("Fitted statistics:", imputer.statistics_)
    print("Transformed DataFrame:\n", transformed_df, "\n")

# Test categorical column separately for most_frequent/constant
data_cat = pd.DataFrame({
    "C": [np.nan, "cat", "dog", np.nan]
})

for strategy in ["most_frequent", "constant"]:
    print(f"--- Testing categorical column with strategy: {strategy} ---")
    
    fill_value = "missing" if strategy == "constant" else None
    imputer = SimpleImputer(strategy=strategy, fill_value=fill_value)
    
    transformed_cat = imputer.fit_transform(data_cat)
    print("Fitted statistics:", imputer.statistics_)
    print("Transformed categorical DataFrame:\n", transformed_cat, "\n")


Original DataFrame:
      A    B
0  1.0  4.0
1  NaN  5.0
2  3.0  NaN
3  NaN  7.0 

--- Testing strategy: mean ---
Fitted statistics: [2.         5.33333333]
Transformed DataFrame:
      A         B
0  1.0  4.000000
1  2.0  5.000000
2  3.0  5.333333
3  2.0  7.000000 

--- Testing strategy: median ---
Fitted statistics: [2. 5.]
Transformed DataFrame:
      A    B
0  1.0  4.0
1  2.0  5.0
2  3.0  5.0
3  2.0  7.0 

--- Testing strategy: most_frequent ---
Fitted statistics: [1. 4.]
Transformed DataFrame:
      A    B
0  1.0  4.0
1  1.0  5.0
2  3.0  4.0
3  1.0  7.0 

--- Testing strategy: constant ---
Fitted statistics: [0 0]
Transformed DataFrame:
      A    B
0  1.0  4.0
1  0.0  5.0
2  3.0  0.0
3  0.0  7.0 

--- Testing categorical column with strategy: most_frequent ---
Fitted statistics: ['cat']
Transformed categorical DataFrame:
      C
0  cat
1  cat
2  dog
3  cat 

--- Testing categorical column with strategy: constant ---
Fitted statistics: ['missing']
Transformed categorical DataFrame