In [1]:
using Flux,Random,Optimisers,Statistics,.Threads

In [2]:
add_dim(x::Array) = reshape(x, (size(x)...,1))

add_dim (generic function with 1 method)

In [3]:
struct HJBEquation
    dim::Int64
    total_time::Float64
    Ndis::Int64
    delta_t::Float64
    sigma::Float64
    x_init::Vector{Float64}
    function HJBEquation(eqn_config)
        new(eqn_config["dim"],
            eqn_config["total_time"],
            eqn_config["Ndis"],
            eqn_config["total_time"]/eqn_config["Ndis"],
            sqrt(2),
            zeros(eqn_config["dim"]))
    end
end 

function sample(eqn::HJBEquation,num_sample)
    dw_sample=randn((num_sample,eqn.dim,eqn.Ndis))*eqn.delta_t^0.5
    x_sample = zeros((num_sample, eqn.dim, eqn.Ndis + 1))
    x_sample[begin:end,begin:end,1]=transpose(eqn.x_init).*ones((num_sample, eqn.dim))
    for i in 1:eqn.Ndis-1
        x_sample[begin:end, begin:end, i + 1]=x_sample[begin:end, begin:end, i] .+ eqn.sigma .* dw_sample[begin:end, begin:end, i]
    end
    return dw_sample, x_sample
end

function f_t(t,x,y,z)
    return sum(z.^2,dims=2)./2
end

function g_T(t,x)
    return log.((1.0 .+ sum(x.^2,dims=2) ./ 2.0))
end

g_T (generic function with 1 method)

In [4]:
mutable struct FF_subnet
    model
    function FF_subnet(eqn)
        modelNN=Chain(
        BatchNorm(eqn.dim, initβ=zeros, initγ=ones,ϵ=1e-6, momentum= 0.99),
        Dense(eqn.dim=>eqn.dim+100,bias=false,init=rand),
        BatchNorm(eqn.dim+100, initβ=zeros, initγ=ones,ϵ=1e-6, momentum= 0.99),
        relu,
        Dense(eqn.dim+100=>eqn.dim+100,bias=false,init=rand), 
        BatchNorm(eqn.dim+100, initβ=zeros, initγ=ones,ϵ=1e-6, momentum= 0.99),  
        relu,
        Dense(eqn.dim+100=>eqn.dim,bias=false,init=rand), 
        BatchNorm(eqn.dim, initβ=zeros, initγ=ones,ϵ=1e-6, momentum= 0.99),  
        )
        new(modelNN)
    end
    (subnet::FF_subnet)(x) =subnet.model(x)
end

mutable struct GlobalModel
    eqn::HJBEquation
    subnets
    y_init
    z_init
    function GlobalModel(eqn)
        subnets=[FF_subnet(eqn).model for _ in 1:eqn.Ndis]
        y_init=rand(Float64,(1))
        z_init=(rand(Float64,(1,eqn.dim)).*0.2).-0.1
        new(eqn,subnets,y_init,z_init)
    end
end

function call_train(glob::GlobalModel,inputs)
    dw, x = inputs
    Nmuestras=size(dw)[1]
    time=range(0, stop = glob.eqn.Ndis*glob.eqn.delta_t, length = glob.eqn.Ndis) |> collect
    y=add_dim((repeat(y_init,Nmuestras)))
    z=repeat(z_init,Nmuestras)
     
    for i in 1:length(glob.subnets)
        Flux.trainmode!(glob.subnets[i])
    end
    
    for i in 1:glob.eqn.Ndis-1
        y=y-glob.eqn.delta_t*(f_t(time[i],x[begin:end,begin:end,:],y,z))+sum(z.*dw[:, :, i],dims=2)
        z=transpose(glob.subnets[i](transpose(x[begin:end,begin:end, i + 1]))./(glob.eqn.dim))
    end
        
    y=y-glob.eqn.delta_t*(f_t(time[end],x[begin:end,begin:end,end-1],y,z))+sum(z.*dw[:, :, end],dims=2)
    return y
end

function call_test(glob::GlobalModel,inputs)
    dw, x = inputs
    Nmuestras=size(dw)[1]
    time=range(0, stop = glob.eqn.Ndis*glob.eqn.delta_t, length = glob.eqn.Ndis) |> collect
    y=add_dim((repeat(y_init,Nmuestras)))
    z=repeat(z_init,Nmuestras)
     
    for i in 1:length(glob.subnets)
        Flux.testmode!(glob.subnets[i])
    end
    
    for i in 1:glob.eqn.Ndis-1
        y=y-glob.eqn.delta_t*(f_t(time[i],x[begin:end,begin:end,:],y,z))+sum(z.*dw[:, :, i],dims=2)
        z=transpose(glob.subnets[i](transpose(x[begin:end,begin:end, i + 1]))./(glob.eqn.dim))
    end
        
    y=y-glob.eqn.delta_t*(f_t(time[end],x[begin:end,begin:end,end-1],y,z))+sum(z.*dw[:, :, end],dims=2)
    return y
end

(glob::GlobalModel)(inputs) =call_train(glob,inputs)
Flux.@functor GlobalModel

In [5]:
DELTA_CLIP=50.0
function loss(glob,inputs)
    dw,x=inputs
    y_terminal=glob(inputs)
    delta = y_terminal .- g_T(glob.eqn.total_time, x[begin:end,begin:end, end])
    return mean(ifelse.(delta .< DELTA_CLIP, delta.^2, 2 * DELTA_CLIP * abs.(delta) .- DELTA_CLIP^2))
end

function loss_test(glob,inputs)
    dw,x=inputs
    y_terminal=call_test(glob,inputs)
    delta = y_terminal .- g_T(glob.eqn.total_time, x[begin:end,begin:end, end])
    return mean(ifelse.(delta .< DELTA_CLIP, delta.^2, 2 * DELTA_CLIP * abs.(delta) .- DELTA_CLIP^2))
end

function solve_deepBSDE(glob::GlobalModel)
    #opt = Optimisers.Adam(0.01)
    #opt_state = Optimisers.setup(opt, glob)
    optim = Flux.setup(Flux.Adam(0.01), glob.subnets)
    
    my_log = []
    for epoch in 1:2000
        losses = Float32[]
        input=sample(glob.eqn,64)

        val, grads = Flux.withgradient(glob.subnets) do m
          # Any code inside here is differentiated.
          # Evaluation of the model and loss must be inside!
        result = glob(input)
        loss(glob, input)
        end

        # Save the loss from the forward pass. (Done outside of gradient.)
        push!(losses, val)

        # Detect loss of Inf or NaN. Print a warning, and then skip update!
        if !isfinite(val)
          @warn "loss is $val on item $i" epoch
          continue
        end
        if epoch%100==0
            println("Epoch ",epoch, " losses ",loss_test(glob,sample(glob.eqn,254)))
        end

        Flux.update!(optim, glob.subnets, grads[1])
        #state, model = Optimisers.update(opt_state, glob, grads[1])
    end
    
end

solve_deepBSDE (generic function with 1 method)

In [1]:
eqn=HJBEquation(Dict("dim"=>100,"total_time"=>1.0,"Ndis"=>20))

LoadError: UndefVarError: HJBEquation not defined

In [7]:
dw,x=sample(eqn,15);

In [8]:
y_init=rand(Float64,(1))
z_init=(rand(Float64,(1,10)).*0.2).-0.1

1×10 Matrix{Float64}:
 -0.010744  -0.0168559  0.085374  …  -0.00974729  0.0153075  0.0177779

In [9]:
glob=GlobalModel(eqn);

In [10]:
@time solve_deepBSDE(glob)

Epoch 100 losses 0.7948490405567581
Epoch 200 losses 0.7755333527766887
Epoch 300 losses 0.8033820337035619
Epoch 400 losses 0.7496778880126584
Epoch 500 losses 0.7595727514283246
Epoch 600 losses 0.8635234437202844
Epoch 700 losses 0.857349550372989
Epoch 800 losses 0.7767400126213
Epoch 900 losses 0.8270851965812074
Epoch 1000 losses 0.7726272830609059
Epoch 1100 losses 0.7295494389715612
Epoch 1200 losses 0.6746803115249647
Epoch 1300 losses 0.796480948826693
Epoch 1400 losses 0.8577938825165058
Epoch 1500 losses 0.7826781960710344
Epoch 1600 losses 0.8026195889658705
Epoch 1700 losses 0.8072196031427402
Epoch 1800 losses 0.8454634639948534
Epoch 1900 losses 0.74170846484631
Epoch 2000 losses 0.8033973525332977
127.328162 seconds (345.05 M allocations: 140.478 GiB, 8.42% gc time, 30.64% compilation time)
