# Fitting Gaussian Mixture on a data sample using features from a pre-trained Model

Consider the set of all $K$-dimensional categorical distributions given by

\begin{equation}
    \mathcal{C} = \bigg\{ \mathbf{c} \in \mathbb{R}^{K}_{+} :  \mathbf{c} \geq 0,\ \mathbf{c}^{\rm T}\mathbf{1} = 1 \bigg\} \subset \mathbb{R}^{K}_{+}.
\end{equation}

In this code we consider the set of gaussian mixtures for the set of distributions $\{ \nu_{i} \}_{i = 1}^{K} $ given by given by:
\begin{equation}
		\mathcal{GM} \bigg(\mathbf{c}, \{ \nu_{i} \}_{i = 1}^{K} \bigg) = \bigg\{ \sum_{i = 1}^{K} c_{i} \nu_{i}, \ \forall \mathbf{c} \in \mathcal{C} \bigg\}
	\end{equation}
For the mixture of gaussian distributions $\{ \nu_{i} \}_{i = 1}^{K} $ with means $\{ \mu_{i} \}_{i = 1}^{K} $ and covariance matrices $\{ \Sigma_{i} \}_{i = 1}^{K} $, mixture means and covariance are


\begin{gather*}
 \bar{\mu}_{\mathbf{c}} = \sum_{k =1}^{K} c_{k} \mu_{k}\  \text{,  }
 \,\,\,\, \Sigma_{\mathbf{c}} = \sum_{k =1}^{K} c_k \big( \Sigma_k + \mu_{k} \mu_{k}^{\top} - \bar{\mu}_{\mathbf{c}} \bar{\mu}_{\mathbf{c}}^\top \big)
\end{gather*}

For a data sample $\nu_{*}$, we employed Frank-Wolfe Based optimization routine to find the best matching mean, covaraince pair from the the set of all possible mixtures of $\{ \nu_{i} \}_{i = 1}^{K} $ to fit with mean  and covaraince of $\nu_{*}$.

In [43]:
#%%
import numpy as np
import scipy as sp
from scipy.optimize import check_grad
from scipy.optimize import approx_fprime
from matplotlib import pyplot as plt
import tensorflow as tf
#%%
'''
    mx: mean of source with size: n x 1
    M: means of target distributions with size: n x K, where K indicates number of data-classes.
    covx: covariance matrix of source with size: n x 1
    covM: array of covariance matrices of target distributions with size: K x n x n, where K indicates number of data-classes.
    max_iter: max_iterations of Frank Wolfe
    tol: tolerance for gradient check
    divg: divergences to be used for computations
    max_iter: maximum of frank wolfe iterations
    min_tol: minimum tolerance for gardient checking
    print_iter: priniting values of maixture weights for each iteration
'''
class GM_FW:
    def __init__(self, mux, M, covx, covM, c0,  max_iter = None, min_tol = None, print_iter = None):
        # instance attributes which should be initilzied with class
        self.mux = mux
        self.M = M
        self.covx = covx
        self.covM = covM
        self.c0 = c0

        
        self.sqrt_covx = sp.linalg.sqrtm(covx)
        self.I = np.identity(np.size(mux))

        # attributes to be keep track of computations
        self.c = np.zeros_like(c0)
        self.mubar = np.zeros_like(mux)
        self.hat_covM = self.covM + self.M.T.reshape([self.M.shape[1], self.M.shape[0], 1]) @ self.M.T.reshape([M.shape[1], 1, M.shape[0]])
        self.tilde_covM = np.zeros_like(covM)
        self.tilde_sigma_c = np.zeros([covM.shape[1], covM.shape[2]])
        self.zeta_c = np.zeros([covM.shape[1], covM.shape[2]])
        
        if max_iter == None:
            self.max_iter = 1000
        else:
            self.max_iter = max_iter

        if min_tol == None:
            self.min_tol = 1e-6
        else:
            self.min_tol = min_tol
        
        if print_iter == None:
            self.print_iter = False
        elif print_iter == False:
            self.print_iter = False
        elif print_iter == True:
            self.print_iter = True
        else:
            self.print_iter = False

        self.obj_vals = np.zeros([self.max_iter])
        self.grads = np.zeros([self.max_iter, np.size(self.c)])
        self.weights = np.zeros([self.max_iter, np.size(self.c)])
    
    def reset_values(self):
        # self.mux = mux
        # self.M = M
        # self.covx = covx
        # self.covM = covM
        # self.c0 = c0
        # self.divg = divg

        self.sqrt_covx = sp.linalg.sqrtm(self.covx)
        self.I = np.identity(np.size(self.mux))

        # attributes to be keep track of computations
        self.c = np.zeros_like(self.c0)
        self.mubar = np.zeros_like(self.mux)
        self.hat_covM = self.covM + \
            self.M.T.reshape([self.M.shape[1], self.M.shape[0], 1]
                         ) @ self.M.T.reshape([self.M.shape[1], 1, self.M.shape[0]])
        self.tilde_covM = np.zeros_like(self.covM)
        self.tilde_sigma_c = np.zeros([self.covM.shape[1], self.covM.shape[2]])
        self.zeta_c = np.zeros([self.covM.shape[1], self.covM.shape[2]])

        # if max_iter == None:
        #     self.max_iter = 1000
        # else:
        #     self.max_iter = max_iter

        # if min_tol == None:
        #     self.min_tol = 1e-6
        # else:
        #     self.min_tol = min_tol

        # if print_iter == None:
        #     self.print_iter = False
        # elif print_iter == False:
        #     self.print_iter = False
        # elif print_iter == True:
        #     self.print_iter = True
        # else:
        #     self.print_iter = False

        self.obj_vals = np.zeros([self.max_iter])
        self.grads = np.zeros([self.max_iter, np.size(self.c)])
        self.weights = np.zeros([self.max_iter, np.size(self.c)])

    def initialize_weights(self):
        self.c = self.c0

    def update_mubar(self):
        self.mubar =  np.sum(self.c[np.newaxis,:] * self.M, axis = 1, keepdims = True)
    
    def compute_mmd(self):
        return np.linalg.norm(self.mux - self.M@self.c)**2
    
    def compute_mmd_grad(self):
        return 2*self.M.T @ (self.M @ self.c[:,np.newaxis] - self.mux)

    def update_tilde_covM(self):
        self.tilde_covM = self.hat_covM - (self.mubar@self.mubar.T)[np.newaxis,:,:]

        
    def update_tilde_sigma_c(self):
        self.tilde_sigma_c = np.sum(
            self.c[:, np.newaxis, np.newaxis] * self.tilde_covM, axis=0)
            
    def update_zeta_c(self):
        temp = sp.linalg.sqrtm(self.sqrt_covx @ self.tilde_sigma_c @ self.sqrt_covx)
        self.zeta_c = (temp + temp.T)/2

    def compute_bures(self):
        temp =  np.trace( self.covx + self.tilde_sigma_c - 2 * self.zeta_c )
        return temp
    
    def compute_bures_grad_tilde_sigma_c(self):
        return self.I - self.sqrt_covx @ sp.linalg.pinv(self.zeta_c) @ self.sqrt_covx
    
    def compute_bures_grad_c(self):
        G = self.compute_bures_grad_tilde_sigma_c()
        rho = (self.M.T.reshape(self.M.shape[1], self.M.shape[0],1) @ self.mubar.T) + self.mubar @ self.M.T.reshape(self.M.shape[1],1,self. M.shape[0])
        temp = np.sum(G[np.newaxis, :, :] * (self.hat_covM - rho ), axis=(1, 2))[:, np.newaxis]
        return temp
    
    def compute_frechet(self):
        B = self.compute_bures()
        mmd = self.compute_mmd()
        return B + mmd

    def compute_frechet_grad(self):
        GB = self.compute_bures_grad_c()
        Gmmd = self.compute_mmd_grad()
        return GB + Gmmd


    def FW_frechet_routine(self):
        self.initialize_weights()
        for i in range(0, self.max_iter):
            self.update_mubar()
            self.update_tilde_covM()
            self.update_tilde_sigma_c()
            self.update_zeta_c()
            self.compute_bures_grad_c()
            self.obj_vals[i] = self.compute_frechet()
            g = self.compute_frechet_grad()
            #print(g)
            self.grads[i, :] = np.squeeze(g)
            s = np.zeros_like(self.c)
            idx = np.argmin(g)
            s[idx] = 1
            gamma = 2 / (2 + i)
            self.c = (1 - gamma) * self.c + gamma * s
            
            self.weights[i, :] = self.c
            if self.print_iter == True:
                print("iter = i")
                print(self.c)
    
    def FW_bures_routine(self):

        # commented part must be uncommented for gradient-checking
        
        # def bures_grad_num(c_i, epsilon):
            
        #     num_grad = np.zeros_like(c_i)
        #     I = np.identity(c_i.size)
        #     for i in range(0, num_grad.size):
        #         num_grad[i] = (bures(c_i + epsilon * I[:,i]) - bures(c_i) ) / epsilon
            
        #     return num_grad
            

        # def bures(c_i):
            
        #     mubar =  np.sum(c_i[np.newaxis,:] * self.M, axis = 1, keepdims = True)
        #     # k, n, n - n,n => k,n,n
        #     tilde_covM = self.hat_covM 
        #     #k,.,. * k,n,n => n,n
        #     tilde_sigma_c = np.sum(c_i[:, np.newaxis, np.newaxis] * tilde_covM, axis=0) - mubar@mubar.T

        #     #n,n
        #     zeta_ci = sp.linalg.sqrtm(self.sqrt_covx @ tilde_sigma_c @ self.sqrt_covx)

        #     #n,n
        #     zeta_ci = (zeta_ci + zeta_ci.T)/2

        #     #scalar
        #     temp =  np.trace( tilde_sigma_c + self.covx  - 2 * zeta_ci )
        #     return temp
        
        # def bures_grad(c_i):

        #     mubar =  np.sum(c_i[np.newaxis,:] * self.M, axis = 1, keepdims = True)

        #     tilde_covM = self.hat_covM - (mubar@mubar.T)[np.newaxis,:,:]

        #     tilde_sigma_c = np.sum(c_i[:, np.newaxis, np.newaxis] * tilde_covM, axis=0)

        #     zeta_ci = sp.linalg.sqrtm(self.sqrt_covx @ tilde_sigma_c @ self.sqrt_covx)
        #     zeta_ci = (zeta_ci + zeta_ci.T)/2

        #     G = self.I - self.sqrt_covx @ np.linalg.pinv(zeta_ci) @ self.sqrt_covx

        #     rho = (M.T.reshape(M.shape[1],M.shape[0],1) @ mubar.T) +  mubar @ M.T.reshape(M.shape[1],1,M.shape[0])

        #     return np.sum(G[np.newaxis, :, :] * (self.hat_covM - rho ), axis=(1, 2))
            

        self.initialize_weights()
        for i in range(0, self.max_iter):
            self.update_mubar()
            self.update_tilde_covM()
            self.update_tilde_sigma_c()
            self.update_zeta_c()
            self.obj_vals[i] = self.compute_bures()
            g = self.compute_bures_grad_c()
            
            
            # print("g")
            # print(np.squeeze(g))
            # c_i = self.c
            
            # print(np.squeeze(bures_grad(c_i)))
            #  print(np.squeeze(self.compute_bures_grad_c()))

            # eps = 10 * np.sqrt(np.finfo(float).eps)
            # ga = sp.optimize.approx_fprime(c_i, bures, np.array(eps*np.ones(np.size(g))).T)
            # ga = bures_grad_num(c_i, eps)
            # print("ga")
            # print(ga)
            
            # print(sp.optimize.check_grad(bures, bures_grad, c_i))
            
            s = np.zeros_like(self.c)
            idx = np.argmin(g)
            s[idx] = 1
            gamma = 2 / (2 + i)
            self.c = (1 - gamma) * self.c + gamma * s
            self.grads[i, :] = np.squeeze(g)
            self.weights[i, :] = self.c

            if self.print_iter == True:
                print("iter =  ",i)
                print(self.c)

    def FW_MMD_routine(self):
        
        # commented part must be uncommented for gradient-checking
        
        def mmd(c_i):
            return (np.linalg.norm(self.M@c_i[:,np.newaxis] - self.mux))**2

        def mmd_grad(c_i):
            return np.squeeze(2 * self.M.T @ (self.M @ c_i[:,np.newaxis] - self.mux))
        
        self.initialize_weights()
        for i in range(0, self.max_iter):
            self.update_mubar()
            self.update_tilde_covM()
            self.obj_vals[i] = self.compute_mmd()

            g = self.compute_mmd_grad()

            # print("g")
            c_i = self.c
            # print(np.squeeze(mmd_grad(c_i)))
            # #print(np.squeeze(self.compute_mmd_grad()))

            # eps = np.sqrt(np.finfo(float).eps) * 100
            
            # print("ga")
            # ga = sp.optimize.approx_fprime(c_i, mmd, np.array([eps ,eps, eps]).T)
            # print(ga)

            print(sp.optimize.check_grad(mmd, mmd_grad, c_i))

            s = np.zeros_like(self.c)
            idx = np.argmin(g)
            s[idx] = 1
            gamma = 2 / (2 + i)
            self.c = (1 - gamma) * self.c  + gamma * s 
            self.grads[i,:] = np.squeeze(g)
            self.weights[i,:] = self.c

            if self.print_iter == True:
                print("iter = ",i)
                print(self.c)


def Hellinger(p, q):
  '''
  Function to evaluate Hellinger Distance between distributions categorical distributions p and q
  input: distributions p and q
  outputs: Hillnger distance between p and q

  '''
  assert np.all(p >= 0), "first distribution must be non-negative"
  assert np.all(q >= 0), "second distribution must be non-negative"
  assert p.size == q.size, "input vectors must be of same size"

  p = p.reshape([p.size, 1])
  q = q.reshape([q.size, 1])

  return (1/np.sqrt(2)) * np.linalg.norm(np.sqrt(p) - np.sqrt(q))


In [32]:
import numpy as np
import scipy as sp

M = np.array([[0, -0.5, 0.5],[1, 0, 0]])
# M = np.array([[0, 0, 0], [0, 0, 0]])
#M = np.ones_like(M)
mux = np.array([[np.sqrt(3)/2], [np.sqrt(3)/2] ])
covx = np.array([[10,6],[6,8]])
covM = np.zeros([3,2,2])
covM[0] = np.array([[1, 0.5],[0.5, 1]])
covM[1] = np.array([[5, 0.6], [0.6, 7]])
covM[2] = np.array([[4, 1], [1, 2]])
c0 = np.ones(3)/3
 
iterations = 10

test = GM_FW(mux, M, covx, covM, c0, max_iter = iterations)
test.reset_values()
#test.FW_bures_routine()
test.FW_frechet_routine()



In [34]:
num_ptx = 10
dim = 100
M = np.random.rand(dim, num_ptx)
#M = np.zeros([dim, num_ptx])
# M = np.ones_like(M)
mux = np.random.rand(dim).reshape([dim, 1])
A = np.random.rand(dim, dim)
covx = A @ A.T
covM = np.zeros([num_ptx, dim, dim])
for i in range(0, num_ptx):
    A = np.random.rand(dim, dim)
    covM[i] = A @ A.T

iterations = 10

c0 = np.ones(num_ptx)/(num_ptx)
test = GM_FW(mux, M, covx, covM, c0, max_iter=iterations)
test.reset_values()
test.FW_MMD_routine()


8.908369880698093e-07
1.4338831199137095e-06
2.1190743935323343e-06
1.2017297546573584e-06
1.758248343906939e-06
1.5532142638477044e-06
1.984057302369691e-06
9.763875361878837e-07
1.429359211585103e-06
1.4957066229305563e-06


In [5]:
test.c

array([0.14545455, 0.10909091, 0.09090909, 0.07272727, 0.05454545,
       0.18181818, 0.12727273, 0.18181818, 0.03636364, 0.        ])

In [37]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import keras

from sklearn.decomposition import PCA

# Data Processing before PCA 
# N_samples indicates the number of samples to be drawn from data
N_samples = 1
Hellinger_distances = np.zeros([N_samples, 1])
# sample_sz indicates the number of samples to be drawn from data
sample_sz = 400


tf.keras.datasets.mnist.load_data(path="mnist.npz")
(X_train, y), (X_test, yt) = keras.datasets.mnist.load_data()

X_train = X_train / 255
y = y / 1
X_test = X_test / 255
yt = yt / 1

X = X_train.reshape(X_train.shape[0], X_train.shape[1] * X_train.shape[2])
Xt = X_test.reshape(X_test.shape[0], X_test.shape[1] * X_test.shape[2])


# PCA from scikit-learn

PCA_components = 130
pca = PCA(n_components = PCA_components)
pca.fit(X)
# np.sum(pca.explained_variance_ratio_)
P = pca.components_
X_hat = X@P.T
Xt_hat = Xt@P.T

# Creation and Testing and Training samples

n_classes = np.size(np.unique(y))
means_i = np.zeros([n_classes, PCA_components])
covs_i = np.zeros([n_classes, PCA_components, PCA_components])

Training_Data = {}
Training_Lables = {}
for i in range(0, 10):
  Training_Data[str(i)] = X_hat[y == i]
  Training_Lables[str(i)] = y[y == i]
  means_i[i] = np.mean(Training_Data[str(i)], axis = 0)
  covs_i[i] = np.cov(Training_Data[str(i)].T)


Testing_Data = {}
Testing_Lables = {}
for i in range(0, 10):
  Testing_Data[str(i)] = Xt_hat[yt == i]
  Testing_Lables[str(i)] = yt[yt == i]
  
# Slicing thorugh test set

Hellinger_metric = np.zeros(N_samples)

for sample_iter in range(0, N_samples): 
  sz = np.random.randint(10, 2000, [10])
  s = np.random.dirichlet(sz, 1).T
  ints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  p = s
  counts_sample = np.floor(sample_sz * p)
  sample = np.empty([0, PCA_components], dtype = float)
  sample_labels = np.empty(0, dtype = float)
  counter = 0
  for k in ints:
    idx = np.random.choice(np.arange(0, Testing_Data[str(k)].shape[0]), size = np.int(counts_sample[counter]), replace = False )
    sample = np.append(sample, Testing_Data[str(k)][idx], axis = 0)
    sample_labels = np.append(sample_labels, Testing_Lables[str(k)][idx], axis = 0)
    counter = counter + 1

  mean_x = np.mean(sample, axis = 0)
  covx = np.cov(sample.T)

  mux = mean_x[np.newaxis,:].T
  M = means_i.T
  covM = covs_i
  


  # Frank_Wolfe is invoked
  c0 = np.ones(n_classes)/n_classes
  iterations = 100
  
  
  MNIST = GM_FW(mux, M, covx, covM, c0, max_iter = iterations, print_iter = False)
  
  MNIST.FW_MMD_routine()
  Hellinger_metric[sample_iter] = Hellinger(MNIST.c, s)
  print(Hellinger_metric[sample_iter])

    

2.2277942513955988e-06
2.654202485201261e-06
2.35024122947716e-06
2.22603766228557e-06
2.3482181066627652e-06
2.3032625194424033e-06
2.2379821952121423e-06
2.273219604305405e-06
2.22814934868718e-06
2.245359971973463e-06
2.2565322210198877e-06
2.1494028608798355e-06
2.2425251480078697e-06
2.2278583126094528e-06
2.2242892452144593e-06
2.2492752374190014e-06
2.2166061526749827e-06
2.234325624245001e-06
2.2278584061832107e-06
2.221347600570517e-06
2.2200777402478455e-06
2.215564219095068e-06
2.2484332521796323e-06
2.2027097204167375e-06
2.2064581717903113e-06
2.2348467570808267e-06
2.230336004154402e-06
2.229052435439908e-06
2.243806536570899e-06
2.2265105359636444e-06
2.2338759972911976e-06
2.2337609002480707e-06
2.2295498818014583e-06
2.2169410671830086e-06
2.2252357822054494e-06
2.2449086863542976e-06
2.206436683080944e-06
2.218239348339371e-06
2.2199426564530545e-06
2.2193198725430436e-06
2.251637211465761e-06
2.238393413504612e-06
2.224641532194242e-06
2.2559877363241316e-06
2.234447

In [25]:
MNIST.c

array([0.16772277, 0.02910891, 0.17544554, 0.11742574, 0.12653465,
       0.11207921, 0.06673267, 0.05920792, 0.        , 0.14574257])

In [26]:
s

array([[0.14657956],
       [0.0202841 ],
       [0.17381758],
       [0.12953466],
       [0.12931951],
       [0.11252194],
       [0.09168747],
       [0.04448029],
       [0.00187804],
       [0.14989685]])

In [9]:
def Hellinger(p,q):
  '''
  Function to evaluate Hellinger Distance between distributions categorical distributions p and q
  input: distributions p and q
  outputs: Hillnger distance between p and q

  '''
  assert np.all(p >= 0), "first distribution must be non-negative" 
  assert np.all(q >= 0), "second distribution must be non-negative"
  assert p.size == q.size, "input vectors must be of same size"

  p = p.reshape([p.size, 1]) 
  q = q.reshape([q.size, 1]) 

  return (1/np.sqrt(2)) * np.linalg.norm( np.sqrt(p) - np.sqrt(q))

In [10]:
1 - Hellinger(MNIST.c, s)

0.9481152573588495

In [29]:
s

array([[0.0210801 ],
       [0.166506  ],
       [0.06654268],
       [0.06074725],
       [0.12808842],
       [0.10591242],
       [0.08285953],
       [0.08196471],
       [0.16542776],
       [0.12087113]])

In [30]:
MNIST.c

array([0.01797203, 0.1681998 , 0.0655025 , 0.05763437, 0.13961039,
       0.091998  , 0.09507493, 0.06078322, 0.15694106, 0.14628372])

In [42]:
np.linalg.norm(MNIST.grads, axis = 1)/2048

array([0.01130345, 0.56102229, 0.07662679, 0.0279048 , 0.0183663 ,
       0.01852785, 0.01588996, 0.01739065, 0.01058555, 0.00706925,
       0.00643785, 0.00685781, 0.00556978, 0.00546221, 0.00457756,
       0.00547457, 0.00430023, 0.00428731, 0.00340144, 0.00578065,
       0.008231  , 0.00732366, 0.00694462, 0.00776018, 0.00646413,
       0.00545903, 0.00555747, 0.00621676, 0.00550672, 0.00534905,
       0.00452354, 0.00396553, 0.00377714, 0.00310301, 0.00293698,
       0.00286797, 0.00344374, 0.00275388, 0.00289212, 0.00299938,
       0.00229725, 0.00234716, 0.00197026, 0.0018147 , 0.00192196,
       0.00212481, 0.00209173, 0.00178811, 0.00156238, 0.00257819,
       0.00193827, 0.00151559, 0.00147174, 0.00138873, 0.00248851,
       0.0040003 , 0.00389536, 0.00356178, 0.00401818, 0.00361892,
       0.0033748 , 0.00336329, 0.00315605, 0.00286371, 0.0027559 ,
       0.00291554, 0.00257194, 0.00271428, 0.00255357, 0.00225775,
       0.00242274, 0.00221548, 0.00193186, 0.00195018, 0.00166