In [151]:
using JuMP
using MosekTools
using DynamicPolynomials
using MultivariatePolynomials
using TSSOS, GaussianMixtures
using LinearAlgebra, Random, Plots, Distributions, IterTools, Combinatorics, CSV, Statistics, MLDatasets, DataFrames, Revise, Clustering, Distances, Colors
includet("Functions_Mixtures.jl")
includet("SeparationExperiment.jl")

### Locally necessary functions

In [273]:
function find_nearest_indices(X::Matrix{Float64}, centers::Matrix{Float64})
    indices = Int[]
    for c in eachrow(centers')
        dists = [sum((X[i, :] .- c).^2) for i in 1:size(X, 1)]
        push!(indices, argmin(dists))
    end
    return indices
end
    using StatsBase  # for countmap, combinations

function adjusted_rand_index(labels_true::Vector{Int}, labels_pred::Vector{Int})
    @assert length(labels_true) == length(labels_pred)

    n = length(labels_true)
    unique_true = sort(unique(labels_true))
    unique_pred = sort(unique(labels_pred))
    n_true = length(unique_true)
    n_pred = length(unique_pred)

    # Confusion matrix: n_ij
    contingency = zeros(Int, n_true, n_pred)
    for (l_true, l_pred) in zip(labels_true, labels_pred)
        i = findfirst(isequal(l_true), unique_true)
        j = findfirst(isequal(l_pred), unique_pred)
        contingency[i, j] += 1
    end

    # Row and column sums (a_i, b_j)
    a = sum(contingency, dims=2)  # true cluster sizes
    b = sum(contingency, dims=1)  # predicted cluster sizes

    # Helper: binomial(n, 2)
    comb2(x) = x < 2 ? 0 : x * (x - 1) ÷ 2

    # ∑_ij C(n_ij, 2)
    index = sum(comb2(nij) for nij in contingency)

    # ∑_i C(a_i, 2), ∑_j C(b_j, 2)
    sum_ai = sum(comb2(ai) for ai in a)
    sum_bj = sum(comb2(bj) for bj in b)

    expected_index = sum_ai * sum_bj / comb2(n)
    max_index = (sum_ai + sum_bj) / 2

    # Adjusted Rand Index
    return (index - expected_index) / (max_index - expected_index)
end
function align_labels(true_labels::Vector{Int}, pred_labels::Vector{Int})
    classes = sort(unique(true_labels))
    perms = collect(permutations(classes))

    best_score = -1
    best_aligned = similar(pred_labels)

    for p in perms
        mapping = Dict(classes[i] => p[i] for i in eachindex(classes))
        aligned = [mapping[l] for l in pred_labels]
        score = sum(aligned .== true_labels)

        if score > best_score
            best_score = score
            best_aligned .= aligned
        end
    end

    return best_aligned
end

using GaussianMixtures 
using DataFrames
using DataFrames: groupby
using StatsBase       
using StatsPlots 

function random_means()
  # e.g. sample k distinct data points
  idx = sample(1:size(X,1), k; replace=false)
  return X[idx, :]
end

function run_em(μ0; X, nIter=100, tol=1e-6, varfloor=0.0001)
    # 1) materialize any Adjoint into a real matrix
    μ0 = Matrix(μ0)
    k, d = size(μ0)

    # 2) compute data‐based diag variances
    σ2 = vec(var(X; dims=1))              # length-d Vector
    Σ0 = repeat(σ2', k, 1)                # k×d Matrix

    # 3) build GMM (skip k‑means by nInit=0)
    g = GMM(k, X; kind=:diag, nInit=0)

    # 4) overwrite init
    g.μ .= μ0
    g.Σ .= Σ0
    g.w .= fill(1/k, k)

    # 5) run EM up to nIter, with varfloor control
    logl = em!(g, X; nIter=nIter, varfloor=varfloor)

    # 6) find when |Δℓ|<tol
    idx = findfirst(i -> abs(logl[i] - logl[i-1]) < tol, 2:length(logl))
    iters = idx === nothing ? length(logl) : idx

    return (
      iterations = iters,
      final_ll   = logl[end],
      logl       = logl,
      model      = g,
    )
end

run_em (generic function with 1 method)

## K=5, relatively well separated, non spherical

#### Data generation

In [165]:
c = 5.0
ecc= 0.25
K=5
n=2
nb_parameter_choices = 50
seed_parameters=1
Random.seed!(seed_parameters)
gmms_50_025_5 = generate_multiple_gmms_heteroscedastic(nb_parameter_choices, K, n; ecc=ecc, c=c);

In [167]:
N = 1000
seed_data = 10
Random.seed!(seed_data)

all_data = Vector{Matrix{Float64}}(undef, length(gmms_50_025_5))
all_labels = Vector{Vector{Int}}(undef, length(gmms_50_025_5))

for mix_index in 1:length(gmms_50_025_5)
    mix = gmms_50_025_5[mix_index]

    # Convert means to Vector of Vectors
    means = [mix.means[:, i] for i in 1:size(mix.means, 2)]
    covariances = mix.covariances
    weights = mix.weights
    k = length(means)

    # Unique seed per configuration (optional)
    seed_i = seed_data + mix_index

    # Generate data
    samples, labels = generate_gaussian_mixtures(
        k, means, covariances, weights;
        seed=seed_i, n_samples=N
    )

    all_data[mix_index] = samples'
    all_labels[mix_index] = labels
end

In [169]:
n_configs = length(all_data)
n_rows = 10
n_cols = 5

plot_list = []

for i in 1:n_configs
    samples = all_data[i]'
    labels = all_labels[i]
    
    # Normalize
    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)
    
    # Optional: compute empirical mean per label
    unique_labels = sort(unique(labels))
    means_per_label = [mean(samples_scaled[labels .== l, :], dims=1) for l in unique_labels]

    # Scatter plot
    p = scatter(samples_scaled[:, 1], samples_scaled[:, 2],
                group=labels,
                markersize=2, alpha=0.6, legend=false,
                #xlabel="x₁", ylabel="x₂",
                title="Mixture $i", xlims=(-0.2, 1.2), ylims=(-0.2, 1.2))

    # Overlay cluster means
    for m in means_per_label
        scatter!(p, [m[1]], [m[2]], color=:yellow, marker=:circle, ms=6)
    end

    push!(plot_list, p)
end

# Display grid layout
#plot(plot_list..., layout=(n_rows, n_cols), size=(1400, 2400))

#### $S_{\theta}$ description

In [None]:
n=2
@polyvar m[1:n]
@polyvar sigma[1:n]
Sm=[(m[1])*(1-m[1]), (m[2])*(1-m[2])]
Ssig=[(sigma[1]-0.05)*(0.5-sigma[1]), (sigma[2]-0.05)*(0.5-sigma[2])]
S=[vcat(Sm)...,vcat(Ssig)...]
println()
println("Support of the mixing measure")
S_normalized=[S[i]/maximum(abs.(coefficients(S[i]))) for i=1:length(S)]
display(S_normalized)

#### W2

In [None]:
trace_penalization=true
vareps=1e-3
max_order = 4


all_results_NS = Vector{Vector}(undef, length(all_data))

for idx in 1:length(all_data)
    println("\n>>> Running on GMM config $idx")

    samples = Matrix(transpose(all_data[idx]))
    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)

    res = []

    for d = 1:max_order
        println("  d = $d")
        push!(res, multivariate_Gaussian_W2(n, d, m, sigma, S_normalized, samples_scaled, trace_penalization, vareps))
    end

    all_results_NS[idx] = res
end

In [None]:
d=max_order
RES_NS =[]
energy_tol=1e-4

for i in  1:length(all_data)
    println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flatness check for mixture $i")
    push!(RES_NS,analyse_relaxations_W2orTV(all_results_NS[i], d,n,energy_tol))
    println("___________________________________________________________________")
    println()
end

#### TV

In [None]:
trace_penalization=true
vareps=1e-3
max_order = 4

all_resultsTV_NS = Vector{Vector}(undef, length(all_data))

for idx in 1:length(all_data)
    println("\n>>> Running on GMM config $idx")

    samples = Matrix(transpose(all_data[idx]))
    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)

    resTV = []

    for d = 1:max_order
        println("  d = $d")
        push!(resTV, multivariate_Gaussian_TV(n, d, m, sigma, S_normalized, samples_scaled, trace_penalization, vareps))
    end

    all_resultsTV_NS[idx] = resTV
end

In [None]:
d=max_order
RESTV_NS =[]
energy_tol=1e-3
for i in  1:length(all_data)
    println(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Flatness check for mixture $i")
    push!(RESTV_NS,analyse_relaxations_W2orTV(all_resultsTV_NS[i], d,n,energy_tol))
    println("___________________________________________________________________")
    println()
end

#### Extraction

In [None]:
rank_mm = 5
d = max_order  # final relaxation order you used earlier

plot_list = []

CF_points_NS=[]
true_means=[]
for i in 1:length(all_data)
    
    # Get samples and labels
    samples = all_data[i]' |> Matrix  # now shape is 1000 × 2
    labels = all_labels[i]

    # Extract RES for this config
    M = RES_NS[i][1]  # moment matrix
    L = RES_NS[i][end]  # last submatrix (or adjust as needed)

    # Curto-Fialkow flat extension extraction
    curto_f_points_NS = extract_CF(M, L, binomial(n + d - 1, d - 1), n, rank_mm)
    sorted_cf_NS = sort(curto_f_points_NS, by = p -> p[1])
    CF_NS = hcat(sorted_cf_NS...)  # 2 × r matrix
    push!(CF_points_NS,CF_NS)

    # Normalize samples
    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)

    # Empirical means from true labels
    cluster_means = [vec(mean(samples_scaled[labels .== k, :], dims=1)) for k in sort(unique(labels))]
    sorted_means = sort(cluster_means, by = m -> m[1])
    trum = hcat(sorted_means...)
    push!(true_means, trum)

    

    # Plotting
    p = scatter(samples_scaled[:,1], samples_scaled[:,2],
                group=labels, markersize=2, alpha=0.6, legend=false,
                title="Mixture $i in ℝ²",
                #xlabel="X₁", ylabel="X₂", 
                xlims=(-0.2, 1.2), ylims=(-0.2, 1.2))
    
    scatter!(p, trum[1, :], trum[2, :], marker=(:o, 8), color=:yellow, label="True Means")
    scatter!(p, CF_NS[1, :], CF_NS[2, :], marker=(:s, 6), color=:white, label="CDK Means")

    push!(plot_list, p)
end

#plot(plot_list..., layout=(10, 5), size=(1400, 2400))

In [241]:
rank_mm = 5
d = max_order  # final relaxation order you used earlier

plot_listTV_NS = []

CF_pointsTV_NS=[]
true_means=[]
for i in 1:length(all_data)
    
    # Get samples and labels
    samples = all_data[i]' |> Matrix  # now shape is 1000 × 2
    labels = all_labels[i]

    # Extract RES for this config
    M = RESTV_NS[i][1]  # moment matrix
    L = RESTV_NS[i][end]  # last submatrix (or adjust as needed)

    # Curto-Fialkow flat extension extraction
    curto_f_pointsTV_NS = extract_CF(M, L, binomial(n + d - 1, d - 1), n, rank_mm)
    sorted_cfTV_NS = sort(curto_f_pointsTV_NS, by = p -> p[1])
    CFTV_NS = hcat(sorted_cfTV_NS...)  # 2 × r matrix
    push!(CF_pointsTV_NS,CFTV_NS)

    # Normalize samples
    samples_scaled = hcat(scale_to_minus1_1.(eachcol(samples))...)

    # Empirical means from true labels
    cluster_means = [vec(mean(samples_scaled[labels .== k, :], dims=1)) for k in sort(unique(labels))]
    sorted_means = sort(cluster_means, by = m -> m[1])
    trum = hcat(sorted_means...)
    push!(true_means, trum)

    

    # Plotting
    p = scatter(samples_scaled[:,1], samples_scaled[:,2],
                group=labels, markersize=2, alpha=0.6, legend=false,
                title="Mixture $i in ℝ²",
                #xlabel="X₁", ylabel="X₂", 
                xlims=(-0.2, 1.2), ylims=(-0.2, 1.2))
    
    scatter!(p, trum[1, :], trum[2, :], marker=(:o, 8), color=:yellow, label="True Means")
    scatter!(p, CFTV_NS[1, :], CFTV_NS[2, :], marker=(:s, 6), color=:white, label="CDK Means")

    push!(plot_listTV_NS, p)
end

#plot(plot_list..., layout=(10, 5), size=(1400, 2400))

#### Impact on $k$-means

In [245]:
results_NS = []
Random.seed!(123)

for i in 1:nb_parameter_choices
    # Get mixture data & labels
    X = all_data[i]' |> Matrix
    X_scaled = hcat(scale_to_minus1_1.(eachcol(X))...)  # Normalization
    labels = all_labels[i]
    k = length(unique(labels))
    
    # --- CF Initialization ---
    cf_centers_NS = CF_points_NS[i]
    cf_centers_NS = size(cf_centers_NS, 1) == 2 ? cf_centers_NS : cf_centers_NS'
    cf_indices_NS = find_nearest_indices(X_scaled, cf_centers_NS)
    result_cf_NS = kmeans(X_scaled', k; init=cf_indices_NS, maxiter=100, display=:none)
    ARI_cf_NS = adjusted_rand_index(labels, result_cf_NS.assignments)
    obj_cf_NS = result_cf_NS.totalcost
    iter_cf_NS = result_cf_NS.iterations
    mis_cf_NS = sum(align_labels(labels, result_cf_NS.assignments) .!= labels)

    # --- CFTV Initialization ---
    cf_centersTV_NS = CF_pointsTV_NS[i]
    cf_centersTV_NS = size(cf_centersTV_NS, 1) == 2 ? cf_centersTV_NS : cf_centersTV_NS'
    cf_indicesTV_NS = find_nearest_indices(X_scaled, cf_centersTV_NS)
    result_cfTV_NS = kmeans(X_scaled', k; init=cf_indicesTV_NS, maxiter=100, display=:none)
    ARI_cfTV_NS = adjusted_rand_index(labels, result_cfTV_NS.assignments)
    obj_cfTV_NS = result_cfTV_NS.totalcost
    iter_cfTV_NS = result_cfTV_NS.iterations
    mis_cfTV_NS = sum(align_labels(labels, result_cfTV_NS.assignments) .!= labels)

    # --- Random Initialization (repeat N times) ---
    N = 100
    objs_rnd = Float64[]
    iters_rnd = Int[]
    ARIs_rnd = Float64[]
    mis_rnd = Int[]
    
    for rep in 1:N
        rand_indices = rand(1:size(X_scaled, 1), k)
        result_rnd = kmeans(X_scaled', k; init=rand_indices, maxiter=100, display=:none)
        ARI_rnd = adjusted_rand_index(labels, result_rnd.assignments)
        push!(objs_rnd, result_rnd.totalcost)
        push!(iters_rnd, result_rnd.iterations)
        push!(ARIs_rnd, ARI_rnd)
        push!(mis_rnd, sum(align_labels(labels, result_rnd.assignments) .!= labels))
    end

    # Store summary (including full list of iterations!)
    push!(results_NS, (
        i = i,
        obj_cf_NS = obj_cf_NS,
        iter_cf_NS = iter_cf_NS,
        ARI_cf_NS = ARI_cf_NS,
        mis_cf_NS = mis_cf_NS,
        obj_cfTV_NS = obj_cfTV_NS ,
        iter_cfTV_NS  = iter_cfTV_NS ,
        ARI_cfTV_NS  = ARI_cfTV_NS ,
        mis_cfTV_NS  = mis_cfTV_NS ,
        objs_rnd = objs_rnd,
        iters_rnd = iters_rnd,
        ARIs_rnd = ARIs_rnd,
        mis_rnd = mis_rnd,
        mean_obj_rnd = mean(objs_rnd),
        mean_iter_rnd = mean(iters_rnd),
        mean_ARI_rnd = mean(ARIs_rnd),
        mean_mis_rnd = mean(mis_rnd),
        std_iter_rnd = std(iters_rnd)  # precompute if needed for plotting
    ))
end

In [None]:
# Extract data
iter_cf_NS = [r.iter_cf_NS for r in results_NS]
iter_cfTV_NS = [r.iter_cfTV_NS for r in results_NS]
iter_rnd_mean = [mean(r.iters_rnd) for r in results_NS]
iter_rnd_std = [std(r.iters_rnd) for r in results_NS]

# Sort by CF iterations
sort_idx_NS = sortperm(iter_cf_NS)
iter_cf_sorted_NS = iter_cf_NS[sort_idx_NS]
iter_rnd_mean_sorted = iter_rnd_mean[sort_idx_NS]
iter_rnd_std_sorted = iter_rnd_std[sort_idx_NS]
iter_cfTV_sorted_NS=iter_cfTV_NS[sort_idx_NS]

mix_ids_sorted = 1:length(results_NS)  # new x-axis = sorted mixture index

# Plot CF line
plot(mix_ids_sorted, iter_cf_sorted_NS;
    label="Curto-Fialkow-W2",
    lw=2, marker=:circle,
    xlabel="Mixture (sorted by CF)", ylabel="Iterations",
    #title="K-means Iterations: CF vs Random Initialization",
    legend=:topright,
    size=(900, 600))

plot!(mix_ids_sorted, iter_cfTV_sorted_NS;
    label="Curto-Fialkow-TV",
    lw=2, marker=:diamond,
    color=:red
   )


# Add Random line with ribbon
plot!(mix_ids_sorted, iter_rnd_mean_sorted;
    ribbon=iter_rnd_std_sorted,
    label="Random ± 1 std",
    lw=2, marker=:square,
   color=:orange,
)

In [None]:
# Extract misclassification data
mis_cf_NS = [r.mis_cf_NS for r in results_NS]
mis_cfTV_NS = [r.mis_cfTV_NS for r in results_NS]
mis_rnd_mean = [mean(r.mis_rnd) for r in results_NS]
mis_rnd_std = [std(r.mis_rnd) for r in results_NS]


# Sort by CF misclassification
sort_idx_NS = sortperm(mis_cf_NS)
mis_cf_sorted_NS = mis_cf_NS[sort_idx_NS]
mis_cfTV_sorted_NS = mis_cfTV_NS[sort_idx_NS]
mis_rnd_mean_sorted = mis_rnd_mean[sort_idx_NS]
mis_rnd_std_sorted = mis_rnd_std[sort_idx_NS]

mix_ids_sorted = 1:length(results_NS)

# Plot CF misclassifications
plot(mix_ids_sorted, mis_cf_sorted_NS;
    label="Curto_Fialkow",
    lw=2, marker=:circle,
    xlabel="Mixture (sorted by CF)", ylabel="Misclassifications",
    title="Misclassifications: CF vs Random Initialization",
    legend=:topright,
    size=(900, 600))

plot!(mix_ids_sorted, mis_cfTV_sorted_NS;
    label="Curto-Fialkow-TV",
    lw=2, marker=:diamond,
    color=:red
   )


# Add Random line with ribbon
plot!(mix_ids_sorted, mis_rnd_mean_sorted;
    ribbon=mis_rnd_std_sorted,
    label="Random ± 1 std",
    lw=2, marker=:square,
   color=:orange,
)

In [None]:
using Plots, LaTeXStrings, PGFPlotsX, Statistics
pgfplotsx()  # TikZ backend

# --- data (NS) ---
iter_cf_NS        = [r.iter_cf_NS   for r in results_NS]
iter_cfTV_NS      = [r.iter_cfTV_NS for r in results_NS]
iter_rnd_mean_NS  = [mean(r.iters_rnd) for r in results_NS]
iter_rnd_std_NS   = [std(r.iters_rnd)  for r in results_NS]

# Sort by CF iterations (NS)
sort_idx_NS              = sortperm(iter_cf_NS)
iter_cf_sorted_NS        = iter_cf_NS[sort_idx_NS]
iter_cfTV_sorted_NS      = iter_cfTV_NS[sort_idx_NS]
iter_rnd_mean_sorted_NS  = iter_rnd_mean_NS[sort_idx_NS]
iter_rnd_std_sorted_NS   = iter_rnd_std_NS[sort_idx_NS]
mix_ids_sorted_NS        = 1:length(results_NS)

# --- styling to match papers ---
default(
    size=(420,280),
    grid=false,
    framestyle=:box,
    legend=:topright,
    linewidth=2,
    markerstrokewidth=0.8,
)

plt = plot(mix_ids_sorted_NS, iter_cf_sorted_NS;
    label=L"\text{Curto–Fialkow–W2}",
    marker=:circle,
    xlabel=L"\text{Mixture (sorted by CF)}",
    ylabel=L"\text{Iterations}",
)

plot!(plt, mix_ids_sorted_NS, iter_cfTV_sorted_NS;
    label=L"\text{Curto–Fialkow–TV}",
    marker=:diamond,
)

plot!(plt, mix_ids_sorted_NS, iter_rnd_mean_sorted_NS;
    ribbon=iter_rnd_std_sorted_NS,
    label=L"\text{Random} \pm \text{ 1 std}",
    marker=:square,
)

savefig(plt, "iterations_K5kmeans.tikz")

In [None]:
using Plots, LaTeXStrings, PGFPlotsX, Statistics
pgfplotsx()  # TikZ backend

# --- data (NS) ---
mis_cf_NS        = [r.mis_cf_NS    for r in results_NS]
mis_cfTV_NS      = [r.mis_cfTV_NS  for r in results_NS]
mis_rnd_mean_NS  = [mean(r.mis_rnd) for r in results_NS]
mis_rnd_std_NS   = [std(r.mis_rnd)  for r in results_NS]

# Sort by CF misclassification (NS)
sort_idx_NS               = sortperm(mis_cf_NS)
mis_cf_sorted_NS          = mis_cf_NS[sort_idx_NS]
mis_cfTV_sorted_NS        = mis_cfTV_NS[sort_idx_NS]
mis_rnd_mean_sorted_NS    = mis_rnd_mean_NS[sort_idx_NS]
mis_rnd_std_sorted_NS     = mis_rnd_std_NS[sort_idx_NS]
mix_ids_sorted_NS         = 1:length(results_NS)

# --- styling (same as previous figure) ---
default(
    size=(420,280),
    grid=false,
    framestyle=:box,
    legend=:topright,
    linewidth=2,
    markerstrokewidth=0.8,
    markersize=3,
)

plt = plot(mix_ids_sorted_NS, mis_cf_sorted_NS;
    label=L"\text{Curto–Fialkow–W2}",
    marker=:circle,
    xlabel=L"\text{Mixture (sorted by CF)}",
    ylabel=L"\text{Misclassifications}",
)

plot!(plt, mix_ids_sorted_NS, mis_cfTV_sorted_NS;
    label=L"\text{Curto–Fialkow–TV}",
    marker=:diamond,
)

plot!(plt, mix_ids_sorted_NS, mis_rnd_mean_sorted_NS;
    ribbon=mis_rnd_std_sorted_NS,
    label=L"\text{Random} \pm \text{ 1 std}",
    marker=:square,
)

savefig(plt, "misclassifications_K5kmeans.tikz")

#### EM

In [None]:
 # for results and plotting

Random.seed!(1)

nrand = 100    # # of random restarts per dataset
nIter = 100    # max EM iter
tol   = 1e-6  # your convergence tol

# prepare an empty results table
results = DataFrame(
  dataset    = Int[],       # which dataset index
  init       = String[],    # "CF" or "random"
  iterations = Int[],       
  final_ll   = Float64[],
)

for i in eachindex(all_data)
  # --- 1) pull + scale this dataset
  X      = Matrix(all_data[i]')                  # N×2
  Xs     = hcat(scale_to_minus1_1.(eachcol(X))...)
  k      = size(CF_points_NS[i], 2) 
  μcf    = Matrix(CF_points_NS[i]')
  μcf_TV = Matrix(CF_pointsTV_NS[i]')
    
    # 5×2

  # --- 2) CF run
  cf_res = run_em(μcf; X = Xs, nIter = nIter, tol = tol)
  push!(results, (i, "CF",     cf_res.iterations, cf_res.final_ll))

  # ------- CF_TV run
  cf_res_TV = run_em(μcf_TV; X = Xs, nIter = nIter, tol = tol)
  push!(results, (i, "CF_TV",     cf_res_TV.iterations, cf_res_TV.final_ll))

  # --- 3) random runs
  for rep in 1:nrand
    # sample k distinct rows as init means
    inds = sample(1:size(Xs,1), k; replace=false)
    μrand = Xs[inds, :]                  # k×2
    rr    = run_em(μrand; X = Xs, nIter = nIter, tol = tol)
    push!(results, (i, "random", rr.iterations, rr.final_ll))
  end
end

In [86]:
N = maximum(results.dataset)

# pre‐allocate arrays
cf_vals   = Float64[]
cf_vals_TV   = Float64[]

rand_means = Float64[]
rand_stds  = Float64[]

for i in 1:N
    sub = results[results.dataset .== i, :]
    push!(cf_vals,   sub.final_ll[sub.init .== "CF"][1])
    push!(cf_vals_TV,   sub.final_ll[sub.init .== "CF_TV"][1])
    rands = sub.final_ll[sub.init .== "random"]
    push!(rand_means, mean(rands))
    push!(rand_stds,  std(rands))
end

# x‐axis = dataset index
xs = 1:N




# 1) Compute the sort order of cf_vals
order = sortperm(cf_vals)    # gives the indices that sort cf_vals ascending

# 2) Reorder everything
xs_sorted        = xs[order]
cf_sorted        = cf_vals[order]
cf_sorted_TV        = cf_vals_TV[order]

rand_means_sorted = rand_means[order]
rand_stds_sorted  = rand_stds[order]

# 3) Plot the sorted curves
using Plots

plot(
  xs, cf_sorted;
  label   = "Curto-Fialkow-W2",
  lw      = 2,
  marker  = :circle,
  xlabel  = "Dataset (sorted by CF fit)",
  ylabel  = "Final log‑likelihood",
  title   = "CF vs random init across datasets",
  legend  = :bottomright,
  size=(900,600)
)

plot!(
  xs, cf_sorted_TV;
  fillalpha = 0.2,
  label     = "Curto-Fialkow-TV",
  lw        = 2,
  marker    = :diamond,
  color     = :red,
)


plot!(
  xs, rand_means_sorted;
  ribbon    = rand_stds_sorted,
  fillalpha = 0.2,
  label     = "Random ± 1 std",
  lw        = 2,
  marker    = :diamond,
  color     = :orange,
)

In [None]:
using DataFrames, Plots, LaTeXStrings, PGFPlotsX, Statistics
pgfplotsx()  # TikZ backend

# Unique datasets (robust even if not 1:N)
datasets = sort(unique(results.dataset))

cf_vals      = Float64[]
cf_vals_TV   = Float64[]
rand_means   = Float64[]
rand_stds    = Float64[]

for d in datasets
    sub = results[results.dataset .== d, :]
    push!(cf_vals,    first(sub.final_ll[sub.init .== "CF"]))
    push!(cf_vals_TV, first(sub.final_ll[sub.init .== "CF_TV"]))
    r = sub.final_ll[sub.init .== "random"]
    push!(rand_means, mean(r))
    push!(rand_stds,  std(r))
end

# Sort by CF fit
order                 = sortperm(cf_vals)
xs_sorted             = collect(1:length(datasets))[order]
cf_sorted             = cf_vals[order]
cf_sorted_TV          = cf_vals_TV[order]
rand_means_sorted     = rand_means[order]
rand_stds_sorted      = rand_stds[order]

# Styling
default(size=(420,280), grid=false, framestyle=:box, legend=:bottomright,
        linewidth=2, markerstrokewidth=0.8, markersize=3)

plt = plot(xs_sorted, cf_sorted;
    label=L"\text{Curto–Fialkow–W2}",
    marker=:circle,
    xlabel=L"\text{Dataset (sorted by CF fit)}",
    ylabel=L"\text{Final log-likelihood}",
    title=L"\text{CF vs.\ random init across datasets}",
)

plot!(plt, xs_sorted, cf_sorted_TV;
    label=L"\text{Curto–Fialkow–TV}",
    marker=:diamond,
)

plot!(plt, xs_sorted, rand_means_sorted;
    ribbon=rand_stds_sorted,
    label=L"\text{Random} \pm \text{ 1 std}",
    marker=:square,
    color=:orange,
)

# Show or save
display(plt)                     # for REPL/VSCode/Jupyter
savefig(plt, "final_ll_NS.tikz") # for LaTeX inclusion


In [None]:
using DataFrames, Statistics, Plots
gr()  # show inline first

# -- Build an aligned table and sort by CF (blue) --
g = combine(groupby(results, [:dataset, :init]),
            :final_ll => mean => :mean_ll,
            :final_ll => std  => :std_ll)

cf   = filter(:init => ==("CF"),      g)[!, [:dataset, :mean_ll]]
rename!(cf, :mean_ll => :cf)
cfTV = filter(:init => ==("CF_TV"),   g)[!, [:dataset, :mean_ll]]
rename!(cfTV, :mean_ll => :cfTV)
rnd  = filter(:init => ==("random"),  g)[!, [:dataset, :mean_ll, :std_ll]]
rename!(rnd, [:mean_ll, :std_ll] .=> [:rnd_mean, :rnd_std])

T = innerjoin(innerjoin(cf, cfTV, on=:dataset), rnd, on=:dataset)
sort!(T, :cf)  # ascending by blue

xs  = 1:nrow(T)
y1  = T.cf
y2  = T.cfTV
ym  = T.rnd_mean
ys  = T.rnd_std

# -- Styling & plot: connected lines + markers; ribbon on green only --
default(size=(560,360), grid=false, framestyle=:box, legend=:bottomright,
        linewidth=2, markerstrokewidth=0.8, markersize=3)

plt = plot(xs, y1; label="Curto–Fialkow–W2",
           marker=:circle, linestyle=:solid)

plot!(plt, xs, y2; label="Curto–Fialkow–TV",
      marker=:diamond, linestyle=:solid)

plot!(plt, xs, ym; ribbon=ys,       # <- ribbon here
      label="Random ± 1 std",
      marker=:square, linestyle=:solid,
      fillalpha=0.25)

xlabel!("Dataset (sorted by CF fit)")
ylabel!("Final log-likelihood")
display(plt)


In [None]:
using PGFPlotsX, LaTeXStrings
pgfplotsx()
plt_tikz = plot(xs, y1; label=L"\text{Curto–Fialkow–W2}", marker=:circle)
plot!(plt_tikz, xs, y2; label=L"\text{Curto–Fialkow–TV}", marker=:diamond)
plot!(plt_tikz, xs, ym; ribbon=ys, label=L"\text{Random} \pm \text{ 1 std}",
      marker=:square, fillalpha=0.25)
xlabel!(L"\text{Dataset (sorted by CF fit)}")
ylabel!(L"\text{Final log-likelihood}")
savefig(plt_tikz, "final_ll_NS.tikz")   # remember \usepgfplotslibrary{fillbetween} in LaTeX


In [None]:
using DataFrames, Statistics, Plots

# --- aggregate once per (dataset, init) ---
g = combine(groupby(results, [:dataset, :init]),
            :iterations => mean => :mean_it,
            :iterations => std  => :std_it)

# split and rename
cf   = filter(:init => ==("CF"),     g)[!, [:dataset, :mean_it]];  rename!(cf,  :mean_it => :cf)
cfTV = filter(:init => ==("CF_TV"),  g)[!, [:dataset, :mean_it]];  rename!(cfTV,:mean_it => :cfTV)
rnd  = filter(:init => ==("random"), g)[!, [:dataset, :mean_it, :std_it]];
rename!(rnd, [:mean_it, :std_it] .=> [:rnd_mean, :rnd_std])

# align rows and sort by CF (blue)
T = innerjoin(innerjoin(cf, cfTV, on=:dataset), rnd, on=:dataset)
sort!(T, :cf)  # ascending by CF iterations

xs  = 1:nrow(T)
y1  = T.cf
y2  = T.cfTV
ym  = T.rnd_mean
ys  = T.rnd_std

# --- notebook preview (GR backend) ---
gr()
default(size=(560,360), grid=false, framestyle=:box, legend=:bottomright,
        linewidth=2, markerstrokewidth=0.8, markersize=3)

plt = plot(xs, y1; label="Curto–Fialkow–W2", marker=:circle, linestyle=:solid)
plot!(plt, xs, y2; label="Curto–Fialkow–TV",  marker=:diamond, linestyle=:solid)
plot!(plt, xs, ym;  ribbon=ys,                 # ribbon only for random
      label="Random ± 1 std", marker=:square, linestyle=:solid, fillalpha=0.25)

xlabel!(plt, "Dataset (sorted by CF iterations)")
ylabel!(plt, "EM iterations to convergence")
display(plt)


In [None]:
using PGFPlotsX, LaTeXStrings
pgfplotsx()
plt_tikz = plot(xs, y1; label=L"\operatorname{W2}", marker=:circle)
plot!(plt_tikz, xs, y2; label=L"\operatorname{TV}",  marker=:diamond)
plot!(plt_tikz, xs, ym; ribbon=ys, label=L"\text{Random}",
      marker=:square, fillalpha=0.25)
#xlabel!(plt_tikz, L"\text{Iterations}")
ylabel!(plt_tikz, L"\text{Iterations}")
savefig(plt_tikz, "iterations_K5EM.tikz")
