In [1]:
import sys
import os
os.chdir("..")
os.chdir("..")
os.chdir("./src")
# sys.path.append("./src")

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy.io
from tqdm import tqdm
from scipy.stats import invgamma, chi2, t
from WSMBSS import *
from numba import njit
from IPython import display
import pylab as pl
import mne 
from mne.preprocessing import ICA
import warnings
warnings.filterwarnings("ignore")
np.random.seed(1569)
# %load_ext autoreload
# %autoreload 2
notebook_name = 'Antisparse_Copula'

In [3]:
rho = 0.6
N = 500000
NumberofSources = 4
NumberofMixtures = 8

M = NumberofMixtures
r = NumberofSources

SNR = 30 # dB
NoiseAmp = (10 ** (-SNR/20))# * np.sqrt(NumberofSources)

S = generate_correlated_copula_sources(rho = rho, df = 4, n_sources = NumberofSources, size_sources = N , 
                                       decreasing_correlation = True)
S = 2 * S - 1
INPUT_STD = 0.5
A = np.random.standard_normal(size=(NumberofMixtures,NumberofSources))
X = A @ S
for MM in range(A.shape[0]):
    stdx = np.std(X[MM,:])
    A[MM,:] = A[MM,:]/stdx * INPUT_STD
Xn = A @ S
Noisecomp=np.random.randn(A.shape[0],S.shape[1])*np.power(10,-SNR/20)*INPUT_STD
X=Xn+Noisecomp
SNRinp = 20*np.log10(np.std(Xn)/np.std(Noisecomp))

print("The following is the mixture matrix A")
display_matrix(A)
print("Input SNR is : {}".format(SNRinp))

The following is the mixture matrix A


<IPython.core.display.Math object>

Input SNR is : 30.002407782786484


In [4]:
if rho > 0.4:
    MUS = 0.25
    gamma_stop = 5*1e-4
else:
    MUS = 0.6
    gamma_stop = 1e-3
OUTPUT_COMP_TOL = 1e-6
MAX_OUT_ITERATIONS= 3000
LayerGains = [1,1]
LayerMinimumGains = [0.2,0.2]
LayerMaximumGains = [1e6,5]
WScalings = [0.005,0.005]
GamScalings = [2,1]
zeta = 5*1e-5
beta = 0.5
muD = [1.125, 0.2]

s_dim = S.shape[0]
x_dim = X.shape[0]
h_dim = s_dim
samples = S.shape[1]
W_HX = np.eye(h_dim, x_dim)
W_YH = np.eye(s_dim, h_dim)

In [5]:
debug_iteration_point = 10000
model = OnlineWSMBSS(s_dim = s_dim, x_dim = x_dim, h_dim = h_dim, 
                        gamma_start = MUS, gamma_stop = gamma_stop, beta = beta, zeta = zeta, 
                        muD = muD,WScalings = WScalings, GamScalings = GamScalings,
                        W_HX = W_HX, W_YH = W_YH,
                        DScalings = LayerGains, LayerMinimumGains = LayerMinimumGains,
                        LayerMaximumGains = LayerMaximumGains,neural_OUTPUT_COMP_TOL = OUTPUT_COMP_TOL,
                        set_ground_truth = True, S = S, A = A)

# modelWSM.fit_batch_antisparse(X, n_epochs = 1, 
#                                 neural_lr_start = 0.75,
#                                 neural_lr_stop = 0.05,
#                                 debug_iteration_point = debug_iteration_point,
#                                 plot_in_jupyter = True,
#                                 )

In [6]:
@njit
def run_neural_dynamics_antisparse_jit(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, 
                                       neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL):

    Gamma_H = np.diag(np.diag(M_H))
    M_hat_H = M_H - Gamma_H

    Gamma_Y = np.diag(np.diag(M_Y))
    M_hat_Y = M_Y - Gamma_Y

    v = ((1 - beta) * Gamma_H + beta * D1 @ Gamma_H @ D1) @ h
    u = Gamma_Y @ D2 @ y

    PreviousMembraneVoltages = {'v': np.zeros_like(v), 'u': np.zeros_like(u)}
    MembraneVoltageNotSettled = 1
    OutputCounter = 0
    while MembraneVoltageNotSettled & (OutputCounter < neural_dynamic_iterations):
        OutputCounter += 1
        MUV = max(lr_start/(1+OutputCounter*0.005), lr_stop)

        delv = -v + (1 - zeta) * beta * D1 @ W_HX @ x_current
        delv = delv - ((1 - zeta) * (1 - beta) * M_hat_H  + (1- zeta) * beta * D1 @ M_hat_H @ D1) @ h
        delv = delv + (1 - zeta) * (1 - beta) * W_YH.T @ D2 @ y
        v = v + MUV * delv
        h = v / np.diag(Gamma_H * ((1 - zeta) * (1 - beta) + (1 - zeta) * beta * D1 ** 2))

        delu = -u + W_YH @ h
        delu = delu - M_hat_Y @ D2 @ y
        u = u + (MUV) * delu
        y = u / np.diag(Gamma_Y * (D2))
        y = y*(y>=-1.0)*(y<=1.0)+(y>1.0)*1.0-1.0*(y<-1.0)


        MembraneVoltageNotSettled = 0
        if (np.linalg.norm(v - PreviousMembraneVoltages['v'])/np.linalg.norm(v) > OUTPUT_COMP_TOL) | (np.linalg.norm(u - PreviousMembraneVoltages['u'])/np.linalg.norm(u) > OUTPUT_COMP_TOL):
            MembraneVoltageNotSettled = 1
        PreviousMembraneVoltages['v'] = v
        PreviousMembraneVoltages['u'] = u

    return h,y, OutputCounter


@njit
def run_neural_dynamics_antisparse_jitV2(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, 
                                         neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL):
    
    Gamma_H = np.diag(np.diag(M_H))
    M_hat_H = M_H - Gamma_H

    Gamma_Y = np.diag(np.diag(M_Y))
    M_hat_Y = M_Y - Gamma_Y

    v = ((1 - beta) * Gamma_H + beta * D1 @ Gamma_H @ D1) @ h
    u = Gamma_Y @ D2 @ y
    
    mat_factor1 = (1 - zeta) * beta * D1 @ W_HX
    mat_factor2 = ((1 - zeta) * (1 - beta) * M_hat_H  + (1- zeta) * beta * D1 @ M_hat_H @ D1)
    mat_factor3 = (1 - zeta) * (1 - beta) * W_YH.T @ D2
    mat_factor4 = (1 - zeta) * Gamma_H * ((1 - beta) + beta * D1 ** 2)
    mat_factor5 = M_hat_Y @ D2
    mat_factor6 = Gamma_Y * (D2)
    
    PreviousMembraneVoltages = {'v': np.zeros_like(v), 'u': np.zeros_like(u)}
    MembraneVoltageNotSettled = 1
    OutputCounter = 0
    while MembraneVoltageNotSettled & (OutputCounter < neural_dynamic_iterations):
        OutputCounter += 1
        MUV = max(lr_start/(1+OutputCounter*0.005), lr_stop)

        delv = -v + mat_factor1 @ x_current
        delv = delv - mat_factor2 @ h
        delv = delv + mat_factor3 @ y
        v = v + MUV * delv
        h = v / np.diag(mat_factor4)

        delu = -u + W_YH @ h
        delu = delu - mat_factor5 @ y
        u = u + (MUV) * delu
        y = u / np.diag(mat_factor6)
        y = y*(y>=-1.0)*(y<=1.0)+(y>1.0)*1.0-1.0*(y<-1.0)


        MembraneVoltageNotSettled = 0
        if (np.linalg.norm(v - PreviousMembraneVoltages['v'])/np.linalg.norm(v) > OUTPUT_COMP_TOL) | (np.linalg.norm(u - PreviousMembraneVoltages['u'])/np.linalg.norm(u) > OUTPUT_COMP_TOL):
            MembraneVoltageNotSettled = 1
        PreviousMembraneVoltages['v'] = v
        PreviousMembraneVoltages['u'] = u

    return h,y, OutputCounter


@njit
def run_neural_dynamics_antisparse_jitV3(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, 
                                         neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL):
    
#     @njit
    def ddiag(A):
        return np.diag(np.diag(A))
    
    Gamma_H = ddiag(M_H)
    M_hat_H = M_H - Gamma_H

    Gamma_Y = ddiag(M_Y)
    M_hat_Y = M_Y - Gamma_Y
    
    mat_factor1 = (1 - zeta) * beta * D1 @ W_HX
    mat_factor2 = ((1 - zeta) * (1 - beta) * M_hat_H  + (1- zeta) * beta * D1 @ M_hat_H @ D1)
    mat_factor3 = (1 - zeta) * (1 - beta) * W_YH.T @ D2
    mat_factor4 = (1 - zeta) * Gamma_H * ((1 - beta) + beta * D1 ** 2)
    mat_factor5 = M_hat_Y @ D2
    mat_factor6 = Gamma_Y * (D2)
    

    v = mat_factor4 @ h
    u = mat_factor6 @ y
    
    PreviousMembraneVoltages = {'v': np.zeros_like(v), 'u': np.zeros_like(u)}
    MembraneVoltageNotSettled = 1
    OutputCounter = 0
    while MembraneVoltageNotSettled & (OutputCounter < neural_dynamic_iterations):
        OutputCounter += 1
        MUV = max(lr_start/(1+OutputCounter*0.005), lr_stop)

        delv = -v + mat_factor1 @ x_current - mat_factor2 @ h + mat_factor3 @ y
        v = v + MUV * delv
        h = v / np.diag(mat_factor4)

        delu = -u + W_YH @ h - mat_factor5 @ y
        u = u + MUV * delu
        y = u / np.diag(mat_factor6)
        y = y*(y>=-1.0)*(y<=1.0)+(y>1.0)*1.0-1.0*(y<-1.0)

        MembraneVoltageNotSettled = 0
        if (np.linalg.norm(v - PreviousMembraneVoltages['v'])/np.linalg.norm(v) > OUTPUT_COMP_TOL) | (np.linalg.norm(u - PreviousMembraneVoltages['u'])/np.linalg.norm(u) > OUTPUT_COMP_TOL):
            MembraneVoltageNotSettled = 1
        PreviousMembraneVoltages['v'] = v
        PreviousMembraneVoltages['u'] = u

    return h,y, OutputCounter

@njit
def run_neural_dynamics_antisparse_jitV4(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, 
                                         neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL):
    
    def ddiag(A):
        return np.diag(np.diag(A))
    
    def offdiag(A, return_diag = False):
        if return_diag:
            diag = np.diag(A)
            return A - np.diag(diag), diag
        else:
            return A - np.diag(diag)
        
#     Gamma_H = ddiag(M_H)
#     M_hat_H = M_H - Gamma_H

#     Gamma_Y = ddiag(M_Y)
#     M_hat_Y = M_Y - Gamma_Y
    
    M_hat_H, Gamma_H = offdiag(M_H, True)
    M_hat_Y, Gamma_Y = offdiag(M_Y, True)
    
    mat_factor1 = (1 - zeta) * beta * (D1 * W_HX)
    mat_factor2 = ((1 - zeta) * (1 - beta) * M_hat_H  + (1- zeta) * beta * ((D1 * M_hat_H) * D1.T))
    mat_factor3 = (1 - zeta) * (1 - beta) * (W_YH.T * D2.T)
    mat_factor4 = (1 - zeta) * Gamma_H * ((1 - beta) + beta * D1 ** 2)
    mat_factor5 = M_hat_Y * D2.T
    mat_factor6 = Gamma_Y * D2
    

    v = mat_factor4 @ h
    u = mat_factor6 @ y
    
    PreviousMembraneVoltages = {'v': np.zeros_like(v), 'u': np.zeros_like(u)}
    MembraneVoltageNotSettled = 1
    OutputCounter = 0
    while MembraneVoltageNotSettled & (OutputCounter < neural_dynamic_iterations):
        OutputCounter += 1
        MUV = max(lr_start/(1+OutputCounter*0.005), lr_stop)

        delv = -v + mat_factor1 @ x_current - mat_factor2 @ h + mat_factor3 @ y
        v = v + MUV * delv
        h = v / np.diag(mat_factor4)

        delu = -u + W_YH @ h - mat_factor5 @ y
        u = u + MUV * delu
        y = u / np.diag(mat_factor6)
        y = np.clip(y, -1, 1)

        MembraneVoltageNotSettled = 0
        if (np.linalg.norm(v - PreviousMembraneVoltages['v'])/np.linalg.norm(v) > OUTPUT_COMP_TOL) | (np.linalg.norm(u - PreviousMembraneVoltages['u'])/np.linalg.norm(u) > OUTPUT_COMP_TOL):
            MembraneVoltageNotSettled = 1
        PreviousMembraneVoltages['v'] = v
        PreviousMembraneVoltages['u'] = u

    return h,y, OutputCounter

@njit( parallel=True )
def run_neural_dynamics_antisparse_jitV5(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, 
                                         neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL):
    
    def ddiag(A):
        return np.diag(np.diag(A))
    
    def offdiag(A, return_diag = False):
        if return_diag:
            diag = np.diag(A)
            return A - np.diag(diag), diag
        else:
            return A - np.diag(diag)
        
#     Gamma_H = ddiag(M_H)
#     M_hat_H = M_H - Gamma_H

#     Gamma_Y = ddiag(M_Y)
#     M_hat_Y = M_Y - Gamma_Y
    
    M_hat_H, Gamma_H = offdiag(M_H, True)
    M_hat_Y, Gamma_Y = offdiag(M_Y, True)
    
    mat_factor1 = (1 - zeta) * beta * (D1 * W_HX)
    mat_factor2 = ((1 - zeta) * (1 - beta) * M_hat_H  + (1- zeta) * beta * ((D1 * M_hat_H) * D1.T))
    mat_factor3 = (1 - zeta) * (1 - beta) * (W_YH.T * D2.T)
    mat_factor4 = (1 - zeta) * Gamma_H * ((1 - beta) + beta * D1 ** 2)
    mat_factor5 = M_hat_Y * D2.T
    mat_factor6 = Gamma_Y * D2
    

    v = mat_factor4 @ h
    u = mat_factor6 @ y
    
    PreviousMembraneVoltages = {'v': np.zeros_like(v), 'u': np.zeros_like(u)}
    MembraneVoltageNotSettled = 1
    OutputCounter = 0
    while MembraneVoltageNotSettled & (OutputCounter < neural_dynamic_iterations):
        OutputCounter += 1
        MUV = max(lr_start/(1+OutputCounter*0.005), lr_stop)

        delv = -v + mat_factor1 @ x_current - mat_factor2 @ h + mat_factor3 @ y
        v = v + MUV * delv
        h = v / np.diag(mat_factor4)

        delu = -u + W_YH @ h - mat_factor5 @ y
        u = u + MUV * delu
        y = u / np.diag(mat_factor6)
        y = np.clip(y, -1, 1)

        MembraneVoltageNotSettled = 0
        if (np.linalg.norm(v - PreviousMembraneVoltages['v'])/np.linalg.norm(v) > OUTPUT_COMP_TOL) | (np.linalg.norm(u - PreviousMembraneVoltages['u'])/np.linalg.norm(u) > OUTPUT_COMP_TOL):
            MembraneVoltageNotSettled = 1
        PreviousMembraneVoltages['v'] = v
        PreviousMembraneVoltages['u'] = u

    return h,y, OutputCounter

In [7]:
W_HX = model.W_HX
W_YH = model.W_YH
M_H = model.M_H
M_Y = model.M_Y
D1 = model.D1
D2 = model.D2
d1 = np.diag(D1).reshape(-1,1)
d2 = np.diag(D2).reshape(-1,1)
neural_dynamic_iterations = 750
lr_start = 0.75
lr_stop = 0.05
OUTPUT_COMP_TOL = 1e-6

def ddiag(A):
    return np.diag(np.diag(A))

Gamma_H = ddiag(M_H)
M_hat_H = M_H - Gamma_H

Gamma_Y = ddiag(M_Y)
M_hat_Y = M_Y - Gamma_Y
    
    
H = np.zeros((h_dim,samples))
Y = np.zeros((s_dim,samples))

x_current  = X[:,0] # Take one input

y = Y[:,0]

h = H[:,0]

In [23]:
h1,y1,_ = run_neural_dynamics_antisparse_jit(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, 
                                    neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)
h2,y2,_ =run_neural_dynamics_antisparse_jitV2(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, 
                                    neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)
h3,y3,_ =run_neural_dynamics_antisparse_jitV3(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, 
                                    neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)
h4,y4,_ =run_neural_dynamics_antisparse_jitV4(x_current, h, y, M_H, M_Y, W_HX, W_YH, d1, d2, beta, zeta, 
                                    neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)

# h5,y5,_ =run_neural_dynamics_antisparse_jitV5(x_current, h, y, M_H, M_Y, W_HX, W_YH, d1, d2, beta, zeta, 
#                                     neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)

In [9]:
np.linalg.norm(h1 -h2),np.linalg.norm(y1 -y2),np.linalg.norm(h2 -h3),np.linalg.norm(y2 -y3),np.linalg.norm(h3 -h4),np.linalg.norm(y3 -y4)

(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)

In [10]:
%timeit run_neural_dynamics_antisparse_jit(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)

197 µs ± 27.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
%timeit run_neural_dynamics_antisparse_jitV4(x_current, h, y, M_H, M_Y, W_HX, W_YH, d1, d2, beta, zeta, neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)

104 µs ± 82.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:
%timeit run_neural_dynamics_antisparse_jitV2(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)

149 µs ± 20.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [13]:
%timeit run_neural_dynamics_antisparse_jitV3(x_current, h, y, M_H, M_Y, W_HX, W_YH, D1, D2, beta, zeta, neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)

109 µs ± 98.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [24]:
# %timeit run_neural_dynamics_antisparse_jitV5(x_current, h, y, M_H, M_Y, W_HX, W_YH, d1, d2, beta, zeta, neural_dynamic_iterations, lr_start, lr_stop, OUTPUT_COMP_TOL)

In [25]:
d1.shape

(4, 1)

In [29]:
np.diag(D1).shape

(4,)