# Interfaces

In [None]:
#| default_exp methods.base

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
from nbdev import show_doc

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| export
from __future__ import annotations
from cfnet.import_essentials import *
from cfnet.datasets import TabularDataModule
from cfnet.train import TrainingConfigs
from copy import deepcopy

In [None]:
#| export
class BaseCFModule(ABC):
    cat_arrays = []
    cat_idx = 0

    @property
    @abstractmethod
    def name(self):
        pass

    @abstractmethod
    def generate_cfs(
        self,
        X: jnp.ndarray,
        pred_fn: Callable = None
    ) -> jnp.ndarray:
        pass

    def update_cat_info(self, data_module: TabularDataModule):
        # TODO: need refactor
        self.cat_arrays = deepcopy(data_module.cat_arrays)
        self.cat_idx = data_module.cat_idx
        self.imutable_idx_list = deepcopy(data_module.imutable_idx_list)


In [None]:
#| export
class BaseParametricCFModule(ABC):
    @abstractmethod
    def train(
        self, 
        datamodule: TabularDataModule, # data module
        t_configs: TrainingConfigs | dict = None # training configs; see docs in `TrainingConfigs`
    ): 
        pass

    @abstractmethod
    def _is_module_trained(self) -> bool: pass

In [None]:
show_doc(BaseParametricCFModule.train)

---

### BaseParametricCFModule.train

>      BaseParametricCFModule.train
>                                    (datamodule:cfnet.datasets.TabularDataModul
>                                    e, t_configs:Union[cfnet.train.TrainingConf
>                                    igs,dict]=None)

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| datamodule | TabularDataModule |  | data module |
| t_configs | TrainingConfigs \| dict | None | training configs; see docs in `TrainingConfigs` |

In [None]:
#| export
class BasePredFnCFModule(ABC):
    """Base class of CF Module with a predictive module."""
    @abstractmethod
    def pred_fn(
        self, 
        X: jnp.DeviceArray  # input `X`
    ) -> jnp.DeviceArray:   # prediction
        raise NotImplementedError

In [None]:
show_doc(BasePredFnCFModule.pred_fn)

---

### BasePredictFnCFModule.pred_fn

>      BasePredictFnCFModule.pred_fn (X:jaxlib.xla_extension.DeviceArrayBase)

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| X | jnp.DeviceArray | input `X` |
| **Returns** | **jnp.DeviceArray** | **prediction** |