# Technical Details of Our Approach {.appendix} 


## Generating Counterfactuals through Gradient Descent {#sec-app-ce}

In this section, we provide some background on gradient-based counterfactual generators (@sec-app-ce-background) and discuss how we define convergence in this context (@sec-app-conv).

### Background {#sec-app-ce-background}

Gradient-based counterfactual search was originally proposed by @wachter2017counterfactual. It generally solves the following unconstrained objective,

$$
\begin{aligned}
\min_{\mathbf{z}^\prime \in \mathcal{Z}^L} \left\{  {\text{yloss}(\mathbf{M}_{\theta}(g(\mathbf{z}^\prime)),\mathbf{y}^+)}+ \lambda {\text{cost}(g(\mathbf{z}^\prime)) }  \right\} 
\end{aligned} 
$$

where $g: \mathcal{Z} \mapsto \mathcal{X}$ is an invertible function that maps from the $L$-dimensional counterfactual state space to the feature space and $\text{cost}(\cdot)$ denotes one or more penalties that are used to induce certain properties of the counterfactual outcome. As above, $\mathbf{y}^+$ denotes the target output and $\mathbf{M}_{\theta}(\mathbf{x})$ returns the logit predictions of the underlying classifier for $\mathbf{x}=g(\mathbf{z})$.

For all generators used in this work we use standard logit crossentropy loss for $\text{yloss}(\cdot)$. All generators also penalize the distance ($\ell_1$-norm) of counterfactuals from their original factual state. For *Generic* and *ECCo*, we have $\mathcal{Z}:=\mathcal{X}$ and $g(\mathbf{z})=g(\mathbf{z})^{-1}=\mathbf{z}$, that is counterfactual are searched directly in the feature space. Conversely, *REVISE* traverses the latent space of a variational autoencoder (VAE) fitted to the training data, where $g(\cdot)$ corresponds to the decoder [@joshi2019realistic]. In addition to the distance penalty, *ECCo* uses an additional penalty component that regularizes the energy associated with the counterfactual, $\mathbf{x}^\prime$ [@altmeyer2024faithful]. 

### Convergence {#sec-app-conv}

An important consideration when generating counterfactual explanations using gradient-based methods is how to define convergence. Two common choices are to 1) perform gradient descent over a fixed number of iterations $T$, or 2) conclude the search as soon as the predicted probability for the target class has reached a pre-determined threshold, $\tau$: $\mathcal{S}(\mathbf{M}_\theta(\mathbf{x}^\prime))[y^+] \geq \tau$. We prefer the latter for our purposes, because it explicitly defines convergence in terms of the black-box model, $\mathbf{M}(\mathbf{x})$.

Defining convergence in this way allows for a more intuitive interpretation of the resulting counterfactual outcomes than with fixed $T$. Specifically, it allows us to think of counterfactuals as explaining 'high-confidence' predictions by the model for the target class $y^+$. Depending on the context and application, different choices of $\tau$ can be considered as representing 'high-confidence' predictions.


```{julia}
#| eval: false

projectdir = splitpath(pwd()) |>
    ss -> joinpath(ss[1:findall([s == "CounterfactualTraining.jl" for s in ss])[1]]...) 
cd(projectdir)

using CTExperiments
using CTExperiments.CounterfactualExplanations
using CTExperiments.CounterfactualTraining
using CTExperiments.Flux
using CTExperiments.Plots
using CTExperiments.TaijaPlotting
using Plots.PlotMeasures
```

## Protecting Mutability Constraints with Linear Classifiers {#sec-app-constraints}

In @sec-constraints we explain that to avoid penalizing implausibility that arises due to mutability constraints, we impose a point mass prior on $p(\mathbf{x})$ for the corresponding feature. We argue in @sec-constraints that this approach induces models to be less sensitive to immutable features and demonstrate this empirically in @sec-experiments. Below we derive the analytical results in Prp.~\ref{prp-mtblty}.

::: {.proof}

Let $d_{\text{mtbl}}$ and $d_{\text{immtbl}}$ denote some mutable and immutable feature, respectively. Suppose that $\mu_{y^-,d_{\text{immtbl}}} < \mu_{y^+,d_{\text{immtbl}}}$ and $\mu_{y^-,d_{\text{mtbl}}} > \mu_{y^+,d_{\text{mtbl}}}$, where $\mu_{k,d}$ denotes the conditional sample mean of feature $d$ in class $k$. In words, we assume that the immutable feature tends to take lower values for samples in the non-target class $y^-$ than in the target class $y^+$. We assume the opposite to hold for the mutable feature.

Assuming multivariate Gaussian class densities with common diagonal covariance matrix $\Sigma_k=\Sigma$ for all $k \in \mathcal{K}$, we have for the log likelihood ratio between any two classes $k,m \in \mathcal{K}$ [@hastie2009elements]:

$$
\log \frac{p(k|\mathbf{x})}{p(m|\mathbf{x})}=\mathbf{x}^\intercal \Sigma^{-1}(\mu_{k}-\mu_{m})  + \text{const}
$$ {#eq-loglike}

By independence of $x_1,...,x_D$, the full log-likelihood ratio decomposes into:

$$
\log \frac{p(k|\mathbf{x})}{p(m|\mathbf{x})} = \sum_{d=1}^D \frac{\mu_{k,d}-\mu_{m,d}}{\sigma_{d}^2} x_{d} + \text{const}
$$ {#eq-loglike-decomp}

By the properties of our classifier (*multinomial logistic regression*), we have:

$$
\log \frac{p(k|\mathbf{x})}{p(m|\mathbf{x})} = \sum_{d=1}^D \left( \theta_{k,d} - \theta_{m,d} \right)x_d + \text{const}
$$ {#eq-multi}

where $\theta_{k,d}=\Theta[k,d]$ denotes the coefficient on feature $d$ for class $k$. 

Based on @eq-loglike-decomp and @eq-multi we can identify that $(\mu_{k,d}-\mu_{m,d}) \propto (\theta_{k,d} - \theta_{m,d})$ under the assumptions we made above. Hence, we have that $(\theta_{y^-,d_{\text{immtbl}}} - \theta_{y^+,d_{\text{immtbl}}}) < 0$ and $(\theta_{y^-,d_{\text{mtbl}}} - \theta_{y^+,d_{\text{mtbl}}}) > 0$

Let $\mathbf{x}^\prime$ denote some randomly chosen individual from class $y^-$ and let $y^+ \sim p(y)$ denote the randomly chosen target class. Then the partial derivative of the contrastive divergence penalty [@eq-div] with respect to coefficient $\theta_{y^+,d}$ is equal to 

$$
\frac{\partial}{\partial\theta_{y^+,d}} \left(\text{div}(\mathbf{x},\mathbf{x^\prime},\mathbf{y};\theta)\right) = \frac{\partial}{\partial\theta_{y^+,d}} \left( \left(-\mathbf{M}_\theta(\mathbf{x})[y^+]\right) - \left(-\mathbf{M}_\theta(\mathbf{x}^\prime)[y^+]\right) \right) = x_{d}^\prime - x_{d}
$$ {#eq-grad}

and equal to zero everywhere else.

Since $(\mu_{y^-,d_{\text{immtbl}}} < \mu_{y^+,d_{\text{immtbl}}})$ we are more likely to have $(x_{d_{\text{immtbl}}}^\prime - x_{d_{\text{immtbl}}}) < 0$ than vice versa at initialization. Similarly, we are more likely to have $(x_{d_{\text{mtbl}}}^\prime - x_{d_{\text{mtbl}}}) > 0$ since $(\mu_{y^-,d_{\text{mtbl}}} > \mu_{y^+,d_{\text{mtbl}}})$.

This implies that if we do not protect feature $d_{\text{immtbl}}$, the contrastive divergence penalty will decrease $\theta_{y^-,d_{\text{immtbl}}}$ thereby exacerbating the existing effect $(\theta_{y^-,d_{\text{immtbl}}} - \theta_{y^+,d_{\text{immtbl}}}) < 0$. In words, not protecting the immutable feature would have the undesirable effect of making the classifier more sensitive to this feature, in that it would be more likely to predict class $y^-$ as opposed to $y^+$ for lower values of $d_{\text{immtbl}}$. 

By the same rationale, the contrastive divergence penalty can generally be expected to increase $\theta_{y^-,d_{\text{mtbl}}}$ exacerbating $(\theta_{y^-,d_{\text{mtbl}}} - \theta_{y^+,d_{\text{mtbl}}}) > 0$. In words, this has the effect of making the classifier more sensitive to the mutable feature, in that it would be more likely to predict class $y^-$ as opposed to $y^+$ for higher values of $d_{\text{mtbl}}$.

Thus, our proposed approach of protecting feature $d_{\text{immtbl}}$ has the net affect of decreasing the classifier's sensitivity to the immutable feature relative to the mutable feature (i.e. no change in sensitivity for $d_{\text{immtbl}}$ relative to increased sensitivity for $d_{\text{mtbl}}$).

:::

```{julia}
#| eval: false

using CTExperiments: 
    get_ce_data, 
    train_val_split, 
    build_model, 
    LinearModel, 
    MLPModel,
    get_input_encoder

# Callback:
using CTExperiments: get_log_reg_params, get_decision_boundary

function plot_db(model, ces; _xlab="Debt", _ylab="Age", plot_contour::Bool=false)

    x = (ce -> ce.counterfactual).(ces) |> xs -> reduce(hcat, xs)
    x0 = (ce -> ce.factual).(ces) |> xs -> reduce(hcat, xs)

    xlab = _constraint[1] == "both" ? "$(_xlab) (mutable)" : "$(_xlab) (immutable)"
    ylab = _constraint[2] == "both" ? "$(_ylab) (mutable)" : "$(_ylab) (immutable)"

    # Data and decision boundary:
    plt = Plots.scatter(
        ce_data, 
        # label=["1=Default" "2=No Default"], 
        # legend_position=:topright, 
        xlab=xlab, ylab=ylab,  
        axis=nothing, 
        legend=false,
        title=_title
    )
    if length(model) == 1
        coeff = get_log_reg_params(model)
        db = get_decision_boundary(coeff)
        Plots.abline!(plt, db.slope, db.intercept; lw=5, label="Dec. Boundary")
    end

    # Counterfactuals:
    target = 2
    factual = 1
    if !any(isnothing.(x))
        yhat = [argmax(y) for y in eachcol(model(x))]
        yhat0 = [argmax(y) for y in eachcol(model(x0))]
        idx_plotted = yhat0.==factual .|| yhat0.==target
        if any(idx_plotted)

            # # Paths:
            # u = []
            # v = []
            # for (i,ce) in enumerate(eachcol(x))
            #     Δ = ce - x0[:,i]
            #     push!(u, Δ[1])
            #     push!(v, Δ[2])
            # end
            # Plots.quiver!(x0[1,idx_plotted], x0[2,idx_plotted], quiver=(u[idx_plotted], v[idx_plotted]), color=yhat0[idx_plotted])
            
            # End points:
            Plots.scatter!(x[1,idx_plotted], x[2,idx_plotted], label=["CE (y⁺=1)" "CE (y⁺=2)"], ms=15, shape=:star, color=yhat[idx_plotted], group=yhat[idx_plotted], mscolor=yhat0[idx_plotted])

        end
    end
    display(plt)
end

# Data:
specs = [
    # ("(a)", VanillaObjective(needs_ce=true), ["both", "both"]),
    # ("(b)", FullObjective(lambda=[1.0,0.5,0.01,0.1]), ["both", "both"]),
    # ("(c)", VanillaObjective(needs_ce=true), ["both", "none"]),
    ("(d)", FullObjective(lambda=[1.0,0.5,0.01,0.1]), ["both", "none"]),
]
```

```{julia}
#| eval: false

using CounterfactualExplanations.Convergence
using Random
Random.seed!(42)

models = []
plts = []

for (title, obj, constraint) in specs

    # Globals for the callback:
    global _title = title
    global _constraint = constraint

    # Data:
    data = LinearlySeparable(
        n_train=3000,
        batchsize=50,
        mutability=constraint
    )
    global ce_data = get_ce_data(data)
    val_size = data.n_validation / (data.n_validation + data.n_train)
    train_set, val_set, _ = train_val_split(data, ce_data, val_size)

    # Model:
    nin = size(first(train_set)[1], 1)
    nout = size(first(train_set)[2], 1)
    model = build_model(MLPModel(), nin, nout)

    # Objective:
    opt = AMSGrad()
    generator = GenericGenerator(opt=Descent(0.1))
    opt_state = Flux.setup(opt, model)
    conv = DecisionThresholdConvergence(decision_threshold=0.75, max_iter=30)

    model, logs = counterfactual_training(
        obj,
        model,
        generator,
        train_set,
        opt_state;
        val_set = val_set,
        nepochs = 100,
        mutability = Symbol.(constraint),
        callback = plot_db,
        nce=1000,
        convergence=conv, 
        verbose=3
    )

    push!(models, model)
    push!(plts, current())
end
```

```{julia}
#| eval: false

plt = Plots.plot(
    plts..., 
    layout=(1,4), 
    size=(1150,250),  
    left_margin = 10mm, 
    bottom_margin = 5mm,
    top_margin = 3mm,
    right_margin = 10mm,
)
display(plt)
Plots.savefig(plt, "paper/figures/poc.svg")
```

```{julia}
#| eval: false

plts = []
titles = ["(a)", "(b)"]
_xlab = "Existing Debt"
_ylab = "Age"
for (i,model) in enumerate(models)
    coeff = get_log_reg_params(model)
    db = get_decision_boundary(coeff)
    xlab = constraints[i][1] == "both" ? "$(_xlab) (mutable)" : "$(_xlab) (immutable)"
    ylab = constraints[i][2] == "both" ? "$(_ylab) (mutable)" : "$(_ylab) (immutable)"
    plt = Plots.scatter(ce_data; xlab=xlab, ylab=ylab, bottom_margin = 5mm, left_margin = 5mm, label=["Default" "No Default"], title=titles[i])
    Plots.abline!(plt, db.slope, db.intercept; lw=5, label="Dec. Boundary")

    push!(plts, plt)
end

plt = Plots.plot(plts..., layout=(1,2), size=(600,300))
Plots.savefig(plt, "paper/figures/app_mtblty.svg")
```

## Domain Constraints

We apply domain constraints on counterfactuals during training and evaluation. There are at least two good reasons for doing so. Firstly, within the context of explainability and algorithmic recourse, real-world attributes are often domain constrained: the *age* feature, for example, is lower bounded by zero and upper bounded by the maximum human lifespan. Secondly, domain constraints help mitigate training instabilities commonly associated with energy-based modelling [@grathwohl2020your;@altmeyer2024faithful].

For our image datasets, features are pixel values and hence the domain is constrained by the lower and upper bound of values that pixels can take depending on how they are scaled (in our case $[-1,1]$). For all other features $d$ in our synthetic and tabular datasets, we automatically infer domain constraints $[x_d^{\text{LB}},x_d^{\text{UB}}]$  as follows,

$$
\begin{aligned}
x_d^{\text{LB}} &= \arg\min_{x_d} \{\mu_d - n_{\sigma_d}\sigma_d, \arg \min_{x_d} x_d\} \\
x_d^{\text{UB}} &= \arg\max_{x_d} \{\mu_d + n_{\sigma_d}\sigma_d, \arg \max_{x_d} x_d\} 
\end{aligned}
$$ {#eq-domain}

where $\mu_d$ and $\sigma_d$ denote the sample mean and standard deviation of feature $d$. We set $n_{\sigma_d}=3$ across the board but higher values and hence wider bounds may be appropriate depending on the application.


## Training Details {#sec-app-training}

In this section, we describe the training procedure in detail. While the details laid out here are not crucial for understanding our proposed approach, they are of importance to anyone looking to implement counterfactual training. 
