In [1]:
from utils import UtilsSparse
import numpy as np
from typing import List
from wasserstein import NMRSpectrum

def load_data():
    """
    Load NMR spectral data from files.

    Returns:
        Tuple[List[NMRSpectrum], NMRSpectrum]: Component spectra and mixture spectrum
    """
    components_names = ["Pinene", "Benzyl benzoate"]

    protons_list = [16, 12]

    dirname = "data"
    filename = dirname + "/" + "preprocessed_mix.csv"
    mix = np.loadtxt(filename, delimiter=",")

    how_many_components = len(components_names)
    names = ["comp" + str(i) for i in range(how_many_components)]

    files_with_components = [dirname + "/" + "preprocessed_comp0.csv", 
                             dirname + "/" + "preprocessed_comp1.csv"]
    spectra = []
    for i in range(how_many_components):
        filename = files_with_components[i]
        spectra.append(np.loadtxt(filename, delimiter=","))

    spectra2: List[NMRSpectrum] = []
    names = []
    for i in range(len(spectra)):
        spectra2.append(
            NMRSpectrum(
                confs=list(zip(spectra[i][:, 0], spectra[i][:, 1])),
                protons=protons_list[i],
            )
        )
        names.append("comp" + str(i))

    spectra = spectra2
    del spectra2
    mix = NMRSpectrum(confs=list(zip(mix[:, 0], mix[:, 1])))
    mix.trim_negative_intensities()
    mix.normalize()
    for sp in spectra:
        sp.trim_negative_intensities()
        sp.normalize()

    return spectra, mix

In [2]:
spectra, mix = load_data()
# Sample params
N = 2100
C = 20
reg = 1.5
regm1 = 230
regm2 = 115
eta_G = 1e-3
eta_p = 1e-3
tol = 1e-5
gamma = 0.99
max_iter = 1000

sparse = UtilsSparse(spectra, mix, N, C, reg, regm1, regm2)
G, p = sparse.joint_md(eta_G, eta_p, max_iter, tol=tol, gamma=gamma)
print("Final p: ", p)
print("True proportions: ", [0.3865, 1 - 0.3865])

Iteration 0: p = [0.4951 0.5049]
Iteration 20: p = [0.4309 0.5691]
Iteration 40: p = [0.4104 0.5896]
Iteration 60: p = [0.4002 0.5998]
Iteration 80: p = [0.395 0.605]
Iteration 100: p = [0.3919 0.6081]
Iteration 120: p = [0.3902 0.6098]
Iteration 140: p = [0.3888 0.6112]
Iteration 160: p = [0.3875 0.6125]
Iteration 180: p = [0.3864 0.6136]
Iteration 200: p = [0.3854 0.6146]
Iteration 220: p = [0.3847 0.6153]
Iteration 240: p = [0.384 0.616]
Iteration 260: p = [0.3834 0.6166]
Iteration 280: p = [0.3828 0.6172]
Iteration 300: p = [0.3824 0.6176]
Iteration 320: p = [0.382 0.618]
Iteration 340: p = [0.3818 0.6182]
Iteration 360: p = [0.3816 0.6184]
Iteration 380: p = [0.3814 0.6186]
Iteration 400: p = [0.3812 0.6188]
Iteration 420: p = [0.3811 0.6189]
Iteration 440: p = [0.381 0.619]
Iteration 460: p = [0.3809 0.6191]
Iteration 480: p = [0.3808 0.6192]
Iteration 500: p = [0.3808 0.6192]
Iteration 520: p = [0.3807 0.6193]
Iteration 540: p = [0.3807 0.6193]
Iteration 560: p = [0.3807 0.6193]