In [1]:
using JuMP
using MosekTools
using DynamicPolynomials
using MultivariatePolynomials
using TSSOS
using LinearAlgebra, Random, Plots, Distributions, IterTools, Combinatorics, CSV, Statistics, MLDatasets, DataFrames, Revise, Clustering, Distances
includet("UnivariateModels.jl")

In [None]:
using MLDatasets

"""
    load_mnist_subset(digits; normalize=true)

Load MNIST *training* images for the given `digits` (e.g. `[1,4,7]`).

Returns `(X, y)` where:
- `X` is an `N × 784` Float32 matrix (N = total # of selected images),
  each row is a flattened 28×28 image.
- `y` is a Vector{Int} of the corresponding labels.

Keyword:
- `normalize=true` scales pixel values to [0,1] by dividing by 255.
"""
function load_mnist_subset(digits::AbstractVector{<:Integer}; normalize::Bool=true)
    # sanity
    @assert all(0 .≤ digits .≤ 9) "digits must be between 0 and 9"

    imgs, labs = MNIST.traindata()          # imgs: 28×28×60000 UInt8, labs: 60000-element Vector{UInt8}

    # mask for the chosen digits
    mask = falses(length(labs))
    @inbounds for d in digits
        mask .|= (labs .== d)
    end

    # select and flatten
    sel = imgs[:, :, mask]                  # 28×28×N
    N   = size(sel, 3)
    X   = reshape(Float32.(sel), 28*28, N)' # N×784
    if normalize
        X ./= 255f0
    end
    y = Int.(labs[mask])

    return X, y
end


In [None]:
digs = [0, 1, 2]
X, y = load_mnist_subset(digs; normalize=false)
@show size(X)
@show map(d -> count(==(d), y), digs)
@show extrema(X), count(!=(0f0), X) / length(X)
@show counts = map(d -> sum(y .== d), digs)  # number of images per digit

In [None]:
@polyvar m
@polyvar sigma
Sm=[(m-0.0001)*(1.0001-m)]
Ssig=[(sigma-0.0001)*(1.000-sigma)]
S=[vcat(Sm)...,vcat(Ssig)...]
println()
println("Support of the mixing measure")
S_normalized=[S[i]/maximum(abs.(coefficients(S[i]))) for i=1:length(S)]
display(S_normalized)

In [None]:
max_order=4
RESW2=[]
for dim=1:784
    relax=[]
    for d = 1:max_order
        println("  d = $d")
        push!(relax, univariate_SOS_model_Gaussian_W2(d, m, sigma, S_normalized, X[:, dim], true, 0.00001))
    end
    push!(RESW2,relax)
end

In [None]:
max_order=4
RESW2bige=[]
for dim=1:784
    relax=[]
    for d = 1:max_order
        println("  d = $d")
        push!(relax, univariate_SOS_model_Gaussian_W2(d, m, sigma, S_normalized, X[:, dim], true, 0.1))
    end
    push!(RESW2bige,relax)
end

In [None]:
A=[]
for i=1:784
    println("dimension = ",i)
    push!(A, analyse_relaxations(RESW2[i],4, 1));
end

In [None]:
Abige=[]
for i=1:784
    println("dimension = ",i)
    push!(Abige, analyse_relaxations(RESW2bige[i],4, 1));
end

In [None]:
using LinearAlgebra
using Plots
import StatsBase: countmap, mode  #


function numeric_rank(M; reltol=1e-6)
    vals = eigvals(Symmetric(Matrix(M)))
    thr = reltol * maximum(abs, vals)
    return count(>(thr), abs.(vals))
end


ranks = [rank_by_energy(A[i][1]; energy_tol=1e-6) for i in eachindex(A)]

# summary
minr, maxr = minimum(ranks), maximum(ranks)
mrank       = StatsBase.mode(ranks)              
cm          = countmap(ranks)                    

println("min rank = $minr, max rank = $maxr, mode = $mrank")
for k in sort(collect(keys(cm)))
    println("rank $k : ", cm[k])
end
cm = countmap(ranks)
xs = sort!(collect(keys(cm)))
ys = [cm[x] for x in xs]

# plot the sticks
plot(xs, ys;
     seriestype = :sticks,
     marker = :circle, ms = 5, lw = 3,
     xticks = xs,
     xlims = (minimum(xs)-0.5, maximum(xs)+0.5),
     xlabel = "Estimated order", ylabel = "Number of dimensions",
     #title = "Rank frequencies across 784 pixels",
     legend = false)

dy = max(5, 0.03*maximum(ys)) 
annotate!([(x, y + dy, text(string(y), 9, :center, :bottom)) for (x,y) in zip(xs, ys)]...)
ylims!(0, maximum(ys) + 4dy)   

In [None]:
using LinearAlgebra
using Plots
import StatsBase: countmap, mode  

# numerical rank via eigenvalues with relative tolerance
function numeric_rank(M; reltol=1e-6)
    vals = eigvals(Symmetric(Matrix(M)))
    thr = reltol * maximum(abs, vals)
    return count(>(thr), abs.(vals))
end


ranks = [rank_by_energy(Abige[i][1]; energy_tol=1e-6) for i in eachindex(Abige)]

# summary
minr, maxr = minimum(ranks), maximum(ranks)
mrank       = StatsBase.mode(ranks)              # qualified to avoid conflicts
cm          = countmap(ranks)                    # Dict{Int,Int}: rank -> frequency

println("min rank = $minr, max rank = $maxr, mode = $mrank")
for k in sort(collect(keys(cm)))
    println("rank $k : ", cm[k])
end
cm = countmap(ranks)
xs = sort!(collect(keys(cm)))
ys = [cm[x] for x in xs]

# plot the sticks
plot(xs, ys;
     seriestype = :sticks,
     marker = :circle, ms = 5, lw = 3,
     xticks = xs,
     xlims = (minimum(xs)-0.5, maximum(xs)+0.5),
     xlabel = "Estimated order", ylabel = "Number of dimensions",
     #title = "Rank frequencies across 784 pixels",
     legend = false)

# offset labels upward to avoid overlap
dy = max(5, 0.03*maximum(ys))  # 3% of max (at least 5 counts)
annotate!([(x, y + dy, text(string(y), 9, :center, :bottom)) for (x,y) in zip(xs, ys)]...)
ylims!(0, maximum(ys) + 4dy)   


In [35]:
using LinearAlgebra
using Plots
using StatsBase


In [None]:
max_order=4
RESTV=[]
for dim=1:784
    relax=[]
    for d = 1:max_order
        println("  d = $d")
        push!(relax, univariate_SOS_model_Gaussian_TV(d, m, sigma, S_normalized, X[:, dim], true, 0.00001))
    end
    push!(RESTV,relax)
end

In [None]:
max_order=4
RESTVbige=[]
for dim=1:784
    relax=[]
    for d = 1:max_order
        println("  d = $d")
        push!(relax, univariate_SOS_model_Gaussian_TV(d, m, sigma, S_normalized, X[:, dim], true, 0.1))
    end
    push!(RESTVbige,relax)
end

In [None]:
B=[]
for i=1:784
    println("dimension = ",i)
    push!(B, analyse_relaxations(RESTV[i],4, 1));
end
#extract_CF(TW2[1], TW2[end], size(TW2[2],1), 1, 1)

In [None]:
Bbige=[]
for i=1:784
    println("dimension = ",i)
    push!(Bbige, analyse_relaxations(RESTVbige[i],4, 1));
end
#extract_CF(TW2[1], TW2[end], size(TW2[2],1), 1, 1)

In [None]:
#ranks = [numeric_rank(B[i][1]; reltol=1e-3) for i in eachindex(B)]
ranks = [rank_by_energy(B[i][1]; energy_tol=1e-6) for i in eachindex(B)]

# summary
minr, maxr = minimum(ranks), maximum(ranks)
mrank       = StatsBase.mode(ranks)              
cm          = countmap(ranks)                    
println("min rank = $minr, max rank = $maxr, mode = $mrank")
for k in sort(collect(keys(cm)))
    println("rank $k : ", cm[k])
end
cm = countmap(ranks)
xs = sort!(collect(keys(cm)))
ys = [cm[x] for x in xs]

# plot the sticks
plot(xs, ys;
     seriestype = :sticks,
     marker = :circle, ms = 5, lw = 3,
     xticks = xs,
     xlims = (minimum(xs)-0.5, maximum(xs)+0.5),
     xlabel = "Estimated order", ylabel = "Number of dimensions",
     #title = "Rank frequencies across 784 pixels",
     legend = false)

# offset labels upward to avoid overlap
dy = max(5, 0.03*maximum(ys))  # 3% of max (at least 5 counts)
annotate!([(x, y + dy, text(string(y), 9, :center, :bottom)) for (x,y) in zip(xs, ys)]...)
ylims!(0, maximum(ys) + 4dy)   # give headroom


In [None]:
#ranks = [numeric_rank(B[i][1]; reltol=1e-3) for i in eachindex(B)]
ranks = [rank_by_energy(Bbige[i][1]; energy_tol=1e-6) for i in eachindex(Bbige)]

# summary
minr, maxr = minimum(ranks), maximum(ranks)
mrank       = StatsBase.mode(ranks)              # qualified to avoid conflicts
cm          = countmap(ranks)                    # Dict{Int,Int}: rank -> frequency

println("min rank = $minr, max rank = $maxr, mode = $mrank")
for k in sort(collect(keys(cm)))
    println("rank $k : ", cm[k])
end
cm = countmap(ranks)
xs = sort!(collect(keys(cm)))
ys = [cm[x] for x in xs]

# plot the sticks
plot(xs, ys;
     seriestype = :sticks,
     marker = :circle, ms = 5, lw = 3,
     xticks = xs,
     xlims = (minimum(xs)-0.5, maximum(xs)+0.5),
     xlabel = "Estimated order", ylabel = "Number of dimensions",
     #title = "Rank frequencies across 784 pixels",
     legend = false)

# offset labels upward to avoid overlap
dy = max(5, 0.03*maximum(ys))  # 3% of max (at least 5 counts)
annotate!([(x, y + dy, text(string(y), 9, :center, :bottom)) for (x,y) in zip(xs, ys)]...)
ylims!(0, maximum(ys) + 4dy)   # give headroom
