Skip to content

Commit

Permalink
allow Float32 and a mix of Float32/Float64 in fitvertlen
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander-Barth committed Jan 17, 2024
1 parent d0ba9aa commit 589457e
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/DIVAnd_aexerr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ function DIVAnd_aexerr(mask, pmn, xi, x, f, len, epsilon2; rng=Random.GLOBAL_RNG
#@show sum(restrictedlist[1:size(f)[1]]),sum(restrictedlist[size(f)[1]+1:end])
# #############################################################
Batdatapoints = DIVAnd_erroratdatapoints(s1; restrictedlist = restrictedlist)
epsilonforB = ones(Float64, size(ffake)[1]) .* epsilon2fake
epsilonforB = ones(size(ffake)[1]) .* epsilon2fake
epsilonforB[restrictedlist] .= 1.0 / 100.0
Batdatapoints[.!restrictedlist] .= 1.0
Bmean = mean(Batdatapoints[restrictedlist])
Expand Down
4 changes: 2 additions & 2 deletions src/DIVAnd_constr_fractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ function DIVAnd_constr_fractions(s, epsfractions)

A = spzeros(nwithoutfractions, s.sv.n)
l = size(A, 1)
#yo = ones(Float64,l)
yo = zeros(Float64, l)
#yo = ones(l)
yo = zeros(l)
R = Diagonal(epsfractions .* ones(l))


Expand Down
4 changes: 2 additions & 2 deletions src/DIVAnd_datainboundingbox.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ function DIVAnd_datainboundingbox(xi, x, f; Rmatrix = ())

n = length(xi)

maxxi = zeros(Float64, n)
minxi = zeros(Float64, n)
maxxi = zeros(n)
minxi = zeros(n)

sel = trues(size(x[1], 1))

Expand Down
8 changes: 4 additions & 4 deletions src/DIVAnd_diagapp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ function DIVAnd_diagapp(
# Allocate arrays once
eij = zeros(Int, size(pmn[1]))
#@show size(eij)
diagerror = zeros(Float64, size(pmn[1])) .* NaN
tutuu = zeros(Float64, size(pmn[1]))
diagerror = zeros(size(pmn[1])) .* NaN
tutuu = zeros(size(pmn[1]))
tutu = statevector_pack(sv, (eij,))
#@show size(tutu)
z = zeros(Float64, size(P)[1])
zs = zeros(Float64, size(P)[1])
z = zeros(size(P)[1])
zs = zeros(size(P)[1])

if Binv
y = zeros(Float64, size(P)[1])
Expand Down
4 changes: 2 additions & 2 deletions src/diva.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ function diva3d(
lonr, latr, depthr, TS = if length(xi) == 4
xi
else
(xi[1], xi[2], Float64[0.0], xi[3])
(xi[1], xi[2], [0.0], xi[3])
end

checkdepth(depthr)
Expand All @@ -241,7 +241,7 @@ function diva3d(
lon, lat, depth, time = if length(xi) == 4
x
else
(x[1], x[2], Float64[0.0], x[3])
(x[1], x[2], [0.0], x[3])
end

# anamorphosis transform
Expand Down
25 changes: 12 additions & 13 deletions src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,11 +311,11 @@ function Base.iterate(iter::RandomCoupels, state = (0, copy(iter.rng)))
return ((i, j), (count + 1, rng))
end

mutable struct VertRandomCoupels{TRNG <: AbstractRNG}
zlevel::Float64 # depth in meters
mutable struct VertRandomCoupels{TRNG <: AbstractRNG,Tz,Tx,Ts}
zlevel::Tz # depth in meters
zindex::Vector{Int}
x::NTuple{3,Vector{Float64}}
searchxy::Float64 # in meters
x::NTuple{3,Vector{Tx}}
searchxy::Ts # in meters
maxntries::Int
count::Int
rng::TRNG
Expand Down Expand Up @@ -773,8 +773,8 @@ function fithorzlen(
x,
value::Vector{T},
z;
tolrel::T = 1e-4,
smoothz::T = 100.0,
tolrel = T(1e-4),
smoothz = T(100.0),
smoothk = 3,
searchz = 50.0,
progress = (iter, var, len, fitness) -> nothing,
Expand Down Expand Up @@ -876,16 +876,16 @@ function fitvertlen(
x,
value::Vector{T},
z;
smoothz::T = 100.0,
smoothk::T = 3.0,
searchz = 10.0,
searchxy::T = 1_000.0, # meters
maxntries::Int = 10000,
smoothz = T(100.0),
smoothk = T(3.0),
searchz = T(10.0),
searchxy = T(1_000.0), # meters
maxntries = 10000,
maxnsamp = 50,
progress = (iter, var, len, fitness) -> nothing,
distfun = (xi, xj) -> sqrt(sum(abs2, xi - xj)),
limitfun = (z, len) -> len,
epsilon2 = ones(size(value)),
epsilon2 = ones(T,size(value)),
min_rqual = 0.5,
rng = Random.GLOBAL_RNG,
) where {T}
Expand Down Expand Up @@ -962,5 +962,4 @@ function fitvertlen(
end

return lenoptf, Dict(:var0 => var0opt, :len => lenopt, :fitinfos => fitinfos)

end
13 changes: 10 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -607,16 +607,23 @@ function random(
)

n = size(s.iB, 1)::Int
z = randn(rng,n, Nens)
z = randn(rng, n, Nens)

F = cholesky(s.iB::SparseMatrixCSC{T,Int})
F_UP = F.UP

# P pivoting matrix
# P pivoting matrix, L lower triangular matrix
# s.iB == P'*L*L'*P
# F[:UP] == L'*P
# F.UP == L'*P

ff = F_UP \ z

# covariance of ff
# ff * ff' = inv(F_UP) * z * z' * inv(F_UP)'
# = inv(F_UP) * inv(F_UP)'
# = inv(F_UP' * F_UP')
# = inv(s.iB)

field = DIVAnd.unpackens(s.sv, ff)[1]::Array{T,N + 1}
return field
end
Expand Down
29 changes: 26 additions & 3 deletions test/test_fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ mask, (pm, pn, po), (xi, yi, zi) =
lenx = leny = lenz = 0.2
Nens = 1
Random.seed!(rng,12345)
field = DIVAnd.random(mask, (pm, pn, po), (lenx, leny, lenz), Nens, rng = rng)
# pivoting is not stable
#field = DIVAnd.random(mask, (pm, pn, po), (lenx, leny, lenz), Nens, rng = rng)
field = @. sin(xi/lenx) * cos(yi/leny) * cos(zi/lenz)


z = [0.3, 0.5, 0.7]
Expand All @@ -158,8 +160,7 @@ x = (xi[s], yi[s], zi[s])
v = field[s]
epsilon2 = ones(length(x[3])) + x[3][:] .^ 2

fitlenz,
dbinfo = @test_logs (:info, r".*at*") match_mode = :any DIVAnd.fitvertlen(
fitlenz, dbinfo = @test_logs (:info, r".*at*") match_mode = :any DIVAnd.fitvertlen(
x,
v,
z;
Expand All @@ -168,6 +169,28 @@ dbinfo = @test_logs (:info, r".*at*") match_mode = :any DIVAnd.fitvertlen(
);
@test median(fitlenz) lenz rtol = 0.5

# mix Float32 and Float64
T = Float32
fitlenz,dbinfo = @test_logs (:info, r".*at*") match_mode = :any DIVAnd.fitvertlen(
map(v -> T.(v),x),
v,
T.(z);
epsilon2 = T.(epsilon2),
rng = rng,
);
@test median(fitlenz) lenz rtol = 0.5

# just Float32
fitlenz,dbinfo = @test_logs (:info, r".*at*") match_mode = :any DIVAnd.fitvertlen(
map(v -> T.(v),x),
T.(v),
T.(z);
epsilon2 = T.(epsilon2),
rng = rng,
);

@test median(fitlenz) lenz rtol = 0.5


# RandomCoupels iterators

Expand Down

0 comments on commit 589457e

Please sign in to comment.