In [None]:
import sys
import os

os.chdir("..")
os.chdir("./src")
# sys.path.append("./src")

import numpy as np
import matplotlib.pyplot as plt
from IPython import display
import pylab as pl

from WSMBSS import *
from general_utils import *
from visualization_utils import *

import warnings

warnings.filterwarnings("ignore")

notebook_name = "Simplex"

In [None]:
NumberofSources = 5
NumberofMixtures = 10
N = 500000
np.random.seed(0)
# https://stackoverflow.com/questions/65154622/sample-uniformly-at-random-from-a-simplex-in-python
S = np.random.exponential(scale=1.0, size=(NumberofSources, int(N)))
S = S / np.sum(S, axis=0)

SNR = 30
A = np.random.randn(NumberofMixtures, NumberofSources)
X = np.dot(A, S)

# X, NoisePart = addWGN(X, SNR, return_noise = True)

# SNRinp = 10 * np.log10(np.sum(np.mean((X - NoisePart)**2, axis = 1)) / np.sum(np.mean(NoisePart**2, axis = 1)))
print("The following is the mixture matrix A")
display_matrix(A)
# print("Input SNR is : {}".format(SNRinp))

plt.scatter(S[0, :], S[2, :])
plt.show()

In [None]:
def offdiag(A, return_diag=False):
    """_summary_

    Args:
        A (_type_): _description_
        return_diag (bool, optional): _description_. Defaults to False.

    Returns:
        _type_: _description_
    """
    if return_diag:
        diag = np.diag(A)
        return A - np.diag(diag), diag
    else:
        return A - np.diag(diag)

In [None]:
MUS = 0.25
gammaM_start = [MUS, MUS]
gammaM_stop = [1e-3, 1e-3]
gammaW_start = [MUS, MUS]
gammaW_stop = [1e-3, 1e-3]

OUTPUT_COMP_TOL = 1e-5
MAX_OUT_ITERATIONS = 3000
LayerGains = [1, 1]
LayerMinimumGains = [1, 1]
LayerMaximumGains = [1e6, 1.001]
WScalings = [0.0033, 0.0033]
GamScalings = [0.01, 0.01]
zeta = 1e-4
beta = 0.5
muD = [20, 1e-2]

s_dim = S.shape[0]
x_dim = X.shape[0]
h_dim = s_dim
samples = S.shape[1]
W_HX = np.eye(h_dim, x_dim)
W_YH = np.eye(s_dim, h_dim)

In [None]:
np.random.seed(100)
debug_iteration_point = 25000

model = OnlineWSMBSS(
    s_dim=s_dim,
    x_dim=x_dim,
    h_dim=h_dim,
    gammaM_start=gammaM_start,
    gammaM_stop=gammaM_stop,
    gammaW_start=gammaW_start,
    gammaW_stop=gammaW_stop,
    beta=beta,
    zeta=zeta,
    muD=muD,
    WScalings=WScalings,
    W_HX=W_HX,
    W_YH=W_YH,
    DScalings=LayerGains,
    LayerMinimumGains=LayerMinimumGains,
    LayerMaximumGains=LayerMaximumGains,
    neural_OUTPUT_COMP_TOL=OUTPUT_COMP_TOL,
    set_ground_truth=True,
    S=S,
    A=A,
)

In [None]:
hidden_layer_gain = 2
lr_rule = "divide_by_slow_loop_index"
neural_dynamic_iterations = 750
neural_lr_decay_multiplier = 0.005
neural_lr_start = 0.5
neural_lr_stop = 0.5
stlambd_lr = 0.01
OUTPUT_COMP_TOL = 1e-5

W_HX = model.W_HX
W_YH = model.W_YH
M_H = model.M_H
M_Y = model.M_Y
D1 = model.D1
D2 = model.D2

In [None]:
np.random.seed(100)
H = np.zeros((h_dim, samples))
Y = np.zeros((s_dim, samples))

H = np.random.randn(h_dim, samples)  # *0.05
Y = np.random.randn(s_dim, samples)  # *0.05

i_sample = 287

x_current = X[:, i_sample]  # Take one input
y = Y[:, i_sample]
h = H[:, i_sample]

In [None]:
model.run_neural_dynamics_simplex_jit(
    x_current,
    h,
    y,
    M_H,
    M_Y,
    W_HX,
    W_YH,
    D1,
    D2,
    beta,
    zeta,
    neural_dynamic_iterations,
    neural_lr_start,
    neural_lr_stop,
    lr_rule,
    neural_lr_decay_multiplier,
    stlambd_lr,
    hidden_layer_gain,
    OUTPUT_COMP_TOL,
)

In [None]:
hidden_layer_gain = 2
lr_rule = "divide_by_slow_loop_index"
neural_dynamic_iterations = 10
neural_lr_decay_multiplier = 0.005
neural_lr_start = 0.5
neural_lr_stop = 0.5
stlambd_lr = 0.01
OUTPUT_COMP_TOL = 1e-5

W_HX = model.W_HX
W_YH = model.W_YH
M_H = model.M_H
M_Y = model.M_Y
D1 = model.D1
D2 = model.D2

In [None]:
np.random.seed(100)
H = np.random.randn(h_dim, samples)
Y = np.random.randn(s_dim, samples)

i_sample = 0

x_current = X[:, i_sample]  # Take one input
y = Y[:, i_sample]
h = H[:, i_sample]

In [None]:
x_current

In [None]:
np.random.seed(100)
W_HX = np.random.randn(*W_HX.shape)
W_YH = np.random.randn(*W_YH.shape)
M_H = np.random.randn(*M_H.shape)
M_Y = np.random.randn(*M_Y.shape)
M_hat_H, Gamma_H = offdiag(M_H, True)
M_hat_Y, Gamma_Y = offdiag(M_Y, True)
display_matrix(W_YH)
display_matrix(W_HX)
display_matrix(M_Y)
display_matrix(M_H)

In [None]:
D1, D2

In [None]:
(1 - zeta) * beta * np.diag(
    np.diag(
        M_H
        @ np.diag(
            D1.reshape(
                -1,
            )
        )
        @ M_H
        - W_HX @ W_HX.T
    )
)

In [None]:
(np.diag(W_HX @ W_HX.T))

In [None]:
np.sum(np.abs(W_HX) ** 2, axis=1)

In [None]:
(np.diag(M_H @ M_H.T))

In [None]:
np.sum((np.abs(M_H) ** 2), axis=1)

In [None]:
np.diag(
    M_H
    @ np.diag(
        D1.reshape(
            -1,
        )
    )
    @ M_H.T
)

In [None]:
np.sum((np.abs(M_H) ** 2) * D1.T, axis=1)

In [None]:
(1 - zeta) * beta * (
    np.sum((np.abs(M_H) ** 2) * D1.T, axis=1) - np.sum(np.abs(W_HX) ** 2, axis=1)
).reshape(-1, 1) + zeta * (1 / D1)

In [None]:
(1 - zeta) * (1 - beta) * (
    np.sum((np.abs(M_Y) ** 2) * D2.T, axis=1) - np.sum(np.abs(W_YH) ** 2, axis=1)
).reshape(-1, 1) + zeta * (1 / D2)

In [None]:
mat_factor1 = (1 - zeta) * beta * (D1 * W_HX)
mat_factor2 = (1 - zeta) * (1 - beta) * M_hat_H + (1 - zeta) * beta * (
    (D1 * M_hat_H) * D1.T
)
mat_factor3 = (1 - zeta) * (1 - beta) * (W_YH.T * D2.T)
mat_factor4 = (1 - zeta) * Gamma_H * ((1 - beta) + beta * (D1.T) ** 2)
mat_factor5 = M_hat_Y * D2.T
mat_factor6 = Gamma_Y * D2.T

In [None]:
h

In [None]:
mat_factor4[0] * h

In [None]:
(mat_factor4 * h)[0]

In [None]:
y

In [None]:
mat_factor6

In [None]:
mat_factor6[0] * y

In [None]:
mat_factor1 @ x_current

In [None]:
mat_factor2 @ h

In [None]:
mat_factor3 @ y

In [None]:
mat_factor1 @ x_current - mat_factor2 @ h + mat_factor3 @ y

In [None]:
np.diag(Gamma_H * ((1 - zeta) * (1 - beta) + (1 - zeta) * beta * D1**2))

In [None]:
mat_factor4[0]

In [None]:
mat_factor5

In [None]:
mat_factor5 @ y

In [None]:
(1 - zeta) * Gamma_Y * ((1 - beta) + beta * D2.T**2)

In [None]:
(1 - zeta) * Gamma_Y * ((1 - beta) * D2.T**2)

In [None]:
np.diag(Gamma_Y) @ np.diag(
    D2.reshape(
        -1,
    )
)  # - Gamma_Y * D2.T

In [None]:
Gamma_Y.reshape(-1, 1) * D2

In [None]:
(1 - zeta) * Gamma_H * ((1 - beta) + beta * (D1.T) ** 2)

In [None]:
# (1 - zeta) * np.diag(Gamma_H) @
(
    (1 - beta)
    + beta
    * np.diag(
        D1.reshape(
            -1,
        )
    )
    ** 2
)

In [None]:
model.run_neural_dynamics_simplex_jit(
    x_current,
    h,
    y,
    M_H,
    M_Y,
    W_HX,
    W_YH,
    D1,
    D2,
    beta,
    zeta,
    neural_dynamic_iterations,
    neural_lr_start,
    neural_lr_stop,
    lr_rule,
    neural_lr_decay_multiplier,
    stlambd_lr,
    hidden_layer_gain,
    OUTPUT_COMP_TOL,
)

In [None]:
M_hat_H, Gamma_H = offdiag(M_H, True)
M_hat_Y, Gamma_Y = offdiag(M_Y, True)

In [None]:
(1 - zeta) * Gamma_H * ((1 - beta) + beta * D1.T**2)

In [None]:
np.diag((Gamma_Y * D2.T)[0])

In [None]:
M_hat_Y * D2.T