# Data Transformations and Features

In [None]:
#| default_exp data_utils

In [None]:
#| export
from __future__ import annotations
from fastcore.test import *
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
import einops
import os, sys, json, pickle
import shutil
from relax.utils import *

In [None]:
#| hide
import sklearn.preprocessing as skp

## Data Preprocessors

`DataPreprocessor` transforms *individual* features into numerical representations for the machine learning and recourse generation workflows. 
It can be considered as a drop-in jax-friendly replacement to the 
[sklearn.preprocessing](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing) module.
The supported preprocessing methods include `MinMaxScaler` and `OneHotEncoder`. 

:::{.callout-important}

Unlike the `DataPreprocessor` [sklearn.preprocessing](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing) module,
all of the data preprocessors work only with single features (e.g., Dim: `(B, 1)`). 

:::

In [None]:
#| export
def _check_xs(xs: np.ndarray, name: str):
    if xs.ndim > 2 or (xs.ndim == 2 and xs.shape[1] != 1):
        raise ValueError(f"`{name}` only supports array with a single feature, but got shape={xs.shape}.")
    
        
class DataPreprocessor:

    def __init__(
        self, 
        name: str = None # The name of the preprocessor. If None, the class name will be used.
    ):
        """Base class for data preprocessors."""
        self.name = name or self.__class__.__name__
    
    def fit(self, xs, y=None):
        """Fit the preprocessor with `xs` and `y`."""
        raise NotImplementedError
    
    def transform(self, xs):
        """Transform `xs`."""
        raise NotImplementedError
    
    def fit_transform(self, xs, y=None):
        """Fit the preprocessor with `xs` and `y`, then transform `xs`."""
        self.fit(xs, y)
        return self.transform(xs)
    
    def inverse_transform(self, xs):
        """Inverse transform `xs`."""
        raise NotImplementedError
    
    def to_dict(self) -> dict:
        """Convert the preprocessor to a dictionary."""
        raise NotImplementedError
    
    def from_dict(self, params: dict):
        """Load the attributes of the preprocessor from a dictionary."""
        raise NotImplementedError
        
    __ALL__ = ["fit", "transform", "fit_transform", "inverse_transform", "to_dict", "from_dict"]

In [None]:
#| export
class MinMaxScaler(DataPreprocessor): 
    def __init__(self):
        super().__init__(name="minmax")
        
    def fit(self, xs, y=None):
        _check_xs(xs, name="MinMaxScaler")
        self.min_ = xs.min(axis=0)
        self.max_ = xs.max(axis=0)
        return self
    
    def transform(self, xs):
        return (xs - self.min_) / (self.max_ - self.min_)
    
    def inverse_transform(self, xs):
        return xs * (self.max_ - self.min_) + self.min_
    
    def from_dict(self, params: dict):
        self.min_ = params["min_"]
        self.max_ = params["max_"]
        return self
    
    def to_dict(self) -> dict:
        return {"min_": self.min_, "max_": self.max_}

In [None]:
xs = np.random.randn(100, )
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert transformed_xs.shape == (100, )
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))
# Test correctness 
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1)).reshape(100,)
)
# Also work with 2D array
xs = xs.reshape(100, 1)
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1))
)

`MinMaxScaler` only supports scaling a single feature.

In [None]:
xs = xs.reshape(50, 2)
scaler = MinMaxScaler()
test_fail(lambda: scaler.fit_transform(xs), 
          contains="`MinMaxScaler` only supports array with a single feature")

Convert to a dictionary (or the pytree representations).

In [None]:
xs = xs.reshape(-1, 1)
scaler = MinMaxScaler().fit(xs)
scaler_1 = MinMaxScaler().from_dict(scaler.to_dict())
assert np.allclose(scaler.transform(xs), scaler_1.transform(xs))

In [None]:
#| exporti
def _unique(xs):
    if xs.dtype == object:
        # Note: np.unique does not work with object dtype
        # We will enforce xs to be string type
        # It assumes that xs is a list of strings, and might not work
        # for other cases (e.g., list of string and numbers)
        return np.unique(xs.astype(str))
    return np.unique(xs)

In [None]:
#| export
class EncoderPreprocessor(DataPreprocessor):
    """Encode categorical features as an integer array."""
    def _fit(self, xs, y=None):
        _check_xs(xs, name="EncoderPreprocessor")
        self.categories_ = _unique(xs)

    def _transform(self, xs):
        """Transform data to ordinal encoding."""
        if xs.dtype == object:
            xs = xs.astype(str)
        ordinal = np.searchsorted(self.categories_, xs)
        # return einops.rearrange(ordinal, 'k n -> n k')
        return ordinal
    
    def _inverse_transform(self, xs):
        """Transform ordinal encoded data back to original data."""
        return self.categories_[xs.T].T
    
    def from_dict(self, params: dict):
        self.categories_ = params["categories_"]
        return self
    
    def to_dict(self) -> dict:
        return {"categories_": self.categories_}

In [None]:
#| export
class OrdinalPreprocessor(EncoderPreprocessor):
    """Ordinal encoder for a single feature."""
    
    def fit(self, xs, y=None):
        self._fit(xs, y)
        return self
    
    def transform(self, xs):
        if xs.ndim == 1:
            raise ValueError(f"OrdinalPreprocessor only supports 2D array with a single feature, "
                             f"but got shape={xs.shape}.")
        return self._transform(xs)
    
    def inverse_transform(self, xs):
        return self._inverse_transform(xs)

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
enc = OrdinalPreprocessor().fit(xs)
transformed_xs = enc.transform(xs)
assert np.all(enc.inverse_transform(transformed_xs) == xs)
# Test from_dict and to_dict
enc_1 = OrdinalPreprocessor().from_dict(enc.to_dict())
assert np.all(enc.transform(xs) == enc_1.transform(xs))

xs = np.array(['a', 'b', 'c', np.nan, 'a', 'b', 'c', np.nan], dtype=object).reshape(-1, 1)
enc = OrdinalPreprocessor().fit(xs)
# Check categories_
assert np.array_equiv(enc.categories_, np.array(['a', 'b', 'c', np.nan], dtype=str)) 
transformed_xs = enc.transform(xs)
assert transformed_xs.shape == (8, 1)
inverse_transformed_xs = enc.inverse_transform(transformed_xs)
assert np.all(inverse_transformed_xs == xs.astype(str))
# Test from_dict and to_dict
enc_1 = OrdinalPreprocessor().from_dict(enc.to_dict())
assert np.all(enc.transform(xs) == enc_1.transform(xs))
assert np.array_equal(enc.categories_, enc_1.categories_)

xs = np.random.choice(['a', 'b', 'c'], size=(100, ))
test_fail(lambda: OrdinalPreprocessor().fit_transform(xs), 
    contains="OrdinalPreprocessor only supports 2D array with a single feature")


In [None]:
#| export
class OneHotEncoder(EncoderPreprocessor):
    """One-hot encoder for a single categorical feature."""
    
    def fit(self, xs, y=None):
        self._fit(xs, y)
        return self

    def transform(self, xs):
        if xs.ndim == 1:
            raise ValueError(f"OneHotEncoder only supports 2D array with a single feature, "
                             f"but got shape={xs.shape}.")
        xs_int = self._transform(xs)
        one_hot_feats = jax.nn.one_hot(xs_int, len(self.categories_))
        return einops.rearrange(one_hot_feats, 'n k d -> n (k d)')

    def inverse_transform(self, xs):
        xs_int = np.argmax(xs, axis=-1)
        return self._inverse_transform(xs_int).reshape(-1, 1)

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
enc = OneHotEncoder().fit(xs)
transformed_xs = enc.transform(xs)
assert np.all(enc.inverse_transform(transformed_xs) == xs)
# Test from_dict and to_dict
enc_1 = OneHotEncoder().from_dict(enc.to_dict())
assert np.all(enc.transform(xs) == enc_1.transform(xs))

xs = np.array(['a', 'b', 'c', np.nan, 'a', 'b', 'c', np.nan], dtype=object).reshape(-1, 1)
enc = OneHotEncoder().fit(xs)
transformed_xs = enc.transform(xs)
assert np.all(enc.inverse_transform(transformed_xs) == xs.astype(str))
# Test from_dict and to_dict
enc_1 = OneHotEncoder().from_dict(enc.to_dict())
enc_2 = OneHotEncoder()
enc_2.from_dict(enc_1.to_dict())
assert np.all(enc.transform(xs) == enc_1.transform(xs))
assert np.all(enc.transform(xs) == enc_2.transform(xs))

xs = np.random.choice(['a', 'b', 'c'], size=(100, ))
test_fail(lambda: OneHotEncoder().fit_transform(xs), 
    contains="OneHotEncoder only supports 2D array with a single feature")


## Transformation


In [None]:
#| export
class Transformation:
    def __init__(self, name, transformer):
        self.name = name
        self.transformer = transformer

    def fit(self, xs, y=None):
        self.transformer.fit(xs)
        return self
    
    def transform(self, xs):
        return self.transformer.transform(xs)

    def fit_transform(self, xs, y=None):
        return self.transformer.fit_transform(xs)
    
    def inverse_transform(self, xs):
        return self.transformer.inverse_transform(xs)

    def apply_constraints(self, xs):
        return xs
    
    def from_dict(self, params: dict):
        self.name = params["name"]
        self.transformer.from_dict(params["transformer"])
        return self
    
    def to_dict(self) -> dict:
        return {"name": self.name, "transformer": self.transformer.to_dict()}

In [None]:
#| export
class MinMaxTransformation(Transformation):
    def __init__(self):
        super().__init__("minmax", MinMaxScaler())

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jnp.clip(cfs, 0., 1.)

In [None]:
xs = np.random.randn(100, 1)
scaler = MinMaxTransformation()
transformed_xs = scaler.fit_transform(xs)

cfs = np.random.randn(100, 1)
cf_constrained = scaler.apply_constraints(xs, cfs)
assert np.all(cf_constrained >= 0) and np.all(cf_constrained <= 1)

# Test from_dict and to_dict
scaler_1 = MinMaxTransformation().from_dict(scaler.to_dict())
assert np.allclose(scaler.transform(xs), scaler_1.transform(xs))

In [None]:
#| export
class OneHotTransformation(Transformation):
    def __init__(self):
        super().__init__("ohe", OneHotEncoder())

    @property
    def categories(self) -> int:
        return len(self.transformer.categories_)

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jax.lax.cond(
            hard,
            true_fun=lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), self.categories),
            false_fun=lambda x: jax.nn.softmax(x, axis=-1),
            operand=cfs,
        )

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
ohe_t = OneHotTransformation().fit(xs)
transformed_xs = ohe_t.transform(xs)

cfs = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 3))
# Test hard=True which applies softmax function.
soft = ohe_t.apply_constraints(transformed_xs, cfs, hard=False)
assert jnp.allclose(soft.sum(axis=-1), 1)
assert jnp.all(soft >= 0)
assert jnp.all(soft <= 1)

# Test hard=True which enforce one-hot constraint.
hard = ohe_t.apply_constraints(transformed_xs, cfs, hard=True)
assert np.all([1 in x for x in hard])
assert np.all([0 in x for x in hard])
assert jnp.allclose(hard.sum(axis=-1), 1)

# Test from_dict and to_dict
ohe_t_1 = OneHotTransformation().from_dict(ohe_t.to_dict())
assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))

In [None]:
#| export
class OrdinalTransformation(Transformation):
    def __init__(self):
        super().__init__("ordinal", OrdinalPreprocessor())

    @property
    def categories(self) -> int:
        return len(self.transformer.categories_)
    
class IdentityTransformation(Transformation):
    def __init__(self):
        super().__init__("identity", None)

    def fit(self, xs, y=None):
        return self
    
    def transform(self, xs):
        return xs
    
    def fit_transform(self, xs, y=None):
        return xs

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return cfs
    
    def to_dict(self):
        return {'name': 'identity'}
    
    def from_dict(self, params: dict):
        self.name = params["name"]
        return self

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
encoder = OrdinalTransformation().fit(xs)
transformed_xs = encoder.transform(xs)
assert np.all(encoder.inverse_transform(transformed_xs) == xs)

# Test from_dict and to_dict
encoder_1 = OrdinalTransformation().from_dict(encoder.to_dict())
assert np.allclose(encoder.transform(xs), encoder_1.transform(xs))

xs = np.random.randn(100, 1)
scaler = IdentityTransformation()
transformed_xs = scaler.fit_transform(xs)
assert np.all(transformed_xs == xs)

# Test from_dict and to_dict
scaler_1 = IdentityTransformation().from_dict(scaler.to_dict())
assert np.allclose(scaler.transform(xs), scaler_1.transform(xs))

In [None]:
#| export
PREPROCESSING_TRANSFORMATIONS = {
    'ohe': OneHotTransformation,
    'minmax': MinMaxTransformation,
    'ordinal': OrdinalPreprocessor,
    'identity': IdentityTransformation,
}

## Feature

In [None]:
#| export
class Feature:
    
    def __init__(
        self,
        name: str,
        data: np.ndarray,
        transformation: str | Transformation,
        transformed_data = None,
        is_immutable: bool = False,
    ):
        self.name = name
        self.data = data
        if isinstance(transformation, str):
            self.transformation = PREPROCESSING_TRANSFORMATIONS[transformation]()
        elif isinstance(transformation, Transformation):
            self.transformation = transformation
        elif isinstance(transformation, dict):
            # TODO: only supported transformation can be used for serialization
            t_name = transformation['name']
            if t_name not in PREPROCESSING_TRANSFORMATIONS.keys():
                raise ValueError("Only supported transformation can be inited from dict. "
                                 f"Got {t_name}, but should be one of {PREPROCESSING_TRANSFORMATIONS.keys()}.")
            self.transformation = PREPROCESSING_TRANSFORMATIONS[t_name]().from_dict(transformation)
        else:
            raise ValueError(f"Unknown transformer {transformation}")
        self._transformed_data = transformed_data
        self.is_immutable = is_immutable

    @property
    def transformed_data(self):
        if self._transformed_data is None:
            return self.fit_transform(self.data)
        else:
            return self._transformed_data

    @classmethod
    def from_dict(cls, d):
        return cls(**d)
    
    def to_dict(self):
        return {
            'name': self.name,
            'data': self.data,
            'transformed_data': self.transformed_data,
            'transformation': self.transformation.to_dict(),
            'is_immutable': self.is_immutable,
        }
    
    def __repr__(self):
        return f"Feature(" \
               f"name={self.name}, \ndata={self.data}, \n" \
               f"transformed_data={self.transformed_data}, \n" \
               f"transformer={self.transformation}, \n" \
               f"is_immutable={self.is_immutable})"
    
    __str__ = __repr__

    def __get_item__(self, idx):
        return {
            'data': self.data[idx],
            'transformed_data': self.transformed_data[idx],
        }

    def fit(self):
        self.transformation.fit(self.data)
        return self
    
    def transform(self, xs):
        return self.transformation.transform(xs)

    def fit_transform(self, xs):
        return self.transformation.fit_transform(xs)
    
    def inverse_transform(self, xs):
        return self.transformation.inverse_transform(xs)
    
    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jax.lax.cond(
            self.is_immutable,
            true_fun=lambda xs: xs,
            false_fun=lambda _: self.transformation.apply_constraints(xs, cfs, hard),
            operand=xs,
        )

In [None]:
feat_cont = Feature(
    name='continuous',
    data=np.random.randn(100, 1),
    transformation='minmax',
    is_immutable=False,
)
assert feat_cont.transformed_data.shape == (100, 1)
assert feat_cont.transformed_data.min() >= 0
assert feat_cont.transformed_data.max() <= 1
assert jnp.allclose(
    feat_cont.inverse_transform(feat_cont.transformed_data), feat_cont.data)

feat_cat = Feature(
    name='category',
    data=np.random.choice(['a', 'b', 'c'], size=(100, 1)),
    transformation='ohe',
    is_immutable=False,
)
assert feat_cat.transformed_data.shape == (100, 3)
assert np.all(feat_cat.inverse_transform(feat_cat.transformed_data) == feat_cat.data)

# TODO: Need test cases for from_dict and to_dict

In [None]:
#| export
class FeaturesList:
    def __init__(
        self,
        features: list[Feature] | FeaturesList,
        *args, **kwargs
    ):
        if isinstance(features, FeaturesList):
            self._features = features.features
            self._feature_indices = features.feature_indices
            self._transformed_data = features.transformed_data
        elif isinstance(features, Feature):
            self._features = [features]
            self._feature_indices = []
            self._transformed_data = None
        elif isinstance(features, list):
            if len(features) > 0 and not isinstance(features[0], Feature):
                raise ValueError(f"Invalid features type: {type(features[0]).__name__}")
            self._features = features
            self._feature_indices = []
            self._transformed_data = None
        else:
            raise ValueError(f"Unknown features type {type(features)}")

    @property
    def features(self):
        return self._features

    @property
    def feature_indices(self):
        if self._feature_indices is None or len(self._feature_indices) == 0:
            self._transform_data()
        return self._feature_indices
    
    @property
    def transformed_data(self):
        if self._transformed_data is None:
            self._transform_data()
        return self._transformed_data
    
    def to_dict(self):
        return {
            'features': [feat.to_dict() for feat in self.features],
        }
    
    @classmethod
    def from_dict(cls, d):
        return cls(
            features=[Feature.from_dict(feat) for feat in d['features']],
        )

    def save(self, saved_dir):
        os.makedirs(saved_dir, exist_ok=True)
        save_pytree(self.to_dict(), saved_dir)
        
    @classmethod
    def load_from_path(cls, saved_dir):
        return cls.from_dict(load_pytree(saved_dir))

    def _transform_data(self):
        self._feature_indices = []
        self._transformed_data = []
        start, end = 0, 0
        for feat in self.features:
            transformed_data = feat.transformed_data
            end += transformed_data.shape[-1]
            self._feature_indices.append((start, end))
            self._transformed_data.append(transformed_data)
            start = end

        self._transformed_data = jnp.concatenate(self._transformed_data, axis=-1)

    def transform(self, data):
        raise NotImplementedError

    def inverse_transform(self, xs):
        raise NotImplementedError

    def apply_constraints(self, xs, cfs, hard: bool = False):
        constrainted_cfs = []
        for (start, end), feat in zip(self.feature_indices, self.features):
            _cfs = feat.apply_constraints(xs[:, start:end], cfs[:, start:end], hard)
            constrainted_cfs.append(_cfs)
        return jnp.concatenate(constrainted_cfs, axis=-1)

In [None]:
df = pd.read_csv('assets/adult/data/data.csv')
cont_feats = ['age', 'hours_per_week']
cat_feats = ["workclass", "education", "marital_status","occupation", "race", "gender"]

feats_list = FeaturesList([
    Feature(name, df[name].to_numpy().reshape(-1, 1), 'minmax') for name in cont_feats
] + [
    Feature(name, df[name].to_numpy().reshape(-1, 1), 'ohe') for name in cat_feats
])
assert feats_list.transformed_data.shape == (32561, 29)
cfs = np.random.randn(10, 29)
assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=False).shape == (10, 29)
assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=True).shape == (10, 29)

# Test save and load
feats_list.save('tmp/data_module/')
feats_list_1 = FeaturesList.load_from_path('tmp/data_module/')
# remove tmp folder
shutil.rmtree('tmp')

# assert feats_list.feature_indices == [(0, 1), (1, 5)]
# cfs = np.random.randn(10, 5)
# feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=False)
# feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=True)

In [None]:
from sklearn.preprocessing import OneHotEncoder as SkOneHotEncoder, MinMaxScaler as SkMinMaxScaler

In [None]:
sk_ohe = SkOneHotEncoder(sparse_output=False)
sk_minmax = SkMinMaxScaler()

for feat in feats_list.features:
    if feat.name in cont_feats:
        assert np.allclose(
            sk_minmax.fit_transform(feat.data),
            feat.transformed_data,
        ), f"Failed at {feat.name}. "
    else:
        assert np.allclose(
            sk_ohe.fit_transform(feat.data),
            feat.transformed_data,
        ), f"Failed at {feat.name}"