# Set up

In [None]:
using IntervalArithmetic, Combinatorics, Polynomials, Serialization, Base.Threads, Random, LaTeXStrings, LinearAlgebra

In [None]:
# You only need to run this cell once, in order to create the Vandermonde matrices for the quadratures.
# These matrices are then stored and will not need to be recomputed if you re-run the proof.
include("quadrature.jl")

We aim to find a radially asymmetric self-similar solution to the critical nonlinear Schrodinger equation with $d=2$, $\omega = -5/2$ and $\varepsilon = -1$, which in our setting is equivalent to solving the equation
$$\mathcal{L}u-\frac{7}{4}u+\frac{e^{r^2/4}}{2}u^3 = 0.$$

In [None]:
K = 100
Ms = collect(K:-1:0)
# indices of different frequencies
indices = [1; cumsum(Ms.+1).+1]
N = indices[end]-1

# approximate solution
setprecision(128)
≈´ = interval.(deserialize("ubar"));

# regularised Vandermonde matrices for 2p and p+1 products
VÃÑ4 = deserialize("V4r");
VÃÑ6 = deserialize("V6r");

N = size(VÃÑ4)[2]
N4 = size(VÃÑ4)[1];
N6 = size(VÃÑ6)[1];

# cut uÃÑ after frequency 2*K‚ÇÄ-1
K‚ÇÄ = 18
uÃÑ[indices[K‚ÇÄ+1]:end] .=interval(0);

In [None]:
function H2(u)
    return sum((ùîè*u).^2)
end

# rigorous upper bound of the 2-norm of a matrix
function op_norm(A)
    if size(A) == (2,2)
        Z = sqrt(sum(A.^2) + sqrt(((A[1,2]+A[2,1])^2+(A[1,1]-A[2,2])^2)*((A[1,2]-A[2,1])^2+(A[1,1]+A[2,2])^2)))/sqrt(interval(2))
        if isguaranteed(Z)
            return interval(sup(Z))
        else
            return Z
        end
    else
        all(isguaranteed.(A)) || error("matrix not guaranteed")
        return sqrt(interval(maximum(sup.(sum(abs.(A), dims = 1))))*interval(maximum(sup.(sum(abs.(A), dims = 2)))))
    end
end

function L‚àû(u)
    return sum(abs.(u).*sups)
end

function Gram(u)
    # computes G·µ¢‚±º = <œàÃÇ·µ¢, exp(r¬≤/4)u¬≤œàÃÇ‚±º>
#     println("test")
    U = VÃÑ4.*u'
    P = [sum(U[:,indices[k+1]:indices[k+2]-1], dims = 2)[:] for k=0:K‚ÇÄ-1]
    Gu = [zeros(Interval{Float64},(N, N)) for i in 1:Threads.nthreads()]
    count = zeros(Int64, Threads.nthreads())
    ind = shuffle(vcat([(i,k,j) for i=0:K, k=0:K‚ÇÄ-1, j=0:K‚ÇÄ-1]...))
    Threads.@threads for index in ind
        (i, j, k) = index
        if k>j
            X = (P[k+1].*P[j+1]/interval(2)).*(VÃÑ4[:,indices[i+1]:indices[i+2]-1])
            for l in (k+i+j+1, abs(2*(k+i-j)+1)√∑2, abs(2*(k-i-j)-1)√∑2, abs(2*(-k+i-j)-1)√∑2)
                if 2*l+1<=2*K+1 && l > i
                    Y = (VÃÑ4[:,indices[l+1]:indices[l+2]-1])'*X
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1, indices[i+1]:indices[i+2]-1] += Y
                    Gu[Threads.threadid()][indices[i+1]:indices[i+2]-1, indices[l+1]:indices[l+2]-1] += Y'    
                elseif 2*l+1<=2*K+1 && l == i
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1, indices[i+1]:indices[i+2]-1] += (VÃÑ4[:,indices[l+1]:indices[l+2]-1])'*X
                end
            end
        elseif k == j
            X = (P[k+1].*P[j+1]/interval(4)).*(VÃÑ4[:,indices[i+1]:indices[i+2]-1])
            for l in (k+i+j+1, abs(2*(k+i-j)+1)√∑2, abs(2*(k-i-j)-1)√∑2, abs(2*(-k+i-j)-1)√∑2)
                if 2*l+1<=2*K+1 && l > i
                    Y = (VÃÑ4[:,indices[l+1]:indices[l+2]-1])'*X
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1, indices[i+1]:indices[i+2]-1] += Y
                    Gu[Threads.threadid()][indices[i+1]:indices[i+2]-1, indices[l+1]:indices[l+2]-1] += Y'    
                elseif 2*l+1<=2*K+1 && l == i
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1, indices[i+1]:indices[i+2]-1] += (VÃÑ4[:,indices[l+1]:indices[l+2]-1])'*X
                end
            end
        end
        count[Threads.threadid()] +=1
        if count[Threads.threadid()] %500 == 0
            # println((count[Threads.threadid()], Threads.threadid()))
        end
    end
    # println("test")
    return sum(Gu)
end

function multiply(A::Matrix{Interval{Float64}}, B::Matrix{Interval{Float64}})
    # multi-threaded matrix multiplication
    m, k = size(A)
    l, n = size(B)
    k == l || error("dimension mismatch")
    C = zeros(Interval{Float64}, (m,n))
    ind = shuffle(1:n)
    Threads.@threads for i in ind
        # println(i)
        C[:, i] = A[:, .!iszero.(B[:,i])]*B[.!iszero.(B[:,i]),i]
    end
    return C
end

function compute_norms(u)
    # computes ||exp(r¬≤/4)u¬≤œàÃÇ‚±º||¬≤
    U = VÃÑ6.*u'
    P = [sum(U[:,indices[k+1]:indices[k+2]-1], dims = 2)[:] for k=0:K‚ÇÄ-1]
    u_prod = [P[k+1].*P[j+1] for k=0:K‚ÇÄ-1, j=0:K‚ÇÄ-1]
    norms6 = zeros(Interval{Float64}, N)
    ind = shuffle(collect(0:K))
    Threads.@threads for i in ind
#         println((i, Threads.threadid()))
        for m = 0:Ms[i+1]
            freq_proj = zeros(Interval{Float64}, (N6, 3*K+2))
            W = VÃÑ6[:,indices[i+1]+m]
            for k=0:K‚ÇÄ-1, j=0:K‚ÇÄ-1
                X = W.*(u_prod[k+1,j+1])/interval(4)
                freq_proj[:,k+i+j+2] += X
                freq_proj[:,abs(2*(k+i-j)+1)√∑2+1] += X
                freq_proj[:,abs(2*(k-i-j)-1)√∑2+1] += X
                freq_proj[:,abs(2*(-k+i-j)-1)√∑2+1] += X
            end
            norms6[indices[i+1]+m] = sum(freq_proj.^2)
        end
    end
    return norms6
end

function proj(u)
    # computes P‚Çô(exp(r¬≤/4)u¬≥)
#     println("test")
    U = VÃÑ4.*u'
    P = [sum(U[:,indices[k+1]:indices[k+2]-1], dims = 2)[:] for k=0:K‚ÇÄ-1]
    Gu = [zeros(Interval{BigFloat},N) for i in 1:Threads.nthreads()]
    count = zeros(Int64, Threads.nthreads())
    ind = shuffle(vcat([(i,k,j) for i=0:K‚ÇÄ-1, k=0:K‚ÇÄ-1, j=0:K‚ÇÄ-1]...))
    Threads.@threads for index in ind
        (i, j, k) = index
        # println((i, k, j))
        if i > k && k > j
            X = P[i+1].*P[k+1].*P[j+1]*interval(3//2)
            for l in (k+i+j+1, abs(2*(k+i-j)+1)√∑2, abs(2*(k-i-j)-1)√∑2, abs(2*(-k+i-j)-1)√∑2)
                if 2*l+1<=2*K+1
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1] += (VÃÑ4[:,indices[l+1]:indices[l+2]-1])'*X
                    # println("test")
                end
            end
        elseif i > k && k==j
            X = P[i+1].*P[k+1].*P[j+1]*interval(3//4)
            for l in (k+i+j+1, abs(2*(k+i-j)+1)√∑2, abs(2*(k-i-j)-1)√∑2, abs(2*(-k+i-j)-1)√∑2)
                if 2*l+1<=2*K+1
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1] += (VÃÑ4[:,indices[l+1]:indices[l+2]-1])'*X
                end
            end
        elseif i == k && k>j
            X = P[i+1].*P[k+1].*P[j+1]*interval(3//4)
            for l in (k+i+j+1, abs(2*(k+i-j)+1)√∑2, abs(2*(k-i-j)-1)√∑2, abs(2*(-k+i-j)-1)√∑2)
                if 2*l+1<=2*K+1
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1] += (VÃÑ4[:,indices[l+1]:indices[l+2]-1])'*X
                end
            end
        elseif i == k && k==j
            X = P[i+1].*P[k+1].*P[j+1]/interval(4)
            for l in (k+i+j+1, abs(2*(k+i-j)+1)√∑2, abs(2*(k-i-j)-1)√∑2, abs(2*(-k+i-j)-1)√∑2)
                if 2*l+1<=2*K+1
                    Gu[Threads.threadid()][indices[l+1]:indices[l+2]-1] += (VÃÑ4[:,indices[l+1]:indices[l+2]-1])'*X
                end
            end
        end
        count[Threads.threadid()] +=1
        if count[Threads.threadid()] %500 == 0
            # println((count[Threads.threadid()], Threads.threadid()))
        end
    end
    return sum(Gu)
end

function L6(u)
    #computes ||(exp(r¬≤/4)u¬≥)||¬≥
    U = VÃÑ6.*u'
    P = [sum(U[:,indices[k+1]:indices[k+2]-1], dims = 2)[:] for k=0:K]
    freq_proj = zeros(Interval{BigFloat}, (N6, 3*K+2))
    for i=0:K‚ÇÄ-1, k=0:K‚ÇÄ-1, j=0:K‚ÇÄ-1
        if i > k && k > j
            X = P[k+1].*P[j+1].*P[i+1]*interval(3//2)
            freq_proj[:,k+i+j+2] += X
            freq_proj[:,abs(2*(k+i-j)+1)√∑2+1] += X
            freq_proj[:,abs(2*(k-i-j)-1)√∑2+1] += X
            freq_proj[:,abs(2*(-k+i-j)-1)√∑2+1] += X
        elseif i > k && k==j
            X = P[k+1].*P[j+1].*P[i+1]*interval(3//4)
            freq_proj[:,k+i+j+2] += X
            freq_proj[:,abs(2*(k+i-j)+1)√∑2+1] += X
            freq_proj[:,abs(2*(k-i-j)-1)√∑2+1] += X
            freq_proj[:,abs(2*(-k+i-j)-1)√∑2+1] += X
        elseif i == k && k>j
            X = P[k+1].*P[j+1].*P[i+1]*interval(3//4)
            freq_proj[:,k+i+j+2] += X
            freq_proj[:,abs(2*(k+i-j)+1)√∑2+1] += X
            freq_proj[:,abs(2*(k-i-j)-1)√∑2+1] += X
            freq_proj[:,abs(2*(-k+i-j)-1)√∑2+1] += X
        elseif i == k && k==j
            X = P[k+1].*P[j+1].*P[i+1]/interval(4)
            freq_proj[:,k+i+j+2] += X
            freq_proj[:,abs(2*(k+i-j)+1)√∑2+1] += X
            freq_proj[:,abs(2*(k-i-j)-1)√∑2+1] += X
            freq_proj[:,abs(2*(-k+i-j)-1)√∑2+1] += X
        end
    end
    return sum(freq_proj.^interval(2))
end

function L2(u)
    return sum(u.^interval(2))
end

# Proof

In [None]:
P≈´¬≥ = proj(≈´);

In [None]:
FuÃÑ‚àû = (L6(uÃÑ)- L2(P≈´¬≥))/(interval(4))

In [None]:
# switch to Float64 precision
LinearAlgebra.norm(v::Vector) = sqrt(sum(v.^2))

# regularised Vandermonde matrices for 2p and p+1 products
VÃÑ4 = interval.(Float64, deserialize("V4r"));
VÃÑ6 = interval.(Float64, deserialize("V6r"));

sups = deserialize("suppsi")

ùîè = Diagonal(reduce(vcat, [interval(k) .+ interval(3//2) .+ interval.(collect(0:Ms[k+1])) for k=0:K]))

p = interval(3)
d = 2
Œª‚Çò = maximum(ùîè) + interval(1)
Œª‚ÇÄ = interval(3//2);
œâ = interval(-5//2)
Œµ = interval(-1)
Z = interval(4)*interval(BigFloat, œÄ)

In [None]:
G = Gram(uÃÑ);

In [None]:
# P‚ÇôF(≈´)
PF≈´ = interval.(Float64, ≈´ - inv(ùîè)*((interval(d//4)-œâ/interval(2))*≈´+Œµ*P≈´¬≥/interval(2)));
# println(norm(PF≈´))

In [None]:
# P‚ÇôDF(≈´)P‚Çô
DF≈´ = interval.(I(N)) - (interval(d//4)-œâ/interval(2))*inv(ùîè) + interval(3)*Œµ*ùîè\G/interval(2);
# approximate numerical inverse of P‚ÇôDF(≈´)P‚Çô
A‚Çô = interval.(inv(mid.(DF≈´)));

In [None]:
Y = (sqrt(H2(A‚Çô*interval.(Float64, PF≈´))+FuÃÑ‚àû))

In [None]:
Z¬π¬π = op_norm(ùîè*(interval.(I(N)) - multiply(A‚Çô,DF≈´))*inv(ùîè));

In [None]:
int = compute_norms(uÃÑ);

In [None]:
w = sqrt.(abs.([int[i] - L2(G[i,:]) for i=1:N]));

In [None]:
Z¬≤¬π = p*norm(ùîè\w)/interval(2);

In [None]:
Z¬π¬≤ = p*norm(abs.(ùîè*A‚Çô*inv(ùîè))*w)/Œª‚Çò/interval(2);

In [None]:
supœÜÃÑ = L‚àû(uÃÑ)
Z¬≤¬≤ = (abs.(interval(d//4)-œâ/interval(2))+p*supœÜÃÑ/interval(2))/interval(Œª‚Çò);

In [None]:
[Z¬π¬π Z¬π¬≤ ; Z¬≤¬π Z¬≤¬≤]

In [None]:
Z‚ÇÅ = op_norm([Z¬π¬π Z¬π¬≤ ; Z¬≤¬π Z¬≤¬≤])

In [None]:
if sup(Z‚ÇÅ)>=1
    println("≈´ cannot be validated")
end

In [None]:
# Calculation of the L¬≤-norm of the Hessian of Œ∑ and of the sup of Œ∑
C‚ÇÄ = I"0.56419"
C‚ÇÅ = I"0.79789"
C‚ÇÇ = I"0.23033"
C2 = sqrt(Z)*(interval(2//d)*C‚ÇÄ + sqrt(interval(2//d))*C‚ÇÅ + C‚ÇÇ)

In [None]:
op_n = interval(max(sup(op_norm(ùîè*A‚Çô*inv(ùîè))), 1))
Z‚ÇÇ = interval(3)*C2*supœÜÃÑ*op_n
Z‚ÇÉ = interval(3)*C2^2*op_n

In [None]:
Œ¥ÃÑ = (-Z‚ÇÇ+sqrt(Z‚ÇÇ^2 +interval(2)*Z‚ÇÉ-interval(2)*Z‚ÇÅ*Z‚ÇÉ))/Z‚ÇÉ

In [None]:
P(Œ¥) = Z‚ÇÉ/interval(6)*Œ¥^3 + Z‚ÇÇ/interval(2)*Œ¥^2 - (interval(1) - Z‚ÇÅ)*Œ¥ + Y

In [None]:
Œ¥Ã≤ = Y/(interval(1)-Z‚ÇÅ)*(interval(1)+interval(BigFloat,2)^(-21))
if sup(P(Œ¥Ã≤))<0
    println("Œ¥Ã≤ is validated")
end

In [None]:
Œ∑ = sqrt(interval(BigFloat,7)*Z)*Œ¥Ã≤/interval(2)

# Plotting

In [None]:
rloc = big.(collect(0:700))/100;

œÜ·µ£ = zeros((K‚ÇÄ, 701))
for k in 0:K‚ÇÄ-1
    # println(k)
    p = ((rloc/2).^(2*k+1)).*exp.(-rloc.^2/8)
    for m=0:big(Ms[k+1])
        # println((k,m))
        L‚Çò = Polynomial([(-1)^j*big(binomial(m+2*k+1,m-j)//factorial(j)) for j=0:m])
        Z = sqrt(big(factorial(m+2*k+1)//factorial(m)))
        œÜ·µ£[k+1,:] += mid(uÃÑ[indices[k+1]+m])*p.* L‚Çò.(rloc.^2/4)/Z
    end
    
end

In [None]:
using Plots

In [None]:
œë = big.(collect(0:100))/100*2*big(œÄ)
c = cos.(œë.*collect(1:2:2*K‚ÇÄ-1)')

In [None]:
surface(rloc.*cos.(œë'),rloc.*sin.(œë'), sqrt(big(2))*(c*œÜ·µ£)', xlabel = L"$x$", ylabel = L"$y$", zlabel = L"$\overline{\varphi}\,(x, y)$", colorbar = false, dpi = 800)

In [None]:
png("asymmetric_schrodinger")