In [None]:
#| default_exp data_utils

In [None]:
#| export
from __future__ import annotations
from fastcore.test import *
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
import einops


## Data Preprocessor

E.g., MinMaxScaler, OneHotEncoder.

Note: it only works with preprocessing a single feature.

In [None]:
#| export
def _check_xs(xs: np.ndarray):
    if xs.ndim > 2 or (xs.ndim == 2 and xs.shape[1] != 1):
        raise ValueError(f"MinMaxScaler only supports array with a single feature, but got shape={xs.shape}.")
    
        
class DataPreprocessor:
    
    def fit(self, xs, y=None):
        raise NotImplementedError
    
    def transform(self, xs):
        raise NotImplementedError
    
    def fit_transform(self, xs, y=None):
        self.fit(xs, y)
        return self.transform(xs)
    
    def inverse_transform(self, xs):
        raise NotImplementedError

In [None]:
#| export
class MinMaxScaler(DataPreprocessor): 
    def fit(self, xs, y=None):
        _check_xs(xs)
        self.min_ = xs.min(axis=0)
        self.max_ = xs.max(axis=0)
        return self
    
    def transform(self, xs):
        return (xs - self.min_) / (self.max_ - self.min_)
    
    def inverse_transform(self, xs):
        return xs * (self.max_ - self.min_) + self.min_

In [None]:
xs = np.random.randn(100, )
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))


xs = xs.reshape(100, 1)
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))

xs = xs.reshape(50, 2)
scaler = MinMaxScaler()
test_fail(lambda: scaler.fit_transform(xs), contains="MinMaxScaler only supports array with a single feature")

In [None]:
#| exporti
def _unique(xs):
    if xs.dtype == object:
        # Note: np.unique does not work with object dtype
        # We will enforce xs to be string type
        # It assumes that xs is a list of strings, and might not work
        # for other cases (e.g., list of string and numbers)
        return np.unique(xs.astype(str))
    return np.unique(xs)

In [None]:
#| export
class EncoderPreprocessor(DataPreprocessor):
    def _fit(self, xs, y=None):
        _check_xs(xs)
        self.categories_ = _unique(xs)

    def _transform(self, xs):
        """Transform data to ordinal encoding."""
        if xs.dtype == object:
            xs = xs.astype(str)
        ordinal = np.searchsorted(self.categories_, xs)
        return einops.rearrange(ordinal, 'k n -> n k')
    
    def _inverse_transform(self, xs):
        """Transform ordinal encoded data back to original data."""
        return self.categories_[xs].T

In [None]:
#| export
class OrdinalPreprocessor(EncoderPreprocessor):
    def fit(self, xs, y=None):
        self._fit(xs, y)
        return self
    
    def transform(self, xs):
        if xs.ndim == 1:
            raise ValueError(f"OrdinalPreprocessor only supports 2D array with a single feature, "
                             f"but got shape={xs.shape}.")
        return self._transform(xs)
    
    def inverse_transform(self, xs):
        return self._inverse_transform(xs)

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
enc = OrdinalPreprocessor().fit(xs)
transformed_xs = enc.transform(xs)
assert np.all(enc.inverse_transform(transformed_xs) == xs)

xs = np.array(['a', 'b', 'c', np.nan, 'a', 'b', 'c', np.nan], dtype=object).reshape(-1, 1)
enc = OrdinalPreprocessor().fit(xs)
transformed_xs = enc.transform(xs)
assert np.all(enc.inverse_transform(transformed_xs) == xs.astype(str))

xs = np.random.choice(['a', 'b', 'c'], size=(100, ))
test_fail(lambda: OrdinalPreprocessor().fit_transform(xs), 
    contains="OrdinalPreprocessor only supports 2D array with a single feature")


In [None]:
#| export
class OneHotEncoder(EncoderPreprocessor):
    # Fit the encoder without sci-kit OneHotEncoder.
    def fit(self, xs, y=None):
        self._fit(xs, y)
        return self

    def transform(self, xs):
        if xs.ndim == 1:
            raise ValueError(f"OneHotEncoder only supports 2D array with a single feature, "
                             f"but got shape={xs.shape}.")
        xs_int = self._transform(xs)
        one_hot_feats = jax.nn.one_hot(xs_int, len(self.categories_))
        return einops.rearrange(one_hot_feats, 'k n d -> n (k d)')

    def inverse_transform(self, xs):
        xs_int = np.argmax(xs, axis=-1)
        return self._inverse_transform(xs_int).reshape(-1, 1)

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))

ohe = OneHotEncoder().fit(xs)
transformed_xs = ohe.transform(xs)
assert np.all(ohe.inverse_transform(transformed_xs) == xs)


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


## Transformation


In [None]:
#| export
class Transformation:
    def __init__(self, name, transformer):
        self.name = name
        self.transformer = transformer

    def fit(self, xs, y=None):
        self.transformer.fit(xs)
        return self
    
    def transform(self, xs):
        return self.transformer.transform(xs)

    def fit_transform(self, xs, y=None):
        return self.transformer.fit_transform(xs)
    
    def inverse_transform(self, xs):
        return self.transformer.inverse_transform(xs)

    def apply_constraints(self, xs):
        return xs

In [None]:
#| export
class MinMaxTransformation(Transformation):
    def __init__(self):
        super().__init__("minmax", MinMaxScaler())

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jnp.clip(cfs, 0., 1.)

In [None]:
xs = np.random.randn(100, 1)
scaler = MinMaxTransformation()
transformed_xs = scaler.fit_transform(xs)

cfs = np.random.randn(100, 1)
cf_constrained = scaler.apply_constraints(xs, cfs)
assert np.all(cf_constrained >= 0) and np.all(cf_constrained <= 1)

In [None]:
#| export
class OneHotTransformation(Transformation):
    def __init__(self):
        super().__init__("ohe", OneHotEncoder())

    @property
    def categories(self) -> int:
        return len(self.transformer.categories_)

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jax.lax.cond(
            hard,
            true_fun=lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), self.categories),
            false_fun=lambda x: jax.nn.softmax(x, axis=-1),
            operand=cfs,
        )

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
ohe_t = OneHotTransformation().fit(xs)
transformed_xs = ohe_t.transform(xs)

cfs = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 3))
# Test hard=True which applies softmax function.
soft = ohe_t.apply_constraints(transformed_xs, cfs, hard=False)
assert jnp.allclose(soft.sum(axis=-1), 1)
assert jnp.all(soft >= 0)
assert jnp.all(soft <= 1)

# Test hard=True which enforce one-hot constraint.
hard = ohe_t.apply_constraints(transformed_xs, cfs, hard=True)
assert np.all([1 in x for x in hard])
assert np.all([0 in x for x in hard])
assert jnp.allclose(hard.sum(axis=-1), 1)


In [None]:
#| export
class OrdinalTransformation(Transformation):
    def __init__(self):
        super().__init__("ordinal", OrdinalPreprocessor())

    @property
    def categories(self) -> int:
        return len(self.transformer.categories_)
    
class IdentityTransformation(Transformation):
    def __init__(self):
        super().__init__("identity", None)

    def fit(self, xs, y=None):
        return self
    
    def transform(self, xs):
        return xs
    
    def fit_transform(self, xs, y=None):
        return xs

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return cfs

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
encoder = OrdinalTransformation().fit(xs)
transformed_xs = encoder.transform(xs)
assert np.all(encoder.inverse_transform(transformed_xs) == xs)

xs = np.random.randn(100, 1)
scaler = IdentityTransformation()
transformed_xs = scaler.fit_transform(xs)
assert np.all(transformed_xs == xs)

In [None]:
#| export
PREPROCESSING_TRANSFORMATIONS = {
    'ohe': OneHotTransformation,
    'minmax': MinMaxTransformation,
    'ordinal': OrdinalPreprocessor,
    'identity': IdentityTransformation,
}

## Feature

In [None]:
#| export
class Feature:
    
    def __init__(
        self,
        name: str,
        data: np.ndarray,
        transformation: str | Transformation,
        transformed_data = None,
        is_immutable: bool = False,
    ):
        self.name = name
        self.data = data
        if isinstance(transformation, str):
            self.transformation = PREPROCESSING_TRANSFORMATIONS[transformation]()
        elif isinstance(transformation, Transformation):
            self.transformation = transformation
        else:
            raise ValueError(f"Unknown transformer {transformation}")
        self._transformed_data = transformed_data
        self.is_immutable = is_immutable

    @property
    def transformed_data(self):
        if self._transformed_data is None:
            return self.fit_transform(self.data)
        else:
            return self._transformed_data

    @classmethod
    def from_dict(cls, d):
        return cls(**d)
    
    def to_dict(self):
        return {
            'name': self.name,
            'data': self.data,
            'transformed_data': self.transformed_data,
            'transformation': self.transformation,
            'is_immutable': self.is_immutable,
        }
    
    def __repr__(self):
        return f"Feature(" \
               f"name={self.name}, \ndata={self.data}, \n" \
               f"transformed_data={self.transformed_data}, \n" \
               f"transformer={self.transformation}, \n" \
               f"is_immutable={self.is_immutable})"
    
    __str__ = __repr__

    def __get_item__(self, idx):
        return {
            'data': self.data[idx],
            'transformed_data': self.transformed_data[idx],
        }

    def fit(self):
        self.transformation.fit(self.data)
        return self
    
    def transform(self, xs):
        return self.transformation.transform(xs)

    def fit_transform(self, xs):
        return self.transformation.fit_transform(xs)
    
    def inverse_transform(self, xs):
        return self.transformation.inverse_transform(xs)
    
    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jax.lax.cond(
            self.is_immutable,
            true_fun=lambda xs: xs,
            false_fun=lambda _: self.transformation.apply_constraints(xs, cfs, hard),
            operand=xs,
        )

In [None]:
feat_cont = Feature(
    name='continuous',
    data=np.random.randn(100, 1),
    transformation='minmax',
    is_immutable=False,
)
assert feat_cont.transformed_data.shape == (100, 1)
assert feat_cont.transformed_data.min() >= 0
assert feat_cont.transformed_data.max() <= 1
assert jnp.allclose(
    feat_cont.inverse_transform(feat_cont.transformed_data), feat_cont.data)

feat_cat = Feature(
    name='category',
    data=np.random.choice(['a', 'b', 'c'], size=(100, 1)),
    transformation='ohe',
    is_immutable=False,
)
assert feat_cat.transformed_data.shape == (100, 3)
assert np.all(feat_cat.inverse_transform(feat_cat.transformed_data) == feat_cat.data)

In [None]:
#| export
class FeaturesList:
    def __init__(
        self,
        features: list[Feature],
        *args, **kwargs
    ):
        self._features = features
        self._feature_indices = []
        self._transformed_data = None

    @property
    def features(self):
        return self._features

    @property
    def feature_indices(self):
        if self._feature_indices is None:
            self._transform_data()
        return self._feature_indices
    
    @property
    def transformed_data(self):
        if self._transformed_data is None:
            self._transform_data()
        return self._transformed_data

    def _transform_data(self):
        self._feature_indices = []
        self._transformed_data = []
        start, end = 0, 0
        for feat in self.features:
            transformed_data = feat.transformed_data
            end += transformed_data.shape[-1]
            self._feature_indices.append((start, end))
            self._transformed_data.append(transformed_data)
            start = end

        self._transformed_data = jnp.concatenate(self._transformed_data, axis=-1)

    def transform(self, data):
        raise NotImplementedError

    def inverse_transform(self, xs):
        raise NotImplementedError

    def apply_constraints(self, xs, cfs, hard: bool = False):
        constrainted_cfs = []
        for (start, end), feat in zip(self.feature_indices, self.features):
            _cfs = feat.apply_constraints(xs[:, start:end], cfs[:, start:end], hard)
            constrainted_cfs.append(_cfs)
        return jnp.concatenate(constrainted_cfs, axis=-1)

In [None]:
df = pd.read_csv('assets/adult/data.csv')
cont_feats = ['age', 'hours_per_week']
cat_feats = ["workclass", "education", "marital_status","occupation", "race", "gender"]

feats_list = FeaturesList([
    Feature(name, df[name].to_numpy().reshape(-1, 1), 'minmax') for name in cont_feats
] + [
    Feature(name, df[name].to_numpy().reshape(-1, 1), 'ohe') for name in cat_feats
])
# ds = FeaturesList([
#     Feature('age', df['age'].to_numpy().reshape(-1, 1), 'minmax'),
#     Feature('hours_per_week', df['hours_per_week'].to_numpy().reshape(-1, 1), 'minmax'),
#     Feature('workclass', df['workclass'].to_numpy().reshape(-1, 1), 'ohe'),
# ])
assert feats_list.transformed_data.shape == (32561, 29)
cfs = np.random.randn(10, 29)
assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=False).shape == (10, 29)
assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=True).shape == (10, 29)
# assert feats_list.feature_indices == [(0, 1), (1, 5)]
# cfs = np.random.randn(10, 5)
# feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=False)
# feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=True)

In [None]:
from sklearn.preprocessing import OneHotEncoder as SkOneHotEncoder, MinMaxScaler as SkMinMaxScaler

In [None]:
sk_ohe = SkOneHotEncoder(sparse_output=False)
sk_minmax = SkMinMaxScaler()

for feat in feats_list.features:
    if feat.name in cont_feats:
        assert np.allclose(
            sk_minmax.fit_transform(feat.data),
            feat.transformed_data,
        ), f"Failed at {feat.name}. "
    else:
        assert np.allclose(
            sk_ohe.fit_transform(feat.data),
            feat.transformed_data,
        ), f"Failed at {feat.name}"