Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/multilevel #378

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions src/EnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ Provides a failsafe update that
- updates the successful ensemble according to the EKI update,
- updates the failed ensemble by sampling from the updated successful ensemble.
"""
function FailureHandler(process::Inversion, method::SampleSuccGauss)
function FailureHandler(::Inversion, ::SampleSuccGauss)
function failsafe_update(ekp, u, g, y, obs_noise_cov, failed_ens)
successful_ens = filter(x -> !(x in failed_ens), collect(1:size(g, 2)))
n_failed = length(failed_ens)
u[:, successful_ens] =
eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov)
(failed_ens, sample_transform, sample_dim) = get_correlations(ekp.level_scheduler, failed_ens)
u = eki_update(ekp, u, g, y, obs_noise_cov; ignored_indices = failed_ens)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed)
new_samples = sample_empirical_gaussian(ekp.rng, ekp, u, sample_dim; ignored_indices = failed_ens)
u[:, failed_ens] = new_samples[:, sample_transform]
end
return u
end
Expand All @@ -39,7 +38,8 @@ end
u::AbstractMatrix{FT},
g::AbstractMatrix{FT},
y::AbstractMatrix{FT},
obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}},
obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}};
ignored_indices = [],
) where {FT <: Real, IT, CT <: Real}

Returns the updated parameter vectors given their current values and
Expand All @@ -53,14 +53,17 @@ function eki_update(
u::AbstractMatrix{FT},
g::AbstractMatrix{FT},
y::AbstractMatrix{FT},
obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}},
obs_noise_cov::Union{AbstractMatrix{CT}, UniformScaling{CT}};
ignored_indices = [],
) where {FT <: Real, IT, CT <: Real}

cov_est = cov([u; g], dims = 2, corrected = false) # [(N_par + N_obs)×(N_par + N_obs)]
cov_est = compute_cov(ekp, [u; g]; corrected = false, ignored_indices) # [(N_par + N_obs)×(N_par + N_obs)]

# Localization
cov_localized = ekp.localizer.localize(cov_est)

cov_uu, cov_ug, cov_gg = get_cov_blocks(cov_localized, size(u, 1))
cov_gg = posdef(cov_gg)

# N_obs × N_obs \ [N_obs × N_ens]
# --> tmp is [N_obs × N_ens]
Expand Down Expand Up @@ -108,9 +111,10 @@ function update_ensemble!(
# g: N_obs × N_ens
u = get_u_final(ekp)
N_obs = size(g, 1)
cov_init = cov(u, dims = 2)

if ekp.verbose
cov_init = compute_cov(ekp, u; corrected = true)

if get_N_iterations(ekp) == 0
@info "Iteration 0 (prior)"
@info "Covariance trace: $(tr(cov_init))"
Expand All @@ -123,7 +127,9 @@ function update_ensemble!(

# Scale noise using Δt
scaled_obs_noise_cov = ekp.obs_noise_cov / ekp.Δt[end]
noise = sqrt(scaled_obs_noise_cov) * rand(ekp.rng, MvNormal(zeros(N_obs), I), ekp.N_ens)
independent_noise_dim = get_N_indep(ekp.level_scheduler)
noise = scaled_obs_noise_cov * rand(ekp.rng, MvNormal(zeros(N_obs), I), independent_noise_dim)
noise = transform_noise(ekp.level_scheduler, noise)

# Add obs_mean (N_obs) to each column of noise (N_obs × N_ens) if
# G is deterministic
Expand All @@ -143,10 +149,10 @@ function update_ensemble!(
# Store error
compute_error!(ekp)

# Diagnostics
cov_new = cov(u, dims = 2)

if ekp.verbose
# Diagnostics
cov_new = compute_cov(ekp, u; corrected = true)

@info "Covariance-weighted error: $(get_error(ekp)[end])\nCovariance trace: $(tr(cov_new))\nCovariance trace ratio (current/previous): $(tr(cov_new)/tr(cov_init))"
end

Expand Down
37 changes: 33 additions & 4 deletions src/EnsembleKalmanProcess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,14 @@ abstract type FailureHandlingMethod end
# Accelerators
abstract type Accelerator end

# Level schedulers
abstract type LevelScheduler end


"Failure handling method that ignores forward model failures"
struct IgnoreFailures <: FailureHandlingMethod end

""""
"""
SampleSuccGauss <: FailureHandlingMethod

Failure handling method that substitutes failed ensemble members by new samples from
Expand Down Expand Up @@ -130,6 +132,8 @@ struct EnsembleKalmanProcess{
scheduler::LRS
"accelerator object that informs EK update steps, stores additional state variables as needed"
accelerator::ACC
""
level_scheduler::LevelScheduler
"stored vector of timesteps used in each EK iteration"
Δt::Vector{FT}
"the particular EK process (`Inversion` or `Sampler` or `Unscented` or `TransformInversion` or `SparseInversion`)"
Expand All @@ -151,6 +155,7 @@ function EnsembleKalmanProcess(
process::P;
scheduler::Union{Nothing, LRS} = nothing,
accelerator::Union{Nothing, ACC} = nothing,
level_scheduler::Union{Nothing, LS} = nothing,
Δt = nothing,
rng::AbstractRNG = Random.GLOBAL_RNG,
failure_handler_method::FM = IgnoreFailures(),
Expand All @@ -160,6 +165,7 @@ function EnsembleKalmanProcess(
FT <: AbstractFloat,
LRS <: LearningRateScheduler,
ACC <: Accelerator,
LS <: LevelScheduler,
P <: Process,
FM <: FailureHandlingMethod,
LM <: LocalizationMethod,
Expand Down Expand Up @@ -221,6 +227,17 @@ function EnsembleKalmanProcess(
end
end

# set up level scheduler
ls = if isnothing(level_scheduler)
SingleLevelScheduler(N_ens, LevelInfinity())
else
if !(typeof(process) <: Inversion)
throw(ArgumentError("Only `Inversion` (EKI) can currently be used with multilevel Monte Carlo."))
end

level_scheduler
end

# failure handler
fh = FailureHandler(process, failure_handler_method)
# localizer
Expand All @@ -239,6 +256,7 @@ function EnsembleKalmanProcess(
err,
lrs,
acc,
ls,
Δt,
process,
rng,
Expand Down Expand Up @@ -503,20 +521,24 @@ get_error(ekp::EnsembleKalmanProcess) = ekp.err
"""
sample_empirical_gaussian(
rng::AbstractRNG,
ekp::EnsembleKalmanProcess,
u::AbstractMatrix{FT},
n::IT;
inflation::Union{FT, Nothing} = nothing,
ignored_indices = [],
) where {FT <: Real, IT <: Int}

Returns `n` samples from an empirical Gaussian based on point estimates `u`, adding inflation if the covariance is singular.
"""
function sample_empirical_gaussian(
rng::AbstractRNG,
ekp::EnsembleKalmanProcess,
u::AbstractMatrix{FT},
n::IT;
inflation::Union{FT, Nothing} = nothing,
ignored_indices = [],
) where {FT <: Real, IT <: Int}
cov_u_new = Symmetric(cov(u, dims = 2))
cov_u_new = Symmetric(posdef(compute_cov(ekp, u; corrected = true, ignored_indices)))
if !isposdef(cov_u_new)
@warn string("Sample covariance matrix over ensemble is singular.", "\n Applying variance inflation.")
if isnothing(inflation)
Expand All @@ -525,16 +547,18 @@ function sample_empirical_gaussian(
end
cov_u_new = cov_u_new + inflation * I
end
mean_u_new = mean(u, dims = 2)
mean_u_new = compute_mean(ekp, u; ignored_indices)
return mean_u_new .+ sqrt(cov_u_new) * rand(rng, MvNormal(zeros(length(mean_u_new[:])), I), n)
end

function sample_empirical_gaussian(
ekp::EnsembleKalmanProcess,
u::AbstractMatrix{FT},
n::IT;
inflation::Union{FT, Nothing} = nothing,
ignored_indices = [],
) where {FT <: Real, IT <: Int}
return sample_empirical_gaussian(Random.GLOBAL_RNG, u, n, inflation = inflation)
return sample_empirical_gaussian(Random.GLOBAL_RNG, ekp, u, n; inflation, ignored_indices)
end


Expand Down Expand Up @@ -691,6 +715,8 @@ function update_ensemble!(
end


include("SampleStatistics.jl")

## include the different types of Processes and their exports:

# struct Inversion
Expand Down Expand Up @@ -719,3 +745,6 @@ include("UnscentedKalmanInversion.jl")

# struct Accelerator
include("Accelerators.jl")

# Level schedulers
include("Multilevel.jl")
2 changes: 1 addition & 1 deletion src/EnsembleTransformKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ function FailureHandler(process::TransformInversion, method::SampleSuccGauss)
n_failed = length(failed_ens)
u[:, successful_ens] = etki_update(ekp, u[:, successful_ens], g[:, successful_ens], y, obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, ekp, u, n_failed; ignored_indices = failed_ens)
end
return u
end
Expand Down
96 changes: 96 additions & 0 deletions src/Multilevel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
export SingleLevelScheduler, MultilevelScheduler, get_N_ens, get_N_indep, levels, transform_noise

struct LevelInfinity end

const SingleLevelType{IT} = Union{IT, LevelInfinity}

struct MultilevelScheduler{IT <: Integer} <: LevelScheduler
Js::Dict{IT, IT}
N_indep::IT
N_ens::IT

function MultilevelScheduler(Js::Dict{IT, IT}) where {IT <: Integer}
N_indep = sum(values(Js))
N_ens = sum(J * (level == 0 ? 1 : 2) for (level, J) in Js)

new{IT}(Js, N_indep, N_ens)
end
end

struct SingleLevelScheduler{IT <: Integer} <: LevelScheduler
N_ens::IT
level::SingleLevelType{IT}

function SingleLevelScheduler(N_ens::IT, level::SingleLevelType{IT} = LevelInfinity()) where {IT <: Integer}
new{IT}(N_ens, level)
end
end


get_N_ens(ms::MultilevelScheduler) = ms.N_ens

get_N_indep(ms::MultilevelScheduler) = ms.N_indep

levels(ms::MultilevelScheduler) = begin
vcat(
fill(0, ms.Js[0]),
(fill(l, ms.Js[l]) for l in sort(collect(keys(ms.Js))) if l != 0)...,
(fill(l - 1, ms.Js[l]) for l in sort(collect(keys(ms.Js))) if l != 0)...,
)
end

statistic_groups(ms::MultilevelScheduler) = begin
groups = []

offset = ms.N_indep - ms.Js[0]

index = 0
for level in sort(collect(keys(ms.Js)))
J = ms.Js[level]
push!(groups, (index+1:index+J, 1))
if level > 0
push!(groups, (index+offset+1:index+offset+J, -1))
end

index += J
end

groups
end

transform_noise(ms::MultilevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = begin
@assert size(noise, 2) == ms.N_indep

noise[:, vcat(1:ms.N_indep, ms.Js[0]+1:ms.N_indep)]
end

get_correlations(ms::MultilevelScheduler, indices::AbstractVector{IT}) where {IT <: Integer} = begin
num_uncorrelated = 0
new_indices = map(indices) do i
if i <= ms.Js[0]
num_uncorrelated += 1
i # There is no correlated index
elseif i <= ms.N_indep
i + (ms.N_indep - ms.Js[0])
else
i - (ms.N_indep - ms.Js[0])
end
end
all_indices = sort!(unique!(vcat(indices, new_indices)))
num_correlated = (length(all_indices) - num_uncorrelated) ÷ 2
noise_dim = num_correlated + num_uncorrelated
all_indices, hcat(1:num_uncorrelated, num_uncorrelated+1:noise_dim, num_uncorrelated+1:noise_dim), noise_dim
end


get_N_ens(sls::SingleLevelScheduler) = sls.N_ens

get_N_indep(sls::SingleLevelScheduler) = sls.N_ens

levels(sls::SingleLevelScheduler) = fill(sls.level, sls.N_ens)

statistic_groups(sls::SingleLevelScheduler) = [(1:sls.N_ens, 1)]

transform_noise(::SingleLevelScheduler, noise::AbstractMatrix{FT}) where {FT <: Real} = noise

get_correlations(::SingleLevelScheduler, indices::AbstractVector{IT}) where {IT <: Integer} = (indices, 1:length(indices), length(indices))
24 changes: 24 additions & 0 deletions src/SampleStatistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# included in EnsembleKalmanProcess.jl

export compute_mean, compute_cov

function posdef(mat)
S, V = eigen(mat)
V = V[:, (S .> 0)]
S = S[S .> 0]
V * diagm(S) * V'
end

function compute_mean(ekp::EnsembleKalmanProcess, x; ignored_indices = [])
reduce(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier)
indices = setdiff(indices, ignored_indices)
multiplier * mean(x[:, indices]; dims = 2) .+ acc
end
end

function compute_cov(ekp::EnsembleKalmanProcess, x; corrected, ignored_indices = [])
reduce(statistic_groups(ekp.level_scheduler); init = 0) do acc, (indices, multiplier)
indices = setdiff(indices, ignored_indices)
multiplier * cov(x[:, indices]; corrected, dims = 2) .+ acc
end
end
2 changes: 1 addition & 1 deletion src/SparseEnsembleKalmanInversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function FailureHandler(process::SparseInversion, method::SampleSuccGauss)
u[:, successful_ens] =
sparse_eki_update(ekp, u[:, successful_ens], g[:, successful_ens], y[:, successful_ens], obs_noise_cov)
if !isempty(failed_ens)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, u[:, successful_ens], n_failed)
u[:, failed_ens] = sample_empirical_gaussian(ekp.rng, ekp, u, n_failed; ignored_indices = failed_ens)
end
return u
end
Expand Down
6 changes: 4 additions & 2 deletions test/EnsembleKalmanProcess/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -960,16 +960,18 @@ end
rng = Random.MersenneTwister(rng_seed)

u = rand(10, 4)
ekp = EKP.EnsembleKalmanProcess(u, [1.;], [1.;;], Inversion())
@test_logs (:warn, r"Sample covariance matrix over ensemble is singular.") match_mode = :any sample_empirical_gaussian(
ekp,
u,
2,
)

u2 = rand(rng, 5, 20)
@test all(
isapprox.(
sample_empirical_gaussian(copy(rng), u2, 2),
sample_empirical_gaussian(copy(rng), u2, 2, inflation = 0.0);
sample_empirical_gaussian(copy(rng), ekp, u2, 2),
sample_empirical_gaussian(copy(rng), ekp, u2, 2, inflation = 0.0);
atol = 1e-8,
),
)
Expand Down
Loading