Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend tests for GPEmulator.jl and Observations.jl #62

Merged
merged 1 commit into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 51 additions & 52 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ version = "0.3.3"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "fd04049c7dd78cfef0b06cdc1f0f181467655712"
git-tree-sha1 = "0fac443759fa829ed8066db6cf1077d888bb6573"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "1.1.0"
version = "2.0.2"

[[ArnoldiMethod]]
deps = ["DelimitedFiles", "LinearAlgebra", "Random", "SparseArrays", "StaticArrays", "Test"]
Expand All @@ -38,21 +38,21 @@ version = "3.5.0+3"

[[ArrayInterface]]
deps = ["LinearAlgebra", "Requires", "SparseArrays"]
git-tree-sha1 = "649c08a5a3a513f4662673d3777fe6ccb4df9f5d"
git-tree-sha1 = "851de9a8acd7b8863aa2ec2af0a44f375502c878"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "2.8.7"
version = "2.9.0"

[[ArrayLayouts]]
deps = ["FillArrays", "LinearAlgebra"]
git-tree-sha1 = "89182776a99b69964e995cc2f1e37b5fc3476d56"
git-tree-sha1 = "a3254b3780a3544838ca0b7e23b1e9b06eb71bd8"
uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
version = "0.3.4"
version = "0.3.5"

[[BandedMatrices]]
deps = ["ArrayLayouts", "FillArrays", "LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "fd300e252fa1d96c75884cfa37fd6a5402c79d4b"
git-tree-sha1 = "195ceb173f0759ca595770fac3b379e51579e5e7"
uuid = "aae01518-5342-5314-be14-df237901396f"
version = "0.15.12"
version = "0.15.13"

[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
Expand Down Expand Up @@ -94,15 +94,15 @@ version = "0.8.1"

[[ChainRules]]
deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"]
git-tree-sha1 = "85f130f2c5ce208a5a395b550802398d2fcc5ee6"
git-tree-sha1 = "76cd719cb7ab57bd2687dcb3b186c4f99820a79d"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.6.4"
version = "0.6.5"

[[ChainRulesCore]]
deps = ["MuladdMacro"]
git-tree-sha1 = "32e2c6e44d4fdd985b5688b5e85c1f6892cf3d15"
git-tree-sha1 = "c384e0e4fe6bfeb6bec0d41f71cc5e391cd110ba"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.8.0"
version = "0.8.1"

[[Cloudy]]
deps = ["Coverage", "DifferentialEquations", "DocStringExtensions", "ForwardDiff", "HCubature", "LinearAlgebra", "Optim", "PyPlot", "SpecialFunctions", "TaylorSeries", "Test"]
Expand Down Expand Up @@ -187,15 +187,15 @@ version = "1.3.0"

[[DataFrames]]
deps = ["CategoricalArrays", "Compat", "DataAPI", "Future", "InvertedIndices", "IteratorInterfaceExtensions", "Missings", "PooledArrays", "Printf", "REPL", "Reexport", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"]
git-tree-sha1 = "02f08ae77249b7f6d4186b081a016fb7454c616f"
git-tree-sha1 = "e516e72bfb40809b7709cda7bfb39e82ec492d68"
uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
version = "0.21.2"
version = "0.21.3"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "be680f1ad03c0a03796aa3fda5a2180df7f83b46"
git-tree-sha1 = "edad9434967fdc0a2631a65d902228400642120c"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.18"
version = "0.17.19"

[[DataValueInterfaces]]
git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6"
Expand All @@ -218,9 +218,9 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"

[[DiffEqBase]]
deps = ["ArrayInterface", "ChainRulesCore", "ConsoleProgressMonitor", "DataStructures", "Distributed", "DocStringExtensions", "FunctionWrappers", "IterativeSolvers", "IteratorInterfaceExtensions", "LabelledArrays", "LinearAlgebra", "Logging", "LoggingExtras", "MuladdMacro", "Parameters", "Printf", "ProgressLogging", "RecipesBase", "RecursiveArrayTools", "RecursiveFactorization", "Requires", "Roots", "SparseArrays", "StaticArrays", "Statistics", "SuiteSparse", "TableTraits", "TerminalLoggers", "TreeViews", "ZygoteRules"]
git-tree-sha1 = "ae65fac7d9933f3d039c0296b5d41bf8c3d8f4ea"
git-tree-sha1 = "eb3cfba5228aceca0024d9a15086d82ef8330d8e"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
version = "6.38.4"
version = "6.39.1"

[[DiffEqCallbacks]]
deps = ["DataStructures", "DiffEqBase", "ForwardDiff", "LinearAlgebra", "NLsolve", "OrdinaryDiffEq", "RecipesBase", "RecursiveArrayTools", "StaticArrays"]
Expand All @@ -230,9 +230,9 @@ version = "2.13.3"

[[DiffEqFinancial]]
deps = ["DiffEqBase", "DiffEqNoiseProcess", "LinearAlgebra", "Markdown", "RandomNumbers"]
git-tree-sha1 = "f0c6f2b0b9fa463a90da06142e45ecf8e0b70bac"
git-tree-sha1 = "db08e0def560f204167c58fd0637298e13f58f73"
uuid = "5a0ffddc-d203-54b0-88ba-2c03c0fc2e67"
version = "2.3.0"
version = "2.4.0"

[[DiffEqJump]]
deps = ["ArrayInterface", "Compat", "DataStructures", "DiffEqBase", "FunctionWrappers", "LinearAlgebra", "Parameters", "PoissonRandom", "Random", "RandomNumbers", "RecursiveArrayTools", "StaticArrays", "Statistics", "TreeViews"]
Expand All @@ -242,9 +242,9 @@ version = "6.9.2"

[[DiffEqNoiseProcess]]
deps = ["DataStructures", "DiffEqBase", "Distributions", "LinearAlgebra", "PoissonRandom", "Random", "RandomNumbers", "RecipesBase", "RecursiveArrayTools", "Requires", "ResettableStacks", "StaticArrays", "Statistics"]
git-tree-sha1 = "fc9ba5c47246d1e6c15ae36ce9f5e67b6ffc06b7"
git-tree-sha1 = "474bba439ce886baab756744c54436d7628ef05e"
uuid = "77a26b50-5914-5dd7-bc55-306e6241c503"
version = "4.2.0"
version = "4.3.0"

[[DiffEqPhysics]]
deps = ["DiffEqBase", "DiffEqCallbacks", "ForwardDiff", "LinearAlgebra", "Printf", "Random", "RecipesBase", "RecursiveArrayTools", "Reexport", "StaticArrays"]
Expand Down Expand Up @@ -317,10 +317,10 @@ uuid = "2904ab23-551e-5aed-883f-487f97af5226"
version = "0.2.1"

[[ExponentialUtilities]]
deps = ["LinearAlgebra", "Printf", "SparseArrays"]
git-tree-sha1 = "1672dedeacaab85345fd359ad56dde8fb5d48a45"
deps = ["LinearAlgebra", "Printf", "Requires", "SparseArrays"]
git-tree-sha1 = "91f7498b66205431fe3e35833cda97a22b1ab6a5"
uuid = "d4d017d3-3776-5f7e-afef-a10c40355c18"
version = "1.6.0"
version = "1.7.0"

[[FastGaussQuadrature]]
deps = ["LinearAlgebra", "SpecialFunctions"]
Expand All @@ -330,9 +330,9 @@ version = "0.4.2"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "44f561e293987ffc84272cd3d2b14b0b93123d63"
git-tree-sha1 = "bf726ba7ce99e00d10bf63c031285fb9ab3676ae"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.8.10"
version = "0.8.11"

[[FiniteDiff]]
deps = ["ArrayInterface", "LinearAlgebra", "Requires", "SparseArrays", "StaticArrays"]
Expand Down Expand Up @@ -453,9 +453,9 @@ version = "0.2.0"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "72fc0a39d5899091ff2d4cdaa64cb5e4862cf813"
git-tree-sha1 = "d9c6e1efcaa6c2fcd043da812a62b3e489a109a3"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.5.2"
version = "1.7.0"

[[LaTeXStrings]]
git-tree-sha1 = "de44b395389b84fd681394d4e8d39ef14e3a2ea8"
Expand All @@ -481,7 +481,6 @@ uuid = "1d6d02ad-be62-4b6b-8a6d-2f90e265016e"
version = "0.1.2"

[[LibGit2]]
deps = ["Printf"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"

[[Libdl]]
Expand Down Expand Up @@ -513,9 +512,9 @@ version = "0.4.1"

[[LoopVectorization]]
deps = ["DocStringExtensions", "LinearAlgebra", "OffsetArrays", "SIMDPirates", "SLEEFPirates", "UnPack", "VectorizationBase"]
git-tree-sha1 = "59f7e9fddaae12967a0c0903aff2d06a8813e2b1"
git-tree-sha1 = "f49302d088dadda9dad58e65883ce24413b8c1f4"
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
version = "0.8.5"
version = "0.8.7"

[[METIS_jll]]
deps = ["Libdl", "Pkg"]
Expand Down Expand Up @@ -546,9 +545,9 @@ version = "1.0.2"

[[MbedTLS_jll]]
deps = ["Libdl", "Pkg"]
git-tree-sha1 = "c83f5a1d038f034ad0549f9ee4d5fac3fb429e33"
git-tree-sha1 = "f85473aeb7a2561a5c58c06c4868971ebe2bcbff"
uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
version = "2.16.0+2"
version = "2.16.6+0"

[[Missings]]
deps = ["DataAPI"]
Expand Down Expand Up @@ -653,12 +652,12 @@ version = "0.12.1"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "eb3e09940c0d7ae01b01d9291ebad7b081c844d3"
git-tree-sha1 = "20ef902ea02f7000756a4bc19f7b9c24867c6211"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.0.5"
version = "1.0.6"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[PoissonRandom]]
Expand Down Expand Up @@ -739,15 +738,15 @@ version = "1.0.1"

[[RecursiveArrayTools]]
deps = ["ArrayInterface", "LinearAlgebra", "RecipesBase", "Requires", "StaticArrays", "Statistics", "ZygoteRules"]
git-tree-sha1 = "96e71928efa701fa5a6df0f88b51f05ceed70f2c"
git-tree-sha1 = "0ffe36b65f0fc4967a42a673c1a9ffa65724dee6"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
version = "2.4.4"
version = "2.5.0"

[[RecursiveFactorization]]
deps = ["LinearAlgebra", "LoopVectorization"]
git-tree-sha1 = "09217cb106dd826de9960986207175b52e3035f2"
deps = ["LinearAlgebra", "LoopVectorization", "VectorizationBase"]
git-tree-sha1 = "04bc629fc40d612e1a048c61c3fcbbe1adc3b641"
uuid = "f2c3362d-daeb-58d1-803e-2bc74f2840b4"
version = "0.1.2"
version = "0.1.3"

[[Reexport]]
deps = ["Pkg"]
Expand Down Expand Up @@ -790,9 +789,9 @@ uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"

[[SIMDPirates]]
deps = ["VectorizationBase"]
git-tree-sha1 = "74bf6ed250c21651955bdb36b2b12320374c49ae"
git-tree-sha1 = "18dca6ff298fdde2d5d837f8aaba6d54302ebee3"
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
version = "0.8.7"
version = "0.8.10"

[[SLEEFPirates]]
deps = ["Libdl", "SIMDPirates", "VectorizationBase"]
Expand Down Expand Up @@ -846,9 +845,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SparseDiffTools]]
deps = ["Adapt", "ArrayInterface", "Compat", "DataStructures", "FiniteDiff", "ForwardDiff", "LightGraphs", "LinearAlgebra", "Requires", "SparseArrays", "VertexSafeGraphs"]
git-tree-sha1 = "bfe68e0d914952932594b3c838f08463b0841037"
git-tree-sha1 = "567fd5758c8271b81cb6497f1bddf1a2d0dd09af"
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
version = "1.8.0"
version = "1.9.0"

[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl"]
Expand Down Expand Up @@ -966,15 +965,15 @@ uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[[Unitful]]
deps = ["ConstructionBase", "LinearAlgebra", "Random"]
git-tree-sha1 = "3714b55de06b11b2aa788b8643d6e91f13648be5"
git-tree-sha1 = "a061dada333813818aa7454f93c63a5cab6ea981"
uuid = "1986cc42-f94f-5a68-af5c-568840ba703d"
version = "1.2.1"
version = "1.3.0"

[[VectorizationBase]]
deps = ["CpuId", "LLVM", "Libdl", "LinearAlgebra"]
git-tree-sha1 = "bcadc352d9c81b0ef9ceebe822d30128b779f56b"
git-tree-sha1 = "ed02d6b61057bb6ddf7e8b1dccfec907cc064b36"
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
version = "0.12.8"
version = "0.12.13"

[[VersionParsing]]
git-tree-sha1 = "80229be1f670524750d905f8fc8148e5a8c4537f"
Expand All @@ -989,9 +988,9 @@ version = "0.1.2"

[[Zygote]]
deps = ["AbstractFFTs", "ArrayLayouts", "ChainRules", "FillArrays", "ForwardDiff", "Future", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "Random", "Requires", "Statistics", "ZygoteRules"]
git-tree-sha1 = "6d0f78976db6dbea9a36865efe068e6e2a5db6ed"
git-tree-sha1 = "2e2c82549fb0414df10469082fd001e2ede8547c"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.4.21"
version = "0.4.22"

[[ZygoteRules]]
deps = ["MacroTools"]
Expand Down
60 changes: 55 additions & 5 deletions src/EKI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export EKIObj
export construct_initial_ensemble
export compute_error
export update_ensemble!

export find_eki_step

"""
EKIObj{FT<:AbstractFloat, IT<:Int}
Expand All @@ -36,6 +36,8 @@ struct EKIObj{FT<:AbstractFloat, IT<:Int}
g::Vector{Array{FT, 2}}
"vector of errors"
err::Vector{FT}
"vector of timesteps used in each EKI iteration"
Δt::Vector{FT}
end

# outer constructors
Expand All @@ -53,9 +55,11 @@ function EKIObj(parameters::Array{FT, 2},
# observations
g = Vector{FT}[]
# error store
err = []
err = FT[]
# timestep store
Δt = FT[]

EKIObj{FT,IT}(u, parameter_names, t_mean, t_cov, N_ens, g, err)
EKIObj{FT,IT}(u, parameter_names, t_mean, t_cov, N_ens, g, err, Δt)
end


Expand Down Expand Up @@ -91,9 +95,10 @@ function compute_error(eki)
end


function update_ensemble!(eki::EKIObj{FT}, g) where {FT}
function update_ensemble!(eki::EKIObj{FT}, g; cov_threshold::FT=0.01, Δt_new = nothing) where {FT}
# u: N_ens x N_params
u = eki.u[end]
cov_init = cov(eki.u[end], dims=1)

u_bar = fill(FT(0), size(u)[2])
# g: N_ens x N_data
Expand All @@ -102,6 +107,14 @@ function update_ensemble!(eki::EKIObj{FT}, g) where {FT}
cov_ug = fill(FT(0), size(u)[2], size(g)[2])
cov_gg = fill(FT(0), size(g)[2], size(g)[2])

if !isnothing(Δt_new)
push!(eki.Δt, Δt_new)
elseif isnothing(Δt_new) && isempty(eki.Δt)
push!(eki.Δt, FT(1))
else
push!(eki.Δt, eki.Δt[end])
end

# update means/covs with new param/observation pairs u, g
for j = 1:eki.N_ens

Expand All @@ -123,7 +136,7 @@ function update_ensemble!(eki::EKIObj{FT}, g) where {FT}
cov_gg = cov_gg / eki.N_ens - g_bar * g_bar'

# update the parameters (with additive noise too)
noise = rand(MvNormal(zeros(size(g)[2]), eki.cov), eki.N_ens) # N_data * N_ens
noise = rand(MvNormal(zeros(size(g)[2]), eki.cov/eki.Δt[end]), eki.N_ens) # N_data * N_ens
y = (eki.g_t .+ noise)' # add g_t (N_data) to each column of noise (N_data x N_ens), then transp. into N_ens x N_data
tmp = (cov_gg + eki.cov) \ (y - g)' # N_data x N_data \ [N_ens x N_data - N_ens x N_data]' --> tmp is N_data x N_ens
u += (cov_ug * tmp)' # N_ens x N_params
Expand All @@ -134,6 +147,43 @@ function update_ensemble!(eki::EKIObj{FT}, g) where {FT}

compute_error(eki)

# Check convergence
cov_new = cov(eki.u[end], dims=1)
cov_ratio = det(cov_new)/det(cov_init)
if cov_ratio < cov_threshold
@warn string("New ensemble covariance determinant is less than ",cov_threshold," times its former value.
Consider reducing the EKI time step.")
end
end


"""
find_eki_step(eki::EKIObj{FT}, g::Array{FT, 2}; cov_threshold::FT=0.01) where {FT}
Find largest step for the EKI solver that leads to a reduction of the determinant of the sample
covariance matrix no greater than cov_threshold.
"""
function find_eki_step(eki::EKIObj{FT}, g::Array{FT, 2}; cov_threshold::FT=0.01) where {FT}
accept_step = false
if !isempty(eki.Δt)
Δt = deepcopy(eki.Δt[end])
else
Δt = FT(1)
end
# u: N_ens x N_params
cov_init = cov(eki.u[end], dims=1)
while accept_step == false
eki_copy = deepcopy(eki)
update_ensemble!(eki_copy, g, Δt_new=Δt)
cov_new = cov(eki_copy.u[end], dims=1)
if det(cov_new) > cov_threshold*det(cov_init)
accept_step = true
else
Δt = Δt/2
end
end

return Δt

end

end # module EKI
Loading