# Data Transformations and Features

In [None]:
#| default_exp data_utils

In [None]:
#| export
from __future__ import annotations
from fastcore.test import *
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
import einops
import os, sys, json, pickle
import shutil
from relax.utils import *
import chex

In [None]:
#| hide
import sklearn.preprocessing as skp
from copy import deepcopy

## Data Preprocessors

`DataPreprocessor` transforms *individual* features into numerical representations for the machine learning and recourse generation workflows. 
It can be considered as a drop-in jax-friendly replacement to the 
[sklearn.preprocessing](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing) module.
The supported preprocessing methods include `MinMaxScaler` and `OneHotEncoder`. 

:::{.callout-important}

Unlike the `DataPreprocessor` [sklearn.preprocessing](https://scikit-learn.org/stable/modules/classes.html#module-sklearn.preprocessing) module,
all of the data preprocessors work only with single features (e.g., Dim: `(B, 1)`). 

:::

In [None]:
#| export
def _check_xs(xs: np.ndarray, name: str):
    if xs.ndim > 2 or (xs.ndim == 2 and xs.shape[1] != 1):
        raise ValueError(f"`{name}` only supports array with a single feature, but got shape={xs.shape}.")
    
        
class DataPreprocessor:

    def __init__(
        self, 
        name: str = None # The name of the preprocessor. If None, the class name will be used.
    ):
        """Base class for data preprocessors."""
        self.name = name or self.__class__.__name__
    
    def fit(self, xs, y=None):
        """Fit the preprocessor with `xs` and `y`."""
        raise NotImplementedError
    
    def transform(self, xs):
        """Transform `xs`."""
        raise NotImplementedError
    
    def fit_transform(self, xs, y=None):
        """Fit the preprocessor with `xs` and `y`, then transform `xs`."""
        self.fit(xs, y)
        return self.transform(xs)
    
    def inverse_transform(self, xs):
        """Inverse transform `xs`."""
        raise NotImplementedError
    
    def to_dict(self) -> dict:
        """Convert the preprocessor to a dictionary."""
        raise NotImplementedError
    
    def from_dict(self, params: dict):
        """Load the attributes of the preprocessor from a dictionary."""
        raise NotImplementedError
        
    __ALL__ = ["fit", "transform", "fit_transform", "inverse_transform", "to_dict", "from_dict"]

In [None]:
#| export
class MinMaxScaler(DataPreprocessor): 
    def __init__(self):
        super().__init__(name="minmax")
        
    def fit(self, xs, y=None):
        _check_xs(xs, name="MinMaxScaler")
        self.min_ = xs.min(axis=0)
        self.max_ = xs.max(axis=0)
        return self
    
    def transform(self, xs):
        return (xs - self.min_) / (self.max_ - self.min_)
    
    def inverse_transform(self, xs):
        return xs * (self.max_ - self.min_) + self.min_
    
    def from_dict(self, params: dict):
        self.min_ = params["min_"]
        self.max_ = params["max_"]
        return self
    
    def to_dict(self) -> dict:
        return {"min_": self.min_, "max_": self.max_}

In [None]:
xs = np.random.randn(100, )
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert transformed_xs.shape == (100, )
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))
# Test correctness 
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1)).reshape(100,)
)
# Also work with 2D array
xs = xs.reshape(100, 1)
scaler = MinMaxScaler()
transformed_xs = scaler.fit_transform(xs)
assert np.allclose(xs, scaler.inverse_transform(transformed_xs))
assert np.allclose(
    transformed_xs, 
    skp.MinMaxScaler().fit_transform(xs.reshape(100, 1))
)

`MinMaxScaler` only supports scaling a single feature.

In [None]:
xs = xs.reshape(50, 2)
scaler = MinMaxScaler()
test_fail(lambda: scaler.fit_transform(xs), 
          contains="`MinMaxScaler` only supports array with a single feature")

Convert to a dictionary (or the pytree representations).

In [None]:
xs = xs.reshape(-1, 1)
scaler = MinMaxScaler().fit(xs)
scaler_1 = MinMaxScaler().from_dict(scaler.to_dict())
assert np.allclose(scaler.transform(xs), scaler_1.transform(xs))

In [None]:
#| exporti
def _unique(xs):
    if xs.dtype == object:
        # Note: np.unique does not work with object dtype
        # We will enforce xs to be string type
        # It assumes that xs is a list of strings, and might not work
        # for other cases (e.g., list of string and numbers)
        return np.unique(xs.astype(str))
    return np.unique(xs)

In [None]:
#| export
class EncoderPreprocessor(DataPreprocessor):
    """Encode categorical features as an integer array."""
    def _fit(self, xs, y=None):
        _check_xs(xs, name="EncoderPreprocessor")
        self.categories_ = _unique(xs)

    def _transform(self, xs):
        """Transform data to ordinal encoding."""
        if xs.dtype == object:
            xs = xs.astype(str)
        ordinal = np.searchsorted(self.categories_, xs)
        # return einops.rearrange(ordinal, 'k n -> n k')
        return ordinal
    
    def _inverse_transform(self, xs):
        """Transform ordinal encoded data back to original data."""
        return self.categories_[xs.T].T
    
    def from_dict(self, params: dict):
        self.categories_ = params["categories_"]
        return self
    
    def to_dict(self) -> dict:
        return {"categories_": self.categories_}

In [None]:
#| export
class OrdinalPreprocessor(EncoderPreprocessor):
    """Ordinal encoder for a single feature."""
    
    def fit(self, xs, y=None):
        self._fit(xs, y)
        return self
    
    def transform(self, xs):
        if xs.ndim == 1:
            raise ValueError(f"OrdinalPreprocessor only supports 2D array with a single feature, "
                             f"but got shape={xs.shape}.")
        return self._transform(xs)
    
    def inverse_transform(self, xs):
        return self._inverse_transform(xs)

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
enc = OrdinalPreprocessor().fit(xs)
transformed_xs = enc.transform(xs)
assert np.all(enc.inverse_transform(transformed_xs) == xs)
# Test from_dict and to_dict
enc_1 = OrdinalPreprocessor().from_dict(enc.to_dict())
assert np.all(enc.transform(xs) == enc_1.transform(xs))

xs = np.array(['a', 'b', 'c', np.nan, 'a', 'b', 'c', np.nan], dtype=object).reshape(-1, 1)
enc = OrdinalPreprocessor().fit(xs)
# Check categories_
assert np.array_equiv(enc.categories_, np.array(['a', 'b', 'c', np.nan], dtype=str)) 
transformed_xs = enc.transform(xs)
assert transformed_xs.shape == (8, 1)
inverse_transformed_xs = enc.inverse_transform(transformed_xs)
assert np.all(inverse_transformed_xs == xs.astype(str))
# Test from_dict and to_dict
enc_1 = OrdinalPreprocessor().from_dict(enc.to_dict())
assert np.all(enc.transform(xs) == enc_1.transform(xs))
assert np.array_equal(enc.categories_, enc_1.categories_)

xs = np.random.choice(['a', 'b', 'c'], size=(100, ))
test_fail(lambda: OrdinalPreprocessor().fit_transform(xs), 
    contains="OrdinalPreprocessor only supports 2D array with a single feature")


In [None]:
#| export
class OneHotEncoder(EncoderPreprocessor):
    """One-hot encoder for a single categorical feature."""
    
    def fit(self, xs, y=None):
        self._fit(xs, y)
        return self

    def transform(self, xs):
        if xs.ndim == 1:
            raise ValueError(f"OneHotEncoder only supports 2D array with a single feature, "
                             f"but got shape={xs.shape}.")
        xs_int = self._transform(xs)
        one_hot_feats = jax.nn.one_hot(xs_int, len(self.categories_))
        return einops.rearrange(one_hot_feats, 'n k d -> n (k d)')

    def inverse_transform(self, xs):
        xs_int = np.argmax(xs, axis=-1)
        return self._inverse_transform(xs_int).reshape(-1, 1)

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
enc = OneHotEncoder().fit(xs)
transformed_xs = enc.transform(xs)
assert np.all(enc.inverse_transform(transformed_xs) == xs)
# Test from_dict and to_dict
enc_1 = OneHotEncoder().from_dict(enc.to_dict())
assert np.all(enc.transform(xs) == enc_1.transform(xs))

xs = np.array(['a', 'b', 'c', np.nan, 'a', 'b', 'c', np.nan], dtype=object).reshape(-1, 1)
enc = OneHotEncoder().fit(xs)
transformed_xs = enc.transform(xs)
assert np.all(enc.inverse_transform(transformed_xs) == xs.astype(str))
# Test from_dict and to_dict
enc_1 = OneHotEncoder().from_dict(enc.to_dict())
enc_2 = OneHotEncoder()
enc_2.from_dict(enc_1.to_dict())
assert np.all(enc.transform(xs) == enc_1.transform(xs))
assert np.all(enc.transform(xs) == enc_2.transform(xs))

xs = np.random.choice(['a', 'b', 'c'], size=(100, ))
test_fail(lambda: OneHotEncoder().fit_transform(xs), 
    contains="OneHotEncoder only supports 2D array with a single feature")


## Transformation


In [None]:
#| export
class Transformation:
    def __init__(self, name, transformer):
        self.name = name
        self.transformer = transformer

    @property
    def is_categorical(self) -> bool:
        return isinstance(self.transformer, EncoderPreprocessor)

    def fit(self, xs, y=None):
        self.transformer.fit(xs)
        return self
    
    def transform(self, xs):
        return self.transformer.transform(xs)

    def fit_transform(self, xs, y=None):
        return self.transformer.fit_transform(xs)
    
    def inverse_transform(self, xs):
        return self.transformer.inverse_transform(xs)

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return cfs
    
    def compute_reg_loss(self, xs, cfs, hard: bool = False):
        return 0.
    
    def from_dict(self, params: dict):
        self.name = params["name"]
        self.transformer.from_dict(params["transformer"])
        return self
    
    def to_dict(self) -> dict:
        return {"name": self.name, "transformer": self.transformer.to_dict()}

In [None]:
#| export
class MinMaxTransformation(Transformation):
    def __init__(self):
        super().__init__("minmax", MinMaxScaler())

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jnp.clip(cfs, 0., 1.)

In [None]:
xs = np.random.randn(100, 1)
scaler = MinMaxTransformation()
transformed_xs = scaler.fit_transform(xs)
assert scaler.is_categorical is False

cfs = np.random.randn(100, 1)
cf_constrained = scaler.apply_constraints(xs, cfs)
assert np.all(cf_constrained >= 0) and np.all(cf_constrained <= 1)

# Test from_dict and to_dict
scaler_1 = MinMaxTransformation().from_dict(scaler.to_dict())
assert np.allclose(scaler.transform(xs), scaler_1.transform(xs))

In [None]:
#| export
class OneHotTransformation(Transformation):
    def __init__(self):
        super().__init__("ohe", OneHotEncoder())

    @property
    def num_categories(self) -> int:
        return len(self.transformer.categories_)

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jax.lax.cond(
            hard,
            true_fun=lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories),
            false_fun=lambda x: jax.nn.softmax(x, axis=-1),
            operand=cfs,
        )
    
    def compute_reg_loss(self, xs, cfs, hard: bool = False):
        reg_loss_per_xs = (cfs.sum(axis=-1, keepdims=True) - 1.0) ** 2
        return reg_loss_per_xs.mean()

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
ohe_t = OneHotTransformation().fit(xs)
transformed_xs = ohe_t.transform(xs)

cfs = jax.random.uniform(jax.random.PRNGKey(0), shape=(100, 3))
# Test hard=True which applies softmax function.
soft = ohe_t.apply_constraints(transformed_xs, cfs, hard=False)
assert jnp.allclose(soft.sum(axis=-1), 1)
assert jnp.all(soft >= 0)
assert jnp.all(soft <= 1)
assert jnp.allclose(jnp.zeros((len(cfs), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))

# Test hard=True which enforce one-hot constraint.
hard = ohe_t.apply_constraints(transformed_xs, cfs, hard=True)
assert np.all([1 in x for x in hard])
assert np.all([0 in x for x in hard])
assert jnp.allclose(hard.sum(axis=-1), 1)
assert jnp.allclose(jnp.zeros((len(cfs), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))

# Test compute_reg_loss
assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0

# Test from_dict and to_dict
ohe_t_1 = OneHotTransformation().from_dict(ohe_t.to_dict())
assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))

In [None]:
#| export
class OrdinalTransformation(Transformation):
    def __init__(self):
        super().__init__("ordinal", OrdinalPreprocessor())

    @property
    def num_categories(self) -> int:
        return len(self.transformer.categories_)
    
class IdentityTransformation(Transformation):
    def __init__(self):
        super().__init__("identity", None)

    def fit(self, xs, y=None):
        return self
    
    def transform(self, xs):
        return xs
    
    def fit_transform(self, xs, y=None):
        return xs

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return cfs
    
    def to_dict(self):
        return {'name': 'identity'}
    
    def from_dict(self, params: dict):
        self.name = params["name"]
        return self

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
encoder = OrdinalTransformation().fit(xs)
transformed_xs = encoder.transform(xs)
assert np.all(encoder.inverse_transform(transformed_xs) == xs)

# Test from_dict and to_dict
encoder_1 = OrdinalTransformation().from_dict(encoder.to_dict())
assert np.allclose(encoder.transform(xs), encoder_1.transform(xs))

xs = np.random.randn(100, 1)
scaler = IdentityTransformation()
transformed_xs = scaler.fit_transform(xs)
assert np.all(transformed_xs == xs)

# Test from_dict and to_dict
scaler_1 = IdentityTransformation().from_dict(scaler.to_dict())
assert np.allclose(scaler.transform(xs), scaler_1.transform(xs))

## Feature

In [None]:
#| export
PREPROCESSING_TRANSFORMATIONS = {
    'ohe': OneHotTransformation,
    'minmax': MinMaxTransformation,
    'ordinal': OrdinalTransformation,
    'identity': IdentityTransformation,
}

In [None]:
#| export
class Feature:
    """THe feature class which represents a column in the dataset."""
    
    def __init__(
        self,
        name: str,
        data: np.ndarray,
        transformation: str | Transformation | dict,
        transformed_data = None,
        is_immutable: bool = False,
        is_categorical: bool = None,
    ):
        self.name = name
        self._data = data
        self._transformation = self._dispatch_transformation(transformation)
        self._transformed_data = transformed_data
        self._is_immutable = is_immutable
        self._is_categorical = self._init_is_categorical(is_categorical)

    def with_transformed_data(
        self,
        transformed_data: np.ndarray,
        **kwargs        
    ) -> Feature:
        """Create a new feature with transformed data."""
        
        org_data = self.inverse_transform(transformed_data)
        return Feature(
            name=self.name,
            data=org_data,
            transformation=self.transformation,
            transformed_data=transformed_data,
            is_immutable=self.is_immutable,
            is_categorical=self.is_categorical,
        )
    
    @property
    def data(self) -> jax.Array:
        return self._data
    
    @property
    def is_immutable(self) -> bool:
        return self._is_immutable
    
    @property
    def transformation(self) -> Transformation:
        return self._transformation
    
    @property
    def is_categorical(self) -> bool:
        return self._is_categorical

    @property
    def transformed_data(self) -> jax.Array:
        if self._transformed_data is None:
            return self.fit_transform(self.data)
        else:
            return self._transformed_data

    @classmethod
    def from_dict(cls, d):
        return cls(**d)
    
    def to_dict(self):
        return {
            'name': self.name,
            'data': self.data,
            'transformed_data': self.transformed_data,
            'transformation': self.transformation.to_dict(),
            'is_immutable': self.is_immutable,
            'is_categorical': self.is_categorical,
        }
    
    def __repr__(self):
        return "Feature(" + ",\n".join([
            f"{k}={v}" for k, v in self.to_dict().items()]) + ")"
    
    __str__ = __repr__

    def __get_item__(self, idx):
        return self.to_dict().update({
            'data': self.data[idx],
            'transformed_data': self.transformed_data[idx],
        })
    
    def _dispatch_transformation(self, transformation: str | dict | Transformation):
        T = PREPROCESSING_TRANSFORMATIONS
        if isinstance(transformation, str):
            if transformation not in T.keys():
                raise ValueError(f"Unknown transformation: {transformation}")
            return T[transformation]()
        elif isinstance(transformation, dict):
            # TODO: only supported transformation can be used for serialization
            t_name = transformation['name']
            if t_name not in T.keys():
                raise ValueError("Only supported transformation can be inited from dict. "
                                 f"Got {t_name}, but should be one of {T.keys()}.")
            return T[t_name]().from_dict(transformation)
        elif isinstance(transformation, Transformation):
            return transformation
        else:
            raise ValueError(f"Unknown transformation: {transformation}")
        
    def _init_is_categorical(self, is_categorical: bool = None):
        if is_categorical is None:
            return self.transformation.is_categorical
        else:
            if hasattr(self, '_is_categorical'):
                assert self.is_categorical == is_categorical
            return is_categorical

    def set_transformation(self, transformation: str | dict | Transformation) -> Feature:
        self._transformation = self._dispatch_transformation(transformation)
        self._is_categorical = self._init_is_categorical()
        self._transformed_data = None # Reset transformed data
        return self
    
    def fit(self):
        self.transformation.fit(self.data)
        return self
    
    def transform(self, xs):
        return self.transformation.transform(xs)

    def fit_transform(self, xs):
        return self.transformation.fit_transform(xs)
    
    def inverse_transform(self, xs):
        return self.transformation.inverse_transform(xs)
    
    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jax.lax.cond(
            self.is_immutable,
            true_fun=lambda xs: jnp.broadcast_to(xs, cfs.shape),
            false_fun=lambda _: self.transformation.apply_constraints(xs, cfs, hard),
            operand=xs,
        )
    
    def compute_reg_loss(self, xs, cfs, hard: bool = False):
        return self.transformation.compute_reg_loss(xs, cfs, hard)

In [None]:
feat_cont = Feature(
    name='continuous',
    data=np.random.randn(100, 1),
    transformation='minmax',
    is_immutable=False,
)
assert feat_cont.transformed_data.shape == (100, 1)
assert feat_cont.transformed_data.min() >= 0
assert feat_cont.transformed_data.max() <= 1
assert jnp.allclose(
    feat_cont.inverse_transform(feat_cont.transformed_data), feat_cont.data)
assert feat_cont.is_categorical is False

feat_cont_1 = feat_cont.with_transformed_data(feat_cont.transformed_data)
assert isinstance(feat_cont_1, Feature)
assert feat_cont_1 is not feat_cont
assert np.allclose(
    feat_cont_1.data, feat_cont.data
)
assert feat_cont.transformation.to_dict() == feat_cont_1.transformation.to_dict()

feat_cat = Feature(
    name='category',
    data=np.random.choice(['a', 'b', 'c'], size=(100, 1)),
    transformation='ohe',
    is_immutable=False,
)
assert feat_cat.transformed_data.shape == (100, 3)
assert np.all(feat_cat.inverse_transform(feat_cat.transformed_data) == feat_cat.data)
assert feat_cat.is_categorical

feat_cat_1 = feat_cat.with_transformed_data(jax.nn.one_hot(jnp.array([0, 1, 2, 0, 1, 2]), 3))
assert feat_cat_1 is not feat_cat
assert np.array_equal(
    feat_cat_1.data, np.array(['a', 'b', 'c', 'a', 'b', 'c']).reshape(-1, 1)
) 

# Test serialization
d = feat_cont.to_dict()
feat_cont_1 = Feature.from_dict(d)
assert feat_cont_1.name == feat_cont.name
assert np.allclose(feat_cont_1.data, feat_cont.data)
assert np.allclose(feat_cont_1.transformed_data, feat_cont.transformed_data)
assert feat_cont_1.is_immutable == feat_cont.is_immutable

In [None]:
# Test set_transformation
feat_cat = Feature(
    name='category',
    data=np.random.choice(['a', 'b', 'c'], size=(100, 1)),
    transformation='ohe',
    is_immutable=False,
)
assert feat_cat.transformation.name == 'ohe'
assert feat_cat.transformed_data.shape == (100, 3)
feat_cat.set_transformation('ordinal')
assert feat_cat.transformation.name == 'ordinal'
assert feat_cat.is_categorical
assert feat_cat.transformed_data.shape == (100, 1)
assert feat_cat.is_immutable is False

In [None]:
#| export
class FeaturesList:
    def __init__(
        self,
        features: list[Feature] | FeaturesList,
        *args, **kwargs
    ):
        if isinstance(features, FeaturesList):
            self._features = features.features
            self._feature_indices = features.feature_indices
            self._transformed_data = features.transformed_data
        elif isinstance(features, Feature):
            self._features = [features]
        elif isinstance(features, list):
            if len(features) > 0 and not isinstance(features[0], Feature):
                raise ValueError(f"Invalid features type: {type(features[0]).__name__}")
            self._features = features
        else:
            raise ValueError(f"Unknown features type. Got {type(features).__name__}")
        
        # Record the current position of the features
        self.pose = 0

    # Iterator
    def __len__(self):
        return len(self._features)
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.pose < len(self):
            feat = self._features[self.pose]
            self.pose += 1
            return feat
        else:
            self.pose = 0
            raise StopIteration
        
    # Indexing
    def __getitem__(self, idx: str | list[str]) -> Feature | list[Feature]:
        if not hasattr(self, "_feature_name_list_indices"):
            self._feature_name_list_indices = {feat.name: i for i, feat in enumerate(self._features)}
        indices = self._feature_name_list_indices
        if isinstance(idx, str):
            if idx not in indices:
                raise ValueError(f"Invalid feature name: {idx}")
            return self._features[indices[idx]]
        elif isinstance(idx, list):
            return [self[i] for i in idx]
        else:
            raise ValueError(f"Invalid idx type: {type(idx).__name__}")
    
    #############################
    # Properties
    #############################
    @property
    def features(self) -> list[Feature]: # Return [Feature(...), ...]
        return self._features

    @property
    def feature_indices(self) -> list[tuple[int, int]]: # Return [(start, end), ...]
        if not hasattr(self, "_feature_indices") or self._feature_indices is None or len(self._feature_indices) == 0:
            self._transform_data()
        return self._feature_indices
    
    @property
    def features_and_indices(self) -> list[tuple[Feature, tuple[int, int]]]: # Return [(Feature(...), (start, end)), ...]
        return list(zip(self.features, self.feature_indices))
    
    @property
    def feature_name_indices(self) -> dict[str, tuple[int, int]]: # Return {feature_name: (feat_idx, start, end), ...}
        if not hasattr(self, "_feature_name_indices") or self._feature_name_indices is None:
            self._transform_data()
        return self._feature_name_indices
    
    @property
    def transformed_data(self) -> jax.Array:
        if not hasattr(self, "_transformed_data") or self._transformed_data is None:
            self._transform_data()
        return self._transformed_data
    
    #############################
    # Methods
    #############################
    def with_transformed_data(
        self,
        transformed_data: np.ndarray,
        **kwargs        
    ) -> FeaturesList:
        """Create a new feature with transformed data."""
        
        return FeaturesList(
            features=[feat.with_transformed_data(transformed_data[:, start:end], **kwargs) 
                      for feat, (start, end) in self.features_and_indices],
        )
    
    def set_transformations(self, feature_names_to_transformation: dict[str, Transformation]) -> FeaturesList:
        """Set the transformations for the features."""
        
        if not isinstance(feature_names_to_transformation, dict):
            raise ValueError(f"Invalid feature_names_to_transformation type: "
                             f"{type(feature_names_to_transformation).__name__}."
                             f"Should be dict[str, Transformation]")
        
        for feat, transformation in feature_names_to_transformation.items():
            self[feat].set_transformation(transformation)
        self._transformed_data = None # Reset transformed data
        return self
        
    def _transform_data(self):
        self._feature_indices = []
        self._feature_name_indices = {}
        self._transformed_data = []
        start, end = 0, 0
        for i, feat in enumerate(self.features):
            transformed_data = feat.transformed_data
            end += transformed_data.shape[-1]
            self._feature_indices.append((start, end))
            self._feature_name_indices[feat.name] = (i, start, end)
            self._transformed_data.append(transformed_data)
            start = end

        self._transformed_data = jnp.concatenate(self._transformed_data, axis=-1)
    
    def transform(self, data: dict[str, jax.Array]):
        if not isinstance(data, dict):
            raise ValueError(f"Invalid data type: {type(data).__name__}, should be dict[str, jax.Array]")

        transformed_data = [None] * len(self.features)
        for feat_name, val in data.items():
            feat_idx, _, _ = self.feature_name_indices[feat_name]
            feat = self.features[feat_idx]
            transformed_data[feat_idx] = feat.transform(val)
        return jnp.concatenate(transformed_data, axis=-1)

    def inverse_transform(self, xs) -> dict[str, jax.Array]:
        orignial_data = {}
        for feat, (start, end) in self.features_and_indices:
            orignial_data[feat.name] = feat.inverse_transform(xs[:, start:end])
        return orignial_data

    def apply_constraints(self, xs, cfs, hard: bool = False):
        return jnp.concatenate(
            [feat.apply_constraints(xs[:, start:end], cfs[:, start:end], hard) for feat, (start, end) in self.features_and_indices], axis=-1)
    
    def compute_reg_loss(self, xs, cfs, hard: bool = False):
        reg_loss = 0.
        for feat, (start, end) in self.features_and_indices:
            reg_loss += feat.compute_reg_loss(xs[:, start:end], cfs[:, start:end], hard)
        return reg_loss

    def to_dict(self):
        return {
            'features': [feat.to_dict() for feat in self.features],
        }
    
    @classmethod
    def from_dict(cls, d):
        return cls(
            features=[Feature.from_dict(feat) for feat in d['features']],
        )
    
    def to_pandas(self, use_transformed: bool = False) -> pd.DataFrame:
        if use_transformed:
            # data = {feat.name: feat.transformed_data.reshape(-1) for feat in self.features}
            raise NotImplementedError
        else:
            data = {feat.name: feat.data.reshape(-1) for feat in self.features}
        return pd.DataFrame(data=data)

    def save(self, saved_dir):
        os.makedirs(saved_dir, exist_ok=True)
        save_pytree(self.to_dict(), saved_dir)
        
    @classmethod
    def load_from_path(cls, saved_dir):
        return cls.from_dict(load_pytree(saved_dir))

In [None]:
df = pd.read_csv('assets/adult/data/data.csv')
cont_feats = ['age', 'hours_per_week']
cat_feats = ["workclass", "education", "marital_status","occupation", "race", "gender"]

feats_list = FeaturesList([
    Feature(name, df[name].to_numpy().reshape(-1, 1), 'minmax') for name in cont_feats
] + [
    Feature(name, df[name].to_numpy().reshape(-1, 1), 'ohe') for name in cat_feats
])
assert feats_list.transformed_data.shape == (32561, 29)

In [None]:
# test __get_item__
assert np.allclose(
    feats_list['age'].transformed_data,
    feats_list.transformed_data[:, 0:1]
)
assert np.allclose(
    FeaturesList(feats_list[['age', 'hours_per_week', 'workclass']]).transformed_data,
    feats_list.transformed_data[:, :6]
)

In [None]:
# Test with_transformed_data
transformed_xs = feats_list.transformed_data
indices = np.random.choice(len(transformed_xs), size=100)
feats_list_1 = feats_list.with_transformed_data(transformed_xs[indices])

pd.testing.assert_frame_equal(
    feats_list.to_pandas().iloc[indices].reset_index(drop=True),
    feats_list_1.to_pandas(),
    check_exact=False,
    check_dtype=False,
    check_index_type=False
)

In [None]:
# Test set_transformations
feats_list_2 = deepcopy(feats_list)
feats_list_2.set_transformations({
    feat: 'ordinal' for feat in cat_feats
})
assert feats_list_2.transformed_data.shape == (32561, 8)

for feat in feats_list_2:
    if feat.name in cat_feats:  
        assert feat.transformation.name == 'ordinal'
        assert feat.is_categorical
    else:
        assert feat.transformation.name == 'minmax'                       
        assert feat.is_categorical is False
    assert feat.is_immutable is False

In [None]:
# Test transform and inverse_transform
# Convert df to dict[str, np.ndarray]
df_dict = {k: np.array(v).reshape(-1, 1) for k, v in df.iloc[:, :-1].to_dict(orient='list').items()}
# feats_list.transform(df_dict) should be the same as feats_list.transformed_data
transformed_data = feats_list.transform(df_dict)
assert np.equal(feats_list.transformed_data, transformed_data).all()
# feats_list.inverse_transform(transformed_data) should be the same as df_dict
inverse_transformed_data = feats_list.inverse_transform(transformed_data)
pd.testing.assert_frame_equal(
    pd.DataFrame.from_dict({k: v.reshape(-1) for k, v in inverse_transformed_data.items()}),
    pd.DataFrame.from_dict({k: v.reshape(-1) for k, v in df_dict.items()}),
    check_dtype=False,
)

In [None]:
# Test apply_constraints and compute_reg_loss
cfs = np.random.randn(10, 29)
constraint_cfs = feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=False)
assert constraint_cfs.shape == (10, 29)
assert np.allclose(
    constraint_cfs[:, 2:].sum(axis=-1),
    np.ones((10,)) * 6
)
assert constraint_cfs[: :2].min() >= 0 and constraint_cfs[: :2].max() <= 1
assert feats_list.apply_constraints(feats_list.transformed_data[:10, :], cfs, hard=True).shape == (10, 29)

reg_loss = feats_list.compute_reg_loss(feats_list.transformed_data, cfs)
assert jnp.ndim(reg_loss) == 0
assert np.all(reg_loss > 0)
assert np.allclose(feats_list.compute_reg_loss(xs, constraint_cfs), 0)

In [None]:
# Test `to_pandas`
feats_pd = feats_list.to_pandas()
pd.testing.assert_frame_equal(
    feats_pd,
    pd.DataFrame.from_dict({k: v.reshape(-1) for k, v in df_dict.items()}),
    check_dtype=False,
)

In [None]:
# Test save and load
feats_list.save('tmp/data_module/')
feats_list_1 = FeaturesList.load_from_path('tmp/data_module/')
# remove tmp folder
shutil.rmtree('tmp')

In [None]:
sk_ohe = skp.OneHotEncoder(sparse_output=False)
sk_minmax = skp.MinMaxScaler()

# for feat in feats_list.features:
for feat in feats_list:
    if feat.name in cont_feats:
        assert np.allclose(
            sk_minmax.fit_transform(feat.data),
            feat.transformed_data,
        ), f"Failed at {feat.name}. "
    else:
        assert np.allclose(
            sk_ohe.fit_transform(feat.data),
            feat.transformed_data,
        ), f"Failed at {feat.name}"