In [1]:
import h5py
import torch
from pykeops.torch import LazyTensor
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.mixture import GaussianMixture
from matplotlib.patches import Ellipse
import numpy as np
from math import pi, cos, sin
from scipy.integrate import quad

In [2]:
f = h5py.File('./MNIST data/train_0.h5', 'r')
data = f['data']
label = f['label']

In [3]:
# parameters
numOfClass = 10
numOfImagesEachClass = 100
numOfGaussian = 10
eta = 1e-3
epsilon = 0.1

In [5]:
# data cleaning: choose images from each of the classes
selectedData, selectedLabel = ImageChoosing(numOfImagesEachClass, numOfClass)

# putting 1000 images together
finalData = np.zeros((256 * numOfImagesEachClass * numOfClass, 4))
finalData = selectedData.reshape((256 * numOfImagesEachClass * numOfClass, 4))

In [102]:
# initializing C
gmm = GaussianMixture(n_components = numOfGaussian)
gmm.fit(finalData)
means = gmm.means_
covariances = gmm.covariances_
C = covariances
C = torch.from_numpy(C)
C.requires_grad_(True)

tensor([[[ 1.3881e-02,  1.9484e-11, -2.3195e-03, -7.7641e-05],
         [ 1.9484e-11,  1.0000e-06,  4.7144e-11,  4.5005e-13],
         [-2.3195e-03,  4.7144e-11,  1.0032e-01,  1.2114e-04],
         [-7.7641e-05,  4.5005e-13,  1.2114e-04,  1.6436e-05]],

        [[ 9.2455e-03, -3.8658e-11, -3.4908e-03, -2.0831e-05],
         [-3.8658e-11,  1.0000e-06,  2.0765e-10,  3.3862e-13],
         [-3.4908e-03,  2.0765e-10,  8.2228e-02,  2.7266e-05],
         [-2.0831e-05,  3.3862e-13,  2.7266e-05,  1.7265e-05]],

        [[ 1.0660e-01,  2.5638e-11, -1.7477e-02, -2.8510e-05],
         [ 2.5638e-11,  1.0000e-06,  1.0211e-10,  9.0939e-14],
         [-1.7477e-02,  1.0211e-10,  1.1466e-01, -4.2041e-05],
         [-2.8510e-05,  9.0939e-14, -4.2041e-05,  1.4826e-05]],

        [[ 6.1380e-02, -5.9047e-12, -1.3720e-02, -1.9236e-03],
         [-5.9047e-12,  1.0000e-06,  2.3922e-11,  5.7013e-12],
         [-1.3720e-02,  2.3922e-11,  1.3274e-01,  2.3937e-03],
         [-1.9236e-03,  5.7013e-12,  2.3937e-03, 

In [103]:
# initializing W
W = np.random.random_sample((numOfClass * numOfImagesEachClass, numOfGaussian))

# normalizing W
W = NormalizingW(W)

W = torch.from_numpy(W)
W.requires_grad_(True)
W.retain_grad()
print(W)

tensor([[0.1159, 0.1631, 0.1301,  ..., 0.1670, 0.0777, 0.1910],
        [0.1778, 0.0914, 0.2158,  ..., 0.0143, 0.0927, 0.0715],
        [0.1282, 0.0493, 0.0693,  ..., 0.1065, 0.2155, 0.1417],
        ...,
        [0.1282, 0.0562, 0.0065,  ..., 0.1309, 0.1427, 0.1302],
        [0.1513, 0.1658, 0.1075,  ..., 0.1183, 0.1097, 0.0423],
        [0.1227, 0.1536, 0.0177,  ..., 0.0317, 0.1606, 0.1639]],
       dtype=torch.float64, requires_grad=True)


In [104]:
# initializing flag
flag = 1

# compute X
X = GenerateX()
X = torch.from_numpy(X)

# compute Xhat
Xhat = GenerateXhat()

In [105]:
# compute objective function E
E = GenerateE()
E.requires_grad_(True)

tensor(11827.1808, dtype=torch.float64, grad_fn=<TraceBackward>)

In [106]:
# begin gradient descent
lambda_ = 0
YCNew = C

while True:
    print("Error of iteration", flag, "is:", E.item())
    EOld = E
    lambdaNew_ = (1 + np.sqrt(1 + 4 * np.power(lambda_, 2))) / 2
    gama = (1 - lambda_) / lambdaNew_
    lambda_ = lambdaNew_
    WGrad = torch.autograd.grad(E, W, create_graph = True, allow_unused = True)[0]
    CGrad = torch.autograd.grad(E, C, create_graph = True, allow_unused = True)[0]
    CGrad = NormalizingGrad(CGrad)
    WGrad = NormalizingGrad(WGrad)
    YWNew = W - torch.mul(WGrad, eta)
    W = (1 - gama) * YWNew + gama * W
    CheckW(W)
    CheckC(C)
    for j in range(numOfGaussian):
        YCNew[j] = ExponentialMap(C[j], - eta * CGrad[j])
        C[j] = ExponentialMap(YCNew[j], gama * LogMap(YCNew[j], YC[j]))
    CheckYCNew(YCNew)
    CheckC(C)
    Xhat = GenerateXhat()
    E = GenerateE()
    if torch.abs(E - EOld) < epsilon:
        print("The Final Error is:", E)
        break
    flag += 1

Error of iteration 1 is: 11827.180830844158
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
YCNew has negative components on the diag
C has negative components on the diag
C has negative components on the diag
C has negative components on the diag
C 

KeyboardInterrupt: 

In [95]:
# helper functions
def ImageChoosing(numEach, numOfClass):
    selectedData = np.zeros((numEach * numOfClass, 256, 4))
    selectedLabel = np.zeros(numEach * numOfClass)
    count = np.zeros(numOfClass, dtype = int)
    for i in range(label.size):
        thisData = data[i]
        thisLabel = label[i]
        thisCount = count[thisLabel]
        if thisCount < 100:
            selectedData[thisLabel * 100 + thisCount] = data[i]
            selectedLabel[thisLabel * 100 + thisCount] = label[i]
            count[thisLabel] += 1
        else:
            exit = True
            for j in range(10):
                if count[j] < 100:
                    exit = False
            if exit:
                break
    return selectedData, selectedLabel

def NormalizingW(W):
    for i in range(W.shape[0]):
        sumOfW = np.sum(W[i])
        for j in range(W.shape[1]):
            W[i][j] = W[i][j] / sumOfW
    return W

def GenerateX():
    result = np.zeros((numOfClass * numOfImagesEachClass, 4, 4))
    for i in range(numOfClass * numOfImagesEachClass):
        thisData = selectedData[i]
        gmm = GaussianMixture(n_components = 1)
        gmm.fit(thisData)
        means = gmm.means_
        covariances = gmm.covariances_
        result[i] = covariances
    return result

def GenerateXhat():
    result = np.zeros((numOfClass * numOfImagesEachClass, 4, 4))
    result = torch.from_numpy(result)
    for i in range(numOfClass * numOfImagesEachClass):
        A = 0
        B = 0
        for j in range(numOfGaussian):
            A += W[i][j] * C[j]
            B += W[i][j] * torch.inverse(C[j])
        firstTerm = sqrtMatrix(torch.inverse(B))
        secondTerm = sqrtMatrix(torch.mm(torch.mm(sqrtMatrix(B),A),sqrtMatrix(B)))
        thirdTerm = firstTerm
        result[i] = torch.mm(torch.mm(firstTerm, secondTerm), thirdTerm)
    return result

def sqrtMatrix(matrix):
    u, s, v = torch.svd(matrix)
    newS = torch.sqrt(torch.diag(s))
    return torch.mm(torch.mm(u, newS), v.t())
    
def GenerateE():
    E = 0
    for i in range(numOfClass * numOfImagesEachClass):
        E += (1 / 4) * (torch.mm(torch.inverse(X[i]), Xhat[i]) + torch.mm(torch.inverse(Xhat[i]), X[i]) - 8)
    return torch.trace(E)

def ExponentialMap(x, V):
    svdTerm = torch.mm(torch.mm(sqrtMatrix(torch.inverse(x)), V), sqrtMatrix(torch.inverse(x)))
    u, s, v = torch.svd(svdTerm)
    temp = torch.exp(s)
    newS = torch.diag(temp)
    midTerm = torch.mm(torch.mm(u, newS), v.t())
    return torch.mm(torch.mm(sqrtMatrix(x), midTerm), sqrtMatrix(x))

def LogMap(x, Y):
    u, s, v = torch.svd(Y)
    newS = torch.diag(torch.log(s))
    midTerm = torch.mm(torch.mm(u, newS), v.t())
    return torch.mm(torch.mm(sqrtMatrix(x), midTerm), sqrtMatrix(x))
    
def NormalizingGrad(grad):
    for i in range(grad.shape[0]):
        sumOfGrad = torch.sum(grad[i])
        for j in range(grad.shape[1]):
            grad[i][j] = grad[i][j] / sumOfGrad
    return grad

def CheckW(W):
    for i in range(W.shape[0]):
        for j in range(W.shape[1]):
            if W[i][j] < 0:
                print("W has negative components")

def CheckC(C):
    for i in range(C.shape[0]):
        for j in range(C.shape[1]):
            if C[i][j][j] < 0:
                print("C has negative components on the diag")

def CheckYCNew(C):
    for i in range(C.shape[0]):
        for j in range(C.shape[1]):
            if C[i][j][j] < 0:
                print("YCNew has negative components on the diag")

In [101]:
print(CGrad)

tensor([[[ 6.3311e-02,  2.7917e-06, -1.0903e-02, -2.9411e-03],
         [ 2.7917e-06, -9.3894e-09, -5.2331e-07,  3.7663e-07],
         [-1.0903e-02, -5.2331e-07, -1.5954e-03,  1.0121e-03],
         [-2.9411e-03,  3.7663e-07,  1.0121e-03,  9.6394e-01]],

        [[ 1.6240e-02,  9.4092e-08, -2.0637e-03,  1.5375e-03],
         [ 9.4092e-08, -1.2325e-08, -1.7288e-07,  4.0265e-08],
         [-2.0637e-03, -1.7288e-07, -2.2618e-03,  3.2353e-04],
         [ 1.5375e-03,  4.0267e-08,  3.2353e-04,  9.8643e-01]],

        [[ 1.9755e-03, -1.0003e-08, -4.0294e-04,  5.4565e-04],
         [-1.0003e-08, -2.0633e-09,  3.5594e-08, -7.0336e-07],
         [-4.0294e-04,  3.5594e-08, -4.9761e-04, -2.1850e-03],
         [ 5.4565e-04, -7.0336e-07, -2.1850e-03,  1.0026e+00]],

        [[ 1.0543e-02,  5.9585e-08, -2.3899e-03,  3.1896e-03],
         [ 5.9585e-08, -1.0772e-08,  2.7011e-07,  5.6891e-07],
         [-2.3899e-03,  2.7011e-07, -4.0943e-03, -5.2816e-04],
         [ 3.1896e-03,  5.6891e-07, -5.2816e-04, 