In [1]:
using Flux.Tracker
using Flux
using Flux:throttle,glorot_uniform
using BSON:@save
using BSON:@load
using Base.Iterators: repeated
using Flux:@treelike
using PyPlot
using FastGaussQuadrature
using SparseGrids
using LinearAlgebra
using Flux: @epochs
using DataFrames
using CSV
using Sobol
using Pkg; Pkg.activate("cuda"); Pkg.instantiate()
using CuArrays
using ForwardDiff
using CUDAnative
CuArrays.culiteral_pow(::typeof(^), x::ForwardDiff.Dual{Nothing,Float32,1}, ::Val{2}) = x*x
CuArrays.culiteral_pow(::typeof(^), x::ForwardDiff.Dual{Nothing,Float64,1}, ::Val{2}) = x*x

f(x)=2*π^2 .*sum(cos.(π*x[i,:]') for i=1:2, dims=1)
acti(x)=@. x/(1+CUDAnative.exp(-x))


struct Block{F,S,T}
  W1::S
  W2::S
  b1::T
  b2::T
  σ::F
end

Block(W1 , W2, b1, b2) = Block(W1,W2 ,b1 ,b2, identity)

function Block(in1::Integer,in2::Integer, out1::Integer,out2::Integer, σ = identity;
  initW1 = glorot_uniform, initW2 =glorot_uniform, initb1 = zeros,initb2=zeros)
  return Block(param(initW1(out1, in1)),param(initW2(out2,in2)), param(initb1(out1)), param(initb2(out2)),σ)
end

@treelike Block

function (a::Block)(x)
  W1,W2, b1,b2, σ = a.W1,a.W2,a.b1, a.b2, a.σ
  σ.(W2 *σ.(W1*x .+ b1) .+ b2) .+ x
end




M=Chain(
Dense(2,10),
Block(10,10,10,10,acti),
Block(10,10,10,10,acti),
Block(10,10,10,10,acti),
Block(10,10,10,10,acti),
Dense(10,1)
)|>gpu

function lossFD(N)
  notes=rand(2,N)
  notes =notes |>gpu
  d=rand(1)/100 |>gpu
  losses=0
  MMM=M(notes)
  for i=1:2
    l1=zeros(2)
    l1[i]=1
    l1 = l1 |>gpu
    losses  += (sum(0.5*((M(notes .+ d.*l1) .- MMM)./d).^2)/N)[1]
  end
  losses += (sum(0.5*(π^2 * (MMM).^2))/N)[1]
  losses += (sum(-f(notes) .*MMM)/N)[1]
  return losses
end



function lossBC(β)
  d=(rand(1)/100)|>gpu
  point=zeros(2,400)
  p=rand(1,100)
  for i=1:2
    point[:,200*(i-1)+1:200*(i-1)+100]=vcat(p[1:i-1,:],zeros(100)',p[i:end,:])
    point[:,200*(i-1)+101:200*i]=vcat(p[1:i-1,:], ones(100)',p[i:end,:])
  end
    point=point|>gpu
  losses=0
  for i=1:2
    l2=zeros(2)
    l1=zeros(2)
    l2[i]=2
    l1[i]=1
    l1=l1|>gpu
    l2=l2|>gpu
    losses +=β*(sum((( 1/2*M(point[:,200*(i-1)+1:200*i] .+ d.*l1) .- 1/2*M(point[:,200*(i-1)+1:200*i] .- d.*l1))./d).^2)/100)[1]
  end
  return losses
end

function loss1(N)
  return lossBC(50)+lossFD(N)
end

function loss_true(xx=1,yy=1,xx1=1,xx2=2)
  points=rand(2,2000)|>gpu
  F_true(x)= sum(cos.(pi*x[i,:]') for i=1:2, dims=1)
  errors=sqrt(sum((M(points)-F_true(points)).^2)/2000)
  return errors
end

inner=SobolSeq(2)
BC=SobolSeq(1)
function test(N;traintime=2000)
  errdf=DataFrame(time=Int[],absolute_error=Float64[],loss_FD=Float64[],loss_BC=Float64[])
  cntr=1
  evalcb = function()
    loss_true1=loss_true()
    loss_FD = lossFD(N)
    loss_BC =lossBC(1)
    push!(errdf,[Tracker.data(cntr),Tracker.data(loss_true1),Tracker.data(loss_FD),Tracker.data(loss_BC)])
    cntr+=1
    if cntr%1000==0
      @show(cntr,loss_FD,loss_true1,loss_BC)
      CSV.write("$(N)rand2Dexm2.csv",errdf)
      @save "$(N)randpoint2Dexm2m.bson" M
      weights=Tracker.data.(params(M))
      @save "$(N)randpoint2Dexm2w.bson" weights
    end
  end
  θ=Flux.params(M)
  opt=ADAM()
  dataset=[(N) for i=1:traintime]
  Flux.train!(loss1, params(M), zip(dataset), opt, cb=evalcb)
  return errdf
end


[32m[1mActivating[22m[39m new environment at `~/2Dexm2/cuda/Project.toml`
[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`


test (generic function with 1 method)

In [None]:
for i=1:4
    N=[250 500 1000 2000]
    M=Chain(
        Dense(2,10),
        Block(10,10,10,10,acti),
        Block(10,10,10,10,acti),
        Block(10,10,10,10,acti),
        Block(10,10,10,10,acti),
        Dense(10,1)
        )|>gpu
    errdf=test(N[i],traintime=100000)
end

cntr = 1000
loss_FD = -7.903569005139207 (tracked)
loss_true1 = 0.46685845f0 (tracked)
loss_BC = 0.003965439620820977 (tracked)
cntr = 2000
loss_FD = -11.634361023088005 (tracked)
loss_true1 = 0.020110203f0 (tracked)
loss_BC = 0.0004659452429950963 (tracked)
cntr = 3000
loss_FD = -10.942559085653706 (tracked)
loss_true1 = 0.023106731f0 (tracked)
loss_BC = 0.0003838518650108326 (tracked)
cntr = 4000
loss_FD = -10.116635241208403 (tracked)
loss_true1 = 0.05469218f0 (tracked)
loss_BC = 0.00018462389427014039 (tracked)
cntr = 5000
loss_FD = -9.337664303317037 (tracked)
loss_true1 = 0.00967024f0 (tracked)
loss_BC = 0.00014354424296791682 (tracked)
cntr = 6000
loss_FD = -9.519744221502279 (tracked)
loss_true1 = 0.009619538f0 (tracked)
loss_BC = 0.0003055643070125184 (tracked)
cntr = 7000
loss_FD = -12.13728814091478 (tracked)
loss_true1 = 0.020308843f0 (tracked)
loss_BC = 0.0006841213231191515 (tracked)
cntr = 8000
loss_FD = -10.47404126586476 (tracked)
loss_true1 = 0.014249951f0 (tracked)
l

cntr = 64000
loss_FD = -9.749672708945376 (tracked)
loss_true1 = 0.017957633f0 (tracked)
loss_BC = 0.0004050347574311396 (tracked)
cntr = 65000
loss_FD = -9.127710832859577 (tracked)
loss_true1 = 0.015654743f0 (tracked)
loss_BC = 0.00016194481025464107 (tracked)
cntr = 66000
loss_FD = -9.344679590003633 (tracked)
loss_true1 = 0.020640887f0 (tracked)
loss_BC = 3.2825688806068855e-5 (tracked)
cntr = 67000
loss_FD = -8.294041169051733 (tracked)
loss_true1 = 0.00974909f0 (tracked)
loss_BC = 4.78536574790026e-5 (tracked)
cntr = 68000
loss_FD = -9.506550506017916 (tracked)
loss_true1 = 0.0277714f0 (tracked)
loss_BC = 0.00026756985556422065 (tracked)
cntr = 69000
loss_FD = -10.30949649225498 (tracked)
loss_true1 = 0.018616829f0 (tracked)
loss_BC = 0.00014156905136750901 (tracked)
cntr = 70000
loss_FD = -7.82298119850784 (tracked)
loss_true1 = 0.010961526f0 (tracked)
loss_BC = 8.194480129020957e-5 (tracked)
cntr = 71000
loss_FD = -9.219278590963278 (tracked)
loss_true1 = 0.026426315f0 (tracked

cntr = 27000
loss_FD = -10.715403664944935 (tracked)
loss_true1 = 0.010435427f0 (tracked)
loss_BC = 0.00016048591582034 (tracked)
cntr = 28000
loss_FD = -10.041692065793832 (tracked)
loss_true1 = 0.0062782443f0 (tracked)
loss_BC = 0.0001767338165824587 (tracked)
cntr = 29000
loss_FD = -9.048089994289692 (tracked)
loss_true1 = 0.034002293f0 (tracked)
loss_BC = 0.00010258787154101998 (tracked)
cntr = 30000
loss_FD = -9.49448840110955 (tracked)
loss_true1 = 0.020087928f0 (tracked)
loss_BC = 0.0005692023246942378 (tracked)
cntr = 31000
loss_FD = -10.029679278008153 (tracked)
loss_true1 = 0.03530808f0 (tracked)
loss_BC = 1.6056215314095638e-5 (tracked)
cntr = 32000
loss_FD = -9.03024057432878 (tracked)
loss_true1 = 0.008527895f0 (tracked)
loss_BC = 0.0002247395560724794 (tracked)
cntr = 33000
loss_FD = -8.414202228298178 (tracked)
loss_true1 = 0.018796707f0 (tracked)
loss_BC = 0.00013264072575126714 (tracked)
cntr = 34000
loss_FD = -9.540843298994972 (tracked)
loss_true1 = 0.012290847f0 (tr