In [1]:
using Flux,Random,Optimisers,Statistics

In [2]:
Threads.nthreads()

8

In [13]:
struct HJBEquation
    dim::Int64
    total_time::Float32
    Ndis::Int64
    delta_t::Float32
    sigma::Float32
    x_init::Vector{Float32}
    f::Function
    g::Function
end

function HJBEquation(eqn_config::Dict,f,g)
    HJBEquation(eqn_config["dim"],
    eqn_config["total_time"],
    eqn_config["Ndis"],
    eqn_config["total_time"]/eqn_config["Ndis"],
    sqrt(2),
    zeros32(eqn_config["dim"]),
    f,
    g)
end

function sample(eqn::HJBEquation,num_sample)
    dw_sample=randn((eqn.dim,num_sample,eqn.Ndis))*eqn.delta_t^0.5
    x_sample = zeros((eqn.dim,num_sample, eqn.Ndis + 1))
    x_sample[begin:end,begin:end,1]=ones((eqn.dim,num_sample)).*eqn.x_init
    for i in 1:eqn.Ndis-1
        x_sample[begin:end, begin:end, i + 1]=x_sample[begin:end, begin:end, i] .+ eqn.sigma .* dw_sample[begin:end, begin:end, i]
    end
    return dw_sample, x_sample
end

function f_t(t,x,y,z)
    return sum(z.^2,dims=1)./2
end

function g_T(t,x)
    return log.((1.0 .+ sum(x.^2,dims=1) ./ 2.0))
end

g_T (generic function with 1 method)

In [4]:
mutable struct FF_subnet
    model
end

function FF_subnet(eqn::HJBEquation)
    modelNN=Chain(
    BatchNorm(eqn.dim, initβ=zeros32, initγ=ones32,ϵ=1e-6, momentum= 0.99),
    Dense(eqn.dim=>eqn.dim+10,bias=false,init=rand32),
    BatchNorm(eqn.dim+10, initβ=zeros32, initγ=ones32,ϵ=1e-6, momentum= 0.99),
    relu,
    Dense(eqn.dim+10=>eqn.dim+10,bias=false,init=rand32), 
    BatchNorm(eqn.dim+10, initβ=zeros32, initγ=ones32,ϵ=1e-6, momentum= 0.99),  
    relu,
    Dense(eqn.dim+10=>eqn.dim,bias=false,init=rand32), 
    BatchNorm(eqn.dim, initβ=zeros32, initγ=ones32,ϵ=1e-6, momentum= 0.99),  
    )
    FF_subnet(modelNN)
end

(subnet::FF_subnet)(x) =subnet.model(x)
Flux.@functor FF_subnet

mutable struct GlobalModel
    eqn::HJBEquation
    subnets
    y_init
    z_init
    times
end

mutable struct trainableVariable
    arr
end

Flux.@functor trainableVariable (arr,)

function GlobalModel(eqn::HJBEquation)
    subnets=[FF_subnet(eqn) for _ in 1:eqn.Ndis]
    for i in 1:length(subnets)
        Flux.trainmode!(subnets[i].model)
    end
    y_init=trainableVariable(rand(Float32,(1)))
    z_init=trainableVariable((rand(Float32,(eqn.dim,1)).*0.2).-0.1)
    times=range(0, stop = eqn.Ndis*eqn.delta_t, length = eqn.Ndis) |> collect
    GlobalModel(eqn,subnets,y_init,z_init,times)
end

function testMode!(glob::GlobalModel)
    for i in 1:length(glob.subnets)
        Flux.testmode!(glob.subnets[i].model)
    end
end

function trainMode!(glob::GlobalModel)
    for i in 1:length(glob.subnets)
        Flux.trainmode!(glob.subnets[i].model)
    end
end

function call_train(glob::GlobalModel,inputs)
    dw, x = inputs
    Nmuestras=size(dw)[2]
    y=transpose((repeat(glob.y_init.arr,Nmuestras)))
    z=ones((glob.eqn.dim,Nmuestras)).*glob.z_init.arr
    #println(size(z))

    for i in 1:glob.eqn.Ndis-1
        y=y.-glob.eqn.delta_t.*(glob.eqn.f(glob.times[i],x[begin:end,begin:end,i],y,z)).+sum(z.*dw[:, :, i],dims=1)
        z=glob.subnets[i](x[begin:end,begin:end, i + 1])./(glob.eqn.dim)
    end  
    y=y.-glob.eqn.delta_t*(glob.eqn.f(glob.times[end],x[begin:end,begin:end,end-1],y,z)).+sum(z.*dw[:, :, end],dims=1)
    return y
end

(glob::GlobalModel)(inputs) =call_train(glob,inputs)
Flux.@functor GlobalModel

In [5]:
y_init=trainableVariable(rand(Float64,(1)))

trainableVariable([0.880815858164645])

In [6]:
Flux.params(y_init)

Params([[0.880815858164645]])

In [7]:
DELTA_CLIP=50.0
function loss(glob,y_terminal,inputs)
    dw,x=inputs
    #y_terminal=glob(inputs)
    delta = y_terminal .- glob.eqn.g(glob.eqn.total_time, x[begin:end,begin:end, end])
    return mean(ifelse.(delta .< DELTA_CLIP, delta.^2, 2 * DELTA_CLIP * abs.(delta) .- DELTA_CLIP^2))
end

function loss_test(glob,y_terminal,inputs)
    dw,x=inputs
    testmode!(glob)
    #y_terminal=call_test(glob,inputs)
    delta = y_terminal .- glob.eqn.g(glob.eqn.total_time, x[begin:end,begin:end, end])
    trainmode!(glob)
    return mean(ifelse.(delta .< DELTA_CLIP, delta.^2, 2 * DELTA_CLIP * abs.(delta) .- DELTA_CLIP^2))
end

function solve_deepBSDE(glob::GlobalModel)
    #opt = Optimisers.Adam(0.01)
    #opt_state = Optimisers.setup(opt, glob)
    optim = Flux.setup(Flux.Adam(0.01), glob)
    
    for epoch in 1:2000
        losses = Float32[]
        input=sample(glob.eqn,64)
        
        grad = Flux.gradient(glob) do m
        result = m(input)
        loss(glob,result, input)
        end
        #val, grads = Flux.withgradient(glob.subnets) do m
          # Any code inside here is differentiated.
          # Evaluation of the model and loss must be inside!
        #result = glob(input)
        #loss(glob, input)
        #end

        if epoch%100==0
            input=sample(glob.eqn,254)
            println("Epoch ",epoch, " losses ",loss_test(glob,glob(input),input), " ysol ", glob.y_init.arr)
        end

        Flux.update!(optim, glob, grad[1])
        #state, model = Optimisers.update(opt_state, glob, grads[1])
    end
    
end

solve_deepBSDE (generic function with 1 method)

In [8]:
eqn=HJBEquation(Dict("dim"=>100,"total_time"=>1.0,"Ndis"=>20),f_t,g_T)

HJBEquation(100, 1.0f0, 20, 0.05f0, 1.4142135f0, Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], f_t, g_T)

In [9]:
glob=GlobalModel(eqn);

In [14]:
@time solve_deepBSDE(glob)

Epoch 100 losses 0.002656905927631626 ysol Float32[0.0031304776]
Epoch 200 losses 0.0026959720447918057 ysol Float32[0.00649954]
Epoch 300 losses 0.0026251476063191073 ysol Float32[0.004626618]
Epoch 400 losses 0.002188021751344348 ysol Float32[0.007943555]
Epoch 500 losses 0.0013781577292144892 ysol Float32[0.0028352172]
Epoch 600 losses 0.0013344294357706136 ysol Float32[0.004278153]
Epoch 700 losses 0.000954004813989094 ysol Float32[0.0010694381]
Epoch 800 losses 0.0011014290347316534 ysol Float32[0.0014601664]
Epoch 900 losses 0.0005692993337299386 ysol Float32[-0.0022162013]
Epoch 1000 losses 0.000586309260443065 ysol Float32[0.0005684359]
Epoch 1100 losses 0.00028486178572682466 ysol Float32[0.0009460431]
Epoch 1200 losses 0.00023481761281217642 ysol Float32[-0.00090554805]
Epoch 1300 losses 0.00015093630504002816 ysol Float32[0.00015742563]
Epoch 1400 losses 0.00010842360808944489 ysol Float32[-0.00010397883]
Epoch 1500 losses 0.00010666544947496423 ysol Float32[0.0009546122]
Ep

In [11]:
size(glob(sample(glob.eqn,78)))

(1, 78)

In [12]:
Flux.params(glob)[1]

100-element Vector{Float32}:
  0.0010579661
  0.0006942776
  0.00045408803
 -0.0010029796
 -0.0023774996
  0.0006423336
  7.4286127f-6
 -4.1688258f-5
 -0.0020096388
 -0.0030830482
 -0.0014785125
 -0.00071486476
  0.00023742882
  ⋮
 -0.00083389273
  0.0005112303
 -0.0009998005
 -0.0009897274
  0.0015123197
 -0.0015001412
  0.00086951075
  0.0010303209
 -0.00016403117
  4.4514305f-5
  0.00082801876
 -0.0020559207