In [1]:
import h5py
import torch
from pykeops.torch import LazyTensor
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.mixture import GaussianMixture
from matplotlib.patches import Ellipse
import numpy as np
from math import pi, cos, sin
from scipy.integrate import quad

In [2]:
f = h5py.File('./MNIST data/train_0.h5', 'r')
data = f['data']
label = f['label']

In [3]:
# parameters
numOfClass = 10
numOfImagesEachClass = 100
numOfGaussian = 10
eta = 0.9
epsilon = 0.1

In [5]:
# data cleaning: choose images from each of the classes
selectedData, selectedLabel = ImageChoosing(numOfImagesEachClass, numOfClass)

# putting 1000 images together
finalData = np.zeros((256 * numOfImagesEachClass * numOfClass, 4))
finalData = selectedData.reshape((256 * numOfImagesEachClass * numOfClass, 4))

In [6]:
# initializing C
gmm = GaussianMixture(n_components = numOfGaussian)
gmm.fit(finalData)
means = gmm.means_
covariances = gmm.covariances_
C = covariances
C = torch.from_numpy(C)
C.requires_grad_(True)

tensor([[[ 4.3534e-02,  1.0462e-12, -1.2681e-02, -8.3743e-04],
         [ 1.0462e-12,  1.0000e-06,  1.0893e-10, -6.5424e-13],
         [-1.2681e-02,  1.0893e-10,  1.1686e-01,  2.4895e-05],
         [-8.3743e-04, -6.5424e-13,  2.4895e-05,  1.7499e-03]],

        [[ 4.2709e-02, -1.1817e-11, -1.6273e-02, -5.4185e-03],
         [-1.1817e-11,  1.0000e-06,  1.3627e-11,  2.8190e-12],
         [-1.6273e-02,  1.3627e-11,  1.3136e-01,  4.1727e-03],
         [-5.4185e-03,  2.8190e-12,  4.1727e-03,  4.8257e-02]],

        [[ 1.5700e-01,  6.9776e-12, -1.5606e-02, -4.3025e-03],
         [ 6.9776e-12,  1.0000e-06,  4.0864e-11, -5.6317e-12],
         [-1.5606e-02,  4.0864e-11,  1.1202e-01, -1.9629e-03],
         [-4.3025e-03, -5.6317e-12, -1.9629e-03,  5.9523e-03]],

        [[ 5.1556e-02, -8.2271e-12,  2.9351e-03,  3.8695e-05],
         [-8.2271e-12,  1.0000e-06,  3.5085e-11, -2.4916e-13],
         [ 2.9351e-03,  3.5085e-11,  4.8580e-02,  6.0292e-06],
         [ 3.8695e-05, -2.4916e-13,  6.0292e-06, 

In [7]:
# initializing W
W = np.random.random_sample((numOfClass * numOfImagesEachClass, numOfGaussian))

# normalizing W
W = NormalizingW(W)

W = torch.from_numpy(W)
W.requires_grad_(True)

tensor([[0.1324, 0.1768, 0.1790,  ..., 0.0979, 0.1782, 0.1330],
        [0.0408, 0.1422, 0.1466,  ..., 0.1614, 0.1590, 0.0832],
        [0.1412, 0.1058, 0.0808,  ..., 0.2160, 0.0536, 0.0572],
        ...,
        [0.1195, 0.0135, 0.1502,  ..., 0.1002, 0.0527, 0.2044],
        [0.1815, 0.1344, 0.0175,  ..., 0.1488, 0.0451, 0.1357],
        [0.0748, 0.1407, 0.1398,  ..., 0.1685, 0.0366, 0.1614]],
       dtype=torch.float64, requires_grad=True)

In [8]:
# initializing flag
flag = 1

# compute X
X = GenerateX()
X = torch.from_numpy(X)

# compute Xhat
Xhat = GenerateXhat()

In [9]:
# compute objective function E
E = GenerateE()
E.requires_grad_(True)

tensor(11254.5600, dtype=torch.float64, grad_fn=<TraceBackward>)

In [None]:
torch.autograd.grad(E, C, create_graph = True, allow_unused = True)

In [10]:
# begin gradient descent
lambda_ = 1
YW = W
YC = C

while True:
    print("Error of iteration", flag, "is:", E)
    EOld = E
    lambdaNew_ = (1 + np.sqrt(1 + 4 * lambda_^2)) / 2
    gama = (1 - lambdaNew_) / lambda_
    E.backward(retain_graph = True, create_graph = True)
    YWNew = W - eta * W.grad
    W = (1 - gama) * YWNew + gama * YW
    YW = YWNew
    YCNew = YC
    for j in range(numOfGaussian):
        YCNew[j] = ExponentialMap(C[j], - eta * C.grad[j])
        C[j] = ExponentialMap(YCNew[j], gama * LogMap(YCNew[j], YC[j]))
    YC = YCNew
    E = GenerateE()
    if np.abs(E - EOld) < eta:
        break
    flag += 1

Error of iteration 1 is: tensor(11254.5600, dtype=torch.float64, grad_fn=<TraceBackward>)


RuntimeError: svd_cpu: the updating process of SBDSDC did not converge (error: 3)

In [4]:
# helper functions
def ImageChoosing(numEach, numOfClass):
    selectedData = np.zeros((numEach * numOfClass, 256, 4))
    selectedLabel = np.zeros(numEach * numOfClass)
    count = np.zeros(numOfClass, dtype = int)
    for i in range(label.size):
        thisData = data[i]
        thisLabel = label[i]
        thisCount = count[thisLabel]
        if thisCount < 100:
            selectedData[thisLabel * 100 + thisCount] = data[i]
            selectedLabel[thisLabel * 100 + thisCount] = label[i]
            count[thisLabel] += 1
        else:
            exit = True
            for j in range(10):
                if count[j] < 100:
                    exit = False
            if exit:
                break
    return selectedData, selectedLabel

def NormalizingW(W):
    for i in range(W.shape[0]):
        sumOfW = np.sum(W[i])
        for j in range(W.shape[1]):
            W[i][j] = W[i][j] / sumOfW
    return W

def GenerateX():
    result = np.zeros((numOfClass * numOfImagesEachClass, 4, 4))
    for i in range(numOfClass * numOfImagesEachClass):
        thisData = selectedData[i]
        gmm = GaussianMixture(n_components = 1)
        gmm.fit(thisData)
        means = gmm.means_
        covariances = gmm.covariances_
        result[i] = covariances
    return result

def GenerateXhat():
    result = np.zeros((numOfClass * numOfImagesEachClass, 4, 4))
    result = torch.from_numpy(result)
    for i in range(numOfClass * numOfImagesEachClass):
        A = 0
        B = 0
        for j in range(numOfGaussian):
            A += W[i][j] * C[j]
            B += W[i][j] * torch.inverse(C[j])
        firstTerm = sqrtMatrix(torch.inverse(B))
        secondTerm = sqrtMatrix(torch.mm(torch.mm(sqrtMatrix(B),A),sqrtMatrix(B)))
        thirdTerm = firstTerm
        result[i] = torch.mm(torch.mm(firstTerm, secondTerm), thirdTerm)
    return result

def sqrtMatrix(matrix):
    u, s, vt = torch.svd(matrix)
    newS = torch.sqrt(torch.diag(s))
    return torch.mm(torch.mm(u, newS), vt)
    
def GenerateE():
    E = 0
    for i in range(numOfClass * numOfImagesEachClass):
        E += (1 / 4) * (torch.mm(torch.inverse(X[i]), Xhat[i]) + torch.mm(torch.inverse(Xhat[i]), X[i]) - 8)
    return torch.trace(E)

def ExponentialMap(x, V):
    return sqrtMatrix(x) * torch.exp(torch.mm(torch.mm(torch.inverse(sqrtMatrix(x)), V), torch.inverse(sqrtMatrix(x)), )) * sqrtMatrix(x)
    
def LogMap(x, Y):
    return sqrtMatrix(x) * torch.exp(torch.mm(torch.mm(torch.inverse(sqrtMatrix(x)), V), torch.inverse(sqrtMatrix(x)), )) * sqrtMatrix(x)
    
    
    
    