In [1]:
using DrWatson
@quickactivate 

using JsonGrinder, Flux, MLDatasets, Statistics, Random, Printf, JSON3, HierarchicalUtils
using SumProductSet
import Mill

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mPrecompiling SumProductSet [d0366596-3556-49ae-b3ef-851ab4ad1106]


# Update default JsonGrinder extractor

In [2]:
function default_scalar_extractor()
    [
    (e -> length(keys(e)) <= 100 && JsonGrinder.is_numeric_or_numeric_string(e),
        (e, uniontypes) -> ExtractCategorical(keys(e), uniontypes)),
    (e -> JsonGrinder.is_intable(e),
        (e, uniontypes) -> extractscalar(Int32, e, uniontypes)),
    (e -> JsonGrinder.is_floatable(e),
        (e, uniontypes) -> extractscalar(Float32, e, uniontypes)),
    (e -> (keys_len = length(keys(e)); keys_len / e.updated < 0.1 && keys_len < 10000 && !JsonGrinder.is_numeric_or_numeric_string(e)),
        (e, uniontypes) -> ExtractCategorical(keys(e), uniontypes)),
    (e -> true,
        (e, uniontypes) -> ExtractScalar(Float32, 0., 1., false)),]
end

default_scalar_extractor (generic function with 1 method)

# Extract data into MILL format

In [3]:
train_data = MLDatasets.Mutagenesis(split=:train)
x_train, y_train = train_data.features, train_data.targets
y_train .+= 1;
sch = JsonGrinder.schema(x_train)
extractor = suggestextractor(sch, (; scalar_extractors = default_scalar_extractor()))
ds_train = Mill.catobs(extractor.(x_train))

[34mProductNode[39m[90m  # 100 obs, 104 bytes[39m
[34m  ├─── lumo: [39m[39mArrayNode(99×100 OneHotArray with Bool elements)[90m  # 100 obs, 456  [39m[90m⋯[39m
[34m  ├─── inda: [39m[39mArrayNode(2×100 OneHotArray with Bool elements)[90m  # 100 obs, 456 b [39m[90m⋯[39m
[34m  ├─── logp: [39m[39mArrayNode(63×100 OneHotArray with Bool elements)[90m  # 100 obs, 456  [39m[90m⋯[39m
[34m  ├─── ind1: [39m[39mArrayNode(3×100 OneHotArray with Bool elements)[90m  # 100 obs, 456 b [39m[90m⋯[39m
[34m  ╰── atoms: [39m[31mBagNode[39m[90m  # 100 obs, 1.680 KiB[39m
[34m             [39m[31m  ╰── [39m[32mProductNode[39m[90m  # 2529 obs, 64 bytes[39m
[34m             [39m[31m      [39m[32m  ├──── element: [39m[39mArrayNode(7×2529 OneHotArray with Bool ele [90m⋯[39m
[34m             [39m[31m      [39m[32m  ├────── bonds: [39m[33mBagNode[39m[90m  # 2529 obs, 39.602 KiB[39m
[34m             [39m[31m      [39m[32m  │              [39m[33m 

In [4]:
printtree(sch, htrunc=25, vtrunc=25)

[34m[Dict][39m[90m  # updated = 100[39m
[34m  ├─── lumo: [39m[39m[Scalar - Float64], 98 unique values[90m  # updated = 100[39m
[34m  ├─── inda: [39m[39m[Scalar - Int64], 1 unique values[90m  # updated = 100[39m
[34m  ├─── logp: [39m[39m[Scalar - Float64,Int64], 62 unique values[90m  # updated = 100[39m
[34m  ├─── ind1: [39m[39m[Scalar - Int64], 2 unique values[90m  # updated = 100[39m
[34m  ╰── atoms: [39m[31m[List][39m[90m  # updated = 100[39m
[34m             [39m[31m  ╰── [39m[32m[Dict][39m[90m  # updated = 2529[39m
[34m             [39m[31m      [39m[32m  ├──── element: [39m[39m[Scalar - String], 6 unique values[90m  # updated = 2529[39m
[34m             [39m[31m      [39m[32m  ├────── bonds: [39m[33m[List][39m[90m  # updated = 2529[39m
[34m             [39m[31m      [39m[32m  │              [39m[33m  ╰── [39m[36m[Dict][39m[90m  # updated = 5402[39m
[34m             [39m[31m      [39m[32m  │              [39

In [5]:
test_data = MLDatasets.Mutagenesis(split=:test)
x_test, y_test = test_data.features, test_data.targets
y_test .+= 1;
ds_test = Mill.catobs(extractor.(x_test));

# Define training utility functions

In [6]:
function train!(m, x, y; niter::Int=100, opt=ADAM(0.1), cb=iter->())
    ps = Flux.params(m)
    cb(0)
    for i in 1:niter
        gs = gradient(() -> SumProductSet.ce_loss(m, x, y), ps)
        Flux.Optimise.update!(opt, ps, gs)
        cb(i)
    end
end

train! (generic function with 1 method)

In [7]:
predict = x-> Flux.onecold(softmax(logjnt(m, x)))

#28 (generic function with 1 method)

In [8]:
accuracy(y, x) = mean(y .== predict(x))
function status(iter, x_trn, y_trn, x_tst, y_tst)
    acc_trn = accuracy(y_trn, x_trn) 
    acc_tst = accuracy(y_tst, x_tst)
    
    @printf("Epoch %i - acc: | %.3f  %.3f | \n", iter, acc_trn, acc_tst)
end

status (generic function with 1 method)

# Initialize model which reflects training data

In [9]:
Random.seed!(1234);
dir_rand(d) = (r = rand(d); return r ./ sum(r))
f_cat = d->Categorical(log.(dir_rand(d))) # choose how to represent categorical variables
f_cont = d->gmm(2, d)  # choose how to represent continuous variables

m = reflectinmodel(ds_train[1], 2)
cb = i -> status(i, ds_train, y_train, ds_test, y_test)
@time train!(m, ds_train, y_train; niter=100, opt=ADAM(0.2), cb=cb)

Epoch 0 - acc: | 0.610  0.682 | 
Epoch 1 - acc: | 0.610  0.682 | 
Epoch 2 - acc: | 0.610  0.682 | 
Epoch 3 - acc: | 0.670  0.682 | 
Epoch 4 - acc: | 0.850  0.773 | 
Epoch 5 - acc: | 0.850  0.795 | 
Epoch 6 - acc: | 0.830  0.773 | 
Epoch 7 - acc: | 0.820  0.818 | 
Epoch 8 - acc: | 0.800  0.750 | 
Epoch 9 - acc: | 0.800  0.727 | 
Epoch 10 - acc: | 0.790  0.682 | 
Epoch 11 - acc: | 0.790  0.659 | 
Epoch 12 - acc: | 0.790  0.682 | 
Epoch 13 - acc: | 0.790  0.705 | 
Epoch 14 - acc: | 0.790  0.727 | 
Epoch 15 - acc: | 0.800  0.750 | 
Epoch 16 - acc: | 0.800  0.773 | 
Epoch 17 - acc: | 0.840  0.773 | 
Epoch 18 - acc: | 0.890  0.773 | 
Epoch 19 - acc: | 0.880  0.773 | 
Epoch 20 - acc: | 0.880  0.795 | 
Epoch 21 - acc: | 0.870  0.795 | 
Epoch 22 - acc: | 0.870  0.795 | 
Epoch 23 - acc: | 0.880  0.795 | 
Epoch 24 - acc: | 0.890  0.795 | 
Epoch 25 - acc: | 0.880  0.795 | 
Epoch 26 - acc: | 0.880  0.795 | 
Epoch 27 - acc: | 0.880  0.773 | 
Epoch 28 - acc: | 0.870  0.773 | 
Epoch 29 - acc: | 0.850 

In [10]:
logpdf(m, ds_train)

1×100 Matrix{Float64}:
 -189.62  -208.523  -248.18  -266.854  …  -259.407  -218.131  -210.526

In [11]:
sum(length, Flux.params(m))

516

In [12]:
printtree(m, htrunc=25, vtrunc=25)

[34mSumNode[39m
[34m  ├── [39m[31mProductNode (:lumo, :inda, :logp, :ind1, :atoms)[39m
[34m  │   [39m[31m  ├── [39m[39mCategorical
[34m  │   [39m[31m  ├── [39m[39mCategorical
[34m  │   [39m[31m  ├── [39m[39mCategorical
[34m  │   [39m[31m  ├── [39m[39mCategorical
[34m  │   [39m[31m  ╰── [39m[32mSetNode[39m
[34m  │   [39m[31m      [39m[32m  ├── c: [39m[39mPoisson
[34m  │   [39m[31m      [39m[32m  ╰── f: [39m[33mProductNode (:element, :bonds, :charge, :atom_type)[39m
[34m  │   [39m[31m      [39m[32m         [39m[33m  ├── [39m[39mCategorical
[34m  │   [39m[31m      [39m[32m         [39m[33m  ├── [39m[36mSetNode[39m
[34m  │   [39m[31m      [39m[32m         [39m[33m  │   [39m[36m  ├── c: [39m[39mPoisson
[34m  │   [39m[31m      [39m[32m         [39m[33m  │   [39m[36m  ╰── f: [39m[35mProductNode (:element, :bond_type, :charge, :atom_type)[39m
[34m  │   [39m[31m      [39m[32m         [39m[33m  │   [