In [None]:
#| default_exp methods.base

In [None]:
#| export
from relax.import_essentials import *
from relax.base import BaseConfig, BaseModule, PredFnMixedin, TrainableMixedin

In [None]:
#| export
class CFModule(BaseModule):
    """Base class for all counterfactual modules."""

    def __init__(
        self, 
        config,
        *, 
        name: str = None,
        apply_constraints_fn = None,
        apply_regularization_fn = None
    ):
        super().__init__(config, name=name)
        self.apply_constraints_fn = apply_constraints_fn
        self.apply_regularization_fn = apply_regularization_fn

    def init_apply_fns(
        self,
        apply_constraints_fn = None,
        apply_regularization_fn = None
    ):
        if self.apply_constraints_fn is None and apply_constraints_fn is not None:
            self.apply_constraints_fn = apply_constraints_fn
        if self.apply_regularization_fn is None and apply_regularization_fn is not None:
            self.apply_regularization_fn = apply_regularization_fn
    
    def apply_constraints(self, *args, **kwargs):
        if self.apply_constraints_fn is not None:
            self.apply_constraints_fn(*args, **kwargs)
    
    def apply_regularization(self, *args, **kwargs):
        if self.apply_regularization_fn is not None:
            self.apply_regularization_fn(*args, **kwargs)

    def generate_cf(
        self,
        x: Array,
        pred_fn: Callable = None,
        pred_fn_args: Dict = None,
    ) -> Array: # Return counterfactual of x.
        raise NotImplementedError

In [None]:
#| export
class ParametricCFModule(CFModule, TrainableMixedin):
    """Base class for parametric counterfactual modules."""
    pass