In [1]:
using Flux.Tracker
using Flux
using Flux:throttle,glorot_uniform
using BSON:@save
using BSON:@load
using Base.Iterators: repeated
using Flux:@treelike
using PyPlot
using FastGaussQuadrature
using SparseGrids
using LinearAlgebra
using Flux: @epochs
using DataFrames
using CSV
using Sobol
using Pkg; Pkg.activate("cuda"); Pkg.instantiate()
using CuArrays
using ForwardDiff
using CUDAnative
CuArrays.culiteral_pow(::typeof(^), x::ForwardDiff.Dual{Nothing,Float32,1}, ::Val{2}) = x*x
CuArrays.culiteral_pow(::typeof(^), x::ForwardDiff.Dual{Nothing,Float64,1}, ::Val{2}) = x*x

f(x)=2*π^2 .*sum(cos.(π*x[i,:]') for i=1:8, dims=1)
acti(x)=@. x/(1+CUDAnative.exp(-x))


struct Block{F,S,T}
  W1::S
  W2::S
  b1::T
  b2::T
  σ::F
end

Block(W1 , W2, b1, b2) = Block(W1,W2 ,b1 ,b2, identity)

function Block(in1::Integer,in2::Integer, out1::Integer,out2::Integer, σ = identity;
  initW1 = glorot_uniform, initW2 =glorot_uniform, initb1 = zeros,initb2=zeros)
  return Block(param(initW1(out1, in1)),param(initW2(out2,in2)), param(initb1(out1)), param(initb2(out2)),σ)
end

@treelike Block

function (a::Block)(x)
  W1,W2, b1,b2, σ = a.W1,a.W2,a.b1, a.b2, a.σ
  σ.(W2 *σ.(W1*x .+ b1) .+ b2) .+ x
end





M=Chain(
Dense(8,30),
Block(30,30,30,30,acti),
Block(30,30,30,30,acti),
Block(30,30,30,30,acti),
Block(30,30,30,30,acti),
Dense(30,1)
)|>gpu

#=
@load "test1FD.bson" M
@load "test2FD.bson" weights
Flux.loadparams!(M, weights)
=#


#=

function loss(s,y=1)
  loss = sum((M(s) - 0.25*(s[1,:].^2+s[2,:].^2 .-1)').^2)
  return loss
end
=#

function lossFD(N)
  notes=zeros(8,N)
  for i=1:N
    notes[:,i]=next!(inner)
  end
  notes =notes |>gpu
  d=rand(1)/100 |>gpu
  losses=0
  MMM=M(notes)
  for i=1:8
    l1=zeros(8)
    l1[i]=1
    l1 = l1 |>gpu
    losses  += (sum(0.5*((M(notes .+ d.*l1) .- MMM)./d).^2)/N)[1]
  end
  losses += (sum(0.5*(π^2 * (MMM).^2))/N)[1]
  losses += (sum(-f(notes) .*MMM)/N)[1]
  return losses
end

function lossBC(β)
  d=(rand(1)/100)|>gpu
  point=zeros(8,1600)
  p=zeros(7,100)
  for i=1:100
    p[:,i]=next!(BC)
  end
  for i=1:8
    point[:,200*(i-1)+1:200*(i-1)+100]=vcat(p[1:i-1,:],zeros(100)',p[i:end,:])
    point[:,200*(i-1)+101:200*i]=vcat(p[1:i-1,:], ones(100)',p[i:end,:])
  end
    point=point|>gpu
  losses=0
  for i=1:8
    l2=zeros(8)
    l1=zeros(8)
    l2[i]=2
    l1[i]=1
    l1=l1|>gpu
    l2=l2|>gpu
    losses +=β*(sum((( 1/2*M(point[:,200*(i-1)+1:200*i] .+ d.*l1) .- 1/2*M(point[:,200*(i-1)+1:200*i] .- d.*l1))./d).^2)/100)[1]
  end
  return losses
end

function loss1(N)
  return lossBC(50)+lossFD(N)
end



function loss_true(xx=1,yy=1,xx1=1,xx2=2)
  points=zeros(8,5000)
    for i=1:5000
      points[:,i]=next!(truep)
    end
    points=points  |>gpu
  F_true(x)= sum(cos.(pi*x[i,:]') for i=1:8, dims=1)
  errors=sqrt(sum((M(points)-F_true(points)).^2)/2000)
  return errors
end
truep=SobolSeq(8)

inner=SobolSeq(8)
BC=SobolSeq(7)
function test(N;traintime=2000)
  errdf=DataFrame(time=Int[],absolute_error=Float64[],loss_FD=Float64[])
  cntr=1
  evalcb = function()
    loss_true1=loss_true()
    loss_FD = lossFD(N)
    push!(errdf,[Tracker.data(cntr),Tracker.data(loss_true1),Tracker.data(loss_FD)])
    cntr+=1
    if cntr%500==0
      @show(cntr,loss_FD,loss_true1)
      CSV.write("$(N)Sobol8Dexm2.csv",errdf)
      @save "$(N)Sobolpoint8Dexm2m.bson" M
      weights=Tracker.data.(params(M))
      @save "$(N)Sobolpoint8Dexm2w.bson" weights
    end
  end
  θ=Flux.params(M)
  opt=ADAM()
  dataset=[(N) for i=1:traintime]
  Flux.train!(loss1, params(M), zip(dataset), opt, cb=evalcb)
  return errdf
end


[32m[1mActivating[22m[39m new environment at `~/exm2/cuda/Project.toml`
[32m[1m  Updating[22m[39m registry at `~/.julia/registries/General`
[32m[1m  Updating[22m[39m git-repo `https://github.com/JuliaRegistries/General.git`
[?25l[2K[?25h[32m[1m Resolving[22m[39m package versions...


test (generic function with 1 method)

In [2]:
errdf=test(100,traintime=50000)
for i=1:4
    N=[500 1000 2000 10000]
    errdf=test(N[i],traintime=20000)
end

cntr = 500
loss_FD = -1.2521597459485652 (tracked)
loss_true1 = 3.1087825f0 (tracked)
cntr = 1000
loss_FD = -14.084606125098937 (tracked)
loss_true1 = 2.5380523f0 (tracked)
cntr = 1500
loss_FD = -26.63269239455864 (tracked)
loss_true1 = 1.588723f0 (tracked)
cntr = 2000
loss_FD = -32.58551398775146 (tracked)
loss_true1 = 1.467154f0 (tracked)
cntr = 2500
loss_FD = -37.60506623301963 (tracked)
loss_true1 = 1.0906321f0 (tracked)
cntr = 3000
loss_FD = -38.44732586646599 (tracked)
loss_true1 = 0.80061495f0 (tracked)
cntr = 3500
loss_FD = -40.48010410588239 (tracked)
loss_true1 = 0.19731796f0 (tracked)
cntr = 4000
loss_FD = -38.052154361669864 (tracked)
loss_true1 = 0.15191132f0 (tracked)
cntr = 4500
loss_FD = -39.5481588721866 (tracked)
loss_true1 = 0.1280506f0 (tracked)
cntr = 5000
loss_FD = -39.209583357782876 (tracked)
loss_true1 = 0.13242835f0 (tracked)
cntr = 5500
loss_FD = -38.69811539801696 (tracked)
loss_true1 = 0.12648755f0 (tracked)
cntr = 6000
loss_FD = -16.270090249693173 (tracke

cntr = 47000
loss_FD = -43.99135803167002 (tracked)
loss_true1 = 0.049810193f0 (tracked)
cntr = 47500
loss_FD = -35.341105672248624 (tracked)
loss_true1 = 0.05736922f0 (tracked)
cntr = 48000
loss_FD = -39.51647656618111 (tracked)
loss_true1 = 0.030720329f0 (tracked)
cntr = 48500
loss_FD = -36.25767896550621 (tracked)
loss_true1 = 0.029083934f0 (tracked)
cntr = 49000
loss_FD = -36.43010672224858 (tracked)
loss_true1 = 0.052894205f0 (tracked)
cntr = 49500
loss_FD = -43.158748633181204 (tracked)
loss_true1 = 0.046350338f0 (tracked)
cntr = 50000
loss_FD = -41.13268401150836 (tracked)
loss_true1 = 0.03663211f0 (tracked)
cntr = 500
loss_FD = -39.462604102616126 (tracked)
loss_true1 = 0.025822747f0 (tracked)
cntr = 1000
loss_FD = -40.39840048824662 (tracked)
loss_true1 = 0.027199639f0 (tracked)
cntr = 1500
loss_FD = -39.94950433013162 (tracked)
loss_true1 = 0.024186525f0 (tracked)
cntr = 2000
loss_FD = -38.30208606064873 (tracked)
loss_true1 = 0.028570598f0 (tracked)
cntr = 2500
loss_FD = -39

cntr = 3500
loss_FD = -39.60449004116862 (tracked)
loss_true1 = 0.018797502f0 (tracked)
cntr = 4000
loss_FD = -39.37678501879231 (tracked)
loss_true1 = 0.016674247f0 (tracked)
cntr = 4500
loss_FD = -39.400668678684184 (tracked)
loss_true1 = 0.017744467f0 (tracked)
cntr = 5000
loss_FD = -39.69342506812495 (tracked)
loss_true1 = 0.017859627f0 (tracked)
cntr = 5500
loss_FD = -39.410875843426886 (tracked)
loss_true1 = 0.017413378f0 (tracked)
cntr = 6000
loss_FD = -39.42192854170191 (tracked)
loss_true1 = 0.017095588f0 (tracked)
cntr = 6500
loss_FD = -39.86679456405062 (tracked)
loss_true1 = 0.018038535f0 (tracked)
cntr = 7000
loss_FD = -39.23831026595291 (tracked)
loss_true1 = 0.01601497f0 (tracked)
cntr = 7500
loss_FD = -39.40370741652148 (tracked)
loss_true1 = 0.019878082f0 (tracked)
cntr = 8000
loss_FD = -39.46352881191068 (tracked)
loss_true1 = 0.018057322f0 (tracked)
cntr = 8500
loss_FD = -39.536655421618846 (tracked)
loss_true1 = 0.016304981f0 (tracked)
cntr = 9000
loss_FD = -39.4651