In [1]:
using DrWatson
@quickactivate 

using JsonGrinder, Flux, MLDatasets, Statistics, Random, Printf, JSON, HierarchicalUtils
using SumProductSet
import Mill
import Distributions

# Update default JsonGrinder extractor

In [2]:
function default_scalar_extractor()
    [
    (e -> length(keys(e)) <= 100 && JsonGrinder.is_numeric_or_numeric_string(e),
        (e, uniontypes) -> ExtractCategorical(keys(e), uniontypes)),
    (e -> JsonGrinder.is_intable(e),
        (e, uniontypes) -> extractscalar(Int64, e, uniontypes)),
    (e -> JsonGrinder.is_floatable(e),
        (e, uniontypes) -> extractscalar(Float64, e, uniontypes)),
    (e -> (keys_len = length(keys(e)); keys_len / e.updated < 0.1 && keys_len < 10000 && !JsonGrinder.is_numeric_or_numeric_string(e)),
        (e, uniontypes) -> ExtractCategorical(keys(e), uniontypes)),
    (e -> true,
        (e, uniontypes) -> ExtractScalar(Float64, 0., 1., false)),]
end

default_scalar_extractor (generic function with 1 method)

# Extract data into MILL format

In [3]:
x_train, y_train = MLDatasets.Mutagenesis.traindata();
sch = JsonGrinder.schema(x_train)
extractor = suggestextractor(sch, (; scalar_extractors = default_scalar_extractor()))
ds_train = Mill.catobs(extractor.(x_train))

[34mProductNode[39m[90m 	# 100 obs, 104 bytes[39m
[34m  ├─── lumo: [39m[39mArrayNode(99×100 OneHotArray with Bool elements)[90m 	# 100 obs, 456  [39m[90m⋯[39m
[34m  ├─── inda: [39m[39mArrayNode(2×100 OneHotArray with Bool elements)[90m 	# 100 obs, 456 b [39m[90m⋯[39m
[34m  ├─── logp: [39m[39mArrayNode(63×100 OneHotArray with Bool elements)[90m 	# 100 obs, 456  [39m[90m⋯[39m
[34m  ├─── ind1: [39m[39mArrayNode(3×100 OneHotArray with Bool elements)[90m 	# 100 obs, 456 b [39m[90m⋯[39m
[34m  ╰── atoms: [39m[31mBagNode[39m[90m 	# 100 obs, 1.680 KiB[39m
[34m             [39m[31m  ╰── [39m[32mProductNode[39m[90m 	# 2529 obs, 64 bytes[39m
[34m             [39m[31m      [39m[32m  ┊[39m

In [4]:
JSON.print(x_train[1],2)

{
  "lumo": -1.246,
  "inda": 0,
  "logp": 4.23,
  "ind1": 1,
  "atoms": [
    {
      "element": "c",
      "bonds": [
        {
          "element": "c",
          "bond_type": 7,
          "charge": -0.117,
          "atom_type": 22
        },
        {
          "element": "h",
          "bond_type": 1,
          "charge": 0.142,
          "atom_type": 3
        },
        {
          "element": "c",
          "bond_type": 7,
          "charge": -0.117,
          "atom_type": 22
        }
      ],
      "charge": -0.117,
      "atom_type": 22
    },
    {
      "element": "h",
      "bonds": [
        {
          "element": "c",
          "bond_type": 1,
          "charge": -0.117,
          "atom_type": 22
        }
      ],
      "charge": 0.142,
      "atom_type": 3
    },
    {
      "element": "c",
      "bonds": [
        {
          "element": "c",
          "bond_type": 7,
          "charge": 0.013,
          "atom_type": 27
        },
        {
          "element": "c",
  

In [5]:
printtree(sch, htrunc=25, vtrunc=25)

[34m[Dict][39m[90m 	# updated = 100[39m
[34m  ├─── lumo: [39m[39m[Scalar - Float64], 98 unique values[90m 	# updated = 100[39m
[34m  ├─── inda: [39m[39m[Scalar - Int64], 1 unique values[90m 	# updated = 100[39m
[34m  ├─── logp: [39m[39m[Scalar - Float64,Int64], 62 unique values[90m 	# updated = 100[39m
[34m  ├─── ind1: [39m[39m[Scalar - Int64], 2 unique values[90m 	# updated = 100[39m
[34m  ╰── atoms: [39m[31m[List][39m[90m 	# updated = 100[39m
[34m             [39m[31m  ╰── [39m[32m[Dict][39m[90m 	# updated = 2529[39m
[34m             [39m[31m      [39m[32m  ├──── element: [39m[39m[Scalar - String], 6 unique values[90m 	# updated = 2529[39m
[34m             [39m[31m      [39m[32m  ├────── bonds: [39m[33m[List][39m[90m 	# updated = 2529[39m
[34m             [39m[31m      [39m[32m  │              [39m[33m  ╰── [39m[36m[Dict][39m[90m 	# updated = 5402[39m
[34m             [39m[31m      [39m[32m  │              [39

In [6]:
x_train, y_train = MLDatasets.Mutagenesis.traindata();
y_train .+= 1;
sch = JsonGrinder.schema(x_train)
extractor = suggestextractor(sch, (; scalar_extractors = default_scalar_extractor()))
ds_train = Mill.catobs(extractor.(x_train))

x_test, y_test = MLDatasets.Mutagenesis.testdata();
y_test .+= 1;
ds_test = Mill.catobs(extractor.(x_test));

# Define training utility functions

In [7]:
function train!(m, x, y; niter::Int=1000, opt=ADAM(0.1), cb=iter->())
    ps = Flux.params(m)
    cb(0)
    for i in 1:niter
        gs = gradient(() -> SumProductSet.sl_loss(m, x, y), ps)
        Flux.Optimise.update!(opt, ps, gs)
        cb(i)
    end
end

train! (generic function with 1 method)

In [8]:
predict = x-> Flux.onecold(softmax(logjnt(m, x)))

#28 (generic function with 1 method)

In [9]:
accuracy(y, x) = mean(y .== predict(x))
function status(iter, x_trn, y_trn, x_tst, y_tst)
    acc_trn = accuracy(y_trn, x_trn) 
    acc_tst = accuracy(y_tst, x_tst)
    
    @printf("Epoch %i - acc: | %.3f  %.3f | \n", iter, acc_trn, acc_tst)
end

status (generic function with 1 method)

# Initialize model which reflects training data

In [10]:
Random.seed!(1234);
dir_rand = d->rand(Distributions.Dirichlet(d, 10*d))
f_cat = d->_Categorical(log.(dir_rand(d))) # choose how to represent categorical variables
f_cont = d->gmm(2, d)  # choose how to represent continuous variables

m = reflectinmodel(ds_train[1], 2; f_cont=f_cont, f_cat=f_cat)
cb = i -> status(i, ds_train, y_train, ds_test, y_test)
train!(m, ds_train, y_train; niter=10, opt=ADAM(0.25), cb=cb)

Epoch 0 - acc: | 0.610  0.682 | 
Epoch 1 - acc: | 0.620  0.682 | 
Epoch 2 - acc: | 0.820  0.705 | 
Epoch 3 - acc: | 0.790  0.750 | 
Epoch 4 - acc: | 0.790  0.659 | 
Epoch 5 - acc: | 0.790  0.705 | 
Epoch 6 - acc: | 0.790  0.727 | 
Epoch 7 - acc: | 0.790  0.727 | 
Epoch 8 - acc: | 0.790  0.727 | 
Epoch 9 - acc: | 0.790  0.727 | 
Epoch 10 - acc: | 0.790  0.727 | 


In [11]:
sum(length, Flux.params(m))

516

In [13]:
printtree(m, htrunc=25, vtrunc=25)

[34mSumNode[39m
[34m  ├── [39m[31mProductNode[39m
[34m  │   [39m[31m  ├─── lumo: [39m[39m_Categorical
[34m  │   [39m[31m  ├─── inda: [39m[39m_Categorical
[34m  │   [39m[31m  ├─── logp: [39m[39m_Categorical
[34m  │   [39m[31m  ├─── ind1: [39m[39m_Categorical
[34m  │   [39m[31m  ╰── atoms: [39m[32mSetNode[39m
[34m  │   [39m[31m             [39m[32m  ├── c: [39m[39m_Poisson
[34m  │   [39m[31m             [39m[32m  ╰── f: [39m[33mProductNode[39m
[34m  │   [39m[31m             [39m[32m         [39m[33m  ├──── element: [39m[39m_Categorical
[34m  │   [39m[31m             [39m[32m         [39m[33m  ├────── bonds: [39m[36mSetNode[39m
[34m  │   [39m[31m             [39m[32m         [39m[33m  │              [39m[36m  ├── c: [39m[39m_Poisson
[34m  │   [39m[31m             [39m[32m         [39m[33m  │              [39m[36m  ╰── f: [39m[35mProductNode[39m
[34m  │   [39m[31m             [39m[32m         [39