Skip to content

Commit

Permalink
Deflation added
Browse files Browse the repository at this point in the history
  • Loading branch information
jmbeckers committed Jan 5, 2024
1 parent 3e6f1ed commit 4e3ad9c
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/DIVAnd_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Input:
Output:
fi: analyzed field
"""
function DIVAnd_solve!(s::DIVAnd_struct{T,Ti,N,OT}, fi0, f0; btrunc = []) where {T,Ti,N,OT}
function DIVAnd_solve!(s::DIVAnd_struct{T,Ti,N,OT}, fi0, f0; btrunc = [],ZDF=nothing) where {T,Ti,N,OT}
# btrunc=[]

H = s.H
Expand Down Expand Up @@ -79,6 +79,7 @@ function DIVAnd_solve!(s::DIVAnd_struct{T,Ti,N,OT}, fi0, f0; btrunc = []) where
x0 = fi0,
pc! = s.preconditioner,
progress = s.progress,
ZDF=ZDF
)

if !success
Expand Down
6 changes: 4 additions & 2 deletions src/DIVAndrun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ function DIVAndrun(
coeff_laplacian::Vector{Float64} = ones(ndims(mask)),
coeff_derivative2::Vector{Float64} = zeros(ndims(mask)),
mean_Labs = nothing,
ZDF=nothing,
) where {N,T}

# Inequality constraints via loop around classical analysis (recursive call)
Expand Down Expand Up @@ -112,7 +113,8 @@ end
QCMETHOD = QCMETHOD,
coeff_laplacian = coeff_laplacian,
coeff_derivative2 = coeff_derivative2,
mean_Labs = mean_Labs
mean_Labs = mean_Labs,
ZDF=ZDF
)

# Calculate inequality constraints. If satisfied put ineqok true otherwise
Expand Down Expand Up @@ -229,7 +231,7 @@ end
fi0_pack = statevector_pack(s.sv, (fi0,))[:, 1]

#@code_warntype DIVAnd_solve!(s,fi0_pack,f0)
fi = DIVAnd_solve!(s, fi0_pack, f0; btrunc = btrunc)::Array{T,N}
fi = DIVAnd_solve!(s, fi0_pack, f0; btrunc = btrunc,ZDF=ZDF)::Array{T,N}

# @info "Done solving"

Expand Down
135 changes: 129 additions & 6 deletions src/conjugategradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ function conjugategradient(
minit::Int = 0,
pc! = pc_none!,
progress = (iter, x, r, tol2, fun!, b) -> nothing,
ZDF=nothing,
) where {T}

success = false
Expand All @@ -124,14 +125,119 @@ function conjugategradient(
Ap = similar(x)
z = similar(x)

# gradient at initial guess
fun!(x, Ap)
r = b - Ap















##### To add deflation, change r into Pr here and keep it stored in r
# see https://link.springer.com/chapter/10.1007/978-3-030-55874-1_45
# TEST FORCE deflation true
if ZDF==nothing
deflation=false
else
deflation=true
end
#NMDF=10
#ZDF = [falses(size(x,1)) for i=1:NMDF]
#for i=1:size(x,1)
# jr=rand(1:NMDF)
# ZDF[jr][i]=true
#end

if deflation
@show typeof(ZDF),size(ZDF)


# Create matrices
function testZDF(ZDF)
isok=true
for i=1:size(ZDF[1],1)
s=0
for j=1:size(ZDF,1)
s=s+ZDF[j][i]
end
if s!==1
@show i,s
isok=false
end
end
return isok
end

if testZDF(ZDF)
# ok to go for deflation
EDF=randn(Float64,size(ZDF,1),size(ZDF,1))
AZDF=randn(Float64,size(x,1),size(ZDF,1))
BDF=zeros(Float64,size(ZDF,1))
CDF=zeros(Float64,size(ZDF,1))
# If AZDF is stored, that is a major storage ... either that or double Ax calculations ?
for idf=1:size(ZDF,1)
fun!(float.(ZDF[idf]), Ap)
AZDF[:,idf].=Ap #
for jdf=idf:size(ZDF,1)
EDF[idf,jdf]=sum(AZDF[ZDF[jdf],idf])
#@show EDF[idf,jdf]-ZDF[jdf]'*A*ZDF[idf]
EDF[jdf,idf]=EDF[idf,jdf]
end
end
EDF=cholesky(EDF)

function projectPx!(x)
#ZDFTx
for idf=1:size(ZDF,1)
BDF[idf]= sum(x[ZDF[idf]])
end
CDF.=EDF\BDF
# x-> x - AZDF E^1 ZDF' x
x.-=AZDF*CDF
end

function projectPTx!(fx,x)
#fx already Ax done outside, x must correspond to the x used to calculate fx and it will be mutated
for idf=1:size(ZDF,1)
BDF[idf]= sum(fx[ZDF[idf]])
end
CDF.=EDF\BDF
# x-> x - ZDF E^1 ZDF' Ax
for i=1:size(ZDF,1)
x.-=ZDF[i].*CDF[i]
end
end



else
@show "ZDF not valid"
deflation=false
end

end
########


# gradient at initial guess
fun!(x, Ap)
r = b - Ap

###########
if deflation
projectPx!(r)
end

# quick exit


r2=r r
if r2 < tol2
GC.enable(true)
Expand Down Expand Up @@ -167,7 +273,12 @@ function conjugategradient(
# compute A*p
#@show k
fun!(p, Ap)


##### To add deflation, change Ap into PAp here and keep it stored in Ap
if deflation
projectPx!(Ap)
end
#####
# how far do we need to go in direction p?
# alpha is determined by linesearch
alpha[k] = zr_old / (p Ap)
Expand All @@ -185,7 +296,11 @@ function conjugategradient(
# @show "restart"
fun!(x, Ap)
r = b - Ap

##### To add deflation, change r into Pr here and keep it stored in r
if deflation
projectPx!(r)
end
####
else
r = BLAS.axpy!(-alpha[k], Ap, r)
end
Expand Down Expand Up @@ -228,7 +343,15 @@ function conjugategradient(
end

GC.enable(true)

##### To add deflation, change x into P'x and add ZAcZHb before returning
if deflation
fun!(x, Ap)
projectPTx!(Ap,x)
Ap.=0.0
projectPTx!(b,Ap)
x.=x.-Ap
end
####
return x, success, kfinal

end
Expand Down

0 comments on commit 4e3ad9c

Please sign in to comment.