In [1]:
using Flux,Random,Optimisers,Statistics

In [2]:
Threads.nthreads()

8

In [3]:
struct HJBEquation
    dim::Int64
    total_time::Float64
    Ndis::Int64
    delta_t::Float64
    sigma::Float64
    x_init::Vector{Float64}
    f::Function
    g::Function
end

function HJBEquation(eqn_config::Dict,f,g)
    HJBEquation(eqn_config["dim"],
    eqn_config["total_time"],
    eqn_config["Ndis"],
    eqn_config["total_time"]/eqn_config["Ndis"],
    sqrt(2),
    zeros32(eqn_config["dim"]),
    f,
    g)
end

function sample(eqn::HJBEquation,num_sample)
    dw_sample=randn((eqn.dim,num_sample,eqn.Ndis))*eqn.delta_t^0.5
    x_sample = zeros((eqn.dim,num_sample, eqn.Ndis + 1))
    x_sample[begin:end,begin:end,1]=ones((eqn.dim,num_sample)).*eqn.x_init
    for i in 1:eqn.Ndis-1
       @views x_sample[begin:end, begin:end, i + 1]=x_sample[begin:end, begin:end, i] .+ eqn.sigma .* dw_sample[begin:end, begin:end, i]
    end
    return dw_sample, x_sample
end

function f_t(t,x,y,z)
    return sum(z.^2,dims=1)./2
end

function g_T(t,x)
    return log.((1.0 .+ sum(x.^2,dims=1) ./ 2.0))
end

g_T (generic function with 1 method)

In [4]:
mutable struct FF_subnet
    model
end

function FF_subnet(eqn::HJBEquation)
    modelNN=Chain(
    BatchNorm(eqn.dim, initβ=zeros, initγ=ones,ϵ=1e-6, momentum= 0.99),
    Dense(eqn.dim=>eqn.dim+10,bias=false,init=rand),
    BatchNorm(eqn.dim+10, initβ=zeros, initγ=ones,ϵ=1e-6, momentum= 0.99),
    relu,
    Dense(eqn.dim+10=>eqn.dim+10,bias=false,init=rand), 
    BatchNorm(eqn.dim+10, initβ=zeros, initγ=ones,ϵ=1e-6, momentum= 0.99),  
    relu,
    Dense(eqn.dim+10=>eqn.dim,bias=false,init=rand), 
    BatchNorm(eqn.dim, initβ=zeros, initγ=ones,ϵ=1e-6, momentum= 0.99),  
    )
    FF_subnet(modelNN)
end

(subnet::FF_subnet)(x) =subnet.model(x)
Flux.@functor FF_subnet

mutable struct GlobalModel
    eqn::HJBEquation
    subnets
    y_init
    z_init
    times
end

mutable struct trainableVariable
    arr
end

Flux.@functor trainableVariable (arr,)

function GlobalModel(eqn::HJBEquation)
    subnets=[FF_subnet(eqn) for _ in 1:eqn.Ndis]
    for i in 1:length(subnets)
        Flux.trainmode!(subnets[i].model)
    end
    y_init=trainableVariable(rand(Float64,(1)))
    z_init=trainableVariable((rand(Float64,(eqn.dim,1)).*0.2).-0.1)
    times=range(0, stop = eqn.Ndis*eqn.delta_t, length = eqn.Ndis) |> collect
    GlobalModel(eqn,subnets,y_init,z_init,times)
end

function testMode!(glob::GlobalModel)
    for i in 1:length(glob.subnets)
        Flux.testmode!(glob.subnets[i].model)
    end
end

function trainMode!(glob::GlobalModel)
    for i in 1:length(glob.subnets)
        Flux.trainmode!(glob.subnets[i].model)
    end
end

function call_train(glob::GlobalModel,inputs)
    dw, x = inputs
    Nmuestras=size(dw)[2]
    y=transpose((repeat(glob.y_init.arr,Nmuestras)))
    z=ones((glob.eqn.dim,Nmuestras)).*glob.z_init.arr
    #println(size(z))

    for i in 1:glob.eqn.Ndis-1
         y=@views y.-glob.eqn.delta_t.*(glob.eqn.f(glob.times[i],x[begin:end,begin:end,i],y,z)).+sum(z.*dw[:, :, i],dims=1)
         z=@views glob.subnets[i](x[begin:end,begin:end, i + 1])./(glob.eqn.dim)
    end  
    y=@views y.-glob.eqn.delta_t*(glob.eqn.f(glob.times[end],x[begin:end,begin:end,end-1],y,z)).+sum(z.*dw[:, :, end],dims=1)
    return y
end

(glob::GlobalModel)(inputs) =call_train(glob,inputs)
Flux.@functor GlobalModel

In [5]:
y_init=trainableVariable(rand(Float64,(1)))

trainableVariable([0.3162416203668966])

In [6]:
Flux.params(y_init)

Params([[0.3162416203668966]])

In [7]:
DELTA_CLIP=50.0
function loss(glob,y_terminal,inputs)
    dw,x=inputs
    #y_terminal=glob(inputs)
    delta = @views y_terminal .- glob.eqn.g(glob.eqn.total_time, x[begin:end,begin:end, end])
    return mean(ifelse.(delta .< DELTA_CLIP, delta.^2, 2 * DELTA_CLIP * abs.(delta) .- DELTA_CLIP^2))
end

function loss_test(glob,y_terminal,inputs)
    dw,x=inputs
    testmode!(glob)
    #y_terminal=call_test(glob,inputs)
    delta = @views y_terminal .- glob.eqn.g(glob.eqn.total_time, x[begin:end,begin:end, end])
    trainmode!(glob)
    return mean(ifelse.(delta .< DELTA_CLIP, delta.^2, 2 * DELTA_CLIP * abs.(delta) .- DELTA_CLIP^2))
end

function solve_deepBSDE(glob::GlobalModel)
    #opt = Optimisers.Adam(0.01)
    #opt_state = Optimisers.setup(opt, glob)
    optim = Flux.setup(Flux.Adam(0.01), glob)
    
    for epoch in 1:2000
        losses = Float32[]
        input=sample(glob.eqn,64)
        
        grad = Flux.gradient(glob) do m
        result = m(input)
        loss(glob,result, input)
        end
        #val, grads = Flux.withgradient(glob.subnets) do m
          # Any code inside here is differentiated.
          # Evaluation of the model and loss must be inside!
        #result = glob(input)
        #loss(glob, input)
        #end

        if epoch%100==0
            input=sample(glob.eqn,254)
            println("Epoch ",epoch, " losses ",loss_test(glob,glob(input),input), " ysol ", glob.y_init.arr)
        end

        Flux.update!(optim, glob, grad[1])
        #state, model = Optimisers.update(opt_state, glob, grads[1])
    end
    
end

solve_deepBSDE (generic function with 1 method)

In [8]:
eqn=HJBEquation(Dict("dim"=>100,"total_time"=>1.0,"Ndis"=>20),f_t,g_T)

HJBEquation(100, 1.0, 20, 0.05, 1.4142135623730951, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0  …  0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], f_t, g_T)

In [9]:
glob=GlobalModel(eqn);

In [13]:
@time solve_deepBSDE(glob)

Epoch 100 losses 0.0010985343393734761 ysol [0.002235794382079895]
Epoch 200 losses 0.0006759316471598907 ysol [0.0007356184860515199]
Epoch 300 losses 0.0006366787494693209 ysol [0.0009809053512971162]
Epoch 400 losses 0.0006045692572417405 ysol [0.0036235622728816594]
Epoch 500 losses 0.00033670910229452233 ysol [0.005285969774767549]
Epoch 600 losses 0.00023802028464593698 ysol [-0.0007802545719589988]
Epoch 700 losses 0.00019928899909859993 ysol [-0.0009619016502547542]
Epoch 800 losses 0.00014721747720366368 ysol [-0.0011069563984999163]
Epoch 900 losses 0.00010195619396543539 ysol [-0.0006469187795215655]
Epoch 1000 losses 8.918661471768875e-5 ysol [-0.00022831149472520598]
Epoch 1100 losses 0.00010494229143276654 ysol [-0.0020568123039538555]
Epoch 1200 losses 8.720548314167359e-5 ysol [-0.000374358288205359]
Epoch 1300 losses 0.00010319218123648247 ysol [-0.0022051327938945905]
Epoch 1400 losses 0.00040852345210203934 ysol [-0.007316377259454516]
Epoch 1500 losses 0.00029667681

In [11]:
size(glob(sample(glob.eqn,78)))

(1, 78)

In [12]:
Flux.params(glob)[1]

100-element Vector{Float64}:
  2.78837991832156e-12
 -4.869334577670193e-13
  2.6766023539961064e-13
 -9.130613476982703e-13
 -6.315040857483158e-13
  1.481527898799976e-12
  3.488196485801134e-13
 -1.3234127328425388e-12
 -1.5358675779559863e-12
 -7.30435852192345e-13
  9.07099309251287e-13
  1.2692232387353922e-12
 -1.7957855480943528e-12
  ⋮
 -1.8296418364459685e-12
  1.3333398595443609e-12
  3.4304582318406017e-13
 -1.21425427744357e-12
  2.761783013682222e-14
 -3.426596762812489e-14
 -1.0870591558452859e-12
  2.6241101813390985e-12
 -1.437519933469093e-12
 -1.5268056569604891e-12
  2.2940062873969003e-13
 -4.614527453247167e-13