# Mutagenesis

This example demonstrates how to predict [the mutagenicity on Salmonella typhimurium](https://relational.fel.cvut.cz/dataset/Mutagenesis).

The full environment, the script and the data are accessible [here](https://github.com/CTUAvastLab/JsonGrinder.jl/tree/master/docs/src/examples/mutagenesis).

We start by activating the environment and installing required packages

In [1]:
using Pkg
Pkg.activate(pwd())
Pkg.instantiate()
Pkg.status()

  Activating project at `~/work/JsonGrinder.jl/JsonGrinder.jl/docs/src/examples/mutagenesis`
Status `~/work/JsonGrinder.jl/JsonGrinder.jl/docs/src/examples/mutagenesis/Project.toml`
⌅ [587475ba] Flux v0.14.25
  [682c06a0] JSON v0.21.4
  [d201646e] JsonGrinder v2.6.2
  [f1d291b0] MLUtils v0.4.4
  [1d0525e4] Mill v2.11.2
Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`


We load all dependencies and fix the seed:

In [2]:
using JsonGrinder, Mill, Flux, JSON, MLUtils, Statistics

using Random; Random.seed!(42);

### Loading the data

we load the dataset (available ), and split it into training and testing set.

In [3]:
dataset = JSON.parsefile("mutagenesis.json");
jss_train, jss_test = dataset[1:100], dataset[101:end];

`jss_train` and `jss_test` are just lists of parsed JSONs:

In [4]:
jss_train[1]

Dict{String, Any} with 6 entries:
  "ind1"      => 1
  "lumo"      => -1.246
  "inda"      => 0
  "logp"      => 4.23
  "mutagenic" => 1
  "atoms"     => Any[Dict{String, Any}("element"=>"c", "atom_type"=>22, "bonds"…

We also extract binary labels, which are stored in the `"mutagenic"` key:

In [5]:
y_train = getindex.(jss_train, "mutagenic");
y_test = getindex.(jss_test, "mutagenic");
y_train

100-element Vector{Int64}:
 1
 1
 0
 1
 1
 1
 1
 1
 1
 1
 ⋮
 0
 0
 1
 1
 0
 0
 1
 0
 0

We first create the `schema` of the training data, which is the first important step in using the
[`JsonGrinder.jl`](https://github.com/CTUAvastLab/JsonGrinder.jl).
This infers both the hierarchical structure of the documents and basic statistics of individual values.

In [6]:
sch = schema(jss_train)

[34mDictEntry[39m[90m[3m 100x updated[23m[39m
[34m  ├────── atoms: [39m[31mArrayEntry[39m[90m[3m 100x updated[23m[39m
[34m  │              [39m[31m  ╰── [39m[32mDictEntry[39m[90m[3m 2529x updated[23m[39m
[34m  │              [39m[31m      [39m[32m  ├── atom_type: [39m[39mLeafEntry (28 unique `Real` values)[90m[3m 25 [23m[39m[90m⋯[39m
[34m  │              [39m[31m      [39m[32m  ├────── bonds: [39m[33mArrayEntry[39m[90m[3m 2529x updated[23m[39m
[34m  │              [39m[31m      [39m[32m  │              [39m[33m  ╰── [39m[36mDictEntry[39m[90m[3m 5402x updated[23m[39m
[34m  │              [39m[31m      [39m[32m  │              [39m[33m      [39m[36m  ┊[39m
[34m  │              [39m[31m      [39m[32m  ├───── charge: [39m[39mLeafEntry (318 unique `Real` values)[90m[3m 2 [23m[39m[90m⋯[39m
[34m  │              [39m[31m      [39m[32m  ╰──── element: [39m[39mLeafEntry (6 unique `String` values)[90m[3

Of course, we have to remove the `"mutagenic"` key from the schema, as we don't want to include it
in the data:

In [7]:
delete!(sch, :mutagenic);
sch

[34mDictEntry[39m[90m[3m 100x updated[23m[39m
[34m  ├── atoms: [39m[31mArrayEntry[39m[90m[3m 100x updated[23m[39m
[34m  │          [39m[31m  ╰── [39m[32mDictEntry[39m[90m[3m 2529x updated[23m[39m
[34m  │          [39m[31m      [39m[32m  ├── atom_type: [39m[39mLeafEntry (28 unique `Real` values)[90m[3m 2529x  [23m[39m[90m⋯[39m
[34m  │          [39m[31m      [39m[32m  ├────── bonds: [39m[33mArrayEntry[39m[90m[3m 2529x updated[23m[39m
[34m  │          [39m[31m      [39m[32m  │              [39m[33m  ╰── [39m[36mDictEntry[39m[90m[3m 5402x updated[23m[39m
[34m  │          [39m[31m      [39m[32m  │              [39m[33m      [39m[36m  ┊[39m
[34m  │          [39m[31m      [39m[32m  ├───── charge: [39m[39mLeafEntry (318 unique `Real` values)[90m[3m 2529x [23m[39m[90m⋯[39m
[34m  │          [39m[31m      [39m[32m  ╰──── element: [39m[39mLeafEntry (6 unique `String` values)[90m[3m 2529x [23m[39m[90m⋯

Now we create an extractor capable of converting JSONs to [`Mill.jl`](https://github.com/CTUAvastLab/Mill.jl) structures.
We use function `suggestextractor` with the default settings:

In [8]:
e = suggestextractor(sch)

[34mDictExtractor[39m
[34m  ├─── lumo: [39m[39mCategoricalExtractor(n=99)
[34m  ├─── inda: [39m[39mCategoricalExtractor(n=2)
[34m  ├─── logp: [39m[39mCategoricalExtractor(n=63)
[34m  ├─── ind1: [39m[39mCategoricalExtractor(n=3)
[34m  ╰── atoms: [39m[31mArrayExtractor[39m
[34m             [39m[31m  ╰── [39m[32mDictExtractor[39m
[34m             [39m[31m      [39m[32m  ├──── element: [39m[39mCategoricalExtractor(n=7)
[34m             [39m[31m      [39m[32m  ├────── bonds: [39m[33mArrayExtractor[39m
[34m             [39m[31m      [39m[32m  │              [39m[33m  ╰── [39m[36mDictExtractor[39m
[34m             [39m[31m      [39m[32m  │              [39m[33m      [39m[36m  ┊[39m
[34m             [39m[31m      [39m[32m  ├───── charge: [39m[39mScalarExtractor(c=-0.781, s=0.60790277)
[34m             [39m[31m      [39m[32m  ╰── atom_type: [39m[39mCategoricalExtractor(n=29)

We also need to convert JSONs to [`Mill.jl`](https://github.com/CTUAvastLab/Mill.jl) data samples.
Extractor `e` is callable, we can use it to extract one document as follows:

In [9]:
x_single = e(jss_train[1])

[34mProductNode[39m[90m[3m  1 obs[23m[39m
[34m  ├─── lumo: [39m[39mArrayNode(99×1 OneHotArray with Bool elements)[90m[3m  1 obs[23m[39m
[34m  ├─── inda: [39m[39mArrayNode(2×1 OneHotArray with Bool elements)[90m[3m  1 obs[23m[39m
[34m  ├─── logp: [39m[39mArrayNode(63×1 OneHotArray with Bool elements)[90m[3m  1 obs[23m[39m
[34m  ├─── ind1: [39m[39mArrayNode(3×1 OneHotArray with Bool elements)[90m[3m  1 obs[23m[39m
[34m  ╰── atoms: [39m[31mBagNode[39m[90m[3m  1 obs[23m[39m
[34m             [39m[31m  ╰── [39m[32mProductNode[39m[90m[3m  26 obs[23m[39m
[34m             [39m[31m      [39m[32m  ├──── element: [39m[39mArrayNode(7×26 OneHotArray with Bool eleme [90m⋯[39m
[34m             [39m[31m      [39m[32m  ├────── bonds: [39m[33mBagNode[39m[90m[3m  26 obs[23m[39m
[34m             [39m[31m      [39m[32m  │              [39m[33m  ╰── [39m[36mProductNode[39m[90m[3m  56 obs[23m[39m
[34m             [39m[31

To extract a batch of 10 documents, we can extract individual documents and then `Mill.catobs` them:

In [10]:
x_batch = reduce(catobs, e.(jss_train[1:10]))

[34mProductNode[39m[90m[3m  10 obs[23m[39m
[34m  ├─── lumo: [39m[39mArrayNode(99×10 OneHotArray with Bool elements)[90m[3m  10 obs[23m[39m
[34m  ├─── inda: [39m[39mArrayNode(2×10 OneHotArray with Bool elements)[90m[3m  10 obs[23m[39m
[34m  ├─── logp: [39m[39mArrayNode(63×10 OneHotArray with Bool elements)[90m[3m  10 obs[23m[39m
[34m  ├─── ind1: [39m[39mArrayNode(3×10 OneHotArray with Bool elements)[90m[3m  10 obs[23m[39m
[34m  ╰── atoms: [39m[31mBagNode[39m[90m[3m  10 obs[23m[39m
[34m             [39m[31m  ╰── [39m[32mProductNode[39m[90m[3m  299 obs[23m[39m
[34m             [39m[31m      [39m[32m  ├──── element: [39m[39mArrayNode(7×299 OneHotArray with Bool elem [90m⋯[39m
[34m             [39m[31m      [39m[32m  ├────── bonds: [39m[33mBagNode[39m[90m[3m  299 obs[23m[39m
[34m             [39m[31m      [39m[32m  │              [39m[33m  ╰── [39m[36mProductNode[39m[90m[3m  650 obs[23m[39m
[34m         

Or we can use a much more efficient `extract` function, which operates on a list of documents:
Because the dataset is small, we can extract all data at once and keep it in memory:

In [11]:
x_train = extract(e, jss_train);
x_test = extract(e, jss_test);
x_train

[34mProductNode[39m[90m[3m  100 obs[23m[39m
[34m  ├─── lumo: [39m[39mArrayNode(99×100 OneHotArray with Bool elements)[90m[3m  100 obs[23m[39m
[34m  ├─── inda: [39m[39mArrayNode(2×100 OneHotArray with Bool elements)[90m[3m  100 obs[23m[39m
[34m  ├─── logp: [39m[39mArrayNode(63×100 OneHotArray with Bool elements)[90m[3m  100 obs[23m[39m
[34m  ├─── ind1: [39m[39mArrayNode(3×100 OneHotArray with Bool elements)[90m[3m  100 obs[23m[39m
[34m  ╰── atoms: [39m[31mBagNode[39m[90m[3m  100 obs[23m[39m
[34m             [39m[31m  ╰── [39m[32mProductNode[39m[90m[3m  2529 obs[23m[39m
[34m             [39m[31m      [39m[32m  ├──── element: [39m[39mArrayNode(7×2529 OneHotArray with Bool ele [90m⋯[39m
[34m             [39m[31m      [39m[32m  ├────── bonds: [39m[33mBagNode[39m[90m[3m  2529 obs[23m[39m
[34m             [39m[31m      [39m[32m  │              [39m[33m  ╰── [39m[36mProductNode[39m[90m[3m  5402 obs[23m[39m


Then we create an encoding model capable of embedding each JSON document into a fixed-size vector.

In [12]:
encoder = reflectinmodel(sch, e)

[34mProductModel ↦ Dense(50 => 10)[39m[90m[3m  2 arrays, 510 params, 2.078 KiB[23m[39m
[34m  ├─── lumo: [39m[39mArrayModel(Dense(99 => 10))[90m[3m  2 arrays, 1_000 params, 3.992 KiB[23m[39m
[34m  ├─── inda: [39m[39mArrayModel(Dense(2 => 10))[90m[3m  2 arrays, 30 params, 208 bytes[23m[39m
[34m  ├─── logp: [39m[39mArrayModel(Dense(63 => 10))[90m[3m  2 arrays, 640 params, 2.586 KiB[23m[39m
[34m  ├─── ind1: [39m[39mArrayModel(Dense(3 => 10))[90m[3m  2 arrays, 40 params, 248 bytes[23m[39m
[34m  ╰── atoms: [39m[31mBagModel ↦ BagCount([SegmentedMean(10); SegmentedMax(10)]) ↦ Dens [39m[90m⋯[39m
[34m             [39m[31m  ╰── [39m[32mProductModel ↦ Dense(31 => 10)[39m[90m[3m  2 arrays, 320 params, 1.336 [23m[39m[90m⋯[39m
[34m             [39m[31m      [39m[32m  ├──── element: [39m[39mArrayModel(Dense(7 => 10))[90m[3m  2 arrays, 80 p [23m[39m[90m⋯[39m
[34m             [39m[31m      [39m[32m  ├────── bonds: [39m[33mBagModel ↦

For further details about `reflectinmodel`, see the [Mill.jl documentation](https://CTUAvastLab.github.io/Mill.jl/stable/manual/reflectin/#Model-Reflection).

Finally, we chain the `encoder` with one more dense layer computing the logit of mutagenic probability:

In [13]:
model = vec ∘ Dense(10, 1) ∘ encoder

vec ∘ Dense(10 => 1) ∘ ProductModel ↦ Dense(50 => 10)

We can train the model in the standard [`Flux.jl`](https://fluxml.ai) way. We define the loss
function, optimizer, and minibatch iterator:

In [14]:
pred(m, x) = σ.(m(x))
loss(m, x, y) = Flux.Losses.logitbinarycrossentropy(m(x), y);
opt_state = Flux.setup(Flux.Optimise.Descent(), model);
minibatch_iterator = Flux.DataLoader((x_train, y_train), batchsize=32, shuffle=true);

We train for 10 epochs, and after each epoch we report the training accuracy:

In [15]:
accuracy(p, y) = mean((p .> 0.5) .== y)
for i in 1:10
    Flux.train!(loss, model, minibatch_iterator, opt_state)
    @info "Epoch $i" accuracy=accuracy(pred(model, x_train), y_train)
end

┌ Info: Epoch 1
└   accuracy = 0.61
┌ Info: Epoch 2
└   accuracy = 0.63
┌ Info: Epoch 3
└   accuracy = 0.64
┌ Info: Epoch 4
└   accuracy = 0.74
┌ Info: Epoch 5
└   accuracy = 0.61
┌ Info: Epoch 6
└   accuracy = 0.82
┌ Info: Epoch 7
└   accuracy = 0.82
┌ Info: Epoch 8
└   accuracy = 0.84
┌ Info: Epoch 9
└   accuracy = 0.82
┌ Info: Epoch 10
└   accuracy = 0.82


We can compute the accuracy on the testing set now:

In [16]:
accuracy(pred(model, x_test), y_test)

0.8636363636363636

---

*This notebook was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*