# Fitting Gaussian Mixture on a data sample using features from a pre-trained Model

Consider the set of all $K$-dimensional categorical distributions given by

\begin{equation}
    \mathcal{C} = \bigg\{ \mathbf{c} \in \mathbb{R}^{K}_{+} :  \mathbf{c} \geq 0,\ \mathbf{c}^{\rm T}\mathbf{1} = 1 \bigg\} \subset \mathbb{R}^{K}_{+}.
\end{equation}

In this code we consider the set of gaussian mixtures for the set of distributions $\{ \nu_{i} \}_{i = 1}^{K} $ given by given by:
\begin{equation}
		\mathcal{GM} \bigg(\mathbf{c}, \{ \nu_{i} \}_{i = 1}^{K} \bigg) = \bigg\{ \sum_{i = 1}^{K} c_{i} \nu_{i}, \ \forall \mathbf{c} \in \mathcal{C} \bigg\}
	\end{equation}
For the mixture of gaussian distributions $\{ \nu_{i} \}_{i = 1}^{K} $ with means $\{ \mu_{i} \}_{i = 1}^{K} $ and covariance matrices $\{ \Sigma_{i} \}_{i = 1}^{K} $, mixture means and covariance are


\begin{gather*}
 \bar{\mu}_{\mathbf{c}} = \sum_{k =1}^{K} c_{k} \mu_{k}\  \text{,  }
 \,\,\,\, \tilde{\Sigma}_{\mathbf{c}} = \sum_{k =1}^{K} c_k \big( \Sigma_k + \mu_{k} \mu_{k}^{\top} - \bar{\mu}_{\mathbf{c}} \bar{\mu}_{\mathbf{c}}^\top \big)
\end{gather*}

For a data sample $\nu_{*}$, we employed Frank-Wolfe Based optimization routine to find the best matching mean, covaraince pair from the the set of all possible mixtures of $\{ \nu_{i} \}_{i = 1}^{K} $ to fit with mean  and covaraince of $\nu_{*}$.

In [1]:
#%%
import numpy as np
import scipy as sp
from scipy.optimize import check_grad
from scipy.optimize import approx_fprime
from matplotlib import pyplot as plt

'''
    mx: mean of source with size: n x 1
    M: means of target distributions with size: n x K, where K indicates number of data-classes.
    covx: covariance matrix of source with size: n x 1
    covM: array of covariance matrices of target distributions with size: K x n x n, where K indicates number of data-classes.
    max_iter: max_iterations of Frank Wolfe
    tol: tolerance for gradient check
    divg: divergences to be used for computations
    max_iter: maximum of frank wolfe iterations
    min_tol: minimum tolerance for gardient checking
    print_iter: priniting values of maixture weights for each iteration
'''
class GM_FW:
    def __init__(self, mux, M, covx, covM, c0, eta, max_iter = None, min_tol = None, print_iter = None):
        # instance attributes which should be initilzied with class
        self.mux = mux
        self.M = M
        self.covx = covx
        self.covM = covM
        self.c0 = c0

        
        self.sqrt_covx = np.real(sp.linalg.sqrtm(covx))
        self.sqrt_covx = (self.sqrt_covx + self.sqrt_covx.T) /2
        
        self.I = np.identity(np.size(mux), dtype=np.float64)

        # attributes to be keep track of computations
        self.c = np.zeros_like(c0, dtype=np.float64)
        self.mubar = np.zeros_like(mux, dtype=np.float64)
        self.hat_covM = self.covM + self.M.T.reshape([self.M.shape[1], self.M.shape[0], 1]) @ self.M.T.reshape([M.shape[1], 1, M.shape[0]])
        self.tilde_covM = np.zeros_like(covM, dtype=np.float64)
        self.tilde_sigma_c = np.zeros([covM.shape[1], covM.shape[2]],dtype=np.float64)
        self.zeta_c = np.zeros([covM.shape[1], covM.shape[2]], dtype=np.float64)
        
        if max_iter == None:
            self.max_iter = 1000
        else:
            self.max_iter = max_iter

        if min_tol == None:
            self.min_tol = 1e-6
        else:
            self.min_tol = min_tol
        
        if print_iter == None:
            self.print_iter = False
        elif print_iter == False:
            self.print_iter = False
        elif print_iter == True:
            self.print_iter = True
        else:
            self.print_iter = False

        self.obj_vals = np.zeros([self.max_iter], dtype=np.float64)
        self.grads = np.zeros([self.max_iter, np.size(self.c)], dtype=np.float64)
        self.weights = np.zeros([self.max_iter, np.size(self.c)], dtype=np.float64)
        self.eta = eta
        
    
    def reset_values(self):

        self.sqrt_covx = sp.linalg.sqrtm(self.covx)
        self.sqrt_covx = (self.sqrt_covx + self.sqrt_covx.T) /2
        self.I = np.identity(np.size(self.mux))

        # attributes to be keep track of computations
        self.c = np.zeros_like(self.c0)
        self.mubar = np.zeros_like(self.mux)
        self.hat_covM = self.covM + \
            self.M.T.reshape([self.M.shape[1], self.M.shape[0], 1]
                         ) @ self.M.T.reshape([self.M.shape[1], 1, self.M.shape[0]])
        self.tilde_covM = np.zeros_like(self.covM)
        self.tilde_sigma_c = np.zeros([self.covM.shape[1], self.covM.shape[2]])
        self.zeta_c = np.zeros([self.covM.shape[1], self.covM.shape[2]])

        self.obj_vals = np.zeros([self.max_iter])
        self.grads = np.zeros([self.max_iter, np.size(self.c)])
        self.weights = np.zeros([self.max_iter, np.size(self.c)])
        

    def initialize_weights(self):
        self.c = self.c0

    def update_mubar(self):
        self.mubar =  np.sum(self.c[np.newaxis,:] * self.M, axis = 1, keepdims = True)
    
    def compute_mmd(self):
        return np.linalg.norm(self.mux - self.M@self.c)**2
    
    def compute_mmd_grad(self):
        return 2*self.M.T @ (self.M @ self.c[:,np.newaxis] - self.mux)

    def update_tilde_covM(self):
        self.tilde_covM = self.hat_covM - (self.mubar@self.mubar.T)[np.newaxis,:,:]

    def update_tilde_sigma_c(self):
        self.tilde_sigma_c = np.sum(
            self.c[:, np.newaxis, np.newaxis] * self.tilde_covM, axis=0)
            
    def update_zeta_c(self):
        temp = sp.linalg.sqrtm(self.sqrt_covx@self.tilde_sigma_c@self.sqrt_covx +  self.eta **4 /4 * self.I)
        temp = np.real(temp)
        self.zeta_c = (temp + temp.T)/2

    def compute_bures(self):
        if self.eta == 0:
            return np.trace( self.covx + self.tilde_sigma_c - 2 * self.zeta_c )
        else:
            sign, val = np.linalg.slogdet(2*self.zeta_c + self.eta**2 * self.I)
            if sign <= 0:
                "error determinant is negative"
            B = np.trace( self.covx + self.tilde_sigma_c - 2 * self.zeta_c )
            B = B + self.covx.shape[0]*self.eta**2 * (1 - np.log(2*self.eta**2))
            B = B + self.eta**2 * val 
            return B
    
    def compute_bures_grad_tilde_sigma_c(self):
        return self.I - self.sqrt_covx @ sp.linalg.pinv(self.zeta_c + self.eta **2 /2 * self.I) @ self.sqrt_covx
    
    def compute_bures_grad_c(self):
        G = self.compute_bures_grad_tilde_sigma_c()
        rho = (self.M.T.reshape(self.M.shape[1], self.M.shape[0],1) @ self.mubar.T) + self.mubar @ self.M.T.reshape(self.M.shape[1],1,self. M.shape[0])
        temp = np.sum(G[np.newaxis, :, :] * (self.hat_covM - rho ), axis=(1, 2))[:, np.newaxis]
        return temp
    
    def compute_frechet(self):
        B = self.compute_bures()
        mmd = self.compute_mmd()
        return B + mmd

    def compute_frechet_grad(self):
        GB = self.compute_bures_grad_c()
        Gmmd = self.compute_mmd_grad()
        return GB + Gmmd


    def FW_frechet_routine(self):
        self.initialize_weights()
        for i in range(0, self.max_iter):
            self.update_mubar()
            self.update_tilde_covM()
            self.update_tilde_sigma_c()
            self.update_zeta_c()
            self.compute_bures_grad_c()
            self.obj_vals[i] = self.compute_frechet()
            g = self.compute_frechet_grad()
            #print(g)
            self.grads[i, :] = np.squeeze(g)
            s = np.zeros_like(self.c)
            idx = np.argmin(g)
            s[idx] = 1
            gamma = 2 / (2 + i)
            self.c = (1 - gamma) * self.c + gamma * s
            
            self.weights[i, :] = self.c
            if self.print_iter == True:
                print("iter = ",i)
                print(self.c)
    
    def FW_bures_routine(self):

        #commented part must be uncommented for gradient-checking
        
        def bures_grad_num(c_i, epsilon):
            
            num_grad = np.zeros_like(c_i)
            I = np.identity(c_i.size)
            for i in range(0, num_grad.size):
                num_grad[i] = (bures(c_i + epsilon * I[:,i]) - bures(c_i) ) / epsilon
            
            return num_grad
            

        def bures(c_i):
            
            mubar =  np.sum(c_i[np.newaxis,:] * self.M, axis = 1, keepdims = True)
            # k, n, n - n,n => k,n,n
            tilde_covM = self.hat_covM 
            #k,.,. * k,n,n => n,n
            tilde_sigma_c = np.sum(c_i[:, np.newaxis, np.newaxis] * tilde_covM, axis=0) - mubar@mubar.T

            #n,n
            zeta_ci = sp.linalg.sqrtm(self.sqrt_covx@tilde_sigma_c@self.sqrt_covx +  self.eta **4 /4 * self.I)
            zeta_ci= np.real(zeta_ci)
            zeta_ci = (zeta_ci + zeta_ci.T)/2
            #scalar
            
            if self.eta == 0:
                return np.trace( self.covx + tilde_sigma_c - 2 * zeta_ci )
            else:
                sign, val = np.linalg.slogdet(2*zeta_ci + self.eta**2 * self.I)
                if sign <= 0:
                    "error determinant is negative"
                B = np.trace( self.covx + tilde_sigma_c - 2 * zeta_ci )
                B = B + self.covx.shape[0]*self.eta**2 * (1 - np.log(2*self.eta**2))
                B = B + self.eta**2 * val 
                return B
            
        def bures_grad(c_i):

            mubar =  np.sum(c_i[np.newaxis,:] * self.M, axis = 1, keepdims = True)

            tilde_covM = self.hat_covM - (mubar@mubar.T)[np.newaxis,:,:]

            tilde_sigma_c = np.sum(c_i[:, np.newaxis, np.newaxis] * tilde_covM, axis=0)

            zeta_ci = sp.linalg.sqrtm(self.sqrt_covx@tilde_sigma_c@self.sqrt_covx +  self.eta **4 /4 * self.I)
            zeta_ci= np.real(zeta_ci)
            zeta_ci = (zeta_ci + zeta_ci.T)/2

            G = self.I - self.sqrt_covx @ sp.linalg.pinv(zeta_ci + self.eta **2 /2 * self.I) @ self.sqrt_covx
            G = (G + G.T)/2

            rho = (M.T.reshape(M.shape[1],M.shape[0],1) @ mubar.T) +  mubar @ M.T.reshape(M.shape[1],1,M.shape[0])

            return np.sum(G[np.newaxis, :, :] * (self.hat_covM - rho ), axis=(1, 2))
            

        self.initialize_weights()
        for i in range(0, self.max_iter):
            self.update_mubar()
            self.update_tilde_covM()
            self.update_tilde_sigma_c()
            self.update_zeta_c()
            self.obj_vals[i] = self.compute_bures()
            g = self.compute_bures_grad_c()
            
            print("g")
            print(np.squeeze(g))
            c_i = self.c
            
            # print(np.squeeze(bures_grad(c_i)))
            # #  print(np.squeeze(self.compute_bures_grad_c()))

            eps = np.sqrt(np.finfo(np.float64).eps)
            # #eps = 100*(np.finfo(np.float64).eps)
            # eps = 1e-8
            ga = sp.optimize.approx_fprime(c_i, bures, np.array(eps*np.ones(np.size(g))).T)
            #ga = bures_grad_num(c_i, eps)
            print("ga")
            print(ga)
            
            #print(sp.optimize.check_grad(bures, bures_grad, c_i))
            
            s = np.zeros_like(self.c)
            idx = np.argmin(g)
            s[idx] = 1
            gamma = 2 / (2 + i)
            self.c = (1 - gamma) * self.c + gamma * s
            self.grads[i, :] = np.squeeze(g)
            self.weights[i, :] = self.c

            if self.print_iter == True:
                print("iter =  ",i)
                print(self.c)

    def FW_MMD_routine(self):
        
        # commented part must be uncommented for gradient-checking
        
        def mmd(c_i):
            return (np.linalg.norm(self.M@c_i[:,np.newaxis] - self.mux))**2

        def mmd_grad(c_i):
            return np.squeeze(2 * self.M.T @ (self.M @ c_i[:,np.newaxis] - self.mux))
        
        self.initialize_weights()
        for i in range(0, self.max_iter):
            self.update_mubar()
            self.update_tilde_covM()
            self.obj_vals[i] = self.compute_mmd()

            g = self.compute_mmd_grad()

            # print("g")
            # print(np.squeeze(g))
            # c_i = self.c
            # # print(np.squeeze(mmd_grad(c_i)))
            # # #print(np.squeeze(self.compute_mmd_grad()))

            # eps = np.sqrt(np.finfo(float).eps)
            
            # print("ga")
            # ga = sp.optimize.approx_fprime(c_i, mmd, np.array(eps*np.ones(np.size(g))).T)
            # print(ga)

            #print(sp.optimize.check_grad(mmd, mmd_grad, c_i))

            s = np.zeros_like(self.c)
            idx = np.argmin(g)
            s[idx] = 1
            gamma = 2 / (2 + i)
            self.c = (1 - gamma) * self.c  + gamma * s 
            self.grads[i,:] = np.squeeze(g)
            self.weights[i,:] = self.c

            if self.print_iter == True:
                print("iter = ",i)
                print(self.c)


def Hellinger(p, q):
  '''
  Function to evaluate Hellinger Distance between distributions categorical distributions p and q
  input: distributions p and q
  outputs: Hillnger distance between p and q

  '''
  assert np.all(p >= 0), "first distribution must be non-negative"
  assert np.all(q >= 0), "second distribution must be non-negative"
  assert p.size == q.size, "input vectors must be of same size"

  p = p.reshape([p.size, 1])
  q = q.reshape([q.size, 1])

  return (1/np.sqrt(2)) * np.linalg.norm(np.sqrt(p) - np.sqrt(q))


Testing on Gaussians in 2D, which are easily visualizible

In [2]:
import numpy as np
import scipy as sp

#M = np.array([[0, -0.5, 0.5],[1, 0, 0]])
M = np.array([[0, 0, 0], [0, 0, 0]])
#M = np.ones_like(M)
mux = np.array([[np.sqrt(3)/2], [np.sqrt(3)/2] ])
covx = np.array([[10,6],[6,8]])
covM = np.zeros([3,2,2])
covM[0] = np.array([[1, 0.5],[0.5, 1]]) 
covM[1] = np.array([[5, 0.6], [0.6, 7]])
covM[2] = np.array([[4, 1], [1, 2]])

#mux = np.mean(M, axis = 1, keepdims = True)
#mux = np.array([[1], [1] ])
#covx = 1/3 * covM[0] + 2/3 * covM[1] + 0/3 * covM[1]


c0 = np.ones(3)/3
 
iterations = 1000

test = GM_FW(mux, M, covx, covM, c0, eta = 0.000, max_iter = iterations, print_iter = True)
test.reset_values()
#test.FW_bures_routine()
test.FW_frechet_routine()
#test.reset_values()

iter =  0
[0. 1. 0.]
iter =  1
[0.         0.33333333 0.66666667]
iter =  2
[0.         0.66666667 0.33333333]
iter =  3
[0.  0.8 0.2]
iter =  4
[0.         0.86666667 0.13333333]
iter =  5
[0.        0.9047619 0.0952381]
iter =  6
[0.         0.92857143 0.07142857]
iter =  7
[0.         0.94444444 0.05555556]
iter =  8
[0.         0.95555556 0.04444444]
iter =  9
[0.         0.78181818 0.21818182]
iter =  10
[0.         0.81818182 0.18181818]
iter =  11
[0.         0.84615385 0.15384615]
iter =  12
[0.         0.86813187 0.13186813]
iter =  13
[0.         0.88571429 0.11428571]
iter =  14
[0.  0.9 0.1]
iter =  15
[0.         0.91176471 0.08823529]
iter =  16
[0.         0.92156863 0.07843137]
iter =  17
[0.         0.92982456 0.07017544]
iter =  18
[0.         0.93684211 0.06315789]
iter =  19
[0.         0.94285714 0.05714286]
iter =  20
[0.         0.94805195 0.05194805]
iter =  21
[0.         0.95256917 0.04743083]
iter =  22
[0.         0.87318841 0.12681159]
iter =  23
[0.       

In [3]:
test.weights

array([[0.        , 1.        , 0.        ],
       [0.        , 0.33333333, 0.66666667],
       [0.        , 0.66666667, 0.33333333],
       ...,
       [0.        , 0.95146248, 0.04853752],
       [0.        , 0.95155956, 0.04844044],
       [0.        , 0.95165634, 0.04834366]])

In [4]:
num_ptx = 10
dim = 100

np.random.seed(2021)
M = np.random.rand(dim, num_ptx)
#M = np.zeros([dim, num_ptx])
#M = np.ones_like(M)
mux = np.random.rand(dim).reshape([dim, 1])
A = np.random.rand(dim, dim)
covx = A @ A.T
covM = np.zeros([num_ptx, dim, dim])
for i in range(0, num_ptx):
    A = np.random.rand(dim, dim)
    covM[i] = A @ A.T +  0 * np.identity(dim)

iterations = 1000

c0 = np.ones(num_ptx)/(num_ptx)
test = GM_FW(mux, M, covx, covM, c0, eta = 0, max_iter=iterations, print_iter = True)
#test.reset_values()
test.FW_bures_routine()


g
[131.78478908 124.80934282 138.28859042 120.98625841 134.69296324
 121.94296326 119.74914852 119.41750775 131.74584938 127.89016922]
ga
[131.50952148 124.0515976  138.21163559 120.38177109 133.5706749
 121.866539   119.54069138 118.69256592 131.35587311 127.63020325]
iter =   0
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
g
[-754.41526927 -749.42806303 -702.68464944 -728.95897895 -701.02230022
 -787.1796086  -728.16284345  195.03671339 -715.76265185 -683.55267196]
ga
[-722.10423279 -700.56184006 -654.7858429  -754.05873108 -646.08842087
 -771.4319458  -683.09476471  211.75753784 -680.73534012 -657.09510422]
iter =   1
[0.         0.         0.         0.         0.         0.66666667
 0.         0.33333333 0.         0.        ]
g
[-20.61759858 -38.23502648  -4.88218065 -39.38048373 -19.99070089
 200.83788451 -47.05030243 105.04407249 -19.26296471 -25.23526756]
ga
[-20.33437347 -36.66832733  -5.47698212 -38.86569977 -19.63662338
 200.3463974  -46.09080505 106.53300476 -18.23244095 -25.38471985]
i

In [48]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import keras
from numpy.random import exponential
from sklearn.decomposition import PCA

np.random.seed(2021)

# Data Processing before PCA 
# N_samples indicates the number of samples to be drawn from data
N_samples = 1
# sample_sz indicates the number of samples to be drawn from data
sample_sz = 400

tf.keras.datasets.mnist.load_data(path="mnist.npz")
(X_train, y), (X_test, yt) = keras.datasets.mnist.load_data()

X_train = X_train / 255
y = y / 1
X_test = X_test / 255
yt = yt / 1

X = X_train.reshape(X_train.shape[0], X_train.shape[1] * X_train.shape[2])
Xt = X_test.reshape(X_test.shape[0], X_test.shape[1] * X_test.shape[2])

# PCA from scikit-learn
PCA_components = 130
pca = PCA(n_components = PCA_components)
pca.fit(X)
# np.sum(pca.explained_variance_ratio_)
P = pca.components_
X_hat = X@P.T
Xt_hat = Xt@P.T

# Creation and Testing and Training samples

n_classes = np.size(np.unique(y))
means_i = np.zeros([n_classes, PCA_components])
covs_i = np.zeros([n_classes, PCA_components, PCA_components])

Training_Data = {}
Training_Lables = {}
for i in range(0, 10):
  Training_Data[str(i)] = X_hat[y == i]
  Training_Lables[str(i)] = y[y == i]
  means_i[i] = np.mean(Training_Data[str(i)], axis = 0)
  covs_i[i] = np.cov(Training_Data[str(i)].T)

Testing_Data = {}
Testing_Lables = {}
for i in range(0, 10):
  Testing_Data[str(i)] = Xt_hat[yt == i]
  Testing_Lables[str(i)] = yt[yt == i]

FW_iterations = 10
Hellinger_metric = np.zeros([N_samples, FW_iterations])
Grads = np.zeros([N_samples, FW_iterations, n_classes])
Optimal_weights = np.zeros([N_samples, FW_iterations, n_classes])
Sample_distributions = np.zeros([N_samples, n_classes])

for sample_iter in range(0, N_samples): 
  # sampling uniformly on probability simplex
  sz = exponential(1, n_classes)
  s = sz / np.sum(sz)
  ints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  p = s
  counts_sample = np.floor(sample_sz * p)
  Sample_distributions[sample_iter,:] = counts_sample / np.sum(counts_sample)

  sample = np.empty([0, PCA_components], dtype = float)
  sample_labels = np.empty(0, dtype = float)
  counter = 0
  for k in ints:
    idx = np.random.choice(np.arange(0, Testing_Data[str(k)].shape[0]), size = np.int(counts_sample[counter]), replace = False )
    sample = np.append(sample, Testing_Data[str(k)][idx], axis = 0)
    sample_labels = np.append(sample_labels, Testing_Lables[str(k)][idx], axis = 0)
    counter = counter + 1

  mean_x = np.mean(sample, axis = 0)
  covx = np.cov(sample.T)

  mux = mean_x[np.newaxis,:].T
  M = means_i.T
  covM = covs_i
  
  # Frank_Wolfe is invoked
  c0 = np.ones(n_classes)/n_classes
  
  MNIST = GM_FW(mux, M, covx, covM, c0, eta = 0, max_iter = FW_iterations, print_iter = True)
  
  MNIST.FW_bures_routine()
  Optimal_weights[sample_iter] = MNIST.weights
  Grads[sample_iter] = MNIST.grads
  print("sample = ", sample_iter)

# for i in range(0, N_samples):
#     for j in range(0, FW_iterations):
#         Hellinger_metric[i, j] = Hellinger(Sample_distributions[i], Optimal_weights[i, j, :])

# from tempfile import TemporaryFile
# MNIST_data = TemporaryFile()
# np.savez("MNIST_data", Optimal_weights=Optimal_weights, Grads=Grads, \
#          Sample_distributions=Sample_distributions, Hellinger_metric=Hellinger_metric)


g
[-7.42340074 -6.95190473  2.50009402  3.59707874  1.09798436  2.20510135
 -2.33075807 -6.02562461  3.41699755 -2.83924741]
ga
[-7.42340857 -6.95190313  2.5000903   3.59708485  1.09797862  2.20509973
 -2.33076277 -6.02562341  3.41699937 -2.83924779]
iter =   0
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
g
[  20.49932518 -101.21588486  -43.41402083  -34.68102881  -68.2163266
  -26.06999409  -22.8051795   -88.54026392  -48.29881235  -82.30691392]
ga
[  20.49933767 -101.21583176  -43.41398883  -34.68100524  -68.21629071
  -26.06997871  -22.805161    -88.54021883  -48.29878974  -82.30687499]
iter =   1
[0.33333333 0.66666667 0.         0.         0.         0.
 0.         0.         0.         0.        ]
g
[  2.61162372   7.10146833 -45.8305643  -47.6776802  -67.6417745
 -42.25155992 -45.08077487 -80.20768193 -31.59505207 -76.21107197]
ga
[  2.61162484   7.10146749 -45.83054793 -47.67766154 -67.64174283
 -42.25154376 -45.08075249 -80.2076447  -31.59504402 -76.21104109]
iter =   2
[0.16666667 0.33333

In [9]:
from tempfile import TemporaryFile
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import exponential

import pickle

def LoadFile(file_name):
  with open(file_name,'rb') as f:
    data = pickle.load(f)
  return data


# Data Processing before PCA 
# N_samples indicates the number of samples to be drawn from data
N_samples = 500
Hellinger_distances = np.zeros([N_samples, 1])
# sample_sz indicates the number of samples to be drawn from data
sample_sz = 500

X = LoadFile('x_train_Cifar10_inception_2048.pkl')
Xt = LoadFile('x_test_Cifar10_inception_2048.pkl')
y = LoadFile('y_train_Cifar10_inception_2048.pkl')
yt = LoadFile('y_test_Cifar10_inception_2048.pkl')
y = np.squeeze(y)
yt = np.squeeze(yt)

n_classes = np.size(np.unique(y))
means_i = np.zeros([n_classes, X.shape[1]])
covs_i = np.zeros([n_classes,X.shape[1], X.shape[1]])

Training_Data = {}
Training_Lables = {}
for i in range(0, 10):
  Training_Data[str(i)] = X[y == i]
  Training_Lables[str(i)] = y[y == i]
  means_i[i] = np.mean(Training_Data[str(i)], axis = 0)
  covs_i[i] = np.cov(Training_Data[str(i)].T)

Testing_Data = {}
Testing_Lables = {}
for i in range(0, 10):
  Testing_Data[str(i)] = Xt[yt == i]
  Testing_Lables[str(i)] = yt[yt == i]

FW_iterations = 1000
Hellinger_metric = np.zeros([N_samples, FW_iterations])
Grads = np.zeros([N_samples, FW_iterations, n_classes])
Optimal_weights = np.zeros([N_samples, FW_iterations, n_classes])
Sample_distributions = np.zeros([N_samples, n_classes])

for sample_iter in range(0, N_samples): 
  # sampling uniformly on probability simplex
  sz = exponential(1, n_classes)
  s = sz / np.sum(sz)
  ints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  p = s
  counts_sample = np.floor(sample_sz * p)
  Sample_distributions[sample_iter,:] = counts_sample / np.sum(counts_sample)

  sample = np.empty([0, X.shape[1]], dtype = float)
  sample_labels = np.empty(0, dtype = float)
  counter = 0
  for k in ints:
    idx = np.random.choice(np.arange(0, Testing_Data[str(k)].shape[0]), size = np.int(counts_sample[counter]), replace = False )
    sample = np.append(sample, Testing_Data[str(k)][idx], axis = 0)
    sample_labels = np.append(sample_labels, Testing_Lables[str(k)][idx], axis = 0)
    counter = counter + 1

  mean_x = np.mean(sample, axis = 0)
  covx = np.cov(sample.T)

  mux = mean_x[np.newaxis,:].T
  M = means_i.T
  covM = covs_i
  
  # Frank_Wolfe is invoked
  c0 = np.ones(n_classes)/n_classes
  
  CIFAR10_incpetion = GM_FW(mux, M, covx, covM, c0, eta = 100, max_iter = FW_iterations, print_iter = False)
  
  CIFAR10_incpetion.FW_frechet_routine()
  Optimal_weights[sample_iter] = CIFAR10_incpetion.weights
  Grads[sample_iter] = CIFAR10_incpetion.grads
  print("sample = ", sample_iter)

for i in range(0, N_samples):
    for j in range(0, FW_iterations):
        Hellinger_metric[i, j] = Hellinger(Sample_distributions[i], Optimal_weights[i, j, :])

CIFAR10_Inception_data = TemporaryFile()
np.savez("CIFAR10_Inception_data", Optimal_weights=Optimal_weights, Grads=Grads,
         Sample_distributions=Sample_distributions, Hellinger_metric=Hellinger_metric)


In [55]:
for i in range(0, covs_i.shape[0]):
    print(np.all(covs_i[i ]== covs_i[i].T))
    print(np.linalg.eigvalsh(covs_i[i]))

True
[3.83375552e-14 1.12408687e-12 1.96340814e-04 ... 9.27784065e+04
 2.93879631e+05 3.53434081e+05]
True
[-1.81782230e-13 -2.71990769e-17  1.41385590e-03 ...  1.00987804e+05
  2.82825891e+05  2.95069214e+05]
True
[-1.92736377e-12  3.17267228e-04  4.88677242e-04 ...  6.88614494e+04
  1.95322356e+05  2.85996934e+05]
True
[-1.01653071e-14  7.37822609e-14  3.02138660e-04 ...  8.09922027e+04
  2.60595387e+05  2.68726680e+05]
True
[-4.58279472e-17  1.86354367e-15  1.61875919e-14 ...  6.79171180e+04
  1.97274161e+05  2.26980212e+05]
True
[-1.20422107e-14  7.18521108e-17  2.47701854e-12 ...  1.08263349e+05
  2.10049129e+05  2.83155286e+05]
True
[8.78894749e-15 2.00862706e-12 7.05356872e-04 ... 1.06125664e+05
 1.84869644e+05 2.55261990e+05]
True
[-2.15673930e-18  1.45737749e-14  2.83365166e-13 ...  9.34842451e+04
  2.09536203e+05  2.73619073e+05]
True
[-4.09352220e-13 -9.84400632e-14  1.77810270e-17 ...  9.32383240e+04
  2.27768808e+05  3.02323225e+05]
True
[1.15135440e-12 1.73264276e-04 2.69

In [64]:
for i in range(0, covs_i.shape[0]):
    for j in range(0, covs_i.shape[0]):
        print( np.linalg.norm(covs_i[i] - covs_i[j], ord = 'fro') )

0.0
171354.38769853624
163334.76941555872
156433.55390648486
186483.17991134652
198959.62698350614
219287.70661481735
192355.93137983905
171812.257392717
225244.67861852085
171354.38769853624
0.0
183973.04497587468
137040.79483937487
174330.90095647553
168470.84572068902
178332.3640390809
161124.75467175938
165910.85195796457
127089.63764524643
163334.76941555872
183973.04497587468
0.0
113643.08489506297
87899.8435381585
152830.3298313368
130726.36272509243
143896.83686515153
194906.3064224918
186771.5804566617
156433.55390648486
137040.79483937487
113643.08489506297
0.0
113176.73669471905
86993.20101188352
127905.2036329166
128177.77301419074
149654.0640443293
166405.6907993815
186483.17991134652
174330.90095647553
87899.8435381585
113176.73669471905
0.0
134035.27289964896
118721.55301428834
121699.7501063512
162356.46768927728
173223.21113889068
198959.62698350614
168470.84572068902
152830.3298313368
86993.20101188352
134035.27289964896
0.0
142213.75260295387
118748.01711187068
16634

In [71]:
for i in range(0, covs_i.shape[0]):
    for j in range(0, covs_i.shape[0]):
        print( np.linalg.norm(M[:,i] - M[:,j]) )

0.0
341.51507470335383
422.4247859351434
377.53352322946137
454.9555941805132
389.5170101283486
548.8112469628584
427.859624925202
195.66708419437438
471.18522652555345
341.51507470335383
0.0
405.6970563031867
321.60428266701206
409.9680549422441
323.9156979529624
408.42127802194597
306.5818435677959
364.717596261466
238.39170641914615
422.4247859351434
405.6970563031867
0.0
177.35041300200206
96.28074076069788
261.4844844499535
255.47287413740645
242.40553452140895
504.6963227031405
439.1338662834322
377.53352322946137
321.60428266701206
177.35041300200206
0.0
179.11764944654195
136.0897772903951
250.62793918123165
266.1108231714618
418.7627735933172
412.4215054914685
454.9555941805132
409.9680549422441
96.28074076069788
179.11764944654195
0.0
272.56271103809877
222.34550652027463
281.56750207664334
513.6390230554665
460.07386063572505
389.5170101283486
323.9156979529624
261.4844844499535
136.0897772903951
272.56271103809877
0.0
299.93348624442456
295.16307079746946
430.3248723064433


In [70]:
M.shape

(2048, 10)