In [1]:
using Flux.Tracker
using Flux
using Flux:throttle,glorot_uniform
using BSON:@save
using BSON:@load
using Base.Iterators: repeated
using Flux:@treelike
using PyPlot
using FastGaussQuadrature
using SparseGrids
using LinearAlgebra
using Flux: @epochs
using DataFrames
using CSV
using Sobol
using Pkg; Pkg.activate("cuda"); Pkg.instantiate()
using CuArrays
using ForwardDiff
using CUDAnative
CuArrays.culiteral_pow(::typeof(^), x::ForwardDiff.Dual{Nothing,Float32,1}, ::Val{2}) = x*x
CuArrays.culiteral_pow(::typeof(^), x::ForwardDiff.Dual{Nothing,Float64,1}, ::Val{2}) = x*x

f(x)=2*π^2 .*sum(cos.(π*x[i,:]') for i=1:16, dims=1)
acti(x)=@. x/(1+CUDAnative.exp(-x))


struct Block{F,S,T}
  W1::S
  W2::S
  b1::T
  b2::T
  σ::F
end

Block(W1 , W2, b1, b2) = Block(W1,W2 ,b1 ,b2, identity)

function Block(in1::Integer,in2::Integer, out1::Integer,out2::Integer, σ = identity;
  initW1 = glorot_uniform, initW2 =glorot_uniform, initb1 = zeros,initb2=zeros)
  return Block(param(initW1(out1, in1)),param(initW2(out2,in2)), param(initb1(out1)), param(initb2(out2)),σ)
end

@treelike Block

function (a::Block)(x)
  W1,W2, b1,b2, σ = a.W1,a.W2,a.b1, a.b2, a.σ
  σ.(W2 *σ.(W1*x .+ b1) .+ b2) .+ x
end





M=Chain(
Dense(16,48),
Block(48,48,48,48,acti),
Block(48,48,48,48,acti),
Block(48,48,48,48,acti),
Block(48,48,48,48,acti),
Block(48,48,48,48,acti),
Dense(48,1)
)|>gpu

function lossFD(N)
  notes=rand(16,N)
  notes =notes |>gpu
  d=rand(1)/100 |>gpu
  losses=0
  MMM=M(notes)
  for i=1:16
    l1=zeros(16)
    l1[i]=1
    l1 = l1 |>gpu
    losses  += (sum(0.5*((M(notes .+ d.*l1) .- MMM)./d).^2)/N)[1]
  end
  losses += (sum(0.5*(π^2 * (MMM).^2))/N)[1]
  losses += (sum(-f(notes) .*MMM)/N)[1]
  return losses
end


function lossBC(β)
  d=(rand(1)/100)|>gpu
  point=zeros(16,3200)
  p=rand(15,100)
  for i=1:16
    point[:,200*(i-1)+1:200*(i-1)+100]=vcat(p[1:i-1,:],zeros(100)',p[i:end,:])
    point[:,200*(i-1)+101:200*i]=vcat(p[1:i-1,:], ones(100)',p[i:end,:])
  end
    point=point|>gpu
  losses=0
  for i=1:16
    l1=zeros(16)
    l1[i]=1
    l1=l1|>gpu
    losses +=β*(sum((( 1/2*M(point[:,200*(i-1)+1:200*i] .+ d.*l1) .- 1/2*M(point[:,200*(i-1)+1:200*i] .- d.*l1))./d).^2)/100)[1]
  end
  return losses
end

function loss1(N)
  return lossBC(50)+lossFD(N)
end


function loss_true(xx=1,yy=1,xx1=1,xx2=2)
  points=zeros(16,5000)
    for i=1:5000
      points[:,i]=next!(truep)
    end
    points=points  |>gpu
  F_true(x)= sum(cos.(pi*x[i,:]') for i=1:16, dims=1)
  errors=sqrt(sum((M(points)-F_true(points)).^2)/5000)
  return errors
end


inner=SobolSeq(16)
truep=SobolSeq(16)
BC=SobolSeq(15)


function test(N;traintime=2000)
  errdf=DataFrame(time=Int[],absolute_error=Float64[],loss_FD=Float64[],loss_BC=Float64[])
  cntr=1
  evalcb = function()
    loss_true1=loss_true()
    loss_FD = lossFD(N)
    loss_BC =lossBC(1)
    push!(errdf,[Tracker.data(cntr),Tracker.data(loss_true1),Tracker.data(loss_FD),Tracker.data(loss_BC)])
    cntr+=1
    if cntr%1000==0
      @show(cntr,loss_FD,loss_true1,loss_BC)
      CSV.write("$(N)rand16Dexm2.csv",errdf)
      @save "$(N)randpoint16Dexm2m.bson" M
      weights=Tracker.data.(params(M))
      @save "$(N)randpoint16Dexm2w.bson" weights
    end
  end
  θ=Flux.params(M)
  opt=ADAM()
  dataset=[(N) for i=1:traintime]
  Flux.train!(loss1, params(M), zip(dataset), opt, cb=evalcb)
  return errdf
end


┌ Info: Precompiling Flux [587475ba-b771-5e3f-ad9e-33799f191a9c]
└ @ Base loading.jl:1242
┌ Info: Precompiling BSON [fbb218c0-5317-5bc6-957e-2ee96dd4b1f0]
└ @ Base loading.jl:1242
┌ Info: Precompiling SparseGrids [bafbe729-afc6-5148-bb4f-226bf3d46895]
└ @ Base loading.jl:1242
┌ Info: Precompiling DataFrames [a93c6f00-e57d-5684-b7b6-d8193f3e46c0]
└ @ Base loading.jl:1242
┌ Info: Precompiling CSV [336ed68f-0bac-5ca0-87d4-7b16caf5d00b]
└ @ Base loading.jl:1242


[32m[1mActivating[22m[39m new environment at `~/16Dexm2/cuda/Project.toml`
[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`
[?25l[2K[?25h[32m[1m Resolving[22m[39m package versions...


┌ Info: Building the CUDAnative run-time library for your sm_61 device, this might take a while...
└ @ CUDAnative /root/.julia/packages/CUDAnative/LkH1v/src/compiler/rtlib.jl:168


test (generic function with 1 method)

In [2]:
errdf=test(500,traintime=30000)
for i=1:4
    N=[1000 2000 5000 10000]
    errdf=test(N[i],traintime=50000)
end

cntr = 1000
loss_FD = -19.525688630449324 (tracked)
loss_true1 = 2.481211f0 (tracked)
loss_BC = 0.06296989680272257 (tracked)
cntr = 2000
loss_FD = -27.146347471131246 (tracked)
loss_true1 = 2.3172364f0 (tracked)
loss_BC = 0.02617622278537293 (tracked)
cntr = 3000
loss_FD = -34.95413641679771 (tracked)
loss_true1 = 2.1820488f0 (tracked)
loss_BC = 0.030396658428438694 (tracked)
cntr = 4000
loss_FD = -47.12749156956052 (tracked)
loss_true1 = 1.9424953f0 (tracked)
loss_BC = 0.059307627318419855 (tracked)
cntr = 5000
loss_FD = -46.935382064491094 (tracked)
loss_true1 = 1.7081114f0 (tracked)
loss_BC = 0.05073101561654039 (tracked)
cntr = 6000
loss_FD = -55.19652986351254 (tracked)
loss_true1 = 1.5925113f0 (tracked)
loss_BC = 0.06196336761231943 (tracked)
cntr = 7000
loss_FD = -47.498971982456 (tracked)
loss_true1 = 1.5749954f0 (tracked)
loss_BC = 0.02862421738848757 (tracked)
cntr = 8000
loss_FD = -51.373548135244704 (tracked)
loss_true1 = 1.563626f0 (tracked)
loss_BC = 0.0240225681477729 (

cntr = 36000
loss_FD = -74.51880077015657 (tracked)
loss_true1 = 0.08752469f0 (tracked)
loss_BC = 0.014232981480811766 (tracked)
cntr = 37000
loss_FD = -81.71168973696639 (tracked)
loss_true1 = 0.16153972f0 (tracked)
loss_BC = 0.01232945590953666 (tracked)
cntr = 38000
loss_FD = -70.92073387211462 (tracked)
loss_true1 = 0.08599121f0 (tracked)
loss_BC = 0.01924758563898435 (tracked)
cntr = 39000
loss_FD = -78.57381526597402 (tracked)
loss_true1 = 0.13024403f0 (tracked)
loss_BC = 0.010238559178423276 (tracked)
cntr = 40000
loss_FD = -71.09109083500921 (tracked)
loss_true1 = 0.12472065f0 (tracked)
loss_BC = 0.015708422643285484 (tracked)
cntr = 41000
loss_FD = -81.37323300970925 (tracked)
loss_true1 = 0.09926249f0 (tracked)
loss_BC = 0.01422299314051003 (tracked)
cntr = 42000
loss_FD = -82.08097259940382 (tracked)
loss_true1 = 0.048950583f0 (tracked)
loss_BC = 0.015109917508829956 (tracked)
cntr = 43000
loss_FD = -83.76401245337084 (tracked)
loss_true1 = 0.06289853f0 (tracked)
loss_BC = 0

cntr = 50000
loss_FD = -85.21540641512742 (tracked)
loss_true1 = 0.07433978f0 (tracked)
loss_BC = 0.0034780593790309355 (tracked)
cntr = 1000
loss_FD = -79.31783955877832 (tracked)
loss_true1 = 0.045748882f0 (tracked)
loss_BC = 0.00277814649482918 (tracked)
cntr = 2000
loss_FD = -78.80496590412795 (tracked)
loss_true1 = 0.032260925f0 (tracked)
loss_BC = 0.003695299862312681 (tracked)
cntr = 3000
loss_FD = -80.60986692702329 (tracked)
loss_true1 = 0.05516353f0 (tracked)
loss_BC = 0.004189870427622248 (tracked)
cntr = 4000
loss_FD = -76.20336972645492 (tracked)
loss_true1 = 0.040186226f0 (tracked)
loss_BC = 0.0038915915081914284 (tracked)
cntr = 5000
loss_FD = -48.381349176445546 (tracked)
loss_true1 = 1.7906055f0 (tracked)
loss_BC = 0.040886130881543366 (tracked)
cntr = 6000
loss_FD = -79.90842662366163 (tracked)
loss_true1 = 0.13280821f0 (tracked)
loss_BC = 0.02846981828543376 (tracked)
cntr = 7000
loss_FD = -82.95770035870768 (tracked)
loss_true1 = 0.09035478f0 (tracked)
loss_BC = 0.0

cntr = 14000
loss_FD = -77.75169743275734 (tracked)
loss_true1 = 0.053894784f0 (tracked)
loss_BC = 0.0015702744401064868 (tracked)
cntr = 15000
loss_FD = -79.27562670482729 (tracked)
loss_true1 = 0.03358482f0 (tracked)
loss_BC = 0.003489297905881515 (tracked)
cntr = 16000
loss_FD = -77.9441903877835 (tracked)
loss_true1 = 0.054629065f0 (tracked)
loss_BC = 0.0018128575548400354 (tracked)
cntr = 17000
loss_FD = -77.33105141731536 (tracked)
loss_true1 = 0.039264955f0 (tracked)
loss_BC = 0.002412758875492693 (tracked)
cntr = 18000
loss_FD = -78.92500229497068 (tracked)
loss_true1 = 0.03508818f0 (tracked)
loss_BC = 0.002673374926039376 (tracked)
cntr = 19000
loss_FD = -79.91226960461597 (tracked)
loss_true1 = 0.06318046f0 (tracked)
loss_BC = 0.003875079980190545 (tracked)
cntr = 20000
loss_FD = -79.50309327167355 (tracked)
loss_true1 = 0.029892536f0 (tracked)
loss_BC = 0.0025578366983867243 (tracked)
cntr = 21000
loss_FD = -78.24988916898329 (tracked)
loss_true1 = 0.028103089f0 (tracked)
lo