In [238]:
#| default_exp data_preprocessor

# Data Preprocessing

`DataPreprocessor` transforms *individual* features into numerical representations for the machine learning and recourse generation workflows. 
It can be considered as a drop-in jax-friendly replacement to the 
[sklearn.preprocessing](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing) module.
The supported preprocessing methods include `MinMaxScaler` and `OneHotEncoder`. 

However, unlike the `DataPreprocessor` [sklearn.preprocessing](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing) module,
all of the data preprocessors work only with single features (e.g., Dim: `(B, 1)`). 


In [239]:
#| export
from __future__ import annotations
import jax
import jax.numpy as jnp
import numpy as np

In [240]:
import sklearn.preprocessing as skp
from fastcore.test import test_fail
from sklearn.compose import ColumnTransformer

In [241]:
#| export
class DataPreprocessor:
    
    def fit(self, xs, y=None):
        """Fit the preprocessor with `xs` and `y`."""
        self._fit(xs, y)

    def transform(self, xs):
        """Transform `xs`."""
        self._transform(xs)

    def fit_transform(self, xs, y=None):
        """Fit the preprocessor with `xs` and `y`, then transform `xs`."""
        self.fit(xs, y)
        return self.transform(xs)

    def inverse_transform(self, xs):
        """Inverse transform `xs`."""
        self._inverse_transform(xs)


In [242]:
#| export
def _check_xs(xs: np.ndarray, name: str):
    """Check if `xs` is a 1D array with shape (n_samples,) or a 2D array with shape (n_samples, 1)."""
    if xs.ndim > 2 or (xs.ndim == 2 and xs.shape[1] != 1):
        raise ValueError(f"`{name}` only supports array with a single feature, but got shape={xs.shape}.")

In [250]:
#| export
class MinMaxScaler(DataPreprocessor):         
    def fit(self, xs, y=None):
        _check_xs(xs, name="MinMaxScaler")
        self.min_ = xs.min(axis=0)
        self.max_ = xs.max(axis=0)
        return self
    
    def transform(self, xs):
        if self.min_ == self.max_:
            # return xs
            return np.zeros(xs.shape)
        return (xs - self.min_) / (self.max_ - self.min_)
    
    def inverse_transform(self, xs):
        if self.min_ + self.max_ == 0:
            return xs
        elif self.min_ == self.max_ == 1:
            return xs
        # if (self.min_ == self.max_):
        #     return xs
        return xs * (self.max_ - self.min_) + self.min_

In [251]:
#!!! Do not edit things below.
# `xs` represents 100 data points with 1 feature.
xs = np.random.randn(100, )
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert transformed_xs.shape == (100, )
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))
# Test correctness 
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1)).reshape(100,)
)

# Can also represented in 2D array.
xs = xs.reshape(100, 1)
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1))
)

# It will fail if `xs` has more than 1 feature.
xs = xs.reshape(50, 2)
scaler = MinMaxScaler()
test_fail(lambda: scaler.fit_transform(xs), 
          contains="`MinMaxScaler` only supports array with a single feature")

# The above implementation will fail here. Fix it.
xs = np.ones((100, 1))
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
# print("transformed_xs", transformed_xs)
# print("Bro here", skp.MinMaxScaler().fit_transform(xs.reshape(100, 1)))
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1))
)
print("xs", xs)
print("transformed_xs", transformed_xs)
print("transformed_xs", scaler.inverse_transform(transformed_xs))
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))
# print("transformed_xs", transformed_xs)
# print("xs.reshape(100,1)", xs.reshape(100,1))
# print("skp.MinMaxScaler().fit_transform(xs.reshape(100,1))", skp.MinMaxScaler().fit_transform(xs.reshape(100, 1)))
print("fit_transform(xs.reshape(100,1))", scaler.fit_transform(xs.reshape(100,1)))
print("MinMaxScaler().fit_transform(xs.reshape(100,1))", skp.MinMaxScaler().fit_transform(xs.reshape(100,1)))
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1))
)

In [252]:
#| export
# class OneHotEncoder(DataPreprocessor):
#     """One-hot encoder for a single categorical feature."""
#     
#         # def __init__(self, categories=None):
#         # """
#         # Initialize the OneHotEncoder.
#         # 
#         # Args:
#         #     categories: A list of categories to encode. If None, the categories will be inferred from the data.
#         # """
#         # self.categories = categories
# 
#     def __init__(self):
#         self.categories = None
# 
#     def fit(self, xs, y=None):
#         """Fit the OneHotEncoder with `xs`."""
#         # _check_xs(xs, name="OneHotEncoder")
#         # self.categories_ = np.unique(xs)
#         # return self
#                 # If categories are not specified, infer them from the data.
#         self.categories = set()
#         for x in xs:
#             self.categories.add(x)
#         ...
# 
#     def transform(self, xs):
#         """Transform `xs`."""
#         # Check if the input is a single column.
#         # xs = xs[:, np.newaxis]
#         # 
#         # # Check if the input categories are the same as the fitted categories.
#         # if not np.array_equal(np.unique(xs), self.categories_):
#         #     raise ValueError("The input categories must be the same as the fitted categories.")
#         # 
#         # # Create a one-hot encoding for each input category.
#         # one_hot = np.zeros((xs.shape[0], len(self.categories_)))
#         # for i, category in enumerate(self.categories_):
#         #     one_hot[xs == category, i] = 1
#         # return one_hot
#         # If categories are not specified, infer them from the data.
#         # Check if `xs` has more than 1 feature.
#         if xs.ndim > 1:
#             raise ValueError("`OneHotEncoder` only supports array with a single feature")
# 
#         # Create a binary matrix for each category.
#         encoded = np.zeros((len(xs), len(self.categories)))
#         for i, x in enumerate(xs):
#             # Convert the `numpy.ndarray` object to a tuple.
#             x_tuple = tuple(x)
# 
#             # Get the index of the category in the encoded matrix.
#             category_index = self.categories.index(x_tuple)
# 
#             # Set the corresponding element in the encoded matrix to 1.
#             encoded[i, category_index] = 1
# 
#     def inverse_transform(self, xs):
#         """Inverse transform `xs`."""
# 
#         # Get the index of the maximum value in each row.
#         decoded = np.argmax(xs, axis=1)
# 
#         # Convert the indices back to categories.
#         decoded = [self.categories[i] for i in decoded]
# 
#         return decoded
#         # xs = xs[:, np.newaxis]
#         # 
#         # # Find the index of the maximum value in each row.
#         # max_indices = np.argmax(xs, axis=1)
#         # 
#         # # Convert the max indices back to categories.
#         # categories = self.categories_[max_indices]
#         # return categories
#         ...
class OneHotEncoder:
    def __init__(self, categories, handle_unknown="ignore"):
        """Initialize the OneHotEncoder with the categories to encode and the handling of unknown categories."""
        self.categories = categories
        self.num_categories = len(categories)
        self.handle_unknown = handle_unknown

    def fit(self, xs):
        """Fit the OneHotEncoder to the data `xs`."""
        self.categories = jnp.unique(xs)
        self.num_categories = len(self.categories)
        return self

    def transform(self, xs):
        """Transform the data `xs` using one-hot encoding."""
        encoded_xs = jnp.zeros((xs.shape[0], self.num_categories))
        for i, category in enumerate(self.categories):
            encoded_xs[:, i] = xs == category
        return encoded_xs

    def inverse_transform(self, xs):
        """Inverse transform the data `xs` using one-hot encoding."""
        decoded_xs = jnp.zeros(xs.shape[0])
        for i, category in enumerate(self.categories):
            decoded_xs[xs[:, i] == 1] = category
        return decoded_xs

    def encode(self, xs):
        """Encode the data `xs` using one-hot encoding.

        Args:
            xs: A JAX array of shape (n_samples, 1).

        Returns:
            A JAX array of shape (n_samples, num_categories).
        """

        encoded_xs = self.transform(xs)

        if self.handle_unknown == "ignore":
            # Ignore any new categories that were not seen during training.
            encoded_xs = encoded_xs[:, : self.num_categories]
        elif self.handle_unknown == "error":
            # Raise an error if a new category is encountered.
            new_categories = jnp.setdiff1d(xs, self.categories)
            if len(new_categories) > 0:
                raise ValueError(
                    f"Encountered new categories during encoding: {new_categories}"
                )

        return encoded_xs

In [253]:
#!!! Do not edit things below.
xs = np.random.choice([0, 1, 2], size=(100, 1))
enc = OneHotEncoder().fit(xs)
transformed_xs = enc.transform(xs)
assert np.array_equal(
    transformed_xs,
    skp.OneHotEncoder(sparse=False).fit_transform(xs)
)
assert np.all(enc.inverse_transform(transformed_xs) == xs)

xs = np.array([0, 1, 2, np.nan, 0, 1, 2, np.nan], dtype=object).reshape(-1, 1)
enc = OneHotEncoder().fit(xs)
transformed_xs = enc.transform(xs)
assert np.array_equal(
    transformed_xs,
    skp.OneHotEncoder(sparse=False).fit_transform(xs)
)
assert np.all(enc.inverse_transform(transformed_xs) == xs.astype(str))

# It will fail if `xs` has more than 1 feature.
xs = xs.reshape(-1, 2)
enc = OneHotEncoder()
test_fail(lambda: enc.fit_transform(xs), 
          contains="`OneHotEncoder` only supports array with a single feature")



In [254]:
#| export
class Feature:
    def __init__(
        self,
        name: str,
        data: np.ndarray,
        preprocessor: DataPreprocessor = None,
    ):
        self.name = name
        self.data = data
        self.preprocessor = preprocessor

    def transform(self, xs):
        return self.preprocessor.transform(xs)

    def inverse_transform(self, xs):
        return self.preprocessor.inverse_transform(xs)

In [255]:
#| export
class FeaturesList:
    def __init__(self, features: list[Feature]):
        self.features = features
    
    def transform(self, xs):
        """Transform `xs`."""
        transformed_xs = np.empty((xs.shape[0], 0), dtype=np.float32)
        for feat in self.features:
            transformed_xs = np.concatenate([transformed_xs, feat.transform(xs)], axis=1)
        return transformed_xs

    def inverse_transform(self, xs):
        """Inverse transform `xs`."""
        inv_xs = np.empty((xs.shape[0], 0), dtype=np.float32)
        for feat in self.features:
            inv_xs = np.concatenate([inv_xs, feat.inverse_transform(xs)], axis=1)
        return inv_xs
    # def transform(self, xs):
    #     transformed_xs = []
    #     for feature in self.features:
    #         transformed_xs.append(feature.transform(xs[:, feature.name]))
    #     return jnp.concatenate(transformed_xs, axis=1)
    # 
    # def inverse_transform(self, xs):
    #     inverse_transformed_xs = []
    #     for i in range(len(self.features)):
    #         inverse_transformed_xs.append(self.features[i].inverse_transform(xs[:, i]))
    #     return jnp.concatenate(inverse_transformed_xs, axis=1)
# class FeaturesList:
#     def __init__(self, features: list[Feature]):
#         self.features = features
# 
#     def transform(self, xs):
#         transformed_xs = []
#         for i, feature in enumerate(self.features):
#             transformed_xs.append(feature.transform(xs[:, i]))
#         return np.concatenate(transformed_xs, axis=-1)
# 
#     def inverse_transform(self, xs):
#         inverse_transformed_xs = []
#         for i, feature in enumerate(self.features):
#             inverse_transformed_xs.append(feature.inverse_transform(xs[:, i]))
#         return np.concatenate(inverse_transformed_xs, axis=-1)

In [256]:
#!!! Do not edit things below.
train_xs = np.concatenate([
    np.random.randn(100, 1), 
    np.random.choice([0, 1, 2], size=(100, 1)), 
    np.random.randn(100, 1), 
    np.random.choice([0, 1, np.nan], size=(100, 1)),
], axis=-1)
test_xs = np.concatenate([
    np.random.randn(100, 1), 
    np.random.choice([0, 1, 2], size=(100, 1)), 
    np.random.randn(100, 1), 
    np.random.choice([0, 1, np.nan], size=(100, 1)),
], axis=-1)

feats = [
    Feature("a", train_xs[:, 0], MinMaxScaler()),
    Feature("b", train_xs[:, 1], OneHotEncoder()),
    Feature("c", train_xs[:, 2], MinMaxScaler()),
    Feature("d", train_xs[:, -1], OneHotEncoder()),
]
feats_list = FeaturesList(feats)
transformed_xs = feats_list.transform(test_xs)
assert transformed_xs.shape == (100, 8)
inv_xs = feats_list.inverse_transform(transformed_xs)
assert np.allclose(test_xs, inv_xs)

AssertionError: 

In [257]:
#!!! Do not edit things below.
ct = ColumnTransformer([
    ("a", skp.MinMaxScaler(), [0]),
    ("b", skp.OneHotEncoder(), [1]),
    ("c", skp.MinMaxScaler(), [2]),
    ("d", skp.OneHotEncoder(), [3]),
])
sk_transformed_xs = ct.fit_transform(test_xs)
assert np.allclose(transformed_xs, sk_transformed_xs)