In [66]:
using Flux,Random

In [None]:
add_dim(x::Array) = reshape(x, (size(x)...,1))

In [182]:
struct HJBEquation
    dim::Int64
    total_time::Float64
    Ndis::Int64
    delta_t::Float64
    sigma::Float64
    x_init::Vector{Float64}
    function HJBEquation(eqn_config)
        new(eqn_config["dim"],
            eqn_config["total_time"],
            eqn_config["Ndis"],
            eqn_config["total_time"]/eqn_config["Ndis"],
            sqrt(2),
            zeros(eqn_config["dim"]))
    end
end 

function sample(eqn::HJBEquation,num_sample)
    dw_sample=randn((num_sample,eqn.dim,eqn.Ndis))*eqn.delta_t^0.5
    x_sample = zeros((num_sample, eqn.dim, eqn.Ndis + 1))
    x_sample[begin:end,begin:end,1]=transpose(eqn.x_init).*ones((num_sample, eqn.dim))
    for i in 1:eqn.Ndis-1
        x_sample[begin:end, begin:end, i + 1]=x_sample[begin:end, begin:end, i] .+ eqn.sigma .* dw_sample[begin:end, begin:end, i]
    end
    return dw_sample, x_sample
end

function f_t(t,x,y,z)
    sum(z.^2,dims=2)./2
end

function g_T(t,x)
    log((1.0 + sum(x.^2) / 2.0))
end

g_T (generic function with 1 method)

In [193]:
mutable struct FF_subnet
    model
    function FF_subnet(eqn)
        modelNN=Chain(
        BatchNorm(eqn.dim, initβ=zeros32, initγ=ones32,ϵ=1f-6, momentum= 0.99f0),
        Dense(eqn.dim=>eqn.dim+10,bias=false),
        BatchNorm(eqn.dim+10, initβ=zeros32, initγ=ones32,ϵ=1f-6, momentum= 0.99f0),
        relu,
        Dense(eqn.dim+10=>eqn.dim+10,bias=false), 
        BatchNorm(eqn.dim+10, initβ=zeros32, initγ=ones32,ϵ=1f-6, momentum= 0.99f0),  
        relu,
        Dense(eqn.dim+10=>eqn.dim,bias=false), 
        BatchNorm(eqn.dim, initβ=zeros32, initγ=ones32,ϵ=1f-6, momentum= 0.99f0),  
        )
        new(modelNN)
    end
    (subnet::FF_subnet)(x) =subnet.model(x)
end

mutable struct GlobalModel
    eqn::HJBEquation
    subnets::Array{FF_subnet}
    y_init
    z_init
    function GlobalModel(eqn)
        subnets=[FF_subnet(eqn) for _ in 1:eqn.Ndis]
        y_init=rand(Float64,(1))
        z_init=(rand(Float64,(1,eqn.dim)).*0.2).-0.1
        new(eqn,subnets,y_init,z_init)
    end
end

function call_train(glob::GlobalModel,inputs)
        dw, x = inputs
        Nmuestras=size(dw)[1]
        time=range(0, stop = glob.eqn.Ndis*glob.eqn.delta_t, length = glob.eqn.Ndis) |> collect
        y=add_dim((repeat(y_init,Nmuestras)))
        z=repeat(z_init,Nmuestras)
        
        for i in 1:length(glob.subnets)
            Flux.trainmode!(glob.subnets[i].model)
        end
        
        for i in 1:glob.eqn.Ndis-1
            y=y-glob.eqn.delta_t*(f_t(time[i],x[begin:end,begin:end,:],y,z))+sum(z.*dw[:, :, i],dims=2)
            z=transpose(glob.subnets[i](transpose(x[begin:end,begin:end, i + 1]))./(glob.eqn.dim))
        end
        
        y=y-glob.eqn.delta_t*(f_t(time[end],x[begin:end,begin:end,end-1],y,z))+sum(z.*dw[:, :, end],dims=2)
        return y
    end

call_train (generic function with 1 method)

In [184]:
eqn=HJBEquation(Dict("dim"=>10,"total_time"=>1.0,"Ndis"=>20))

HJBEquation(10, 1.0, 20, 0.05, 1.4142135623730951, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

In [185]:
dw,x=sample(eqn,15);

In [186]:
y_init=rand(Float64,(1))
z_init=(rand(Float64,(1,10)).*0.2).-0.1

1×10 Matrix{Float64}:
 -0.099427  0.0539794  0.0566823  …  0.0223384  0.0834763  0.0834237

In [191]:
glob=GlobalModel(eqn);

In [194]:
call_train(glob,sample(eqn,15))

15×1 Matrix{Float64}:
  0.21321425920943265
  0.017905480759636647
  0.22268700782951975
  0.037368202313630694
 -0.1641135427119533
  0.3012551539681844
 -0.1492950304924039
  0.08637380649229044
 -0.04582753569480057
 -0.039283715796431105
  0.8805039638920404
  0.35884436493633987
 -0.11358801106699018
  0.28595709936582137
 -0.24026947132979007

In [166]:
subnet=FF_subnet(eqn)

FF_subnet(Chain(BatchNorm(10), Dense(10 => 20; bias=false), BatchNorm(20), relu, Dense(20 => 20; bias=false), BatchNorm(20), relu, Dense(20 => 10; bias=false), BatchNorm(10)))

In [178]:
subnet(transpose(x[begin:end,begin:end,2]))

10×15 Matrix{Float32}:
  0.181832    0.0500037  -0.132855   …  -0.0716951  -0.184757   -0.00401488
  0.0212535   0.0551969  -0.0405443      0.0101977   0.0648386   0.072182
 -0.104506   -0.0723708  -0.0291431     -0.29108     0.0278465  -0.238214
 -0.0428329  -0.103686   -0.126515       0.0270944  -0.136132   -0.179157
  0.0377601  -0.0196186  -0.18984       -0.167015   -0.0127032  -0.0356404
  0.14793    -0.0138725   0.228474   …   0.230346    0.131342    0.17053
 -0.0491887  -0.022917   -0.130278      -0.10069    -0.0374406  -0.100968
 -0.138075    0.0295732   0.136468       0.0196258   0.0787449   0.148819
 -0.106023   -0.176836   -0.435257      -0.327775   -0.383141   -0.5258
  0.240044    0.200077    0.206523       0.392456    0.071489    0.406051

In [158]:
f_t(0.0,dx[begin:end,begin:end,:],add_dim((repeat(y_init,15))),repeat(z_init,15))

15×1 Matrix{Float64}:
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734
 0.015958857505341734

In [165]:
sum(repeat(z_init,15).*dw[:, :, 2],dims=2)

15×1 Matrix{Float64}:
 -0.10161579141708288
  0.051076615267751806
 -0.08119402852187693
 -0.005199119667968805
  0.007857195776234553
  0.032651651979344616
 -0.01042587294348046
  0.04205370719434498
 -0.051828471816426386
 -0.014349873970768794
 -0.029507062195669696
  0.08586444090986241
 -0.008235016565569922
  0.02341725066795785
 -0.002070654104354565

In [126]:
repeat(z_init,15)

15×10 Matrix{Float64}:
 -0.0788249  0.0613694  -0.0186422  …  -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422  …  -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422  …  -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.0613694  -0.0186422     -0.00266041  0.0774545  -0.0433818
 -0.0788249  0.

In [3]:
struct Person
    name::String
    age::Int64
end