# MNIST pre-trained model

In [1]:
include("generate_artifacts.jl")
include("utils.jl")
data_dir = "../data"
artifact_name = "mnist";

### Retrieve data

In [7]:
using Plots, MLDatasets
using MLDatasets.MNIST: convert2image
using BSON: @save, @load
train_x, train_y = MNIST.traindata()
train_x = Float32.(train_x)
train_y = Float32.(train_y);

### Preprocess data

In [8]:
using Flux
using Flux: onehotbatch, onecold, DataLoader
y = Flux.onehotbatch(train_y, 0:9)
x = Flux.flatten(train_x)
bs = Int(round(size(x)[2]/10))
data = DataLoader((x,y),batchsize=bs)
@save joinpath(data_dir, artifact_name * "_data.bson") data

### Classifier

In [9]:
output_dim = 10
input_dim = prod(size(train_x[:,:,1]))
hidden_dim = 32

kw_args = (input_dim=input_dim,n_hidden=hidden_dim,output_dim=output_dim,batch_norm=true)
model = build_model(;kw_args...)
loss(x, y) = Flux.Losses.logitcrossentropy(model(x), y)


using Flux.Optimise: update!, ADAM
using Statistics
opt = ADAM()
epochs = 10
avg_loss(data) = mean(map(d -> loss(d[1],d[2]), data))
accuracy(data) = mean(map(d -> mean(onecold(softmax(model(d[1])), 0:9) .== onecold(d[2], 0:9)), data))

if false
  for epoch = 1:epochs
    for d in data
      gs = gradient(params(model)) do
        l = loss(d...)
      end
      update!(opt, params(model), gs)
    end
    @info "Epoch " * string(epoch)
    @show accuracy(data)
  end
  @save joinpath(data_dir, artifact_name * "_model.bson") model
end

┌ Info: Epoch 1
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.5373666666666667




┌ Info: Epoch 2
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.7550833333333333


┌ Info: Epoch 3
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.8258666666666666


┌ Info: Epoch 4
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.8582000000000001


┌ Info: Epoch 5
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.8751333333333333


┌ Info: Epoch 6
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.8875666666666666


┌ Info: Epoch 7
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.8974666666666666


┌ Info: Epoch 8
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.9047333333333334


┌ Info: Epoch 9
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.9100666666666667


┌ Info: Epoch 10
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/mnist.ipynb:25


accuracy(data) = 0.9144833333333333


### Generate artifact

In [10]:
datafiles = [artifact_name * "_data.bson",artifact_name * "_model.bson"]
generate_artifact(datafiles)

└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:13


┌ Info: Binding mnist_data in Artifacts.toml...
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:56


┌ Info: Binding mnist_model in Artifacts.toml...
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:56
┌ Info: Uploading tarballs to pat-alt/CounterfactualExplanations.jl tag `data`
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:67


--> Deleting: mnist_model.tar.gz
--> Deleting: mnist_data.tar.gz
--> Uploading: mnist_data.tar.gz
--> Uploading: mnist_model.tar.gz


┌ Info: Artifacts.toml file now contains all bound artifact names
└ @ Main /Users/FA31DU/Documents/code/CounterfactualExplanations.jl/dev/artifacts/generate_artifacts.jl:73
