# Data Module

> `DataModule` for training parametric models, generating and benchmarking CF explanations.

In [None]:
#| default_exp data.module

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from relax.utils import show_doc
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)
show_doc_parser = show_doc

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.utils import load_json, validate_configs, cat_normalize
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
from sklearn.base import TransformerMixin
from sklearn.utils.validation import check_is_fitted, NotFittedError
from urllib.request import urlretrieve
from relax.data.loader import Dataset, ArrayDataset, DataLoader, DataloaderBackends
from pydantic.fields import ModelField

In [None]:
#| hide
from fastcore.test import *
from copy import deepcopy

## Data Module Interfaces

High-level interfaces for `DataModule`. Docs to be added. 

In [None]:
#| export
class BaseDataModule(ABC):
    """DataModule Interface"""

    @property
    @abstractmethod
    def data_name(self) -> str: 
        return
        
    @property
    @abstractmethod
    def data(self) -> Any:
        return
    
    @property
    @abstractmethod
    def train_dataset(self) -> Dataset:
        return
    
    @property
    @abstractmethod
    def val_dataset(self) -> Dataset:
        return

    @property
    @abstractmethod
    def test_dataset(self) -> Dataset:
        return

    def dataset(self, name: str) -> Dataset:
        raise NotImplementedError

    def train_dataloader(self, batch_size):
        raise NotImplementedError

    def val_dataloader(self, batch_size):
        raise NotImplementedError

    def test_dataloader(self, batch_size):
        raise NotImplementedError

    @abstractmethod
    def prepare_data(self) -> Any:
        raise NotImplementedError

    @abstractmethod
    def transform(self, data) -> jnp.DeviceArray:
        raise NotImplementedError

    @abstractmethod
    def inverse_transform(self, x: jnp.DeviceArray) -> Any:
        raise NotImplementedError

    def apply_constraints(
        self, 
        x: jnp.DeviceArray,
        cf: jnp.DeviceArray,
        hard: bool
    ) -> jnp.DeviceArray:
        return cf
    
    def apply_regularization(
        self, 
        x: jnp.DeviceArray,
        cf: jnp.DeviceArray,
        hard: bool
    ):
        raise NotImplementedError


## Tabula Data Module

`DataModule` for processing tabular data.


In [None]:
#| hide
#| export
def find_imutable_idx_list(
    imutable_col_names: List[str],
    discrete_col_names: List[str],
    continuous_col_names: List[str],
    cat_arrays: List[List[str]],
) -> List[int]:
    imutable_idx_list = []
    for idx, col_name in enumerate(continuous_col_names):
        if col_name in imutable_col_names:
            imutable_idx_list.append(idx)

    cat_idx = len(continuous_col_names)

    for i, (col_name, cols) in enumerate(zip(discrete_col_names, cat_arrays)):
        cat_end_idx = cat_idx + len(cols)
        if col_name in imutable_col_names:
            imutable_idx_list += list(range(cat_idx, cat_end_idx))
        cat_idx = cat_end_idx
    return imutable_idx_list

In [None]:
#| export
class TransformerMixinType(TransformerMixin):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        if not isinstance(v, TransformerMixin):
            raise TypeError("`sklearn.base.TransformerMixin` required")
        return v
    
    @classmethod
    def __modify_schema__(
        cls, field_schema: Dict[str, Any], field: Optional[ModelField]
    ):
        if field:
            field_schema['type'] = 'TransformerMixin'


In [None]:
#| export
def _supported_backends(): 
    back = DataloaderBackends()
    return back.supported()

class TabularDataModuleConfigs(BaseParser):
    """Configurator of `TabularDataModule`."""

    data_dir: str = Field(description="The directory of dataset.")
    data_name: str = Field(description="The name of `TabularDataModule`.")
    continous_cols: List[str] = Field(
        [], description="Continuous features/columns in the data."
    )
    discret_cols: List[str] = Field(
        [], description="Categorical features/columns in the data."
    )
    imutable_cols: List[str] = Field(
        [], description="Immutable features/columns in the data."
    )
    normalizer: Optional[TransformerMixinType] = Field(
        default_factory=lambda: MinMaxScaler(),
        description="Sklearn scalar for continuous features. Can be unfitted, fitted, or None. "
        "If not fitted, the `TabularDataModule` will fit using the training data. If fitted, no fitting will be applied. "
        "If `None`, no transformation will be applied. Default to `MinMaxScaler()`."
    )
    encoder: Optional[TransformerMixinType] = Field(
        default_factory=lambda: OneHotEncoder(sparse=False),
        description="Fitted encoder for categorical features. Can be unfitted, fitted, or None. "
        "If not fitted, the `TabularDataModule` will fit using the training data. If fitted, no fitting will be applied. "
        "If `None`, no transformation will be applied. Default to `OneHotEncoder(sparse=False)`."
    )
    sample_frac: Optional[float] = Field(
        None, description="Sample fraction of the data. Default to use the entire data.", 
        ge=0., le=1.0
    )
    backend: str = Field(
        "jax", description=f"`Dataloader` backend. Currently supports: {_supported_backends()}"
    )

    class Config:
        json_encoders = {
            TransformerMixinType: lambda v: f"{v.__class__.__name__}()",
        }

An example configurator of the **adult** dataset:

In [None]:
configs_dict = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "continous_cols": ["age", "hours_per_week"],
    "discret_cols": ["workclass", "education", "marital_status","occupation"],
    "imutable_cols": ["age", "workclass", "marital_status"],
    "normalizer": MinMaxScaler(),
    "encoder": OneHotEncoder(sparse=False),
    "sample_frac": 0.1,
    "backend": "jax"
}
configs = TabularDataModuleConfigs(**configs_dict)

In [None]:
#| exporti
def _check_cols(data: pd.DataFrame, configs: TabularDataModuleConfigs) -> pd.DataFrame:
    data = data.astype({
        col: float for col in configs.continous_cols
    })
    
    cols = configs.continous_cols + configs.discret_cols
    # check target columns
    target_col = data.columns[-1]
    assert not target_col in cols, \
        f"continous_cols or discret_cols contains target_col={target_col}."
    
    # check imutable cols
    for col in configs.imutable_cols:
        assert col in cols, \
            f"imutable_cols=[{col}] is not specified in `continous_cols` or `discret_cols`."
    data = data[cols + [target_col]]
    return data


In [None]:
#| exporti
def _process_data(
    df: pd.DataFrame | None, configs: TabularDataModuleConfigs
) -> pd.DataFrame:
    if df is None:
        df = pd.read_csv(configs.data_dir)
    elif isinstance(df, pd.DataFrame):
        df = df
    else:
        raise ValueError(
            f"{type(df).__name__} is not supported as an input type for `TabularDataModule`.")

    df = _check_cols(df, configs)
    return df

In [None]:
#| hide
df_1 = _process_data(None, configs)
_df_t = pd.read_csv(configs.data_dir)
df_2 = _process_data(_df_t, configs)
assert pd.DataFrame.equals(df_1, df_2)

_config_fail = deepcopy(configs)
_config_fail.data_dir = "assets/data/s_adult_1.csv" # wrong data path
test_fail(lambda: _process_data(None, _config_fail), contains="No such file or directory")
test_fail(lambda: _process_data("fail", configs), contains="not supported as an input type")

In [None]:
#| exporti
def _transform_df(
    transformer: TransformerMixin | None,
    data: pd.DataFrame,
    cols: List[str] | None,
):
    if transformer is None:
        return data[cols].to_numpy() if cols else np.array([[] for _ in range(len(data))])
    else:
        return (
            transformer.transform(data[cols])
                if cols else np.array([[] for _ in range(len(data))])
        )

In [None]:
#| hide
# test
df = pd.read_csv('assets/data/s_adult.csv')
cols = ['age', 'hours_per_week']
sca = MinMaxScaler().fit(df[cols])
x = _transform_df(sca, df, cols)
assert x.shape == (len(df), len(cols))

cols = []
sca = MinMaxScaler()
x = _transform_df(sca, df, cols)
assert x.shape == (len(df), len(cols))

cols = None
sca = MinMaxScaler()
x = _transform_df(sca, df, cols)
assert x.shape == (len(df), 0)


cols = ['age', 'hours_per_week']
sca = None
x = _transform_df(sca, df, cols)
assert np.allclose(x, df[cols].to_numpy())

cols = ["workclass", "education"]
enc = OneHotEncoder(sparse=False).fit(df[cols])
x = _transform_df(enc, df, cols)
assert x.shape == (len(df), enc.categories_[0].shape[0] + enc.categories_[1].shape[0])

In [None]:
#| export
def _inverse_transform_np(
    transformer: TransformerMixin | None,
    x: np.ndarray,
    cols: List[str] | None
):
    assert len(cols) <= x.shape[-1], \
        f"x.shape={x.shape} probably will not match len(cols)={len(cols)}"
    
    if cols:
        data = transformer.inverse_transform(x) if transformer else x
        df = pd.DataFrame(data=data, columns=cols)
    else:
        df = None
    return df


In [None]:
#| hide
# test
df = pd.read_csv('assets/data/s_adult.csv')
cols = ['age', 'hours_per_week']
sca = MinMaxScaler().fit(df[cols])
x = _transform_df(sca, df, cols)
data = _inverse_transform_np(sca, x, cols)

assert x.shape == (len(df), len(cols))
assert np.allclose(df[cols].values, data.values)

cols = []
sca = MinMaxScaler()
x = _transform_df(sca, df, cols)
data = _inverse_transform_np(sca, x, cols)

assert x.shape == (len(df), len(cols))
assert data is None

cols = ['age', 'hours_per_week']
sca = None
x = _transform_df(sca, df, cols)
data = _inverse_transform_np(sca, x, cols)

assert x.shape == (len(df), len(cols))
assert np.allclose(df[cols].values, data.values)

cols = ['workclass', 'education']
sca = OneHotEncoder().fit(df[cols])
x = _transform_df(sca, df, cols)
data = _inverse_transform_np(sca, x, cols)

assert x.shape[0] == len(df)
assert df[cols].equals(data)

cols = []
sca = OneHotEncoder()
x = _transform_df(sca, df, cols)
data = _inverse_transform_np(sca, x, cols)

assert x.shape == (len(df), len(cols))
assert data is None

In [None]:
#| exporti
def _init_scalar_encoder(
    data: pd.DataFrame,
    configs: TabularDataModuleConfigs
) -> Dict[str, TransformerMixin | None]: 
    # The normlizer and encoder will be either None, fitted or not fitted.
    # If the user has specified the normlizer and encoder, then we will use it.
    # Otherwise, we will fit the normlizer and encoder.
    # fit scalar
    if configs.normalizer is not None:
        scalar = configs.normalizer
        try:
            check_is_fitted(scalar)
        except NotFittedError:
            if configs.continous_cols:  scalar.fit(data[configs.continous_cols])
            else:                       scalar = None
    else:
        scalar = None
    
    if configs.encoder is not None:
        encoder = configs.encoder
        try:
            check_is_fitted(encoder)
        except NotFittedError:
            if configs.discret_cols:    encoder.fit(data[configs.discret_cols])
            else:                       encoder = None
    else:
        encoder = None
    return dict(scalar=scalar, encoder=encoder)


In [None]:
#| hide
# test
configs_dict = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    "continous_cols": ["age", "hours_per_week"],
    "discret_cols": ["workclass", "education", "marital_status","occupation"],
    "imutable_cols": ["age", "workclass", "marital_status"],
    "normalizer": None,
    "encoder": OneHotEncoder(),
    "sample_frac": 0.1,
    "backend": "jax"
}
configs = TabularDataModuleConfigs(**configs_dict)
df = pd.read_csv(configs.data_dir)
scalar_and_encoder = _init_scalar_encoder(df, configs)

assert scalar_and_encoder["scalar"] is None
check_is_fitted(scalar_and_encoder["encoder"])

scalar = MinMaxScaler().fit(df[configs.continous_cols].iloc[:100])
configs_dict["normalizer"] = deepcopy(scalar)
configs = TabularDataModuleConfigs(**configs_dict)
scalar_and_encoder = _init_scalar_encoder(df, configs)

# check if the scalar is refitted
assert np.allclose(
    scalar_and_encoder["scalar"].transform(df[configs.continous_cols].iloc[100:200]),
    scalar.transform(df[configs.continous_cols].iloc[100:200])
) 

In [None]:
#| export
class TabularDataModule(BaseDataModule):
    """DataModule for tabular data"""
    cont_scalar = None # scalar for normalizing continuous features
    cat_encoder = None # encoder for encoding categorical features
    __initialized = False

    def __init__(
        self, 
        data_config: dict | TabularDataModuleConfigs, # Configurator of `TabularDataModule`
        data: pd.DataFrame = None # Data in `pd.DataFrame`. If `data` is `None`, the DataModule will load data from `data_dir`.
    ):
        self._configs: TabularDataModuleConfigs = validate_configs(
            data_config, TabularDataModuleConfigs
        )
        self._data = _process_data(data, self._configs)
        # init idx lists
        self.cat_idx = len(self._configs.continous_cols)
        self._imutable_idx_list = []
        self.prepare_data()

    def prepare_data(self):
        scalar_encoder_dict = _init_scalar_encoder(
            data=self._data, configs=self._configs
        )
        self.cont_scalar = scalar_encoder_dict['scalar']
        self.cat_encoder = scalar_encoder_dict['encoder']
        X, y = self.transform(self.data)

        self._cat_arrays = self.cat_encoder.categories_ \
            if self._configs.discret_cols else []

        self._imutable_idx_list = find_imutable_idx_list(
            imutable_col_names=self._configs.imutable_cols,
            discrete_col_names=self._configs.discret_cols,
            continuous_col_names=self._configs.continous_cols,
            cat_arrays=self._cat_arrays,
        )
        
        # prepare train & test
        train_test_tuple = train_test_split(X, y, shuffle=False)
        train_X, test_X, train_y, test_y = map(
             lambda x: x.astype(float), train_test_tuple
         )
        if self._configs.sample_frac:
            train_size = int(len(train_X) * self._configs.sample_frac)
            train_X, train_y = train_X[:train_size], train_y[:train_size]
        
        self._train_dataset = ArrayDataset(train_X, train_y)
        self._val_dataset = ArrayDataset(test_X, test_y)
        self._test_dataset = self.val_dataset

        self.__initialized = True

    def __setattr__(self, attr: str, val: Any) -> None:
        if self.__initialized and attr in (
            '_data', 'cat_idx', '_imutable_idx_list', '_cat_arrays',
            '_train_dataset', '_val_dataset', '_test_dataset',
            'cont_scalar', 'cat_encoder'
        ):
            raise ValueError(f'{attr} attribute should not be set after '
                             f'{self.__class__.__name__} is initialized')

        super().__setattr__(attr, val)

    @property
    def data_name(self) -> str: 
        return self._configs.data_name
    
    @property
    def data(self) -> pd.DataFrame:
        """Loaded data in `pd.DataFrame`."""
        return self._data
    
    @property
    def train_dataset(self) -> ArrayDataset:
        return self._train_dataset
    
    @property
    def val_dataset(self) -> ArrayDataset:
        return self._val_dataset

    @property
    def test_dataset(self) -> ArrayDataset:
        return self._test_dataset

    def dataset(
        self, name: str # Name of the dataset; should be one of ['train', 'val', 'test'].
    ) -> ArrayDataset:
        if name == 'train': return self._train_dataset
        elif name == 'val': return self._val_dataset
        elif name == 'test': return self._test_dataset
        else: raise ValueError(f"`name` must be one of ['train', 'val', 'test'], but got {name}")

    def train_dataloader(self, batch_size):
        return DataLoader(self.train_dataset, self._configs.backend, 
            batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False
        )

    def val_dataloader(self, batch_size):
        return DataLoader(self.val_dataset, self._configs.backend,
            batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False
        )

    def test_dataloader(self, batch_size):
        return DataLoader(self.val_dataset, self._configs.backend,
            batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False
        )

    def transform(
        self, 
        data: pd.DataFrame, # Data to be transformed to `numpy.ndarray`
    ) -> Tuple[np.ndarray, np.ndarray]: # Return `(X, y)`
        """Transform data into numerical representations."""
        # TODO: validate `data`
        X_cont = _transform_df(
            self.cont_scalar, data, self._configs.continous_cols
        )
        X_cat = _transform_df(
            self.cat_encoder, data, self._configs.discret_cols
        )
        X = np.concatenate((X_cont, X_cat), axis=1)
        y = data.iloc[:, -1:].to_numpy()
        
        return X, y

    def inverse_transform(
        self, 
        x: jnp.DeviceArray, # The transformed input to be scaled back
        y: jnp.DeviceArray = None # The transformed label to be scaled back. If `None`, the target columns will not be scaled back.
    ) -> pd.DataFrame: # Transformed `pd.DataFrame`. 
        """Scaled back into `pd.DataFrame`."""
        X_cont_df = _inverse_transform_np(
            self.cont_scalar, x[:, :self.cat_idx], self._configs.continous_cols
        )
        X_cat_df = _inverse_transform_np(
            self.cat_encoder, x[:, self.cat_idx:], self._configs.discret_cols
        )
        if y is not None:
            y_df = pd.DataFrame(data=y, columns=[self.data.columns[-1]])
        else:
            y_df = None
        
        return pd.concat(
            [X_cont_df, X_cat_df, y_df], axis=1
        )

    def apply_constraints(
        self, 
        x: jnp.DeviceArray, # input
        cf: jnp.DeviceArray, # Unnormalized counterfactuals
        hard: bool = False # Apply hard constraints or not
    ) -> jnp.DeviceArray:
        """Apply categorical normalization and immutability constraints"""
        cf = cat_normalize(
            cf, cat_arrays=self._cat_arrays, 
            cat_idx=len(self._configs.continous_cols),
            hard=hard
        )
        # apply immutable constraints
        if len(self._configs.imutable_cols) > 0:
            cf = cf.at[:, self._imutable_idx_list].set(x[:, self._imutable_idx_list])
        return cf

    def apply_regularization(
        self, 
        x: jnp.DeviceArray, # Input
        cf: jnp.DeviceArray, # Unnormalized counterfactuals
    ) -> float: # Return regularization loss
        """Apply categorical constraints by adding regularization terms"""
        reg_loss = 0.
        cat_idx = len(self._configs.continous_cols)

        for col in self._cat_arrays:
            cat_idx_end = cat_idx + len(col)
            reg_loss += jnp.power(
                (jnp.sum(cf[cat_idx:cat_idx_end]) - 1.0), 2
            )
        return reg_loss


To load `TabularDataModule` from `TabularDataModuleConfigs`,

In [None]:
configs = TabularDataModuleConfigs(
    data_name='adult',
    data_dir='assets/data/s_adult.csv',
    continous_cols=['age', 'hours_per_week'],
    discret_cols=['workclass', 'education', 'marital_status', 'occupation'],
    imutable_cols=['age', 'workclass', 'marital_status'],
    sample_frac=0.1
)

dm = TabularDataModule(configs)

We can also explicitly pass a `pd.DataFrame` to `TabularDataModule`. 
In this case, `TabularDataModule` will use the passed `pd.DataFrame`, instead of loading data from `data_dir` in `TabularDataModuleConfigs`. 

In [None]:
df = pd.read_csv('assets/data/s_adult.csv')[:1000]
dm = TabularDataModule(configs, data=df)
assert len(dm.data) == 1000 # dm contains `df`

In [None]:
show_doc(TabularDataModule.data)

---

[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L340){target="_blank" style="float:right; font-size:smaller"}

### TABULARDATAMODULE.DATA

::: {.doc-sig}

 relax.data.module.<b>TabularDataModule.data</b> <em>()</em>

:::

Loaded data in `pd.DataFrame`.

`TabulaDataModule` loads either a csv file (specified in `data_dir` in `data_config`), 
or directly passes a DataFrame (specified as `data`). 
Either way, this data needs to satisfy following conditions:

* It requires the **target column** (i.e., the labels) to be the **last** column of the DataFrame, 
and the rest columns are **features**. 
This **target column** needs to be binary-valued (i.e., it is either `0` or `1`).
    * In the belowing example, **income** is the **target column**, and the rest columns are features.
* It requires `continous_cols` and `discret_cols` in `data_config` to be subsets of `data.columns`.
* It only use columns specified in `continous_cols` and `discret_cols`.
    * It loads `continous_cols` first, then `discret_cols`.


In [None]:
dm.data.head()

Unnamed: 0,age,hours_per_week,workclass,education,marital_status,occupation,income
0,42.0,45.0,Private,HS-grad,Married,Blue-Collar,1
1,32.0,40.0,Self-Employed,Some-college,Married,Blue-Collar,0
2,35.0,40.0,Private,Assoc,Single,White-Collar,1
3,36.0,40.0,Private,HS-grad,Single,Blue-Collar,0
4,57.0,35.0,Private,School,Married,Service,0


In [None]:
show_doc(TabularDataModule.transform)

---

[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L379){target="_blank" style="float:right; font-size:smaller"}

### TABULARDATAMODULE.TRANSFORM

::: {.doc-sig}

 relax.data.module.<b>TabularDataModule.transform</b> <em>(data)</em>

:::

Transform data into numerical representations.

::: {#docs .callout-note icon=false}

## Parameters:

* <b>data</b> (`pd.DataFrame`) -- Data to be transformed to `numpy.ndarray`


:::


::: {#docs .callout-note icon=false}




## Returns:

&ensp;&ensp;&ensp;&ensp;(`Tuple[np.ndarray, np.ndarray]`) -- Return `(X, y)`


:::

By default, we transform *continuous features* via `MinMaxScaler`, 
and *discrete features* via `OneHotEncoding`. 

A tabular data point $x$ is encoded as 
$$x = [\underbrace{x_{0}, x_{1}, ..., x_{m}}_{\text{cont features}}, 
\underbrace{x_{m+1}^{c=1},..., x_{m+p}^{c=1}}_{\text{cat feature} (1)}, ..., 
\underbrace{x_{k-q}^{c=i},..., x_{k}^{^{c=i}}}_{\text{cat feature} (i)}]$$

In [None]:
df = dm.data.head()
X, y = dm.transform(df)

assert isinstance(X, np.ndarray)
assert isinstance(y, np.ndarray)
assert y.shape == (len(X), 1)

In [None]:
show_doc(TabularDataModule.inverse_transform)

---

[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L396){target="_blank" style="float:right; font-size:smaller"}

### TABULARDATAMODULE.INVERSE_TRANSFORM

::: {.doc-sig}

 relax.data.module.<b>TabularDataModule.inverse_transform</b> <em>(x, y=None)</em>

:::

Scaled back into `pd.DataFrame`.

::: {#docs .callout-note icon=false}

## Parameters:

* <b>x</b> (`jnp.DeviceArray`) -- The transformed input to be scaled back
* <b>y</b> (`jnp.DeviceArray`, <em>default=None</em>) -- The transformed label to be scaled back. If `None`, the target columns will not be scaled back.


:::


::: {#docs .callout-note icon=false}




## Returns:

&ensp;&ensp;&ensp;&ensp;(`pd.DataFrame`) -- Transformed `pd.DataFrame`.


:::

`TabularDataModule.inverse_transform` scales numerical representations back
to the original DataFrame.

In [None]:
dm.inverse_transform(X, y)

Unnamed: 0,age,hours_per_week,workclass,education,marital_status,occupation,income
0,42.0,45.0,Private,HS-grad,Married,Blue-Collar,1
1,32.0,40.0,Self-Employed,Some-college,Married,Blue-Collar,0
2,35.0,40.0,Private,Assoc,Single,White-Collar,1
3,36.0,40.0,Private,HS-grad,Single,Blue-Collar,0
4,57.0,35.0,Private,School,Married,Service,0


If `y` is not passed, it will only scale back `X`.

In [None]:
dm.inverse_transform(X)

Unnamed: 0,age,hours_per_week,workclass,education,marital_status,occupation
0,42.0,45.0,Private,HS-grad,Married,Blue-Collar
1,32.0,40.0,Self-Employed,Some-college,Married,Blue-Collar
2,35.0,40.0,Private,Assoc,Single,White-Collar
3,36.0,40.0,Private,HS-grad,Single,Blue-Collar
4,57.0,35.0,Private,School,Married,Service


In [None]:
show_doc(TabularDataModule.apply_constraints)

---

[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L417){target="_blank" style="float:right; font-size:smaller"}

### TABULARDATAMODULE.APPLY_CONSTRAINTS

::: {.doc-sig}

 relax.data.module.<b>TabularDataModule.apply_constraints</b> <em>(x, cf, hard=False)</em>

:::

Apply categorical normalization and immutability constraints

::: {#docs .callout-note icon=false}

## Parameters:

* <b>x</b> (`jnp.DeviceArray`) -- input
* <b>cf</b> (`jnp.DeviceArray`) -- Unnormalized counterfactuals
* <b>hard</b> (`bool`, <em>default=False</em>) -- Apply hard constraints or not


:::


::: {#docs .callout-note icon=false}




## Returns:

&ensp;&ensp;&ensp;&ensp;(`jnp.DeviceArray`)


:::

`TabularDataModule.apply_constraints` does two things: 

1. It ensures that generated counterfactuals respect the one-hot encoding format (i.e., $\sum_{p \to q} x^{c=i}_{p} = 1$).
2. It ensures the immutability constraints (i.e., immutable features defined in `imutable_cols` will not be changed).

In [None]:
x, y = next(iter(dm.test_dataloader(batch_size=128)))
# unnormalized counterfactuals
cf = random.normal(
    random.PRNGKey(0), x.shape
)
# normalized counterfactuals
cf_normed = dm.apply_constraints(x, cf)

In [None]:
show_doc(TabularDataModule.apply_regularization)

---

[source](https://github.com/birkhoffg/relax/tree/master/blob/master/relax/data/module.py#L434){target="_blank" style="float:right; font-size:smaller"}

### TABULARDATAMODULE.APPLY_REGULARIZATION

::: {.doc-sig}

 relax.data.module.<b>TabularDataModule.apply_regularization</b> <em>(x, cf)</em>

:::

Apply categorical constraints by adding regularization terms

::: {#docs .callout-note icon=false}

## Parameters:

* <b>x</b> (`jnp.DeviceArray`) -- Input
* <b>cf</b> (`jnp.DeviceArray`) -- Unnormalized counterfactuals


:::


::: {#docs .callout-note icon=false}




## Returns:

&ensp;&ensp;&ensp;&ensp;(`float`) -- Return regularization loss


:::

In [None]:
x, y = next(iter(dm.test_dataloader(batch_size=128)))
# unnormalized counterfactuals
cf = random.normal(
    random.PRNGKey(0), x.shape
)
# normalized counterfactuals
cf_normed = dm.apply_constraints(x, cf)

In [None]:
#| export
def sample(datamodule: BaseDataModule, frac: float = 1.0): 
    X, y = datamodule.train_dataset[:]
    size = int(len(X) * frac)
    return X[:size], y[:size]

In [None]:
#| hide
def check_datamodule(dm: TabularDataModule, data_configs: dict | TabularDataModuleConfigs):
    batch_size = 32
    data_configs = validate_configs(data_configs, TabularDataModuleConfigs)
    cat_idx = len(data_configs.continous_cols)
    n_cat_feat = len(data_configs.discret_cols)

    feats, label = dm.train_dataset[:]
    assert feats.shape[0] == len(label)
    assert label.shape == (feats.shape[0], 1)

    X, y = dm.val_dataset[:]
    assert X.shape[0] == len(y)

    X, y = dm.test_dataset[:]
    assert X.shape[0] == len(y)

    dl = dm.train_dataloader(batch_size)
    x, y = next(iter(dl))
    assert x.shape[0] == batch_size

    dl = dm.val_dataloader(batch_size)
    x, y = next(iter(dl))
    assert x.shape[0] == batch_size

    dl = dm.test_dataloader(batch_size)
    x, y = next(iter(dl))
    assert x.shape[0] == batch_size

    ############################################################
    # test `transform` and `inverse_transform`
    ############################################################
    df = dm.inverse_transform(feats, label)
    assert len(df) == len(feats)
    assert len(df.columns) == cat_idx + n_cat_feat + 1

    X_transformed, y = dm.transform(df)
    assert np.allclose(feats, X_transformed)
    assert np.allclose(y, label)
    
    ############################################################
    # test `apply_constraints` ad `project`
    ##########################################################
    dl = dm.test_dataloader(batch_size)
    x, y = next(iter(dl))
    cf = random.normal(
        random.PRNGKey(0), x.shape
    )
    cf = dm.apply_constraints(x, cf, hard=False)
    assert jnp.allclose(jnp.sum(cf[:, cat_idx:]), len(cf) * n_cat_feat)

    cf = dm.apply_constraints(x, cf, hard=True)
    assert jnp.count_nonzero(cf == 1) == len(cf) * n_cat_feat


In [None]:
#| hide
data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    'sample_frac': 0.1,
    "continous_cols": ["age", "hours_per_week"],
    "discret_cols": [
        "workclass", "education", "marital_status",
        "occupation", "race", "gender"
    ],
}
dm = TabularDataModule(data_configs)
check_datamodule(dm, data_configs)

# immutable
_data_configs = deepcopy(data_configs)
_data_configs["imutable_cols"] = ["race","gender"]
dm = TabularDataModule(_data_configs)
check_datamodule(dm, _data_configs)

# no cont
_data_configs = deepcopy(data_configs)
_data_configs['continous_cols'] = []
dm = TabularDataModule(_data_configs)
check_datamodule(dm, _data_configs)

# no cat
_data_configs = deepcopy(data_configs)
_data_configs['discret_cols'] = []
dm = TabularDataModule(_data_configs)
check_datamodule(dm, _data_configs)

# https://github.com/BirkhoffG/ReLax/issues/134
data_configs = load_json('assets/configs/data_configs/breast_cancer.json')['data_configs']
dm = TabularDataModule(data_configs)
check_datamodule(dm, data_configs)

## Load Data


In [None]:
#| exporti
DEFAULT_DATA_CONFIGS = {
    'adult': {
        'data' :'assets/adult/data.csv',
        'conf' :'assets/adult/configs.json',
        'model' :'assets/adult/model'
    },
    'heloc': {
        'data': 'assets/heloc/data.csv',
        'conf': 'assets/heloc/configs.json',
        'model' :'assets/heloc/model'
    },
    'oulad': {
        'data': 'assets/oulad/data.csv',
        'conf': 'assets/oulad/configs.json',
        'model' :'assets/oulad/model'
    },
    'credit': {
        'data': 'assets/credit/data.csv',
        'conf': 'assets/credit/configs.json',
        'model' :'assets/credit/model'
    },
    'cancer': {
        'data': 'assets/cancer/data.csv',
        'conf': 'assets/cancer/configs.json',
        'model' :'assets/cancer/model'
    },
    'student_performance': {
        'data': 'assets/student_performance/data.csv',
        'conf': 'assets/student_performance/configs.json',
        'model' :'assets/student_performance/model'
    },
    'titanic': {
        'data': 'assets/titanic/data.csv',
        'conf': 'assets/titanic/configs.json',
        'model' :'assets/titanic/model'
    },
    'german': {
        'data': 'assets/german/data.csv',
        'conf': 'assets/german/configs.json',
        'model' :'assets/german/model'
    },
    'spam': {
        'data': 'assets/spam/data.csv',
        'conf': 'assets/spam/configs.json',
        'model' :'assets/spam/model'
    },
    'ozone': {
        'data': 'assets/ozone/data.csv',
        'conf': 'assets/ozone/configs.json',
        'model' :'assets/ozone/model'
    },
    'qsar': {
        'data': 'assets/qsar/data.csv',
        'conf': 'assets/qsar/configs.json',
        'model' :'assets/qsar/model'
    },
    'bioresponse': {
        'data': 'assets/bioresponse/data.csv',
        'conf': 'assets/bioresponse/configs.json',
        'model' :'assets/bioresponse/model'
    },
    'churn': {
        'data': 'assets/churn/data.csv',
        'conf': 'assets/churn/configs.json',
        'model' :'assets/churn/model'
    },
    'road': {
        'data': 'assets/road/data.csv',
        'conf': 'assets/road/configs.json',
        'model' :'assets/road/model'
    }
}

In [None]:
#| exporti
def _validate_dataname(data_name: str):
    if data_name not in DEFAULT_DATA_CONFIGS.keys():
        raise ValueError(f'`data_name` must be one of {DEFAULT_DATA_CONFIGS.keys()}, '
            f'but got data_name={data_name}.')

In [None]:
#| export
def load_data(
    data_name: str, # The name of data
    return_config: bool = False, # Return `data_config `or not
    data_configs: dict = None # Data configs to override default configuration
) -> TabularDataModule | Tuple[TabularDataModule, TabularDataModuleConfigs]: 
    """High-level util function for loading `data` and `data_config`."""
    
    _validate_dataname(data_name)

    # get data/config/model urls
    _data_path = DEFAULT_DATA_CONFIGS[data_name]['data']
    _conf_path = DEFAULT_DATA_CONFIGS[data_name]['conf']
    _model_path = DEFAULT_DATA_CONFIGS[data_name]['model']
    
    data_url = f"https://github.com/BirkhoffG/ReLax/raw/master/{_data_path}"
    conf_url = f"https://github.com/BirkhoffG/ReLax/raw/master/{_conf_path}"
    model_params_url = f"https://github.com/BirkhoffG/ReLax/raw/master/{_model_path}/params.npy"
    model_tree_url = f"https://github.com/BirkhoffG/ReLax/raw/master/{_model_path}/tree.pkl"

    # create new dir
    data_dir = Path(os.getcwd()) / "cf_data"
    if not data_dir.exists():
        os.makedirs(data_dir)
    data_path = data_dir / data_name / 'data.csv'
    conf_path = data_dir / data_name / 'configs.json'
    model_path = data_dir / data_name / "model"
    if not model_path.exists():
        os.makedirs(model_path)

    # download data/configs and trained model
    if not data_path.is_file():
        urlretrieve(data_url, data_path)    
    if not conf_path.is_file():
        urlretrieve(conf_url, conf_path)
    params_path = os.path.join(model_path, "params.npy")
    tree_path = os.path.join(model_path, "tree.pkl")
    if not (os.path.isfile(params_path) and os.path.isfile(tree_path)):
        urlretrieve(model_params_url, params_path)
        urlretrieve(model_tree_url, tree_path)

    # read config
    config = load_json(conf_path)['data_configs']
    config['data_dir'] = str(data_path)

    if not (data_configs is None):
        config.update(data_configs)

    config = TabularDataModuleConfigs(**config)
    data_module = TabularDataModule(config)

    if return_config:
        return data_module, config
    else:
        return data_module


`load_data` easily loads example datasets by passing the `data_name`. 
For example, you can load the [adult](https://archive.ics.uci.edu/ml/datasets/adult) as:

In [None]:
dm = load_data(data_name = 'adult')

Underlying, `load_data` loads the default `data_configs`. To access this `data_configs`,

In [None]:
dm, data_configs = load_data(data_name = 'adult', return_config=True)

If you want to override some of the data configs, 
you can pass it as an auxillary argumenet in `data_configs`. 
For example, if you want to use only 10% of the data, you can

In [None]:
dm = load_data(
    data_name = 'adult', data_configs={'sample_frac': 0.1}
)

#### Supported Datasets

`load_data` currently supports following datasets:

In [None]:
#|echo: false
def display_data_attrbutes(names: list):
    attrs = {
        '# Cont Features': { data_name: 0 for data_name in names}, 
        '# Cat Features': { data_name: 0 for data_name in names},
        '# of Data Points': { data_name: 0 for data_name in names}, 
    }
    for data_name in names:
        dm, config = load_data(data_name, return_config=True)
        attrs['# Cont Features'][data_name] = len(config.continous_cols)
        attrs['# Cat Features'][data_name] = len(config.discret_cols)
        attrs['# of Data Points'][data_name] = len(dm.data)

        # run tests
        check_datamodule(dm, config)
    return pd.DataFrame.from_dict(attrs)

display_data_attrbutes(DEFAULT_DATA_CONFIGS.keys())

Unnamed: 0,# Cont Features,# Cat Features,# of Data Points
adult,2,6,32561
heloc,21,2,10459
oulad,23,8,32593
credit,20,3,30000
cancer,30,0,569
student_performance,2,14,649
titanic,2,24,891
german,7,13,1000
spam,57,0,4601
ozone,72,0,2534


In [None]:
DEFAULT_DATA_CONFIGS.keys()

dict_keys(['adult', 'heloc', 'oulad', 'credit', 'cancer', 'student_performance', 'titanic', 'german', 'spam', 'ozone', 'qsar', 'bioresponse', 'churn', 'road'])

In [None]:
#| hide
for data_name in DEFAULT_DATA_CONFIGS.keys():
    dm, config = load_data(
        data_name, return_config=True, data_configs=dict(sample_frac=0.1)
    )
    assert config.sample_frac == 0.1
    check_datamodule(dm, config)