In [None]:
using OptimalTransport
import OptimalTransport.Dual: Dual
# using MLDatasets: MLDatasets
using StatsBase
using Plots;
default(; palette=:Set1_3)
using LogExpFunctions
using NNlib: NNlib
using LinearAlgebra
using Distances
using Base.Iterators
using NMF
using Optim

In [None]:

function simplex_norm!(x; dims=1)
    return x .= x ./ sum(x; dims=dims)
end

In [None]:

function E_star(x; dims=1)
    return logsumexp(x; dims=dims)
end;

In [None]:
function E_star_grad(x; dims=1)
    return NNlib.softmax(x; dims=1)
end;
     

In [None]:

function dual_obj_weights(X, K, ε, D, G, ρ1)
    return sum(Dual.ot_entropic_semidual(X, G, ε, K)) + ρ1 * sum(E_star(-D' * G / ρ1))
end
function dual_obj_weights_grad!(∇, X, K, ε, D, G, ρ1)
    return ∇ .= Dual.ot_entropic_semidual_grad(X, G, ε, K) - D * E_star_grad(-D' * G / ρ1)
end
function dual_obj_dict(X, K, ε, Λ, G, ρ2)
    return sum(Dual.ot_entropic_semidual(X, G, ε, K)) + ρ2 * sum(E_star(-G * Λ' / ρ2))
end
function dual_obj_dict_grad!(∇, X, K, ε, Λ, G, ρ2)
    return ∇ .= Dual.ot_entropic_semidual_grad(X, G, ε, K) - E_star_grad(-G * Λ' / ρ2) * Λ
end;

In [None]:

function getprimal_weights(D, G, ρ1)
    return NNlib.softmax(-D' * G / ρ1; dims=1)
end
function getprimal_dict(Λ, G, ρ2)
    return NNlib.softmax(-G * Λ' / ρ2; dims=1)
end;

In [None]:
function solve_weights(X, K, ε, D, ρ1; alg, options)
    opt = optimize(
        g -> dual_obj_weights(X, K, ε, D, g, ρ1),
        (∇, g) -> dual_obj_weights_grad!(∇, X, K, ε, D, g, ρ1),
        zero.(X),
        alg,
        options,
    )
    return getprimal_weights(D, Optim.minimizer(opt), ρ1)
end
function solve_dict(X, K, ε, Λ, ρ2; alg, options)
    opt = optimize(
        g -> dual_obj_dict(X, K, ε, Λ, g, ρ2),
        (∇, g) -> dual_obj_dict_grad!(∇, X, K, ε, Λ, g, ρ2),
        zero.(X),
        alg,
        options,
    )
    return getprimal_dict(Λ, Optim.minimizer(opt), ρ2)
end;

In [None]:

f(x, μ, σ) = exp.(-(x .- μ) .^ 2)
coord = range(-12, 12; length=100)
N = 100
σ = 1
X = hcat(
    [
        rand() * f(coord, σ * randn() + 6, 1) +
        rand() * f(coord, σ * randn(), 1) +
        rand() * f(coord, σ * randn() - 6, 1) for _ in 1:N
    ]...,
)
X = simplex_norm!(X);

In [None]:
plot(coord, X; alpha=0.1, color=:blue, title="Input data", legend=nothing)

In [None]:
k = 3
result = nnmf(X, k; alg=:multmse)
plot(coord, result.W; title="NMF with Frobenius loss", palette=:Set1_3)

In [None]:
C = pairwise(SqEuclidean(), coord)
C = C / mean(C);

In [None]:

ε = 0.025
ρ1, ρ2 = (5e-2, 5e-2);

In [None]:

K = exp.(-C / ε);

In [None]:

D = rand(size(X, 1), k) # dictionary
simplex_norm!(D; dims=1) # norm columnwise
Λ = rand(k, size(X, 2)) # weights
simplex_norm!(Λ; dims=1); # norm rowwise

In [None]:

n_iter = 10
for iter in 1:n_iter
    @info "Wasserstein-NMF: iteration $iter"
    D .= solve_dict(
        X,
        K,
        ε,
        Λ,
        ρ2;
        alg=LBFGS(),
        options=Optim.Options(;
            iterations=250, g_tol=1e-4, show_trace=false, show_every=10
        ),
    )
    Λ .= solve_weights(
        X,
        K,
        ε,
        D,
        ρ1;
        alg=LBFGS(),
        options=Optim.Options(;
            iterations=250, g_tol=1e-4, show_trace=false, show_every=10
        ),
    )
end

In [None]:
plot(coord, D; title="NMF with Wasserstein loss", palette=:Set1_3)

In [None]:
using DelimitedFiles

# Suppose your final matrix is named X
# Save it as a CSV
writedlm("data/X_data.csv", X, ',')


In [None]:
writedlm("../data/coord_data.csv", coord, ',')


In [None]:
using JuWassNMF

In [None]:
wasserstein_nmf