In [1]:
import sys
import os
os.chdir("..")
os.chdir("..")
os.chdir("./src")
# sys.path.append("./src")

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.io
from tqdm import tqdm
from scipy.stats import invgamma, chi2, t
from WSMBSS import *
from numba import njit
from IPython import display
import pylab as pl
import mne 
from mne.preprocessing import ICA
import warnings
warnings.filterwarnings("ignore")
np.random.seed(1569)
# %load_ext autoreload
# %autoreload 2
notebook_name = 'Antisparse_Copula'

In [3]:
rho = 0.6
N = 500000
NumberofSources = 4
NumberofMixtures = 8

M = NumberofMixtures
r = NumberofSources

SNR = 30 # dB
NoiseAmp = (10 ** (-SNR/20))# * np.sqrt(NumberofSources)

S = generate_correlated_copula_sources(rho = rho, df = 4, n_sources = NumberofSources, size_sources = N , 
                                       decreasing_correlation = True)
S = 2 * S - 1
INPUT_STD = 0.5
A = np.random.standard_normal(size=(NumberofMixtures,NumberofSources))
X = A @ S
for MM in range(A.shape[0]):
    stdx = np.std(X[MM,:])
    A[MM,:] = A[MM,:]/stdx * INPUT_STD
Xn = A @ S
Noisecomp=np.random.randn(A.shape[0],S.shape[1])*np.power(10,-SNR/20)*INPUT_STD
X=Xn+Noisecomp
SNRinp = 20*np.log10(np.std(Xn)/np.std(Noisecomp))

print("The following is the mixture matrix A")
display_matrix(A)
print("Input SNR is : {}".format(SNRinp))

The following is the mixture matrix A


<IPython.core.display.Math object>

Input SNR is : 30.002407782786484


In [4]:
if rho > 0.4:
    MUS = 0.25
    gamma_stop = 5*1e-4
else:
    MUS = 0.6
    gamma_stop = 1e-3
OUTPUT_COMP_TOL = 1e-6
MAX_OUT_ITERATIONS= 3000
LayerGains = [1,1]
LayerMinimumGains = [0.2,0.2]
LayerMaximumGains = [1e6,5]
WScalings = [0.005,0.005]
GamScalings = [2,1]
zeta = 5*1e-5
beta = 0.5
muD = [1.125, 0.2]

s_dim = S.shape[0]
x_dim = X.shape[0]
h_dim = s_dim
samples = S.shape[1]
W_HX = np.eye(h_dim, x_dim)
W_YH = np.eye(s_dim, h_dim)

In [5]:
debug_iteration_point = 10000
model = OnlineWSMBSS(s_dim = s_dim, x_dim = x_dim, h_dim = h_dim, 
                        gamma_start = MUS, gamma_stop = gamma_stop, beta = beta, zeta = zeta, 
                        muD = muD,WScalings = WScalings, GamScalings = GamScalings,
                        W_HX = W_HX, W_YH = W_YH,
                        DScalings = LayerGains, LayerMinimumGains = LayerMinimumGains,
                        LayerMaximumGains = LayerMaximumGains,neural_OUTPUT_COMP_TOL = OUTPUT_COMP_TOL,
                        set_ground_truth = True, S = S, A = A)

# modelWSM.fit_batch_antisparse(X, n_epochs = 1, 
#                                 neural_lr_start = 0.75,
#                                 neural_lr_stop = 0.05,
#                                 debug_iteration_point = debug_iteration_point,
#                                 plot_in_jupyter = True,
#                                 )

In [6]:
@njit
def run_neural_dynamics_antisparse_jitV4(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, 
                                         neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL):
    
    def ddiag(A):
        return np.diag(np.diag(A))
    
    def offdiag(A, return_diag = False):
        if return_diag:
            diag = np.diag(A)
            return A - np.diag(diag), diag
        else:
            return A - np.diag(diag)
    
    M_hat_H, Gamma_H = offdiag(M_H, True)
    M_hat_Y, Gamma_Y = offdiag(M_Y, True)
    
    mat_factor1 = (1 - zeta) * beta * (D1 * W_HX)
    mat_factor2 = ((1 - zeta) * (1 - beta) * M_hat_H  + (1- zeta) * beta * ((D1 * M_hat_H) * D1.T))
    mat_factor3 = (1 - zeta) * (1 - beta) * (W_YH.T * D2.T)
    mat_factor4 = (1 - zeta) * Gamma_H * ((1 - beta) + beta * D1 ** 2)
    mat_factor5 = M_hat_Y * D2.T
    mat_factor6 = Gamma_Y * D2
    

    v = mat_factor4 @ h
    u = mat_factor6 @ y
    
    PreviousMembraneVoltages = {'v': np.zeros_like(v), 'u': np.zeros_like(u)}
    MembraneVoltageNotSettled = 1
    OutputCounter = 0
    while MembraneVoltageNotSettled & (OutputCounter < neural_dynamic_iterations):
        OutputCounter += 1
        MUV = max(lr_start/(1+OutputCounter*0.005), lr_stop)

        delv = -v + mat_factor1 @ x_current - mat_factor2 @ h + mat_factor3 @ y
        v = v + MUV * delv
        h = v / np.diag(mat_factor4)

        delu = -u + W_YH @ h - mat_factor5 @ y
        u = u + MUV * delu
        y = u / np.diag(mat_factor6)
        y = np.clip(y, -1, 1)

        MembraneVoltageNotSettled = 0
        if (np.linalg.norm(v - PreviousMembraneVoltages['v'])/np.linalg.norm(v) > OUTPUT_COMP_TOL) | (np.linalg.norm(u - PreviousMembraneVoltages['u'])/np.linalg.norm(u) > OUTPUT_COMP_TOL):
            MembraneVoltageNotSettled = 1
        PreviousMembraneVoltages['v'] = v
        PreviousMembraneVoltages['u'] = u

    return h,y, OutputCounter

In [7]:
W_HX = model.W_HX
W_YH = model.W_YH
M_H = model.M_H
M_Y = model.M_Y
D1 = model.D1
D2 = model.D2
d1 = np.diag(D1).reshape(-1,1)
d2 = np.diag(D2).reshape(-1,1)
neural_dynamic_iterations = 750
lr_start = 0.75
lr_stop = 0.05
OUTPUT_COMP_TOL = 1e-6

def ddiag(A):
    return np.diag(np.diag(A))

Gamma_H = ddiag(M_H)
M_hat_H = M_H - Gamma_H

Gamma_Y = ddiag(M_Y)
M_hat_Y = M_Y - Gamma_Y
    
    
H = np.zeros((h_dim,samples))
Y = np.zeros((s_dim,samples))

x_current  = X[:,0] # Take one input

y = Y[:,0]

h = H[:,0]

In [8]:
h4,y4,_ =run_neural_dynamics_antisparse_jitV4(x_current, h, y, M_H, M_Y, W_HX, W_YH, d1, d2, beta, zeta, 
                                    neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)

In [9]:
# WL1 = np.linalg.inv((1 -zeta) * beta * ((d1 * M_H) * d1.T) + (1 - zeta) * (1 - beta) * M_H - (1 -zeta) * (1 - beta) * W_YH.T @ np.linalg.inv(M_Y) @ W_YH) @ ((1 - zeta) * beta * (d1 * W_HX))
# WL1

In [10]:
# A = (1 -zeta) * (beta * ((d1 * M_H) * d1.T) +  (1 - beta) * (M_H -  W_YH.T @ np.linalg.solve(M_Y, W_YH)))
# b = ((1 - zeta) * beta * (d1 * W_HX))

In [11]:
# np.linalg.solve(A, b) - WL1

In [12]:
# XX = np.random.randn(W_YH.shape[0], W_YH.shape[1])
# WL2 = np.linalg.inv(D2) @ np.linalg.inv(M_Y) @ XX
# WL2

In [13]:
# np.linalg.solve(M_Y * d2.T, XX) - WL2

In [14]:
# np.linalg.inv(M_Y) @ W_YH

In [49]:
def compute_overall_mapping(beta, zeta, D1, D2, M_H, M_Y, W_HX, W_YH):

    # Mapping from xt -> ht
    WL1 = np.linalg.inv((1 -zeta) * beta * D1 @ M_H @ D1 + (1 - zeta) * (1 - beta) * M_H - (1 -zeta) * (1 - beta) * W_YH.T @ np.linalg.inv(M_Y) @ W_YH) @ ((1 - zeta) * beta * D1 @ W_HX)

    # Mapping from ht -> yt
    WL2 = np.linalg.inv(D2) @ np.linalg.inv(M_Y) @ W_YH
    # Seperator
    W = WL2 @ WL1
    return W

@njit
def compute_overall_mapping_jit(beta, zeta, D1, D2, M_H, M_Y, W_HX, W_YH):
    # Mapping from xt -> ht
    A = (1 -zeta) * (beta * ((D1 * M_H) * D1.T) +  (1 - beta) * (M_H -  W_YH.T @ np.linalg.solve(M_Y, W_YH)))
    b = ((1 - zeta) * beta * (D1 * W_HX))
    WL1 = np.linalg.solve(A, b)
    # Mapping from ht -> yt
    WL2 = np.linalg.solve(M_Y * D2.T, W_YH)
    W = WL2 @ WL1
    return W

In [50]:
compute_overall_mapping(beta, zeta, D1, D2, M_H, M_Y, W_HX, W_YH) - compute_overall_mapping_jit(beta, zeta, d1, d2, M_H, M_Y, W_HX, W_YH)

array([[0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.]])

In [51]:
%timeit compute_overall_mapping_jit(beta, zeta, d1, d2, M_H, M_Y, W_HX, W_YH)

26.2 µs ± 1.17 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [52]:
%timeit compute_overall_mapping(beta, zeta, D1, D2, M_H, M_Y, W_HX, W_YH)

70.3 µs ± 1.72 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


# FASTER snr FUNCTION WITH NUMBA

In [37]:
def snr(S_original, S_noisy):
    N_hat = S_original - S_noisy
    N_P = (N_hat ** 2).sum(axis = 0)
    S_P = (S_original ** 2).sum(axis = 0)
    snr = 10 * np.log10(S_P / N_P)
    return snr

# @njit
@njit( parallel=True )
def snr_jit(S_original, S_noisy):
    N_hat = S_original - S_noisy
    N_P = (N_hat ** 2).sum(axis = 0)
    S_P = (S_original ** 2).sum(axis = 0)
    snr = 10 * np.log10(S_P / N_P)
    return snr

In [38]:
S = generate_correlated_copula_sources(rho = rho, df = 4, n_sources = NumberofSources, size_sources = N , 
                                       decreasing_correlation = True)
S = 2 * S - 1
SNR = 30
Snoisy, NoisePart = addWGN(S, SNR, return_noise = True)
10 * np.log10(np.sum(np.mean((Snoisy - NoisePart)**2, axis = 1)) / np.sum(np.mean(NoisePart**2, axis = 1)))

29.998920710595733

In [39]:
snr_jit(S.T, Snoisy.T)

array([30.00611817, 29.9971599 , 29.99770972, 29.994705  ])

In [40]:
snr(S.T, Snoisy.T)

array([30.00611817, 29.9971599 , 29.99770972, 29.994705  ])

In [41]:
%timeit snr_jit(S.T, Snoisy.T)

5.99 ms ± 92.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [42]:
%timeit snr(S.T, Snoisy.T)

26.4 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [32]:
def CalculateSINR(Out,S, compute_permutation = True):
    
    Smean = np.mean(S,1)
    Outmean = np.mean(Out,1)
    r=S.shape[0]
    if compute_permutation:
        G=np.dot(Out-np.reshape(np.mean(Out,1),(r,1)),np.linalg.pinv(S-np.reshape(np.mean(S,1),(r,1))))
        indmax=np.argmax(np.abs(G),1)
    else:
        G=np.dot(Out-np.reshape(np.mean(Out,1),(r,1)),np.linalg.pinv(S-np.reshape(np.mean(S,1),(r,1))))
        indmax = np.arange(0,r)
        
    GG=np.zeros((r,r))
    for kk in range(r):
        GG[kk,indmax[kk]]=np.dot(Out[kk,:] - Outmean[kk], S[indmax[kk],:].T - Smean[indmax[kk]])/np.dot(S[indmax[kk],:] - Smean[indmax[kk]], S[indmax[kk],:].T - Smean[indmax[kk]])#(G[kk,indmax[kk]])

    ZZ = GG @ (S-np.reshape(Smean,(r,1))) + np.reshape(Outmean,(r,1))
    E=Out-ZZ
    MSE=np.linalg.norm(E,'fro')**2
    SigPow=np.linalg.norm(ZZ,'fro')**2
    SINR=(SigPow/MSE)
    return SINR,SigPow,MSE,G

@njit
def CalculateSINRjit(Out,S, compute_permutation = True):
    def mean_numba(a):
        res = []
        for i in range(a.shape[0]):
            res.append(a[i, :].mean())

        return np.array(res)
    
    r=S.shape[0]
    Smean = mean_numba(S)
    Outmean = mean_numba(Out)
    if compute_permutation:
        G=np.dot(Out-np.reshape(Outmean,(r,1)),np.linalg.pinv(S-np.reshape(Smean,(r,1))))
        #G = np.linalg.lstsq((S-np.reshape(Smean,(r,1))).T, (Out-np.reshape(Outmean,(r,1))).T)[0]
        indmax = np.abs(G).argmax(1).astype(np.int64)
    else:
        G=np.dot(Out-np.reshape(Outmean,(r,1)),np.linalg.pinv(S-np.reshape(Smean,(r,1))))
        #G = np.linalg.lstsq((S-np.reshape(Smean,(r,1))).T, (Out-np.reshape(Outmean,(r,1))).T)[0]
        indmax = np.arange(0,r)

    GG=np.zeros((r,r))
    for kk in range(r):
        GG[kk,indmax[kk]]=np.dot(Out[kk,:] - Outmean[kk], S[indmax[kk],:].T - Smean[indmax[kk]])/np.dot(S[indmax[kk],:] - Smean[indmax[kk]], S[indmax[kk],:].T - Smean[indmax[kk]])#(G[kk,indmax[kk]])

    ZZ = GG @ (S-np.reshape(Smean,(r,1))) + np.reshape(Outmean,(r,1))
    E = Out - ZZ
    MSE = np.linalg.norm(E)**2
    SigPow = np.linalg.norm(ZZ)**2
    SINR = (SigPow/MSE)
    return SINR,SigPow,MSE,G

In [33]:
np.linalg.norm(CalculateSINRjit(Snoisy,S)[-1] - CalculateSINR(Snoisy,S)[-1])

1.7936506102875025e-15

In [36]:
%timeit CalculateSINRjit(Snoisy,S)

699 ms ± 147 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
%timeit CalculateSINR(Snoisy,S)

300 ms ± 2.49 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [53]:
@jit(parallel = True)
def outer_prod_broadcasting_jit(A, B):
    """Broadcasting trick"""
    return A[...,None]*B[:,None]

def find_permutation_between_source_and_estimation(S,Y):
    """
    S    : Original source matrix
    Y    : Matrix of estimations of sources (after BSS or ICA algorithm)
    
    return the permutation of the source seperation algorithm
    """
    
    # perm = np.argmax(np.abs(np.corrcoef(S.T,Y.T) - np.eye(2*S.shape[1])),axis = 0)[S.shape[1]:]
    # perm = np.argmax(np.abs(np.corrcoef(Y.T,S.T) - np.eye(2*S.shape[1])),axis = 0)[S.shape[1]:]
    perm = np.argmax(np.abs(outer_prod_broadcasting(Y,S).sum(axis = 0))/(np.linalg.norm(S,axis = 0)*np.linalg.norm(Y,axis=0)), axis = 0)
    return perm

def signed_and_permutation_corrected_sources(S,Y):
    perm = find_permutation_between_source_and_estimation(S,Y)
    return np.sign((Y[:,perm] * S).sum(axis = 0)) * Y[:,perm]

In [54]:
S = generate_correlated_copula_sources(rho = rho, df = 4, n_sources = NumberofSources, size_sources = N , 
                                       decreasing_correlation = True)
S = 2 * S - 1

Y = S[[2,1,0,3],:]
Y, NoisePart = addWGN(Y, 24, return_noise = True)

SNRinp = 10 * np.log10(np.sum(np.mean((Y - NoisePart)**2, axis = 1)) / np.sum(np.mean(NoisePart**2, axis = 1)))
print("Input SNR is : {}".format(SNRinp))

Input SNR is : 23.994731034308998


In [55]:
outer_prod_broadcasting_jit(S.T, Y.T)

array([[[-6.20607790e-01, -3.55771856e-01,  5.04593164e-01,
         -6.53838588e-01],
        [ 4.19333880e-01,  2.40388850e-01, -3.40944818e-01,
          4.41787352e-01],
        [ 7.50545492e-01,  4.30260411e-01, -6.10240689e-01,
          7.90733879e-01],
        [ 7.99900030e-01,  4.58553571e-01, -6.50369032e-01,
          8.42731134e-01]],

       [[ 2.61191869e-01,  3.87944658e-01,  6.93774996e-01,
          6.33866176e-01],
        [ 1.61099330e-01,  2.39278600e-01,  4.27910286e-01,
          3.90959400e-01],
        [ 9.63370903e-02,  1.43088143e-01,  2.55889529e-01,
          2.33792971e-01],
        [ 2.65030466e-01,  3.93646073e-01,  7.03971037e-01,
          6.43181768e-01]],

       [[-7.01248803e-02,  5.81639816e-02,  3.32431936e-02,
         -1.16174565e-01],
        [-1.14771333e-01,  9.51952810e-02,  5.44081589e-02,
         -1.90139499e-01],
        [ 1.31382289e-01, -1.08972978e-01, -6.22826996e-02,
          2.17658556e-01],
        [ 2.31847457e-01, -1.92302235e-

In [58]:
%timeit outer_prod_broadcasting_jit(S.T, Y.T)

44.1 ms ± 1.45 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [59]:
%timeit outer_prod_broadcasting(S.T, Y.T)

43.5 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [35]:
10*np.log10(CalculateSINRjit(Y,S)[0])

24.005319336373987

In [36]:
10*np.log10(CalculateSINR(Y,S)[0])

24.00531933637399