In [None]:
using PyCall
using PyPlot
using ForwardDiff, Flux
using LinearAlgebra
using Distributions
cmd = pyimport("cmdtools")

In [2]:
N = 13

sys = cmd.systems.diffusion.DoubleWell(nx=N,beta=3, xlims=(-1.5,1.5))
Q = sys.Q.todense()
K = exp(Q*.5)

LoadError: [91mUndefVarError: cmd not defined[39m

In [None]:
plot(sys.u')

In [None]:
imshow(K)

In [None]:
bnd = fill(NaN,size(K,1))
bnd[1+1] = 0
bnd[end-1] = 1

bndinds = findall(!isnan, bnd)
bnd

In [None]:
function committor(K, b)
    A = copy(K)
    b = copy(b)
    for i in 1:length(b)
        if isnan(b[i])
            b[i] = 0
            A[i,i] -= 1
        else
            A[i,:].=0
            A[i,i]=1
        end
    end
    A\b
end

c = committor(K,bnd)

In [None]:
sample(x) = rand(Categorical(K[x,:]))

random_data(nx, ny) = let xs = [rand(1:N) for i in 1:nx]
    [(x, [sample(x) for i in 1:ny]) for x in xs]
end

linear_data(ny) = [(x, [sample(x) for i in 1:ny]) for x in 1:N]

data = linear_data(100)

In [9]:
loss(c, x, ys) = abs(mean(c(y) for y in ys) - c(x))

loss_batch(c, batch) = mean(loss(c,d...) for d in batch)

loss_batch (generic function with 1 method)

In [10]:
struct VectorModel
    c
    bnd
    bndinds
end

function (m::VectorModel)(x) 
    if x in m.bndinds
        m.bnd[x]
    else
        m.c[x]
    end
end

Flux.trainable(m::VectorModel) = (m.c,)

model1 = VectorModel(rand(N), bnd, bndinds)
model1(5)

LoadError: [91mUndefVarError: bnd not defined[39m

In [None]:
struct NNModel
    nn
    bnd
    bndinds
end

function (m::NNModel)(x) 
    if x in m.bndinds
        m.bnd[Int(x)]
    else
        m.nn([x])[1] # we work with scalars, nn with arrays
    end
end

Flux.trainable(m::NNModel) = (m.nn, )

model2 = NNModel(Chain(Dense(1,10), Dense(10,1)), bnd, bndinds)
model2(5)

In [None]:
error(x::Vector) = norm(x - c) / sqrt(length(x))
error(f) = error([f(x) for x in 1:N])

In [None]:
function learn(model, nx=100, ny=100)


    errors = []
    losses = []

    opt = ADAM(0.1)

    data_batch = [linear_data(ny) for i in 1:nx] # 1 x sample 100 y sampes

    @time for d in data_batch
        ps = Flux.params(model)
        l, pb = Flux.pullback(ps) do
            loss_batch(model, d)
        end
        grad = pb(1)
        Flux.Optimise.update!(opt, ps, grad)
        push!(errors, error(model))
        push!(losses, l)
    end
    @show losses[end], errors[end]
    model, losses, errors
end

model, losses, errors = learn(VectorModel(rand(N), bnd, bndinds))
#model, losses, errors = learn(NNModel(Chain(Dense(1,10), Dense(10,10), Dense(10,1)), bnd, bndinds))



In [None]:
PyPlot.plot(errors)
PyPlot.yscale(:log)
figure()
plot(losses)
#yscale(:log)

In [None]:
chat = [model(x) for x in 1:N]

hcat(chat, c)

In [None]:
loss_batch(model, linear_data(100000))

In [None]:
truth = VectorModel(c, bnd, bndinds)
loss_batch(truth, linear_data(100000))