In [None]:
#| default_exp methods.base

In [None]:
#| export
from relax.import_essentials import *
from relax.base import BaseConfig, BaseModule, PredFnMixedin, TrainableMixedin

In [None]:
#| export
def default_apply_constraints_fn(x, cf, hard, **kwargs):
    return cf

def default_compute_reg_loss_fn(x, cf, **kwargs):
    return 0.

In [None]:
#| export
class CFModule(BaseModule):
    """Base class for all counterfactual modules."""

    def __init__(
        self, 
        config,
        *, 
        name: str = None,
        apply_constraints_fn = None,
        compute_reg_loss_fn = None
    ):
        super().__init__(config, name=name)
        self.apply_constraints_fn = apply_constraints_fn
        self.compute_reg_loss_fn = compute_reg_loss_fn

    def init_fns(
        self,
        apply_constraints_fn = None,
        compute_reg_loss_fn = None
    ):
        if self.apply_constraints_fn is None and apply_constraints_fn is not None:
            self.apply_constraints_fn = apply_constraints_fn
        else:
            self.apply_constraints_fn = default_apply_constraints_fn
        if self.compute_reg_loss_fn is None and compute_reg_loss_fn is not None:
            self.compute_reg_loss_fn = compute_reg_loss_fn
        else:
            self.compute_reg_loss_fn = default_compute_reg_loss_fn
    
    def apply_constraints(self, *args, **kwargs):
        if self.apply_constraints_fn is not None:
            self.apply_constraints_fn(*args, **kwargs)
    
    def compute_reg_loss(self, *args, **kwargs):
        if self.compute_reg_loss_fn is not None:
            self.compute_reg_loss_fn(*args, **kwargs)

    def generate_cf(
        self,
        x: Array,
        pred_fn: Callable = None,
        y_target: Array = None,
        rng_key: jrand.PRNGKey = None,
        **kwargs
    ) -> Array: # Return counterfactual of x.
        raise NotImplementedError

In [None]:
#| export
class ParametricCFModule(CFModule, TrainableMixedin):
    """Base class for parametric counterfactual modules."""
    pass