# MNIST pre-trained model

In [44]:
include("generate_artifacts.jl")
include("../utils.jl")
data_dir = "../data"
artifact_name = "mnist";

### Retrieve data

In [45]:
using Plots, MLDatasets
using MLDatasets.MNIST: convert2image
using BSON
using BSON: @save, @load
train_x, train_y = MNIST.traindata()
train_x = Float32.(train_x)
train_y = Float32.(train_y);

### Preprocess data

In [46]:
using Flux
using Flux: onehotbatch, onecold, DataLoader
y = Flux.onehotbatch(train_y, 0:9)
x = Flux.flatten(train_x)
bs = Int(round(size(x)[2]/10))
data_train = DataLoader((x,y),batchsize=bs)
data = Dict(
    "data" => DataLoader((x,y),batchsize=bs),
    "x" => x,
    "y" => y
)
@save joinpath(data_dir, artifact_name * "_data.bson") data

### Classifier

In [47]:
data = data_train
output_dim = 10
input_dim = prod(size(train_x[:,:,1]))
hidden_dim = 32
kw_args = (input_dim=input_dim,n_hidden=hidden_dim,output_dim=output_dim,batch_norm=true)
model = build_model(;kw_args...)
loss(x, y) = Flux.Losses.logitcrossentropy(model(x), y)

using Flux.Optimise: update!, ADAM
using Statistics
opt = ADAM()
epochs = 10
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
accuracy(data) = mean(map(d -> mean(onecold(softmax(model(d[1])), 0:9) .== onecold(d[2], 0:9)), data))

if false
  for epoch = 1:epochs
    for d in data
      gs = gradient(params(model)) do
        l = loss(d...)
      end
      update!(opt, params(model), gs)
    end
    @info "Epoch " * string(epoch)
    @show accuracy(data)
  end
  @save joinpath(data_dir, artifact_name * "_model.bson") model
end

### Ensemble classifier

In [48]:
𝓜 = build_ensemble(5;kw=kw_args)
if false
    𝓜, anim = forward(𝓜, data, opt, n_epochs=epochs, plot_loss=false) # fit the ensemble
    save_ensemble(𝓜;root=joinpath(data_dir, artifact_name * "_ensemble")) 
end

### Generate artifacts

In [49]:
datafiles = [artifact_name * "_data.bson",artifact_name * "_model.bson",artifact_name * "_ensemble"]
generate_artifact(datafiles)

└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:13


┌ Info: Binding mnist_data in Artifacts.toml...
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:53


┌ Info: Binding mnist_model in Artifacts.toml...
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:53


┌ Info: Binding mnist_ensemble in Artifacts.toml...
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:53
┌ Info: Uploading tarballs to pat-alt/CounterfactualExplanations.jl tag `data`
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:64


--> Deleting: mnist_model.tar.gz
--> Deleting: mnist_ensemble.tar.gz
--> Deleting: mnist_data.tar.gz
--> Uploading: mnist_data.tar.gz
--> Uploading: mnist_ensemble.tar.gz
--> Uploading: mnist_model.tar.gz


┌ Info: Artifacts.toml file now contains all bound artifact names
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:70
