In [1]:
import h5py
import torch
from pykeops.torch import LazyTensor
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.mixture import GaussianMixture
from matplotlib.patches import Ellipse
import numpy as np
from math import pi, cos, sin
from scipy.integrate import quad
from sklearn.manifold import TSNE

In [2]:
f = h5py.File('./MNIST data/train_0.h5', 'r')
data = f['data']
label = f['label']

In [3]:
# parameters
numOfClass = 10
numOfImagesEachClass = 100
numOfGaussian = 10
etaW = 1e-3
etaC = 1e-2
epsilon = 1

In [5]:
# data cleaning: choose images from each of the classes
selectedData, selectedLabel = ImageChoosing(numOfImagesEachClass, numOfClass)

# putting 1000 images together
finalData = np.zeros((256 * numOfImagesEachClass * numOfClass, 4))
finalData = selectedData.reshape((256 * numOfImagesEachClass * numOfClass, 4))

In [6]:
# initializing C
gmm = GaussianMixture(n_components = numOfGaussian)
gmm.fit(finalData)
means = gmm.means_
covariances = gmm.covariances_
C = covariances
C = torch.from_numpy(C)
C.requires_grad_(True)

tensor([[[ 1.0254e-01, -1.0172e-10, -6.5024e-03, -9.5170e-03],
         [-1.0172e-10,  1.0000e-06,  3.9966e-11,  2.3592e-11],
         [-6.5024e-03,  3.9966e-11,  5.7827e-02,  1.5372e-03],
         [-9.5170e-03,  2.3592e-11,  1.5372e-03,  4.9389e-02]],

        [[ 1.6869e-02, -4.5345e-11, -1.8350e-03, -2.2840e-05],
         [-4.5345e-11,  1.0000e-06,  2.0341e-10,  3.2203e-13],
         [-1.8350e-03,  2.0341e-10,  9.0035e-02,  5.9610e-06],
         [-2.2840e-05,  3.2203e-13,  5.9610e-06,  1.6763e-05]],

        [[ 8.0517e-02,  4.9278e-11,  9.9877e-03, -7.6654e-03],
         [ 4.9278e-11,  1.0000e-06,  2.1197e-11, -1.7911e-11],
         [ 9.9877e-03,  2.1197e-11,  6.7267e-02, -3.0416e-03],
         [-7.6654e-03, -1.7911e-11, -3.0416e-03,  7.3761e-03]],

        [[ 9.0149e-02,  2.5793e-11, -4.4032e-03, -4.5932e-05],
         [ 2.5793e-11,  1.0000e-06,  9.1025e-11,  5.3471e-14],
         [-4.4032e-03,  9.1025e-11,  1.1156e-01, -3.7898e-05],
         [-4.5932e-05,  5.3471e-14, -3.7898e-05, 

In [7]:
# initializing W
W = np.random.random_sample((numOfClass * numOfImagesEachClass, numOfGaussian))

# normalizing W
W = torch.from_numpy(W)
temp = NormalizingW(W)
W = temp
W.requires_grad_(True)
W.retain_grad()
print(W)

tensor([[0.1133, 0.0489, 0.0999,  ..., 0.1555, 0.0417, 0.1349],
        [0.1231, 0.1243, 0.0755,  ..., 0.0487, 0.1203, 0.0594],
        [0.0701, 0.1650, 0.0258,  ..., 0.0597, 0.2229, 0.0149],
        ...,
        [0.2290, 0.1129, 0.0765,  ..., 0.0125, 0.0154, 0.1164],
        [0.0570, 0.1471, 0.0754,  ..., 0.1179, 0.1229, 0.1335],
        [0.0704, 0.0821, 0.0712,  ..., 0.1060, 0.1834, 0.0265]],
       dtype=torch.float64, requires_grad=True)


In [None]:
# graph
graphW = np.random.random_sample((numOfClass * numOfImagesEachClass, numOfGaussian))
for i in range(graphW.shape[0]):
    for j in range(graphW.shape[1]):
        graphW[i][j] = W[i][j]
tsne = TSNE(n_components=2, random_state=0)
W_2d = tsne.fit_transform(graphW)
target_ids = range(numOfClass * numOfImagesEachClass)

from matplotlib import pyplot as plt
plt.figure(figsize=(6, 5))
label = selectedLabel.astype(int)
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'yellow', 'orange', 'purple']
for i in range(numOfClass * numOfImagesEachClass):
    plt.scatter(W_2d[i,0], W_2d[i,1], color = colors[label[i]])
plt.legend()
plt.show()

In [8]:
# initializing flag
flag = 1

# compute X
X = GenerateX()
X = torch.from_numpy(X)

# compute Xhat
Xhat = GenerateXhat()

In [9]:
# compute objective function E
E = GenerateE()
E.requires_grad_(True)

tensor(8856.5335, dtype=torch.float64, grad_fn=<TraceBackward>)

In [10]:
# begin gradient descent
lambda_ = 0
YCNew = C
YW = torch.sqrt(W)
YWNew = W

while True:
    print("Error of iteration", flag, "is:", E.item())
    EOld = E
    lambdaNew_ = (1 + np.sqrt(1 + 4 * np.power(lambda_, 2))) / 2
    gama = (1 - lambda_) / lambdaNew_
    lambda_ = lambdaNew_
    WGrad = torch.autograd.grad(E, W, create_graph = True, allow_unused = True)[0]
    CGrad = torch.autograd.grad(E, C, create_graph = True, allow_unused = True)[0]
    CGrad = NormalizingCGrad(CGrad)
    WGrad = NormalizingWGrad(WGrad)
    YWNew = YW - torch.mul(WGrad, etaW)
    YWNew = (1 - gama) * YWNew + gama * YW
    YW = YWNew
    W = torch.pow(YWNew, 2)
    NormalizingW(W)
    for j in range(numOfGaussian):
        YCNew[j] = ExponentialMap(C[j], - etaC * CGrad[j])
        C[j] = ExponentialMap(YCNew[j], gama * LogMap(YCNew[j], C[j]))
    Xhat = GenerateXhat()
    E = GenerateE()
    if torch.abs(E - EOld) < epsilon:
        print("The Final Error is:", E.item())
        break
    flag += 1

Error of iteration 1 is: 8856.533521764843
Error of iteration 2 is: 7728.3553831674035
Error of iteration 3 is: 7759.703750067056
Error of iteration 4 is: 7350.892297558691


KeyboardInterrupt: 

In [None]:
# graph tsne
graphW = np.random.random_sample((numOfClass * numOfImagesEachClass, numOfGaussian))
for i in range(graphW.shape[0]):
    for j in range(graphW.shape[1]):
        graphW[i][j] = W[i][j]
tsne = TSNE(n_components=2, random_state=0)
W_2d = tsne.fit_transform(graphW)
target_ids = range(numOfClass * numOfImagesEachClass)

plt.figure(figsize=(6, 5))
label = selectedLabel.astype(int)
colors = ['r', 'g', 'b', 'c', 'm', 'y', 'k', 'yellow', 'orange', 'purple']
for i in range(numOfClass * numOfImagesEachClass):
    plt.scatter(W_2d[i,0], W_2d[i,1], color = colors[label[i]])
plt.legend()
plt.show()

In [4]:
# helper functions
def ImageChoosing(numEach, numOfClass):
    selectedData = np.zeros((numEach * numOfClass, 256, 4))
    selectedLabel = np.zeros(numEach * numOfClass)
    count = np.zeros(numOfClass, dtype = int)
    for i in range(label.size):
        thisData = data[i]
        thisLabel = label[i]
        thisCount = count[thisLabel]
        if thisCount < 100:
            selectedData[thisLabel * 100 + thisCount] = data[i]
            selectedLabel[thisLabel * 100 + thisCount] = label[i]
            count[thisLabel] += 1
        else:
            exit = True
            for j in range(10):
                if count[j] < 100:
                    exit = False
            if exit:
                break
    return selectedData, selectedLabel

def NormalizingW(W):
    for i in range(W.shape[0]):
        sumOfW = torch.sum(W[i])
        for j in range(W.shape[1]):
            W[i][j] = W[i][j] / sumOfW
    return W

def GenerateX():
    result = np.zeros((numOfClass * numOfImagesEachClass, 4, 4))
    for i in range(numOfClass * numOfImagesEachClass):
        thisData = selectedData[i]
        gmm = GaussianMixture(n_components = 1)
        gmm.fit(thisData)
        means = gmm.means_
        covariances = gmm.covariances_
        result[i] = covariances
    return result

def GenerateXhat():
    result = np.zeros((numOfClass * numOfImagesEachClass, 4, 4))
    result = torch.from_numpy(result)
    for i in range(numOfClass * numOfImagesEachClass):
        A = 0
        B = 0
        for j in range(numOfGaussian):
            A += W[i][j] * C[j]
            B += W[i][j] * torch.inverse(C[j])
        firstTerm = sqrtMatrix(torch.inverse(B))
        secondTerm = sqrtMatrix(torch.mm(torch.mm(sqrtMatrix(B),A),sqrtMatrix(B)))
        thirdTerm = firstTerm
        result[i] = torch.mm(torch.mm(firstTerm, secondTerm), thirdTerm)
    return result

def sqrtMatrix(matrix):
    u, s, v = torch.svd(matrix)
    newS = torch.sqrt(torch.diag(s))
    return torch.mm(torch.mm(u, newS), v.t())
    
def GenerateE():
    E = 0
    for i in range(numOfClass * numOfImagesEachClass):
        E += (1 / 4) * (torch.mm(torch.inverse(X[i]), Xhat[i]) + torch.mm(torch.inverse(Xhat[i]), X[i]) - 8)
    return torch.trace(E)

def ExponentialMap(x, V):
    u, s, v = torch.svd(V)
    temp = torch.exp(s)
    newS = torch.diag(temp)
    midTerm = torch.mm(torch.mm(u, newS), u.t())
    return torch.mm(torch.mm(sqrtMatrix(x), midTerm), sqrtMatrix(x))

def LogMap(x, Y):
    u, s, v = torch.svd(Y)
    newS = torch.diag(torch.log(s))
    midTerm = torch.mm(torch.mm(u, newS), u.t())
    return torch.mm(torch.mm(sqrtMatrix(x), midTerm), sqrtMatrix(x))
    
def NormalizingWGrad(grad):
    temp = grad
    for i in range(grad.shape[0]):
        sumOfGrad = torch.sum(grad[i])
        for j in range(grad.shape[1]):
            temp[i][j] = grad[i][j] / sumOfGrad
    return temp

def NormalizingCGrad(grad):
    temp = grad
    for i in range(grad.shape[0]):
        sumOfGrad = torch.sum(grad[i])
        for j in range(grad.shape[1]):
            for k in range(grad.shape[2]):
                temp[i][j][k] = grad[i][j][k] / sumOfGrad
    return temp

def CheckW(W):
    for i in range(W.shape[0]):
        for j in range(W.shape[1]):
            if W[i][j] < 0:
                print("W has negative components on", i)

def CheckC(C):
    for i in range(C.shape[0]):
        for j in range(C.shape[1]):
            if C[i][j][j] < 0:
                print("C has negative components on the diag on the", j, "row of ", i, "element")

def CheckYCNew(C):
    for i in range(C.shape[0]):
        for j in range(C.shape[1]):
            if C[i][j][j] < 0:
                print("YCNew has negative components on the diag on the", j, "row of ", i, "element")

In [None]:
print(W)