In [None]:
#| eval: false

projectdir = splitpath(pwd()) |>
    ss -> joinpath(ss[1:findall([s == "CounterfactualTraining.jl" for s in ss])[1]]...) 
cd(projectdir)

using CTExperiments
using CTExperiments.CounterfactualExplanations
using CTExperiments.CounterfactualTraining
using CTExperiments.Flux
using CTExperiments.Plots
using CTExperiments.TaijaPlotting
using Plots.PlotMeasures

## Protecting Mutability Constraints with Linear Classifiers {#sec-app-constraints}

In @sec-constraints we explain that to avoid penalizing implausibility that arises due to mutability constraints, we impose a point mass prior on $p(\mathbf{x})$ for the corresponding feature. We argue in @sec-constraints that this approach induces models to be less sensitive to immutable features and demonstrate this empirically in @sec-experiments. Below we derive the analytical results in @thm-mtblty.

::: {.proof}

Let $d_{\text{mtbl}}$ and $d_{\text{immtbl}}$ denote some mutable and immutable feature, respectively. Suppose that $\mu_{y^-,d_{\text{immtbl}}} < \mu_{y^+,d_{\text{immtbl}}}$ and $\mu_{y^-,d_{\text{mtbl}}} > \mu_{y^+,d_{\text{mtbl}}}$, where $\mu_{k,d}$ denotes the conditional sample mean of feature $d$ in class $k$. In words, we assume that the immutable feature tends to take lower values for samples in the non-target class $y^-$ than in the target class $y^+$. We assume the opposite to hold for the mutable feature.

Assuming multivariate Gaussian class densities with common diagonal covariance matrix $\Sigma_k=\Sigma$ for all $k \in \mathcal{K}$, we have for the log likelihood ratio between any two classes $k,m \in \mathcal{K}$ [@hastie2009elements]:

$$
\log \frac{p(k|\mathbf{x})}{p(m|\mathbf{x})}=\mathbf{x}^\intercal \Sigma^{-1}(\mu_{k}-\mu_{m})  + \text{const}
$$ {#eq-loglike}

By independence of $x_1,...,x_D$, the full log-likelihood ratio decomposes into:

$$
\log \frac{p(k|\mathbf{x})}{p(m|\mathbf{x})} = \sum_{d=1}^D \frac{\mu_{k,d}-\mu_{m,d}}{\sigma_{d}^2} x_{d} + \text{const}
$$ {#eq-loglike-decomp}

By the properties of our classifier (*multinomial logistic regression*), we have:

$$
\log \frac{p(k|\mathbf{x})}{p(m|\mathbf{x})} = \sum_{d=1}^D \left( \theta_{k,d} - \theta_{m,d} \right)x_d + \text{const}
$$ {#eq-multi}

where $\theta_{k,d}=\Theta[k,d]$ denotes the coefficient on feature $d$ for class $k$. 

Based on @eq-loglike-decomp and @eq-multi we can identify that $(\mu_{k,d}-\mu_{m,d}) \propto (\theta_{k,d} - \theta_{m,d})$ under the assumptions we made above. Hence, we have that $(\theta_{y^-,d_{\text{immtbl}}} - \theta_{y^+,d_{\text{immtbl}}}) < 0$ and $(\theta_{y^-,d_{\text{mtbl}}} - \theta_{y^+,d_{\text{mtbl}}}) > 0$

Let $\mathbf{x}^\prime$ denote some randomly chosen individual from class $y^-$ and let $y^+ \sim p(y)$ denote the randomly chosen target class. Then the partial derivative of the contrastive divergence penalty [@eq-div] with respect to coefficient $\theta_{y^+,d}$ is equal to 

$$
\frac{\partial}{\partial\theta_{y^+,d}} \left(\text{div}(\mathbf{x},\mathbf{x^\prime},\mathbf{y};\theta)\right) = \frac{\partial}{\partial\theta_{y^+,d}} \left( \left(-\mathbf{M}_\theta(\mathbf{x})[y^+]\right) - \left(-\mathbf{M}_\theta(\mathbf{x}^\prime)[y^+]\right) \right) = x_{d}^\prime - x_{d}
$$ {#eq-grad}

and equal to zero everywhere else.

Since $(\mu_{y^-,d_{\text{immtbl}}} < \mu_{y^+,d_{\text{immtbl}}})$ we are more likely to have $(x_{d_{\text{immtbl}}}^\prime - x_{d_{\text{immtbl}}}) < 0$ than vice versa at initialization. Similarly, we are more likely to have $(x_{d_{\text{mtbl}}}^\prime - x_{d_{\text{mtbl}}}) > 0$ since $(\mu_{y^-,d_{\text{mtbl}}} > \mu_{y^+,d_{\text{mtbl}}})$.

This implies that if we do not protect feature $d_{\text{immtbl}}$, the contrastive divergence penalty will push down on $\theta_{y^-,d_{\text{immtbl}}}$ thereby exacerbating the existing effect $(\theta_{y^-,d_{\text{immtbl}}} - \theta_{y^+,d_{\text{immtbl}}}) < 0$. In words, not protecting the immutable feature would have the undesirable effect of making the classifier more sensitive to this feature, in that it would be more likely to predict class $y^-$ as opposed to $y^+$ for lower values of $d_{\text{immtbl}}$. 

By the same rationale, the contrastive divergence penalty can generally be expected to push up on $\theta_{y^-,d_{\text{mtbl}}}$ exacerbating $(\theta_{y^-,d_{\text{mtbl}}} - \theta_{y^+,d_{\text{mtbl}}}) > 0$. In words, this has the effect of making the classifier more sensitive to the mutable feature, in that it would be more likely to predict class $y^-$ as opposed to $y^+$ for higher values of $d_{\text{mtbl}}$.

Thus, our proposed approach of protecting feature $d_{\text{immtbl}}$ has the net affect of decreasing the classifier's sensitivity to the immutable feature relative to the mutable feature (i.e. no change in sensitivity for $d_{\text{immtbl}}$ relative to increased sensitivity for $d_{\text{mtbl}}$).

:::

We provide an illustrative example in @exm-grad.

::: {#exm-grad}

## Prediction of Consumer Credit Default

Suppose now that $d_{\text{immtbl}}$ represents an individual's *age* and $d_{\text{mtbl}}$ represents an individual's existing level of credit card debt. Assume that these two features are independent, the class conditional densities are Gaussian and we use no other features to predict the risk of individuals defaulting on a consumer loan using a linear classifier. We have simulated this scenario using synthetic data in @fig-mtblty.

In panel (a) of @fig-mtblty, we have trained the linear classifier using counterfactual training, treating both features as mutable. The linear decision boundary is roughly equally sensitive to both *age* and existing levels of *debt*. 

Conversely, in panel (b) of @fig-mtblty, we have trained the same classifier using counterfactual training, but this time treating *age* as an immutable feature. The result is a new decision boundary that has tilted in favor of higher sensitivity to the mutable feature (existing *debt*) and lower sensitivity to the immutable feature (*age*).

:::

![Visual illustration of the effect of imposing mutability constraints. See @exm-grad for details.](/paper/figures/app_mtblty.svg){#fig-mtblty}

::: {.callout-warning}

\@Cynthia, \@Arie, I have tentatively phrased the above in terms of a theorem and proof. This is something I've so far shied away from because I feel a bit out of my depth when it comes to mathematical proofs. The above makes intuitive sense to me, but I don't know for sure if it's correct. 

:::


In [None]:
#| eval: false

using CTExperiments: 
    get_ce_data, 
    train_val_split, 
    build_model, 
    LinearModel, 
    get_input_encoder


# Callback:
using CTExperiments: get_log_reg_params, get_decision_boundary

function plot_db(model, ces; _xlab="Debt", _ylab="Age")

    x = (ce -> ce.counterfactual).(ces) |> xs -> reduce(hcat, xs)
    x0 = (ce -> ce.factual).(ces) |> xs -> reduce(hcat, xs)

    xlab = _constraint[1] == "both" ? "$(_xlab) (mutable)" : "$(_xlab) (immutable)"
    ylab = _constraint[2] == "both" ? "$(_ylab) (mutable)" : "$(_ylab) (immutable)"

    # Data and decision boundary:
    coeff = get_log_reg_params(model)
    db = get_decision_boundary(coeff)
    plt = Plots.scatter(
        ce_data, 
        # label=["1=Default" "2=No Default"], 
        # legend_position=:topright, 
        xlab=xlab, ylab=ylab,  
        axis=nothing, 
        legend=false,
        title=_title
    )
    Plots.abline!(plt, db.slope, db.intercept; lw=5, label="Dec. Boundary")

    # Counterfactuals:
    if !any(isnothing.(x))
        yhat = [argmax(y) for y in eachcol(model(x))]
        yhat0 = [argmax(y) for y in eachcol(model(x0))]
        if any(yhat.==2)

            # Paths:
            u = []
            v = []
            for (i,ce) in enumerate(eachcol(x))
                Δ = ce - x0[:,i]
                push!(u, Δ[1])
                push!(v, Δ[2])
            end
            Plots.quiver!(x0[1,yhat.==2], x0[2,yhat.==2], quiver=(u[yhat.==2], v[yhat.==2]), color=1)
            
            # End points:
            Plots.scatter!(x[1,yhat.==2], x[2,yhat.==2], label=["CE (y⁺=1)" "CE (y⁺=2)"], ms=15, shape=:star, color=yhat[yhat.==2], group=yhat[yhat.==2], mscolor=yhat0[yhat.==2])

        end
    end
    display(plt)
end

# Data:
specs = [
    ("(a)", VanillaObjective(needs_ce=true), ["both", "both"]),
    ("(b)", EnergyDifferentialObjective(lambda=[1.0,1.0,0.0]), ["both", "both"]),
    ("(c)", VanillaObjective(needs_ce=true), ["both", "none"]),
    ("(d)", EnergyDifferentialObjective(lambda=[1.0,1.0,0.0]), ["both", "none"]),
]

In [None]:
#| eval: false

using CounterfactualExplanations.Convergence
using Random
Random.seed!(42)

models = []
plts = []

for (title, obj, constraint) in specs

    # Globals for the callback:
    global _title = title
    global _constraint = constraint

    # Data:
    data = LinearlySeparable(
        n_train=500,
        batchsize=50,
        mutability=constraint
    )
    global ce_data = get_ce_data(data)
    val_size = data.n_validation / (data.n_validation + data.n_train)
    train_set, val_set, _ = train_val_split(data, ce_data, val_size)

    # Model:
    nin = size(first(train_set)[1], 1)
    nout = size(first(train_set)[2], 1)
    model = build_model(LinearModel(), nin, nout)

    # Objective:
    generator = GenericGenerator()
    opt_state = Flux.setup(Descent(), model)
    conv = MaxIterConvergence()

    model, logs = counterfactual_training(
        obj,
        model,
        generator,
        train_set,
        opt_state;
        val_set = val_set,
        nepochs = 50,
        mutability = Symbol.(constraint),
        callback = plot_db,
        nce=50,
        convergence=conv, 
    )

    push!(models, model)
    push!(plts, current())
end

In [None]:
plt = Plots.plot(
    plts..., 
    layout=(1,4), 
    size=(1150,250),  
    left_margin = 10mm, 
    bottom_margin = 5mm,
    top_margin = 3mm,
    right_margin = 10mm,
)
display(plt)
Plots.savefig(plt, "paper/figures/poc.svg")

In [None]:
#| eval: false

plts = []
titles = ["(a)", "(b)"]
_xlab = "Existing Debt"
_ylab = "Age"
for (i,model) in enumerate(models)
    coeff = get_log_reg_params(model)
    db = get_decision_boundary(coeff)
    xlab = constraints[i][1] == "both" ? "$(_xlab) (mutable)" : "$(_xlab) (immutable)"
    ylab = constraints[i][2] == "both" ? "$(_ylab) (mutable)" : "$(_ylab) (immutable)"
    plt = Plots.scatter(ce_data; xlab=xlab, ylab=ylab, bottom_margin = 5mm, left_margin = 5mm, label=["Default" "No Default"], title=titles[i])
    Plots.abline!(plt, db.slope, db.intercept; lw=5, label="Dec. Boundary")

    push!(plts, plt)
end

plt = Plots.plot(plts..., layout=(1,2), size=(600,300))
Plots.savefig(plt, "paper/figures/app_mtblty.svg")