In [1]:
using Statistics
using Plots
using FFTW
using Statistics
using Optim
using Images, FileIO, ImageIO
using Printf
using Revise
using Profile
using LinearAlgebra
using JLD2
using Random
using Distributions
using FITSIO
using LineSearches
using Flux
using StatsBase
using SparseArrays

push!(LOAD_PATH, pwd()*"/../../../main")
using DHC_2DUtils
push!(LOAD_PATH, pwd()*"/../../../scratch_NM")
using Deriv_Utils_New
using Data_Utils
using Visualization
using ReconFuncs
include("../../../main/compute.jl")

#FUNCTIONS
function realspace_filter(Nx, f_i, f_v)

    zarr = zeros(ComplexF64, Nx, Nx)
    for i = 1:length(f_i)
        zarr[f_i[i]] = f_v[i] # filter*image in Fourier domain
    end
    filt = ifft(zarr)  # real space, complex
    return filt
end

function DHC_compute_S20r_noisy_so(image::Array{Float64,2}, filter_hash::Dict, sigim::Array{Float64,2};
    doS2::Bool=true, doS20::Bool=false, apodize=false, norm=false, iso=false, FFTthreads=1, filter_hash2::Dict=filter_hash, coeff_mask=nothing)
    #Not using coeff_mask here after all. ASSUMES ANY ONE of doS12, doS2 and doS20 are true, when using coeff_mask
    #@assert !iso "Iso not implemented yet"
    Nf = size(filter_hash["filt_index"])[1]
    Nx = size(image)[1]

    if apodize
        ap_img = apodizer(image)
        dA = get_dApodizer(Nx, Dict([(:apodize => apodize)]))
    else
        ap_img = image
        dA = get_dApodizer(Nx, Dict([(:apodize => apodize)]))
    end

    function DHC_compute_biased(image::Array{Float64,2}, filter_hash::Dict, filter_hash2::Dict, sigim::Array{Float64,2};
        doS2::Bool=true, doS20::Bool=false, norm=true, iso=false, FFTthreads=2, normS1::Bool=false, normS1iso::Bool=false)
        # image        - input for WST
        # filter_hash  - filter hash from fink_filter_hash
        # filter_hash2 - filters for second order.  Default to same as first order.
        # doS2         - compute S2 coeffs
        # doS20        - compute S2 coeffs
        # norm         - scale to mean zero, unit variance
        # iso          - sum over angles to obtain isotropic coeffs

        # Use 2 threads for FFT
        FFTW.set_num_threads(FFTthreads)

        (Nx, Ny)  = size(image)
        if Nx != Ny error("Input image must be square") end
        (Nf, )    = size(filter_hash["filt_index"])
        if Nf == 0  error("filter hash corrupted") end
        @assert Nx==filter_hash["npix"] "Filter size should match npix"
        @assert Nx==filter_hash2["npix"] "Filter2 size should match npix"
        @assert (normS1 && normS1iso) != 1 "normS1 and normS1iso are mutually exclusive"

        # allocate coeff arrays
        out_coeff = []
        S0  = zeros(Float64, 2)
        S1  = zeros(Float64, Nf)
        if doS2  S2  = zeros(Float64, Nf, Nf) end  # traditional 2nd order
        if doS20 S20 = zeros(Float64, Nf, Nf) end  # real space correlation
        anyM2 = doS2 | doS20
        anyrd = doS2 | doS20             # compute real domain with iFFT

        # allocate image arrays for internal use
        if anyrd
            im_rd_0_1  = Array{Float64, 3}(undef, Nx, Ny, Nf)
            ψsqfac = Array{Float64, 4}(undef, Nf, Nx, Ny, 3)
            Iψfac = Array{ComplexF64, 4}(undef, Nf, Nx, Ny, 2)
            ψfaccross = Array{Float64, 2}(undef, Nf, Nf)
            sozoterms = Array{Float64, 4}(undef, Nf, Nx, Ny, 2)
            rsψmat = Array{ComplexF64, 3}(undef, Nf, Nx, Ny)
        end
        varim  = sigim.^2
        Pf = plan_fft(varim)
        fvar = Pf*(varim)

        ## 0th Order
        S0[1]   = mean(image)
        norm_im = image.-S0[1]
        S0[2]   = sum(norm_im .* norm_im)/(Nx*Ny)
        if norm
            norm_im ./= sqrt(Nx*Ny*S0[2])
        else
            norm_im = copy(image)
        end

        append!(out_coeff,S0[:])

        ## 1st Order
        im_fd_0 = Pf*(norm_im)  # total power=1.0

        # unpack filter_hash
        f_ind   = filter_hash["filt_index"]  # (J, L) array of filters represented as index value pairs
        f_val   = filter_hash["filt_value"]

        zarr = zeros(ComplexF64, Nx, Ny)  # temporary array to fill with zvals

        # make a FFTW "plan" for an array of the given size and type
        if anyrd
            P = plan_ifft(im_fd_0) end  # P is an operator, P*im is ifft(im)

        ## Main 1st Order and Precompute 2nd Order
        for f = 1:Nf
            S1tot = 0.0
            f_i = f_ind[f]  # CartesianIndex list for filter
            f_v = f_val[f]  # Values for f_i
            # for (ind, val) in zip(f_i, f_v)   # this is slower!
            if length(f_i) > 0
                for i = 1:length(f_i)
                    ind       = f_i[i]
                    zval      = f_v[i] * im_fd_0[ind]
                    S1tot    += abs2(zval)
                    zarr[ind] = zval        # filter*image in Fourier domain
                end
                S1[f] = S1tot/(Nx*Ny)  # image power
                if anyrd
                    #psi_pow = sum(f_v.^2)./(Nx*Ny)
                    im_rd_0_1[:,:,f] .= abs2.(P*zarr)
                    rsψmat[f, :, :] = realspace_filter(Nx, f_i, f_v)
                    frealψ = Pf*(real.(rsψmat[f, :, :]))
                    fimψ = Pf*(imag.(rsψmat[f, :, :]))
                    ψsqfac[f, :, :, 1] = real.(rsψmat[f, :, :]) #check that real
                    ψsqfac[f, :, :, 2] = imag.(rsψmat[f, :, :])
                    #ψsqfac[f, :, :, 3] = P*(fvar .* powrs)
                    #ψsqfac[f, :, :, 4] = P*(fvar .* Pf*(imag.(rsψ).*real.(rsψ)))

                    Iψfac[f, :, :, 1] = P*(im_fd_0 .* frealψ) #(I ✪ ψR)
                    Iψfac[f, :, :, 2] = P*(im_fd_0 .* fimψ)   #(I ✪ ψI)
                    #Iψfac[f, :, :, 3] = P*(fvar .* fsqrsψ)
                    #Iψfac[f, :, :, 4] = P*(fvar .* fsqimψ)
                end

                zarr[f_i] .= 0
            end
        end

        append!(out_coeff, iso ? filter_hash["S1_iso_mat"]*S1 : S1)

        if normS1iso
            S1iso = vec(reshape(filter_hash["S1_iso_mat"]*S1,(1,:))*filter_hash["S1_iso_mat"])
        end

        # we stored the abs()^2, so take sqrt (this is faster to do all at once)
        if anyrd
            im_rd_0_1 .= sqrt.(im_rd_0_1)
            for f1=1:Nf
                #Precompute SO_λ1.ZO_λ2 terms
                fsqrsψ = Pf*(real.(rsψmat[f1, :, :]).^2) #F(ψR^2)
                fsqimψ = Pf*(imag.(rsψmat[f1, :, :]).^2) #F(ψI^2)
                powrs = fsqrsψ + fsqimψ #F(|ψ|^2)
                sozoterms[f1, :, :, 1] = (P*(fvar .* powrs)) ./ im_rd_0_1[:, :, f1]                   #(σ2 ✪ |ψ|^2)/|I ✪ ψ|
                sozoterms[f1, :, :, 2] = real.((Iψfac[f1, :, :, 1].^2) .* (P*(fvar .* fsqrsψ)))     #(I ✪ ψR)^2 . (σ2 ✪ ψR^2)
                sozoterms[f1, :, :, 2] += real.((Iψfac[f1, :, :, 2].^2) .* (P*(fvar .* fsqimψ)))    #(I ✪ ψI)^2 . (σ2 ✪ ψI^2)
                sozoterms[f1, :, :, 2] += real.(2*Iψfac[f1, :, :, 1].*Iψfac[f1, :, :, 2].* (P*(fvar .* (Pf*(imag.(rsψmat[f1, :, :]).*real.(rsψmat[f1, :, :]))))))  #2(I ✪ ψR)(I ✪ ψI) . (σ2 ✪ ψRψI)
                sozoterms[f1, :, :, 2] = sozoterms[f1, :, :, 2]./(im_rd_0_1[:, :, f1].^3)
                sozoterms[f1, :, :, 1] -= sozoterms[f1, :, :, 2]
            end
            for f1=1:Nf
                for f2=1:Nf
                    #println("f2", f2)
                    val1 = Pf*(ψsqfac[f1, :, :, 1] .* ψsqfac[f2, :, :, 1]) #F{ψ_λ1R.ψ_λ2R}
                    #println(size(val1), size(fvar))
                    term1 = (P*(fvar .* val1) .* Iψfac[f1, :, :, 1]) .* Iψfac[f2, :, :, 1]
                    term2 = (P*(fvar .* (Pf*(ψsqfac[f1, :, :, 1] .* ψsqfac[f2, :, :, 2]))).* Iψfac[f1, :, :, 1]) .* Iψfac[f2, :, :, 2]
                    term3 = (P*(fvar .* (Pf*(ψsqfac[f1, :, :, 2] .* ψsqfac[f2, :, :, 1]))).* Iψfac[f1, :, :, 2]) .* Iψfac[f2, :, :, 1]
                    term4 = (P*(fvar .* (Pf*(ψsqfac[f1, :, :, 2] .* ψsqfac[f2, :, :, 2]))).* Iψfac[f1, :, :, 2]) .* Iψfac[f2, :, :, 2]
                    so1zo2 = 0.5 .* im_rd_0_1[:, :, f2] .* sozoterms[f1, :, :, 1] #T0_λ2 . T2_λ1
                    so2zo1 = 0.5 .* im_rd_0_1[:, :, f1] .* sozoterms[f2, :, :, 1] #T0_λ1 . T2_λ2
                    #so(λ1) * zo(λ2)
                    #so1_1 = ψsqfac[f1, :, :, 3] ./ im_rd_0_1[:, :, f1]
                    #so1_2 = (Iψfac[f1, :, :, 1].^2) .* Iψfac[f1, :, :, 3]
                    #so1_2 += (Iψfac[f1, :, :, 2].^2) .* Iψfac[f1, :, :, 4]
                    #so1_2 += 2*Iψfac[f1, :, :, 1].*Iψfac[f1, :, :, 2].*ψsqfac[f1, :, :, 4]
                    #so1_2 = so1_2./im_rd_0_1[:, :, f1].^3

                    combined = (term1 + term2 + term3 + term4)./(im_rd_0_1[:, :, f1] .* im_rd_0_1[:, :, f2]) #fo1fo2
                    combined += (so2zo1 + so1zo2)
                    comsum = sum(combined)
                    println(imag(comsum))
                    ψfaccross[f1, f2] = real(comsum)
                end
            end
        end

        Mat2 = filter_hash["S2_iso_mat"]
        if doS2
            f_ind2   = filter_hash2["filt_index"]  # (J, L) array of filters represented as index value pairs
            f_val2   = filter_hash2["filt_value"]

            ## Traditional second order
            for f1 = 1:Nf
                thisim = fft(im_rd_0_1[:,:,f1])  # Could try rfft here
                # Loop over f2 and do second-order convolution
                if normS1
                    normS1pwr = S1[f1]
                elseif normS1iso
                    normS1pwr = S1iso[f1]
                else
                    normS1pwr = 1
                end

                for f2 = 1:Nf
                    f_i = f_ind2[f2]  # CartesianIndex list for filter
                    f_v = f_val2[f2]  # Values for f_i
                    # sum im^2 = sum(|fft|^2/npix)
                    #intfterm = fft((real.(rsψ1) + imag.(rsψ1)).*(real.(rsψ2) + imag.(rsψ2)))
                    S2[f1,f2] = sum(abs2.(f_v .* thisim[f_i]))/(Nx*Ny)/normS1pwr
                end
            end
            append!(out_coeff, iso ? Mat2*S2[:] : S2[:])
        end

        # Real domain 2nd order
        if doS20
            Amat = reshape(im_rd_0_1, Nx*Ny, Nf)
            S20  = Amat' * Amat
            #println(size(S20))
            S20 .+= ψfaccross #assuming the nx*ny factor above was for the parseval correction
            append!(out_coeff, iso ? Mat2*S20[:] : S20[:])
        end


        return out_coeff
    end

    sall = DHC_compute_biased(ap_img, filter_hash, filter_hash2, sigim, doS2=doS2, doS20=doS20, norm=norm, iso=iso, FFTthreads=FFTthreads)
    dS20dp = wst_S20_deriv(ap_img, filter_hash)
    ap_dS20dp = dA * reshape(dS20dp, Nx*Nx, Nf*Nf)
    G =  ap_dS20dp .* reshape(sigim, Nx*Nx, 1)
    cov = G'*G

    if coeff_mask!=nothing
        @assert count(!iszero, coeff_mask[1:Nf+2])==0 "This function only handles S20r"
        @assert length(coeff_mask)==length(sall) "Mask must have the same length as the total number of coefficients"
        s20rmask = coeff_mask[Nf+3:end]
        return sall[coeff_mask], cov[s20rmask, s20rmask]
    else
        return sall, cov
    end

end



##0) What smoothing kernel gives the best MSE, Abs Frac Res and PowSpec?
ARGS_buffer = ["reg", "nonapd", "noiso", "../../../scratch_NM/StandardizedExp/Nx64/", "full_3losstest", "Full+Eps"]
ENV_buffer= "1000"
numfile = Base.parse(Int, ENV_buffer)
println(numfile, ARGS_buffer[1], ARGS_buffer[2])

if ARGS_buffer[1]=="log"
    logbool=true
else
    if ARGS_buffer[1]!="reg" error("Invalid log arg") end
    logbool=false
end

if ARGS_buffer[2]=="apd"
    apdbool = true
else
    if ARGS_buffer[2]!="nonapd" error("Invalid apd arg") end
    apdbool=false
end

if ARGS_buffer[3]=="iso"
    isobool = true
else
    if ARGS_buffer[3]!="noiso" error("Invalid iso arg") end
    isobool=false
end


direc = ARGS_buffer[4] #"../StandardizedExp/Nx64/noisy_stdtrue/" #Change
datfile = direc * "Data_" * string(numfile) * ".jld2" #Replace w SLURM array
loaddf = load(datfile)
true_img = loaddf["true_img"]
init = loaddf["init"]

kbins=convert(Array{Float64, 1}, collect(1:32))
for l in [0.8, 0.9, 1.0, 1.1, 1.2]
    apdsmoothed = imfilter(init, Kernel.gaussian(l))
    true_ps = calc_1dps(true_img, kbins)
    smoothps = calc_1dps(apdsmoothed, kbins)
    fracres = (apdsmoothed .- true_img)./true_img
    fps = (smoothps .- true_ps)./true_ps
    println("Mean Abs Frac, Smoothed = ", round(mean(abs.((apdsmoothed .- true_img)./true_img)), digits=3))
    println("MSE = ", round(mean((apdsmoothed .- true_img).^2), digits=5))
    println("Power Spec Frac Res, Smoothed = ", round(mean(abs.(smoothps .- true_ps)), digits=3))
end

##3) Can true+emp noisy do better than smoothing with pixwise reg?

1000regnonapd
Mean Abs Frac, Smoothed = 0.126
MSE = 1.0e-5
Power Spec Frac Res, Smoothed = 0.306
Mean Abs Frac, Smoothed = 0.114
MSE = 1.0e-5
Power Spec Frac Res, Smoothed = 0.281
Mean Abs Frac, Smoothed = 0.107
MSE = 1.0e-5
Power Spec Frac Res, Smoothed = 0.272
Mean Abs Frac, Smoothed = 0.1
MSE = 1.0e-5
Power Spec Frac Res, Smoothed = 0.278
Mean Abs Frac, Smoothed = 0.096
MSE = 1.0e-5
Power Spec Frac Res, Smoothed = 0.294
