# Data Module

> `DataModule` for training parametric models, generating CF explanations, and benchmarking.

In [None]:
#| default_exp data.module

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import annotations
from cfnet.import_essentials import *
from cfnet.utils import load_json, validate_configs, cat_normalize
from cfnet.data.loader import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
from sklearn.base import TransformerMixin
from urllib.request import urlretrieve

In [None]:
#| export
class BaseDataModule(ABC):
    """DataModule Interface"""

    @property
    @abstractmethod
    def data_name(self) -> str: 
        return
        
    @property
    @abstractmethod
    def data(self) -> Any:
        return
    
    @property
    @abstractmethod
    def train_dataset(self) -> Dataset:
        return
    
    @property
    @abstractmethod
    def val_dataset(self) -> Dataset:
        return

    @property
    @abstractmethod
    def test_dataset(self) -> Dataset:
        return

    def train_dataloader(self, batch_size):
        raise NotImplementedError

    def val_dataloader(self, batch_size):
        raise NotImplementedError

    def test_dataloader(self, batch_size):
        raise NotImplementedError

    @abstractmethod
    def prepare_data(self) -> Any:
        raise NotImplementedError

    @abstractmethod
    def transform(self, data) -> jnp.DeviceArray:
        raise NotImplementedError

    @abstractmethod
    def inverse_transform(self, x: jnp.DeviceArray) -> Any:
        raise NotImplementedError

    @abstractmethod
    def apply_constraints(self, x: jnp.DeviceArray) -> jnp.DeviceArray:
        return x
    
    

## Tabula Data Module

`DataModule` for processing tabular data.


#### Process Data

In [None]:
#| export
def find_imutable_idx_list(
    imutable_col_names: List[str],
    discrete_col_names: List[str],
    continuous_col_names: List[str],
    cat_arrays: List[List[str]],
) -> List[int]:
    imutable_idx_list = []
    for idx, col_name in enumerate(continuous_col_names):
        if col_name in imutable_col_names:
            imutable_idx_list.append(idx)

    cat_idx = len(continuous_col_names)

    for i, (col_name, cols) in enumerate(zip(discrete_col_names, cat_arrays)):
        cat_end_idx = cat_idx + len(cols)
        if col_name in imutable_col_names:
            imutable_idx_list += list(range(cat_idx, cat_end_idx))
        cat_idx = cat_end_idx
    return imutable_idx_list

In [None]:
#| exporti
def _check_cols(data: pd.DataFrame, configs: TabularDataModuleConfigs) -> pd.DataFrame:
    data = data.astype({
        col: float for col in configs.continous_cols
    })
    
    cols = configs.continous_cols + configs.discret_cols
    # check target columns
    target_col = data.columns[-1]
    assert not target_col in cols, \
        f"continous_cols or discret_cols contains target_col={target_col}."
    
    # check imutable cols
    for col in configs.imutable_cols:
        assert col in cols, \
            f"imutable_cols=[{col}] is not specified in `continous_cols` or `discret_cols`."
    return data


In [None]:
#| exporti
def _process_data(
    df: pd.DataFrame | None, configs: TabularDataModuleConfigs
) -> pd.DataFrame:
    if df is None:
        df = pd.read_csv(configs.data_dir)
    elif isinstance(df, pd.DataFrame):
        df = df
    else:
        raise ValueError(f"{type(df).__name__} is not supported as an input type for `TabularDataModule`.")

    df = _check_cols(df, configs)
    return df

In [None]:
#| export
def _transform_df(
    transformer: TransformerMixin,
    data: pd.DataFrame,
    cols: List[str] | None,
):
    return (
        transformer.transform(data[cols])
            if cols else np.array([[] for _ in range(len(data))])
    )

In [None]:
#| hide
# test
df = pd.read_csv('assets/data/s_adult.csv')
cols = ['age', 'hours_per_week']
sca = MinMaxScaler().fit(df[cols])
x = _transform_df(sca, df, cols)
assert x.shape == (len(df), len(cols))

cols = []
sca = MinMaxScaler()
x = _transform_df(sca, df, cols)
assert x.shape == (len(df), len(cols))

In [None]:
#| export
def _inverse_transform_np(
    transformer: TransformerMixin,
    x: jnp.DeviceArray,
    cols: List[str] | None
):
    assert len(cols) <= x.shape[-1], \
        f"x.shape={x.shape} probably will not match len(cols)={len(cols)}"
    if cols:
        data = transformer.inverse_transform(x)
        return pd.DataFrame(data=data, columns=cols)
    else:
        return None


In [None]:
#| hide
# test
df = pd.read_csv('assets/data/s_adult.csv')
cols = ['age', 'hours_per_week']
sca = MinMaxScaler().fit(df[cols])
x = _transform_df(sca, df, cols)
data = _inverse_transform_np(sca, x, cols)

assert x.shape == (len(df), len(cols))
assert np.allclose(df[cols].values, data.values)

cols = []
sca = MinMaxScaler()
x = _transform_df(sca, df, cols)
data = _inverse_transform_np(sca, x, cols)

assert x.shape == (len(df), len(cols))
assert data is None

In [None]:
#| hide
# test
df = pd.read_csv('assets/data/s_adult.csv')
cols = ['workclass', 'education']
sca = OneHotEncoder().fit(df[cols])
x = _transform_df(sca, df, cols)
data = _inverse_transform_np(sca, x, cols)

assert x.shape[0] == len(df)
assert df[cols].equals(data)

cols = []
sca = OneHotEncoder()
x = _transform_df(sca, df, cols)
data = _inverse_transform_np(sca, x, cols)

assert x.shape == (len(df), len(cols))
assert data is None

In [None]:
#| exporti
def _init_scalar_encoder(
    data: pd.DataFrame,
    configs: TabularDataModuleConfigs
):  
    # fit scalar
    if configs.normalizer:
        scalar = configs.normalizer
    else:
        scalar = MinMaxScaler()
        if configs.continous_cols:
            scalar.fit(data[configs.continous_cols])
    
    X_cont = _transform_df(
        scalar, data, configs.continous_cols
    )

    # fit encoder
    if configs.encoder:
        encoder = configs.encoder
    else:
        encoder = OneHotEncoder(sparse=False)
        if configs.discret_cols:
            encoder.fit(data[configs.discret_cols])
    
    X_cat = _transform_df(
        encoder, data, configs.discret_cols
    )
    return dict(
        X_cont=X_cont, X_cat=X_cat, scalar=scalar, encoder=encoder
    )


#### Module

In [None]:
#| export
class TabularDataModuleConfigs(BaseParser):
    """Config of `TabularDataModule`."""

    data_dir: str
    data_name: str
    discret_cols: List[str] = []
    continous_cols: List[str] = []
    imutable_cols: List[str] = []
    normalizer: Optional[Any] = None
    encoder: Optional[Any] = None
    sample_frac: Optional[float] = None
    backend: str = 'jax'

In [None]:
#| export
class TabularDataModule(BaseDataModule):
    """DataModule for tabular data"""
    cont_scalar = None # scalar for normalizing continuous features
    cat_encoder = None # encoder for encoding categorical features

    def __init__(
        self, 
        data_config: dict | TabularDataModuleConfigs, # Configurator of `TabularDataModule`
        df: pd.DataFrame = None # Dataframe which overrides `data_dir` in `data_config` (if not None)
    ):
        self._configs: TabularDataModuleConfigs = validate_configs(
            data_config, TabularDataModuleConfigs
        )
        self._data = _process_data(df, self._configs)
        # init idx lists
        self.cat_idx = len(self._configs.continous_cols)
        self._imutable_idx_list = []
        self.prepare_data()

    def prepare_data(self):
        scalar_encoder_dict = _init_scalar_encoder(
            data=self._data, configs=self._configs
        )
        self.cont_scalar = scalar_encoder_dict['scalar']
        self.cat_encoder = scalar_encoder_dict['encoder']
        X = np.concatenate(
            (scalar_encoder_dict['X_cont'], scalar_encoder_dict['X_cat']),
            axis=1
        )
        y  = self._data.iloc[:, -1:] # last column is the target columns

        self._imutable_idx_list = find_imutable_idx_list(
            imutable_col_names=self._configs.imutable_cols,
            discrete_col_names=self._configs.discret_cols,
            continuous_col_names=self._configs.continous_cols,
            cat_arrays=self.cat_encoder.categories_,
        )
        
        # prepare train & test
        train_test_tuple = train_test_split(X, y.to_numpy(), shuffle=False)
        train_X, test_X, train_y, test_y = map(
             lambda x: x.astype(float), train_test_tuple
         )
        if self._configs.sample_frac:
            train_size = int(len(train_X) * self._configs.sample_frac)
            train_X, train_y = train_X[:train_size], train_y[:train_size]
        
        self._train_dataset = Dataset(train_X, train_y)
        self._val_dataset = Dataset(test_X, test_y)
        self._test_dataset = self.val_dataset

    @property
    def data_name(self) -> str: 
        return self._configs.data_name
    
    @property
    def data(self) -> Any:
        return self._data
    
    @property
    def train_dataset(self) -> Dataset:
        return self._train_dataset
    
    @property
    def val_dataset(self) -> Dataset:
        return self._val_dataset

    @property
    def test_dataset(self) -> Dataset:
        return self._test_dataset

    def train_dataloader(self, batch_size):
        return DataLoader(self.train_dataset, self._configs.backend, 
            batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False
        )

    def val_dataloader(self, batch_size):
        return DataLoader(self.val_dataset, self._configs.backend,
            batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False
        )

    def test_dataloader(self, batch_size):
        return DataLoader(self.val_dataset, self._configs.backend,
            batch_size=batch_size, shuffle=True, num_workers=0, drop_last=False
        )

    def transform(self, data: pd.DataFrame) -> np.ndarray:
        # TODO: validate `data`
        X_cont = _transform_df(
            self.cont_scalar, data, self._configs.continous_cols
        )
        X_cat = _transform_df(
            self.cat_encoder, data, self._configs.discret_cols
        )
        return np.concatenate((X_cont, X_cat), axis=1)

    def inverse_transform(self, x: jnp.DeviceArray) -> pd.DataFrame:
        X_cont_df = _inverse_transform_np(
            self.cont_scalar, x[:, :self.cat_idx], self._configs.continous_cols
        )
        X_cat_df = _inverse_transform_np(
            self.cat_encoder, x[:, self.cat_idx:], self._configs.discret_cols
        )
        return pd.concat(
            [X_cont_df, X_cat_df], axis=1
        )

    def apply_constraints(
        self, 
        cf: jnp.DeviceArray, 
        hard: bool = False
    ) -> jnp.DeviceArray:
        cat_arrays = self.cat_encoder.categories_ \
            if self._configs.discret_cols else []
        cf = cat_normalize(
            cf, cat_arrays=cat_arrays, 
            cat_idx=len(self._configs.continous_cols),
            hard=hard
        )
        return cf

    def project(self, x: jnp.DeviceArray, cf: jnp.DeviceArray) -> jnp.DeviceArray:
        cf = cf.at[:, self._imutable_idx_list].set(x[:, self._imutable_idx_list])
        return cf


In [None]:
#| export
def samples(datamodule: BaseDataModule, frac: float = 1.0): 
    X, y = datamodule.train_dataset[:]
    size = int(len(X) * frac)
    return X[:size], y[:size]

In [None]:
#| hide
def check_datamodule(dm: TabularDataModule, data_configs: dict):
    batch_size = 256

    X, y = dm.train_dataset[:]
    assert X.shape[0] == len(y)

    X, y = dm.val_dataset[:]
    assert X.shape[0] == len(y)

    X, y = dm.test_dataset[:]
    assert X.shape[0] == len(y)

    dl = dm.train_dataloader(batch_size)
    x, y = next(iter(dl))
    assert x.shape[0] == batch_size

    dl = dm.val_dataloader(batch_size)
    x, y = next(iter(dl))
    assert x.shape[0] == batch_size

    dl = dm.test_dataloader(batch_size)
    x, y = next(iter(dl))
    assert x.shape[0] == batch_size

    ############################################################
    # test `transform` and `inverse_transform`
    ############################################################
    df = dm.inverse_transform(X)
    assert len(df) == len(X)
    assert len(df.columns) == \
        len(data_configs['continous_cols'] + data_configs['discret_cols']) 

    data = dm.transform(df)
    assert np.allclose(X, data)
    
    ############################################################
    # test `apply_constraints` ad `projec`
    ##########################################################
    cat_idx = len(data_configs['continous_cols'])
    n_cat_feat = len(data_configs['discret_cols'])
    dl = dm.test_dataloader(batch_size)
    x, y = next(iter(dl))
    cf = random.normal(
        random.PRNGKey(0), x.shape
    )
    cf = dm.apply_constraints(cf)
    assert jnp.allclose(jnp.sum(cf[:, cat_idx:]), len(cf) * n_cat_feat)

    cf = dm.apply_constraints(cf, hard=True)
    assert jnp.count_nonzero(cf == 1) == len(cf) * n_cat_feat

    cf = dm.project(x, cf)
    assert x.shape == cf.shape

In [None]:
#| hide
from copy import deepcopy

In [None]:
#| hide
data_configs = {
    "data_dir": "assets/data/s_adult.csv",
    "data_name": "adult",
    'sample_frac': 0.1,
    "continous_cols": ["age", "hours_per_week"],
    "discret_cols": [
        "workclass", "education", "marital_status",
        "occupation", "race", "gender"
    ],
}
dm = TabularDataModule(data_configs)
check_datamodule(dm, data_configs)

# immutable
_data_configs = deepcopy(data_configs)
_data_configs["imutable_cols"] = ["race","gender"]
dm = TabularDataModule(data_configs)
check_datamodule(dm, data_configs)

# no cont
_data_configs = deepcopy(data_configs)
_data_configs['continous_cols'] = []
dm = TabularDataModule(data_configs)
check_datamodule(dm, data_configs)

# no cat
_data_configs = deepcopy(data_configs)
_data_configs['discret_cols'] = []
dm = TabularDataModule(data_configs)
check_datamodule(dm, data_configs)


## Load Data

High-level interfaces for loading default data

In [None]:
#| exporti
DEFAULT_DATA_CONFIGS = {
    'adult': {
        'data' :'assets/data/s_adult.csv',
        'conf' :'assets/configs/data_configs/adult.json',
    },
    'heloc': {
        'data': 'assets/data/s_home.csv',
        'conf': 'assets/configs/data_configs/home.json'
    },
    'oulad': {
        'data': 'assets/data/s_student.csv',
        'conf': 'assets/configs/data_configs/student.json'
    }
}

In [None]:
#| exporti
def _validate_dataname(data_name: str):
    if data_name not in DEFAULT_DATA_CONFIGS.keys():
        raise ValueError(f'`data_name` must be one of {DEFAULT_DATA_CONFIGS.keys()}, '
            f'but got data_name={data_name}.')

In [None]:
#| export
def load_data(
    data_name: str, return_config: bool = False
) -> TabularDataModule:
    _validate_dataname(data_name)

    # get data/config urls
    _data_path = DEFAULT_DATA_CONFIGS[data_name]['data']
    _conf_path = DEFAULT_DATA_CONFIGS[data_name]['conf']
    
    data_url = f"https://github.com/BirkhoffG/cfnet/raw/master/{_data_path}"
    conf_url = f"https://github.com/BirkhoffG/cfnet/raw/master/{_conf_path}"

    # create new dir
    data_dir = Path(os.getcwd()) / "cf_data"
    if not data_dir.exists():
        os.makedirs(data_dir)
    data_path = data_dir / f'{data_name}.csv'
    conf_path = data_dir / f'{data_name}.json'

    # download data/configs
    if not data_path.is_file():
        urlretrieve(data_url, data_path)    
    if not conf_path.is_file():
        urlretrieve(conf_url, conf_path)

    # read config
    config = load_json(conf_path)['data_configs']
    config['data_dir'] = str(data_path)

    data_module = TabularDataModule(config)

    if return_config:
        return data_module, config
    else:
        return data_module


In [None]:
#| hide
for data_name in DEFAULT_DATA_CONFIGS.keys():
    dm, config = load_data(data_name, return_config=True)
    check_datamodule(dm, config)