```@meta
CurrentModule = AlgorithmicRecourse 
```

# Recourse for multi-class targets

In [104]:
using Flux, Random, Plots, PlotThemes, AlgorithmicRecourse
theme(:wong)
using Logging
disable_logging(Logging.Info)

LogLevel(1)

In [105]:
x, y = toy_data_multi()
X = hcat(x...)
y_train = Flux.onehotbatch(y, unique(y))
y_train = Flux.unstack(y_train',1)
plt = plot()
plt = plot_data!(plt,X',y);
savefig(plt, "www/multi_samples.png")

![](www/multi_samples.png)

## Classifier

In [106]:
n_hidden = 32
out_dim = length(unique(y))
nn = Chain(
    Dense(2, n_hidden),
    Dense(n_hidden, out_dim)
)  
loss(x, y) = Flux.Losses.logitcrossentropy(nn(x), y)
ps = Flux.params(nn)
data = zip(x,y_train);

In [107]:
using Flux.Optimise: update!, ADAM
using Statistics
opt = ADAM()
epochs = 50
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))

using Plots
anim = Animation()
plt = plot(ylim=(0,avg_loss(data)), xlim=(0,epochs), legend=false, xlab="Epoch")
avg_l = []

for epoch = 1:epochs
  for d in data
    gs = gradient(params(nn)) do
      l = loss(d...)
    end
    update!(opt, params(nn), gs)
  end
  avg_l = vcat(avg_l,avg_loss(data))
  plot!(plt, avg_l, color=1, title="Average (training) loss")
  frame(anim, plt)
end

gif(anim, "www/multi_loss.gif");

![](www/multi_loss.gif)

In [108]:
using AlgorithmicRecourse, AlgorithmicRecourse.Models
import AlgorithmicRecourse.Models: logits, probs # import functions in order to extend

# Step 1)
struct NeuralNetwork <: Models.FittedModel
    nn::Any
end

# Step 2)
logits(𝑴::NeuralNetwork, X::AbstractArray) = 𝑴.nn(X)
probs(𝑴::NeuralNetwork, X::AbstractArray)= softmax(logits(𝑴, X))
𝑴 = NeuralNetwork(nn)

NeuralNetwork(Chain(Dense(2, 32), Dense(32, 4)))

In [113]:
plt = plot_contour_multi(X',y,𝑴);
savefig(plt, "www/multi_contour.png")

![](www/multi_contour.png)

In [116]:
# Randomly selected factual:
Random.seed!(1234);
x̅ = X[:,rand(1:size(X)[2])]
y̅ = Flux.onecold(probs(𝑴, x̅),unique(y))
target = rand(unique(y)[1:end .!= y̅]) # opposite label as target
γ = 0.75
# # Define Generator:
# generator = GenericGenerator(0.1,0.1,1e-5,:logitcrossentropy,nothing)
# # Generate recourse:
# recourse = generate_recourse(generator, x̅, 𝑴, target, γ); # generate recourse

3.0