# Get started

> A basic tutorial over key features in `ReLax`

## Import Data

To import data, we use `TabularDataModule` to load the data. First we setup the input `data_configs` for the module. `data_configs` has 6 attributes:
- `data_dir` should be the directory of your data.
- `data_name` is the name of your data.
- `batch_size` is the batch size of your data.
- `continous_cols` for continuous/numeric values in the data.
- `discret_cols` for discret values in the data. It will be converted to one-hot encoding for training purpose.
- `imutable_cols` for imutable values in the data.

In [None]:
#| all_slow

In [None]:
from relax.import_essentials import *

In [None]:
data_configs = {
    "data_dir": "../assets/data/s_adult.csv",
    "data_name": "adult",
    "continous_cols": ["age","hours_per_week"],
    "discret_cols": ["workclass","education","marital_status","occupation","race","gender"],
    "imutable_cols": ["race","gender"]
}


Then we pass `data_configs` to `TabularDataModule` to load the data

In [None]:
from relax.data import TabularDataModule

In [None]:
dm = TabularDataModule(data_configs)

## Train the Machine Learning Classifier

Now we have the data loaded, we need to specify the classifier.

1. Specify the machine learning model configurations `m_configs`. `m_configs` has 3 atributes:

- `lr` is learning rate
- `sizes` shape of the machine learning model
- `dropout_rate` is dropout rate

In [None]:
m_configs = {
    'lr': 0.003,
    "sizes": [50, 10, 50],
    "batch_size": 256,
    "dropout_rate": 0.3
}

2. Pass `m_configs` to our classification model.

In [None]:
from relax.module import PredictiveTrainingModule

In [None]:
training_module = PredictiveTrainingModule(m_configs)

In [None]:
t_configs = {
    'n_epochs': 10,
    'monitor_metrics': 'val/val_loss',
    'logger_name': 'pred',
    "batch_size": 256
}

4. Pass `training_module`, `dm`, and `t_configs` we have specified above to train the model.

In [None]:
from relax.trainer import train_model

params, opt_state = train_model(
    training_module, dm, t_configs
)

Epoch 9: 100%|██████████| 96/96 [00:01<00:00, 87.87batch/s, train/train_loss_1=0.038] 


Here we use `params` to store the parameters after training. 

## Generate Counterfactual Examples


1. Setup the predict function from previous training.

In [None]:
pred_fn = lambda x: training_module.forward(
    params, random.PRNGKey(0), x, is_training=False)

2. Setup the counterfactual configurations. `n_steps` and `lr` are the hyperparameters.

In [None]:
cf_configs = {
    'n_steps': 1000,
    'lr': 0.001
}

3. Setup the counterfactual method. Here we use `VanillaCF`.

In [None]:
from relax.methods import VanillaCF

cf_exp = VanillaCF(cf_configs)

4. Generate counterfactual examples.

In [None]:
from relax.evaluate import generate_cf_explanations

cf_results = generate_cf_explanations(cf_exp, dm, pred_fn)

100%|██████████| 1000/1000 [00:05<00:00, 182.98it/s]


## Benchmark the Counterfactual Method

After we obtain the counterfactual results, we can use  `benchmark_cfs` to evaluate the accuracy, validity, and proximity of the counterfactual example.

In [None]:
from relax.evaluate import benchmark_cfs

In [None]:
benchmark_cfs([cf_results])

Unnamed: 0,Unnamed: 1,acc,validity,proximity
adult,VanillaCF,0.825574,0.866724,7.6325455
