# Feature Transformation

In [None]:
#| default_exp data_utils.transforms

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import annotations
from relax.data_utils.preprocessing import *
from relax.utils import get_config, gumbel_softmax
from relax.import_essentials import *

In [None]:
#| hide
import sklearn.preprocessing as skp
from fastcore.test import test_fail
from copy import deepcopy

In [None]:
#| export
class BaseTransformation:
    """Base class for all transformations."""
    
    def __init__(self, name, transformer: DataPreprocessor = None):
        self.name = name
        self.transformer = transformer

    @property
    def is_categorical(self) -> bool:   raise NotImplementedError

    def fit(self, xs, y=None):          raise NotImplementedError

    def transform(self, xs):            raise NotImplementedError

    def fit_transform(self, xs, y=None):raise NotImplementedError

    def inverse_transform(self, xs):    raise NotImplementedError

    def apply_constraints(self, xs, cfs, hard, rng_key, **kwargs):
        raise NotImplementedError
    
    def compute_reg_loss(self, xs, cfs, hard: bool = False):
        raise NotImplementedError

    def from_dict(self, params):        raise NotImplementedError    

    def to_dict(self):                  raise NotImplementedError

In [None]:
#| export
class _DefaultTransformation(BaseTransformation):

    @property
    def is_categorical(self) -> bool:
        if self.transformer is None:
            return False
        return isinstance(self.transformer, EncoderPreprocessor)

    def fit(self, xs, y=None):
        if self.transformer is not None:
            self.transformer.fit(xs)
        return self
    
    def transform(self, xs):
        if self.transformer is None:
            return xs
        return self.transformer.transform(xs)

    def fit_transform(self, xs, y=None):
        if self.transformer is None:
            return xs
        return self.transformer.fit_transform(xs)
    
    def inverse_transform(self, xs):
        if self.transformer is None:
            return xs
        return self.transformer.inverse_transform(xs)

    def apply_constraints(self, xs: jax.Array, cfs: jax.Array, hard: bool = False, 
                          rng_key: jrand.PRNGKey = None, **kwargs):
        return cfs
    
    def compute_reg_loss(self, xs, cfs, hard: bool = False):
        return 0.
    
    def from_dict(self, params: dict):
        self.name = params["name"]
        if not 'transformer' in params.keys():
            self.transformer = None
        else:
            self.transformer.from_dict(params["transformer"])
        return self
    
    def to_dict(self) -> dict:
        return {"name": self.name, "transformer": self.transformer.to_dict()}

In [None]:
#| export
class MinMaxTransformation(_DefaultTransformation):
    def __init__(self):
        super().__init__("minmax", MinMaxScaler())

    def apply_constraints(self, xs, cfs, **kwargs):
        return np.clip(cfs, 0., 1.)

In [None]:
xs = np.random.randn(100, 1)
minmax_t = MinMaxTransformation()
transformed_xs = minmax_t.fit_transform(xs)
assert np.allclose(minmax_t.inverse_transform(transformed_xs), xs)
assert minmax_t.is_categorical is False

x = np.random.randn(100, 1)
cf_constrained = minmax_t.apply_constraints(xs, x)
assert np.all(cf_constrained >= 0) and np.all(cf_constrained <= 1)

# Test from_dict and to_dict
scaler_1 = MinMaxTransformation().from_dict(minmax_t.to_dict())
assert np.allclose(minmax_t.transform(xs), scaler_1.transform(xs))

In [None]:
#| export
class _OneHotTransformation(_DefaultTransformation):
    def __init__(self, name: str = None):
        super().__init__(name, OneHotEncoder())

    @property
    def num_categories(self) -> int:
        return len(self.transformer.categories_)
    
    def hard_constraints(self, operand: tuple[jax.Array, jrand.PRNGKey, dict]): 
        x, rng_key, kwargs = operand
        return jax.nn.one_hot(jnp.argmax(x, axis=-1), self.num_categories)
    
    def soft_constraints(self, operand: tuple[jax.Array, jrand.PRNGKey, dict]):
        raise NotImplementedError

    def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):
        return jax.lax.cond(
            hard,
            true_fun=self.hard_constraints,
            false_fun=self.soft_constraints,
            operand=(cfs, rng_key, kwargs),
        )
    
    def compute_reg_loss(self, xs, cfs, hard: bool = False):
        reg_loss_per_xs = (cfs.sum(axis=-1, keepdims=True) - 1.0) ** 2
        return reg_loss_per_xs.mean()

In [None]:
#| export
class SoftmaxTransformation(_OneHotTransformation):
    def __init__(self): 
        super().__init__("ohe")

    def soft_constraints(self, operand: tuple[jax.Array, jrand.PRNGKey, dict]):
        x, rng_key, kwargs = operand
        return jax.nn.softmax(x, axis=-1)
    
class GumbelSoftmaxTransformation(_OneHotTransformation):
    """Apply Gumbel softmax tricks for categorical transformation."""

    def __init__(self, tau: float = .1):
        super().__init__("gumbel")
        self.tau = tau
    
    def soft_constraints(self, operand: tuple[jax.Array, jrand.PRNGKey, dict]):
        x, rng_key, _ = operand
        if rng_key is None: # No randomness
            rng_key = jrand.PRNGKey(get_config().global_seed)
        return gumbel_softmax(rng_key, x, self.tau)
    
    def apply_constraints(self, xs, cfs, hard: bool = False, rng_key=None, **kwargs):
        """Apply constraints to the counterfactuals. If `rng_key` is None, no randomness is used."""
        return super().apply_constraints(xs, cfs, hard, rng_key, **kwargs)
    
    def to_dict(self) -> dict:
        return super().to_dict() | {"tau": self.tau}
    
def OneHotTransformation():
    warnings.warn("OneHotTransformation is deprecated since v0.2.5. "
                  "Use `SoftmaxTransformation`.", DeprecationWarning)
    return SoftmaxTransformation()

In [None]:
def test_ohe_t(ohe_cls):
    xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
    ohe_t = ohe_cls().fit(xs)
    transformed_xs = ohe_t.transform(xs)
    rng_key = jax.random.PRNGKey(get_config().global_seed)
    assert ohe_t.is_categorical

    x = jax.random.uniform(rng_key, shape=(100, 3))
    # Test hard=True which applies softmax function.
    soft = ohe_t.apply_constraints(transformed_xs, x, hard=False, rng_key=rng_key)
    assert jnp.allclose(soft.sum(axis=-1), 1)
    assert jnp.all(soft >= 0)
    assert jnp.all(soft <= 1)
    assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, soft, hard=False))
    assert jnp.allclose(soft, ohe_t.apply_constraints(transformed_xs, x, hard=False))

    # Test hard=True which enforce one-hot constraint.
    hard = ohe_t.apply_constraints(transformed_xs, x, hard=True, rng_key=rng_key)
    assert np.all([1 in x for x in hard])
    assert np.all([0 in x for x in hard])
    assert jnp.allclose(hard.sum(axis=-1), 1)
    assert jnp.allclose(jnp.zeros((len(x), 1)), ohe_t.compute_reg_loss(xs, hard, hard=False))

    # Test compute_reg_loss
    assert jnp.ndim(ohe_t.compute_reg_loss(xs, soft, hard=False)) == 0

    # Test from_dict and to_dict
    ohe_t_1 = ohe_cls().from_dict(ohe_t.to_dict())
    assert np.allclose(ohe_t.transform(xs), ohe_t_1.transform(xs))


test_ohe_t(SoftmaxTransformation)
test_ohe_t(GumbelSoftmaxTransformation)

In [None]:
#| export
class OrdinalTransformation(_DefaultTransformation):
    def __init__(self):
        super().__init__("ordinal", OrdinalPreprocessor())

    @property
    def num_categories(self) -> int:
        return len(self.transformer.categories_)
    
class IdentityTransformation(_DefaultTransformation):
    def __init__(self):
        super().__init__("identity", None)

    def fit(self, xs, y=None):
        return self
    
    def transform(self, xs):
        return xs
    
    def fit_transform(self, xs, y=None):
        return xs

    def apply_constraints(self, xs, cfs, **kwargs):
        return cfs
    
    def to_dict(self):
        return {'name': 'identity'}
    
    def from_dict(self, params: dict):
        self.name = params["name"]
        return self

In [None]:
xs = np.random.choice(['a', 'b', 'c'], size=(100, 1))
encoder = OrdinalTransformation().fit(xs)
transformed_xs = encoder.transform(xs)
assert np.all(encoder.inverse_transform(transformed_xs) == xs)
assert encoder.is_categorical

# Test from_dict and to_dict
encoder_1 = OrdinalTransformation().from_dict(encoder.to_dict())
assert np.allclose(encoder.transform(xs), encoder_1.transform(xs))

xs = np.random.randn(100, 1)
scaler = IdentityTransformation()
transformed_xs = scaler.fit_transform(xs)
assert np.all(transformed_xs == xs)

# Test from_dict and to_dict
scaler_1 = IdentityTransformation().from_dict(scaler.to_dict())
assert np.allclose(scaler.transform(xs), scaler_1.transform(xs))

In [None]:
#| export
FEATURE_TRANSFORMATIONS = {
    'ohe': SoftmaxTransformation,
    'softmax': SoftmaxTransformation,
    'gumbel': GumbelSoftmaxTransformation,
    'minmax': MinMaxTransformation,
    'ordinal': OrdinalTransformation,
    'identity': IdentityTransformation,
}