In [None]:
import sys
import os

os.chdir("..")
os.chdir("./src")
# sys.path.append("./src")

In [None]:
from IPython import display
from IPython.display import clear_output
import pylab as pl
import numpy as np
import matplotlib.pyplot as plt
from numba import njit, jit
from time import time

# from helpers import *
from LDMIBSS import *

%load_ext autoreload
%autoreload 2

In [None]:
N = 10000
NumberofSources = 5
NumberofMixtures = 10

S = np.random.exponential(scale=1.0, size=(NumberofSources, int(N)))
S = S / np.sum(S, axis=0)
print("The following is the correlation matrix of sources")
display_matrix(np.corrcoef(S))
plt.scatter(S[0, :], S[1, :])

# Generate Mxr random mixing from i.i.d N(0,1)
A = np.random.randn(NumberofMixtures, NumberofSources)
X = np.dot(A, S)

SNR = 30
X, NoisePart = addWGN(X, SNR, return_noise=True)

SNRinp = 10 * np.log10(
    np.sum(np.mean((X - NoisePart) ** 2, axis=1))
    / np.sum(np.mean(NoisePart**2, axis=1))
)
print("The following is the mixture matrix A")
display_matrix(A)
print("Input SNR is : {}".format(SNRinp))

In [None]:
s_dim = S.shape[0]
x_dim = X.shape[0]
debug_iteration_point = 200
model = LDMIBSS(s_dim=s_dim, x_dim=x_dim, set_ground_truth=True, S=S, A=A)

In [None]:
model.fit_batch_simplex(
    X,
    epsilon=1e-5,
    mu_start=100,
    n_iterations=10000,
    method="correlation",
    debug_iteration_point=debug_iteration_point,
    plot_in_jupyter=True,
)

In [None]:
mpl.rcParams["xtick.labelsize"] = 18
mpl.rcParams["ytick.labelsize"] = 18
plot_convergence_plot(
    model.SINR_list,
    xlabel="Number of Iterations / {}".format(debug_iteration_point),
    ylabel="SINR (dB)",
    title="SINR Convergence Plot",
    colorcode=None,
    linewidth=1.8,
)

print("Final SINR: {}".format(np.array(model.SINR_list[-1])))

In [None]:
W = model.W
Y = W @ X
10 * np.log10(CalculateSINR(Y, S)[0])

In [None]:
Y = W @ X
print(Y.shape, X.shape, S.shape)
Y_ = signed_and_permutation_corrected_sources(S.T, Y.T)
coef_ = (Y_ * S.T).sum(axis=0) / (Y_ * Y_).sum(axis=0)
Y_ = coef_ * Y_
snr(S.T, Y_)

In [None]:
Y = W @ X
Y_ = signed_and_permutation_corrected_sources(S.T, Y.T).T
subplot_1D_signals(
    X=Y_.T[0:100],
    title="Extracted Signals (Sign and Permutation Corrected)",
    figsize=(15.2, 9),
    colorcode=None,
)
subplot_1D_signals(
    X=S.T[0:100], title="Original Signals", figsize=(15.2, 9), colorcode=None
)

In [None]:
# def projection_simplex(V, z=1, axis=None):
#     """
#     Projection of x onto the simplex, scaled by z:
#         P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
#     z: float or array
#         If array, len(z) must be compatible with V
#     axis: None or int
#         axis=None: project V by P(V.ravel(); z)
#         axis=1: project each V[i] by P(V[i]; z[i])
#         axis=0: project each V[:, j] by P(V[:, j]; z[j])
#     """
#     if axis == 1:
#         n_features = V.shape[1]
#         U = np.sort(V, axis=1)[:, ::-1]
#         z = np.ones(len(V)) * z
#         cssv = np.cumsum(U, axis=1) - z[:, np.newaxis]
#         ind = np.arange(n_features) + 1
#         cond = U - cssv / ind > 0
#         rho = np.count_nonzero(cond, axis=1)
#         theta = cssv[np.arange(len(V)), rho - 1] / rho
#         return np.maximum(V - theta[:, np.newaxis], 0)

#     elif axis == 0:
#         return projection_simplex(V.T, z, axis=1).T

#     else:
#         V = V.ravel().reshape(1, -1)
#         return projection_simplex(V, z, axis=1).ravel()

In [None]:
# %timeit ProjectColstoSimplex(S)

In [None]:
# np.linalg.norm(projection_simplex(S, 1, 0) - ProjectColstoSimplex(S))

In [None]:
# %timeit projection_simplex(S, 1, 0)

In [None]:
# %timeit ProjectRowstoL1NormBall(S.T).T