# Import Required Libraries

In [None]:
import sys
import os

os.chdir("..")
os.chdir("./src")
# sys.path.append("./src")

In [None]:
from IPython import display
from IPython.display import clear_output
import pylab as pl
import numpy as np
import matplotlib.pyplot as plt
from numba import njit, jit
from time import time

# from helpers import *
from LDMIBSS import *

# np.random.seed(100)
%load_ext autoreload
%autoreload 2

# Source Generation and Mixing Scenario

In [None]:
N = 500000
NumberofSources = 5
NumberofMixtures = 10
S = generate_correlated_copula_sources(
    rho=0.7,
    df=4,
    n_sources=NumberofSources,
    size_sources=N,
    decreasing_correlation=False,
)

print("The following is the correlation matrix of sources")
display_matrix(np.corrcoef(S))

# Generate Mxr random mixing from i.i.d N(0,1)
A = np.random.randn(NumberofMixtures, NumberofSources)
X = np.dot(A, S)

SNR = 30
X, NoisePart = addWGN(X, SNR, return_noise=True)

SNRinp = 10 * np.log10(
    np.sum(np.mean((X - NoisePart) ** 2, axis=1))
    / np.sum(np.mean(NoisePart**2, axis=1))
)
print("The following is the mixture matrix A")
display_matrix(A)
print("Input SNR is : {}".format(SNRinp))

# Visualize Generated Sources and Mixtures

In [None]:
subplot_1D_signals(
    S.T[0:100], title="Original Signals", figsize=(15.2, 9), colorcode=None
)
subplot_1D_signals(
    X.T[0:100], title="Mixture Signals", figsize=(15, 18), colorcode=None
)

# Algorithm Hyperparameter Selection and Weight Initialization

In [None]:
lambday = 1 - 1e-1 / 10
lambdae = 1 - 1e-1 / 10
s_dim = S.shape[0]
x_dim = X.shape[0]

# Inverse output covariance
By = 1 * np.eye(s_dim)
# Inverse error covariance
Be = 10000 * np.eye(s_dim)

debug_iteration_point = 10000
model = OnlineLDMIBSS(
    s_dim=s_dim,
    x_dim=x_dim,
    muW=30 * 1e-3,
    lambday=lambday,
    lambdae=lambdae,
    By=By,
    Be=Be,
    neural_OUTPUT_COMP_TOL=1e-6,
    set_ground_truth=True,
    S=S,
    A=A,
)

# Run CorInfoMax Algorithm on Mixture Signals

In [None]:
model.fit_batch_nnantisparse(
    X=X,
    n_epochs=1,
    neural_dynamic_iterations=500,
    plot_in_jupyter=True,
    neural_lr_start=0.9,
    neural_lr_stop=0.000001,
    debug_iteration_point=debug_iteration_point,
    shuffle=True,
)

# Visualize SINR Convergence 

In [None]:
mpl.rcParams["xtick.labelsize"] = 18
mpl.rcParams["ytick.labelsize"] = 18
plot_convergence_plot(
    model.SINR_list,
    xlabel="Number of Iterations / {}".format(debug_iteration_point),
    ylabel="SINR (dB)",
    title="SINR Convergence Plot",
    colorcode=None,
    linewidth=1.8,
)

print("Final SINR: {}".format(np.array(model.SINR_list[-1])))

# Calculate Resulting Component SNRs and Overall SINR

In [None]:
Wf = model.compute_overall_mapping(return_mapping=True)
Y_ = Wf @ X
Y_ = model.signed_and_permutation_corrected_sources(S.T, Y_.T)
coef_ = (Y_ * S.T).sum(axis=0) / (Y_ * Y_).sum(axis=0)
Y_ = coef_ * Y_

print("Component SNR Values : {}\n".format(snr(S.T, Y_)))

SINR = 10 * np.log10(model.CalculateSINR(Y_.T, S)[0])

print("Overall SINR : {}".format(SINR))

# Vizualize Extracted Signals Compared to Original Sources

In [None]:
Wf = model.compute_overall_mapping(return_mapping=True)
Y = Wf @ X
Y_ = model.signed_and_permutation_corrected_sources(S.T, Y.T).T
Y_.shape, X.shape, S.shape

In [None]:
subplot_1D_signals(
    X=Y_.T[0:100],
    title="Extracted Signals (Sign and Permutation Corrected)",
    figsize=(15.2, 9),
    colorcode=None,
)
subplot_1D_signals(
    X=S.T[0:100], title="Original Signals", figsize=(15.2, 9), colorcode=None
)

In [None]:
Wf = model.compute_overall_mapping(return_mapping=True)
Y_ = Wf @ X
Y_ = model.signed_and_permutation_corrected_sources(S.T, Y_.T)
coef_ = (Y_ * S.T).sum(axis=0) / (Y_ * Y_).sum(axis=0)
Y_ = coef_ * Y_
snr(S.T.reshape(-1, 1), Y_.reshape(-1, 1))

In [None]:
Out = Y_.T
r = S.shape[0]
G = np.dot(
    Out - np.reshape(np.mean(Out, 1), (r, 1)),
    np.linalg.pinv(S - np.reshape(np.mean(S, 1), (r, 1))),
)
indmax = np.argmax(np.abs(G), 1)
indmax = np.mod(find_permutation_between_source_and_estimation(Out.T, S.T), r)
# indmax = np.array([2,3,4,1,0])
GG = np.zeros((r, r))
for kk in range(r):
    GG[kk, indmax[kk]] = np.dot(
        Out[kk, :] - np.mean(Out[kk, :]), S[indmax[kk], :].T - np.mean(S[indmax[kk], :])
    ) / np.dot(
        S[indmax[kk], :] - np.mean(S[indmax[kk], :]),
        S[indmax[kk], :].T - np.mean(S[indmax[kk], :]),
    )  # (G[kk,indmax[kk]])
ZZ = GG @ (S - np.reshape(np.mean(S, 1), (r, 1))) + np.reshape(np.mean(Out, 1), (r, 1))
E = Out - ZZ
MSE = np.linalg.norm(E, "fro") ** 2
SigPow = np.linalg.norm(ZZ, "fro") ** 2
SINR = SigPow / MSE

In [None]:
10 * np.log10(SINR)

In [None]:
indmax = np.argmax(np.abs(G), 1)
indmax = np.mod(find_permutation_between_source_and_estimation(Out.T, S.T), r)
indmax

In [None]:
display_matrix(G)

In [None]:
Y_.shape

In [None]:
snr(Y_[:, 3], Y_[:, 4])

In [None]:
10 * np.log10(SINR)

In [None]:
find_permutation_between_source_and_estimation(Out.T, S.T)

In [None]:
def find_permutation_between_source_and_estimation2(S, Y):
    """
    S    : Original source matrix
    Y    : Matrix of estimations of sources (after BSS or ICA algorithm)

    return the permutation of the source seperation algorithm
    """

    # perm = np.argmax(np.abs(np.corrcoef(S.T,Y.T) - np.eye(2*S.shape[1])),axis = 0)[S.shape[1]:]
    perm = np.argmax(np.abs(np.corrcoef(Y.T, S.T) - np.eye(2 * S.shape[1])), axis=0)[
        S.shape[1] :
    ]
    return perm

In [None]:
np.mod(find_permutation_between_source_and_estimation2(Out.T, S.T), 5)

In [None]:
10 * np.log10(SINR)

In [None]:
indmax

In [None]:
10 * np.log10(model.CalculateSINR(Y_.T, S)[0])

In [None]:
display_matrix(G)

In [None]:
display_matrix(Wf @ A)