In [None]:
from __future__ import annotations
from sklearn.preprocessing import OneHotEncoder
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp

In [None]:
class Transformation:
    def __init__(self, name, transformer):
        self.name = name
        self.transformer = transformer

    def fit(self, xs, y=None):
        self.transformer.fit(xs)
        return self
    
    def transform(self, xs):
        return self.transformer.transform(xs)
    
    def inverse_transform(self, xs):
        return self.transformer.inverse_transform(xs)

    def apply_constraint(self, xs):
        return xs

In [None]:
class OneHotTransformation(Transformation):
    def __init__(self):
        super().__init__("ohe", OneHotEncoder(sparse=False))

    @property
    def categories(self) -> int:
        return len(self.transformer.categories_)

    def apply_constraint(self, xs, cfs, hard: bool = False):
        return jax.lax.cond(
            hard,
            true_fun=lambda x: jax.nn.one_hot(jnp.argmax(x, axis=-1), self.categories),
            false_fun=lambda x: jax.nn.softmax(x, axis=-1),
            operand=cfs,
        )

In [None]:
PREPROCESSING_TRANSFORMATIONS = {
    'ohe': OneHotTransformation(),
}

In [None]:
class Feature:
    
    def __init__(
        self,
        name: str,
        data: np.ndarray,
        transformer: str | Transformation,
        transformed_data = None,
        is_immutable: bool = False,
    ):
        self.name = name
        self.data = data
        if isinstance(transformer, str):
            self.transformer = PREPROCESSING_TRANSFORMATIONS[transformer]
        elif isinstance(transformer, Transformation):
            self.transformer = transformer
        else:
            raise ValueError(f"Unknown transformer {transformer}")
        self._fit()
        if transformed_data is None:
            self.transformed_data = self.transform(self.data)
        else:
            self.transformed_data = transformed_data
        self.is_immutable = is_immutable

    @classmethod
    def from_dict(cls, d):
        return cls(**d)
    
    def to_dict(self):
        return {
            'name': self.name,
            'data': self.data,
            'transformed_data': self.transformed_data,
            'transformer': self.transformer,
            'is_immutable': self.is_immutable,
        }
    
    def __repr__(self):
        return f"Feature(" \
               f"name={self.name}, \ndata={self.data}, \n" \
               f"transformed_data={self.transformed_data}, \n" \
               f"transformer={self.transformer}, \n" \
               f"is_immutable={self.is_immutable})"
    
    __str__ = __repr__

    def __get_item__(self, idx):
        return {
            'data': self.data[idx],
            'transformed_data': self.transformed_data[idx],
        }

    def _fit(self):
        self.transformer.fit(self.data)
        return self
    
    def transform(self, xs):
        return self.transformer.transform(xs)
    
    def inverse_transform(self, xs):
        return self.transformer.inverse_transform(xs)
    
    def apply_constraint(self, xs, cfs, hard: bool = False):
        return jax.lax.cond(
            self.is_immutable,
            true_fun=lambda xs: xs,
            false_fun=lambda _: self.transformer.apply_constraint(xs, cfs, hard),
            operand=xs,
        )

In [None]:
class Dataset:
    def __init__(
        self, 
        data: dict[str, np.ndarray], 
        continuous_cols: list[str],
        categorical_cols: list[str],
        immutable_cols: list[str],
        continuous_transformer: str = 'minmax',
        categorical_transformer: str = 'ohe',
    ):
        self._features, self._transformed_data = self.convert_to_features(
            data, continuous_cols, categorical_cols, immutable_cols, 
            continuous_transformer, categorical_transformer
        )

    @property
    def features(self):
        return self._features
    
    @property
    def transformed_data(self):
        return self._transformed_data

    def convert_to_features(
        self, 
        data: dict[str, np.ndarray], 
        continuous_cols: list[str],
        categorical_cols: list[str],
        immutable_cols: list[str],
        continuous_transformer: str = 'minmax',
        categorical_transformer: str = 'ohe',
    ):
        features = []
        transformed_data = []
        # TODO: This can be done in parallel
        for name, xs in data.items():
            is_immutable = name in immutable_cols
            if name in continuous_cols:
                transfomer = continuous_transformer
            elif name in categorical_cols:
                transfomer = categorical_transformer
            else:
                raise ValueError(f"Unknown column type {name}")
            feat = Feature(name, xs, transformer=transfomer, is_immutable=is_immutable)
            features.append(feat)
            transformed_data.append(feat.transformed_data)
        return features, np.concatenate(transformed_data, axis=-1)
    
    def transform(self, data):
        return np.concatenate(
            [feat.transform(data[feat.name]) for feat in self._features], axis=-1)
    
    def inverse_transform(self, data):
        pass

    # def apply_constraint(self, xs, cfs, hard: bool = False):
    #     for feat in self.features:
