# Data Module

> `DataModule` for training parametric models, generating and benchmarking CF explanations.

In [None]:
#| default_exp data_module

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from __future__ import annotations
from relax.utils import load_json, validate_configs, get_config, save_pytree, load_pytree, get_config
from relax.base import *
from relax.data_utils import *
from relax.import_essentials import *
import jax
from jax import numpy as jnp, random as jrand, lax, Array
import pandas as pd
import numpy as np
from pathlib import Path
import json, os, shutil
from urllib.request import urlretrieve
from pydantic.fields import ModelField, Field
from typing import List, Dict, Union, Optional, Tuple, Callable, Any, Iterable
import warnings
from pandas.testing import assert_frame_equal

In [None]:
#| hide
from fastcore.test import *
from copy import deepcopy
from sklearn.datasets import make_classification


## Data Module Interfaces

High-level interfaces for `DataModule`. Docs to be added. 

In [None]:
#| export
class BaseDataModule(BaseModule):
    """DataModule Interface"""

    def _prepare(self, *args, **kwargs):
        """Prepare data for training"""
        raise NotImplementedError
        
    def apply_constraints(self, x: Array, cf: Array, hard: bool = False, **kwargs) -> Array:
        raise NotImplementedError
    
    def compute_reg_loss(self, x: Array, cf: Array, hard: bool = False, **kwargs) -> float:
        raise NotImplementedError

In [None]:
#| exporti
class DataModuleInfoMixin:
    """This base class exposes some attributes of DataModule
    at the base level for easy access.
    """

    @property
    def data(self) -> pd.DataFrame:
        return self._data
    
    @property
    def xs(self) -> Array:
        return self._features.transformed_data
    
    @property
    def ys(self) -> Array:
        return self._label.transformed_data
    
    @property
    def features(self) -> FeaturesList:
        return self._features
    
    @property
    def label(self) -> FeaturesList:
        return self._label

    @property
    def dataset(self) -> Tuple[Array, Array]:
        return (self.xs, self.ys)
    
    @property
    def train_indices(self) -> List[int]:
        return self.config.train_indices
    
    @property
    def test_indices(self) -> List[int]:
        return self.config.test_indices


## Data Module

`DataModule` for processing data, training models, and benchmarking CF explanations.


### Config

In [None]:
#| export
class DataModuleConfig(BaseConfig):
    """Configurator of `DataModule`."""

    data_dir: str = Field(None, description="The directory of dataset.")
    data_name: str = Field(None, description="The name of `DataModule`.")
    continous_cols: List[str] = Field([], description="Continuous features/columns in the data.")
    discret_cols: List[str] = Field([], description="Categorical features/columns in the data.")
    imutable_cols: List[str] = Field([], description="Immutable features/columns in the data.")
    continuous_transformation: Optional[str] = Field('minmax', description="Transformation for continuous features. `None` indicates unknown.")
    discret_transformation: Optional[str] = Field('ohe', description="Transformation for categorical features. `None` indicates unknown.")
    sample_frac: Optional[float] = Field(
        None, description="Sample fraction of the data. Default to use the entire data.", ge=0., le=1.0
    )
    train_indices: List[int] = Field([], description="Indices of training data.")
    test_indices: List[int] = Field([], description="Indices of testing data.")
    
    def shuffle(self, data: Array, test_size: float, seed: int = None):
        """Shuffle data with a seed."""
        if seed is None:
            seed = get_config().global_seed
        key = jrand.PRNGKey(seed)
        total_length = data.shape[0]
        train_length = int((1 - test_size) * total_length)
        if len(self.train_indices) == 0:
            self.train_indices = jrand.permutation(key, total_length)[:train_length].tolist()
        if len(self.test_indices) == 0:
            self.test_indices = jrand.permutation(key, total_length)[train_length:].tolist()

In [None]:
#| hide
config = DataModuleConfig(data_name="TabularDataModule", data_dir="data", continous_cols=[], discret_cols=[], imutable_cols=[])
# Test shuffle
assert len(config.train_indices) == 0
assert len(config.test_indices) == 0
config.shuffle(data=np.arange(100), test_size=0.2)
assert len(config.train_indices) == 100 * 0.8
assert len(config.test_indices) == 100 * 0.2
assert isinstance(config.train_indices, list)
assert isinstance(config.test_indices, list)
assert (sorted(config.train_indices + config.test_indices) == list(range(100)))

configs_dict = {
    "data_dir": "assets/adult/data/data.csv",
    "data_name": "adult",
    "continous_cols": ["age", "hours_per_week"], 
    "discret_cols": ["workclass", "education", "marital_status", "occupation", "race", "gender"], 
    "imutable_cols": ["race", "gender"],
    "sample_frac": 0.1,
}
configs = DataModuleConfig(**configs_dict)
assert len(configs.train_indices) == 0
assert len(configs.test_indices) == 0
assert config.continuous_transformation == 'minmax'
assert config.discret_transformation == 'ohe'

### Utils

> util functions for `DataModule`

In [None]:
#| export
def features2config(
    features: FeaturesList, # FeaturesList to be converted
    name: str, # Name of the data used for `DataModuleConfig`
    return_dict: bool = False # Whether to return a dict or `DataModuleConfig`
) -> Union[DataModuleConfig, Dict]: # Return configs
    """Get `DataModuleConfig` from `FeaturesList`."""

    cont, cats, immu = [], [], []
    cont_transformation, cat_transformation = None, None
    for f in features:
        if f.is_categorical:
            cats.append(f.name)
        else:
            cont.append(f.name)
        if f.is_immutable:
            immu.append(f.name)
    
    configs_dict = {
        "data_dir": ".",
        "data_name": name,
        "continous_cols": cont,
        "discret_cols": cats,
        "imutable_cols": immu,
        "continuous_transformation": cont_transformation,
        "discret_transformation": cat_transformation,
    }
    if return_dict:
        return configs_dict
    return DataModuleConfig(**configs_dict)


In [None]:
#| hide
feats = FeaturesList([
    Feature("age", np.random.normal(0, 1, (10, 1)), transformation='minmax', is_immutable=True),
    Feature("workclass", np.random.randint(0, 2, (10, 1)), transformation='ohe'),
    Feature("education", np.random.randint(0, 2, (10, 1)), transformation='ordinal'),    
])
config = features2config(feats, "test")
assert config.data_dir == "."
assert config.data_name == "test"
assert config.continous_cols == ["age"]
assert config.discret_cols == ["workclass", "education"]
assert config.imutable_cols == ["age"]
assert config.continuous_transformation is None
assert config.discret_transformation is None

config_dict = features2config(feats, "test", return_dict=True)
assert isinstance(config_dict, dict)

In [None]:
#| export
def features2pandas(
    features: FeaturesList, # FeaturesList to be converted
    labels: FeaturesList # labels to be converted
) -> pd.DataFrame: # Return pandas dataframe
    """Convert `FeaturesList` to pandas dataframe."""
    
    feats_df = features.to_pandas()
    labels_df = labels.to_pandas()
    df = pd.concat([feats_df, labels_df], axis=1)
    return df

Example: 

In [None]:
feats = FeaturesList([
    Feature("age", np.random.normal(0, 1, (10, 1)), 
            transformation='minmax', is_immutable=True),
    Feature("workclass", np.random.randint(0, 2, (10, 1)), 
            transformation='ohe'),
    Feature("education", np.random.randint(0, 2, (10, 1)), 
            transformation='ordinal'),    
])
labels = FeaturesList([
    Feature("income", np.random.randint(0, 2, (10, 1)), 
            transformation='identity'),
])
df = features2pandas(feats, labels)
assert isinstance(df, pd.DataFrame)
assert df.shape == (10, 4)

In [None]:
#| exporti
def to_feature(col: str, data: pd.DataFrame, config: DataModuleConfig, transformation: str):
    return Feature(
        name=col, data=data[col].to_numpy().reshape(-1, 1),
        transformation=transformation,
        is_immutable=col in config.imutable_cols
    )

In [None]:
#| export
def dataframe2features(
    data: pd.DataFrame,
    config: DataModuleConfig,
) -> FeaturesList:
    """Convert pandas dataframe of features to `FeaturesList`."""

    cont_features = [to_feature(col, data, config, config.continuous_transformation) for col in config.continous_cols]
    cat_features = [to_feature(col, data, config, config.discret_transformation) for col in config.discret_cols]
    features = cont_features + cat_features
    return FeaturesList(features)


def dataframe2labels(
    data: pd.DataFrame,
    config: DataModuleConfig,
) -> FeaturesList:
    """Convert pandas dataframe of labels to `FeaturesList`."""
    
    label_cols = set(data.columns) - set(config.continous_cols) - set(config.discret_cols)
    labels = [to_feature(col, data, config, 'identity') for col in label_cols]
    return FeaturesList(labels)

### Main Data Module

> Main module.

In [None]:
#| export
class DataModule(BaseDataModule, DataModuleInfoMixin):
    """DataModule for tabular data."""

    def __init__(
        self, 
        features: FeaturesList,
        label: FeaturesList,
        config: DataModuleConfig = None,
        data: pd.DataFrame = None,
        **kwargs
    ):
        self._prepare(features, label)
        if config is None:
            name = kwargs.pop('name', 'DataModule')
            config = features2config(features, name)
        config.shuffle(self.xs, test_size=0.25)
        self._data = features2pandas(features, label) if data is None else data
        super().__init__(config, name=config.data_name)

    def _prepare(self, features, label):
        if features is not None and label is not None:
            self._features = FeaturesList(features)
            self._label = FeaturesList(label)
        elif features is None:
            raise ValueError("Features cannot be None.")
        elif label is None:
            raise ValueError("Label cannot be None.")
            
    def save(
        self, 
        path: str # Path to the directory to save `DataModule`
    ):
        """Save `DataModule` to a directory."""
        path = Path(path)
        if not path.exists():
            path.mkdir(parents=True)
        self._features.save(path / 'features')
        self._label.save(path / 'label')
        if self._data is not None:
            self._data.to_csv(path / 'data.csv', index=False)
        with open(path / "config.json", "w") as f:
            json.dump(self.config.dict(), f)

    @classmethod
    def load_from_path(
        cls, 
        path: str,  # Path to the directory to load `DataModule`
        config: Dict|DataModuleConfig = None # Configs of `DataModule`. This argument is ignored.
    ) -> DataModule: # Initialized `DataModule` from path
        """Load `DataModule` from a directory."""
        if config is not None:
            warnings.warn("Passing `config` will have no effect.")
        
        path = Path(path)
        config = DataModuleConfig.load_from_json(path / 'config.json')
        # config = validate_configs(config, DataModuleConfig)
        features = FeaturesList.load_from_path(path / 'features')
        label = FeaturesList.load_from_path(path / 'label')
        data = pd.read_csv(path / 'data.csv')
        return cls(features=features, label=label, config=config, data=data)
    
    @classmethod
    def from_path(cls, path, config: DataModuleConfig = None):
        """Alias of `load_from_path`"""
        return cls.load_from_path(path, config)
    
    @classmethod
    def from_config(
        cls, 
        config: Dict|DataModuleConfig, # Configs of `DataModule`
        data: pd.DataFrame=None # Passed in pd.Dataframe
    ) -> DataModule: # Initialized `DataModule` from configs and data
        config = validate_configs(config, DataModuleConfig)
        if data is None:
            data = pd.read_csv(config.data_dir)
        if not isinstance(data, pd.DataFrame):
            raise ValueError("`data` should be a pandas DataFrame.")
        features = dataframe2features(data, config)
        label = dataframe2labels(data, config)
        return cls(features=features, label=label, config=config, data=data)
    
    @classmethod
    def from_numpy(
        cls,
        xs: np.ndarray, # Input data
        ys: np.ndarray, # Labels
        name: str = None, # Name of `DataModule`
        transformation='minmax'
    ) -> DataModule: # Initialized `DataModule` from numpy arrays
        """Create `DataModule` from numpy arrays. Note that the `xs` are treated as continuous features."""
        
        features = FeaturesList([Feature(f"feature_{i}", xs[:, i].reshape(-1, 1), transformation=transformation) for i in range(xs.shape[1])])
        labels = FeaturesList([Feature(f"label", ys.reshape(-1, 1), transformation='identity')])
        return cls(features=features, label=labels, name=name)
    
    @classmethod
    def from_features(
        cls, 
        features: FeaturesList, # Features of `DataModule`
        label: FeaturesList, # Labels of `DataModule`
        name: str = None # Name of `DataModule`
    ) -> DataModule: # Initialized `DataModule` from features and labels
        """Create `DataModule` from `FeaturesList`."""
        return cls(features=features, label=label, name=name)
        
    def _get_data(self, indices):
        if isinstance(indices, list):
            indices = jnp.array(indices)
        return (self.xs[indices], self.ys[indices])
        
    def __getitem__(self, name: str):
        if name == 'train':
            return self._get_data(self.config.train_indices)
        elif name in ['valid', 'test']:
            return self._get_data(self.config.test_indices)
        else:
            raise ValueError(f"Unknown data name: {name}. Should be one of ['train', 'valid', 'test']")
    
    def set_transformations(
        self, 
        feature_names_to_transformation: Dict[str, Union[str, Dict, BaseTransformation]], # Dict[feature_name, Transformation]
    ) -> DataModule:
        """Reset transformations for features."""

        self._features = self._features.set_transformations(feature_names_to_transformation)
        return self
    
    def sample(
        self, 
        size: float | int, # Size of the sample. If float, should be 0<=size<=1.
        stage: str = 'train', # Stage of data to sample from. Should be one of ['train', 'valid', 'test']
        key: jrand.PRNGKey = None # Random key. 
    ) -> Tuple[Array, Array]: # Sampled data
        """Sample data from `DataModule`."""

        key = jrand.PRNGKey(get_config().global_seed) if key is None else key
        xs, ys = self[stage]
        indices = jnp.arange(xs.shape[0])
        
        if isinstance(size, float) and 0 <= size <= 1:
            size = int(size * indices.shape[0])
        elif isinstance(size, int):
            size = min(size, indices.shape[0])
        else:
            raise ValueError(f"`size` should be a floating number 0<=size<=1, or an integer,"
                             f" but got size={size}.")
                
        indices = jrand.permutation(key, indices)[:size]
        return xs[indices], ys[indices]

    def transform(
        self, 
        data: pd.DataFrame | Dict[str, Array] # Data to be transformed
    ) -> Array: # Transformed data
        """Transform data to `jax.Array`."""
        # TODO: test this function
        if isinstance(data, pd.DataFrame):
            data_dict = {k: np.array(v).reshape(-1, 1) for k, v in data.iloc[:, :-1].to_dict(orient='list').items()}
            return self._features.transform(data_dict)
        elif isinstance(data, dict):
            data = jax.tree_util.tree_map(lambda x: np.array(x).reshape(-1, 1), data)
            return self._features.transform(data)
        else:
            raise ValueError("data should be a pandas DataFrame or `Dict[str, jax.Array]`.")
        
    def inverse_transform(
        self, 
        data: Array, # Data to be inverse transformed
        return_type: str = 'pandas' # Type of the returned data. Should be one of ['pandas', 'dict']
    ) -> pd.DataFrame: # Inverse transformed data
        """Inverse transform data to `pandas.DataFrame`."""
        # TODO: test this function
        inversed = self._features.inverse_transform(data)
        if return_type == 'pandas':
            return inversed
        elif return_type == 'dict':
            raise NotImplementedError
        else:
            raise ValueError(f"Unknown return type: {return_type}. Should be one of ['pandas', 'dict']")
        
    def apply_constraints(
        self, 
        xs: Array, # Input data
        cfs: Array, # Counterfactuals to be constrained
        hard: bool = False, # Whether to apply hard constraints or not
        rng_key: jrand.PRNGKey = None, # Random key
        **kwargs
    ) -> Array: # Constrained counterfactuals
        """Apply constraints to counterfactuals."""
        return self._features.apply_constraints(xs, cfs, hard=hard, rng_key=rng_key, **kwargs)
    
    def compute_reg_loss(
        self, 
        xs: Array, # Input data
        cfs: Array, # Counterfactuals to be constrained
        hard: bool = False # Whether to apply hard constraints or not
    ) -> float:
        """Compute regularization loss."""
        return self._features.compute_reg_loss(xs, cfs, hard)
    
    __ALL__ = [
        'load_from_path', 
        'from_config', 
        'from_features',
        'from_numpy',
        'save',
        'transform',
        'inverse_transform',
        'apply_constraints',
        'compute_reg_loss',
        'set_transformations',
        'sample'
    ]

In [None]:
#| exporti
def dm_equals(dm1: DataModule, dm2: DataModule, indices_equals: bool = True) -> bool:
    # data_equals = np.allclose(dm1.data.to_numpy(), dm2.data.to_numpy())
    assert_frame_equal(dm1.data, dm2.data)
    xs_equals = np.allclose(dm1.xs, dm2.xs)
    ys_equals = np.allclose(dm1.ys, dm2.ys)
    train_indices_equals = np.array_equal(dm1.train_indices, dm2.train_indices)
    test_indices_equals = np.array_equal(dm1.test_indices, dm2.test_indices)
    if not (xs_equals and ys_equals):
        return False
    if indices_equals: 
        return train_indices_equals and test_indices_equals
    return True

In [None]:
# Test initialization
config = DataModuleConfig.load_from_json("assets/adult/data/config.json")
config_1 = config.dict()
config_1.update({"imutable_cols": []})
dm = DataModule.from_config(config)
dm_1 = DataModule.from_config(config.dict())
assert dm_equals(dm, dm_1)
dm_2 = DataModule.from_path("assets/adult/data")
assert dm_equals(dm, dm_2)
dm_3 = DataModule.from_config(config_1)
assert dm_equals(dm, dm_3)
assert dm_3.config.imutable_cols == []
feats = FeaturesList.load_from_path("assets/adult/data/features")
label = FeaturesList.load_from_path("assets/adult/data/label")
dm_4 = DataModule.from_features(feats, label)
assert dm_equals(dm, dm_4, indices_equals=False) # Indices are not supposed to be equal

In [None]:
# Test from_numpy
xs, ys = make_classification(n_samples=100, n_features=5, n_informative=3, random_state=0)
dm_5 = DataModule.from_numpy(xs, ys, name="test", transformation='identity')
config_5 = dm_5.config
assert dm_5.config.data_name == "test"
assert dm_5.data.shape == (100, 6)
assert np.allclose(dm_5.data.to_numpy(), np.concatenate([xs, ys.reshape(-1, 1)], axis=1))
assert np.allclose(
    xs[config_5.train_indices],
    dm_5['train'][0]
)
assert np.allclose(
    xs[config_5.test_indices],
    dm_5['test'][0]
)
dm_5.save('tmp/test')
dm_6 = DataModule.load_from_path('tmp/test')
assert dm_equals(dm_5, dm_6)
shutil.rmtree("tmp/test")

In [None]:
# Test save and load
dm.save("tmp/adult")
dm_5 = DataModule.load_from_path("tmp/adult")
assert dm_equals(dm, dm_5)
shutil.rmtree("tmp/adult")

In [None]:
# Test set_transformations
dm_6 = deepcopy(dm)
dm_6.set_transformations({"age": 'identity'})
assert dm_6.features['age'].transformation.name == 'identity'
assert np.array_equal(dm_6.xs[:, :1], dm_6.data[['age']].to_numpy())
dm_6.set_transformations({feat: 'ordinal' for feat in config.discret_cols})
assert dm_6.xs.shape == (dm.data.shape[0], len(config.continous_cols) + len(config.discret_cols))

assert np.array_equal(dm_6.xs[:, :1], dm_6.data[['age']].to_numpy())

test_fail(lambda: dm_6.set_transformations({1: 'identity'}), contains="Invalid idx type")
test_fail(lambda: dm_6.set_transformations({"❤": 'identity'}), contains="Invalid feature name")
test_fail(lambda: dm_6.set_transformations({"age": '❤'}), contains="Unknown transformation")
test_fail(lambda: dm_6.set_transformations('❤'), contains="Invalid feature_names_to_transformation type")

dm_6.set_transformations({"age": MinMaxTransformation()})
assert np.allclose(dm_6.xs[:, :1], dm.xs[:, :1])

In [None]:
# Test sample
sampled_xs, sampled_ys = dm.sample(0.1)
assert len(sampled_xs) == len(sampled_ys)
assert sampled_xs.shape[0] == int(0.1 * dm['train'][0].shape[0])
assert not jnp.all(sampled_xs == dm['train'][0][:sampled_xs.shape[0]])

sampled_xs, sampled_ys = dm.sample(100)
assert len(sampled_xs) == len(sampled_ys)
assert sampled_xs.shape[0] == 100
assert not jnp.all(sampled_xs == dm['train'][0][:100])

test_fail(lambda: dm.sample(1.1), contains='should be a floating number 0<=size<=1,')
test_fail(lambda: dm.sample('train'), contains='or an integer')

xs = dm['train'][0]
cfs = jrand.uniform(jrand.PRNGKey(0), shape=xs.shape, minval=0.01, maxval=0.99)
cfs = dm.apply_constraints(xs, cfs, hard=False)
assert cfs.shape == xs.shape

cfs = dm.apply_constraints(xs, cfs, hard=True)
assert cfs.shape == xs.shape

In [None]:
# Test transform
data = dm.transform(dm.data)
assert np.allclose(data, dm.xs)

## Load Data


In [None]:
#| exporti
DEFAULT_DATA = [
    'adult',
    'heloc',
    'oulad',
    'credit',
    'cancer',
    'student_performance',
    'titanic',
    'german',
    'spam',
    'ozone',
    'qsar',
    'bioresponse',
    'churn',
    'road',
    'dummy'
 ]

DEFAULT_DATA_CONFIGS = { 
    data: { 
        'data': f"{data}/data", 'model': f"{data}/model",
    } for data in DEFAULT_DATA
}

In [None]:
# from sklearn.datasets import make_classification

# xs, ys = make_classification(n_samples=1000, n_features=10)
# xs = pd.DataFrame(xs, columns=[f"col_{i}" for i in range(10)])
# ys = pd.DataFrame(ys, columns=['label'])
# data = pd.concat([xs, ys], axis=1)
# os.makedirs('assets/dummy/data', exist_ok=True)
# data.to_csv('assets/dummy/data/data.csv', index=False)
# config = DataModuleConfig(
#     data_name="dummy", 
#     data_dir="assets/dummy/data/data.csv", 
#     continous_cols=[f"col_{i}" for i in range(10)]
# )
# dm = DataModule(config)
# dm.save('assets/dummy/data')

In [None]:
# for data_name in DEFAULT_DATA_CONFIGS.keys():
#     print(f"Loading {data_name}...")
#     shutil.rmtree(f'../relax-assets/{data_name}', ignore_errors=True)
#     conf_path = DEFAULT_DATA_CONFIGS[data_name]['conf']
#     config = load_json(conf_path)['data_configs']
#     dm_config = DataModuleConfig(**config)
#     dm = DataModule(dm_config)
#     dm.save(f'../relax-assets/{data_name}/data')
    

In [None]:
# for data_name in DEFAULT_DATA_CONFIGS.keys():
#     print(f"Loading {data_name}...")
#     DataModule.load_from_path(f'../relax-assets/{data_name}/data')    

In [None]:
# config = load_json('assets/adult/configs.json')['data_configs']
# dm_config = DataModuleConfig(**config)
# dm = DataModule(dm_config)

In [None]:
#| exporti
def _validate_dataname(data_name: str):
    if data_name not in DEFAULT_DATA:
        raise ValueError(f'`data_name` must be one of {DEFAULT_DATA}, '
            f'but got data_name={data_name}.')

In [None]:
#| export
def download_data_module_files(
    data_name: str, # The name of data
    data_parent_dir: Path, # The directory to save data.
    download_original_data: bool = False, # Download original data or not
):
    files = [
        "features/data.npy", "features/treedef.json",
        "label/data.npy", "label/treedef.json",
        "config.json"
    ]
    if download_original_data:
        files.append("data.csv")
    for f in files:
        url = f"https://huggingface.co/datasets/birkhoffg/ReLax-Assets/resolve/main/{data_name}/data/{f}"
        f_path = data_parent_dir / f'{data_name}/data' / f
        os.makedirs(f_path.parent, exist_ok=True)
        if not f_path.is_file(): urlretrieve(url, f_path)


def load_data(
    data_name: str, # The name of data
    return_config: bool = False, # Deprecated
    data_configs: dict = None, # Data configs to override default configuration
) -> DataModule | Tuple[DataModule, DataModuleConfig]: # Return `DataModule` or (`DataModule`, `DataModuleConfig`)
    """High-level util function for loading `data` and `data_config`."""
    
    _validate_dataname(data_name)

    # create new dir
    data_parent_dir = Path(os.getcwd()) / "relax-assets"
    if not data_parent_dir.exists():
        os.makedirs(data_parent_dir, exist_ok=True)
    # download files
    download_data_module_files(
        data_name, data_parent_dir, 
        download_original_data=True
    )

    if return_config:
        warnings.warn("`return_config` is deprecated since v0.2. "
                      "Please access `config` from `DataModule`.", DeprecationWarning)

    # read and override config
    # comment them for now since we cannot garantee the override configs are valid
    # conf_path = data_parent_dir / f'{data_name}/data/config.json'
    # config = load_json(conf_path)
    # if not (data_configs is None):
    #     config.update(data_configs)
    # config = DataModuleConfig(**config)

    data_dir = data_parent_dir / f'{data_name}/data'
    data_module = DataModule.load_from_path(data_dir, config=data_configs)

    return data_module


`load_data` easily loads example datasets by passing the `data_name`. 
For example, you can load the [adult](https://archive.ics.uci.edu/ml/datasets/adult) as:

In [None]:
dm = load_data(data_name = 'adult')

#### Supported Datasets

`load_data` currently supports following datasets:

In [None]:
#| echo: false
#| eval: false
def display_data_attrbutes(names: list):
    attrs = {
        '# Cont Features': { data_name: 0 for data_name in names}, 
        '# Cat Features': { data_name: 0 for data_name in names},
        '# of Data Points': { data_name: 0 for data_name in names}, 
    }
    for data_name in names:
        dm= load_data(data_name)
        config = dm.config
        attrs['# Cont Features'][data_name] = len(config.continous_cols)
        attrs['# Cat Features'][data_name] = len(config.discret_cols)
        attrs['# of Data Points'][data_name] = len(dm.data)

        # run tests
        # check_datamodule(dm, config)
    return pd.DataFrame.from_dict(attrs)

display_data_attrbutes(DEFAULT_DATA_CONFIGS.keys())

Unnamed: 0,# Cont Features,# Cat Features,# of Data Points
adult,2,6,32561
heloc,21,2,10459
oulad,23,8,32593
credit,20,3,30000
cancer,30,0,569
student_performance,2,14,649
titanic,2,24,891
german,7,13,1000
spam,57,0,4601
ozone,72,0,2534


In [None]:
#| hide
# for data_name in DEFAULT_DATA_CONFIGS.keys():
#     dm, config = load_data(
#         data_name, return_config=True, data_configs=dict(sample_frac=0.1)
#     )
#     assert config.sample_frac == 0.1
    # check_datamodule(dm, config)