In [1]:
import pandas as pd
import numpy as np

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

import seaborn as sns

from scipy import signal, linalg

import warnings

from sklearn.datasets import load_digits

# Implementation of FastICA

http://www.ccs.neu.edu/home/jaa/CS6800.11F/Topics/Papers/Hyvarinen97.pdf

https://www.cs.helsinki.fi/u/ahyvarin/papers/NN00new.pdf

https://www.cs.helsinki.fi/u/ahyvarin/papers/bookfinal_ICA.pdf

## Test Dataset

In [3]:
digits = load_digits()
features = digits.images.reshape(-1, 64)
labels = digits.target

## Loss Function (negative entropy)

Using log cos approximation

In [51]:
def tanh(x, alpha = 1.0):

    x *= alpha
    
    gx = np.tanh(x, x)
    g_x = np.empty(x.shape[0])
    
    for i, gx_i in enumerate(gx):
        g_x[i] = (alpha * (1 - gx_i ** 2)).mean()
        
    return gx, g_x


## Whitening Components (using SVD)

In [42]:
# Whiten and pre-process components
def whiten_components(X, n_components):
    
    n, p = X.shape

    X_mean = X.mean(axis=-1)
    
    # Subtract the mean for 0 mean
    X -= X_mean[:, np.newaxis]
    
    # Preprocessing by PCA
    u, d, _ = linalg.svd(X, full_matrices=False)
    
    # Whitening matrix
    whitening = (u / d).T[:n_components]
    
    # Project data onto the principal components using the whitening matrix
    X1 = np.dot(whitening, X)
    X1 *= np.sqrt(p)
    
    # Return whitened components, whitening matrix, and mean of components
    return X1, whitening, X_mean

## Symmetric Decorrelation

In [43]:
# Symmetric decorrelation of un_mixing matrix
# https://ieeexplore.ieee.org/document/398721/
# Ensures no vectors are privileged over others
def symmetric_decorrelation(un_mixing):
    
    # Find eigenvalues and eigenvectors of initial weight matrix
    eig_values, eig_vectors = linalg.eigh(np.dot(un_mixing, un_mixing.T))
    # Symmetric decorrelation equation
    sym_un_mixing = np.dot(np.dot(eig_vectors * (1 / np.sqrt(eig_values)), eig_vectors.T), un_mixing)
    
    return sym_un_mixing


# Parallel ICA Implementation

In [44]:
def parallel_ica(X, init_un_mixing, alpha = 1.0, max_iter = 1000, tol = 1e-4, return_iter = False):
    
    # Symmetric decorrelation of initial un-mixing components 
    un_mixing = symmetric_decorrelation(init_un_mixing)
    
    
    p = float(X.shape[1])
    
    # Iteratively update the un-mixing matrix
    for i in range(max_iter):
        
        # Function and derivative 
        gwtx, g_wtx = tanh(np.dot(un_mixing, X), alpha)
        
        
        new_un_mixing = symmetric_decorrelation(np.dot(gwtx, X.T) / p - g_wtx[:, np.newaxis] * un_mixing)
        
        # Calculate negative entropy based on logcosh
        lim = max(abs(abs(np.diag(np.dot(new_un_mixing, un_mixing.T))) - 1))
        
        # Update un-mixing 
        un_mixing = new_un_mixing

        # Check for convergence
        if lim < tol:
            break
            
    else:
        warnings.warn('FastICA algorithm did not converge. Considering increasing '
                      'tolerance or increasing the maximum number of iterations.')
        
    if return_iter:
        return un_mixing, i + 1
    else: 
        return un_mixing

# Complete Algorithm

In [54]:
# X = mixing * sources
# sources = un-mixing * whitening * X
def perform_fastica(X, n_components, alpha = 1.0, max_iter = 200, tol = 1e-4):
    
    X1 = X.T
    # Whiten components by subtracting mean
    X1, whitening, X_mean = whiten_components(X1, n_components)
    
    # initial un_mixing components
    init_un_mixing = np.asarray(np.random.normal(size = (n_components, n_components)))
    
    # Solve ica using the parallel ica algorithm
    un_mixing = parallel_ica(X1, init_un_mixing, alpha, max_iter, tol)

    # Calculate the sources
    sources = np.dot(np.dot(un_mixing, whitening), X.T).T
    
    # Calculate the mixing matrix
    w = np.dot(un_mixing, whitening)
    mixing = linalg.pinv(w)
    
    # Return mixing matrix, sources, and mean of X
    return mixing, sources, X_mean

## Inverse ICA Transform

In [55]:
def inverse_fastica(mixing, source, X_mean):
    # Inverse transform
    X = np.dot(sources, mixing.T)
    # Add back in mean
    X += X_mean
    
    return X

# Comparison to Scikit-learn ICA

In [56]:
from sklearn.decomposition import FastICA

In [57]:
ica = FastICA(n_components=2)
ica.fit(features)

FastICA(algorithm='parallel', fun='logcosh', fun_args=None, max_iter=200,
    n_components=2, random_state=None, tol=0.0001, w_init=None,
    whiten=True)

In [58]:
ica.mean_

array([ 0.00000000e+00,  2.00229371e-15, -8.04352566e-15,  4.64205438e-15,
        7.26803765e-16,  7.61154572e-17, -3.33375818e-16, -1.31225025e-16,
       -5.86929256e-18, -5.38929303e-15,  3.63278319e-16,  6.40556056e-16,
        7.88832920e-16, -4.32573040e-15,  6.69470044e-16,  3.92470326e-17,
        3.00275131e-17, -3.07007248e-15, -2.52070670e-16, -5.53566962e-17,
       -3.54931567e-15, -7.37924530e-16, -1.65328704e-16,  3.44735988e-16,
       -1.51559035e-19, -5.06859749e-16, -1.41703257e-15,  3.50427657e-16,
       -2.05116330e-16,  4.91043550e-16,  7.30411836e-15,  6.02374763e-19,
        0.00000000e+00, -2.87805394e-15, -5.91130434e-16, -1.35426203e-16,
       -1.04040933e-15, -3.36341355e-16,  9.26260861e-15,  0.00000000e+00,
       -3.26865536e-18,  1.92018540e-16, -6.81777024e-15, -1.00828268e-16,
       -3.80256020e-15,  2.59978769e-16,  3.83789952e-16, -3.22193271e-17,
       -9.16912854e-17,  5.05129852e-16, -1.28506616e-17,  3.67136606e-15,
        3.45040264e-15,  

In [59]:
mixing, sources, mean = perform_fastica(features, n_components=2)

(64, 1797)


In [60]:
mean

array([ 0.00000000e+00,  2.00229371e-15, -8.04352566e-15,  4.64205438e-15,
        7.26803765e-16,  7.61154572e-17, -3.33375818e-16, -1.31225025e-16,
       -5.86929256e-18, -5.38929303e-15,  3.63278319e-16,  6.40556056e-16,
        7.88832920e-16, -4.32573040e-15,  6.69470044e-16,  3.92470326e-17,
        3.00275131e-17, -3.07007248e-15, -2.52070670e-16, -5.53566962e-17,
       -3.54931567e-15, -7.37924530e-16, -1.65328704e-16,  3.44735988e-16,
       -1.51559035e-19, -5.06859749e-16, -1.41703257e-15,  3.50427657e-16,
       -2.05116330e-16,  4.91043550e-16,  7.30411836e-15,  6.02374763e-19,
        0.00000000e+00, -2.87805394e-15, -5.91130434e-16, -1.35426203e-16,
       -1.04040933e-15, -3.36341355e-16,  9.26260861e-15,  0.00000000e+00,
       -3.26865536e-18,  1.92018540e-16, -6.81777024e-15, -1.00828268e-16,
       -3.80256020e-15,  2.59978769e-16,  3.83789952e-16, -3.22193271e-17,
       -9.16912854e-17,  5.05129852e-16, -1.28506616e-17,  3.67136606e-15,
        3.45040264e-15,  

In [41]:
ica.mixing_.shape

(64, 2)

In [18]:
sources.shape

(1797, 2)

In [21]:
mixing.shape

(64, 2)

In [23]:
ica.whitening_.shape

(2, 64)

In [24]:
X_ica = ica.transform(features)

In [25]:
X_ica.shape

(1797, 2)

In [26]:
X_ica 

array([[-0.03295949, -0.02683043],
       [ 0.03488186,  0.01407798],
       [ 0.01912196,  0.00174326],
       ...,
       [ 0.01965296, -0.00693867],
       [-0.0253482 , -0.01104789],
       [-0.01172893, -0.00928464]])

In [27]:
sources

array([[-0.03028885, -0.02503701],
       [ 0.03763316,  0.01573736],
       [ 0.02184893,  0.00343378],
       ...,
       [ 0.02236279, -0.00524918],
       [-0.0226464 , -0.00926953],
       [-0.00902368, -0.00753317]])

In [29]:
ica.mean_.shape

(64,)

In [30]:
mean.shape

(64,)

In [31]:
ica.mean_

array([0.00000000e+00, 3.03839733e-01, 5.20478575e+00, 1.18358375e+01,
       1.18480801e+01, 5.78185865e+00, 1.36227045e+00, 1.29660545e-01,
       5.56483027e-03, 1.99387869e+00, 1.03823038e+01, 1.19794101e+01,
       1.02793545e+01, 8.17584864e+00, 1.84641068e+00, 1.07957707e-01,
       2.78241514e-03, 2.60155815e+00, 9.90317195e+00, 6.99276572e+00,
       7.09794101e+00, 7.80634391e+00, 1.78853645e+00, 5.00834725e-02,
       1.11296605e-03, 2.46967168e+00, 9.09126322e+00, 8.82136895e+00,
       9.92710072e+00, 7.55147468e+00, 2.31775181e+00, 2.22593211e-03,
       0.00000000e+00, 2.33945465e+00, 7.66722315e+00, 9.07178631e+00,
       1.03016138e+01, 8.74401781e+00, 2.90929327e+00, 0.00000000e+00,
       8.90372844e-03, 1.58375070e+00, 6.88146912e+00, 7.22815804e+00,
       7.67223150e+00, 8.23650529e+00, 3.45631608e+00, 2.72676683e-02,
       7.23427935e-03, 7.04507513e-01, 7.50695604e+00, 9.53923205e+00,
       9.41624930e+00, 8.75848637e+00, 3.72509738e+00, 2.06455203e-01,
      

In [32]:
mean

array([ 0.00000000e+00, -1.23910033e-15,  4.92378042e-15,  3.92044030e-15,
        1.36118162e-15, -1.73286229e-15, -1.47214214e-15, -4.54993738e-16,
       -5.88087669e-18,  5.09442238e-15,  1.78920750e-15, -1.86828850e-15,
        5.33005903e-15, -7.62835043e-15,  2.86989872e-15,  5.59220017e-16,
        3.26460092e-17,  3.32535582e-15,  1.96615123e-15, -3.14050399e-15,
        5.26580572e-15,  2.44508550e-15, -2.36130907e-16, -4.61071545e-16,
       -1.87662907e-18, -1.02063909e-15, -2.78365101e-15, -3.34908012e-15,
       -3.08514730e-15, -4.52738694e-15, -4.62722669e-15, -1.04257171e-17,
        0.00000000e+00,  1.74126465e-15, -4.07316547e-15, -2.69863894e-16,
       -2.91018060e-15, -4.08255634e-16, -5.83790729e-15,  0.00000000e+00,
       -1.76542142e-17,  1.04411626e-15,  5.59893441e-15, -5.58509524e-16,
        6.54592932e-15, -2.27357859e-15,  1.97331794e-15,  1.56482290e-16,
        6.86340399e-17, -5.92118946e-16,  8.58028791e-16, -5.83420037e-15,
        3.58138054e-15,  

In [34]:
features.T.mean(axis=-1)

array([ 0.00000000e+00,  2.08279569e-15, -8.82445048e-15, -1.53812534e-14,
        9.35379888e-16, -1.18028384e-15, -1.05004733e-15, -1.89331022e-16,
        8.84544864e-18, -5.38936254e-15,  3.83542823e-16,  2.41987843e-15,
        2.39022306e-15,  1.12141793e-14,  1.52972299e-15,  1.56339419e-16,
       -8.26291170e-17, -2.91981859e-15,  7.84878871e-16, -2.08674974e-15,
       -9.57862368e-15,  1.46892947e-15, -2.17966991e-16,  3.28425533e-16,
       -1.71083121e-18, -7.84878871e-16, -1.30681343e-15,  3.83542823e-16,
       -2.05116330e-16,  2.33733764e-15,  8.44313181e-15, -8.13012861e-18,
        0.00000000e+00, -3.19833197e-15, -1.16792744e-15, -1.35426203e-16,
       -1.04040933e-15, -3.92192307e-16,  9.35577591e-15,  0.00000000e+00,
       -1.62486731e-17,  3.13605569e-16, -6.98285182e-15, -1.32584230e-16,
       -2.92228988e-15,  9.61822596e-16,  8.38999926e-16,  1.26471671e-16,
       -9.25697486e-17,  5.05624109e-16,  4.98210265e-16,  3.13114402e-15,
       -6.17523716e-15,  