# Interfaces

In [None]:
#| default_exp methods.base

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc

In [None]:
#| export
from __future__ import annotations
from relax.import_essentials import *
from relax.data import TabularDataModule
from relax.trainer import TrainingConfigs
from copy import deepcopy

In [None]:
#| export
class BaseCFModule(ABC):
    """Base CF Explanation Module."""
    _data_module: TabularDataModule

    @property
    @abstractmethod
    def name(self):
        """Name of the CF Explanation Module."""
        raise NotImplementedError
    
    @property
    def data_module(self) -> TabularDataModule:
        """Binded `DataModule`."""
        return self._data_module

    @abstractmethod
    def generate_cfs(
        self,
        X: jnp.ndarray, # Input to be explained
        pred_fn: Callable = None # Predictive function 
    ) -> jnp.ndarray: # Generated counterfactuals
        """Abstract method to generate counterfactuals"""
        pass

    def hook_data_module(self, data_module: TabularDataModule):
        """Bind `TabularDataModule` to `self._data_module`."""
        self._data_module = data_module


In [None]:
show_doc(BaseCFModule.name, title_level=4)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/methods/base.py#L17){target="_blank" style="float:right; font-size:smaller"}

#### BaseCFModule.name

>      BaseCFModule.name ()

Name of the CF Explanation Module.

In [None]:
show_doc(BaseCFModule.data_module, title_level=4)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/methods/base.py#L22){target="_blank" style="float:right; font-size:smaller"}

#### BaseCFModule.data_module

>      BaseCFModule.data_module ()

Binded `DataModule`.

In [None]:
show_doc(BaseCFModule.generate_cfs, title_level=4)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/methods/base.py#L26){target="_blank" style="float:right; font-size:smaller"}

#### BaseCFModule.generate_cfs

>      BaseCFModule.generate_cfs (X:jax._src.numpy.ndarray.ndarray,
>                                 pred_fn:Callable=None)

Abstract method to generate counterfactuals

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| X | jnp.ndarray |  | Input to be explained |
| pred_fn | Callable | None | Predictive function |
| **Returns** | **jnp.ndarray** |  | **Generated counterfactuals** |

In [None]:
show_doc(BaseCFModule.hook_data_module, title_level=4)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/methods/base.py#L34){target="_blank" style="float:right; font-size:smaller"}

#### BaseCFModule.hook_data_module

>      BaseCFModule.hook_data_module
>                                     (data_module:cfnet.data.module.TabularData
>                                     Module)

Bind `TabularDataModule` to `self._data_module`.

In [None]:
#| export
class BaseParametricCFModule(ABC):
    @abstractmethod
    def train(
        self, 
        datamodule: TabularDataModule, # data module
        t_configs: TrainingConfigs | dict = None, # training configs; see docs in `TrainingConfigs`
        pred_fn: Callable = None # predictive function
    ): 
        pass

    @abstractmethod
    def _is_module_trained(self) -> bool: pass

In [None]:
show_doc(BaseParametricCFModule.train, title_level=4)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/methods/base.py#L48){target="_blank" style="float:right; font-size:smaller"}

#### BaseParametricCFModule.train

>      BaseParametricCFModule.train
>                                    (datamodule:cfnet.data.module.TabularDataMo
>                                    dule, t_configs:Union[cfnet.train.TrainingC
>                                    onfigs,dict]=None)

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| datamodule | TabularDataModule |  | data module |
| t_configs | TrainingConfigs \| dict | None | training configs; see docs in `TrainingConfigs` |

In [None]:
#| export
class BasePredFnCFModule(ABC):
    """Base class of CF Module with a predictive module."""
    @abstractmethod
    def pred_fn(
        self, 
        X: jnp.DeviceArray  # input `X`
    ) -> jnp.DeviceArray:   # prediction
        raise NotImplementedError

In [None]:
show_doc(BasePredFnCFModule.pred_fn, title_level=4)

---

[source](https://github.com/birkhoffg/cfnet/tree/master/blob/master/cfnet/methods/base.py#L62){target="_blank" style="float:right; font-size:smaller"}

#### BasePredFnCFModule.pred_fn

>      BasePredFnCFModule.pred_fn (X:jaxlib.xla_extension.DeviceArrayBase)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| X | jnp.DeviceArray | input `X` |
| **Returns** | **jnp.DeviceArray** | **prediction** |