In [None]:
#| default_exp data_preprocessor

# Data Preprocessing

In [None]:
#| export
from __future__ import annotations
import jax
import jax.numpy as jnp
import numpy as np

In [None]:
import sklearn.preprocessing as skp
from sklearn.preprocessing import OneHotEncoder as SklearnOneHotEncoder

In [None]:
#| export
class DataPreprocessor:
    
    def fit(self, xs, y=None):
        """Fit the preprocessor with `xs` and `y`."""
        raise NotImplementedError
    
    def transform(self, xs):
        """Transform `xs`."""
        raise NotImplementedError
    
    def fit_transform(self, xs, y=None):
        """Fit the preprocessor with `xs` and `y`, then transform `xs`."""
        self.fit(xs, y)
        return self.transform(xs)
    
    def inverse_transform(self, xs):
        """Inverse transform `xs`."""
        raise NotImplementedError


In [None]:
#| export
def _check_xs(xs: np.ndarray, name: str):
    """Check if `xs` is a 1D array with shape (n_samples,) or a 2D array with shape (n_samples, 1)."""
    if xs.ndim > 2 or (xs.ndim == 2 and xs.shape[1] != 1):
        raise ValueError(f"`{name}` only supports array with a single feature, but got shape={xs.shape}.")

In [None]:
#| export
class MinMaxScaler(DataPreprocessor):         
    def fit(self, xs, y=None):
        _check_xs(xs, name="MinMaxScaler")
        self.min_ = xs.min(axis=0)
        self.max_ = xs.max(axis=0)
        return self
    
    def transform(self, xs):
        return (xs - self.min_) / (self.max_ - self.min_)
    
    def inverse_transform(self, xs):
        return xs * (self.max_ - self.min_) + self.min_

In [None]:
#!!! Do not edit things below.
xs = np.random.randn(100, )
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert transformed_xs.shape == (100, )
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))
# Test correctness 
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1)).reshape(100,)
)
# Also work with 2D array
xs = xs.reshape(100, 1)
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1))
)

In [None]:
#| export
#!!! Fill in the missing code below.
class OneHotEncoder(DataPreprocessor):
    """One-hot encoder for a single categorical feature."""
    
    def fit(self, xs, y=None):
        """Fit the OneHotEncoder with `xs`."""
        ...

    def transform(self, xs):
        """Transform `xs`."""
        ...

    def inverse_transform(self, xs):
        """Inverse transform `xs`."""
        ...

In [None]:
#!!! Do not edit things below.
xs = np.random.choice([0, 1, 2], size=(100, 1))
enc = OneHotEncoder().fit(xs)
transformed_xs = enc.transform(xs)
assert np.array_equal(
    transformed_xs,
    skp.OneHotEncoder(sparse=False).fit_transform(xs)
)
assert np.all(enc.inverse_transform(transformed_xs) == xs)

xs = np.array([0, 1, 2, np.nan, 0, 1, 2, np.nan], dtype=object).reshape(-1, 1)
enc = OneHotEncoder().fit(xs)
transformed_xs = enc.transform(xs)
assert np.array_equal(
    transformed_xs,
    skp.OneHotEncoder(sparse=False).fit_transform(xs)
)
assert np.all(enc.inverse_transform(transformed_xs) == xs.astype(str))

In [None]:
#| export
class Feature:
    def __init__(
        self,
        name: str,
        data: np.ndarray,
        preprocessor: DataPreprocessor = None,
    ):
        self.name = name
        self.data = data
        self.preprocessor = preprocessor

    def transform(self, xs):
        ...

    def inverse_transform(self, xs):
        ...

In [None]:
#| export
class FeaturesList:
    def __init__(self, features: list[Feature]):
        ...

    def transform(self, xs):
        # TODO: use `jax.lax.scan` to implment this function.
        ...

    def inverse_transform(self, xs):
        # TODO: use `jax.lax.scan` to implment this function.
        ...