In [1]:
using Distributions
using Random
using Flux, Statistics, ProgressMeter, Plots, TaijaData, Distances
using Flux.Data: DataLoader
using Flux: onehotbatch, onecold, crossentropy, logitcrossentropy, mse, throttle, update!, push!
using Base.Iterators: repeated, partition
using LinearAlgebra: norm
using CounterfactualExplanations
using Distances
using BSON

include("../utils/train.jl")
include("../utils/plot.jl")
include("../utils/evaluate.jl")



evaluate_model (generic function with 1 method)

## Adversarial attack algorithms

In [2]:
function FGSM(model, loss, x, y; ϵ = 0.3, clamp_range = (0, 1))
    grads = gradient(x -> loss(model(x), y), x)[1]
    x_adv = clamp.(x + (Float32(ϵ) * sign.(grads)), clamp_range...)
    return x_adv
end

function PGD(model, loss, x, y; ϵ = 0.3, step_size = 0.01, iterations = 40, clamp_range = (0, 1))
    x_adv = clamp.(x + (randn(Float32, size(x)...) * Float32(step_size)), clamp_range...); # start from the random point
    δ = Distances.chebyshev(x, x_adv)
    iteration = 1; while (δ < ϵ) && iteration <= iterations
        x_adv = FGSM(model, loss, x_adv, y; ϵ = step_size, clamp_range = clamp_range)
        δ = chebyshev(x, x_adv)
        iteration += 1
    end
    return x_adv
end

PGD (generic function with 1 method)

In [3]:
Random.seed!(43)

X, y = TaijaData.load_california_housing()
# X, y = TaijaData.load_gmsc()
# X, y = TaijaData.load_uci_adult()
# X, y = TaijaData.load_credit_default()
# X, y = TaijaData.load_multi_class(; centers=4)
# X, y = TaijaData.load_overlapping()

┌ Info: Training machine(Standardizer(features = Symbol[], …), …).
└ @ MLJBase C:\Users\Hp\.julia\packages\MLJBase\hoZmq\src\machines.jl:499


([0.6267953329549938 1.2492793703552236 … 3.517855099280236 1.608911071881494; 0.2671551191297572 -1.4013319680081837 … -0.7657178395746824 -0.05065194508699348; … ; -0.6797471205432204 -0.8201877867819832 … -0.6984725427083885 -0.9793538751859174; 0.5088884466495056 0.7983669914605189 … 0.48892440907632917 0.903178188719676], [1, 1, 1, 1, 1, 1, 0, 1, 1, 0  …  0, 1, 0, 1, 1, 0, 0, 1, 1, 1])

In [4]:
n = size(X, 2)
shuffled_indices = shuffle(1:n)
train_ratio = 0.8

n_train = Int(floor(train_ratio * n))
train_indices = shuffled_indices[1:n_train]
test_indices = shuffled_indices[n_train + 1:end]

x_train = X[:, train_indices]
y_train = y[train_indices]

x_test = X[:, test_indices]
y_test = y[test_indices]

1000-element Vector{Int64}:
 1
 0
 0
 1
 1
 1
 1
 0
 0
 1
 ⋮
 0
 1
 0
 1
 0
 0
 0
 0
 0

In [18]:
extrema(X)

(-2.360941954494169, 119.38737362611174)

In [19]:
mean(X)

0.00378496447909859

In [20]:
std(X)

1.1534893086401663

In [5]:
X[:, 19]

8-element Vector{Float64}:
 -0.35395138338882964
  1.1411245457258214
 -0.661790490598118
 -0.05370511582762124
 -0.6475900725734871
 -0.11984216643777258
 -1.3538623184892868
  1.2026387523172812

In [9]:
model = Chain(
    Dense(8, 10, relu; init=Flux.glorot_normal),
    Dense(10, 2; init=Flux.glorot_normal)
)

adv_pgd_strong = deepcopy(model)
adv_pgd_medium = deepcopy(model)
adv_pgd_weak = deepcopy(model)

spare = deepcopy(model)

Chain(
  Dense(8 => 10, relu),                 [90m# 90 parameters[39m
  Dense(10 => 2),                       [90m# 22 parameters[39m
) [90m                  # Total: 4 arrays, [39m112 parameters, 704 bytes.

In [8]:
loss(x, y) = logitcrossentropy(x, y) # Not defining softmax in the model to help with CE
batch_size = 32
epochs = 20
clamp_range = extrema(X)
opt = ADAM()

Adam(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

In [11]:
count(a -> a == 1, y_test)

517

### Training our models

#### Clean model (no adversarial training)

In [12]:
vanilla_losses = vanilla_train(model, loss, opt, x_train, y_train, epochs, batch_size, 0, 1)

Epoch: 1
Average loss: 0.6015058772563935


│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(8 => 10, relu)
│   summary(x) = 8×32 Matrix{Float64}
└ @ Flux C:\Users\Hp\.julia\packages\Flux\Wz6D4\src\layers\stateless.jl:60


Epoch: 2
Average loss: 0.5146374053955078


[32mProgress:  10%|█████                                    |  ETA: 0:02:02[39m[K

Epoch: 3
Average loss: 0.47426541805267336
Epoch: 4
Average loss: 0.45124980902671813
Epoch: 5
Average loss: 0.4358201282024384
Epoch: 6
Average loss: 0.4227229048013687
Epoch: 7
Average loss: 0.41268244457244874
Epoch: 8
Average loss: 0.40472635710239413
Epoch: 9
Average loss: 0.39845274126529695
Epoch: 10
Average loss: 0.3921532881259918
Epoch: 11
Average loss: 0.38660020542144774
Epoch: 12
Average loss: 0.3825889868736267
Epoch: 13
Average loss: 0.3790155427455902
Epoch: 14
Average loss: 0.3747734922170639
Epoch: 15
Average loss: 0.3711442241668701
Epoch: 16
Average loss: 0.36917440271377566
Epoch: 17
Average loss: 0.36700569051504134
Epoch: 18
Average loss: 0.3645698983669281
Epoch: 19
Average loss: 0.36261710727214813
Epoch: 20
Average loss: 0.36039947867393496

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:14[39m[K





20-element Vector{Any}:
 0.6015058772563935
 0.5146374053955078
 0.47426541805267336
 0.45124980902671813
 0.4358201282024384
 0.4227229048013687
 0.41268244457244874
 0.40472635710239413
 0.39845274126529695
 0.3921532881259918
 0.38660020542144774
 0.3825889868736267
 0.3790155427455902
 0.3747734922170639
 0.3711442241668701
 0.36917440271377566
 0.36700569051504134
 0.3645698983669281
 0.36261710727214813
 0.36039947867393496

#### Adversarial training: PGD with epsilon 0.1 and 13 iterations of 0.01 step size

In [14]:
adversarial_losses_strong = adversarial_train(adv_pgd_strong, loss, opt, x_train, y_train, epochs, batch_size, PGD, 0, 1, 0.3; attack_method=:PGD, iterations=40, step_size=0.01, clamp_range=clamp_range)

Epoch: 1
Average loss: 1.1794280576705933
Epoch: 2


[32mProgress:  10%|█████                                    |  ETA: 0:00:02[39m[K

Average loss: 1.1791986560821532
Epoch: 3


[32mProgress:  15%|███████                                  |  ETA: 0:00:02[39m[K

Average loss: 1.1780894556045531
Epoch: 4

[32mProgress:  20%|█████████                                |  ETA: 0:00:02[39m[K


Average loss: 1.177021487236023
Epoch: 5


[32mProgress:  25%|███████████                              |  ETA: 0:00:02[39m[K

Average loss: 1.176909496307373
Epoch: 6

[32mProgress:  30%|█████████████                            |  ETA: 0:00:02[39m[K


Average loss: 1.176136549949646
Epoch: 7

[32mProgress:  35%|███████████████                          |  ETA: 0:00:01[39m[K


Average loss: 1.1741202688217163
Epoch: 8


[32mProgress:  40%|█████████████████                        |  ETA: 0:00:01[39m[K

Average loss: 1.174392972946167
Epoch: 9

[32mProgress:  45%|███████████████████                      |  ETA: 0:00:01[39m[K


Average loss: 1.1731367793083192
Epoch: 10
Average loss: 1.1734899697303771
Epoch: 11


[32mProgress:  55%|███████████████████████                  |  ETA: 0:00:01[39m[K

Average loss: 1.172388087272644
Epoch: 12

[32mProgress:  60%|█████████████████████████                |  ETA: 0:00:01[39m[K


Average loss: 1.1722591047286988
Epoch: 13


[32mProgress:  65%|███████████████████████████              |  ETA: 0:00:01[39m[K

Average loss: 1.172169141292572


[32mProgress:  70%|█████████████████████████████            |  ETA: 0:00:01[39m[K

Epoch: 14
Average loss: 1.1710399956703186
Epoch: 15

[32mProgress:  75%|███████████████████████████████          |  ETA: 0:00:01[39m[K


Average loss: 1.1709223408699037
Epoch: 16


[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:00:00[39m[K

Average loss: 1.1710241751670838
Epoch: 17
Average loss: 1.1702409310340882
Epoch: 18


[32mProgress:  90%|█████████████████████████████████████    |  ETA: 0:00:00[39m[K

Average loss: 1.1705050539970399
Epoch: 19


[32mProgress:  95%|███████████████████████████████████████  |  ETA: 0:00:00[39m[K

Average loss: 1.1703418545722961
Epoch: 20


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:02[39m[K


Average loss: 1.1702612476348877


20-element Vector{Any}:
 1.1794280576705933
 1.1791986560821532
 1.1780894556045531
 1.177021487236023
 1.176909496307373
 1.176136549949646
 1.1741202688217163
 1.174392972946167
 1.1731367793083192
 1.1734899697303771
 1.172388087272644
 1.1722591047286988
 1.172169141292572
 1.1710399956703186
 1.1709223408699037
 1.1710241751670838
 1.1702409310340882
 1.1705050539970399
 1.1703418545722961
 1.1702612476348877

#### Adversarial training: PGD with epsilon 0.05 and 7 iterations of 0.01 step size

In [15]:
adversarial_losses_medium = adversarial_train(adv_pgd_medium, loss, opt, x_train, y_train, epochs, batch_size, PGD, 0, 1, 0.1; attack_method=:PGD, iterations=13, step_size=0.01, clamp_range=clamp_range)

Epoch: 1
Average loss: 1.3308736066818236
Epoch: 2
Average loss: 1.1531618180274963
Epoch: 3


[32mProgress:  15%|███████                                  |  ETA: 0:00:01[39m[K

Average loss: 1.0817458477020263
Epoch: 4
Average loss: 1.049683916091919
Epoch: 5
Average loss: 1.0303350548744201
Epoch: 6

[32mProgress:  30%|█████████████                            |  ETA: 0:00:01[39m[K


Average loss: 1.0206252689361572
Epoch: 7
Average loss: 1.0139870595932008
Epoch: 8
Average loss: 1.0091471314430236
Epoch: 9


[32mProgress:  45%|███████████████████                      |  ETA: 0:00:00[39m[K

Average loss: 1.004466987133026
Epoch: 10
Average loss: 0.9995256018638611
Epoch: 11
Average loss: 0.9947215929031372
Epoch: 12


[32mProgress:  60%|█████████████████████████                |  ETA: 0:00:00[39m[K

Average loss: 0.9906098599433899
Epoch: 13
Average loss: 0.9862371153831482
Epoch: 14
Average loss: 0.9848227453231811
Epoch: 15


[32mProgress:  75%|███████████████████████████████          |  ETA: 0:00:00[39m[K

Average loss: 0.9827404928207397
Epoch: 16
Average loss: 0.9797001795768738
Epoch: 17
Average loss: 0.9771924352645874
Epoch: 18

[32mProgress:  90%|█████████████████████████████████████    |  ETA: 0:00:00[39m[K


Average loss: 0.9763699388504028
Epoch: 19
Average loss: 0.9743294353485108
Epoch: 20

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m[K



Average loss: 0.9726215448379517


20-element Vector{Any}:
 1.3308736066818236
 1.1531618180274963
 1.0817458477020263
 1.049683916091919
 1.0303350548744201
 1.0206252689361572
 1.0139870595932008
 1.0091471314430236
 1.004466987133026
 0.9995256018638611
 0.9947215929031372
 0.9906098599433899
 0.9862371153831482
 0.9848227453231811
 0.9827404928207397
 0.9797001795768738
 0.9771924352645874
 0.9763699388504028
 0.9743294353485108
 0.9726215448379517

#### Adversarial training: PGD with epsilon 0.01 and 13 iterations of 0.001 step size

In [16]:
adversarial_losses_weak = adversarial_train(adv_pgd_weak, loss, opt, x_train, y_train, epochs, batch_size, PGD, 0, 1, 0.05; attack_method=:PGD, iterations=7, step_size=0.01, clamp_range=clamp_range)

Epoch: 1
Average loss: 1.2436175413131714
Epoch: 2
Average loss: 1.0704891090393067
Epoch: 3


[32mProgress:  25%|███████████                              |  ETA: 0:00:00[39m[K

Average loss: 0.9979364490509033
Epoch: 4
Average loss: 0.962753788471222
Epoch: 5
Average loss: 0.9396940665245056
Epoch: 6
Average loss: 0.9238256707191467
Epoch: 7
Average loss: 0.9154068455696106
Epoch: 8


[32mProgress:  60%|█████████████████████████                |  ETA: 0:00:00[39m[K

Average loss: 0.9014751884937287
Epoch: 9
Average loss: 0.8933466787338257
Epoch: 10
Average loss: 0.8875015015602112
Epoch: 11
Average loss: 0.8836750316619874
Epoch: 12
Average loss: 0.8749061369895935


[32mProgress:  85%|███████████████████████████████████      |  ETA: 0:00:00[39m[K

Epoch: 13
Average loss: 0.8728516874313355
Epoch: 14
Average loss: 0.8642656865119934
Epoch: 15
Average loss: 0.8619214978218078
Epoch: 16
Average loss: 0.8599686172008515
Epoch: 17
Average loss: 0.8565892689228057
Epoch: 18


[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00[39m[K


Average loss: 0.8543297114372254
Epoch: 19
Average loss: 0.8476223366260529
Epoch: 20
Average loss: 0.8463143496513367


20-element Vector{Any}:
 1.2436175413131714
 1.0704891090393067
 0.9979364490509033
 0.962753788471222
 0.9396940665245056
 0.9238256707191467
 0.9154068455696106
 0.9014751884937287
 0.8933466787338257
 0.8875015015602112
 0.8836750316619874
 0.8749061369895935
 0.8728516874313355
 0.8642656865119934
 0.8619214978218078
 0.8599686172008515
 0.8565892689228057
 0.8543297114372254
 0.8476223366260529
 0.8463143496513367

### Evaluating robustness: accuracy on clean and adversarial examples

#### Clean model

In [16]:
overview = evaluate_model(x_test, y_test, 0, 1, clean_model, loss, PGD, 0.2; iterations=26, step_size=0.01, attack_method=:PGD, clamp_range=clamp_range)

overview

Dict{Any, Any} with 4 entries:
  "adversary_forced_error" => 640
  "clean_accuracy"         => 0.857
  "adversary_changed"      => 640
  "adversarial_accuracy"   => 0.217

In [17]:
pred1 = 0
pred0 = 0
corr1 = 0
corr0 = 0

for index = 1:1000
    datapt = x_test[:, index]
    actual = y_test[index]
    prediction = (model(datapt) |> Flux.onecold |> getindex) - 1

    if prediction == 1
        pred1 += 1
    else
        pred0 += 1
    end

    if prediction == 1 && actual == 1
        corr1 += 1
    elseif prediction == 0 && actual == 0
        corr0 += 1
    end
end

println("times 1 predicted: ", pred1)
println("times 0 predicted: ", pred0)
println("times 1 was correct: ", corr1)
println("times 0 was correct: ", corr0)


times 1 predicted: 496
times 0 predicted: 504
times 1 was correct: 426
times 0 was correct: 440


#### Adversarial models

In [22]:
# model_to_use = adv_pgd_strong
# model_to_use = adv_pgd_medium
model_to_use = adv_pgd_weak

overview_adv = evaluate_model(x_test, y_test, 0, 1, model_to_use, loss, PGD, 0.2; step_size=0.01, iterations=26, attack_method= :PGD, clamp_range = clamp_range)

overview_adv

Dict{Any, Any} with 4 entries:
  "adversary_forced_error" => 490
  "clean_accuracy"         => 0.84
  "adversary_changed"      => 490
  "adversarial_accuracy"   => 0.35

In [22]:
pred1 = 0
pred0 = 0
corr1 = 0
corr0 = 0

for index = 1:1000
    datapt = x_test[:, index]
    actual = y_test[index]
    prediction = (model_to_use(datapt) |> Flux.onecold |> getindex) - 1

    if prediction == 1
        pred1 += 1
    elseif prediction == 0
        pred0 += 1
    end

    if prediction == 1 && actual == 1
        corr1 += 1
    elseif prediction == 0 && actual == 0
        corr0 += 1
    end
end

println("times 1 predicted: ", pred1)
println("times 0 predicted: ", pred0)
println("times 1 was correct: ", corr1)
println("times 0 was correct: ", corr0)

times 1 predicted: 397
times 0 predicted: 603
times 1 was correct: 322
times 0 was correct: 435


In [107]:
using BSON: @save, @load

@save "../models/CaliHousing/clean_20ep_32bs.bson" model
@save "../models/CaliHousing/adv_20ep_32bs_40it_0.01ss_0.3eps.bson" adv_pgd_strong
@save "../models/CaliHousing/adv_20ep_32bs_13it_0.01ss_0.1eps.bson" adv_pgd_medium
@save "../models/CaliHousing/adv_20ep_32bs_7it_0.01ss_0.05eps.bson" adv_pgd_weak

### Counterfactual generation

In [6]:
clean_model = BSON.load("../models/CaliHousing/clean_20ep_32bs.bson")[:model]
adv_pgd_strong = BSON.load("../models/CaliHousing/adv_20ep_32bs_40it_0.01ss_0.3eps.bson")[:adv_pgd_strong]
adv_pgd_medium = BSON.load("../models/CaliHousing/adv_20ep_32bs_13it_0.01ss_0.1eps.bson")[:adv_pgd_medium]
adv_pgd_weak = BSON.load("../models/CaliHousing/adv_20ep_32bs_7it_0.01ss_0.05eps.bson")[:adv_pgd_weak]

Chain(
  Dense(8 => 10, relu),                 [90m# 90 parameters[39m
  Dense(10 => 2),                       [90m# 22 parameters[39m
) [90m                  # Total: 4 arrays, [39m112 parameters, 704 bytes.

### Measure of Implausibility

In [24]:
using Distances

function distance_from_targets(
    ce::AbstractCounterfactualExplanation;
    agg = mean,
    n_nearest_neighbors::Union{Int,Nothing} = 100,
    p::Int = 2,
)
    target_idx = ce.data.output_encoder.labels .== ce.target
    target_samples = ce.data.X[:, target_idx]
    x′ = CounterfactualExplanations.counterfactual(ce)
    loss = map(eachslice(x′, dims = ndims(x′))) do x
        Δ = map(eachcol(target_samples)) do xsample
            norm(x - xsample, p)
            # euclidean(x, xsample)
            # chebyshev(x, xsample)
        end
        if n_nearest_neighbors != nothing
            Δ = sort(Δ)[1:n_nearest_neighbors]
        end
        return mean(Δ)
    end
    loss = agg(loss)[1]

    return loss

end

function custom_euclidean(counterfactual, label, x_train, y_train, n_datapoints)
    indices = findall(x -> x == label, y_train)
    random_indices = rand(indices, n_datapoints)

    distance = 0

    for index in random_indices
        distance += norm(counterfactual - x_train[:, index])
    end

    return distance/n_datapoints
end

custom_euclidean (generic function with 1 method)

In [25]:
flux_clean = CounterfactualExplanations.MLP(clean_model; likelihood=:classification_multi)
flux_adv_strong = CounterfactualExplanations.MLP(adv_pgd_strong; likelihood=:classification_multi)
flux_adv_medium = CounterfactualExplanations.MLP(adv_pgd_medium; likelihood=:classification_multi) 
flux_adv_weak = CounterfactualExplanations.MLP(adv_pgd_weak; likelihood=:classification_multi) 

CounterfactualExplanations.Models.Model(Chain(Dense(8 => 10, relu), Dense(10 => 2)), :classification_multi, Chain(Dense(8 => 10, relu), Dense(10 => 2)), MLP())

### Random counterfactual

In [26]:
using CounterfactualExplanations.Evaluation: evaluate, validity

# model_to_use = flux_clean
model_to_use = flux_adv_strong
# model_to_use = flux_adv_medium
# model_to_use = flux_adv_weak

# random point's Counterfactual
index = rand(1:1000)
label = y_test[index]
println("actual label: ", label)

different_label = (label == 1 ? 0 : 1)

counterfactual_data = CounterfactualData(x_train, y_train)
generator = CounterfactualExplanations.ECCoGenerator(; λ=[0.1, 0.1])

counterfactual_data.domain = [extrema(X) for var in counterfactual_data.features_continuous]
# convergence = CounterfactualExplanations.DecisionThresholdConvergence(decision_threshold=0.5, max_iter=1000)
convergence = CounterfactualExplanations.GeneratorConditionsConvergence(decision_threshold=0.5, max_iter=1000)
 
ce = generate_counterfactual(
        reshape(x_test[:, index], 8, 1), different_label, counterfactual_data, model_to_use, generator; num_counterfactuals=1, convergence=convergence
    )

println(ce)

ces = CounterfactualExplanations.counterfactual(ce)
cf = ces[:, 1]

# inf_model = model
inf_model = adv_pgd_strong
# inf_model = adv_pgd_medium
# inf_model = adv_pgd_weak

label_reached = (inf_model(cf) |> Flux.onecold |> getindex) - 1

println("distance: ", distance_from_targets(ce))
println("label to reach: ", different_label)
println("label reached: ", label_reached)
println("did CE cross boundary? ", (label_reached == different_label))
println("valid? ", evaluate(ce; measure=validity))

actual label: 0


│   The input will be converted, but any earlier layers may be very slow.
│   layer = Dense(8 => 10, relu)
│   summary(x) = 8×1 Matrix{Float64}
└ @ Flux C:\Users\Hp\.julia\packages\Flux\Wz6D4\src\layers\stateless.jl:60


CounterfactualExplanation
[0m[1mConvergence: ❌ after 1000 steps.[22m
distance: 1.4421449133726187
label to reach: 1
label reached: 1
did CE cross boundary? true
valid? [[1.0]]


### Experimental Setup

#### Setting up data splits

In [27]:
function generate_split(num_datapoints)
    rand(1:num_datapoints, 100)
end

function target_label(label)
    return label == 1 ? 0 : 1
end

splits = [generate_split(length(y_test)) for _ in 1:5]
targets = [[target_label(y_test[num]) for num in split] for split in splits]

5-element Vector{Vector{Int64}}:
 [1, 0, 0, 0, 1, 1, 1, 1, 0, 1  …  1, 0, 0, 1, 1, 0, 1, 0, 1, 1]
 [0, 0, 1, 1, 0, 0, 0, 1, 1, 0  …  0, 1, 1, 0, 1, 0, 1, 0, 0, 0]
 [1, 1, 0, 0, 0, 0, 1, 0, 0, 1  …  0, 1, 0, 1, 0, 1, 1, 0, 0, 0]
 [0, 0, 0, 1, 1, 1, 1, 1, 0, 1  …  1, 0, 0, 1, 0, 1, 1, 1, 1, 1]
 [1, 0, 1, 0, 1, 0, 0, 0, 0, 1  …  1, 1, 1, 0, 1, 0, 0, 1, 1, 0]

In [37]:
flux_clean = CounterfactualExplanations.MLP(clean_model; likelihood=:classification_multi)
flux_adv_strong = CounterfactualExplanations.MLP(adv_pgd_strong; likelihood=:classification_multi)
flux_adv_medium = CounterfactualExplanations.MLP(adv_pgd_medium; likelihood=:classification_multi) 
flux_adv_weak = CounterfactualExplanations.MLP(adv_pgd_weak; likelihood=:classification_multi) 

counterfactual_data = CounterfactualData(x_train, y_train)
counterfactual_data.domain = [extrema(X) for var in counterfactual_data.features_continuous]
generator = CounterfactualExplanations.ECCoGenerator(; λ=[0.1, 0.1])
# generator = CounterfactualExplanations.ECCoGenerator(; λ=[0.1, 0.5])
convergences = []
push!(convergences, CounterfactualExplanations.GeneratorConditionsConvergence(decision_threshold=0.5,max_iter=1000))
push!(convergences, CounterfactualExplanations.DecisionThresholdConvergence(decision_threshold=0.5, max_iter=1000))

convergences

2-element Vector{Any}:
 CounterfactualExplanations.Convergence.GeneratorConditionsConvergence(0.5, 0.01, 1000, 0.75)
 CounterfactualExplanations.Convergence.DecisionThresholdConvergence(0.5, 1000, 0.75)

In [38]:
mean_implausibilities_clean_dt = []
mean_implausibilities_clean_gc = []
total_validity_dt = 0
total_validity_gc = 0
model_to_use = clean_model
skipped_clean = 0

@showprogress for (i, split) in enumerate(splits)
    println("here!")
    implausibilities_dt = []
    implausibilities_gc = []
    for (j, index) in enumerate(split)

        model_pred = (model_to_use(x_test[:, index]) |> Flux.onecold |> getindex) - 1
        if model_pred != y_test[index]
            skipped_clean += 1
            println("Skipping because of model misclassification")
            continue
        end

        if (j % 10 == 0)
            println("datapoint $j of split $i reached")
        end

        count = 0

        for convergence in convergences
            ce = generate_counterfactual(
            reshape(x_test'[index, :], 8, 1), targets[i][j], counterfactual_data, flux_clean, generator; num_counterfactuals=1, convergence=convergence
            )

            implausibility = distance_from_targets(ce)

            if count == 0
                total_validity_gc += evaluate(ce; measure=validity)[1][1]
                push!(implausibilities_gc, implausibility)
            elseif count == 1
                total_validity_dt += evaluate(ce; measure=validity)[1][1]
                push!(implausibilities_dt, implausibility)
            end
            
            count += 1
        end
    end
    push!(mean_implausibilities_clean_dt, mean(implausibilities_dt))
    push!(mean_implausibilities_clean_gc, mean(implausibilities_gc))
end

mean_implausibilities_clean_dt

here!
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 20 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 1 reached
Skipping because of model misclassification
datapoint 40 of split 1 reached
datapoint 50 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 60 of split 1 reached
datapoint 70 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 80 of split 1 reached
Skipping because of model misclassification
datapoint 90 of split 1 reached
Skipping because of model misclassification
da

[32mProgress:  40%|█████████████████                        |  ETA: 0:07:05[39m[K

here!
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 20 of split 3 reached
datapoint 30 of split 3 reached
datapoint 40 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 50 of split 3 reached
Skipping because of model misclassification
datapoint 60 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 70 of split 3 reached
Skipping because of model misclassification
datapoint 80 of split 3 reached
Skipping because of model misclassification
datapoint 90 of split 3 reached
datapoint 100 of split 3 reached


[32mProgress:  60%|█████████████████████████                |  ETA: 0:04:45[39m[K

here!
datapoint 10 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 4 reached
datapoint 40 of split 4 reached
Skipping because of model misclassification
datapoint 50 of split 4 reached
Skipping because of model misclassification
datapoint 60 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 70 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 80 of split 4 reached
Skipping because of model misclassification
datapoint 90 of split 4 reached
datapoint 100 of split 4 reached


[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:02:23[39m[K

here!
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 20 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 40 of split 5 reached
datapoint 50 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 60 of split 5 reached
datapoint 70 of split 5 reached
datapoint 80 of split 5 reached
Skipping because of model misclassification
datapoint 90 of split 5 reached
Skipping because of model misclassification
datapoint 100 of split 5 reached

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:11:57[39m[K





5-element Vector{Any}:
 1.4267614255115362
 1.5246982977274603
 1.476316894570456
 1.4781864458496086
 1.5886510148647106

In [39]:
println("skips: ", skipped_clean)

println("mean implausibilities using DTC: ", mean_implausibilities_clean_dt)
println("mean implausibilities using GCC: ", mean_implausibilities_clean_gc)

println("Using Decision Threshold Convergence:")

println("valid switches: ", total_validity_dt)
println("mean: ", mean(mean_implausibilities_clean_dt))
println("std: ", std(mean_implausibilities_clean_dt))

println("Using Generator Conditions Convergence:")

println("valid switches: ", total_validity_gc)
println("mean: ", mean(mean_implausibilities_clean_gc))
println("std: ", std(mean_implausibilities_clean_gc))

skips: 63
mean implausibilities using DTC: Any[1.4267614255115362, 1.5246982977274603, 1.476316894570456, 1.4781864458496086, 1.5886510148647106]
mean implausibilities using GCC: Any[1.8970965886924729, 2.029216841747039, 2.051623454582683, 1.9991429317202878, 2.083095632022346]
Using Decision Threshold Convergence:
valid switches: 436.0
mean: 1.4989228157047543
std: 0.06095862559496609
Using Generator Conditions Convergence:
valid switches: 436.0
mean: 2.012035089752966
std: 0.07121993868269524


In [40]:
mean_implausibilities_strong_dt = []
mean_implausibilities_strong_gc = []
total_validity_strong_dt = 0
total_validity_strong_gc = 0
model_to_use = adv_pgd_strong
skipped_strong = 0

@showprogress for (i, split) in enumerate(splits)
    println("here!")
    implausibilities_dt = []
    implausibilities_gc = []
    for (j, index) in enumerate(split)

        model_pred = (model_to_use(x_test[:, index]) |> Flux.onecold |> getindex) - 1
        if model_pred != y_test[index]
            skipped_strong += 1
            println("Skipping because of model misclassification")
            continue
        end

        if (j % 10 == 0)
            println("datapoint $j of split $i reached")
        end

        count = 0

        for convergence in convergences
            ce = generate_counterfactual(
            reshape(x_test'[index, :], 8, 1), targets[i][j], counterfactual_data, flux_adv_strong, generator; num_counterfactuals=1, convergence=convergence
            )

            implausibility = distance_from_targets(ce)

            if count == 0
                total_validity_strong_gc += evaluate(ce; measure=validity)[1][1]
                push!(implausibilities_gc, implausibility)
            elseif count == 1
                total_validity_strong_dt += evaluate(ce; measure=validity)[1][1]
                push!(implausibilities_dt, implausibility)
            end
            
            count += 1
        end
    end
    push!(mean_implausibilities_strong_dt, mean(implausibilities_dt))
    push!(mean_implausibilities_strong_gc, mean(implausibilities_gc))
end

mean_implausibilities_strong_dt

here!
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 1 reached
Skipping because of model misclassification
datapoint 20 of split 1 reached
Skipping because of model misclassification
datapoint 30 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 40 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 50 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 60 of split 1 reached
datapoint 70 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 80 of split 1 reached
Skipping because of model misclass

[32mProgress:  40%|█████████████████                        |  ETA: 0:06:15[39m[K

here!
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 40 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 50 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 60 of split 3 reached
Skipping because of model misclassification
datapoint 70 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 80 of split 3 reached
Skipping because of mo

[32mProgress:  60%|█████████████████████████                |  ETA: 0:04:12[39m[K

here!
Skipping because of model misclassification
datapoint 10 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 4 reached
datapoint 40 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 50 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 60 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 70 of split 4 reached
Skipping b

[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:02:06[39m[K


here!
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 20 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 40 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 50 of split 5 reached
Skipping because of model misclassification
datapoint 70 of split 5 reached
Skipping 

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:10:38[39m[K


5-element Vector{Any}:
 1.1249032601576487
 1.1697012552973238
 1.1731097891522333
 1.0750806253737086
 1.273030417344753

In [41]:
println("skips: ", skipped_strong)

println("mean implausibilities using DTC: ", mean_implausibilities_strong_dt)
println("mean implausibilities using GCC: ", mean_implausibilities_strong_gc)

println("Using Decision Threshold Convergence:")

println("valid switches: ", total_validity_strong_dt)
println("mean: ", mean(mean_implausibilities_strong_dt))
println("std: ", std(mean_implausibilities_strong_dt))

println("Using Generator Conditions Convergence:")

println("valid switches: ", total_validity_strong_gc)
println("mean: ", mean(mean_implausibilities_strong_gc))
println("std: ", std(mean_implausibilities_strong_gc))

skips: 108
mean implausibilities using DTC: Any[1.1249032601576487, 1.1697012552973238, 1.1731097891522333, 1.0750806253737086, 1.273030417344753]
mean implausibilities using GCC: Any[1.5187181720725076, 1.545800406612234, 1.568100018658987, 1.4743692675998696, 1.6121980293570308]
Using Decision Threshold Convergence:
valid switches: 392.0
mean: 1.1631650694651334
std: 0.07320322599103032
Using Generator Conditions Convergence:
valid switches: 392.0
mean: 1.543837178860126
std: 0.051774754310936345


In [42]:
mean_implausibilities_medium_dt = []
mean_implausibilities_medium_gc = []
total_validity_medium_dt = 0
total_validity_medium_gc = 0
model_to_use = adv_pgd_medium
skipped_medium = 0

@showprogress for (i, split) in enumerate(splits)
    println("here!")
    implausibilities_dt = []
    implausibilities_gc = []
    for (j, index) in enumerate(split)

        model_pred = (model_to_use(x_test[:, index]) |> Flux.onecold |> getindex) - 1
        if model_pred != y_test[index]
            skipped_medium += 1
            println("Skipping because of model misclassification")
            continue
        end

        if (j % 10 == 0)
            println("datapoint $j of split $i reached")
        end

        count = 0

        for convergence in convergences
            ce = generate_counterfactual(
            reshape(x_test'[index, :], 8, 1), targets[i][j], counterfactual_data, flux_adv_medium, generator; num_counterfactuals=1, convergence=convergence
            )

            implausibility = distance_from_targets(ce)

            if count == 0
                total_validity_medium_gc += evaluate(ce; measure=validity)[1][1]
                push!(implausibilities_gc, implausibility)
            elseif count == 1
                total_validity_medium_dt += evaluate(ce; measure=validity)[1][1]
                push!(implausibilities_dt, implausibility)
            end
            
            count += 1
        end
    end
    push!(mean_implausibilities_medium_dt, mean(implausibilities_dt))
    push!(mean_implausibilities_medium_gc, mean(implausibilities_gc))
end

mean_implausibilities_medium_dt

here!
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 1 reached
Skipping because of model misclassification
datapoint 20 of split 1 reached
datapoint 30 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 40 of split 1 reached
Skipping because of model misclassification
datapoint 50 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 60 of split 1 reached
datapoint 70 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 80 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 90 of split 1 reached
datapoint 100 of split 1 reached
here!
Skippin

[32mProgress:  40%|█████████████████                        |  ETA: 0:07:07[39m[K

here!
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 20 of split 3 reached
datapoint 30 of split 3 reached
Skipping because of model misclassification
datapoint 40 of split 3 reached
Skipping because of model misclassification
datapoint 50 of split 3 reached
Skipping because of model misclassification
datapoint 60 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 70 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 80 of split 3 reached
Skipping because of model misclassification
datapoint 90 of split 3 reached
Skipping because of model misclassification
Sk

[32mProgress:  60%|█████████████████████████                |  ETA: 0:04:39[39m[K

here!
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 4 reached
datapoint 40 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 50 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 60 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 70 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping b

[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:02:17[39m[K

here!
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 20 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 70 of split 5 reached
Skipping because of model misclassification
datapoint 80 of split 5 reached
Skipping because of model misclassificatio

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:11:17[39m[K





5-element Vector{Any}:
 1.1495166238552401
 1.2144272361734945
 1.2003238366939815
 1.1653396055424623
 1.3173814516920654

In [43]:
println("skips: ", skipped_medium)

println("mean implausibilities using DTC: ", mean_implausibilities_medium_dt)
println("mean implausibilities using GCC: ", mean_implausibilities_medium_gc)

println("Using Decision Threshold Convergence:")

println("valid switches: ", total_validity_medium_dt)
println("mean: ", mean(mean_implausibilities_medium_dt))
println("std: ", std(mean_implausibilities_medium_dt))

println("Using Generator Conditions Convergence:")

println("valid switches: ", total_validity_medium_gc)
println("mean: ", mean(mean_implausibilities_medium_gc))
println("std: ", std(mean_implausibilities_medium_gc))

skips: 88
mean implausibilities using DTC: Any[1.1495166238552401, 1.2144272361734945, 1.2003238366939815, 1.1653396055424623, 1.3173814516920654]
mean implausibilities using GCC: Any[1.6993332830514305, 1.8939813805394716, 1.9412652155218406, 1.8185770893336735, 2.0299497927689982]
Using Decision Threshold Convergence:
valid switches: 412.0
mean: 1.2093977507914486
std: 0.06575519148018161
Using Generator Conditions Convergence:
valid switches: 412.0
mean: 1.8766213522430828
std: 0.1252896706143049


In [44]:
mean_implausibilities_weak_dt = []
mean_implausibilities_weak_gc = []
total_validity_weak_dt = 0
total_validity_weak_gc = 0
model_to_use = adv_pgd_weak
skipped_weak = 0

@showprogress for (i, split) in enumerate(splits)
    println("here!")
    implausibilities_dt = []
    implausibilities_gc = []
    for (j, index) in enumerate(split)

        model_pred = (model_to_use(x_test[:, index]) |> Flux.onecold |> getindex) - 1
        if model_pred != y_test[index]
            skipped_weak += 1
            println("Skipping because of model misclassification")
            continue
        end

        if (j % 10 == 0)
            println("datapoint $j of split $i reached")
        end

        count = 0

        for convergence in convergences
            ce = generate_counterfactual(
            reshape(x_test'[index, :], 8, 1), targets[i][j], counterfactual_data, flux_adv_weak, generator; num_counterfactuals=1, convergence=convergence
            )

            implausibility = distance_from_targets(ce)

            if count == 0
                total_validity_weak_gc += evaluate(ce; measure=validity)[1][1]
                push!(implausibilities_gc, implausibility)
            elseif count == 1
                total_validity_weak_dt += evaluate(ce; measure=validity)[1][1]
                push!(implausibilities_dt, implausibility)
            end
            
            count += 1
        end
    end
    push!(mean_implausibilities_weak_dt, mean(implausibilities_dt))
    push!(mean_implausibilities_weak_gc, mean(implausibilities_gc))
end

mean_implausibilities_weak_dt

here!
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 20 of split 1 reached
Skipping because of model misclassification
datapoint 30 of split 1 reached
Skipping because of model misclassification
datapoint 40 of split 1 reached
Skipping because of model misclassification
datapoint 50 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 60 of split 1 reached
datapoint 70 of split 1 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 80 of split 1 reached
Skipping because of model misclassification
datapoint 90 of split 1 reached
datapoint 100 of split 1 reached
here!
datapoi

[32mProgress:  40%|█████████████████                        |  ETA: 0:06:58[39m[K

here!
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 20 of split 3 reached
datapoint 30 of split 3 reached
datapoint 40 of split 3 reached
Skipping because of model misclassification
datapoint 50 of split 3 reached
Skipping because of model misclassification
datapoint 60 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 70 of split 3 reached
Skipping because of model misclassification
datapoint 80 of split 3 reached
Skipping because of model misclassification
datapoint 90 of split 3 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 100 of split 3 reached


[32mProgress:  60%|█████████████████████████                |  ETA: 0:04:40[39m[K

here!
Skipping because of model misclassification
datapoint 10 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 4 reached
datapoint 40 of split 4 reached
Skipping because of model misclassification
datapoint 50 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 60 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 70 of split 4 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 80 of split 4 reached
Skipping because of model misclassification
datapoint 90 of split 4 reached
datapoint 100 of split 4 reached


[32mProgress:  80%|█████████████████████████████████        |  ETA: 0:02:22[39m[K

here!
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 10 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 20 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 30 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 40 of split 5 reached
Skipping because of model misclassification
datapoint 50 of split 5 reached
Skipping because of model misclassification
Skipping because of model misclassification
Skipping because of model misclassification
datapoint 70 of split 5 reached
Skipping because of model misclassification
datapoint 80 of split 5 reached
Skipping because of model misclassification
datapoint 90 of split 5 reached
Skipping because of model misclass

[32mProgress: 100%|█████████████████████████████████████████| Time: 0:11:44[39m[K





5-element Vector{Any}:
 1.2676878286376387
 1.3645726726615375
 1.3164353558399768
 1.3017174688659383
 1.4549688923239505

In [45]:
println("skips: ", skipped_weak)

println("mean implausibilities using DTC: ", mean_implausibilities_weak_dt)
println("mean implausibilities using GCC: ", mean_implausibilities_weak_gc)

println("Using Decision Threshold Convergence:")

println("valid switches: ", total_validity_weak_dt)
println("mean: ", mean(mean_implausibilities_weak_dt))
println("std: ", std(mean_implausibilities_weak_dt))

println("Using Generator Conditions Convergence:")

println("valid switches: ", total_validity_weak_gc)
println("mean: ", mean(mean_implausibilities_weak_gc))
println("std: ", std(mean_implausibilities_weak_gc))

skips: 70
mean implausibilities using DTC: Any[1.2676878286376387, 1.3645726726615375, 1.3164353558399768, 1.3017174688659383, 1.4549688923239505]
mean implausibilities using GCC: Any[2.027445888388999, 2.105669214505507, 2.1845605486629203, 2.0210585454905257, 2.230242890262945]
Using Decision Threshold Convergence:
valid switches: 429.0
mean: 1.3410764436658085
std: 0.07257024796818826
Using Generator Conditions Convergence:
valid switches: 429.0
mean: 2.113795417462179
std: 0.09312642144456523
