# Import Required Libraries

In [None]:
import sys
import os

os.chdir("..")
os.chdir("..")
os.chdir("./src")
# sys.path.append("./src")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from IPython import display
import pylab as pl

from PMF import *
from general_utils import *
from visualization_utils import *

import warnings

warnings.filterwarnings("ignore")

notebook_name = "Nonnegative_Antisparse_Copula"

# Source Generation and Mixing Scenario

In [None]:
N = 10000
NumberofSources = 5
NumberofMixtures = 8
S = generate_correlated_copula_sources(
    rho=0.0,
    df=4,
    n_sources=NumberofSources,
    size_sources=N,
    decreasing_correlation=False,
)

print("The following is the correlation matrix of sources")
display_matrix(np.corrcoef(S))

# Generate Mxr random mixing from i.i.d N(0,1)
A = np.random.randn(NumberofMixtures, NumberofSources)
X = np.dot(A, S)

SNR = 30
X, NoisePart = addWGN(X, SNR, return_noise=True)

SNRinp = 10 * np.log10(
    np.sum(np.mean((X - NoisePart) ** 2, axis=1))
    / np.sum(np.mean(NoisePart**2, axis=1))
)
print("The following is the mixture matrix A")
display_matrix(A)
print("Input SNR is : {}".format(SNRinp))

# Visualize Generated Sources and Mixtures

In [None]:
subplot_1D_signals(
    S[:, 0:100], title="Original Signals", figsize=(15.2, 9), colorcode=None
)
subplot_1D_signals(
    X[:, 0:100], title="Mixture Signals", figsize=(15, 18), colorcode=None
)

# Algorithm Hyperparameter Selection and Weight Initialization

In [None]:
s_dim = S.shape[0]
x_dim = X.shape[0]
debug_iteration_point = 200
model = PMF(s_dim=s_dim, x_dim=x_dim, set_ground_truth=True, Sgt=S, Agt=A)

# Run LDMI Algorithm on Mixture Signals

In [None]:
model.fit_batch_nnantisparse(X, Lt=50, lambda_=10, tau=1e-8, plot_in_jupyter=True)

# Calculate Resulting Component SNRs and Overall SINR

In [None]:
model.SNR_list

In [None]:
# Wf = model.W
Y = model.S_
Y_ = signed_and_permutation_corrected_sources(S, Y)
coef_ = ((Y_ * S).sum(axis=1) / (Y_ * Y_).sum(axis=1)).reshape(-1, 1)
Y_ = coef_ * Y_

print("Component SNR Values : {}\n".format(snr_jit(S, Y_)))

SINRwsm = 10 * np.log10(CalculateSINRjit(Y_, S)[0])

print("Overall SINR : {}".format(SINRwsm))

# Vizualize Extracted Signals Compared to Original Sources

In [None]:
subplot_1D_signals(
    Y_[:, 0:100],
    title="Extracted Signals (Sign and Permutation Corrected)",
    figsize=(15.2, 9),
    colorcode=None,
)
subplot_1D_signals(
    S[:, 0:100], title="Original Signals", figsize=(15.2, 9), colorcode=None
)

In [None]:
S.shape

In [None]:
A.shape

In [None]:
Agt = np.random.randn(10, 5)
H = np.random.randn(10000, 5)
W = np.random.randn(10, 5)
Y = Agt @ H.T
Lt = 5

In [None]:
(Lt * np.linalg.norm(np.transpose(W) @ W, 2))

In [None]:
(np.dot(np.transpose(W), (np.dot(W, np.transpose(H)) - Y)))