In [1]:
using Flux,Random

In [2]:
add_dim(x::Array) = reshape(x, (size(x)...,1))

add_dim (generic function with 1 method)

In [3]:
struct HJBEquation
    dim::Int64
    total_time::Float64
    Ndis::Int64
    delta_t::Float64
    sigma::Float64
    x_init::Vector{Float64}
    function HJBEquation(eqn_config)
        new(eqn_config["dim"],
            eqn_config["total_time"],
            eqn_config["Ndis"],
            eqn_config["total_time"]/eqn_config["Ndis"],
            sqrt(2),
            zeros(eqn_config["dim"]))
    end
end 

function sample(eqn::HJBEquation,num_sample)
    dw_sample=randn((num_sample,eqn.dim,eqn.Ndis))*eqn.delta_t^0.5
    x_sample = zeros((num_sample, eqn.dim, eqn.Ndis + 1))
    x_sample[begin:end,begin:end,1]=transpose(eqn.x_init).*ones((num_sample, eqn.dim))
    for i in 1:eqn.Ndis-1
        x_sample[begin:end, begin:end, i + 1]=x_sample[begin:end, begin:end, i] .+ eqn.sigma .* dw_sample[begin:end, begin:end, i]
    end
    return dw_sample, x_sample
end

function f_t(t,x,y,z)
    sum(z.^2,dims=2)./2
end

function g_T(t,x)
    log((1.0 + sum(x.^2) / 2.0))
end

g_T (generic function with 1 method)

In [22]:
mutable struct FF_subnet
    model
    function FF_subnet(eqn)
        modelNN=Chain(
        BatchNorm(eqn.dim, initβ=zeros32, initγ=ones32,ϵ=1f-6, momentum= 0.99f0),
        Dense(eqn.dim=>eqn.dim+10,bias=false),
        BatchNorm(eqn.dim+10, initβ=zeros32, initγ=ones32,ϵ=1f-6, momentum= 0.99f0),
        relu,
        Dense(eqn.dim+10=>eqn.dim+10,bias=false), 
        BatchNorm(eqn.dim+10, initβ=zeros32, initγ=ones32,ϵ=1f-6, momentum= 0.99f0),  
        relu,
        Dense(eqn.dim+10=>eqn.dim,bias=false), 
        BatchNorm(eqn.dim, initβ=zeros32, initγ=ones32,ϵ=1f-6, momentum= 0.99f0),  
        )
        new(modelNN)
    end
    (subnet::FF_subnet)(x) =subnet.model(x)
end

mutable struct GlobalModel
    eqn::HJBEquation
    subnets
    y_init
    z_init
    function GlobalModel(eqn)
        subnets=[FF_subnet(eqn).model for _ in 1:eqn.Ndis]
        y_init=rand(Float64,(1))
        z_init=(rand(Float64,(1,eqn.dim)).*0.2).-0.1
        new(eqn,subnets,y_init,z_init)
    end
end

function call_train(glob::GlobalModel,inputs)
    dw, x = inputs
    Nmuestras=size(dw)[1]
    time=range(0, stop = glob.eqn.Ndis*glob.eqn.delta_t, length = glob.eqn.Ndis) |> collect
    y=add_dim((repeat(y_init,Nmuestras)))
    z=repeat(z_init,Nmuestras)
     
    for i in 1:length(glob.subnets)
        Flux.trainmode!(glob.subnets[i])
    end
    
    for i in 1:glob.eqn.Ndis-1
        y=y-glob.eqn.delta_t*(f_t(time[i],x[begin:end,begin:end,:],y,z))+sum(z.*dw[:, :, i],dims=2)
        z=transpose(glob.subnets[i](transpose(x[begin:end,begin:end, i + 1]))./(glob.eqn.dim))
    end
        
    y=y-glob.eqn.delta_t*(f_t(time[end],x[begin:end,begin:end,end-1],y,z))+sum(z.*dw[:, :, end],dims=2)
    return y
end

(glob::GlobalModel)(inputs) =call_train(glob,inputs)
Flux.@functor GlobalModel

In [None]:
function solve_deepBSDE(glob)
    
end

In [5]:
eqn=HJBEquation(Dict("dim"=>10,"total_time"=>1.0,"Ndis"=>20))

HJBEquation(10, 1.0, 20, 0.05, 1.4142135623730951, [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

In [6]:
dw,x=sample(eqn,15);

In [7]:
y_init=rand(Float64,(1))
z_init=(rand(Float64,(1,10)).*0.2).-0.1

1×10 Matrix{Float64}:
 -0.0152314  -0.0333775  -0.0516761  …  0.0218367  0.05788  -0.00954501

In [8]:
glob=GlobalModel(eqn);

In [24]:
call_train(glob,sample(eqn,15))

15×1 Matrix{Float64}:
 -0.1816409617031584
  0.2669341484650338
  0.298206514335149
  0.07727183081209783
 -0.16004407132607112
  0.40517618651196113
  0.04795252346218922
 -0.502067128496787
  0.05583168023991775
 -0.03467979468288429
  0.042213335922807585
  0.5704554025856478
  0.498791604781371
 -0.11476183924710223
  0.15150688454253197

In [14]:
f_t(0.0,dx[begin:end,begin:end,:],add_dim((repeat(y_init,15))),repeat(z_init,15))

LoadError: UndefVarError: dx not defined

In [15]:
sum(repeat(z_init,15).*dw[:, :, 2],dims=2)

15×1 Matrix{Float64}:
 -0.007972280864167217
 -0.04433928471026934
  0.004536559344606637
  0.034349846732544737
 -0.07256664692756726
 -0.0037852283860949727
 -0.027476607661145114
 -0.048866595311670945
 -0.04238879849711959
 -0.06113915931863397
  0.020591908938789212
 -0.00841841537964974
 -0.14398784152082558
 -0.014705320647642496
  0.05373210543405884

In [16]:
repeat(z_init,15)

15×10 Matrix{Float64}:
 -0.0152314  -0.0333775  -0.0516761  …  0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761  …  0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761  …  0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0.0218367  0.05788  -0.00954501
 -0.0152314  -0.0333775  -0.0516761     0

In [17]:
struct Person
    name::String
    age::Int64
end