In [None]:
using DrWatson
@quickactivate 

using Flux, Graphs, Statistics, Plots, GraphPlot, Zygote
using Cairo, Compose
using SumProductSet
import Mill

# Graph clustering on Karate graph data

In [None]:
g = smallgraph(:karate)
vs = collect(vertices(g))
x = Flux.onehotbatch(vs, vs)
ds_train = Mill.BagNode(x, g.fadjlist)

In [None]:
r = rand(d)
dir_rand = r ./ sum(r)
f_cat = d->Categorical(log.(dir_rand(d)))
m = reflectinmodel(ds_train[1], 2; f_card=()->Poisson(log(4)), f_cat=f_cat)
m

In [None]:
function train!(m, x; niter::Int=1000, opt=ADAM(1.))
    ps = Flux.params(m)

    for i in 1:niter
        gs = gradient(() -> SumProductSet.ul_loss(m, x), ps)
        Flux.Optimise.update!(opt, ps, gs)
    end
end

In [None]:
colors = [colorant"yellow", colorant"red"]
nodelabel = 1:nv(g)
layout=_->spring_layout(g, 3)

In [None]:
train!(m, ds_train; opt=ADAM(1.), niter=200)
predict = x->mapslices(argmax, logjnt(m, x), dims=1)[:] 
clusters = predict(ds_train);
gplot(g, nodefillc=colors[clusters], layout=layout, nodelabel=nodelabel)

In [None]:
m = reflectinmodel(ds_train[1], 2; f_card=()->Poisson(log(3)), f_cat=f_cat)
train!(m, ds_train; opt=ADAM(0.1), niter=200)
clusters = predict(ds_train)
gplot(g, nodefillc=colors[clusters], layout=layout, nodelabel=nodelabel)

In [None]:
# draw(PDF("karate_presentation_clustered.pdf", 16cm, 9cm), gplot(g,nodefillc=colors[clusters], layout=layout, nodelabel=nodelabel))