# Set up

In [None]:
using IntervalArithmetic, Combinatorics, Polynomials, Serialization, Base.Threads, Random, LaTeXStrings, LinearAlgebra

In [None]:
K = 150
Ms = collect(K:-1:0)
#indices for different frequencies
indices = [1; cumsum(Ms.+1).+1]
N = indices[end]-1

# approximate solution
ū = interval.(deserialize("ubar"));

# regularised Vandermonde matrices for 2p and p+1 products
V̄4 = deserialize("V4r");
V̄6 = deserialize("V6r");

N = size(V̄4)[2]
N4 = size(V̄4)[1];
N6 = size(V̄6)[1];

# cut ū after frequency 2*K₀-1
K₀ = 13
ū[indices[K₀+1]:end] .=interval(0);

𝔏 = Diagonal(reduce(vcat, [interval(k) .+ interval(3//2) .+ interval.(collect(0:Ms[k+1])) for k=0:K]))
p = interval(3)
d = 2
λₘ = maximum(𝔏) + interval(1)
λ₀ = interval(3//2);
Z = interval(4)*interval(BigFloat, π)

In [None]:
LinearAlgebra.norm(v::Vector) = sqrt(sum(v.^2))
sups = deserialize("suppsi")

function L6(u)
    # computes ||u³||²
    U = V̄6.*u'
    P = [sum(U[:,indices[k+1]:indices[k+2]-1], dims = 2)[:] for k=0:K]
    freq_proj = zeros(Interval{Float64}, (N6, 3*K+2))
    for i=0:K₀-1, k=0:K₀-1, j=0:K₀-1
        if i > k && k > j
            X = P[k+1].*P[j+1].*P[i+1]*interval(3//2)
            freq_proj[:,k+i+j+2] += X
            freq_proj[:,abs(2*(k+i-j)+1)÷2+1] += X
            freq_proj[:,abs(2*(k-i-j)-1)÷2+1] += X
            freq_proj[:,abs(2*(-k+i-j)-1)÷2+1] += X
        elseif i > k && k==j
            X = P[k+1].*P[j+1].*P[i+1]*interval(3//4)
            freq_proj[:,k+i+j+2] += X
            freq_proj[:,abs(2*(k+i-j)+1)÷2+1] += X
            freq_proj[:,abs(2*(k-i-j)-1)÷2+1] += X
            freq_proj[:,abs(2*(-k+i-j)-1)÷2+1] += X
        elseif i == k && k>j
            X = P[k+1].*P[j+1].*P[i+1]*interval(3//4)
            freq_proj[:,k+i+j+2] += X
            freq_proj[:,abs(2*(k+i-j)+1)÷2+1] += X
            freq_proj[:,abs(2*(k-i-j)-1)÷2+1] += X
            freq_proj[:,abs(2*(-k+i-j)-1)÷2+1] += X
        elseif i == k && k==j
            X = P[k+1].*P[j+1].*P[i+1]/interval(4)
            freq_proj[:,k+i+j+2] += X
            freq_proj[:,abs(2*(k+i-j)+1)÷2+1] += X
            freq_proj[:,abs(2*(k-i-j)-1)÷2+1] += X
            freq_proj[:,abs(2*(-k+i-j)-1)÷2+1] += X
        end
    end
    return sum(freq_proj.^interval(2))
end

function L2(u)
    return sum(u.^interval(2))
end

function H2(u)
    return sum((𝔏*u).^2)
end

# rigorous upper bound of the 2-norm of a matrix
function op_norm(A)
    if size(A) == (2,2)
        Z = sqrt(sum(A.^2) + sqrt(((A[1,2]+A[2,1])^2+(A[1,1]-A[2,2])^2)*((A[1,2]-A[2,1])^2+(A[1,1]+A[2,2])^2)))/sqrt(interval(2))
        if isguaranteed(Z)
            return interval(sup(Z))
        else
            return Z
        end
    else
        all(isguaranteed.(A)) || error("matrix not guaranteed")
        return sqrt(interval(maximum(sup.(sum(abs.(A), dims = 1))))*interval(maximum(sup.(sum(abs.(A), dims = 2)))))
    end
end

function L∞(u)
    return sum(abs.(u).*sups)
end

function Gram(u)
    # computes the Gram matrix Gᵢⱼ = <ψ̂ᵢ,u²ψ̂ⱼ>
    #     println("test")
    U = V̄4.*u'
    P = [sum(U[:,indices[k+1]:indices[k+2]-1], dims = 2)[:] for k=0:K₀-1]
    Gu = [zeros(Interval{Float64},(N, N)) for i in 1:Threads.nthreads()]
    count = zeros(Int64, Threads.nthreads())
    ind = shuffle(vcat([(i,k,j) for i=0:K, k=0:K₀-1, j=0:K₀-1]...))
    Threads.@threads for index in ind
        (i, j, k) = index
        if k>j
            X = (P[k+1].*P[j+1]/interval(2)).*(V̄4[:,indices[i+1]:indices[i+2]-1])
            for l in (k+i+j+1, abs(2*(k+i-j)+1)÷2, abs(2*(k-i-j)-1)÷2, abs(2*(-k+i-j)-1)÷2)
                if 2*l+1<=2*K+1 && l > i
                    Y = (V̄4[:,indices[l+1]:indices[l+2]-1])'*X
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1, indices[i+1]:indices[i+2]-1] += Y
                    Gu[Threads.threadid()][indices[i+1]:indices[i+2]-1, indices[l+1]:indices[l+2]-1] += Y'    
                elseif 2*l+1<=2*K+1 && l == i
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1, indices[i+1]:indices[i+2]-1] += (V̄4[:,indices[l+1]:indices[l+2]-1])'*X
                end
            end
        elseif k == j
            X = (P[k+1].*P[j+1]/interval(4)).*(V̄4[:,indices[i+1]:indices[i+2]-1])
            for l in (k+i+j+1, abs(2*(k+i-j)+1)÷2, abs(2*(k-i-j)-1)÷2, abs(2*(-k+i-j)-1)÷2)
                if 2*l+1<=2*K+1 && l > i
                    Y = (V̄4[:,indices[l+1]:indices[l+2]-1])'*X
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1, indices[i+1]:indices[i+2]-1] += Y
                    Gu[Threads.threadid()][indices[i+1]:indices[i+2]-1, indices[l+1]:indices[l+2]-1] += Y'    
                elseif 2*l+1<=2*K+1 && l == i
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1, indices[i+1]:indices[i+2]-1] += (V̄4[:,indices[l+1]:indices[l+2]-1])'*X
                end
            end
        end
        count[Threads.threadid()] +=1
        if count[Threads.threadid()] %500 == 0
            # println((count[Threads.threadid()], Threads.threadid()))
        end
    end
    # println("test")
    return sum(Gu)
end

function multiply(A::Matrix{Interval{Float64}}, B::Matrix{Interval{Float64}})
    #multi-threaded matrix multiplication
    m, k = size(A)
    l, n = size(B)
    k == l || error("dimension mismatch")
    C = zeros(Interval{Float64}, (m,n))
    ind = shuffle(1:n)
    Threads.@threads for i in ind
        # println(i)
        C[:, i] = A[:, .!iszero.(B[:,i])]*B[.!iszero.(B[:,i]),i]
    end
    return C
end

function compute_norms(u)
    # computes ||u²ψ̂ⱼ||²
    U = V̄6.*u'
    P = [sum(U[:,indices[k+1]:indices[k+2]-1], dims = 2)[:] for k=0:K₀-1]
    u_prod = [P[k+1].*P[j+1] for k=0:K₀-1, j=0:K₀-1]
    norms6 = zeros(Interval{Float64}, N)
    ind = shuffle(collect(0:K))
    Threads.@threads for i in ind
#         println((i, Threads.threadid()))
        for m = 0:Ms[i+1]
            freq_proj = zeros(Interval{Float64}, (N6, 3*K+2))
            W = V̄6[:,indices[i+1]+m]
            for k=0:K₀-1, j=0:K₀-1
                X = W.*(u_prod[k+1,j+1])/interval(4)
                freq_proj[:,k+i+j+2] += X
                freq_proj[:,abs(2*(k+i-j)+1)÷2+1] += X
                freq_proj[:,abs(2*(k-i-j)-1)÷2+1] += X
                freq_proj[:,abs(2*(-k+i-j)-1)÷2+1] += X
            end
            norms6[indices[i+1]+m] = sum(freq_proj.^2)
        end
    end
    return norms6
end

# Proof

In [None]:
G = Gram(ū);

In [None]:
Fū∞ = (L6(ū)- L2(G*ū))

In [None]:
# PₙF(ū)
PFū = ū - inv(𝔏)*(ū/(p-interval(1))+G*ū);
# println(norm(PFū))

In [None]:
int = compute_norms(ū);

In [None]:
# PₙDF(ū)Pₙ
DFū = interval.(I(N)) - inv(𝔏)/(p-interval(1)) - interval(3)*𝔏\G
# approximate numerical inverse of PₙDF(ū)Pₙ
Aₙ = interval.(inv(mid.(DFū)));

In [None]:
Y = (sqrt(H2(Aₙ*PFū)+Fū∞))

In [None]:
Z¹¹ = op_norm(𝔏*(interval.(I(N)) - multiply(Aₙ,DFū))*inv(𝔏))

In [None]:
w = sqrt.(abs.([int[i] - L2(G[i,:]) for i=1:N]));

In [None]:
Z²¹ = p*norm(𝔏\w)

In [None]:
Z¹² = p*norm(abs.(𝔏*Aₙ*inv(𝔏))*w)/λₘ

In [None]:
supū = L∞(ū)
Z²² = (interval(1)/(p-interval(1))+p*supū)/λₘ

In [None]:
[Z¹¹ Z¹² ; Z²¹ Z²²]

In [None]:
Z₁ = op_norm([Z¹¹ Z¹² ; Z²¹ Z²²])

In [None]:
# Calculation of the L²-norm of the Hessian of η and of the sup of η
C₀ = I"0.56419"
C₁ = I"0.79789"
C₂ = I"0.23033"
C2 = sqrt(Z)*(interval(2//d)*C₀ + sqrt(interval(2//d))*C₁ + C₂)

In [None]:
op_n = interval(max(sup(op_norm(𝔏*Aₙ*inv(𝔏))), 1))
Z₂ = interval(4)*C2*supū*op_n
Z₃ = interval(4)*C2^2*op_n

In [None]:
δ̄ = (-Z₂+sqrt(Z₂^2 +interval(2)*Z₃-interval(2)*Z₁*Z₃))/Z₃

In [None]:
P(δ) = Z₃/interval(6)*δ^3 + Z₂/interval(2)*δ^2 - (interval(1) - Z₁)*δ + Y

In [None]:
δ̲ = Y/(interval(1)-Z₁)*interval(1.1856)
if sup(P(δ̲))<0
    println("δ̲ is validated")
end

# Plotting

In [None]:
rloc = big.(collect(0:400))/100;

φᵣ = zeros((K₀, 401))
for k in 0:K₀-1
    # println(k)
    p = ((rloc/2).^(2*k+1)).*exp.(-rloc.^2/4)
    for m=0:big(Ms[k+1])
        # println((k,m))
        Lₘ = Polynomial([(-1)^j*big(binomial(m+2*k+1,m-j)//factorial(j)) for j=0:m])
        Z = sqrt(big(factorial(m+2*k+1)//factorial(m)))
        φᵣ[k+1,:] += mid(ū[indices[k+1]+m])*p.* Lₘ.(rloc.^2/4)/Z
    end
    
end

In [None]:
using Plots

In [None]:
ϑ = big.(collect(0:100))/100*2*big(π)
c = cos.(ϑ.*collect(1:2:2*K₀-1)');

In [None]:
surface(rloc.*cos.(ϑ'),rloc.*sin.(ϑ'), sqrt(big(2))*(c*φᵣ)', xlabel = L"$x$", ylabel = L"$y$", zlabel = L"$\overline{u}\,(x, y)$", colorbar = false, dpi = 800)

In [None]:
png("asymmetric_heat")