In [None]:
---
format: 
    commonmark:
        variant: -raw_html
        wrap: none
        self-contained: true
crossref:
  fig-prefix: Figure
  tbl-prefix: Table
bibliography: https://raw.githubusercontent.com/pat-alt/bib/main/bib.bib
output: asis
---

```@meta
CurrentModule = CounterfactualExplanations 
```

# MNIST

In this examples we will see how different counterfactual generators can be used to explain deep learning models for image classification. In particular, we will look at MNIST data and visually inspect how the different generators perturb images of handwritten digits in order to change the predicted label to a target label. @fig-samples shows a random sample of handwritten digits.

In [25]:
using CounterfactualExplanations, Plots, MLDatasets
using MLDatasets.MNIST: convert2image
using BSON: @save, @load

In [26]:
train_x, train_y = MNIST.traindata()
input_dim = prod(size(train_x[:,:,1]))
using Images, Random, StatsBase
Random.seed!(1)
n_samples = 10
samples = train_x[:,:,sample(1:end, n_samples, replace=false)]
mosaic = mosaicview([convert2image(samples[:,:,i]) for i ∈ 1:n_samples]...,ncol=Int(n_samples/2))
plt = plot(mosaic, size=(500,260), axis=nothing, background=:transparent)
savefig(plt, "www/mnist_samples.png")

![A few random handwritten digits.](www/mnist_samples.png){#fig-samples}

## Pre-trained classifiers

Next we will load two pre-trained deep-learning classifiers:

1. Simple MLP - `model`
2. Deep ensemble - `𝓜`

In [39]:
using Flux
using CounterfactualExplanations.Data: mnist_data, mnist_model, mnist_ensemble
x,y,data = getindex.(Ref(mnist_data()), ("x", "y", "data"))
model = mnist_model()
𝓜 = mnist_ensemble();

The following code just prepares the models to be used with CounterfactualExplanations.jl:

In [28]:
using CounterfactualExplanations, CounterfactualExplanations.Models
import CounterfactualExplanations.Models: logits, probs # import functions in order to extend

# MLP:
# Step 1)
struct NeuralNetwork <: Models.FittedModel
    nn::Any
end
# Step 2)
logits(𝑴::NeuralNetwork, X::AbstractArray) = 𝑴.nn(X)
probs(𝑴::NeuralNetwork, X::AbstractArray)= softmax(logits(𝑴, X))
𝑴 = NeuralNetwork(model)

# Deep ensemble:
# Step 1)
struct FittedEnsemble <: Models.FittedModel
    𝓜::AbstractArray
end
# Step 2)
using Statistics
logits(𝑴::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([nn(X) for nn in 𝑴.𝓜],3), dims=3)
probs(𝑴::FittedEnsemble, X::AbstractArray) = mean(Flux.stack([softmax(nn(X)) for nn in 𝑴.𝓜],3),dims=3)
𝑴_ensemble=FittedEnsemble(𝓜);

## Generating counterfactuals

We will look at four different approaches here: 

1. Generic approach for the MLP [@wachter2017counterfactual].
2. Greedy approach for the MLP.
3. Generic approach for the deep ensemble.
4. Greedy approach for the deep ensemble [@schut2021generating].

They can be implemented using the `GenericGenerator` and the `GreedyGenerator`.

### Turning a 9 into a 4

We will start with an example that should yield intuitive results: the process of turning a handwritten 9 in @fig-nine into a 4 is straight-forward for a human - just erase the top part. Let's see how the different algorithmic approaches perform.

In [29]:
# Randomly selected factual:
Random.seed!(1234);
x̅ = Flux.unsqueeze(x[:,rand(1:size(x)[2])],2)
target = 5
γ = 0.95
img = convert2image(reshape(x̅,Int(sqrt(input_dim)),Int(sqrt(input_dim))))
plt_orig = plot(img, title="Original", axis=nothing)
savefig(plt_orig, "www/mnist_original.png")

![A random handwritten 9.](www/mnist_original.png){#fig-nine}

The code below implements the four different approaches one by one. @fig-example shows the resulting counterfactuals. In every case the desired label switch is achieved, that is the corresponding classifier classifies the counterfactual as a four. But arguably from a human perspective only the counterfactuals for the deep ensemble look like a 4. For the MLP, both the generic and the greedy approach generate coutnerfactuals that look much like adversarial examples.

In [None]:
# Generic - MLP
generator = GenericGenerator(0.1,0.1,1e-5,:logitcrossentropy,nothing)
recourse = generate_counterfactual(generator, x̅, 𝑴, target, γ; feasible_range=(0.0,1.0)) # generate recourse
img = convert2image(reshape(recourse.x̲,Int(sqrt(input_dim)),Int(sqrt(input_dim))))
plt_wachter = plot(img, title="MLP - Wachter")

# Greedy - MLP
generator = GreedyGenerator(0.1,15,:logitcrossentropy,nothing)
recourse = generate_counterfactual(generator, x̅, 𝑴, target, γ; feasible_range=(0.0,1.0)) # generate recourse
img = convert2image(reshape(recourse.x̲,Int(sqrt(input_dim)),Int(sqrt(input_dim))))
plt_greedy = plot(img, title="MLP - Greedy")

# Generic - Deep Ensemble
generator = GenericGenerator(0.1,0.1,1e-5,:logitcrossentropy,nothing)
recourse = generate_counterfactual(generator, x̅, 𝑴_ensemble, target, γ; feasible_range=(0.0,1.0)) # generate recourse
img = convert2image(reshape(recourse.x̲,Int(sqrt(input_dim)),Int(sqrt(input_dim))))
plt_wachter_de = plot(img, title="Ensemble - Wachter")

# Greedy - Deep Ensemble
generator = GreedyGenerator(0.1,15,:logitcrossentropy,nothing)
recourse = generate_counterfactual(generator, x̅, 𝑴_ensemble, target, γ; feasible_range=(0.0,1.0)) # generate recourse
img = convert2image(reshape(recourse.x̲,Int(sqrt(input_dim)),Int(sqrt(input_dim))))
plt_greedy_de = plot(img, title="Ensemble - Greedy")

plt_list = [plt_orig, plt_wachter, plt_greedy, plt_wachter_de, plt_greedy_de]
plt = plot(plt_list...,layout=(1,length(plt_list)),axis=nothing, size=(1200,240))
savefig(plt, "www/MNIST_9to4.png")

![Counterfactual explanations for MNIST data: turning a 9 into a 4](www/MNIST_9to4.png){#fig-example}

In [31]:
#| echo: false

using Random

# Single:
function from_digit_to_digit(from, to, generator, model; γ=0.95, x=x, y=y, seed=1234, T=1000)

    Random.seed!(seed)

    candidates = findall(onecold(y,0:9).==from)
    x̅ = Flux.unsqueeze(x[:,rand(candidates)],2)
    target = to + 1
    recourse = generate_counterfactual(generator, x̅, model, target, γ; feasible_range=(0.0,1.0), T=T)

    return recourse
end

# Multiple:
function from_digit_to_digit(from, to, generator::Dict, model::Dict; γ=0.95, x=x, y=y, seed=1234, T=1000)

    Random.seed!(seed)

    candidates = findall(onecold(y,0:9).==from)
    x̅ = Flux.unsqueeze(x[:,rand(candidates)],2)
    target = to + 1
    recourses = Dict()

    for (k_gen,v_gen) ∈ generators
        for (k_mod,v_mod) ∈ models 
            k = k_mod * " - " * k_gen
            recourses[k] = generate_counterfactual(v_gen, x̅, v_mod, target, γ; feasible_range=(0.0,1.0), T=T)
        end
    end

    return recourses
end;

In [32]:
#| echo: false
generators = Dict("Wachter" => GenericGenerator(0.1,1,1e-5,:logitcrossentropy,nothing),"Greedy" => GreedyGenerator(0.1,15,:logitcrossentropy,nothing))
models = Dict("MLP" => 𝑴, "Ensemble" => 𝑴_ensemble);

In [None]:
#| echo: false
from = 3
to = 8
recourses = from_digit_to_digit(from,to,generators,models)
plts =  first(values(recourses)).x̅ |> x -> plot(convert2image(reshape(x,Int(sqrt(input_dim)),Int(sqrt(input_dim)))),title="Original")
plts = vcat(plts, [plot(convert2image(reshape(v.x̲,Int(sqrt(input_dim)),Int(sqrt(input_dim)))),title=k) for (k,v) in recourses])
plt = plot(plts...,layout=(1,length(plts)),axis=nothing, size=(1200,240))
savefig(plt, "www/MNIST_$(from)to$(to).png")

In [None]:
#| echo: false
from = 7
to = 2
recourses = from_digit_to_digit(from,to,generators,models)
plts =  first(values(recourses)).x̅ |> x -> plot(convert2image(reshape(x,Int(sqrt(input_dim)),Int(sqrt(input_dim)))),title="Original")
plts = vcat(plts, [plot(convert2image(reshape(v.x̲,Int(sqrt(input_dim)),Int(sqrt(input_dim)))),title=k) for (k,v) in recourses])
plt = plot(plts...,layout=(1,length(plts)),axis=nothing, size=(1200,240))
savefig(plt, "www/MNIST_$(from)to$(to).png")

In [None]:
#| echo: false
from = 1
to = 7
recourses = from_digit_to_digit(from,to,generators,models)
plts =  first(values(recourses)).x̅ |> x -> plot(convert2image(reshape(x,Int(sqrt(input_dim)),Int(sqrt(input_dim)))),title="Original")
plts = vcat(plts, [plot(convert2image(reshape(v.x̲,Int(sqrt(input_dim)),Int(sqrt(input_dim)))),title=k) for (k,v) in recourses])
plt = plot(plts...,layout=(1,length(plts)),axis=nothing, size=(1200,240))
savefig(plt, "www/MNIST_$(from)to$(to).png")

In [None]:
#| echo: false
from = 9
recourses = map(d -> from_digit_to_digit(from,d,GreedyGenerator(0.1,15,:logitcrossentropy,nothing),𝑴_ensemble;T=2500),filter(x -> x!=from, Vector(0:9)))
plts = [plot(convert2image(reshape(rec.x̲,Int(sqrt(input_dim)),Int(sqrt(input_dim)))),title=rec.target-1) for rec in recourses]
plot(plts...,layout=(1,length(plts)),axis=nothing,size=(1200,500))

### References